/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.agents.plan.actions;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import org.apache.flink.agents.api.Event;
import org.apache.flink.agents.api.chat.messages.ChatMessage;
import org.apache.flink.agents.api.chat.messages.MessageRole;
import org.apache.flink.agents.api.chat.model.BaseChatModelSetup;
import org.apache.flink.agents.api.context.MemoryObject;
import org.apache.flink.agents.api.context.RunnerContext;
import org.apache.flink.agents.api.event.ChatRequestEvent;
import org.apache.flink.agents.api.event.ChatResponseEvent;
import org.apache.flink.agents.api.event.ToolRequestEvent;
import org.apache.flink.agents.api.event.ToolResponseEvent;
import org.apache.flink.agents.api.resource.ResourceType;
import org.apache.flink.agents.api.tools.ToolResponse;
import org.apache.flink.agents.plan.JavaFunction;
import org.apache.flink.agents.plan.actions.Action;

public class ChatModelAction {
    private static final String TOOL_CALL_CONTEXT = "_TOOL_CALL_CONTEXT";
    private static final String TOOL_REQUEST_EVENT_CONTEXT = "_TOOL_REQUEST_EVENT_CONTEXT";
    private static final String INITIAL_REQUEST_ID = "initialRequestId";
    private static final String MODEL = "model";

    public static Action getChatModelAction() throws Exception {
        return new Action("chat_model_action", new JavaFunction(ChatModelAction.class, "processChatRequestOrToolResponse", new Class[]{Event.class, RunnerContext.class}), List.of(ChatRequestEvent.class.getName(), ToolResponseEvent.class.getName()));
    }

    public static void chat(UUID initialRequestId, String model, List<ChatMessage> messages, RunnerContext ctx) throws Exception {
        BaseChatModelSetup chatModel = (BaseChatModelSetup)ctx.getResource(model, ResourceType.CHAT_MODEL);
        ChatMessage response = chatModel.chat(messages, Map.of());
        MemoryObject stm = ctx.getShortTermMemory();
        if (!response.getToolCalls().isEmpty()) {
            Map toolCallContext = stm.isExist(TOOL_CALL_CONTEXT) ? (Map)stm.get(TOOL_CALL_CONTEXT).getValue() : new HashMap();
            if (!toolCallContext.containsKey(initialRequestId)) {
                toolCallContext.put(initialRequestId, messages);
            }
            List messageContext = (List)toolCallContext.get(initialRequestId);
            messageContext.add(response);
            stm.set(TOOL_CALL_CONTEXT, toolCallContext);
            ToolRequestEvent toolRequestEvent = new ToolRequestEvent(model, response.getToolCalls());
            Map<UUID, Map<String, String>> toolRequestEventContext = stm.isExist(TOOL_REQUEST_EVENT_CONTEXT) ? (Map)stm.get(TOOL_REQUEST_EVENT_CONTEXT).getValue() : new HashMap<UUID, Map<String, String>>();
            toolRequestEventContext.put(toolRequestEvent.getId(), Map.of(INITIAL_REQUEST_ID, initialRequestId, MODEL, model));
            stm.set(TOOL_REQUEST_EVENT_CONTEXT, toolRequestEventContext);
            ctx.sendEvent(toolRequestEvent);
        } else {
            Map toolCallContext;
            if (stm.isExist(TOOL_CALL_CONTEXT) && (toolCallContext = (Map)stm.get(TOOL_CALL_CONTEXT).getValue()).containsKey(initialRequestId)) {
                toolCallContext.remove(initialRequestId);
                stm.set(TOOL_CALL_CONTEXT, toolCallContext);
            }
            ctx.sendEvent(new ChatResponseEvent(initialRequestId, response));
        }
    }

    public static void processChatRequestOrToolResponse(Event event, RunnerContext ctx) throws Exception {
        MemoryObject stm = ctx.getShortTermMemory();
        if (event instanceof ChatRequestEvent) {
            ChatRequestEvent chatRequestEvent = (ChatRequestEvent)event;
            ChatModelAction.chat(chatRequestEvent.getId(), chatRequestEvent.getModel(), chatRequestEvent.getMessages(), ctx);
        } else if (event instanceof ToolResponseEvent) {
            ToolResponseEvent toolResponseEvent = (ToolResponseEvent)event;
            UUID toolRequestId = toolResponseEvent.getRequestId();
            Map toolRequestEventContext = (Map)stm.get(TOOL_REQUEST_EVENT_CONTEXT).getValue();
            Map context = (Map)toolRequestEventContext.get(toolRequestId);
            UUID initialRequestId = (UUID)context.get(INITIAL_REQUEST_ID);
            String model = (String)context.get(MODEL);
            toolRequestEventContext.remove(toolRequestId);
            stm.set(TOOL_REQUEST_EVENT_CONTEXT, toolRequestEventContext);
            Map<String, ToolResponse> responses = toolResponseEvent.getResponses();
            Map<String, Boolean> success = toolResponseEvent.getSuccess();
            Map toolCallContext = (Map)stm.get(TOOL_CALL_CONTEXT).getValue();
            List messages = (List)toolCallContext.get(initialRequestId);
            for (Map.Entry<String, ToolResponse> entry : responses.entrySet()) {
                HashMap<String, Object> extraArgs = new HashMap<String, Object>();
                String toolCallId = entry.getKey();
                if (toolResponseEvent.getExternalIds().containsKey(toolCallId)) {
                    extraArgs.put("externalId", toolResponseEvent.getExternalIds().get(toolCallId));
                }
                ToolResponse response = entry.getValue();
                if (success.get(toolCallId).booleanValue() && response.isSuccess()) {
                    messages.add(new ChatMessage(MessageRole.TOOL, String.valueOf(response.getResult()), extraArgs));
                    continue;
                }
                messages.add(new ChatMessage(MessageRole.TOOL, String.valueOf(response.getError()), extraArgs));
            }
            stm.set(TOOL_CALL_CONTEXT, toolCallContext);
            ChatModelAction.chat(initialRequestId, model, messages, ctx);
        } else {
            throw new RuntimeException(String.format("Unexpected type event %s", event));
        }
    }
}

