/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.flowframework.workflow;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.common.FlowFrameworkSettings;
import org.opensearch.flowframework.common.WorkflowResources;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.model.ResourceCreated;
import org.opensearch.flowframework.model.Template;
import org.opensearch.flowframework.model.Workflow;
import org.opensearch.flowframework.model.WorkflowEdge;
import org.opensearch.flowframework.model.WorkflowNode;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.flowframework.workflow.ProcessNode;
import org.opensearch.flowframework.workflow.WorkflowData;
import org.opensearch.flowframework.workflow.WorkflowDataStep;
import org.opensearch.flowframework.workflow.WorkflowStep;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
import org.opensearch.plugins.PluginInfo;
import org.opensearch.plugins.PluginsService;
import org.opensearch.threadpool.ThreadPool;

public class WorkflowProcessSorter {
    private static final Logger logger = LogManager.getLogger(WorkflowProcessSorter.class);
    public static final Set<String> WORKFLOW_STEP_DENYLIST = Set.of("delete_index", "delete_ingest_pipeline", "delete_search_pipeline");
    private WorkflowStepFactory workflowStepFactory;
    private ThreadPool threadPool;
    private Integer maxWorkflowSteps;

    public WorkflowProcessSorter(WorkflowStepFactory workflowStepFactory, ThreadPool threadPool, FlowFrameworkSettings flowFrameworkSettings) {
        this.workflowStepFactory = workflowStepFactory;
        this.threadPool = threadPool;
        this.maxWorkflowSteps = flowFrameworkSettings.getMaxWorkflowSteps();
    }

    public List<ProcessNode> sortProcessNodes(Workflow workflow, String workflowId, Map<String, String> params, String tenantId) {
        if (workflow.nodes().size() > this.maxWorkflowSteps) {
            throw new FlowFrameworkException("Workflow " + workflowId + " has " + workflow.nodes().size() + " nodes, which exceeds the maximum of " + this.maxWorkflowSteps + ". Change the setting [" + FlowFrameworkSettings.MAX_WORKFLOW_STEPS.getKey() + "] to increase this.", RestStatus.BAD_REQUEST);
        }
        for (WorkflowNode node : workflow.nodes()) {
            if (!WORKFLOW_STEP_DENYLIST.contains(node.type())) continue;
            throw new FlowFrameworkException("The step type [" + node.type() + "] for node [" + node.id() + "] can not be used in a workflow.", RestStatus.FORBIDDEN);
        }
        List<WorkflowNode> sortedNodes = WorkflowProcessSorter.topologicalSort(workflow.nodes(), workflow.edges());
        ArrayList<ProcessNode> nodes = new ArrayList<ProcessNode>();
        HashMap<String, ProcessNode> idToNodeMap = new HashMap<String, ProcessNode>();
        for (WorkflowNode node : sortedNodes) {
            WorkflowStep step = this.workflowStepFactory.createStep(node.type());
            WorkflowData data = new WorkflowData(node.userInputs(), workflow.userParams(), workflowId, node.id());
            List<ProcessNode> predecessorNodes = workflow.edges().stream().filter(e -> e.destination().equals(node.id())).map(e -> (ProcessNode)idToNodeMap.get(e.source())).collect(Collectors.toList());
            TimeValue nodeTimeout = this.parseTimeout(node);
            ProcessNode processNode = new ProcessNode(node.id(), step, node.previousNodeInputs(), params, data, predecessorNodes, this.threadPool, "opensearch_provision_workflow", nodeTimeout, tenantId);
            idToNodeMap.put(processNode.id(), processNode);
            nodes.add(processNode);
        }
        return nodes;
    }

