// Copyright (C) 2002 Ronan Collobert (collober@iro.umontreal.ca)
//                
//
// This file is part of Torch. Release II.
// [The Ultimate Machine Learning Library]
//
// Torch is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
//
// Torch is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Torch; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

#include "SVMCache.h"

namespace Torch {

SVMCache::SVMCache(Kernel *kernel_, /*int n_variables,*/ real taille_en_megs_)
{
  kernel = kernel_;
//  l = n_variables;
  taille_en_megs = taille_en_megs_;

  index_dans_liste = NULL;
  cached_sauve = NULL;
  memory_cache = NULL;
}

void SVMCache::allocate()
{
  destroy();

  booster_mode = false;

  // Allocs...
  taille = (int)(taille_en_megs*1048576./((real)sizeof(real)*l));
  index_dans_liste = (SVMCacheListe **)xalloc(sizeof(SVMCacheListe *)*l);
  cached = (SVMCacheListe *)xalloc(sizeof(SVMCacheListe)*taille);
  cached_sauve = cached;

  message("SVMCache: max columns in cache: %d", taille);
  if(taille < 2)
    error("SVMCache: please change the cache size : it's too small");

  // Init
  SVMCacheListe *ptr = cached;
  for(int i = 0; i < l; i++)
    index_dans_liste[i] = NULL;

  memory_cache = (real *)xalloc(sizeof(real)*taille*l);
  for(int i = 0; i < taille; i++)
  {
    ptr->adr = memory_cache+i*l;
    ptr->index = -1;
    if(i != 0)
      ptr->prev = (ptr-1);
    else
      ptr->prev = &cached[taille-1];
    if(i != taille-1)
      ptr->suiv = (ptr+1);
    else
      ptr->suiv = cached;

    ptr++;
  }
}

void SVMCache::destroy()
{
  free(index_dans_liste);
  free(cached_sauve);
  free(memory_cache);

  index_dans_liste = NULL;
  cached_sauve = NULL;
  memory_cache = NULL;
}

void SVMCache::clear()
{
  SVMCacheListe *ptr = cached;
  for(int i = 0; i < taille; i++)
  {
    ptr->index = -1;
    ptr = ptr->suiv;
  }

  for(int i = 0; i < l; i++)
    index_dans_liste[i] = NULL;
}

real *SVMCache::adressCache(int index)
{
  SVMCacheListe *ptr;

  // Rq: en regression faudrait faire gaffe a pas recalculer deux trucs...
  // Mais pb: -1 +1 a inverser dans la matrice...
  // Donc faich.

  ptr = index_dans_liste[index];
  if( (ptr != NULL) && (ptr != cached) )
  {
    ptr->prev->suiv = ptr->suiv;
    ptr->suiv->prev = ptr->prev;

    ptr->suiv = cached;
    ptr->prev = cached->prev;
    cached->prev->suiv = ptr;
    cached->prev = ptr;
    cached = ptr;
  }
  else
  {
    cached = cached->prev;
    if(cached->index != -1)
      index_dans_liste[cached->index] = NULL;
    cached->index = index;
    index_dans_liste[index] = cached;
    rempliColonne(index, cached->adr);
  }

  return(cached->adr);
}

void SVMCache::setBoosterMode(int *n_active_var_, int *active_var_)
{
  booster_mode = true;
  n_active_var = n_active_var_;
  active_var = active_var_;
}

SVMCache::~SVMCache()
{
  destroy();
}

SVMCacheClassification::SVMCacheClassification(SVMClassification *svm, real taille_en_megs)
  : SVMCache(svm->kernel, /*svm->data->n_examples,*/ taille_en_megs)
{
  y = svm->y;
}

void SVMCacheClassification::allocate()
{
  l = kernel->data->n_examples;
  SVMCache::allocate();
}

void SVMCacheClassification::rempliColonne(int index, real *adr)
{
  if(booster_mode)
  {
    if(y[index] > 0)
    {
      for(int it = 0; it < *n_active_var; it++)
      {
        int t = active_var[it];
        adr[t] =  y[t]*kernel->eval(index, t);
      }
    }
    else
    {
      for(int it = 0; it < *n_active_var; it++)
      {
        int t = active_var[it];
        adr[t] = -y[t]*kernel->eval(index, t);
      }
    }
  }
  else
  {
    if(y[index] > 0)
    {
      for(int i = 0; i < l; i++)
        adr[i] =  y[i]*kernel->eval(index, i);
    }
    else
    {
      for(int i = 0; i < l; i++)
        adr[i] = -y[i]*kernel->eval(index, i);
    }
  }
}

SVMCacheRegression::SVMCacheRegression(SVMRegression *svm, real taille_en_megs)
  : SVMCache(svm->kernel, /*2*svm->data->n_examples,*/ taille_en_megs)
{
//  lm = svm->data->n_examples;
}

void SVMCacheRegression::allocate()
{
  l = 2*kernel->data->n_examples;
  lm = kernel->data->n_examples;
  SVMCache::allocate();
}

void SVMCacheRegression::rempliColonne(int index, real *adr)
{
  int indexm = index%lm;
  if(booster_mode)
  {
    for(int i = 0; i < *n_active_var; i++)
    {
      int k = active_var[i]%lm;
      adr[k] = kernel->eval(indexm, k);
    }
  }
  else
  {
    for(int i = 0; i < lm; i++)
      adr[i] = kernel->eval(indexm, i);
  }

  for(int i = 0; i < lm; i++)
    adr[i+lm] = adr[i];
}

}

