/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.regression;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataframe.DataFrameBuilder;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.input.parameter.regression.LinearRegressionParams;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.ml.engine.Predictable;
import org.opensearch.ml.engine.Trainable;
import org.opensearch.ml.engine.annotation.Function;
import org.opensearch.ml.engine.contants.TribuoOutputType;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.utils.ModelSerDeSer;
import org.opensearch.ml.engine.utils.TribuoUtil;
import org.tribuo.Model;
import org.tribuo.MutableDataset;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.math.optimisers.AdaDelta;
import org.tribuo.math.optimisers.AdaGrad;
import org.tribuo.math.optimisers.Adam;
import org.tribuo.math.optimisers.RMSProp;
import org.tribuo.math.optimisers.SGD;
import org.tribuo.regression.RegressionFactory;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.sgd.RegressionObjective;
import org.tribuo.regression.sgd.linear.LinearSGDTrainer;
import org.tribuo.regression.sgd.objectives.AbsoluteLoss;
import org.tribuo.regression.sgd.objectives.Huber;
import org.tribuo.regression.sgd.objectives.SquaredLoss;

@Function(value=FunctionName.LINEAR_REGRESSION)
public class LinearRegression
implements Trainable,
Predictable {
    public static final String VERSION = "1.0.0";
    private static final LinearRegressionParams.ObjectiveType DEFAULT_OBJECTIVE_TYPE = LinearRegressionParams.ObjectiveType.SQUARED_LOSS;
    private static final LinearRegressionParams.OptimizerType DEFAULT_OPTIMIZER_TYPE = LinearRegressionParams.OptimizerType.ADA_GRAD;
    private static final double DEFAULT_LEARNING_RATE = 0.01;
    private static final double DEFAULT_MOMENTUM_FACTOR = 0.0;
    private static final LinearRegressionParams.MomentumType DEFAULT_MOMENTUM_TYPE = LinearRegressionParams.MomentumType.STANDARD;
    private static final double DEFAULT_EPSILON = 1.0E-6;
    private static final double DEFAULT_BETA1 = 0.9;
    private static final double DEFAULT_BETA2 = 0.99;
    private static final double DEFAULT_DECAY_RATE = 0.9;
    private static final int DEFAULT_EPOCHS = 1000;
    private static final int DEFAULT_INTERVAL = -1;
    private static final int DEFAULT_BATCH_SIZE = 1;
    private static final Long DEFAULT_SEED = 12345L;
    private LinearRegressionParams parameters;
    private StochasticGradientOptimiser optimiser;
    private RegressionObjective objective;
    private int loggingInterval;
    private int minibatchSize;
    private long seed;
    private Model<Regressor> regressionModel;

    public LinearRegression() {
    }

    public LinearRegression(MLAlgoParams parameters) {
        this.parameters = parameters == null ? LinearRegressionParams.builder().build() : (LinearRegressionParams)parameters;
        this.validateParameters();
        this.createObjective();
        this.createOptimiser();
    }

    private void createObjective() {
        LinearRegressionParams.ObjectiveType objectiveType = Optional.ofNullable(this.parameters.getObjectiveType()).orElse(DEFAULT_OBJECTIVE_TYPE);
        switch (objectiveType) {
            case ABSOLUTE_LOSS: {
                this.objective = new AbsoluteLoss();
                break;
            }
            case HUBER: {
                this.objective = new Huber();
                break;
            }
            default: {
                this.objective = new SquaredLoss();
            }
        }
    }

    private void createOptimiser() {
        LinearRegressionParams.OptimizerType optimizerType = Optional.ofNullable(this.parameters.getOptimizerType()).orElse(DEFAULT_OPTIMIZER_TYPE);
        Double learningRate = Optional.ofNullable(this.parameters.getLearningRate()).orElse(0.01);
        Double momentumFactor = Optional.ofNullable(this.parameters.getMomentumFactor()).orElse(0.0);
        Double epsilon = Optional.ofNullable(this.parameters.getEpsilon()).orElse(1.0E-6);
        Double beta1 = Optional.ofNullable(this.parameters.getBeta1()).orElse(0.9);
        Double beta2 = Optional.ofNullable(this.parameters.getBeta2()).orElse(0.99);
        LinearRegressionParams.MomentumType momentumType = Optional.ofNullable(this.parameters.getMomentumType()).orElse(DEFAULT_MOMENTUM_TYPE);
        Double decayRate = Optional.ofNullable(this.parameters.getDecayRate()).orElse(0.9);
        SGD.Momentum momentum = switch (momentumType) {
            case LinearRegressionParams.MomentumType.NESTEROV -> SGD.Momentum.NESTEROV;
            default -> SGD.Momentum.STANDARD;
        };
        switch (optimizerType) {
            case SIMPLE_SGD: {
                this.optimiser = SGD.getSimpleSGD((double)learningRate, (double)momentumFactor, (SGD.Momentum)momentum);
                break;
            }
            case LINEAR_DECAY_SGD: {
                this.optimiser = SGD.getLinearDecaySGD((double)learningRate, (double)momentumFactor, (SGD.Momentum)momentum);
                break;
            }
            case SQRT_DECAY_SGD: {
                this.optimiser = SGD.getSqrtDecaySGD((double)learningRate, (double)momentumFactor, (SGD.Momentum)momentum);
                break;
            }
            case ADA_DELTA: {
                this.optimiser = new AdaDelta(momentumFactor.doubleValue(), epsilon.doubleValue());
                break;
            }
            case ADAM: {
                this.optimiser = new Adam(learningRate.doubleValue(), beta1.doubleValue(), beta2.doubleValue(), epsilon.doubleValue());
                break;
            }
            case RMS_PROP: {
                this.optimiser = new RMSProp(learningRate.doubleValue(), momentumFactor.doubleValue(), epsilon.doubleValue(), decayRate.doubleValue());
                break;
            }
            default: {
                this.optimiser = new AdaGrad(learningRate.doubleValue(), epsilon.doubleValue());
            }
        }
    }

    private void validateParameters() {
        if (this.parameters.getLearningRate() != null && this.parameters.getLearningRate() < 0.0) {
            throw new IllegalArgumentException("Learning rate should not be negative.");
        }
        if (this.parameters.getMomentumFactor() != null && this.parameters.getMomentumFactor() < 0.0) {
            throw new IllegalArgumentException("MomentumFactor should not be negative.");
        }
        if (this.parameters.getEpsilon() != null && this.parameters.getEpsilon() < 0.0) {
            throw new IllegalArgumentException("Epsilon should not be negative.");
        }
        if (this.parameters.getBeta1() != null && (this.parameters.getBeta1() <= 0.0 || this.parameters.getBeta1() >= 1.0)) {
            throw new IllegalArgumentException("Beta1 should be in an open interval (0,1).");
        }
        if (this.parameters.getBeta2() != null && (this.parameters.getBeta2() <= 0.0 || this.parameters.getBeta2() >= 1.0)) {
            throw new IllegalArgumentException("Beta2 should be in an open interval (0,1).");
        }
        if (this.parameters.getDecayRate() != null && this.parameters.getDecayRate() < 0.0) {
            throw new IllegalArgumentException("DecayRate should not be negative.");
        }
        if (this.parameters.getEpochs() != null && this.parameters.getEpochs() < 0) {
            throw new IllegalArgumentException("Epochs should not be negative.");
        }
        if (this.parameters.getBatchSize() != null && this.parameters.getBatchSize() < 0) {
            throw new IllegalArgumentException("MiniBatchSize should not be negative.");
        }
        if (this.parameters.getLoggingInterval() != null && this.parameters.getLoggingInterval() < -1) {
            throw new IllegalArgumentException("Invalid Logging intervals");
        }
        this.loggingInterval = Optional.ofNullable(this.parameters.getLoggingInterval()).orElse(-1);
        this.minibatchSize = Optional.ofNullable(this.parameters.getBatchSize()).orElse(1);
        this.seed = Optional.ofNullable(this.parameters.getSeed()).orElse(DEFAULT_SEED);
    }

    @Override
    public void initModel(MLModel model, Map<String, Object> params, Encryptor encryptor) {
        this.regressionModel = (Model)ModelSerDeSer.deserialize(model);
    }

    @Override
    public void close() {
        this.regressionModel = null;
    }

    @Override
    public boolean isModelReady() {
        return this.regressionModel != null;
    }

    @Override
    public MLOutput predict(MLInput mlInput) {
        if (this.regressionModel == null) {
            throw new IllegalArgumentException("model not deployed");
        }
        DataFrame dataFrame = ((DataFrameInputDataset)mlInput.getInputDataset()).getDataFrame();
        MutableDataset predictionDataset = TribuoUtil.generateDataset(dataFrame, new RegressionFactory(), "Linear regression prediction data from opensearch", TribuoOutputType.REGRESSOR);
        List predictions = this.regressionModel.predict(predictionDataset);
        ArrayList listPrediction = new ArrayList();
        predictions.forEach(e -> listPrediction.add(Collections.singletonMap(((Regressor)e.getOutput()).getNames()[0], ((Regressor)e.getOutput()).getValues()[0])));
        return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(listPrediction)).build();
    }

    @Override
    public MLOutput predict(MLInput mlInput, MLModel model) {
        if (model == null) {
            throw new IllegalArgumentException("No model found for linear regression prediction.");
        }
        this.regressionModel = (Model)ModelSerDeSer.deserialize(model);
        return this.predict(mlInput);
    }

    @Override
    public MLModel train(MLInput mlInput) {
        DataFrame dataFrame = ((DataFrameInputDataset)mlInput.getInputDataset()).getDataFrame();
        MutableDataset trainDataset = TribuoUtil.generateDatasetWithTarget(dataFrame, new RegressionFactory(), "Linear regression training data from opensearch", TribuoOutputType.REGRESSOR, this.parameters.getTarget());
        Integer epochs = Optional.ofNullable(this.parameters.getEpochs()).orElse(1000);
        LinearSGDTrainer linearSGDTrainer = new LinearSGDTrainer(this.objective, this.optimiser, epochs.intValue(), this.loggingInterval, this.minibatchSize, this.seed);
        Model regressionModel = linearSGDTrainer.train(trainDataset);
        MLModel model = MLModel.builder().name(FunctionName.LINEAR_REGRESSION.name()).algorithm(FunctionName.LINEAR_REGRESSION).version(VERSION).content(ModelSerDeSer.serializeToBase64(regressionModel)).modelState(MLModelState.TRAINED).build();
        return model;
    }
}

