/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.processor;

import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.Multimap;
import java.util.Collections;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.Generated;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.common.collect.Tuple;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchRequestProcessor;
import org.opensearch.search.rescore.QueryRescorerBuilder;
import org.opensearch.search.rescore.RescorerBuilder;

public class NeuralSparseTwoPhaseProcessor
extends AbstractProcessor
implements SearchRequestProcessor {
    public static final String TYPE = "neural_sparse_two_phase_processor";
    private boolean enabled;
    private float ratio;
    private float windowExpansion;
    private int maxWindowSize;
    private static final String PARAMETER_KEY = "two_phase_parameter";
    private static final String RATIO_KEY = "prune_ratio";
    private static final String ENABLE_KEY = "enabled";
    private static final String EXPANSION_KEY = "expansion_rate";
    private static final String MAX_WINDOW_SIZE_KEY = "max_window_size";
    private static final boolean DEFAULT_ENABLED = true;
    private static final float DEFAULT_RATIO = 0.4f;
    private static final float DEFAULT_WINDOW_EXPANSION = 5.0f;
    private static final int DEFAULT_MAX_WINDOW_SIZE = 10000;
    private static final int DEFAULT_BASE_QUERY_SIZE = 10;
    private static final int MAX_WINDOWS_SIZE_LOWER_BOUND = 50;
    private static final float WINDOW_EXPANSION_LOWER_BOUND = 1.0f;
    private static final float RATIO_LOWER_BOUND = 0.0f;
    private static final float RATIO_UPPER_BOUND = 1.0f;

    protected NeuralSparseTwoPhaseProcessor(String tag, String description, boolean ignoreFailure, boolean enabled, float ratio, float windowExpansion, int maxWindowSize) {
        super(tag, description, ignoreFailure);
        this.enabled = enabled;
        if (ratio < 0.0f || ratio > 1.0f) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "The two_phase_parameter.prune_ratio must be within [0, 1]. Received: %f", Float.valueOf(ratio)));
        }
        this.ratio = ratio;
        if (windowExpansion < 1.0f) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "The two_phase_parameter.expansion_rate must >= 1.0. Received: %f", Float.valueOf(windowExpansion)));
        }
        this.windowExpansion = windowExpansion;
        if (maxWindowSize < 50) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "The two_phase_parameter.max_window_size must >= 50. Received: %n" + maxWindowSize, new Object[0]));
        }
        this.maxWindowSize = maxWindowSize;
    }

    public SearchRequest processRequest(SearchRequest request) {
        if (!this.enabled || this.ratio == 0.0f) {
            return request;
        }
        QueryBuilder queryBuilder = request.source().query();
        Multimap<NeuralSparseQueryBuilder, Float> queryBuilderMap = this.collectNeuralSparseQueryBuilder(queryBuilder, 1.0f);
        if (queryBuilderMap.isEmpty()) {
            return request;
        }
        QueryBuilder nestedTwoPhaseQueryBuilder = this.getNestedQueryBuilderFromNeuralSparseQueryBuilderMap(queryBuilderMap);
        nestedTwoPhaseQueryBuilder.boost(this.getOriginQueryWeightAfterRescore(request.source()));
        RescorerBuilder<QueryRescorerBuilder> twoPhaseRescorer = this.buildRescoreQueryBuilderForTwoPhase(nestedTwoPhaseQueryBuilder, request);
        request.source().addRescorer(twoPhaseRescorer);
        return request;
    }

    public String getType() {
        return TYPE;
    }

    public static Tuple<Map<String, Float>, Map<String, Float>> splitQueryTokensByRatioedMaxScoreAsThreshold(Map<String, Float> queryTokens, float thresholdRatio) {
        if (Objects.isNull(queryTokens)) {
            throw new IllegalArgumentException("Query tokens cannot be null or empty.");
        }
        float max = 0.0f;
        for (Float value : queryTokens.values()) {
            max = Math.max(value.floatValue(), max);
        }
        float threshold = max * thresholdRatio;
        Map<Boolean, Map<String, Float>> queryTokensByScore = queryTokens.entrySet().stream().collect(Collectors.partitioningBy(entry -> ((Float)entry.getValue()).floatValue() >= threshold, Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)));
        Map<String, Float> highScoreTokens = queryTokensByScore.get(Boolean.TRUE);
        Map<String, Float> lowScoreTokens = queryTokensByScore.get(Boolean.FALSE);
        if (Objects.isNull(highScoreTokens)) {
            highScoreTokens = Collections.emptyMap();
        }
        if (Objects.isNull(lowScoreTokens)) {
            lowScoreTokens = Collections.emptyMap();
        }
        return Tuple.tuple(highScoreTokens, lowScoreTokens);
    }

    private QueryBuilder getNestedQueryBuilderFromNeuralSparseQueryBuilderMap(Multimap<NeuralSparseQueryBuilder, Float> queryBuilderFloatMap) {
        BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
        queryBuilderFloatMap.asMap().forEach((neuralSparseQueryBuilder, boosts) -> {
            float reduceBoost = boosts.stream().reduce(Float.valueOf(0.0f), Float::sum).floatValue();
            boolQueryBuilder.should((QueryBuilder)neuralSparseQueryBuilder.boost(reduceBoost));
        });
        return boolQueryBuilder;
    }

    private float getOriginQueryWeightAfterRescore(SearchSourceBuilder searchSourceBuilder) {
        if (Objects.isNull(searchSourceBuilder.rescores())) {
            return 1.0f;
        }
        return searchSourceBuilder.rescores().stream().map(rescorerBuilder -> Float.valueOf(((QueryRescorerBuilder)rescorerBuilder).getQueryWeight())).reduce(Float.valueOf(1.0f), (a, b) -> Float.valueOf(a.floatValue() * b.floatValue())).floatValue();
    }

    private Multimap<NeuralSparseQueryBuilder, Float> collectNeuralSparseQueryBuilder(QueryBuilder queryBuilder, float baseBoost) {
        ArrayListMultimap result = ArrayListMultimap.create();
        if (queryBuilder instanceof BoolQueryBuilder) {
            BoolQueryBuilder boolQueryBuilder = (BoolQueryBuilder)queryBuilder;
            float updatedBoost = baseBoost * boolQueryBuilder.boost();
            for (QueryBuilder subQuery : boolQueryBuilder.should()) {
                Multimap<NeuralSparseQueryBuilder, Float> subResult = this.collectNeuralSparseQueryBuilder(subQuery, updatedBoost);
                result.putAll(subResult);
            }
        } else if (queryBuilder instanceof NeuralSparseQueryBuilder) {
            NeuralSparseQueryBuilder neuralSparseQueryBuilder = (NeuralSparseQueryBuilder)queryBuilder;
            float updatedBoost = baseBoost * neuralSparseQueryBuilder.boost();
            NeuralSparseQueryBuilder modifiedQueryBuilder = neuralSparseQueryBuilder.getCopyNeuralSparseQueryBuilderForTwoPhase(this.ratio);
            result.put((Object)modifiedQueryBuilder, (Object)Float.valueOf(updatedBoost));
        }
        return result;
    }

    private RescorerBuilder<QueryRescorerBuilder> buildRescoreQueryBuilderForTwoPhase(QueryBuilder nestedTwoPhaseQueryBuilder, SearchRequest searchRequest) {
        QueryRescorerBuilder twoPhaseRescorer = new QueryRescorerBuilder(nestedTwoPhaseQueryBuilder);
        int requestSize = searchRequest.source().size();
        int windowSize = (int)((float)(requestSize == -1 ? 10 : requestSize) * this.windowExpansion);
        if (windowSize > this.maxWindowSize || windowSize < 0) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "The two-phase window size of neural_sparse_two_phase_processor should be [0,%d], but get the value of %d", this.maxWindowSize, windowSize));
        }
        twoPhaseRescorer.windowSize(windowSize);
        return twoPhaseRescorer;
    }

    @Generated
    public void setEnabled(boolean enabled) {
        this.enabled = enabled;
    }

    @Generated
    public void setRatio(float ratio) {
        this.ratio = ratio;
    }

    @Generated
    public void setWindowExpansion(float windowExpansion) {
        this.windowExpansion = windowExpansion;
    }

    @Generated
    public void setMaxWindowSize(int maxWindowSize) {
        this.maxWindowSize = maxWindowSize;
    }

    @Generated
    public boolean isEnabled() {
        return this.enabled;
    }

    @Generated
    public float getRatio() {
        return this.ratio;
    }

    @Generated
    public float getWindowExpansion() {
        return this.windowExpansion;
    }

    @Generated
    public int getMaxWindowSize() {
        return this.maxWindowSize;
    }

    public static class Factory
    implements Processor.Factory<SearchRequestProcessor> {
        public NeuralSparseTwoPhaseProcessor create(Map<String, Processor.Factory<SearchRequestProcessor>> processorFactories, String tag, String description, boolean ignoreFailure, Map<String, Object> config, Processor.PipelineContext pipelineContext) throws IllegalArgumentException {
            boolean enabled = ConfigurationUtils.readBooleanProperty((String)NeuralSparseTwoPhaseProcessor.TYPE, (String)tag, config, (String)NeuralSparseTwoPhaseProcessor.ENABLE_KEY, (boolean)true);
            Map twoPhaseConfigMap = ConfigurationUtils.readOptionalMap((String)NeuralSparseTwoPhaseProcessor.TYPE, (String)tag, config, (String)NeuralSparseTwoPhaseProcessor.PARAMETER_KEY);
            float ratio = 0.4f;
            float windowExpansion = 5.0f;
            int maxWindowSize = 10000;
            if (Objects.nonNull(twoPhaseConfigMap)) {
                ratio = ((Number)twoPhaseConfigMap.getOrDefault(NeuralSparseTwoPhaseProcessor.RATIO_KEY, Float.valueOf(ratio))).floatValue();
                windowExpansion = ((Number)twoPhaseConfigMap.getOrDefault(NeuralSparseTwoPhaseProcessor.EXPANSION_KEY, Float.valueOf(windowExpansion))).floatValue();
                maxWindowSize = ((Number)twoPhaseConfigMap.getOrDefault(NeuralSparseTwoPhaseProcessor.MAX_WINDOW_SIZE_KEY, maxWindowSize)).intValue();
            }
            return new NeuralSparseTwoPhaseProcessor(tag, description, ignoreFailure, enabled, ratio, windowExpansion, maxWindowSize);
        }
    }
}

