/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.model;

import com.google.common.math.Quantiles;
import java.time.Instant;
import java.util.DoubleSummaryStatistics;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.stream.DoubleStream;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.common.util.TokenBucket;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLGuard;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.engine.MLExecutable;
import org.opensearch.ml.engine.Predictable;
import org.opensearch.ml.profile.MLPredictRequestStats;

public class MLModelCache {
    @Generated
    private static final Logger log = LogManager.getLogger(MLModelCache.class);
    private MLModelState modelState;
    private FunctionName functionName;
    private Predictable predictor;
    private MLExecutable executor;
    private TokenBucket rateLimiter;
    private Map<String, TokenBucket> userRateLimiterMap;
    private Boolean isModelEnabled;
    private final Set<String> targetWorkerNodes = ConcurrentHashMap.newKeySet();
    private final Set<String> workerNodes = ConcurrentHashMap.newKeySet();
    private MLModel modelInfo;
    private final Queue<Double> modelInferenceDurationQueue = new ConcurrentLinkedQueue<Double>();
    private final Queue<Double> predictRequestDurationQueue = new ConcurrentLinkedQueue<Double>();
    private Long memSizeEstimationCPU;
    private Long memSizeEstimationGPU;
    private MLGuard mlGuard;
    private Map<String, String> modelInterface;
    private Boolean deployToAllNodes;
    private Instant lastAccessTime;
    private Boolean isAutoDeploying;

    public void setTargetWorkerNodes(List<String> targetWorkerNodes) {
        if (targetWorkerNodes == null || targetWorkerNodes.size() == 0) {
            throw new IllegalArgumentException("Null or empty target worker nodes");
        }
        this.targetWorkerNodes.clear();
        this.targetWorkerNodes.addAll(targetWorkerNodes);
    }

    public String[] getTargetWorkerNodes() {
        return this.targetWorkerNodes.toArray(new String[0]);
    }

    public void removeWorkerNode(String nodeId, boolean isFromUndeploy) {
        if (this.isDeployToAllNodes() || isFromUndeploy) {
            this.targetWorkerNodes.remove(nodeId);
        }
        if (isFromUndeploy) {
            this.deployToAllNodes = false;
        }
        this.workerNodes.remove(nodeId);
        if (this.targetWorkerNodes.isEmpty() || this.workerNodes.isEmpty()) {
            this.modelInfo = null;
        }
    }

    public void removeWorkerNodes(Set<String> removedNodes, boolean isFromUndeploy) {
        if (this.isDeployToAllNodes() || isFromUndeploy) {
            this.targetWorkerNodes.removeAll(removedNodes);
        }
        if (isFromUndeploy) {
            this.deployToAllNodes = false;
        }
        this.workerNodes.removeAll(removedNodes);
        if (this.targetWorkerNodes.isEmpty() || this.workerNodes.isEmpty()) {
            this.modelInfo = null;
        }
    }

    public void addWorkerNode(String nodeId) {
        if (this.isDeployToAllNodes()) {
            this.targetWorkerNodes.add(nodeId);
        }
        this.workerNodes.add(nodeId);
    }

    public String[] getWorkerNodes() {
        return this.workerNodes.toArray(new String[0]);
    }

    public void setModelInfo(MLModel modelInfo) {
        this.modelInfo = modelInfo;
    }

    public MLModel getCachedModelInfo() {
        return this.modelInfo;
    }

    public void syncWorkerNode(Set<String> workerNodes) {
        this.workerNodes.clear();
        this.workerNodes.addAll(workerNodes);
    }

    public void syncPlanningWorkerNodes(Set<String> planningWorkerNodes) {
        this.targetWorkerNodes.clear();
        this.targetWorkerNodes.addAll(planningWorkerNodes);
    }

    public boolean isDeployToAllNodes() {
        return this.deployToAllNodes != null && this.deployToAllNodes != false;
    }

    public void clearWorkerNodes() {
        this.workerNodes.clear();
    }

    public void clear() {
        this.modelState = null;
        this.functionName = null;
        this.workerNodes.clear();
        this.modelInfo = null;
        this.modelInferenceDurationQueue.clear();
        this.predictRequestDurationQueue.clear();
        if (this.predictor != null) {
            this.predictor.close();
        }
        this.memSizeEstimationCPU = 0L;
        this.memSizeEstimationGPU = 0L;
        if (this.executor != null) {
            this.executor.close();
        }
        this.isModelEnabled = null;
        this.rateLimiter = null;
        this.userRateLimiterMap = null;
        this.mlGuard = null;
        this.modelInterface = null;
    }

    public void addModelInferenceDuration(double duration, long maxRequestCount) {
        this.addInferenceDuration(duration, maxRequestCount, this.modelInferenceDurationQueue);
    }

    public void addPredictRequestDuration(double duration, long maxRequestCount) {
        this.addInferenceDuration(duration, maxRequestCount, this.predictRequestDurationQueue);
    }

    private void addInferenceDuration(double duration, long maxRequestCount, Queue<Double> queue) {
        this.resizeInferenceQueue(maxRequestCount, queue);
        if (maxRequestCount > 0L) {
            queue.add(duration);
        }
    }

