/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.internal.vectorization;

import java.io.IOException;
import java.lang.foreign.MemorySegment;
import java.util.Optional;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.vectorization.MemorySegmentBulkVectorOps;
import org.apache.lucene.store.FilterIndexInput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.MemorySegmentAccessInput;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;

public abstract sealed class Lucene99MemorySegmentFloatVectorScorerSupplier
implements RandomVectorScorerSupplier {
    final int vectorByteSize;
    final int maxOrd;
    final int dims;
    final MemorySegment seg;
    final FloatVectorValues values;

    static Optional<RandomVectorScorerSupplier> create(VectorSimilarityFunction type, IndexInput input, FloatVectorValues values) throws IOException {
        MemorySegmentAccessInput msInput;
        MemorySegment seg;
        if (!((input = FilterIndexInput.unwrapOnlyTest(input)) instanceof MemorySegmentAccessInput) || (seg = (msInput = (MemorySegmentAccessInput)((Object)input)).segmentSliceOrNull(0L, msInput.length())) == null) {
            return Optional.empty();
        }
        Lucene99MemorySegmentFloatVectorScorerSupplier.checkInvariants(values.size(), values.getVectorByteLength(), input);
        return switch (type) {
            default -> throw new MatchException(null, null);
            case VectorSimilarityFunction.COSINE -> Optional.of(new CosineSupplier(seg, values));
            case VectorSimilarityFunction.DOT_PRODUCT -> Optional.of(new DotProductSupplier(seg, values));
            case VectorSimilarityFunction.EUCLIDEAN -> Optional.of(new EuclideanSupplier(seg, values));
            case VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT -> Optional.of(new MaxInnerProductSupplier(seg, values));
        };
    }

    Lucene99MemorySegmentFloatVectorScorerSupplier(MemorySegment seg, FloatVectorValues values) {
        this.seg = seg;
        this.values = values;
        this.vectorByteSize = values.getVectorByteLength();
        this.maxOrd = values.size();
        this.dims = values.dimension();
    }

    static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) {
        if (input.length() < (long)vectorByteLength * (long)maxOrd) {
            throw new IllegalArgumentException("input length is less than expected vector data");
        }
    }

    static final class CosineSupplier
    extends Lucene99MemorySegmentFloatVectorScorerSupplier {
        static final MemorySegmentBulkVectorOps.Cosine COS_OPS = MemorySegmentBulkVectorOps.COS_INSTANCE;

        CosineSupplier(MemorySegment seg, FloatVectorValues values) {
            super(seg, values);
        }

        @Override
        public UpdateableRandomVectorScorer scorer() {
            return new AbstractBulkScorer(this.values){

                @Override
                float vectorOp(MemorySegment seg, long q, long d, int elementCount) {
                    return COS_OPS.cosine(seg, q, d, dims);
                }

                @Override
                void vectorOp(MemorySegment seg, float[] scores, long queryOffset, long node1Offset, long node2Offset, long node3Offset, long node4Offset, int elementCount) {
                    COS_OPS.cosineBulk(seg, scores, queryOffset, node1Offset, node2Offset, node3Offset, node4Offset, dims);
                }

                @Override
                float normalizeRawScore(float rawScore) {
                    return VectorUtil.normalizeToUnitInterval(rawScore);
                }
            };
        }

        @Override
        public CosineSupplier copy() throws IOException {
            return new CosineSupplier(this.seg, this.values.copy());
        }
    }

    static final class DotProductSupplier
    extends Lucene99MemorySegmentFloatVectorScorerSupplier {
        static final MemorySegmentBulkVectorOps.DotProduct DOT_OPS = MemorySegmentBulkVectorOps.DOT_INSTANCE;

        DotProductSupplier(MemorySegment seg, FloatVectorValues values) {
            super(seg, values);
        }

        @Override
        public UpdateableRandomVectorScorer scorer() {
            return new AbstractBulkScorer(this.values){

                @Override
                float vectorOp(MemorySegment seg, long q, long d, int elementCount) {
                    return DOT_OPS.dotProduct(seg, q, d, dims);
                }

                @Override
                void vectorOp(MemorySegment seg, float[] scores, long queryOffset, long node1Offset, long node2Offset, long node3Offset, long node4Offset, int elementCount) {
                    DOT_OPS.dotProductBulk(seg, scores, queryOffset, node1Offset, node2Offset, node3Offset, node4Offset, dims);
                }

                @Override
                float normalizeRawScore(float rawScore) {
                    return VectorUtil.normalizeToUnitInterval(rawScore);
                }
            };
        }

        @Override
        public DotProductSupplier copy() throws IOException {
            return new DotProductSupplier(this.seg, this.values);
        }
    }

    static final class EuclideanSupplier
    extends Lucene99MemorySegmentFloatVectorScorerSupplier {
        static final MemorySegmentBulkVectorOps.SqrDistance SQR_OPS = MemorySegmentBulkVectorOps.SQR_INSTANCE;

        EuclideanSupplier(MemorySegment seg, FloatVectorValues values) {
            super(seg, values);
        }

        @Override
        public UpdateableRandomVectorScorer scorer() {
            return new AbstractBulkScorer(this.values){

                @Override
                float vectorOp(MemorySegment seg, long q, long d, int elementCount) {
                    return SQR_OPS.sqrDistance(seg, q, d, dims);
                }

                @Override
                void vectorOp(MemorySegment seg, float[] scores, long queryOffset, long node1Offset, long node2Offset, long node3Offset, long node4Offset, int elementCount) {
                    SQR_OPS.sqrDistanceBulk(seg, scores, queryOffset, node1Offset, node2Offset, node3Offset, node4Offset, dims);
                }

                @Override
                float normalizeRawScore(float rawScore) {
                    return VectorUtil.normalizeDistanceToUnitInterval(rawScore);
                }
            };
        }

        @Override
        public EuclideanSupplier copy() throws IOException {
            return new EuclideanSupplier(this.seg, this.values);
        }
    }

    static final class MaxInnerProductSupplier
    extends Lucene99MemorySegmentFloatVectorScorerSupplier {
        static final MemorySegmentBulkVectorOps.DotProduct DOT_OPS = MemorySegmentBulkVectorOps.DOT_INSTANCE;

        MaxInnerProductSupplier(MemorySegment seg, FloatVectorValues values) {
            super(seg, values);
        }

        @Override
        public UpdateableRandomVectorScorer scorer() {
            return new AbstractBulkScorer(this.values){

                @Override
                float vectorOp(MemorySegment seg, long q, long d, int elementCount) {
                    return DOT_OPS.dotProduct(seg, q, d, dims);
                }

                @Override
                void vectorOp(MemorySegment seg, float[] scores, long queryOffset, long node1Offset, long node2Offset, long node3Offset, long node4Offset, int elementCount) {
                    DOT_OPS.dotProductBulk(seg, scores, queryOffset, node1Offset, node2Offset, node3Offset, node4Offset, dims);
                }

                @Override
                float normalizeRawScore(float rawScore) {
                    return VectorUtil.scaleMaxInnerProductScore(rawScore);
                }
            };
        }

        @Override
        public MaxInnerProductSupplier copy() throws IOException {
            return new MaxInnerProductSupplier(this.seg, this.values);
        }
    }

    abstract class AbstractBulkScorer
    extends UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer {
        private int queryOrd;
        final float[] scratchScores;

        AbstractBulkScorer(FloatVectorValues values) {
            super(values);
            this.scratchScores = new float[4];
        }

        final void checkOrdinal(int ord) {
            if (ord < 0 || ord >= Lucene99MemorySegmentFloatVectorScorerSupplier.this.maxOrd) {
                throw new IllegalArgumentException("illegal ordinal: " + ord);
            }
        }

        abstract float vectorOp(MemorySegment var1, long var2, long var4, int var6);

        abstract void vectorOp(MemorySegment var1, float[] var2, long var3, long var5, long var7, long var9, long var11, int var13);

        abstract float normalizeRawScore(float var1);

        @Override
        public float score(int node) {
            this.checkOrdinal(node);
            long queryAddr = (long)this.queryOrd * (long)Lucene99MemorySegmentFloatVectorScorerSupplier.this.vectorByteSize;
            long addr = (long)node * (long)Lucene99MemorySegmentFloatVectorScorerSupplier.this.vectorByteSize;
            float raw = this.vectorOp(Lucene99MemorySegmentFloatVectorScorerSupplier.this.seg, queryAddr, addr, Lucene99MemorySegmentFloatVectorScorerSupplier.this.dims);
            return this.normalizeRawScore(raw);
        }

        @Override
        public float bulkScore(int[] nodes, float[] scores, int numNodes) {
            int i;
            long queryAddr = (long)this.queryOrd * (long)Lucene99MemorySegmentFloatVectorScorerSupplier.this.vectorByteSize;
            float maxScore = Float.NEGATIVE_INFINITY;
            int limit = numNodes & 0xFFFFFFFC;
            for (i = 0; i < limit; i += 4) {
                long offset1 = (long)nodes[i] * (long)Lucene99MemorySegmentFloatVectorScorerSupplier.this.vectorByteSize;
                long offset2 = (long)nodes[i + 1] * (long)Lucene99MemorySegmentFloatVectorScorerSupplier.this.vectorByteSize;
                long offset3 = (long)nodes[i + 2] * (long)Lucene99MemorySegmentFloatVectorScorerSupplier.this.vectorByteSize;
                long offset4 = (long)nodes[i + 3] * (long)Lucene99MemorySegmentFloatVectorScorerSupplier.this.vectorByteSize;
                this.vectorOp(Lucene99MemorySegmentFloatVectorScorerSupplier.this.seg, this.scratchScores, queryAddr, offset1, offset2, offset3, offset4, Lucene99MemorySegmentFloatVectorScorerSupplier.this.dims);
                scores[i + 0] = this.normalizeRawScore(this.scratchScores[0]);
                maxScore = Math.max(maxScore, scores[i + 0]);
                scores[i + 1] = this.normalizeRawScore(this.scratchScores[1]);
                maxScore = Math.max(maxScore, scores[i + 1]);
                scores[i + 2] = this.normalizeRawScore(this.scratchScores[2]);
                maxScore = Math.max(maxScore, scores[i + 2]);
                scores[i + 3] = this.normalizeRawScore(this.scratchScores[3]);
                maxScore = Math.max(maxScore, scores[i + 3]);
            }
            int remaining = numNodes - i;
            if (remaining > 0) {
                long addr1 = (long)nodes[i] * (long)Lucene99MemorySegmentFloatVectorScorerSupplier.this.vectorByteSize;
                long addr2 = remaining > 1 ? (long)nodes[i + 1] * (long)Lucene99MemorySegmentFloatVectorScorerSupplier.this.vectorByteSize : addr1;
                long addr3 = remaining > 2 ? (long)nodes[i + 2] * (long)Lucene99MemorySegmentFloatVectorScorerSupplier.this.vectorByteSize : addr1;
                this.vectorOp(Lucene99MemorySegmentFloatVectorScorerSupplier.this.seg, this.scratchScores, queryAddr, addr1, addr2, addr3, addr1, Lucene99MemorySegmentFloatVectorScorerSupplier.this.dims);
                scores[i] = this.normalizeRawScore(this.scratchScores[0]);
                maxScore = Math.max(maxScore, scores[i]);
                if (remaining > 1) {
                    scores[i + 1] = this.normalizeRawScore(this.scratchScores[1]);
                    maxScore = Math.max(maxScore, scores[i + 1]);
                }
                if (remaining > 2) {
                    scores[i + 2] = this.normalizeRawScore(this.scratchScores[2]);
                    maxScore = Math.max(maxScore, scores[i + 2]);
                }
            }
            return maxScore;
        }

        @Override
        public void setScoringOrdinal(int node) {
            this.checkOrdinal(node);
            this.queryOrd = node;
        }
    }
}