    public List<ProcessNode> createReprovisionSequence(String workflowId, Template originalTemplate, Template updatedTemplate, List<ResourceCreated> resourcesCreated, String tenantId) throws Exception {
        Workflow updatedWorkflow = updatedTemplate.workflows().get("provision");
        if (updatedWorkflow.nodes().size() > this.maxWorkflowSteps) {
            throw new FlowFrameworkException("Workflow " + workflowId + " has " + updatedWorkflow.nodes().size() + " nodes, which exceeds the maximum of " + this.maxWorkflowSteps + ". Change the setting [" + FlowFrameworkSettings.MAX_WORKFLOW_STEPS.getKey() + "] to increase this.", RestStatus.BAD_REQUEST);
        }
        List<WorkflowNode> sortedUpdatedNodes = WorkflowProcessSorter.topologicalSort(updatedWorkflow.nodes(), updatedWorkflow.edges());
        Map<String, WorkflowNode> originalTemplateMap = originalTemplate.workflows().get("provision").nodes().stream().collect(Collectors.toMap(WorkflowNode::id, node -> node));
        if (!originalTemplateMap.values().stream().allMatch(sortedUpdatedNodes::contains)) {
            throw new FlowFrameworkException("Workflow Step deletion is not supported when reprovisioning a template.", RestStatus.BAD_REQUEST);
        }
        List<ProcessNode> reprovisionSequence = this.createReprovisionSequence(workflowId, updatedWorkflow, sortedUpdatedNodes, originalTemplateMap, resourcesCreated, tenantId);
        if (reprovisionSequence.stream().allMatch(n -> n.workflowStep().getName().equals("workflow_data_step"))) {
            throw new FlowFrameworkException("Template does not contain any modifications", RestStatus.BAD_REQUEST);
        }
        return reprovisionSequence;
    }

    private List<ProcessNode> createReprovisionSequence(String workflowId, Workflow updatedWorkflow, List<WorkflowNode> sortedUpdatedNodes, Map<String, WorkflowNode> originalTemplateMap, List<ResourceCreated> resourcesCreated, String tenantId) throws Exception {
        HashMap<String, ProcessNode> idToNodeMap = new HashMap<String, ProcessNode>();
        ArrayList<ProcessNode> reprovisionSequence = new ArrayList<ProcessNode>();
        for (WorkflowNode node : sortedUpdatedNodes) {
            ProcessNode processNode = this.createProcessNode(updatedWorkflow, node, originalTemplateMap, resourcesCreated, workflowId, idToNodeMap, tenantId);
            if (processNode == null) continue;
            idToNodeMap.put(processNode.id(), processNode);
            reprovisionSequence.add(processNode);
        }
        return reprovisionSequence;
    }

    private ProcessNode createProcessNode(Workflow updatedWorkflow, WorkflowNode node, Map<String, WorkflowNode> originalTemplateMap, List<ResourceCreated> resourcesCreated, String workflowId, Map<String, ProcessNode> idToNodeMap, String tenantId) throws Exception {
        WorkflowData data = new WorkflowData(node.userInputs(), updatedWorkflow.userParams(), workflowId, node.id());
        List<ProcessNode> predecessorNodes = updatedWorkflow.edges().stream().filter(e -> e.destination().equals(node.id())).map(e -> (ProcessNode)idToNodeMap.get(e.source())).collect(Collectors.toList());
        TimeValue nodeTimeout = this.parseTimeout(node);
        if (!originalTemplateMap.containsKey(node.id())) {
            return this.createNewProcessNode(node, data, predecessorNodes, nodeTimeout, tenantId);
        }
        WorkflowNode originalNode = originalTemplateMap.get(node.id());
        if (this.shouldUpdateNode(node, originalNode)) {
            return this.createUpdateProcessNode(node, data, predecessorNodes, nodeTimeout, tenantId);
        }
        return this.createWorkflowDataStepNode(node, data, predecessorNodes, nodeTimeout, resourcesCreated, tenantId);
    }

    private ProcessNode createNewProcessNode(WorkflowNode node, WorkflowData data, List<ProcessNode> predecessorNodes, TimeValue nodeTimeout, String tenantId) {
        WorkflowStep step = this.workflowStepFactory.createStep(node.type());
        return new ProcessNode(node.id(), step, node.previousNodeInputs(), Collections.emptyMap(), data, predecessorNodes, this.threadPool, "opensearch_provision_workflow", nodeTimeout, tenantId);
    }

