//                                               -*- C++ -*-
/**
 *  @file  KrigingAlgorithm.cxx
 *  @brief The class building gaussian process regression
 *
 *  Copyright 2005-2015 Airbus-EDF-IMACS-Phimeca
 *
 *  This library is free software: you can redistribute it and/or modify
 *  it under the terms of the GNU Lesser General Public License as published by
 *  the Free Software Foundation, either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  This library is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU Lesser General Public License for more details.
 *
 *  You should have received a copy of the GNU Lesser General Public
 *  along with this library.  If not, see <http://www.gnu.org/licenses/>.
 *
 *  @author schueller
 */

#include "KrigingEvaluation.hxx"
#include "PersistentObjectFactory.hxx"

BEGIN_NAMESPACE_OPENTURNS

CLASSNAMEINIT(KrigingEvaluation);


static Factory<KrigingEvaluation> RegisteredFactory("KrigingEvaluation");


/* Constructor with parameters */
KrigingEvaluation::KrigingEvaluation()
  : NumericalMathEvaluationImplementation()
{
  // Nothing to do
}


/* Constructor with parameters */
KrigingEvaluation::KrigingEvaluation (const Basis & basis,
                                      const NumericalSample & inputSample,
                                      const CovarianceModel & covarianceModel,
                                      const NumericalPoint & beta,
                                      const NumericalPoint & gamma)
  : NumericalMathEvaluationImplementation()
  , basis_(basis)
  , inputSample_(inputSample)
  , covarianceModel_(covarianceModel)
  , beta_(beta)
  , gamma_(gamma)
{
  if (basis.getSize() > 0)
    if (basis[0].getOutputDimension() > 1) LOGWARN(OSS() << "Expected a basis of functions with output dimension=1, got output dimension=" << basis[0].getOutputDimension() << ". Only the first output component will be taken into account.");
  if (covarianceModel.getSpatialDimension() != inputSample.getDimension()) throw InvalidArgumentException(HERE) << "Error: the spatial dimension=" << covarianceModel.getSpatialDimension() << " of the covariance model should match the dimension=" << inputSample.getDimension() << " of the input sample";
  if (beta.getSize() != basis.getSize()) throw InvalidArgumentException(HERE) << "Error: the number of regression coefficients=" << beta.getSize() << " is different from the basis size=" << basis.getSize();
  if (gamma.getSize() != inputSample.getSize()) throw InvalidArgumentException(HERE) << "Error: the number of conditional covariance kernel coefficients=" << gamma.getSize() << " is different from the input sample size=" << inputSample.getSize();
  setInputDescription(Description::BuildDefault(getInputDimension(), "x"));
  setOutputDescription(Description::BuildDefault(getOutputDimension(), "y"));
  setParameters(NumericalPointWithDescription(getInputDimension()));
}


/* Virtual constructor */
KrigingEvaluation * KrigingEvaluation::clone() const
{
  return new KrigingEvaluation(*this);
}

/* Comparison operator */
Bool KrigingEvaluation::operator==(const KrigingEvaluation & other) const
{
  return true;
}

/* String converter */
String KrigingEvaluation::__repr__() const
{
  OSS oss;
  oss << "class=" << GetClassName()
      << " name=" << getName()
      << " correlationModel=" << covarianceModel_
      << " beta=" << beta_
      << " gamma=" << gamma_;
  return oss;
}

/* String converter */
String KrigingEvaluation::__str__(const String & offset) const
{
  return OSS(false) << offset << GetClassName();
}

/* Test for actual implementation */
Bool KrigingEvaluation::isActualImplementation() const
{
  return true;
}

// Helper for the parallel version of the point-based evaluation operator
struct KrigingEvaluationPointFunctor
{
  const NumericalPoint & input_;
  const KrigingEvaluation & evaluation_;
  NumericalScalar accumulator_;

  KrigingEvaluationPointFunctor(const NumericalPoint & input,
                                const KrigingEvaluation & evaluation)
    : input_(input)
    , evaluation_(evaluation)
    , accumulator_(0.0)
  {}

  KrigingEvaluationPointFunctor(const KrigingEvaluationPointFunctor & other,
                                TBB::Split)
    : input_(other.input_)
    , evaluation_(other.evaluation_)
    , accumulator_(0.0)
  {}

