/*
 * Decompiled with CFR 0.152.
 */
package org.carrot2.text.vsm;

import org.carrot2.core.attribute.Processing;
import org.carrot2.mahout.math.matrix.DoubleMatrix2D;
import org.carrot2.mahout.math.matrix.impl.DenseDoubleMatrix2D;
import org.carrot2.matrix.MatrixUtils;
import org.carrot2.matrix.factorization.IMatrixFactorization;
import org.carrot2.matrix.factorization.IMatrixFactorizationFactory;
import org.carrot2.matrix.factorization.IterationNumberGuesser;
import org.carrot2.matrix.factorization.IterativeMatrixFactorizationFactory;
import org.carrot2.matrix.factorization.KMeansMatrixFactorizationFactory;
import org.carrot2.matrix.factorization.LocalNonnegativeMatrixFactorizationFactory;
import org.carrot2.matrix.factorization.NonnegativeMatrixFactorizationEDFactory;
import org.carrot2.matrix.factorization.NonnegativeMatrixFactorizationKLFactory;
import org.carrot2.matrix.factorization.PartialSingularValueDecompositionFactory;
import org.carrot2.text.vsm.ReducedVectorSpaceModelContext;
import org.carrot2.text.vsm.VectorSpaceModelContext;
import org.carrot2.util.attribute.Attribute;
import org.carrot2.util.attribute.AttributeLevel;
import org.carrot2.util.attribute.Bindable;
import org.carrot2.util.attribute.Group;
import org.carrot2.util.attribute.Input;
import org.carrot2.util.attribute.Label;
import org.carrot2.util.attribute.Level;
import org.carrot2.util.attribute.Required;
import org.carrot2.util.attribute.constraint.ImplementingClasses;

@Bindable(prefix="TermDocumentMatrixReducer")
public class TermDocumentMatrixReducer {
    @Input
    @Processing
    @Attribute
    @Required
    @ImplementingClasses(classes={PartialSingularValueDecompositionFactory.class, NonnegativeMatrixFactorizationEDFactory.class, NonnegativeMatrixFactorizationKLFactory.class, LocalNonnegativeMatrixFactorizationFactory.class, KMeansMatrixFactorizationFactory.class}, strict=false)
    @Label(value="Factorization method")
    @Level(value=AttributeLevel.ADVANCED)
    @Group(value="Matrix model")
    public IMatrixFactorizationFactory factorizationFactory = new NonnegativeMatrixFactorizationEDFactory();
    @Input
    @Processing
    @Required
    @Attribute
    @Label(value="Factorization quality")
    @Level(value=AttributeLevel.ADVANCED)
    @Group(value="Matrix model")
    public IterationNumberGuesser.FactorizationQuality factorizationQuality = IterationNumberGuesser.FactorizationQuality.HIGH;

    public void reduce(ReducedVectorSpaceModelContext context, int dimensions) {
        VectorSpaceModelContext vsmContext = context.vsmContext;
        if (vsmContext.termDocumentMatrix.columns() == 0 || vsmContext.termDocumentMatrix.rows() == 0) {
            context.baseMatrix = new DenseDoubleMatrix2D(vsmContext.termDocumentMatrix.rows(), vsmContext.termDocumentMatrix.columns());
            return;
        }
        if (this.factorizationFactory instanceof IterativeMatrixFactorizationFactory) {
            ((IterativeMatrixFactorizationFactory)this.factorizationFactory).setK(dimensions);
            IterationNumberGuesser.setEstimatedIterationsNumber((IterativeMatrixFactorizationFactory)this.factorizationFactory, vsmContext.termDocumentMatrix, this.factorizationQuality);
        }
        MatrixUtils.normalizeColumnL2(vsmContext.termDocumentMatrix, null);
        IMatrixFactorization factorization = this.factorizationFactory.factorize(vsmContext.termDocumentMatrix);
        context.baseMatrix = factorization.getU();
        context.coefficientMatrix = factorization.getV();
        context.baseMatrix = this.trim(factorization.getU(), dimensions);
        context.coefficientMatrix = this.trim(factorization.getV(), dimensions);
    }

    private final DoubleMatrix2D trim(DoubleMatrix2D matrix, int dimensions) {
        if (!(this.factorizationFactory instanceof IterativeMatrixFactorizationFactory) && matrix.columns() > dimensions) {
            return matrix.viewPart(0, 0, matrix.rows(), dimensions);
        }
        return matrix;
    }
}