    private ProcessNode createUpdateProcessNode(WorkflowNode node, WorkflowData data, List<ProcessNode> predecessorNodes, TimeValue nodeTimeout, String tenantId) throws FlowFrameworkException {
        String updateStepName = WorkflowResources.getUpdateStepByWorkflowStep(node.type());
        if (updateStepName != null) {
            WorkflowStep step = this.workflowStepFactory.createStep(updateStepName);
            return new ProcessNode(node.id(), step, node.previousNodeInputs(), Collections.emptyMap(), data, predecessorNodes, this.threadPool, "opensearch_provision_workflow", nodeTimeout, tenantId);
        }
        throw new FlowFrameworkException("Workflow Step " + node.id() + " does not support updates when reprovisioning.", RestStatus.BAD_REQUEST);
    }

    private ProcessNode createWorkflowDataStepNode(WorkflowNode node, WorkflowData data, List<ProcessNode> predecessorNodes, TimeValue nodeTimeout, List<ResourceCreated> resourcesCreated, String tenantId) {
        ResourceCreated nodeResource = resourcesCreated.stream().filter(rc -> rc.workflowStepId().equals(node.id())).findFirst().orElse(null);
        if (nodeResource != null) {
            return new ProcessNode(node.id(), new WorkflowDataStep(nodeResource), node.previousNodeInputs(), Collections.emptyMap(), data, predecessorNodes, this.threadPool, "opensearch_provision_workflow", nodeTimeout, tenantId);
        }
        return null;
    }

    private boolean shouldUpdateNode(WorkflowNode node, WorkflowNode originalNode) throws Exception {
        return !node.previousNodeInputs().equals(originalNode.previousNodeInputs()) || !ParseUtils.userInputsEquals(originalNode.userInputs(), node.userInputs());
    }

    public void validate(List<ProcessNode> processNodes, PluginsService pluginsService) throws Exception {
        List<String> installedPlugins = pluginsService.info().getPluginInfos().stream().map(PluginInfo::getName).collect(Collectors.toList());
        this.validatePluginsInstalled(processNodes, installedPlugins);
        this.validateGraph(processNodes);
    }

    public void validatePluginsInstalled(List<ProcessNode> processNodes, List<String> installedPlugins) throws Exception {
        for (ProcessNode processNode : processNodes) {
            String nodeType = processNode.workflowStep().getName();
            ArrayList<String> requiredPlugins = new ArrayList<String>(WorkflowStepFactory.WorkflowSteps.getRequiredPluginsByWorkflowType(nodeType));
            if (installedPlugins.containsAll(requiredPlugins)) continue;
            requiredPlugins.removeAll(installedPlugins);
            throw new FlowFrameworkException("The workflowStep " + processNode.workflowStep().getName() + " requires the following plugins to be installed : " + ((Object)requiredPlugins).toString(), RestStatus.BAD_REQUEST);
        }
    }

    public void validateGraph(List<ProcessNode> processNodes) throws Exception {
        for (ProcessNode processNode : processNodes) {
            ArrayList<String> expectedInputs;
            List predecessorNodeTypes = processNode.predecessors().stream().map(x -> x.workflowStep().getName()).collect(Collectors.toList());
            List predecessorOutputs = predecessorNodeTypes.stream().map(nodeType -> WorkflowStepFactory.WorkflowSteps.getOutputByWorkflowType(nodeType)).flatMap(Collection::stream).collect(Collectors.toList());
            ArrayList<String> currentNodeUserInputs = new ArrayList<String>(processNode.input().getContent().keySet());
            List allInputs = Stream.concat(predecessorOutputs.stream(), currentNodeUserInputs.stream()).collect(Collectors.toList());
            if (allInputs.containsAll(expectedInputs = new ArrayList<String>(WorkflowStepFactory.WorkflowSteps.getInputByWorkflowType(processNode.workflowStep().getName())))) continue;
            expectedInputs.removeAll(allInputs);
            throw new FlowFrameworkException("Invalid workflow, node [" + processNode.id() + "] missing the following required inputs : " + ((Object)expectedInputs).toString(), RestStatus.BAD_REQUEST);
        }
    }

