// Copyright (C) 2002 Ronan Collobert (collober@iro.umontreal.ca)
//                
//
// This file is part of Torch. Release II.
// [The Ultimate Machine Learning Library]
//
// Torch is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
//
// Torch is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Torch; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

#include "QCTrainer.h"

namespace Torch {

QCTrainer::QCTrainer(QCMachine *qcmachine_, DataSet *data_, QCCache *cache_) : Trainer(qcmachine_, data_)
{
  qcmachine = qcmachine_;
  cache = cache_;

  active_var = NULL;
  active_var_new = NULL;
  not_at_bound_at_iter = NULL;
  status_alpha = NULL;

  //------

#ifdef USEDOUBLE
  addROption("eps shrink", &eps_shrink, 1E-9, "shrinking accuracy", true);
#else
  addROption("eps shrink", &eps_shrink, 1E-4, "shrinking accuracy", true);
#endif

  addBOption("unshrink", &unshrink_mode, false, "unshrink or not unshrink", true);
  addIOption("max unshrink", &n_max_unshrink, 1, "maximal number of unshrinking", true);
  addIOption("iter shrink", &n_iter_min_to_shrink, 100, "minimal number of iterations to shrink", true);
  addROption("end accuracy", &eps_fin, 0.01, "end accuracy", true);
  addIOption("iter message", &n_iter_message, 1000, "number of iterations between messages", true);
}

void QCTrainer::prepareToLaunch()
{
  l = qcmachine->l;
  grad = qcmachine->grad;
  Cup = qcmachine->Cup;
  Cdown = qcmachine->Cdown;
  alpha = qcmachine->alpha;
  y = qcmachine->y;
  l = qcmachine->l;
  eps_bornes = qcmachine->eps_bornes;

  cache->allocate();
  deja_shrink = false;
  n_active_var = l;
  active_var = (int *)xalloc(sizeof(int)*l);
  active_var_new = (int *)xalloc(sizeof(int)*l);
  not_at_bound_at_iter = (int *)xalloc(sizeof(int)*l);
  status_alpha = (char *)xalloc(sizeof(char)*l);

  for(int i = 0; i < l; i++)
  {
    active_var[i] = i;
    updateStatus(i);
//    status_alpha[i] = 1;
    not_at_bound_at_iter[i] = 0;
  }
}

void QCTrainer::atomiseAll()
{
  cache->destroy();
  free(active_var);
  free(active_var_new);
  free(not_at_bound_at_iter);
  free(status_alpha);

  active_var = NULL;
  active_var_new = NULL;
  not_at_bound_at_iter = NULL;
  status_alpha = NULL;
}

bool QCTrainer::selectVariables(int *i, int *j)
{
  real gmax_i = -INF;
  real gmin_j =  INF;
  int i_ = -1;
  int j_ = -1;

  for(int it = 0; it < n_active_var; it++)   
  {
    int t = active_var[it];

    if(y[t] > 0)
    {
      if(isNotDown(t))
      {
        if(grad[t] > gmax_i)
        {
          gmax_i = grad[t];
          i_ = t;
        }
      }

      if(isNotUp(t))
      {
        if(grad[t] < gmin_j)
        {
          gmin_j = grad[t];
          j_ = t;
        }
      }            
    }
    else
    {
      if(isNotUp(t))
      {
        if(-grad[t] > gmax_i)
        {
          gmax_i = -grad[t];
          i_ = t;
        }
      }
      
      if(isNotDown(t))
      {
        if(-grad[t] < gmin_j)
        {
          gmin_j = -grad[t];
          j_ = t;
        }
      }            
    }
  }

  current_error =  gmax_i - gmin_j;

  if(current_error < eps_fin)
    return(true);
  
  if( (i_ == -1) || (j_ == -1) )
    return(true);

  *i = i_;
  *j = j_;

  return(false);
}

// Renvoie le nb de var susceptibles d'etre shrinkee
int QCTrainer::checkShrinking(real bmin, real bmax)
{
  real bb = (bmin+bmax)/2.;

  n_active_var_new = 0;
  for(int it = 0; it < n_active_var; it++)
  {
    int t = active_var[it];
    bool garde = true;

    if(isNotDown(t) && isNotUp(t))
      not_at_bound_at_iter[t] = iter;
    else
    {
      if(isNotUp(t)) // Donc elle est en bas.
      {
        if(grad[t] + y[t]*bb < eps_shrink)
          not_at_bound_at_iter[t] = iter;
        else
        {
          if( (iter - not_at_bound_at_iter[t]) > n_iter_min_to_shrink)
            garde = false;
        }
      }
      else
      {
        if(grad[t] + y[t]*bb > -eps_shrink)
          not_at_bound_at_iter[t] = iter;
        else
        {
          if( (iter - not_at_bound_at_iter[t]) > n_iter_min_to_shrink)
            garde = false;
        }
      }      
    }

    if(garde)
      active_var_new[n_active_var_new++] = t;
  }

  return(n_active_var-n_active_var_new);
}

void QCTrainer::shrink()
{
  n_active_var = n_active_var_new;
  int *ptr_sav = active_var;
  active_var = active_var_new;
  active_var_new = ptr_sav;
  deja_shrink = true;

  if(!unshrink_mode)
    cache->setBoosterMode(&n_active_var, active_var);
}

void QCTrainer::unShrink()
{
  for(int i = 0; i < l; i++)
    active_var[i] = i;

  n_active_var = l;
  deja_shrink = false;

  if(++n_unshrink == n_max_unshrink)
  {
    unshrink_mode = false;
    n_iter_min_to_shrink = 666666666;
    warning("QCTrainer: shrinking and unshrinking desactived...");
  }
}

void QCTrainer::train(List *measurers)
{
  prepareToLaunch();

  int xi, xj;
  int n_to_shrink = 0;
  n_unshrink = 0;

  message("QCTrainer: training...");

  iter = 0;
  while(1)
  {
    if(selectVariables(&xi, &xj))
    {
      if(unshrink_mode)
      {
        message("QCTrainer: unshrink...");
        unShrink();
        if(selectVariables(&xi, &xj))
        {
          message("QCTrainer: ...finished");
          break;
        }
        else
          message("QCTrainer: ...restart");
      }
      else
        break;
    }

    if(iter >= n_iter_min_to_shrink)
      n_to_shrink = checkShrinking(-y[xi]*grad[xi], -y[xj]*grad[xj]);

    k_xi = cache->adressCache(xi);
    k_xj = cache->adressCache(xj);

    old_alpha_xi = alpha[xi];
    old_alpha_xj = alpha[xj];

    analyticSolve(xi, xj);

    real delta_alpha_xi = alpha[xi] - old_alpha_xi;
    real delta_alpha_xj = alpha[xj] - old_alpha_xj;

    if(deja_shrink && !unshrink_mode)
    {
      for(int t = 0; t < n_active_var; t++)
      {
        int it = active_var[t];
        grad[it] += k_xi[it]*delta_alpha_xi + k_xj[it]*delta_alpha_xj;
      }
    }
    else
    {
      for(int t = 0; t < l; t++)
        grad[t] += k_xi[t]*delta_alpha_xi + k_xj[t]*delta_alpha_xj;
    }

    iter++;
    if(! (iter % n_iter_message) )
    {
      // Pour ne pas effrayer le neophite.
      if(current_error < 0)
        current_error = 0;
      print("  + Iteration %d\n", iter);
      print("   --> Current error    = %g\n", current_error);
      print("   --> Active variables = %d\n", n_active_var);
    }

    /////////////// Shhhhhrinnnk

    if(!(iter % n_iter_min_to_shrink))
    {
      if( (n_to_shrink > n_active_var/10) && (n_active_var-n_to_shrink > 100) )
        shrink();
    }
  }

  // Pour ne pas effrayer le neophite.
  if(current_error < 0)
    current_error = 0;
  print("  + Iteration %d\n", iter);
  print("   --> Current error    = %g\n", current_error);
  print("   --> Active variables = %d\n", n_active_var);

  qcmachine->checkSupportVectors();
  atomiseAll();
}

void QCTrainer::updateStatus(int i)
{
  if(alpha[i] < Cup[i] - eps_bornes)
    status_alpha[i] = 1;
  else
    status_alpha[i] = 0;

  if(alpha[i] > Cdown[i] + eps_bornes)
    status_alpha[i] |= 2;
}

void QCTrainer::analyticSolve(int xi, int xj)
{
  real ww, H, L;

  real s = y[xi]*y[xj];
  if(s < 0)
  {
    ww = old_alpha_xi - old_alpha_xj;
    L = ((Cdown[xj]+ww >   Cdown[xi]) ? Cdown[xj]+ww :  Cdown[xi]);
    H = ((Cup[xj]+ww   >     Cup[xi]) ? Cup[xi]      : Cup[xj]+ww);
  }
  else
  {
    ww = old_alpha_xi + old_alpha_xj;
    L = ((ww-Cup[xj]   >    Cdown[xi]) ? ww-Cup[xj] :     Cdown[xi]);
    H = ((ww-Cdown[xj] >      Cup[xi]) ?    Cup[xi] :  ww-Cdown[xj]);
  }

  real eta = k_xi[xi] - 2.*s*k_xi[xj] + k_xj[xj];
  if(eta > 0)
  {
    real alph = old_alpha_xi + (s*grad[xj] - grad[xi])/eta;
	
    if(alph > H)
      alph = H;
    else
    {
      if(alph < L)
        alph = L;
    }
    
    alpha[xi] = alph;
    alpha[xj] -= s*(alpha[xi]-old_alpha_xi);
  }
  else
  {
    real alph = grad[xi] - s*grad[xj];
    if(alph > 0)
    {
      alpha[xi] = L;
      alpha[xj] += s*(alpha[xi]-old_alpha_xi);
    }
    else
    {
      alpha[xi] = H;
      alpha[xj] += s*(alpha[xi]-old_alpha_xi);
    }
  }

  updateStatus(xi);
  updateStatus(xj);
}

QCTrainer::~QCTrainer()
{
}

}

