/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.codecs.lucene99;

import java.io.IOException;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizer;

public class Lucene99ScalarQuantizedVectorScorer
implements FlatVectorsScorer {
    private final FlatVectorsScorer nonQuantizedDelegate;

    public Lucene99ScalarQuantizedVectorScorer(FlatVectorsScorer flatVectorsScorer) {
        this.nonQuantizedDelegate = flatVectorsScorer;
    }

    @Override
    public RandomVectorScorerSupplier getRandomVectorScorerSupplier(VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) throws IOException {
        if (vectorValues instanceof QuantizedByteVectorValues) {
            QuantizedByteVectorValues quantizedByteVectorValues = (QuantizedByteVectorValues)vectorValues;
            return new ScalarQuantizedRandomVectorScorerSupplier(quantizedByteVectorValues, similarityFunction);
        }
        return this.nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues);
    }

    @Override
    public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target) throws IOException {
        if (vectorValues instanceof QuantizedByteVectorValues) {
            QuantizedByteVectorValues quantizedByteVectorValues = (QuantizedByteVectorValues)vectorValues;
            ScalarQuantizer scalarQuantizer = quantizedByteVectorValues.getScalarQuantizer();
            byte[] targetBytes = new byte[target.length];
            float offsetCorrection = ScalarQuantizedVectorScorer.quantizeQuery(target, targetBytes, similarityFunction, scalarQuantizer);
            return Lucene99ScalarQuantizedVectorScorer.fromVectorSimilarity(targetBytes, offsetCorrection, similarityFunction, scalarQuantizer.getConstantMultiplier(), quantizedByteVectorValues);
        }
        return this.nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
    }

    @Override
    public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target) throws IOException {
        return this.nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
    }

    public String toString() {
        return "ScalarQuantizedVectorScorer(nonQuantizedDelegate=" + String.valueOf(this.nonQuantizedDelegate) + ")";
    }

    static UpdateableRandomVectorScorer fromVectorSimilarity(byte[] targetBytes, float offsetCorrection, VectorSimilarityFunction sim, float constMultiplier, QuantizedByteVectorValues values) {
        return switch (sim) {
            default -> throw new MatchException(null, null);
            case VectorSimilarityFunction.EUCLIDEAN -> new Euclidean(values, constMultiplier, targetBytes);
            case VectorSimilarityFunction.COSINE, VectorSimilarityFunction.DOT_PRODUCT -> Lucene99ScalarQuantizedVectorScorer.dotProductFactory(targetBytes, offsetCorrection, constMultiplier, values, f -> Math.max((1.0f + f) / 2.0f, 0.0f));
            case VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT -> Lucene99ScalarQuantizedVectorScorer.dotProductFactory(targetBytes, offsetCorrection, constMultiplier, values, VectorUtil::scaleMaxInnerProductScore);
        };
    }

    private static UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer dotProductFactory(byte[] targetBytes, float offsetCorrection, float constMultiplier, QuantizedByteVectorValues values, FloatToFloatFunction scoreAdjustmentFunction) {
        if (values.getScalarQuantizer().getBits() <= 4) {
            if (values.getVectorByteLength() != values.dimension() && values.getSlice() != null) {
                return new CompressedInt4DotProduct(values, constMultiplier, targetBytes, offsetCorrection, scoreAdjustmentFunction);
            }
            return new Int4DotProduct(values, constMultiplier, targetBytes, offsetCorrection, scoreAdjustmentFunction);
        }
        return new DotProduct(values, constMultiplier, targetBytes, offsetCorrection, scoreAdjustmentFunction);
    }

    private static final class ScalarQuantizedRandomVectorScorerSupplier
    implements RandomVectorScorerSupplier {
        private final VectorSimilarityFunction vectorSimilarityFunction;
        private final QuantizedByteVectorValues values;
        private final QuantizedByteVectorValues targetVectors;

        public ScalarQuantizedRandomVectorScorerSupplier(QuantizedByteVectorValues values, VectorSimilarityFunction vectorSimilarityFunction) throws IOException {
            this.values = values;
            this.targetVectors = values.copy();
            this.vectorSimilarityFunction = vectorSimilarityFunction;
        }

        @Override
        public UpdateableRandomVectorScorer scorer() throws IOException {
            byte[] vectorValue = new byte[this.values.dimension()];
            float offsetCorrection = 0.0f;
            return Lucene99ScalarQuantizedVectorScorer.fromVectorSimilarity(vectorValue, offsetCorrection, this.vectorSimilarityFunction, this.values.getScalarQuantizer().getConstantMultiplier(), this.targetVectors);
        }

        @Override
        public ScalarQuantizedRandomVectorScorerSupplier copy() throws IOException {
            return new ScalarQuantizedRandomVectorScorerSupplier(this.values.copy(), this.vectorSimilarityFunction);
        }

        public String toString() {
            return "ScalarQuantizedRandomVectorScorerSupplier(vectorSimilarityFunction=" + String.valueOf((Object)this.vectorSimilarityFunction) + ")";
        }
    }

    private static class Euclidean
    extends UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer {
        private final float constMultiplier;
        private final byte[] targetBytes;
        private final QuantizedByteVectorValues values;

        private Euclidean(QuantizedByteVectorValues values, float constMultiplier, byte[] targetBytes) {
            super(values);
            this.values = values;
            this.constMultiplier = constMultiplier;
            this.targetBytes = targetBytes;
        }

        @Override
        public float score(int node) throws IOException {
            byte[] nodeVector = this.values.vectorValue(node);
            int squareDistance = VectorUtil.squareDistance(nodeVector, this.targetBytes);
            float adjustedDistance = (float)squareDistance * this.constMultiplier;
            return 1.0f / (1.0f + adjustedDistance);
        }

        @Override
        public void setScoringOrdinal(int node) throws IOException {
            System.arraycopy(this.values.vectorValue(node), 0, this.targetBytes, 0, this.targetBytes.length);
        }
    }

    @FunctionalInterface
    private static interface FloatToFloatFunction {
        public float apply(float var1);
    }

    private static class CompressedInt4DotProduct
    extends UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer {
        private final float constMultiplier;
        private final QuantizedByteVectorValues values;
        private final byte[] compressedVector;
        private final byte[] targetBytes;
        private float offsetCorrection;
        private final FloatToFloatFunction scoreAdjustmentFunction;

        private CompressedInt4DotProduct(QuantizedByteVectorValues values, float constMultiplier, byte[] targetBytes, float offsetCorrection, FloatToFloatFunction scoreAdjustmentFunction) {
            super(values);
            this.constMultiplier = constMultiplier;
            this.values = values;
            this.compressedVector = new byte[values.getVectorByteLength()];
            this.targetBytes = targetBytes;
            this.offsetCorrection = offsetCorrection;
            this.scoreAdjustmentFunction = scoreAdjustmentFunction;
        }

        @Override
        public float score(int vectorOrdinal) throws IOException {
            this.values.getSlice().seek((long)vectorOrdinal * (long)(this.values.getVectorByteLength() + 4));
            this.values.getSlice().readBytes(this.compressedVector, 0, this.compressedVector.length);
            float vectorOffset = this.values.getScoreCorrectionConstant(vectorOrdinal);
            int dotProduct = VectorUtil.int4DotProductPacked(this.targetBytes, this.compressedVector);
            assert (dotProduct >= 0);
            float adjustedDistance = (float)dotProduct * this.constMultiplier + this.offsetCorrection + vectorOffset;
            return this.scoreAdjustmentFunction.apply(adjustedDistance);
        }

        @Override
        public void setScoringOrdinal(int node) throws IOException {
            System.arraycopy(this.values.vectorValue(node), 0, this.targetBytes, 0, this.targetBytes.length);
            this.offsetCorrection = this.values.getScoreCorrectionConstant(node);
        }
    }

    private static class Int4DotProduct
    extends UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer {
        private final float constMultiplier;
        private final QuantizedByteVectorValues values;
        private final byte[] targetBytes;
        private float offsetCorrection;
        private final FloatToFloatFunction scoreAdjustmentFunction;

        public Int4DotProduct(QuantizedByteVectorValues values, float constMultiplier, byte[] targetBytes, float offsetCorrection, FloatToFloatFunction scoreAdjustmentFunction) {
            super(values);
            this.constMultiplier = constMultiplier;
            this.values = values;
            this.targetBytes = targetBytes;
            this.offsetCorrection = offsetCorrection;
            this.scoreAdjustmentFunction = scoreAdjustmentFunction;
        }

        @Override
        public float score(int vectorOrdinal) throws IOException {
            byte[] storedVector = this.values.vectorValue(vectorOrdinal);
            float vectorOffset = this.values.getScoreCorrectionConstant(vectorOrdinal);
            int dotProduct = VectorUtil.int4DotProduct(storedVector, this.targetBytes);
            assert (dotProduct >= 0);
            float adjustedDistance = (float)dotProduct * this.constMultiplier + this.offsetCorrection + vectorOffset;
            return this.scoreAdjustmentFunction.apply(adjustedDistance);
        }

        @Override
        public void setScoringOrdinal(int node) throws IOException {
            System.arraycopy(this.values.vectorValue(node), 0, this.targetBytes, 0, this.targetBytes.length);
            this.offsetCorrection = this.values.getScoreCorrectionConstant(node);
        }
    }

    private static class DotProduct
    extends UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer {
        private final float constMultiplier;
        private final QuantizedByteVectorValues values;
        private final byte[] targetBytes;
        private float offsetCorrection;
        private final FloatToFloatFunction scoreAdjustmentFunction;

        public DotProduct(QuantizedByteVectorValues values, float constMultiplier, byte[] targetBytes, float offsetCorrection, FloatToFloatFunction scoreAdjustmentFunction) {
            super(values);
            this.constMultiplier = constMultiplier;
            this.values = values;
            this.targetBytes = targetBytes;
            this.offsetCorrection = offsetCorrection;
            this.scoreAdjustmentFunction = scoreAdjustmentFunction;
        }

        @Override
        public float score(int vectorOrdinal) throws IOException {
            byte[] storedVector = this.values.vectorValue(vectorOrdinal);
            float vectorOffset = this.values.getScoreCorrectionConstant(vectorOrdinal);
            int dotProduct = VectorUtil.dotProduct(storedVector, this.targetBytes);
            assert (dotProduct >= 0);
            float adjustedDistance = (float)dotProduct * this.constMultiplier + this.offsetCorrection + vectorOffset;
            return this.scoreAdjustmentFunction.apply(adjustedDistance);
        }

        @Override
        public void setScoringOrdinal(int node) throws IOException {
            System.arraycopy(this.values.vectorValue(node), 0, this.targetBytes, 0, this.targetBytes.length);
            this.offsetCorrection = this.values.getScoreCorrectionConstant(node);
        }
    }
}

