// Copyright (C) 2002 Ronan Collobert (collober@iro.umontreal.ca)
//                
//
// This file is part of Torch. Release II.
// [The Ultimate Machine Learning Library]
//
// Torch is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
//
// Torch is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Torch; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

#include "MultiClassFormat.h"

namespace Torch {

extern "C" int multiClassTriMelanie(const void *a, const void *b)
{
  real *ar = (real *)a;
  real *br = (real *)b;

  if(*ar < *br)
    return -1;
  else
    return  1;
}

MultiClassFormat::MultiClassFormat(DataSet *data)
{
  tabclasses = NULL;

  if(data->n_targets != 1)
    warning("MultiClassFormat: the data has %d ouputs", data->n_targets);
  
  int n_set = 0;
  for(int i = 0; i < data->n_examples; i++)
  {
    data->setExample(i);
    
    bool flag = false;
    for(int k = 0; k < n_set; k++)
    {
      if(((real *)data->targets)[0] == tabclasses[k])
        flag = true;
    }

    if(!flag)
    {
      tabclasses = (real *)xrealloc(tabclasses, sizeof(real)*(n_set+1));
      tabclasses[n_set++] = ((real *)data->targets)[0];
    }
  }

  switch(n_set)
  {
    case 0:
      error("MultiClassFormat: you have no examples");
      break;
    case 1:
      warning("MultiClassFormat: you have only one class [%g]", tabclasses[0]);
      break;
    default:
      message("MultiClassFormat: %d classes detected", n_set);
      break;
  }

  // He He He...
  n_classes = n_set;
  qsort(tabclasses, n_classes, sizeof(real), multiClassTriMelanie);
  class_labels = (real **)xalloc(sizeof(real *)*n_classes);
  for(int i = 0; i < n_classes; i++)
    class_labels[i] = tabclasses+i;
}

MultiClassFormat::MultiClassFormat(int n_classes_, real *class_labels_)
{
  n_classes = n_classes_;
  tabclasses = (real *)xalloc(sizeof(real)*n_classes);

  if(class_labels_)
  {
    for(int i = 0; i < n_classes; i++)
      tabclasses[i] = class_labels_[i];
  }
  else
  {
    for(int i = 0; i < n_classes; i++)
      tabclasses[i] = (real)i;
  }

  class_labels = (real **)xalloc(sizeof(real *)*n_classes);
  for(int i = 0; i < n_classes; i++)
    class_labels[i] = tabclasses+i;
}

int MultiClassFormat::getOutputSize()
{
  return 1;
}

void MultiClassFormat::fromOneHot(List *outputs, List *one_hot_outputs)
{
  real *out = (real*)outputs->ptr;
  real max = -INF;
  int index = -1;
  int j = 0;
  while (one_hot_outputs) {
    real *one = (real*)one_hot_outputs->ptr;
    for (int i=0;i<one_hot_outputs->n;i++,j++,one++) {
      if (*one > max) {
        max = *one;
        index = j;
      }
    }
    one_hot_outputs = one_hot_outputs->next;
  }
  *out = (real)index;
}

void MultiClassFormat::toOneHot(List *outputs, List *one_hot_outputs)
{
  real out = *(real*)outputs->ptr;
  // heuristic: find the one or two labels that are closer to "out" and
  // attribute them the difference between out and their label. put 0 for
  // all the other values

  // first initialize one_hot_outputs with all zeros
  List* one_hot = one_hot_outputs;
  while (one_hot) {
    real *one = (real*)one_hot->ptr;
    for (int i=0;i<one_hot->n;i++) {
      *one++ = 0.;
    }
    one_hot = one_hot->next;
  }

  // then there are 3 different cases
  one_hot = one_hot_outputs;
  if (out > n_classes-1) {
    real diff = fabs(out - tabclasses[n_classes-1]);
    int j = 0;
    while (one_hot && j != n_classes-1) {
      real *one = (real*)one_hot->ptr;
      for (int i=0;i<one_hot->n;i++,j++,one++) {
        if (j==n_classes-1) {
          *one = diff;
          break;
        }
      }
      one_hot = one_hot->next;
    }
  } else if (out < 0) {
    real diff = fabs(out - tabclasses[0]);
    real *one = (real*)one_hot->ptr;
    *one = diff;
  } else {
    int before = (int)floor(out);
    int after = (int)ceil(out);
    // the scores are reversed so the max score is given to the neirest
    real diff_before = after - out;
    real diff_after = out - before;
    if (before == after) {
      diff_before = diff_after = 1.;
    }
    int j = 0;
    while (one_hot && j <= after) {
      real *one = (real*)one_hot->ptr;
      for (int i=0;i<one_hot->n;i++,j++,one++) {
        if (j==before)
          *one = diff_before;
        if (j==after) {
          *one = diff_after;
          break;
        }
      }
      one_hot = one_hot->next;
    }
  }
}

int MultiClassFormat::getTargetClass(void *target)
{
  real out = *((real *)target);
  real dist = fabs(out - tabclasses[0]);
  int index = 0;

  for(int i = 1; i < n_classes; i++)
  {
    real z = fabs(out - tabclasses[i]);
    if(z < dist)
    {
      index = i;
      dist = z;
    }
  }
  
  return(index);
}

int MultiClassFormat::getOutputClass(List *outputs)
{
  real out = *((real *)outputs->ptr);
  real dist = fabs(out - tabclasses[0]);
  int index = 0;

  for(int i = 1; i < n_classes; i++)
  {
    real z = fabs(out - tabclasses[i]);
    if(z < dist)
    {
      index = i;
      dist = z;
    }
  }
  
  return(index);
}

MultiClassFormat::~MultiClassFormat()
{
  free(tabclasses);
  free(class_labels);
}

}

