#include <iostream>
#include <utility>
#include <vector>

#include "Base.h"

#include <gaml-libsvm.hpp>
#include <gaml.hpp>

// Let us define (i.e. rename) few types.

typedef uci::Database::imagette X;     // This is samples.
typedef bool                    Y;     // this is the label (bi-class here)

typedef std::pair<X,Y>          Data;  // What stands in the databasis for each sample.
typedef std::vector<Data>       Basis; // The databasis.

// The two next functions converts samples into the libsvm sample structure.

int nb_nodes_of(const X& x) {
  // Let us use one coefficient for non-zero values only. We have to
  // count them first.

  int nb_non_zero = uci::Database::imagette::width*uci::Database::imagette::height;
  for(int w = 0; w < uci::Database::imagette::width; ++w)
    for(int h = 0; h < uci::Database::imagette::height; ++h)
      if(x(h,w) == 0)
	--nb_non_zero;
  return nb_non_zero + 1; // -1 stands at then end.
}

void fill_nodes(const X& x,struct svm_node* nodes) {
  int i=0;
  int index,w,h;
  unsigned char tmp;
  for(w = 0, index = 0; w < uci::Database::imagette::width; ++w)
    for(h = 0; h < uci::Database::imagette::height; ++h, ++index)
      if( (tmp = x(h,w)) != 0 ) {
	nodes[i].index = index;
	nodes[i].value = tmp/255.0; // rescale values as floats in [0,1]
	++i;
      }
  nodes[i].index = -1;
}

// The two following functions extract the sample and the label from a
// data. Here, as data are std::pairs, they are straightforward.

const X& input_of (const Data& d) {return d.first;}
const Y& output_of(const Data& d) {return d.second;}

#define NB_SAMPLES     50
#define POSITIVE_CLASS  0

int main(int argc, char* argv[]) {

  try {

    // Let us make libsvm quiet
    gaml::libsvm::quiet();

    Basis basis;
    uci::Database digits;

    // Let us fill some databasis.
    basis.resize(NB_SAMPLES);
    for(auto& data : basis) {
      digits.Next();
      data = {digits.input, 
	      digits.what == POSITIVE_CLASS};
    }

    // Let us set configure a svm. See the libsvm doc for the meaniing
    // of struct svm_parameter fields.
  
    struct svm_parameter params;
    gaml::libsvm::init(params);
    params.kernel_type = LINEAR;  
    params.svm_type    = C_SVC;
    params.C           = 1;
    params.eps         = 1e-5; 
  
    // Let us build a SVM learner.
    auto learner = gaml::libsvm::supervized::learner<X,Y>(params,nb_nodes_of,fill_nodes);

    // Let us train it and get some predictor f. f is a function.
    std::cout << "Learning..." << std::endl;
    auto f = learner(basis.begin(),basis.end(),input_of,output_of);

    // Let us use the predictor to label the next images.
    for(int i=0; i< 20; ++i) {
      digits.Next();
      std::cout << "We extract a " << digits.what << " from the databasis." << std::endl;
      bool v = f(digits.input);
      std::cout << "  Our predictor answers " << v << " for it." << std::endl
		<< "  According to the predictor, the extracted " <<  digits.what;
      if(v)
	std::cout << " is a " << POSITIVE_CLASS;
      else
	std::cout << " is not a " << POSITIVE_CLASS;
      std::cout <<  "." << std::endl;
    }
  }
  catch(gaml::exception::Any& e) {
    std::cout << e.what() << std::endl;
  }
  

  return 0;
}