// Copyright (C) 2002 Samy Bengio (bengio@idiap.ch)
//                
//
// This file is part of Torch. Release II.
// [The Ultimate Machine Learning Library]
//
// Torch is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
//
// Torch is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Torch; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

#include "Kmeans.h"
#include "log_add.h"
#include "random.h"

namespace Torch {

Kmeans::Kmeans(int n_observations_, int n_gaussians_, real* var_threshold_, real prior_weights_, SeqDataSet* data_) : DiagonalGMM(n_observations_, n_gaussians_, var_threshold_, prior_weights_)
{
  data = data_;
}

void Kmeans::reset()
{
  // initialize the parameters using some examples in the dataset randomly
  for (int i=0;i<n_gaussians;i++) {
    data->setExample((int)(uniform()*data->n_examples));
    data->setFrame((int)(uniform()*data->n_frames));
    SeqExample *ex = (SeqExample*)data->inputs->ptr;
    real *x = ex->observations[data->current_frame];
    real *means_i = means[i];
    real *var_i = var[i];
    real *thresh = var_threshold;
    for(int j = 0; j < n_observations; j++) {
      *means_i++ = *x++;
      *var_i++ = *thresh++;
    }
    log_weights[i] = log(1./n_gaussians);
  }
}

void Kmeans::eMIterInitialize()
{
  // initialize the accumulators to 0
  for (int i=0;i<n_gaussians;i++) {
    real *pm = means_acc[i];
    real *ps = var_acc[i];
    for (int j=0;j<n_observations;j++) {
      *pm++ = 0.;
      *ps++ = 0.;
    }
    weights_acc[i] = prior_weights;
  }
}

void Kmeans::allocateMemory()
{
  DiagonalGMM::allocateMemory();
  min_cluster = (int*)xalloc(sizeof(int)*max_n_frames);
}

void Kmeans::freeMemory()
{
  free(min_cluster);
}

void Kmeans::eMSequenceInitialize(List* inputs)
{
  SeqExample* ex = (SeqExample*)inputs->ptr;
  if (ex->n_real_frames > max_n_frames) {
    // max_n_frames will be set in DiagonalGMM...
    min_cluster = (int*)xrealloc(min_cluster,sizeof(int)*ex->n_real_frames);
  }
  DiagonalGMM::eMSequenceInitialize(inputs);
}

real Kmeans::frameLogProbability(real *observations, real *inputs, int t)
{
  real min_dist = INF;
  int min_i = -1;
  for (int i=0;i<n_gaussians;i++) {
    real dist = 0;
    real* means_i = means[i];
    real *x = observations;
    for(int j = 0; j < n_observations; j++) {
      real diff = *x++ - *means_i++;
      dist += diff*diff;
    }
    if (dist < min_dist) {
      min_dist = dist;
      min_i = i;
    }
  }
  log_probabilities[t] = -min_dist;
  min_cluster[t] = min_i;
  return -min_dist;
}

void Kmeans::frameEMAccPosteriors(real *observations, real log_posterior, real *inputs,int t)
{
  int min_i = min_cluster[t];
  real* means_acc_i = means_acc[min_i];
  real* var_acc_i = var_acc[min_i];
  real *x = observations;
  for(int j = 0; j < n_observations; j++) {
    *var_acc_i++ += *x * *x;
    *means_acc_i++ += *x++;
  }
  weights_acc[min_i] ++;
}

void Kmeans::eMUpdate()
{
   // first the weights and var
  real* p_weights_acc = weights_acc;
  for (int i=0;i<n_gaussians;i++,p_weights_acc++) {
    if (*p_weights_acc == 0) {
      warning("Gaussian %d of Kmeans is not used in EM",i);
    } else {
      real* p_means_i = means[i];
      real* p_var_i = var[i];
      real* p_means_acc_i = means_acc[i];
      real* p_var_acc_i = var_acc[i];
      for (int j=0;j<n_observations;j++) {
        real v = *p_var_acc_i++ / *p_weights_acc - *p_means_i * *p_means_i;
        *p_var_i++ = v >= var_threshold[j] ? v : var_threshold[j];
        *p_means_i++ = *p_means_acc_i++ / *p_weights_acc;
      }
    }
  }
  // then the weights
  real sum_weights_acc = 0;
  p_weights_acc = weights_acc;
  for (int i=0;i<n_gaussians;i++)
    sum_weights_acc += *p_weights_acc++;
  if (sum_weights_acc == 0)
    warning("the posteriors of weights of Kmeans are not used");
  else {
    real *p_log_weights = log_weights;
    real log_sum = log(sum_weights_acc);
    p_weights_acc = weights_acc;
    for (int i=0;i<n_gaussians;i++)
      *p_log_weights++ = log(*p_weights_acc++) - log_sum;
  }
}

Kmeans::~Kmeans()
{
  freeMemory();
}

}