    protected TimeValue parseTimeout(WorkflowNode node) {
        TimeValue nodeTimeoutValue = Optional.ofNullable(WorkflowStepFactory.WorkflowSteps.getTimeoutByWorkflowType(node.type())).orElse(WorkflowNode.NODE_TIMEOUT_DEFAULT_VALUE);
        String nodeTimeoutAsString = nodeTimeoutValue.getSeconds() + "s";
        String timeoutValue = (String)node.userInputs().getOrDefault("node_timeout", nodeTimeoutAsString);
        String fieldName = String.join((CharSequence)".", node.id(), "user_inputs", "node_timeout");
        TimeValue userInputTimeValue = TimeValue.parseTimeValue((String)timeoutValue, (String)fieldName);
        if (userInputTimeValue.millis() < 0L) {
            throw new FlowFrameworkException("Failed to parse timeout value [" + timeoutValue + "] for field [" + fieldName + "]. Must be positive", RestStatus.BAD_REQUEST);
        }
        return userInputTimeValue;
    }

    private static List<WorkflowNode> topologicalSort(List<WorkflowNode> workflowNodes, List<WorkflowEdge> workflowEdges) {
        HashMap<String, WorkflowNode> nodeMap = new HashMap<String, WorkflowNode>();
        for (WorkflowNode node : workflowNodes) {
            if (nodeMap.containsKey(node.id())) {
                throw new FlowFrameworkException("Duplicate node id " + node.id() + ".", RestStatus.BAD_REQUEST);
            }
            nodeMap.put(node.id(), node);
        }
        for (WorkflowEdge edge : workflowEdges) {
            String source = edge.source();
            if (!nodeMap.containsKey(source)) {
                throw new FlowFrameworkException("Edge source " + (String)source + " does not correspond to a node.", RestStatus.BAD_REQUEST);
            }
            String dest = edge.destination();
            if (!nodeMap.containsKey(dest)) {
                throw new FlowFrameworkException("Edge destination " + dest + " does not correspond to a node.", RestStatus.BAD_REQUEST);
            }
            if (!source.equals(dest)) continue;
            throw new FlowFrameworkException("Edge connects node " + (String)source + " to itself.", RestStatus.BAD_REQUEST);
        }
        HashMap<WorkflowNode, Set> predecessorEdges = new HashMap<WorkflowNode, Set>();
        HashMap successorEdges = new HashMap();
        for (WorkflowEdge edge : workflowEdges) {
            WorkflowNode source = (WorkflowNode)nodeMap.get(edge.source());
            WorkflowNode dest = (WorkflowNode)nodeMap.get(edge.destination());
            predecessorEdges.computeIfAbsent(dest, k -> new HashSet()).add(edge);
            successorEdges.computeIfAbsent(source, k -> new HashSet()).add(edge);
        }
        HashSet<WorkflowEdge> graph = new HashSet<WorkflowEdge>(workflowEdges);
        ArrayList<WorkflowNode> sortedNodes = new ArrayList<WorkflowNode>();
        Queue sourceNodes = workflowNodes.stream().filter(n -> !predecessorEdges.containsKey(n)).collect(ArrayDeque::new, ArrayDeque::add, ArrayDeque::addAll);
        if (sourceNodes.isEmpty()) {
            throw new FlowFrameworkException("No start node detected: all nodes have a predecessor.", RestStatus.BAD_REQUEST);
        }
        logger.debug("Start node(s): {}", (Object)sourceNodes);
        while (!sourceNodes.isEmpty()) {
            WorkflowNode n2 = (WorkflowNode)sourceNodes.poll();
            sortedNodes.add(n2);
            for (WorkflowEdge e : successorEdges.getOrDefault(n2, Collections.emptySet())) {
                WorkflowNode m = (WorkflowNode)nodeMap.get(e.destination());
                graph.remove(e);
                if (!((Set)predecessorEdges.get(m)).stream().noneMatch(graph::contains)) continue;
                sourceNodes.add(m);
            }
        }
        if (!graph.isEmpty()) {
            throw new FlowFrameworkException("Cycle detected: " + String.valueOf(graph), RestStatus.BAD_REQUEST);
        }
        logger.debug("Execution sequence: {}", sortedNodes);
        return sortedNodes;
    }
}