  inline void operator()( const TBB::BlockedRange<UnsignedInteger> & r )
  {
    for (UnsignedInteger i = r.begin(); i != r.end(); ++i) accumulator_ += evaluation_.covarianceModel_(input_, evaluation_.inputSample_[i])(0, 0) * evaluation_.gamma_[i];
  } // operator()

  inline void join(const KrigingEvaluationPointFunctor & other)
  {
    accumulator_ += other.accumulator_;
  }

}; // struct KrigingEvaluationPointFunctor

/* Operator () */
NumericalPoint KrigingEvaluation::operator()(const NumericalPoint & inP) const
{
  const UnsignedInteger trainingSize(inputSample_.getSize());
  // Evaluate the kernel part in parallel
  KrigingEvaluationPointFunctor functor( inP, *this );
  TBB::ParallelReduce( 0, trainingSize, functor );
  NumericalScalar value(functor.accumulator_);
  // Evaluate the basis part sequentially
  const UnsignedInteger basisSize(basis_.getSize());
  for (UnsignedInteger i = 0; i < basisSize; ++i) value += basis_[i](inP)[0] * beta_[i];
  ++callsNumber_;
  return NumericalPoint(1, value);
}

// Helper for the parallel version of the sample-based evaluation operator
struct KrigingEvaluationSampleFunctor
{
  const NumericalSample & input_;
  NumericalSample & output_;
  const KrigingEvaluation & evaluation_;
  UnsignedInteger trainingSize_;

  KrigingEvaluationSampleFunctor(const NumericalSample & input,
                                 NumericalSample & output,
                                 const KrigingEvaluation & evaluation)
    : input_(input)
    , output_(output)
    , evaluation_(evaluation)
    , trainingSize_(evaluation.inputSample_.getSize())
  {}

  inline void operator()( const TBB::BlockedRange<UnsignedInteger> & r ) const
  {
    const UnsignedInteger start(r.begin());
    const UnsignedInteger size(r.end() - start);
    Matrix R(size, trainingSize_);
    for (UnsignedInteger i = 0; i != size; ++i)
    {
      for (UnsignedInteger j = 0; j < trainingSize_; ++j)
        R(i, j) = evaluation_.covarianceModel_(input_[start + i], evaluation_.inputSample_[j])(0, 0);
    }
    const NumericalPoint pointResult(R * evaluation_.gamma_);
    for (UnsignedInteger i = 0; i != size; ++i) output_[start + i][0] += pointResult[i];
  } // operator()
}; // struct KrigingEvaluationSampleFunctor

NumericalSample KrigingEvaluation::operator()(const NumericalSample & inS) const
{
  const UnsignedInteger size(inS.getSize());
  NumericalSample result(size, getOutputDimension());
  const UnsignedInteger basisSize(basis_.getSize());
  for (UnsignedInteger i = 0; i < basisSize; ++i)
    result += basis_[i](inS) * beta_[i];
  const KrigingEvaluationSampleFunctor functor( inS, result, *this );
  TBB::ParallelFor( 0, size, functor );
  callsNumber_ += size;
  return result;
}


/* Accessor for input point dimension */
UnsignedInteger KrigingEvaluation::getInputDimension() const
{
  return inputSample_.getDimension();
}

/* Accessor for output point dimension */
UnsignedInteger KrigingEvaluation::getOutputDimension() const
{
  return 1;
}

/* Method save() stores the object through the StorageManager */
void KrigingEvaluation::save(Advocate & adv) const
{
  NumericalMathEvaluationImplementation::save(adv);
  adv.saveAttribute("basis_", basis_);
  adv.saveAttribute("inputSample_", inputSample_);
  adv.saveAttribute("covarianceModel_", covarianceModel_);
  adv.saveAttribute("beta_", beta_);
  adv.saveAttribute("gamma_", gamma_);
}

/* Method load() reloads the object from the StorageManager */
void KrigingEvaluation::load(Advocate & adv)
{
  NumericalMathEvaluationImplementation::load(adv);
  adv.loadAttribute("basis_", basis_);
  adv.loadAttribute("inputSample_", inputSample_);
  adv.loadAttribute("covarianceModel_", covarianceModel_);
  adv.loadAttribute("beta_", beta_);
  adv.loadAttribute("gamma_", gamma_);
}


END_NAMESPACE_OPENTURNS
