// Copyright (C) 2002 Samy Bengio (bengio@idiap.ch)
//                and Ronan Collobert (collober@iro.umontreal.ca)
//                
//
// This file is part of Torch. Release II.
// [The Ultimate Machine Learning Library]
//
// Torch is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
//
// Torch is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Torch; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

#include "MLP.h"

namespace Torch {

MLP::MLP(int n_inputs_, int n_hidden_, int n_outputs_)
{
  n_inputs = n_inputs_;
  n_hidden = n_hidden_;
  n_outputs = n_outputs_;
  addBOption("inputs to outputs", &inputs_to_outputs, false, "connections from inputs to outputs");
  addROption("weight decay", &weight_decay, 0, "weight decay");
  addBOption("softmax outputs", &is_softmax_outputs, false, "softmax outputs");
  addBOption("sigmoid outputs", &is_sigmoid_outputs, false, "sigmoid outputs");
  addBOption("log-softmax outputs", &is_log_softmax_outputs, false, "log-softmax outputs");
  addBOption("tanh outputs", &is_tanh_outputs, false, "tanh outputs");
  addBOption("sparse inputs", &is_sparse_inputs, false, "sparse inputs");
}

void MLP::init()
{
  if(inputs_to_outputs && is_sparse_inputs)
    error("MLP: sorry, connections from inputs to outputs and sparse inputs aren't compatible");

  add_layer = NULL;
  sum_layer = NULL;
  if (n_hidden>0) {

    if(is_sparse_inputs)
    {
      sparse_hidden_layer = new SparseLinear(n_inputs,n_hidden);
      sparse_hidden_layer->setROption("weight decay",weight_decay);
      sparse_hidden_layer->init();
      addMachine(sparse_hidden_layer);
    }
    else
    {
      hidden_layer = new Linear(n_inputs,n_hidden);
      hidden_layer->setROption("weight decay",weight_decay);
      hidden_layer->init();
      addMachine(hidden_layer);
    }

    if (inputs_to_outputs) {
      add_layer = new Linear(n_inputs,n_outputs);
      add_layer->setROption("weight decay",weight_decay);
      add_layer->init();
      addMachine(add_layer);
    }

    hidden_tanh_layer = new Tanh(n_hidden);
    hidden_tanh_layer->init();
    addLayer();
    addMachine(hidden_tanh_layer);
    if(is_sparse_inputs)
      connectOn(sparse_hidden_layer);
    else
      connectOn(hidden_layer);

    addLayer();
    outputs_layer = new Linear(n_hidden,n_outputs);
    outputs_layer->setROption("weight decay",weight_decay);
    outputs_layer->init();
    addMachine(outputs_layer);
    connectOn(hidden_tanh_layer);

    if (inputs_to_outputs) {
      sum_layer = new SumMachine(n_outputs,2);
      sum_layer->init();
      addLayer();
      addMachine(sum_layer);
      connectOn(add_layer);
      connectOn(outputs_layer);
    }
  } else {

    if(is_sparse_inputs)
    {
      sparse_outputs_layer = new SparseLinear(n_inputs,n_outputs);
      sparse_outputs_layer->setROption("weight decay",weight_decay);
      sparse_outputs_layer->init();
      addFCL(sparse_outputs_layer);
    }
    else
    {
      outputs_layer = new Linear(n_inputs,n_outputs);
      outputs_layer->setROption("weight decay",weight_decay);
      outputs_layer->init();
      addFCL(outputs_layer);
    }
  }

  if(is_softmax_outputs)
  {
    outputs_softmax_layer = new Softmax(n_outputs);
    outputs_softmax_layer->init();
    addFCL(outputs_softmax_layer);
  }
  else
  {
    if(is_sigmoid_outputs)
    {
      outputs_sigmoid_layer = new Sigmoid(n_outputs);
      outputs_sigmoid_layer->init();
      addFCL(outputs_sigmoid_layer);
    }
    else
    {
      if(is_log_softmax_outputs)
      {
        outputs_log_softmax_layer = new LogSoftmax(n_outputs);
        outputs_log_softmax_layer->init();
        addFCL(outputs_log_softmax_layer);
      }
      else
      {
        if (is_tanh_outputs)
        {
          outputs_tanh_layer = new Tanh(n_outputs);
          outputs_tanh_layer->init();
          addFCL(outputs_tanh_layer);
        }
      }
    }
  }

  ConnectedMachine::init();
}

MLP::~MLP()
{
  if (n_hidden>0) {
    
    if(is_sparse_inputs)
      delete sparse_hidden_layer;
    else
      delete hidden_layer;

    delete hidden_tanh_layer;
    delete outputs_layer;
  }
  else
  {
    if(is_sparse_inputs)
      delete sparse_outputs_layer;
    else
      delete outputs_layer;
  }

  if(is_softmax_outputs)
    delete outputs_softmax_layer;
  else
  {
    if(is_sigmoid_outputs)
      delete outputs_sigmoid_layer;
    else
    {
      if(is_log_softmax_outputs)
        delete outputs_log_softmax_layer;
      else
      {
        if(is_tanh_outputs)
          delete outputs_tanh_layer;
      }
    }
  }

  if (add_layer)
    delete add_layer;
  if (inputs_to_outputs)
    delete sum_layer;
}

}

