/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.remote;

import com.jayway.jsonpath.JsonPath;
import com.jayway.jsonpath.Predicate;
import java.io.IOException;
import java.net.URI;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import lombok.Generated;
import org.apache.commons.text.StringEscapeUtils;
import org.apache.commons.text.StringSubstitutor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.connector.MLPostProcessFunction;
import org.opensearch.ml.common.connector.MLPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.DefaultPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.RemoteInferencePreProcessFunction;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLGuard;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.engine.utils.ScriptUtils;
import org.opensearch.script.ScriptService;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.signer.Aws4Signer;
import software.amazon.awssdk.auth.signer.params.Aws4SignerParams;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.http.SdkHttpMethod;
import software.amazon.awssdk.regions.Region;

public class ConnectorUtils {
    @Generated
    private static final Logger log = LogManager.getLogger(ConnectorUtils.class);
    private static final Aws4Signer signer = Aws4Signer.create();

    public static RemoteInferenceInputDataSet processInput(String action, MLInput mlInput, Connector connector, Map<String, String> parameters, ScriptService scriptService) {
        if (mlInput == null) {
            throw new IllegalArgumentException("Input is null");
        }
        Optional connectorAction = connector.findAction(action);
        if (connectorAction.isEmpty()) {
            throw new IllegalArgumentException("no " + action + " action found");
        }
        RemoteInferenceInputDataSet inputData = ConnectorUtils.processMLInput(action, mlInput, connector, parameters, scriptService);
        ConnectorUtils.escapeRemoteInferenceInputData(inputData);
        return inputData;
    }

