/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.agent;

import com.google.common.annotations.VisibleForTesting;
import com.google.gson.Gson;
import java.io.IOException;
import java.security.AccessController;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.get.GetRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.MediaType;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.core.xcontent.XContentParserUtils;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLAgentType;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.agent.MLMemorySpec;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
import org.opensearch.ml.common.output.Output;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.memory.Memory;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.engine.Executable;
import org.opensearch.ml.engine.algorithms.agent.MLAgentRunner;
import org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner;
import org.opensearch.ml.engine.algorithms.agent.MLConversationalFlowAgentRunner;
import org.opensearch.ml.engine.algorithms.agent.MLFlowAgentRunner;
import org.opensearch.ml.engine.annotation.Function;
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
import org.opensearch.ml.engine.memory.ConversationIndexMessage;
import org.opensearch.ml.memory.action.conversation.GetInteractionAction;
import org.opensearch.ml.memory.action.conversation.GetInteractionRequest;

@Function(value=FunctionName.AGENT)
public class MLAgentExecutor
implements Executable {
    @Generated
    private static final Logger log = LogManager.getLogger(MLAgentExecutor.class);
    public static final String MEMORY_ID = "memory_id";
    public static final String QUESTION = "question";
    public static final String PARENT_INTERACTION_ID = "parent_interaction_id";
    public static final String REGENERATE_INTERACTION_ID = "regenerate_interaction_id";
    public static final String MESSAGE_HISTORY_LIMIT = "message_history_limit";
    private Client client;
    private Settings settings;
    private ClusterService clusterService;
    private NamedXContentRegistry xContentRegistry;
    private Map<String, Tool.Factory> toolFactories;
    private Map<String, Memory.Factory> memoryFactoryMap;

    public MLAgentExecutor(Client client, Settings settings, ClusterService clusterService, NamedXContentRegistry xContentRegistry, Map<String, Tool.Factory> toolFactories, Map<String, Memory.Factory> memoryFactoryMap) {
        this.client = client;
        this.settings = settings;
        this.clusterService = clusterService;
        this.xContentRegistry = xContentRegistry;
        this.toolFactories = toolFactories;
        this.memoryFactoryMap = memoryFactoryMap;
    }

    @Override
    public void execute(Input input, ActionListener<Output> listener) {
        if (!(input instanceof AgentMLInput)) {
            throw new IllegalArgumentException("wrong input");
        }
        AgentMLInput agentMLInput = (AgentMLInput)input;
        String agentId = agentMLInput.getAgentId();
        RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet)agentMLInput.getInputDataset();
        if (inputDataSet == null || inputDataSet.getParameters() == null) {
            throw new IllegalArgumentException("Agent input data can not be empty.");
        }
        ArrayList<ModelTensors> outputs = new ArrayList<ModelTensors>();
        ArrayList modelTensors = new ArrayList();
        outputs.add(ModelTensors.builder().mlModelTensors(modelTensors).build());
        if (this.clusterService.state().metadata().hasIndex(".plugins-ml-agent")) {
            try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                GetRequest getRequest = new GetRequest(".plugins-ml-agent").id(agentId);
                this.client.get(getRequest, ActionListener.runBefore((ActionListener)ActionListener.wrap(r -> {
                    if (r.isExists()) {
                        try (XContentParser parser = this.createXContentParserFromRegistry(this.xContentRegistry, r.getSourceAsBytesRef());){
                            XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)parser.nextToken(), (XContentParser)parser);
                            MLAgent mlAgent = MLAgent.parse((XContentParser)parser);
                            MLMemorySpec memorySpec = mlAgent.getMemory();
                            String memoryId = (String)inputDataSet.getParameters().get(MEMORY_ID);
                            String parentInteractionId = (String)inputDataSet.getParameters().get(PARENT_INTERACTION_ID);
                            String regenerateInteractionId = (String)inputDataSet.getParameters().get(REGENERATE_INTERACTION_ID);
                            String appType = mlAgent.getAppType();
                            String question = (String)inputDataSet.getParameters().get(QUESTION);
                            if (memoryId == null && regenerateInteractionId != null) {
                                throw new IllegalArgumentException("A memory ID must be provided to regenerate.");
                            }
                            if (memorySpec != null && memorySpec.getType() != null && this.memoryFactoryMap.containsKey(memorySpec.getType()) && (memoryId == null || parentInteractionId == null)) {
                                ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory)this.memoryFactoryMap.get(memorySpec.getType());
                                conversationIndexMemoryFactory.create(question, memoryId, appType, (ActionListener<ConversationIndexMemory>)ActionListener.wrap(memory -> {
                                    inputDataSet.getParameters().put(MEMORY_ID, memory.getConversationId());
                                    ActionListener<Object> agentActionListener = this.createAgentActionListener(listener, outputs, modelTensors);
                                    if (regenerateInteractionId != null) {
                                        log.info("Regenerate for existing interaction {}", (Object)regenerateInteractionId);
                                        this.client.execute((ActionType)GetInteractionAction.INSTANCE, (ActionRequest)new GetInteractionRequest(regenerateInteractionId), ActionListener.wrap(interactionRes -> {
                                            inputDataSet.getParameters().putIfAbsent(QUESTION, interactionRes.getInteraction().getInput());
                                            this.saveRootInteractionAndExecute(agentActionListener, (ConversationIndexMemory)memory, inputDataSet, mlAgent);
                                        }, e -> {
                                            log.error("Failed to get existing interaction for regeneration", (Throwable)e);
                                            listener.onFailure(e);
                                        }));
                                    } else {
                                        this.saveRootInteractionAndExecute(agentActionListener, (ConversationIndexMemory)memory, inputDataSet, mlAgent);
                                    }
                                }, ex -> {
                                    log.error("Failed to read conversation memory", (Throwable)ex);
                                    listener.onFailure(ex);
                                }));
                            }
                            ActionListener<Object> agentActionListener = this.createAgentActionListener(listener, outputs, modelTensors);
                            this.executeAgent(inputDataSet, mlAgent, agentActionListener);
                        }
                    } else {
                        listener.onFailure((Exception)new ResourceNotFoundException("Agent not found", new Object[0]));
                    }
                }, e -> {
                    log.error("Failed to get agent", (Throwable)e);
                    listener.onFailure(e);
                }), () -> ((ThreadContext.StoredContext)context).restore()));
            }
        } else {
            listener.onFailure((Exception)new ResourceNotFoundException("Agent index not found", new Object[0]));
        }
    }

    private void saveRootInteractionAndExecute(ActionListener<Object> listener, ConversationIndexMemory memory, RemoteInferenceInputDataSet inputDataSet, MLAgent mlAgent) {
        String appType = mlAgent.getAppType();
        String question = (String)inputDataSet.getParameters().get(QUESTION);
        String regenerateInteractionId = (String)inputDataSet.getParameters().get(REGENERATE_INTERACTION_ID);
        ConversationIndexMessage msg = ConversationIndexMessage.conversationIndexMessageBuilder().type(appType).question(question).response("").finalAnswer(true).sessionId(memory.getConversationId()).build();
        memory.save(msg, null, null, null, ActionListener.wrap(interaction -> {
            log.info("Created parent interaction ID: " + interaction.getId());
            inputDataSet.getParameters().put(PARENT_INTERACTION_ID, interaction.getId());
            if (regenerateInteractionId != null) {
                memory.getMemoryManager().deleteInteractionAndTrace(regenerateInteractionId, (ActionListener<Boolean>)ActionListener.wrap(deleted -> this.executeAgent(inputDataSet, mlAgent, listener), e -> {
                    log.error("Failed to regenerate for interaction {}", (Object)regenerateInteractionId, e);
                    listener.onFailure(e);
                }));
            } else {
                this.executeAgent(inputDataSet, mlAgent, listener);
            }
        }, ex -> {
            log.error("Failed to create parent interaction", (Throwable)ex);
            listener.onFailure(ex);
        }));
    }

    private void executeAgent(RemoteInferenceInputDataSet inputDataSet, MLAgent mlAgent, ActionListener<Object> agentActionListener) {
        MLAgentRunner mlAgentRunner = this.getAgentRunner(mlAgent);
        mlAgentRunner.run(mlAgent, inputDataSet.getParameters(), agentActionListener);
    }

    private ActionListener<Object> createAgentActionListener(ActionListener<Output> listener, List<ModelTensors> outputs, List<ModelTensor> modelTensors) {
        return ActionListener.wrap(output -> {
            if (output != null) {
                Gson gson = new Gson();
                if (output instanceof ModelTensorOutput) {
                    ModelTensorOutput modelTensorOutput = (ModelTensorOutput)output;
                    modelTensorOutput.getMlModelOutputs().forEach(outs -> {
                        for (ModelTensor mlModelTensor : outs.getMlModelTensors()) {
                            modelTensors.add(mlModelTensor);
                        }
                    });
                } else if (output instanceof ModelTensor) {
                    modelTensors.add((ModelTensor)output);
                } else if (output instanceof List) {
                    if (((List)output).get(0) instanceof ModelTensor) {
                        ((List)output).forEach(mlModelTensor -> modelTensors.add((ModelTensor)mlModelTensor));
                    } else if (((List)output).get(0) instanceof ModelTensors) {
                        ((List)output).forEach(outs -> {
                            for (ModelTensor mlModelTensor : outs.getMlModelTensors()) {
                                modelTensors.add(mlModelTensor);
                            }
                        });
                    } else {
                        String result = output instanceof String ? (String)output : AccessController.doPrivileged(() -> gson.toJson(output));
                        modelTensors.add(ModelTensor.builder().name("response").result(result).build());
                    }
                } else {
                    String result = output instanceof String ? (String)output : AccessController.doPrivileged(() -> gson.toJson(output));
                    modelTensors.add(ModelTensor.builder().name("response").result(result).build());
                }
                listener.onResponse((Object)ModelTensorOutput.builder().mlModelOutputs(outputs).build());
            } else {
                listener.onResponse(null);
            }
        }, ex -> {
            log.error("Failed to run flow agent", (Throwable)ex);
            listener.onFailure(ex);
        });
    }

    @VisibleForTesting
    protected MLAgentRunner getAgentRunner(MLAgent mlAgent) {
        MLAgentType agentType = MLAgentType.from((String)mlAgent.getType().toUpperCase());
        switch (agentType) {
            case FLOW: {
                return new MLFlowAgentRunner(this.client, this.settings, this.clusterService, this.xContentRegistry, this.toolFactories, this.memoryFactoryMap);
            }
            case CONVERSATIONAL_FLOW: {
                return new MLConversationalFlowAgentRunner(this.client, this.settings, this.clusterService, this.xContentRegistry, this.toolFactories, this.memoryFactoryMap);
            }
            case CONVERSATIONAL: {
                return new MLChatAgentRunner(this.client, this.settings, this.clusterService, this.xContentRegistry, this.toolFactories, this.memoryFactoryMap);
            }
        }
        throw new IllegalArgumentException("Unsupported agent type: " + mlAgent.getType());
    }

    public XContentParser createXContentParserFromRegistry(NamedXContentRegistry xContentRegistry, BytesReference bytesReference) throws IOException {
        return XContentHelper.createParser((NamedXContentRegistry)xContentRegistry, (DeprecationHandler)LoggingDeprecationHandler.INSTANCE, (BytesReference)bytesReference, (MediaType)XContentType.JSON);
    }

    @Generated
    public Client getClient() {
        return this.client;
    }

    @Generated
    public Settings getSettings() {
        return this.settings;
    }

    @Generated
    public ClusterService getClusterService() {
        return this.clusterService;
    }

    @Generated
    public NamedXContentRegistry getXContentRegistry() {
        return this.xContentRegistry;
    }

    @Generated
    public Map<String, Tool.Factory> getToolFactories() {
        return this.toolFactories;
    }

    @Generated
    public Map<String, Memory.Factory> getMemoryFactoryMap() {
        return this.memoryFactoryMap;
    }

    @Generated
    public void setClient(Client client) {
        this.client = client;
    }

    @Generated
    public void setSettings(Settings settings) {
        this.settings = settings;
    }

    @Generated
    public void setClusterService(ClusterService clusterService) {
        this.clusterService = clusterService;
    }

    @Generated
    public void setXContentRegistry(NamedXContentRegistry xContentRegistry) {
        this.xContentRegistry = xContentRegistry;
    }

    @Generated
    public void setToolFactories(Map<String, Tool.Factory> toolFactories) {
        this.toolFactories = toolFactories;
    }

    @Generated
    public void setMemoryFactoryMap(Map<String, Memory.Factory> memoryFactoryMap) {
        this.memoryFactoryMap = memoryFactoryMap;
    }

    @Generated
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof MLAgentExecutor)) {
            return false;
        }
        MLAgentExecutor other = (MLAgentExecutor)o;
        if (!other.canEqual(this)) {
            return false;
        }
        Client this$client = this.getClient();
        Client other$client = other.getClient();
        if (this$client == null ? other$client != null : !this$client.equals(other$client)) {
            return false;
        }
        Settings this$settings = this.getSettings();
        Settings other$settings = other.getSettings();
        if (this$settings == null ? other$settings != null : !this$settings.equals(other$settings)) {
            return false;
        }
        ClusterService this$clusterService = this.getClusterService();
        ClusterService other$clusterService = other.getClusterService();
        if (this$clusterService == null ? other$clusterService != null : !this$clusterService.equals(other$clusterService)) {
            return false;
        }
        NamedXContentRegistry this$xContentRegistry = this.getXContentRegistry();
        NamedXContentRegistry other$xContentRegistry = other.getXContentRegistry();
        if (this$xContentRegistry == null ? other$xContentRegistry != null : !this$xContentRegistry.equals(other$xContentRegistry)) {
            return false;
        }
        Map<String, Tool.Factory> this$toolFactories = this.getToolFactories();
        Map<String, Tool.Factory> other$toolFactories = other.getToolFactories();
        if (this$toolFactories == null ? other$toolFactories != null : !((Object)this$toolFactories).equals(other$toolFactories)) {
            return false;
        }
        Map<String, Memory.Factory> this$memoryFactoryMap = this.getMemoryFactoryMap();
        Map<String, Memory.Factory> other$memoryFactoryMap = other.getMemoryFactoryMap();
        return !(this$memoryFactoryMap == null ? other$memoryFactoryMap != null : !((Object)this$memoryFactoryMap).equals(other$memoryFactoryMap));
    }

    @Generated
    protected boolean canEqual(Object other) {
        return other instanceof MLAgentExecutor;
    }

    @Generated
    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        Client $client = this.getClient();
        result = result * 59 + ($client == null ? 43 : $client.hashCode());
        Settings $settings = this.getSettings();
        result = result * 59 + ($settings == null ? 43 : $settings.hashCode());
        ClusterService $clusterService = this.getClusterService();
        result = result * 59 + ($clusterService == null ? 43 : $clusterService.hashCode());
        NamedXContentRegistry $xContentRegistry = this.getXContentRegistry();
        result = result * 59 + ($xContentRegistry == null ? 43 : $xContentRegistry.hashCode());
        Map<String, Tool.Factory> $toolFactories = this.getToolFactories();
        result = result * 59 + ($toolFactories == null ? 43 : ((Object)$toolFactories).hashCode());
        Map<String, Memory.Factory> $memoryFactoryMap = this.getMemoryFactoryMap();
        result = result * 59 + ($memoryFactoryMap == null ? 43 : ((Object)$memoryFactoryMap).hashCode());
        return result;
    }

    @Generated
    public String toString() {
        return "MLAgentExecutor(client=" + String.valueOf(this.getClient()) + ", settings=" + String.valueOf(this.getSettings()) + ", clusterService=" + String.valueOf(this.getClusterService()) + ", xContentRegistry=" + String.valueOf(this.getXContentRegistry()) + ", toolFactories=" + String.valueOf(this.getToolFactories()) + ", memoryFactoryMap=" + String.valueOf(this.getMemoryFactoryMap()) + ")";
    }

    @Generated
    public MLAgentExecutor() {
    }
}