    public void resizeMonitoringQueue(long maxRequestCount) {
        log.debug("resize inference duration monitoring queue with size {}", (Object)maxRequestCount);
        this.resizeInferenceQueue(maxRequestCount, this.predictRequestDurationQueue);
        this.resizeInferenceQueue(maxRequestCount, this.modelInferenceDurationQueue);
    }

    private void resizeInferenceQueue(long maxRequestCount, Queue<Double> queue) {
        if (maxRequestCount <= 0L) {
            queue.clear();
        } else {
            while ((long)queue.size() >= maxRequestCount) {
                queue.poll();
            }
        }
    }

    public MLPredictRequestStats getInferenceStats(boolean modelInference) {
        Queue<Double> queue;
        Queue<Double> queue2 = queue = modelInference ? this.modelInferenceDurationQueue : this.predictRequestDurationQueue;
        if (queue.size() > 0) {
            MLPredictRequestStats.MLPredictRequestStatsBuilder statsBuilder = MLPredictRequestStats.builder();
            DoubleStream doubleStream = queue.stream().mapToDouble(v -> v);
            DoubleSummaryStatistics doubleSummaryStatistics = doubleStream.summaryStatistics();
            statsBuilder.count(doubleSummaryStatistics.getCount());
            statsBuilder.max(doubleSummaryStatistics.getMax());
            statsBuilder.min(doubleSummaryStatistics.getMin());
            statsBuilder.average(doubleSummaryStatistics.getAverage());
            Quantiles.Scale percentiles = Quantiles.percentiles();
            statsBuilder.p50(percentiles.index(50).compute(queue));
            statsBuilder.p90(percentiles.index(90).compute(queue));
            statsBuilder.p99(percentiles.index(99).compute(queue));
            return statsBuilder.build();
        }
        return null;
    }

    public boolean isValidCache() {
        return this.modelState != null || this.workerNodes.size() > 0;
    }

    @Generated
    protected void setModelState(MLModelState modelState) {
        this.modelState = modelState;
    }

    @Generated
    protected MLModelState getModelState() {
        return this.modelState;
    }

    @Generated
    protected void setFunctionName(FunctionName functionName) {
        this.functionName = functionName;
    }

    @Generated
    protected FunctionName getFunctionName() {
        return this.functionName;
    }

    @Generated
    protected void setPredictor(Predictable predictor) {
        this.predictor = predictor;
    }

    @Generated
    protected Predictable getPredictor() {
        return this.predictor;
    }

    @Generated
    protected void setExecutor(MLExecutable executor) {
        this.executor = executor;
    }

    @Generated
    protected MLExecutable getExecutor() {
        return this.executor;
    }

    @Generated
    protected void setRateLimiter(TokenBucket rateLimiter) {
        this.rateLimiter = rateLimiter;
    }

    @Generated
    protected TokenBucket getRateLimiter() {
        return this.rateLimiter;
    }

    @Generated
    protected void setUserRateLimiterMap(Map<String, TokenBucket> userRateLimiterMap) {
        this.userRateLimiterMap = userRateLimiterMap;
    }

    @Generated
    protected Map<String, TokenBucket> getUserRateLimiterMap() {
        return this.userRateLimiterMap;
    }

    @Generated
    protected void setIsModelEnabled(Boolean isModelEnabled) {
        this.isModelEnabled = isModelEnabled;
    }

    @Generated
    protected Boolean getIsModelEnabled() {
        return this.isModelEnabled;
    }

    @Generated
    protected void setMemSizeEstimationCPU(Long memSizeEstimationCPU) {
        this.memSizeEstimationCPU = memSizeEstimationCPU;
    }

    @Generated
    protected Long getMemSizeEstimationCPU() {
        return this.memSizeEstimationCPU;
    }

    @Generated
    protected void setMemSizeEstimationGPU(Long memSizeEstimationGPU) {
        this.memSizeEstimationGPU = memSizeEstimationGPU;
    }

    @Generated
    protected Long getMemSizeEstimationGPU() {
        return this.memSizeEstimationGPU;
    }

    @Generated
    protected void setMlGuard(MLGuard mlGuard) {
        this.mlGuard = mlGuard;
    }

    @Generated
    protected MLGuard getMlGuard() {
        return this.mlGuard;
    }

    @Generated
    protected void setModelInterface(Map<String, String> modelInterface) {
        this.modelInterface = modelInterface;
    }

    @Generated
    protected Map<String, String> getModelInterface() {
        return this.modelInterface;
    }

    @Generated
    public void setDeployToAllNodes(Boolean deployToAllNodes) {
        this.deployToAllNodes = deployToAllNodes;
    }

    @Generated
    protected void setLastAccessTime(Instant lastAccessTime) {
        this.lastAccessTime = lastAccessTime;
    }

    @Generated
    protected Instant getLastAccessTime() {
        return this.lastAccessTime;
    }

    @Generated
    protected void setIsAutoDeploying(Boolean isAutoDeploying) {
        this.isAutoDeploying = isAutoDeploying;
    }

    @Generated
    protected Boolean getIsAutoDeploying() {
        return this.isAutoDeploying;
    }
}