    private static RemoteInferenceInputDataSet processMLInput(String action, MLInput mlInput, Connector connector, Map<String, String> parameters, ScriptService scriptService) {
        String preProcessFunction = ConnectorUtils.getPreprocessFunction(action, mlInput, connector);
        if (preProcessFunction == null) {
            if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
                return (RemoteInferenceInputDataSet)mlInput.getInputDataset();
            }
            throw new IllegalArgumentException("pre_process_function not defined in connector");
        }
        if (MLPreProcessFunction.contains((String)(preProcessFunction = ConnectorUtils.fillProcessFunctionParameter(parameters, preProcessFunction)))) {
            Function function = MLPreProcessFunction.get((String)preProcessFunction);
            return (RemoteInferenceInputDataSet)function.apply(mlInput);
        }
        if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
            if (parameters.containsKey("pre_process_function.process_remote_inference_input") && Boolean.parseBoolean(parameters.get("pre_process_function.process_remote_inference_input"))) {
                HashMap<String, String> params = new HashMap<String, String>();
                params.putAll(connector.getParameters());
                params.putAll(parameters);
                RemoteInferencePreProcessFunction function = new RemoteInferencePreProcessFunction(scriptService, preProcessFunction, params);
                return function.apply(mlInput);
            }
            return (RemoteInferenceInputDataSet)mlInput.getInputDataset();
        }
        MLInput newInput = ConnectorUtils.escapeMLInput(mlInput);
        boolean convertInputToJsonString = parameters.containsKey("pre_process_function.convert_input_to_json_string") && Boolean.parseBoolean(parameters.get("pre_process_function.convert_input_to_json_string"));
        DefaultPreProcessFunction function = DefaultPreProcessFunction.builder().scriptService(scriptService).preProcessFunction(preProcessFunction).convertInputToJsonString(convertInputToJsonString).build();
        return function.apply(newInput);
    }

    private static MLInput escapeMLInput(MLInput mlInput) {
        if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) {
            List docs = ((TextDocsInputDataSet)mlInput.getInputDataset()).getDocs();
            List newDocs = StringUtils.processTextDocs((List)docs);
            TextDocsInputDataSet newInputData = ((TextDocsInputDataSet)mlInput.getInputDataset()).toBuilder().docs(newDocs).build();
            return mlInput.toBuilder().inputDataset((MLInputDataset)newInputData).build();
        }
        if (mlInput.getInputDataset() instanceof TextSimilarityInputDataSet) {
            String query = ((TextSimilarityInputDataSet)mlInput.getInputDataset()).getQueryText();
            String newQuery = StringUtils.processTextDoc((String)query);
            List docs = ((TextSimilarityInputDataSet)mlInput.getInputDataset()).getTextDocs();
            List newDocs = StringUtils.processTextDocs((List)docs);
            TextSimilarityInputDataSet newInputData = ((TextSimilarityInputDataSet)mlInput.getInputDataset()).toBuilder().queryText(newQuery).textDocs(newDocs).build();
            return mlInput.toBuilder().inputDataset((MLInputDataset)newInputData).build();
        }
        return mlInput;
    }

    public static void escapeRemoteInferenceInputData(RemoteInferenceInputDataSet inputData) {
        HashMap newParameters = new HashMap();
        if (inputData.getParameters() != null) {
            inputData.getParameters().forEach((key, value) -> {
                if (value == null) {
                    newParameters.put(key, null);
                } else if (StringUtils.isJson((String)value)) {
                    newParameters.put(key, value);
                } else {
                    newParameters.put(key, StringEscapeUtils.escapeJson((String)value));
                }
            });
            inputData.setParameters(newParameters);
        }
    }

    private static String getPreprocessFunction(String action, MLInput mlInput, Connector connector) {
        Optional connectorAction = connector.findAction(action);
        String preProcessFunction = ((ConnectorAction)connectorAction.get()).getPreProcessFunction();
        if (preProcessFunction != null) {
            return preProcessFunction;
        }
        if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) {
            return "connector.pre_process.default.embedding";
        }
        return null;
    }

    public static ModelTensors processOutput(String action, String modelResponse, Connector connector, ScriptService scriptService, Map<String, String> parameters, MLGuard mlGuard) throws IOException {
        boolean scriptReturnModelTensor;
        if (modelResponse == null) {
            throw new IllegalArgumentException("model response is null");
        }
        if (mlGuard != null && !mlGuard.validate(modelResponse, MLGuard.Type.OUTPUT, Map.of("question", StringUtils.processTextDoc((String)modelResponse))).booleanValue()) {
            throw new IllegalArgumentException("guardrails triggered for LLM output");
        }
        ArrayList modelTensors = new ArrayList();
        Optional connectorAction = connector.findAction(action);
        if (connectorAction.isEmpty()) {
            throw new IllegalArgumentException("no " + action + " action found");
        }
        String postProcessFunction = ((ConnectorAction)connectorAction.get()).getPostProcessFunction();
        postProcessFunction = ConnectorUtils.fillProcessFunctionParameter(parameters, postProcessFunction);
        String responseFilter = parameters.get("response_filter");
        if (MLPostProcessFunction.contains((String)postProcessFunction)) {
            if (org.apache.commons.lang3.StringUtils.isBlank((CharSequence)responseFilter)) {
                responseFilter = MLPostProcessFunction.getResponseFilter((String)postProcessFunction);
            }
            Object filteredOutput = JsonPath.read((String)modelResponse, (String)responseFilter, (Predicate[])new Predicate[0]);
            List processedResponse = (List)MLPostProcessFunction.get((String)postProcessFunction).apply(filteredOutput);
            return ModelTensors.builder().mlModelTensors(processedResponse).build();
        }
        Optional<String> processedResponse = ScriptUtils.executePostProcessFunction(scriptService, postProcessFunction, modelResponse);
        String response = processedResponse.orElse(modelResponse);
        boolean bl = scriptReturnModelTensor = postProcessFunction != null && processedResponse.isPresent() && StringUtils.isJson((String)response);
        if (responseFilter == null) {
            connector.parseResponse((Object)response, modelTensors, scriptReturnModelTensor);
        } else {
            Object filteredResponse = JsonPath.parse((String)response).read(parameters.get("response_filter"), new Predicate[0]);
            connector.parseResponse(filteredResponse, modelTensors, scriptReturnModelTensor);
        }
        return ModelTensors.builder().mlModelTensors(modelTensors).build();
    }

    private static String fillProcessFunctionParameter(Map<String, String> parameters, String processFunction) {
        if (processFunction != null && processFunction.contains("${parameters.")) {
            HashMap<String, String> tmpParameters = new HashMap<String, String>();
            for (String key : parameters.keySet()) {
                tmpParameters.put(key, StringUtils.gson.toJson((Object)parameters.get(key)));
            }
            StringSubstitutor substitutor = new StringSubstitutor(tmpParameters, "${parameters.", "}");
            processFunction = substitutor.replace(processFunction);
        }
        return processFunction;
    }

    public static SdkHttpFullRequest signRequest(SdkHttpFullRequest request, String accessKey, String secretKey, String sessionToken, String signingName, String region) {
        AwsBasicCredentials credentials = sessionToken == null ? AwsBasicCredentials.create((String)accessKey, (String)secretKey) : AwsSessionCredentials.create((String)accessKey, (String)secretKey, (String)sessionToken);
        Aws4SignerParams params = Aws4SignerParams.builder().awsCredentials((AwsCredentials)credentials).signingName(signingName).signingRegion(Region.of((String)region)).build();
        return signer.sign(request, params);
    }

    public static SdkHttpFullRequest buildSdkRequest(String action, Connector connector, Map<String, String> parameters, String payload, SdkHttpMethod method) {
        String charset = parameters.getOrDefault("charset", "UTF-8");
        RequestBody requestBody = payload != null ? RequestBody.fromString((String)payload, (Charset)Charset.forName(charset)) : RequestBody.empty();
        if (SdkHttpMethod.POST == method && 0L == (Long)requestBody.optionalContentLength().get()) {
            log.error("Content length is 0. Aborting request to remote model");
            throw new IllegalArgumentException("Content length is 0. Aborting request to remote model");
        }
        String endpoint = connector.getActionEndpoint(action, parameters);
        SdkHttpFullRequest.Builder builder = SdkHttpFullRequest.builder().method(method).uri(URI.create(endpoint)).contentStreamProvider(requestBody.contentStreamProvider());
        Map headers = connector.getDecryptedHeaders();
        if (headers != null) {
            for (String key : headers.keySet()) {
                builder.putHeader(key, (String)headers.get(key));
            }
        }
        if (builder.matchingHeaders("Content-Type").isEmpty()) {
            builder.putHeader("Content-Type", "application/json");
        }
        if (builder.matchingHeaders("Content-Length").isEmpty()) {
            builder.putHeader("Content-Length", ((Long)requestBody.optionalContentLength().get()).toString());
        }
        return builder.build();
    }
}

