diff --git a/src/main/java/org/opensearch/knn/common/featureflags/KNNFeatureFlags.java b/src/main/java/org/opensearch/knn/common/featureflags/KNNFeatureFlags.java index bab5b97bb1..f2b411278d 100644 --- a/src/main/java/org/opensearch/knn/common/featureflags/KNNFeatureFlags.java +++ b/src/main/java/org/opensearch/knn/common/featureflags/KNNFeatureFlags.java @@ -26,6 +26,7 @@ public class KNNFeatureFlags { // Feature flags private static final String KNN_FORCE_EVICT_CACHE_ENABLED = "knn.feature.cache.force_evict.enabled"; + private static final String KNN_REMOTE_VECTOR_BUILD = "knn.feature.remote_index_build.enabled"; @VisibleForTesting public static final Setting KNN_FORCE_EVICT_CACHE_ENABLED_SETTING = Setting.boolSetting( @@ -35,8 +36,15 @@ public class KNNFeatureFlags { Dynamic ); + public static final Setting KNN_REMOTE_VECTOR_BUILD_SETTING = Setting.boolSetting( + KNN_REMOTE_VECTOR_BUILD, + false, + NodeScope, + Dynamic + ); + public static List> getFeatureFlags() { - return ImmutableList.of(KNN_FORCE_EVICT_CACHE_ENABLED_SETTING); + return ImmutableList.of(KNN_FORCE_EVICT_CACHE_ENABLED_SETTING, KNN_REMOTE_VECTOR_BUILD_SETTING); } /** @@ -46,4 +54,11 @@ public static List> getFeatureFlags() { public static boolean isForceEvictCacheEnabled() { return Booleans.parseBoolean(KNNSettings.state().getSettingValue(KNN_FORCE_EVICT_CACHE_ENABLED).toString(), false); } + + /** + * @return true if remote vector index build feature flag is enabled + */ + public static boolean isKNNRemoteVectorBuildEnabled() { + return Booleans.parseBooleanStrict(KNNSettings.state().getSettingValue(KNN_REMOTE_VECTOR_BUILD).toString(), false); + } } diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index ebb55ea2ba..02bc113384 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -41,9 +41,9 @@ import static java.util.stream.Collectors.toUnmodifiableMap; import static org.opensearch.common.settings.Setting.Property.Dynamic; +import static org.opensearch.common.settings.Setting.Property.Final; import static org.opensearch.common.settings.Setting.Property.IndexScope; import static org.opensearch.common.settings.Setting.Property.NodeScope; -import static org.opensearch.common.settings.Setting.Property.Final; import static org.opensearch.common.settings.Setting.Property.UnmodifiableOnRestore; import static org.opensearch.common.unit.MemorySizeValue.parseBytesSizeValueOrHeapRatio; import static org.opensearch.core.common.unit.ByteSizeValue.parseBytesSizeValue; @@ -94,6 +94,8 @@ public class KNNSettings { public static final String KNN_FAISS_AVX512_SPR_DISABLED = "knn.faiss.avx512_spr.disabled"; public static final String KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED = "index.knn.disk.vector.shard_level_rescoring_disabled"; public static final String KNN_DERIVED_SOURCE_ENABLED = "index.knn.derived_source.enabled"; + public static final String KNN_INDEX_REMOTE_VECTOR_BUILD = "index.knn.remote_index_build.enabled"; + public static final String KNN_REMOTE_VECTOR_REPO = "knn.remote_index_build.vector_repo"; /** * Default setting values @@ -371,6 +373,15 @@ public class KNNSettings { NodeScope ); + public static final Setting KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING = Setting.boolSetting( + KNN_INDEX_REMOTE_VECTOR_BUILD, + false, + Dynamic, + IndexScope + ); + + public static final Setting KNN_REMOTE_VECTOR_REPO_SETTING = Setting.simpleString(KNN_REMOTE_VECTOR_REPO, Dynamic, NodeScope); + /** * Dynamic settings */ @@ -525,6 +536,14 @@ private Setting getSetting(String key) { return KNN_DERIVED_SOURCE_ENABLED_SETTING; } + if (KNN_INDEX_REMOTE_VECTOR_BUILD.equals(key)) { + return KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING; + } + + if (KNN_REMOTE_VECTOR_REPO.equals(key)) { + return KNN_REMOTE_VECTOR_REPO_SETTING; + } + throw new IllegalArgumentException("Cannot find setting by key [" + key + "]"); } @@ -550,7 +569,9 @@ public List> getSettings() { QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING, KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING, - KNN_DERIVED_SOURCE_ENABLED_SETTING + KNN_DERIVED_SOURCE_ENABLED_SETTING, + KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING, + KNN_REMOTE_VECTOR_REPO_SETTING ); return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream())) .collect(Collectors.toList()); diff --git a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java index f3a1258387..d7e051cbd0 100644 --- a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java @@ -135,7 +135,7 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) { return nativeEngineVectorsFormat(); } - private NativeEngines990KnnVectorsFormat nativeEngineVectorsFormat() { + protected KnnVectorsFormat nativeEngineVectorsFormat() { // mapperService is already checked for null or valid instance type at caller, hence we don't need // addition isPresent check here. int approximateThreshold = getApproximateThresholdValue(); @@ -145,7 +145,7 @@ private NativeEngines990KnnVectorsFormat nativeEngineVectorsFormat() { ); } - private int getApproximateThresholdValue() { + protected int getApproximateThresholdValue() { // This is private method and mapperService is already checked for null or valid instance type before this call // at caller, hence we don't need additional isPresent check here. final IndexSettings indexSettings = mapperService.get().getIndexSettings(); diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN10010Codec/KNN10010PerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN10010Codec/KNN10010PerFieldKnnVectorsFormat.java new file mode 100644 index 0000000000..115490f8b4 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN10010Codec/KNN10010PerFieldKnnVectorsFormat.java @@ -0,0 +1,120 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN10010Codec; + +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; +import org.opensearch.common.Nullable; +import org.opensearch.common.collect.Tuple; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.codec.BasePerFieldKnnVectorsFormat; +import org.opensearch.knn.index.codec.KNN9120Codec.KNN9120HnswBinaryVectorsFormat; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.remote.RemoteIndexBuilder; + +import java.util.Optional; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +/** + * Class provides per field format implementation for Lucene Knn vector type + */ +public class KNN10010PerFieldKnnVectorsFormat extends BasePerFieldKnnVectorsFormat { + private static final Tuple DEFAULT_MERGE_THREAD_COUNT_AND_EXECUTOR_SERVICE = Tuple.tuple(1, null); + @Nullable + private RemoteIndexBuilder remoteIndexBuilder; + + public KNN10010PerFieldKnnVectorsFormat(final Optional mapperService, final RemoteIndexBuilder remoteIndexBuilder) { + super( + mapperService, + Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, + Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, + Lucene99HnswVectorsFormat::new, + knnVectorsFormatParams -> { + final Tuple mergeThreadCountAndExecutorService = getMergeThreadCountAndExecutorService(); + // There is an assumption here that hamming space will only be used for binary vectors. This will need to be fixed if that + // changes in the future. + if (knnVectorsFormatParams.getSpaceType() == SpaceType.HAMMING) { + return new KNN9120HnswBinaryVectorsFormat( + knnVectorsFormatParams.getMaxConnections(), + knnVectorsFormatParams.getBeamWidth(), + // number of merge threads + mergeThreadCountAndExecutorService.v1(), + // executor service + mergeThreadCountAndExecutorService.v2() + ); + } else { + return new Lucene99HnswVectorsFormat( + knnVectorsFormatParams.getMaxConnections(), + knnVectorsFormatParams.getBeamWidth(), + // number of merge threads + mergeThreadCountAndExecutorService.v1(), + // executor service + mergeThreadCountAndExecutorService.v2() + ); + } + }, + knnScalarQuantizedVectorsFormatParams -> { + final Tuple mergeThreadCountAndExecutorService = getMergeThreadCountAndExecutorService(); + return new Lucene99HnswScalarQuantizedVectorsFormat( + knnScalarQuantizedVectorsFormatParams.getMaxConnections(), + knnScalarQuantizedVectorsFormatParams.getBeamWidth(), + // Number of merge threads + mergeThreadCountAndExecutorService.v1(), + knnScalarQuantizedVectorsFormatParams.getBits(), + knnScalarQuantizedVectorsFormatParams.isCompressFlag(), + knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(), + // Executor service + mergeThreadCountAndExecutorService.v2() + ); + } + ); + this.remoteIndexBuilder = remoteIndexBuilder; + } + + public KNN10010PerFieldKnnVectorsFormat(final Optional mapperService) { + this(mapperService, null); + } + + /** + * This method returns the maximum dimension allowed from KNNEngine for Lucene codec + * + * @param fieldName Name of the field, ignored + * @return Maximum constant dimension set by KNNEngine + */ + @Override + public int getMaxDimensions(String fieldName) { + return KNNEngine.getMaxDimensionByEngine(KNNEngine.LUCENE); + } + + private static Tuple getMergeThreadCountAndExecutorService() { + // To ensure that only once we are fetching the settings per segment, we are fetching the num threads once while + // creating the executors + int mergeThreadCount = KNNSettings.getIndexThreadQty(); + // We need to return null whenever the merge threads are <=1, as lucene assumes that if number of threads are 1 + // then we should be giving a null value of the executor + if (mergeThreadCount <= 1) { + return DEFAULT_MERGE_THREAD_COUNT_AND_EXECUTOR_SERVICE; + } else { + return Tuple.tuple(mergeThreadCount, Executors.newFixedThreadPool(mergeThreadCount)); + } + } + + @Override + protected KnnVectorsFormat nativeEngineVectorsFormat() { + int approximateThreshold = getApproximateThresholdValue(); + return new NativeEngines10010KnnVectorsFormat( + new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()), + approximateThreshold, + remoteIndexBuilder + ); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN10010Codec/NativeEngines10010KnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN10010Codec/NativeEngines10010KnnVectorsFormat.java new file mode 100644 index 0000000000..4a83525780 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN10010Codec/NativeEngines10010KnnVectorsFormat.java @@ -0,0 +1,112 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.codec.KNN10010Codec; + +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.opensearch.common.Nullable; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.codec.KNN990Codec.NativeEngines990KnnVectorsReader; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.remote.RemoteIndexBuilder; + +import java.io.IOException; + +/** + * This is a Vector format that will be used for Native engines like Faiss and Nmslib for reading and writing vector + * related data structures. + */ +public class NativeEngines10010KnnVectorsFormat extends KnnVectorsFormat { + /** The format for storing, reading, merging vectors on disk */ + private static FlatVectorsFormat flatVectorsFormat; + private static final String FORMAT_NAME = "NativeEngines10010KnnVectorsFormat"; + private static int approximateThreshold; + @Nullable + private final RemoteIndexBuilder remoteIndexBuilder; + + // For Testing Only + public NativeEngines10010KnnVectorsFormat() { + this(new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer())); + } + + // For Testing Only + public NativeEngines10010KnnVectorsFormat(int approximateThreshold) { + this(new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer()), approximateThreshold, null); + } + + // For Testing Only + public NativeEngines10010KnnVectorsFormat(final FlatVectorsFormat flatVectorsFormat) { + this(flatVectorsFormat, KNNSettings.INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD_DEFAULT_VALUE, null); + } + + public NativeEngines10010KnnVectorsFormat( + final FlatVectorsFormat flatVectorsFormat, + int approximateThreshold, + RemoteIndexBuilder remoteIndexBuilder + ) { + super(FORMAT_NAME); + NativeEngines10010KnnVectorsFormat.flatVectorsFormat = flatVectorsFormat; + NativeEngines10010KnnVectorsFormat.approximateThreshold = approximateThreshold; + this.remoteIndexBuilder = remoteIndexBuilder; + } + + /** + * Returns a {@link org.apache.lucene.codecs.KnnVectorsWriter} to write the vectors to the index. + * + * @param state {@link org.apache.lucene.index.SegmentWriteState} + */ + @Override + public KnnVectorsWriter fieldsWriter(final SegmentWriteState state) throws IOException { + return new NativeEngines10010KnnVectorsWriter( + state, + flatVectorsFormat.fieldsWriter(state), + approximateThreshold, + remoteIndexBuilder + ); + } + + /** + * Returns a {@link org.apache.lucene.codecs.KnnVectorsReader} to read the vectors from the index. + * + * @param state {@link org.apache.lucene.index.SegmentReadState} + */ + @Override + public KnnVectorsReader fieldsReader(final SegmentReadState state) throws IOException { + return new NativeEngines990KnnVectorsReader(state, flatVectorsFormat.fieldsReader(state)); + } + + /** + * @param s + * @return + */ + @Override + public int getMaxDimensions(String s) { + return KNNEngine.getMaxDimensionByEngine(KNNEngine.LUCENE); + } + + @Override + public String toString() { + return "NativeEngines99KnnVectorsFormat(name=" + + this.getClass().getSimpleName() + + ", flatVectorsFormat=" + + flatVectorsFormat + + ", approximateThreshold=" + + approximateThreshold + + ")"; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN10010Codec/NativeEngines10010KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN10010Codec/NativeEngines10010KnnVectorsWriter.java new file mode 100644 index 0000000000..0c02257676 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN10010Codec/NativeEngines10010KnnVectorsWriter.java @@ -0,0 +1,338 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.codec.KNN10010Codec; + +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.codecs.KnnFieldVectorsWriter; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.RamUsageEstimator; +import org.opensearch.common.StopWatch; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.KNN990Codec.KNN990QuantizationStateWriter; +import org.opensearch.knn.index.codec.NativeEngineFieldVectorsWriter; +import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; +import org.opensearch.knn.index.quantizationservice.QuantizationService; +import org.opensearch.knn.index.remote.RemoteIndexBuilder; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.plugin.stats.KNNGraphValue; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Supplier; + +import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType; +import static org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory.getVectorValues; + +/** + * A KNNVectorsWriter class for writing the vector data strcutures and flat vectors for Native Engines. + */ +@Log4j2 +public class NativeEngines10010KnnVectorsWriter extends KnnVectorsWriter { + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(NativeEngines10010KnnVectorsWriter.class); + + private final SegmentWriteState segmentWriteState; + private final FlatVectorsWriter flatVectorsWriter; + private KNN990QuantizationStateWriter quantizationStateWriter; + private final List> fields = new ArrayList<>(); + private boolean finished; + private final Integer approximateThreshold; + private final RemoteIndexBuilder remoteIndexBuilder; + + public NativeEngines10010KnnVectorsWriter( + SegmentWriteState segmentWriteState, + FlatVectorsWriter flatVectorsWriter, + Integer approximateThreshold, + RemoteIndexBuilder remoteIndexBuilder + ) { + this.segmentWriteState = segmentWriteState; + this.flatVectorsWriter = flatVectorsWriter; + this.approximateThreshold = approximateThreshold; + this.remoteIndexBuilder = remoteIndexBuilder; + } + + // For Testing Only + public NativeEngines10010KnnVectorsWriter( + SegmentWriteState segmentWriteState, + FlatVectorsWriter flatVectorsWriter, + Integer buildGraphAlwaysThreshold + ) { + this(segmentWriteState, flatVectorsWriter, buildGraphAlwaysThreshold, null); + } + + /** + * Add new field for indexing. + * @param fieldInfo {@link org.apache.lucene.index.FieldInfo} + */ + @Override + public KnnFieldVectorsWriter addField(final FieldInfo fieldInfo) throws IOException { + final NativeEngineFieldVectorsWriter newField = NativeEngineFieldVectorsWriter.create( + fieldInfo, + flatVectorsWriter.addField(fieldInfo), + segmentWriteState.infoStream + ); + fields.add(newField); + return newField; + } + + /** + * Flush all buffered data on disk. This is not fsync. This is lucene flush. + * + * @param maxDoc int + * @param sortMap {@link org.apache.lucene.index.Sorter.DocMap} + */ + @Override + public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { + flatVectorsWriter.flush(maxDoc, sortMap); + + for (final NativeEngineFieldVectorsWriter field : fields) { + final FieldInfo fieldInfo = field.getFieldInfo(); + final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); + int totalLiveDocs = field.getVectors().size(); + if (totalLiveDocs == 0) { + log.debug("[Flush] No live docs for field {}", fieldInfo.getName()); + continue; + } + final Supplier> knnVectorValuesSupplier = () -> getVectorValues( + vectorDataType, + field.getFlatFieldVectorsWriter().getDocsWithFieldSet(), + field.getVectors() + ); + final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs); + // should skip graph building only for non quantization use case and if threshold is met + if (quantizationState == null && shouldSkipBuildingVectorDataStructure(totalLiveDocs)) { + log.info( + "Skip building vector data structure for field: {}, as liveDoc: {} is less than the threshold {} during flush", + fieldInfo.name, + totalLiveDocs, + approximateThreshold + ); + continue; + } + + boolean requiresFallback = false; + if (remoteIndexBuilder != null && remoteIndexBuilder.shouldBuildIndexRemotely()) { + try { + remoteIndexBuilder.buildIndexRemotely(fieldInfo, knnVectorValuesSupplier, totalLiveDocs, segmentWriteState); + } catch (InterruptedException e) { + requiresFallback = true; + e.printStackTrace(); + } + } else { + requiresFallback = true; + } + + if (requiresFallback) { + final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState); + final KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); + + StopWatch stopWatch = new StopWatch().start(); + writer.flushIndex(knnVectorValues, totalLiveDocs); + long time_in_millis = stopWatch.stop().totalTime().millis(); + KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis); + log.debug("Flush took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName()); + } + } + } + + @Override + public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState) throws IOException { + // This will ensure that we are merging the FlatIndex during force merge. + flatVectorsWriter.mergeOneField(fieldInfo, mergeState); + + final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); + final Supplier> knnVectorValuesSupplier = () -> getKNNVectorValuesForMerge( + vectorDataType, + fieldInfo, + mergeState + ); + int totalLiveDocs = getLiveDocs(knnVectorValuesSupplier.get()); + if (totalLiveDocs == 0) { + log.debug("[Merge] No live docs for field {}", fieldInfo.getName()); + return; + } + + final QuantizationState quantizationState = train(fieldInfo, knnVectorValuesSupplier, totalLiveDocs); + // should skip graph building only for non quantization use case and if threshold is met + if (quantizationState == null && shouldSkipBuildingVectorDataStructure(totalLiveDocs)) { + log.info( + "Skip building vector data structure for field: {}, as liveDoc: {} is less than the threshold {} during merge", + fieldInfo.name, + totalLiveDocs, + approximateThreshold + ); + return; + } + + boolean requiresFallback = false; + if (remoteIndexBuilder != null && remoteIndexBuilder.shouldBuildIndexRemotely()) { + try { + remoteIndexBuilder.buildIndexRemotely(fieldInfo, knnVectorValuesSupplier, totalLiveDocs, segmentWriteState); + } catch (InterruptedException e) { + requiresFallback = true; + e.printStackTrace(); + } + } else { + requiresFallback = true; + } + + if (requiresFallback) { + final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState); + final KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); + + StopWatch stopWatch = new StopWatch().start(); + + writer.mergeIndex(knnVectorValues, totalLiveDocs); + + long time_in_millis = stopWatch.stop().totalTime().millis(); + KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis); + log.debug("Merge took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName()); + } + } + + /** + * Called once at the end before close + */ + @Override + public void finish() throws IOException { + if (finished) { + throw new IllegalStateException("NativeEnginesKNNVectorsWriter is already finished"); + } + finished = true; + if (quantizationStateWriter != null) { + quantizationStateWriter.writeFooter(); + } + flatVectorsWriter.finish(); + } + + /** + * Closes this stream and releases any system resources associated + * with it. If the stream is already closed then invoking this + * method has no effect. + * + *

As noted in {@link AutoCloseable#close()}, cases where the + * close may fail require careful attention. It is strongly advised + * to relinquish the underlying resources and to internally + * mark the {@code Closeable} as closed, prior to throwing + * the {@code IOException}. + * + * @throws java.io.IOException if an I/O error occurs + */ + @Override + public void close() throws IOException { + if (quantizationStateWriter != null) { + quantizationStateWriter.closeOutput(); + } + IOUtils.close(flatVectorsWriter); + } + + /** + * Return the memory usage of this object in bytes. Negative values are illegal. + */ + @Override + public long ramBytesUsed() { + return SHALLOW_SIZE + flatVectorsWriter.ramBytesUsed() + fields.stream() + .mapToLong(NativeEngineFieldVectorsWriter::ramBytesUsed) + .sum(); + } + + /** + * Retrieves the {@link org.opensearch.knn.index.vectorvalues.KNNVectorValues} for a specific field during a merge operation, based on the vector data type. + * + * @param vectorDataType The {@link org.opensearch.knn.index.VectorDataType} representing the type of vectors stored. + * @param fieldInfo The {@link org.apache.lucene.index.FieldInfo} object containing metadata about the field. + * @param mergeState The {@link org.apache.lucene.index.MergeState} representing the state of the merge operation. + * @param The type of vectors being processed. + * @return The {@link org.opensearch.knn.index.vectorvalues.KNNVectorValues} associated with the field during the merge. + * @throws java.io.IOException If an I/O error occurs during the retrieval. + */ + private KNNVectorValues getKNNVectorValuesForMerge( + final VectorDataType vectorDataType, + final FieldInfo fieldInfo, + final MergeState mergeState + ) { + try { + switch (fieldInfo.getVectorEncoding()) { + case FLOAT32: + FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + return getVectorValues(vectorDataType, mergedFloats); + case BYTE: + ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); + return getVectorValues(vectorDataType, mergedBytes); + default: + throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]"); + } + } catch (final IOException e) { + log.error("Unable to merge vectors for field [{}]", fieldInfo.getName(), e); + throw new IllegalStateException("Unable to merge vectors for field [" + fieldInfo.getName() + "]", e); + } + } + + private QuantizationState train( + final FieldInfo fieldInfo, + final Supplier> knnVectorValuesSupplier, + final int totalLiveDocs + ) throws IOException { + + final QuantizationService quantizationService = QuantizationService.getInstance(); + final QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo); + QuantizationState quantizationState = null; + if (quantizationParams != null && totalLiveDocs > 0) { + initQuantizationStateWriterIfNecessary(); + KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); + quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs); + quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState); + } + + return quantizationState; + } + + /** + * The {@link org.opensearch.knn.index.vectorvalues.KNNVectorValues} will be exhausted after this function run. So make sure that you are not sending the + * vectorsValues object which you plan to use later + */ + private int getLiveDocs(KNNVectorValues vectorValues) throws IOException { + // Count all the live docs as there vectorValues.totalLiveDocs() just gives the cost for the FloatVectorValues, + // and doesn't tell the correct number of docs, if there are deleted docs in the segment. So we are counting + // the total live docs here. + int liveDocs = 0; + while (vectorValues.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + liveDocs++; + } + return liveDocs; + } + + private void initQuantizationStateWriterIfNecessary() throws IOException { + if (quantizationStateWriter == null) { + quantizationStateWriter = new KNN990QuantizationStateWriter(segmentWriteState); + quantizationStateWriter.writeHeader(segmentWriteState); + } + } + + private boolean shouldSkipBuildingVectorDataStructure(final long docCount) { + if (approximateThreshold < 0) { + return true; + } + return docCount < approximateThreshold; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index fb93bfc073..78bd25a01e 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -26,6 +26,7 @@ import org.apache.lucene.util.RamUsageEstimator; import org.opensearch.common.StopWatch; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.NativeEngineFieldVectorsWriter; import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java b/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java index 9e210fcd93..6aa2ad9969 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java @@ -5,10 +5,11 @@ package org.opensearch.knn.index.codec; -import org.opensearch.index.codec.CodecServiceConfig; import org.apache.lucene.codecs.Codec; import org.opensearch.index.codec.CodecService; +import org.opensearch.index.codec.CodecServiceConfig; import org.opensearch.index.mapper.MapperService; +import org.opensearch.knn.index.remote.RemoteIndexBuilder; /** * KNNCodecService to inject the right KNNCodec version @@ -16,10 +17,12 @@ public class KNNCodecService extends CodecService { private final MapperService mapperService; + private final RemoteIndexBuilder remoteIndexBuilder; - public KNNCodecService(CodecServiceConfig codecServiceConfig) { + public KNNCodecService(CodecServiceConfig codecServiceConfig, RemoteIndexBuilder remoteIndexBuilder) { super(codecServiceConfig.getMapperService(), codecServiceConfig.getIndexSettings(), codecServiceConfig.getLogger()); mapperService = codecServiceConfig.getMapperService(); + this.remoteIndexBuilder = remoteIndexBuilder; } /** @@ -30,6 +33,6 @@ public KNNCodecService(CodecServiceConfig codecServiceConfig) { */ @Override public Codec codec(String name) { - return KNNCodecVersion.current().getKnnCodecSupplier().apply(super.codec(name), mapperService); + return KNNCodecVersion.current().getKnnCodecSupplier().apply(super.codec(name), mapperService, remoteIndexBuilder); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java b/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java index 0f03170c25..3d538009fe 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java @@ -8,16 +8,18 @@ import lombok.AllArgsConstructor; import lombok.Getter; import org.apache.lucene.backward_codecs.lucene91.Lucene91Codec; +import org.apache.lucene.backward_codecs.lucene912.Lucene912Codec; import org.apache.lucene.backward_codecs.lucene92.Lucene92Codec; import org.apache.lucene.backward_codecs.lucene94.Lucene94Codec; -import org.apache.lucene.codecs.Codec; import org.apache.lucene.backward_codecs.lucene95.Lucene95Codec; import org.apache.lucene.backward_codecs.lucene99.Lucene99Codec; -import org.apache.lucene.backward_codecs.lucene912.Lucene912Codec; +import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.lucene101.Lucene101Codec; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.opensearch.common.TriFunction; import org.opensearch.index.mapper.MapperService; import org.opensearch.knn.index.codec.KNN10010Codec.KNN10010Codec; +import org.opensearch.knn.index.codec.KNN10010Codec.KNN10010PerFieldKnnVectorsFormat; import org.opensearch.knn.index.codec.KNN80Codec.KNN80CompoundFormat; import org.opensearch.knn.index.codec.KNN80Codec.KNN80DocValuesFormat; import org.opensearch.knn.index.codec.KNN910Codec.KNN910Codec; @@ -31,9 +33,9 @@ import org.opensearch.knn.index.codec.KNN950Codec.KNN950PerFieldKnnVectorsFormat; import org.opensearch.knn.index.codec.KNN990Codec.KNN990Codec; import org.opensearch.knn.index.codec.KNN990Codec.KNN990PerFieldKnnVectorsFormat; +import org.opensearch.knn.index.remote.RemoteIndexBuilder; import java.util.Optional; -import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Supplier; @@ -53,7 +55,7 @@ public enum KNNCodecVersion { new KNN80DocValuesFormat(delegate.docValuesFormat()), new KNN80CompoundFormat(delegate.compoundFormat()) ), - (userCodec, mapperService) -> new KNN910Codec(userCodec), + (userCodec, mapperService, remoteIndexBuilder) -> new KNN910Codec(userCodec), KNN910Codec::new ), @@ -65,7 +67,7 @@ public enum KNNCodecVersion { new KNN80DocValuesFormat(delegate.docValuesFormat()), new KNN80CompoundFormat(delegate.compoundFormat()) ), - (userCodec, mapperService) -> KNN920Codec.builder() + (userCodec, mapperService, remoteIndexBuilder) -> KNN920Codec.builder() .delegate(userCodec) .knnVectorsFormat(new KNN920PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) .build(), @@ -80,7 +82,7 @@ public enum KNNCodecVersion { new KNN80DocValuesFormat(delegate.docValuesFormat()), new KNN80CompoundFormat(delegate.compoundFormat()) ), - (userCodec, mapperService) -> KNN940Codec.builder() + (userCodec, mapperService, remoteIndexBuilder) -> KNN940Codec.builder() .delegate(userCodec) .knnVectorsFormat(new KNN940PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) .build(), @@ -95,7 +97,7 @@ public enum KNNCodecVersion { new KNN80DocValuesFormat(delegate.docValuesFormat()), new KNN80CompoundFormat(delegate.compoundFormat()) ), - (userCodec, mapperService) -> KNN950Codec.builder() + (userCodec, mapperService, remoteIndexBuilder) -> KNN950Codec.builder() .delegate(userCodec) .knnVectorsFormat(new KNN950PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) .build(), @@ -110,7 +112,7 @@ public enum KNNCodecVersion { new KNN80DocValuesFormat(delegate.docValuesFormat()), new KNN80CompoundFormat(delegate.compoundFormat()) ), - (userCodec, mapperService) -> KNN990Codec.builder() + (userCodec, mapperService, remoteIndexBuilder) -> KNN990Codec.builder() .delegate(userCodec) .knnVectorsFormat(new KNN990PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) .build(), @@ -125,7 +127,7 @@ public enum KNNCodecVersion { new KNN80DocValuesFormat(delegate.docValuesFormat()), new KNN80CompoundFormat(delegate.compoundFormat()) ), - (userCodec, mapperService) -> KNN9120Codec.builder() + (userCodec, mapperService, remoteIndexBuilder) -> KNN9120Codec.builder() .delegate(userCodec) .knnVectorsFormat(new KNN9120PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) .mapperService(mapperService) @@ -135,14 +137,14 @@ public enum KNNCodecVersion { V_10_01_0( "KNN10010Codec", new Lucene101Codec(), - new KNN9120PerFieldKnnVectorsFormat(Optional.empty()), + new KNN10010PerFieldKnnVectorsFormat(Optional.empty()), (delegate) -> new KNNFormatFacade( new KNN80DocValuesFormat(delegate.docValuesFormat()), new KNN80CompoundFormat(delegate.compoundFormat()) ), - (userCodec, mapperService) -> KNN10010Codec.builder() + (userCodec, mapperService, remoteIndexBuilder) -> KNN10010Codec.builder() .delegate(userCodec) - .knnVectorsFormat(new KNN9120PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) + .knnVectorsFormat(new KNN10010PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService), remoteIndexBuilder)) .mapperService(mapperService) .build(), KNN10010Codec::new @@ -154,7 +156,7 @@ public enum KNNCodecVersion { private final Codec defaultCodecDelegate; private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat; private final Function knnFormatFacadeSupplier; - private final BiFunction knnCodecSupplier; + private final TriFunction knnCodecSupplier; private final Supplier defaultKnnCodecSupplier; public static final KNNCodecVersion current() { diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/NativeEngineFieldVectorsWriter.java similarity index 91% rename from src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java rename to src/main/java/org/opensearch/knn/index/codec/NativeEngineFieldVectorsWriter.java index 88eee0ee7e..204938efdb 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/NativeEngineFieldVectorsWriter.java @@ -1,15 +1,9 @@ /* + * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. */ -package org.opensearch.knn.index.codec.KNN990Codec; +package org.opensearch.knn.index.codec; import lombok.Getter; import org.apache.lucene.codecs.KnnFieldVectorsWriter; @@ -29,7 +23,7 @@ * * @param float[] or byte[] */ -class NativeEngineFieldVectorsWriter extends KnnFieldVectorsWriter { +public class NativeEngineFieldVectorsWriter extends KnnFieldVectorsWriter { private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(NativeEngineFieldVectorsWriter.class); @Getter private final FieldInfo fieldInfo; @@ -47,7 +41,7 @@ class NativeEngineFieldVectorsWriter extends KnnFieldVectorsWriter { private final FlatFieldVectorsWriter flatFieldVectorsWriter; @SuppressWarnings("unchecked") - static NativeEngineFieldVectorsWriter create( + public static NativeEngineFieldVectorsWriter create( final FieldInfo fieldInfo, final FlatFieldVectorsWriter flatFieldVectorsWriter, final InfoStream infoStream diff --git a/src/main/java/org/opensearch/knn/index/remote/RemoteIndexBuilder.java b/src/main/java/org/opensearch/knn/index/remote/RemoteIndexBuilder.java new file mode 100644 index 0000000000..f036b3c701 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/remote/RemoteIndexBuilder.java @@ -0,0 +1,154 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.remote; + +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.SegmentWriteState; +import org.opensearch.common.StopWatch; +import org.opensearch.index.IndexSettings; +import org.opensearch.knn.common.featureflags.KNNFeatureFlags; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.repositories.RepositoriesService; +import org.opensearch.repositories.Repository; +import org.opensearch.repositories.RepositoryMissingException; +import org.opensearch.repositories.blobstore.BlobStoreRepository; + +import java.io.IOException; +import java.util.function.Supplier; + +import static org.opensearch.knn.index.KNNSettings.KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING; +import static org.opensearch.knn.index.KNNSettings.KNN_REMOTE_VECTOR_REPO_SETTING; + +/** + * This class orchestrates building vector indices. It handles uploading data to a repository, submitting a remote + * build request, awaiting upon the build request to complete, and finally downloading the data from a repository. + * Note: + * - This does not implement {@link org.opensearch.knn.index.codec.nativeindex.NativeIndexBuildStrategy} as we need to use + * multiple iterators to upload the vectors in parallel. In order to do so we need to refactor NativeIndexBuildStrategy first. + * - This class is created per-index, so we cannot store per-segment information as class fields. //TODO: Validate this, it's probably created per shard still + */ +@Log4j2 +public class RemoteIndexBuilder { + + private final Supplier repositoriesServiceSupplier; + private final IndexSettings indexSettings; + private static final String VECTOR_BLOB_FILE_EXTENSION = ".knnvec"; + private static final String DOC_ID_FILE_EXTENSION = ".knndid"; + private static final String GRAPH_FILE_EXTENSION = ".knngraph"; + + /** + * Public constructor + * + * @param repositoriesServiceSupplier A supplier for {@link RepositoriesService} used for interacting with repository + * @param indexSettings + */ + public RemoteIndexBuilder(Supplier repositoriesServiceSupplier, IndexSettings indexSettings) { + this.repositoriesServiceSupplier = repositoriesServiceSupplier; + this.indexSettings = indexSettings; + } + + /** + * @return whether to use the remote build feature + */ + public boolean shouldBuildIndexRemotely() { + return KNNFeatureFlags.isKNNRemoteVectorBuildEnabled() && KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING.get(indexSettings.getSettings()); + } + + /** + * Gets the KNN repository container from the repository service. + * + * @return {@link RepositoriesService} + * @throws RepositoryMissingException if repository is not registered or if {@link KNN_REMOTE_VECTOR_REPO_SETTING} is not set + */ + private BlobStoreRepository getRepository() throws RepositoryMissingException { + RepositoriesService repositoriesService = repositoriesServiceSupplier.get(); + String vectorRepo = KNNSettings.state().getSettingValue(KNN_REMOTE_VECTOR_REPO_SETTING.getKey()); + if (vectorRepo == null || vectorRepo.isEmpty()) { + throw new RepositoryMissingException("Vector repository " + KNN_REMOTE_VECTOR_REPO_SETTING.getKey() + " is not registered"); + } + final Repository repository = repositoriesService.repository(vectorRepo); + assert repository instanceof BlobStoreRepository : "Repository should be instance of BlobStoreRepository"; + return (BlobStoreRepository) repository; + } + + /** + * 1. upload files, 2 trigger build, 3 download the graph, 4 write to indexoutput + * + * @param fieldInfo + * @param knnVectorValuesSupplier + * @param totalLiveDocs + */ + public void buildIndexRemotely( + FieldInfo fieldInfo, + Supplier> knnVectorValuesSupplier, + int totalLiveDocs, + SegmentWriteState segmentWriteState + ) throws IOException, InterruptedException { + StopWatch stopWatch; + long time_in_millis; + + stopWatch = new StopWatch().start(); + writeToRepository(fieldInfo, knnVectorValuesSupplier, totalLiveDocs, segmentWriteState); + time_in_millis = stopWatch.stop().totalTime().millis(); + log.debug("Repository write took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName()); + + stopWatch = new StopWatch().start(); + submitVectorBuild(); + time_in_millis = stopWatch.stop().totalTime().millis(); + log.debug("Submit vector build took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName()); + + stopWatch = new StopWatch().start(); + awaitVectorBuild(); + time_in_millis = stopWatch.stop().totalTime().millis(); + log.debug("Await vector build took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName()); + + stopWatch = new StopWatch().start(); + readFromRepository(); + time_in_millis = stopWatch.stop().totalTime().millis(); + log.debug("Repository read took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName()); + } + + /** + * Write relevant vector data to repository + * + * @param fieldInfo + * @param knnVectorValuesSupplier + * @param totalLiveDocs + * @param segmentWriteState + * @throws IOException + * @throws InterruptedException + */ + private void writeToRepository( + FieldInfo fieldInfo, + Supplier> knnVectorValuesSupplier, + int totalLiveDocs, + SegmentWriteState segmentWriteState + ) { + + } + + /** + * Submit vector build request to remote vector build service + * + */ + private void submitVectorBuild() {} + + /** + * Wait on remote vector build to complete + */ + private void awaitVectorBuild() { + + } + + /** + * Read constructed vector file from remote repository and write to IndexOutput + */ + private void readFromRepository() { + + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index 44c824862a..7582006fca 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -5,37 +5,56 @@ package org.opensearch.knn.plugin; +import com.google.common.collect.ImmutableList; +import org.opensearch.action.ActionRequest; import org.opensearch.cluster.NamedDiff; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.IndexScopedSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.settings.SettingsFilter; import org.opensearch.core.ParseField; import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.env.Environment; +import org.opensearch.env.NodeEnvironment; +import org.opensearch.index.IndexModule; +import org.opensearch.index.IndexSettings; import org.opensearch.index.codec.CodecServiceFactory; import org.opensearch.index.engine.EngineFactory; +import org.opensearch.index.mapper.Mapper; import org.opensearch.indices.SystemIndexDescriptor; import org.opensearch.knn.index.KNNCircuitBreaker; -import org.opensearch.knn.index.memory.NativeMemoryCacheManager; -import org.opensearch.knn.plugin.search.KNNConcurrentSearchRequestDecider; -import org.opensearch.knn.index.util.KNNClusterUtil; -import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; - -import org.opensearch.knn.index.query.parser.KNNQueryBuilderParser; -import org.opensearch.knn.index.query.KNNWeight; import org.opensearch.knn.index.codec.KNNCodecService; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; -import org.opensearch.knn.indices.ModelGraveyard; +import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.knn.index.query.KNNWeight; +import org.opensearch.knn.index.query.parser.KNNQueryBuilderParser; +import org.opensearch.knn.index.remote.RemoteIndexBuilder; +import org.opensearch.knn.index.util.KNNClusterUtil; import org.opensearch.knn.indices.ModelCache; import org.opensearch.knn.indices.ModelDao; +import org.opensearch.knn.indices.ModelGraveyard; +import org.opensearch.knn.plugin.rest.RestClearCacheHandler; import org.opensearch.knn.plugin.rest.RestDeleteModelHandler; import org.opensearch.knn.plugin.rest.RestGetModelHandler; import org.opensearch.knn.plugin.rest.RestKNNStatsHandler; import org.opensearch.knn.plugin.rest.RestKNNWarmupHandler; import org.opensearch.knn.plugin.rest.RestSearchModelHandler; import org.opensearch.knn.plugin.rest.RestTrainModelHandler; -import org.opensearch.knn.plugin.rest.RestClearCacheHandler; import org.opensearch.knn.plugin.script.KNNScoringScriptEngine; +import org.opensearch.knn.plugin.search.KNNConcurrentSearchRequestDecider; import org.opensearch.knn.plugin.stats.KNNStats; +import org.opensearch.knn.plugin.transport.ClearCacheAction; +import org.opensearch.knn.plugin.transport.ClearCacheTransportAction; import org.opensearch.knn.plugin.transport.DeleteModelAction; import org.opensearch.knn.plugin.transport.DeleteModelTransportAction; import org.opensearch.knn.plugin.transport.GetModelAction; @@ -44,27 +63,6 @@ import org.opensearch.knn.plugin.transport.KNNStatsTransportAction; import org.opensearch.knn.plugin.transport.KNNWarmupAction; import org.opensearch.knn.plugin.transport.KNNWarmupTransportAction; -import org.opensearch.knn.plugin.transport.ClearCacheAction; -import org.opensearch.knn.plugin.transport.ClearCacheTransportAction; -import com.google.common.collect.ImmutableList; - -import org.opensearch.action.ActionRequest; -import org.opensearch.transport.client.Client; -import org.opensearch.cluster.metadata.IndexNameExpressionResolver; -import org.opensearch.cluster.node.DiscoveryNodes; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.core.common.io.stream.NamedWriteableRegistry; -import org.opensearch.common.settings.ClusterSettings; -import org.opensearch.common.settings.IndexScopedSettings; -import org.opensearch.common.settings.Setting; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.settings.SettingsFilter; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.env.Environment; -import org.opensearch.env.NodeEnvironment; -import org.opensearch.index.IndexModule; -import org.opensearch.index.IndexSettings; -import org.opensearch.index.mapper.Mapper; import org.opensearch.knn.plugin.transport.RemoveModelFromCacheAction; import org.opensearch.knn.plugin.transport.RemoveModelFromCacheTransportAction; import org.opensearch.knn.plugin.transport.SearchModelAction; @@ -76,10 +74,10 @@ import org.opensearch.knn.plugin.transport.TrainingModelAction; import org.opensearch.knn.plugin.transport.TrainingModelRequest; import org.opensearch.knn.plugin.transport.TrainingModelTransportAction; -import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction; -import org.opensearch.knn.plugin.transport.UpdateModelMetadataTransportAction; import org.opensearch.knn.plugin.transport.UpdateModelGraveyardAction; import org.opensearch.knn.plugin.transport.UpdateModelGraveyardTransportAction; +import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction; +import org.opensearch.knn.plugin.transport.UpdateModelMetadataTransportAction; import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCache; import org.opensearch.knn.training.TrainingJobClusterStateListener; import org.opensearch.knn.training.TrainingJobRunner; @@ -102,10 +100,11 @@ import org.opensearch.threadpool.ExecutorBuilder; import org.opensearch.threadpool.FixedExecutorBuilder; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.client.Client; import org.opensearch.watcher.ResourceWatcherService; -import java.util.Arrays; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.List; @@ -163,6 +162,7 @@ public class KNNPlugin extends Plugin private KNNStats knnStats; private ClusterService clusterService; + private Supplier repositoriesServiceSupplier; @Override public Map getMappers() { @@ -192,6 +192,7 @@ public Collection createComponents( Supplier repositoriesServiceSupplier ) { this.clusterService = clusterService; + this.repositoriesServiceSupplier = repositoriesServiceSupplier; // Initialize Native Memory loading strategies VectorReader vectorReader = new VectorReader(client); @@ -281,10 +282,14 @@ public Optional getEngineFactory(IndexSettings indexSettings) { return Optional.empty(); } + /** + * Get KNN codec service. + * We create {@link RemoteIndexBuilder} per-index in order to consume the {@link IndexSettings} from the {@link org.apache.lucene.codecs.KnnVectorsWriter} implementation. + */ @Override public Optional getCustomCodecServiceFactory(IndexSettings indexSettings) { if (indexSettings.getValue(KNNSettings.IS_KNN_INDEX_SETTING)) { - return Optional.of(KNNCodecService::new); + return Optional.of((config) -> new KNNCodecService(config, new RemoteIndexBuilder(repositoriesServiceSupplier, indexSettings))); } return Optional.empty(); } diff --git a/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index 0fa7314c85..bf9475ed51 100644 --- a/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -11,3 +11,4 @@ org.opensearch.knn.index.codec.KNN990Codec.NativeEngines990KnnVectorsFormat org.opensearch.knn.index.codec.KNN9120Codec.KNN9120HnswBinaryVectorsFormat +org.opensearch.knn.index.codec.KNN10010Codec.NativeEngines10010KnnVectorsFormat \ No newline at end of file diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN10010Codec/KNN10010CodecTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN10010Codec/KNN10010CodecTests.java index 2170276136..72c0a6eb75 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN10010Codec/KNN10010CodecTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN10010Codec/KNN10010CodecTests.java @@ -9,7 +9,6 @@ import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.opensearch.index.mapper.MapperService; -import org.opensearch.knn.index.codec.KNN9120Codec.KNN9120PerFieldKnnVectorsFormat; import org.opensearch.knn.index.codec.KNNCodecTestCase; import org.opensearch.knn.index.codec.KNNCodecVersion; @@ -31,7 +30,7 @@ public void testBuildFromModelTemplate() { // Ensure that the codec is able to return the correct per field knn vectors format for codec public void testCodecSetsCustomPerFieldKnnVectorsFormat() { final Codec codec = new KNN10010Codec(); - assertTrue(codec.knnVectorsFormat() instanceof KNN9120PerFieldKnnVectorsFormat); + assertTrue(codec.knnVectorsFormat() instanceof KNN10010PerFieldKnnVectorsFormat); } // IMPORTANT: When this Codec is moved to a backwards Codec, this test needs to be removed, because it attempts to @@ -39,7 +38,7 @@ public void testCodecSetsCustomPerFieldKnnVectorsFormat() { @SneakyThrows public void testKnnVectorIndex() { Function perFieldKnnVectorsFormatProvider = ( - mapperService) -> new KNN9120PerFieldKnnVectorsFormat(Optional.of(mapperService)); + mapperService) -> new KNN10010PerFieldKnnVectorsFormat(Optional.of(mapperService)); Function knnCodecProvider = (knnVectorFormat) -> KNN10010Codec.builder() .delegate(KNNCodecVersion.V_10_01_0.getDefaultCodecDelegate()) diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN10010Codec/NativeEngines10010KnnVectorsFormatTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN10010Codec/NativeEngines10010KnnVectorsFormatTests.java new file mode 100644 index 0000000000..ed7339324f --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN10010Codec/NativeEngines10010KnnVectorsFormatTests.java @@ -0,0 +1,377 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.codec.KNN10010Codec; + +import lombok.SneakyThrows; +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.FieldType; +import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexOptions; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.NoMergePolicy; +import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentReader; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.SerialMergeScheduler; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Sort; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.store.BaseDirectoryWrapper; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.InfoStream; +import org.apache.lucene.util.Version; +import org.junit.After; +import org.junit.Assert; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.mockito.stubbing.Answer; +import org.opensearch.common.lucene.Lucene; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.KNN990Codec.NativeEngines990KnnVectorsReader; +import org.opensearch.knn.index.codec.KNN990Codec.UnitTestCodec; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.qframe.QuantizationConfig; +import org.opensearch.knn.index.engine.qframe.QuantizationConfigParser; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; + +@Log4j2 +public class NativeEngines10010KnnVectorsFormatTests extends KNNTestCase { + private static final Codec TESTING_CODEC = new UnitTestCodec(); + private static final String FLAT_VECTOR_FILE_EXT = ".vec"; + private static final String FAISS_ENGINE_FILE_EXT = ".faiss"; + private static final String FLOAT_VECTOR_FIELD = "float_field"; + private static final String FLOAT_VECTOR_FIELD_BINARY = "float_binary_field"; + private static final String BYTE_VECTOR_FIELD = "byte_field"; + private Directory dir; + private RandomIndexWriter indexWriter; + + @After + public void tearDown() throws Exception { + if (dir != null) { + dir.close(); + } + super.tearDown(); + } + + @SneakyThrows + public void testReaderAndWriter_whenValidInput_thenSuccess() { + final Lucene99FlatVectorsFormat mockedFlatVectorsFormat = Mockito.mock(Lucene99FlatVectorsFormat.class); + + final String segmentName = "test-segment-name"; + + final SegmentInfo mockedSegmentInfo = new SegmentInfo( + Mockito.mock(Directory.class), + Mockito.mock(Version.class), + Mockito.mock(Version.class), + segmentName, + 0, + false, + false, + Mockito.mock(Codec.class), + Mockito.mock(Map.class), + new byte[16], + Mockito.mock(Map.class), + Mockito.mock(Sort.class) + ); + + final String segmentSuffix = "test-segment-suffix"; + + Directory directory = Mockito.mock(Directory.class); + IndexInput input = Mockito.mock(IndexInput.class); + Mockito.when(directory.openInput(any(), any())).thenReturn(input); + + String fieldName = "test-field"; + FieldInfos fieldInfos = Mockito.mock(FieldInfos.class); + FieldInfo fieldInfo = Mockito.mock(FieldInfo.class); + Mockito.when(fieldInfo.getName()).thenReturn(fieldName); + Mockito.when(fieldInfos.fieldInfo(anyInt())).thenReturn(fieldInfo); + Mockito.when(fieldInfos.iterator()).thenReturn(new Iterator() { + @Override + public boolean hasNext() { + return false; + } + + @Override + public FieldInfo next() { + return null; + } + }); + + final SegmentReadState mockedSegmentReadState = new SegmentReadState( + directory, + mockedSegmentInfo, + fieldInfos, + Mockito.mock(IOContext.class), + segmentSuffix + ); + + final SegmentWriteState mockedSegmentWriteState = new SegmentWriteState( + Mockito.mock(InfoStream.class), + Mockito.mock(Directory.class), + mockedSegmentInfo, + Mockito.mock(FieldInfos.class), + null, + Mockito.mock(IOContext.class) + ); + Mockito.when(mockedFlatVectorsFormat.fieldsReader(mockedSegmentReadState)).thenReturn(Mockito.mock(FlatVectorsReader.class)); + Mockito.when(mockedFlatVectorsFormat.fieldsWriter(mockedSegmentWriteState)).thenReturn(Mockito.mock(FlatVectorsWriter.class)); + + final NativeEngines10010KnnVectorsFormat nativeEngines10010KnnVectorsFormat = new NativeEngines10010KnnVectorsFormat( + mockedFlatVectorsFormat + ); + try (MockedStatic mockedStaticCodecUtil = Mockito.mockStatic(CodecUtil.class)) { + mockedStaticCodecUtil.when( + () -> CodecUtil.writeIndexHeader(any(IndexOutput.class), anyString(), anyInt(), any(byte[].class), anyString()) + ).thenAnswer((Answer) invocation -> null); + mockedStaticCodecUtil.when(() -> CodecUtil.retrieveChecksum(any(IndexInput.class))) + .thenAnswer((Answer) invocation -> null); + Assert.assertTrue( + nativeEngines10010KnnVectorsFormat.fieldsReader(mockedSegmentReadState) instanceof NativeEngines990KnnVectorsReader + ); + + Assert.assertTrue( + nativeEngines10010KnnVectorsFormat.fieldsWriter(mockedSegmentWriteState) instanceof NativeEngines10010KnnVectorsWriter + ); + } + } + + @SneakyThrows + public void testNativeEngineVectorFormat_whenMultipleVectorFieldIndexed_thenSuccess() { + setup(); + float[] floatVector = { 1.0f, 3.0f, 4.0f }; + byte[] byteVector = { 6, 14 }; + + FieldType fieldTypeForFloat = createVectorField(3, VectorEncoding.FLOAT32, VectorDataType.FLOAT); + fieldTypeForFloat.putAttribute(KNNConstants.PARAMETERS, "{ \"index_description\":\"HNSW16,Flat\", \"spaceType\": \"l2\"}"); + fieldTypeForFloat.freeze(); + addFieldToIndex(new KnnFloatVectorField(FLOAT_VECTOR_FIELD, floatVector, fieldTypeForFloat), indexWriter); + FieldType fieldTypeForByte = createVectorField(2, VectorEncoding.BYTE, VectorDataType.BINARY); + fieldTypeForByte.putAttribute(KNNConstants.PARAMETERS, "{ \"index_description\":\"HNSW16,Flat\", \"spaceType\": \"l2\"}"); + fieldTypeForByte.freeze(); + addFieldToIndex(new KnnByteVectorField(BYTE_VECTOR_FIELD, byteVector, fieldTypeForByte), indexWriter); + + float[] floatVectorForBinaryQuantization_1 = { 1.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f }; + float[] floatVectorForBinaryQuantization_2 = { 1.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f }; + FieldType fieldTypeForBinaryQuantization = createVectorField(8, VectorEncoding.FLOAT32, VectorDataType.FLOAT); + fieldTypeForBinaryQuantization.putAttribute(KNNConstants.PARAMETERS, "{ \"index_description\":\"BHNSW32\", \"spaceType\": \"l2\"}"); + QuantizationConfig quantizationConfig = QuantizationConfig.builder().quantizationType(ScalarQuantizationType.ONE_BIT).build(); + fieldTypeForBinaryQuantization.putAttribute(KNNConstants.QFRAMEWORK_CONFIG, QuantizationConfigParser.toCsv(quantizationConfig)); + fieldTypeForBinaryQuantization.freeze(); + + addFieldToIndex( + new KnnFloatVectorField(FLOAT_VECTOR_FIELD_BINARY, floatVectorForBinaryQuantization_1, fieldTypeForBinaryQuantization), + indexWriter + ); + addFieldToIndex( + new KnnFloatVectorField(FLOAT_VECTOR_FIELD_BINARY, floatVectorForBinaryQuantization_2, fieldTypeForBinaryQuantization), + indexWriter + ); + + final IndexReader indexReader = indexWriter.getReader(); + // ensuring segments are created + indexWriter.flush(); + indexWriter.commit(); + indexWriter.close(); + + // Validate to see if correct values are returned, assumption here is only 1 segment is getting created + IndexSearcher searcher = new IndexSearcher(indexReader); + final LeafReader leafReader = searcher.getLeafContexts().get(0).reader(); + SegmentReader segmentReader = Lucene.segmentReader(leafReader); + final List hnswfiles = getFilesFromSegment(dir, FAISS_ENGINE_FILE_EXT); + assertEquals(3, hnswfiles.size()); + assertEquals(hnswfiles.stream().filter(x -> x.contains(FLOAT_VECTOR_FIELD)).count(), 1); + assertEquals(hnswfiles.stream().filter(x -> x.contains(BYTE_VECTOR_FIELD)).count(), 1); + assertEquals(hnswfiles.stream().filter(x -> x.contains(FLOAT_VECTOR_FIELD_BINARY)).count(), 1); + + // Even setting IWC to not use compound file it still uses compound file, hence ensuring we don't check .vec + // file in case segment uses compound format. use this seed once we fix this to validate everything is + // working or not. -Dtests.seed=CAAE1B8D573EEB7E + if (segmentReader.getSegmentInfo().info.getUseCompoundFile() == false) { + final List vecfiles = getFilesFromSegment(dir, FLAT_VECTOR_FILE_EXT); + // 2 .vec files will be created as we are using per field vectors format. + assertEquals(3, vecfiles.size()); + } + + final FloatVectorValues floatVectorValues = leafReader.getFloatVectorValues(FLOAT_VECTOR_FIELD); + floatVectorValues.iterator().nextDoc(); + assertArrayEquals(floatVector, floatVectorValues.vectorValue(floatVectorValues.iterator().index()), 0.0f); + assertEquals(1, floatVectorValues.size()); + assertEquals(3, floatVectorValues.dimension()); + + final ByteVectorValues byteVectorValues = leafReader.getByteVectorValues(BYTE_VECTOR_FIELD); + byteVectorValues.iterator().nextDoc(); + assertArrayEquals(byteVector, byteVectorValues.vectorValue(byteVectorValues.iterator().index())); + assertEquals(1, byteVectorValues.size()); + assertEquals(2, byteVectorValues.dimension()); + + final FloatVectorValues floatVectorValuesForBinaryQuantization = leafReader.getFloatVectorValues(FLOAT_VECTOR_FIELD_BINARY); + floatVectorValuesForBinaryQuantization.iterator().nextDoc(); + assertArrayEquals( + floatVectorForBinaryQuantization_1, + floatVectorValuesForBinaryQuantization.vectorValue(floatVectorValuesForBinaryQuantization.iterator().index()), + 0.0f + ); + assertEquals(2, floatVectorValuesForBinaryQuantization.size()); + assertEquals(8, floatVectorValuesForBinaryQuantization.dimension()); + + Assert.assertThrows( + UnsupportedOperationException.class, + () -> leafReader.searchNearestVectors(FLOAT_VECTOR_FIELD, floatVector, 10, new Bits.MatchAllBits(1), 10) + ); + + Assert.assertThrows( + UnsupportedOperationException.class, + () -> leafReader.searchNearestVectors(BYTE_VECTOR_FIELD, byteVector, 10, new Bits.MatchAllBits(1), 10) + ); + // do it at the end so that all search is completed + indexReader.close(); + } + + @SneakyThrows + public void testNativeEngineVectorFormat_whenBinaryQuantizationApplied_thenSuccess() { + setup(); + float[] floatVectorForBinaryQuantization = { 1.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f }; + FieldType fieldTypeForBinaryQuantization = createVectorField(8, VectorEncoding.FLOAT32, VectorDataType.FLOAT); + fieldTypeForBinaryQuantization.putAttribute(KNNConstants.PARAMETERS, "{ \"index_description\":\"BHNSW32\", \"spaceType\": \"l2\"}"); + QuantizationConfig quantizationConfig = QuantizationConfig.builder().quantizationType(ScalarQuantizationType.ONE_BIT).build(); + fieldTypeForBinaryQuantization.putAttribute(KNNConstants.QFRAMEWORK_CONFIG, QuantizationConfigParser.toCsv(quantizationConfig)); + + addFieldToIndex( + new KnnFloatVectorField(FLOAT_VECTOR_FIELD_BINARY, floatVectorForBinaryQuantization, fieldTypeForBinaryQuantization), + indexWriter + ); + + final IndexReader indexReader = indexWriter.getReader(); + // ensuring segments are created + indexWriter.flush(); + indexWriter.commit(); + indexWriter.close(); + + IndexSearcher searcher = new IndexSearcher(indexReader); + final LeafReader leafReader = searcher.getLeafContexts().get(0).reader(); + SegmentReader segmentReader = Lucene.segmentReader(leafReader); + if (segmentReader.getSegmentInfo().info.getUseCompoundFile() == false) { + final List vecfiles = getFilesFromSegment(dir, FLAT_VECTOR_FILE_EXT); + // 2 .vec files will be created as we are using per field vectors format. + assertEquals(1, vecfiles.size()); + } + + final FloatVectorValues floatVectorValues = leafReader.getFloatVectorValues(FLOAT_VECTOR_FIELD_BINARY); + KnnVectorValues.DocIndexIterator docIndexIterator = floatVectorValues.iterator(); + docIndexIterator.nextDoc(); + assertArrayEquals(floatVectorForBinaryQuantization, floatVectorValues.vectorValue(docIndexIterator.index()), 0.0f); + assertEquals(1, floatVectorValues.size()); + assertEquals(8, floatVectorValues.dimension()); + indexReader.close(); + } + + public void testFormatName_withValidInput_thenSuccess() { + final String validFormatName = "NativeEngines10010KnnVectorsFormat"; + Assert.assertEquals(validFormatName, new NativeEngines10010KnnVectorsFormat().getName()); + Assert.assertEquals( + validFormatName, + new NativeEngines10010KnnVectorsFormat(new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer())) + .getName() + ); + } + + private List getFilesFromSegment(Directory dir, String fileFormat) throws IOException { + return Arrays.stream(dir.listAll()).filter(x -> x.contains(fileFormat)).collect(Collectors.toList()); + } + + /** + * This should have been annotated with @Before, but somehow when I annotate with @Before apart from running + * before tests, it is also running independently and failing. Need to figure this out. + * @throws java.io.IOException + */ + private void setup() throws IOException { + dir = newFSDirectory(createTempDir()); + // on the mock directory Lucene goes ahead and does a search on different fields. We want to avoid that as of + // now. Given we have not implemented search for the native engine format using codec, the dir.close fails + // with exception. Hence, marking this as false. + ((BaseDirectoryWrapper) dir).setCheckIndexOnClose(false); + indexWriter = createIndexWriter(dir); + } + + private RandomIndexWriter createIndexWriter(final Directory dir) throws IOException { + final IndexWriterConfig iwc = newIndexWriterConfig(); + iwc.setMergeScheduler(new SerialMergeScheduler()); + iwc.setCodec(TESTING_CODEC); + iwc.setUseCompoundFile(false); + // Set merge policy to no merges so that we create a predictable number of segments. + iwc.setMergePolicy(NoMergePolicy.INSTANCE); + return new RandomIndexWriter(random(), dir, iwc); + } + + private void addFieldToIndex(final Field vectorField, final RandomIndexWriter indexWriter) throws IOException { + final Document doc1 = new Document(); + doc1.add(vectorField); + indexWriter.addDocument(doc1); + } + + private FieldType createVectorField(int dimension, VectorEncoding vectorEncoding, VectorDataType vectorDataType) { + FieldType nativeVectorField = new FieldType(); + // TODO: Replace this with the default field which will be created in mapper for Native Engines with KNNVectorsFormat + nativeVectorField.setTokenized(false); + nativeVectorField.setIndexOptions(IndexOptions.NONE); + nativeVectorField.putAttribute(KNNVectorFieldMapper.KNN_FIELD, "true"); + nativeVectorField.putAttribute(KNNConstants.KNN_METHOD, KNNConstants.METHOD_HNSW); + nativeVectorField.putAttribute(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()); + nativeVectorField.putAttribute(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()); + nativeVectorField.putAttribute(KNNConstants.HNSW_ALGO_M, "32"); + nativeVectorField.putAttribute(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, "512"); + nativeVectorField.putAttribute(KNNConstants.VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()); + nativeVectorField.setVectorAttributes( + dimension, + vectorEncoding, + SpaceType.L2.getKnnVectorSimilarityFunction().getVectorSimilarityFunction() + ); + return nativeVectorField; + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN10010Codec/NativeEngines10010KnnVectorsWriterFlushTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN10010Codec/NativeEngines10010KnnVectorsWriterFlushTests.java new file mode 100644 index 0000000000..b89e2840dc --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN10010Codec/NativeEngines10010KnnVectorsWriterFlushTests.java @@ -0,0 +1,859 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN10010Codec; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import lombok.RequiredArgsConstructor; +import lombok.SneakyThrows; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.VectorEncoding; +import org.mockito.Mock; +import org.mockito.MockedConstruction; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.KNN990Codec.KNN990QuantizationStateWriter; +import org.opensearch.knn.index.codec.NativeEngineFieldVectorsWriter; +import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; +import org.opensearch.knn.index.quantizationservice.QuantizationService; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; +import org.opensearch.knn.index.vectorvalues.TestVectorValues; +import org.opensearch.knn.plugin.stats.KNNGraphValue; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static com.carrotsearch.randomizedtesting.RandomizedTest.$; +import static com.carrotsearch.randomizedtesting.RandomizedTest.$$; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockConstruction; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +@RequiredArgsConstructor +public class NativeEngines10010KnnVectorsWriterFlushTests extends OpenSearchTestCase { + + @Mock + private FlatVectorsWriter flatVectorsWriter; + @Mock + private SegmentWriteState segmentWriteState; + @Mock + private QuantizationParams quantizationParams; + @Mock + private QuantizationState quantizationState; + @Mock + private QuantizationService quantizationService; + @Mock + private NativeIndexWriter nativeIndexWriter; + + private FlatFieldVectorsWriter mockedFlatFieldVectorsWriter; + + private NativeEngines10010KnnVectorsWriter objectUnderTest; + + private final String description; + private final List> vectorsPerField; + private static final Integer BUILD_GRAPH_ALWAYS_THRESHOLD = 0; + private static final Integer BUILD_GRAPH_NEVER_THRESHOLD = -1; + + @Override + public void setUp() throws Exception { + super.setUp(); + MockitoAnnotations.openMocks(this); + objectUnderTest = new NativeEngines10010KnnVectorsWriter(segmentWriteState, flatVectorsWriter, BUILD_GRAPH_ALWAYS_THRESHOLD); + mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class); + Mockito.doNothing().when(mockedFlatFieldVectorsWriter).addValue(Mockito.anyInt(), Mockito.any()); + Mockito.when(flatVectorsWriter.addField(Mockito.any())).thenReturn(mockedFlatFieldVectorsWriter); + } + + @ParametersFactory + public static Collection data() { + return Arrays.asList( + $$( + $("Single field", List.of(Map.of(0, new float[] { 1, 2, 3 }, 1, new float[] { 2, 3, 4 }, 2, new float[] { 3, 4, 5 }))), + $("Single field, no total live docs", List.of()), + $( + "Multi Field", + List.of( + Map.of(0, new float[] { 1, 2, 3 }, 1, new float[] { 2, 3, 4 }, 2, new float[] { 3, 4, 5 }), + Collections.emptyMap(), + Map.of( + 0, + new float[] { 1, 2, 3, 4 }, + 1, + new float[] { 2, 3, 4, 5 }, + 2, + new float[] { 3, 4, 5, 6 }, + 3, + new float[] { 4, 5, 6, 7 } + ) + ) + ) + ) + ); + } + + @SneakyThrows + public void testFlush() { + // Given + final List> expectedVectorValues = vectorsPerField.stream().map(vectors -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectors.values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + return knnVectorValues; + }).collect(Collectors.toList()); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when( + () -> NativeEngineFieldVectorsWriter.create(fieldInfo, mockedFlatFieldVectorsWriter, segmentWriteState.infoStream) + ).thenReturn(field); + + try { + objectUnderTest.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + }); + + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + objectUnderTest.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + assertNotEquals(0L, (long) KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue()); + } + + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + try { + if (vectorsPerField.get(i).isEmpty()) { + verify(nativeIndexWriter, never()).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } else { + verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + final Long expectedTimesGetVectorValuesIsCalled = vectorsPerField.stream().filter(Predicate.not(Map::isEmpty)).count(); + knnVectorValuesFactoryMockedStatic.verify( + () -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), + times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled)) + ); + } + } + + @SneakyThrows + public void testFlush_WithQuantization() { + // Given + List> expectedVectorValues = new ArrayList<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + expectedVectorValues.add(knnVectorValues); + + }); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when( + () -> NativeEngineFieldVectorsWriter.create(fieldInfo, mockedFlatFieldVectorsWriter, segmentWriteState.infoStream) + ).thenReturn(field); + + try { + objectUnderTest.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); + try { + when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size())) + .thenReturn(quantizationState); + } catch (Exception e) { + throw new RuntimeException(e); + } + + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)) + .thenReturn(nativeIndexWriter); + }); + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + objectUnderTest.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeHeader(segmentWriteState); + assertTrue(KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue() > 0L); + } else { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + } + + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + try { + if (vectorsPerField.get(i).isEmpty()) { + verify(knn990QuantWriterMockedConstruction.constructed().get(0), never()).writeState(i, quantizationState); + verify(nativeIndexWriter, never()).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } else { + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(i, quantizationState); + verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + final Long expectedTimesGetVectorValuesIsCalled = vectorsPerField.stream().filter(Predicate.not(Map::isEmpty)).count(); + knnVectorValuesFactoryMockedStatic.verify( + () -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), + times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled) * 2) + ); + } + } + + public void testFlush_whenThresholdIsNegative_thenNativeIndexWriterIsNeverCalled() throws IOException { + // Given + List> expectedVectorValues = new ArrayList<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + expectedVectorValues.add(knnVectorValues); + + }); + + final NativeEngines10010KnnVectorsWriter nativeEngineWriter = new NativeEngines10010KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + BUILD_GRAPH_NEVER_THRESHOLD + ); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when( + () -> NativeEngineFieldVectorsWriter.create(fieldInfo, mockedFlatFieldVectorsWriter, segmentWriteState.infoStream) + ).thenReturn(field); + try { + nativeEngineWriter.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + }); + + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + nativeEngineWriter.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + } + verifyNoInteractions(nativeIndexWriter); + } + } + + public void testFlush_whenThresholdIsGreaterThanVectorSize_thenNativeIndexWriterIsNeverCalled() throws IOException { + // Given + List> expectedVectorValues = new ArrayList<>(); + final Map sizeMap = new HashMap<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + sizeMap.put(i, randomVectorValues.size()); + expectedVectorValues.add(knnVectorValues); + + }); + final int maxThreshold = sizeMap.values().stream().filter(count -> count != 0).max(Integer::compareTo).orElse(0); + final NativeEngines10010KnnVectorsWriter nativeEngineWriter = new NativeEngines10010KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + maxThreshold + 1 + ); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when( + () -> NativeEngineFieldVectorsWriter.create(fieldInfo, mockedFlatFieldVectorsWriter, segmentWriteState.infoStream) + ).thenReturn(field); + try { + nativeEngineWriter.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + }); + + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + nativeEngineWriter.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + } + verifyNoInteractions(nativeIndexWriter); + } + } + + public void testFlush_whenThresholdIsEqualToMinNumberOfVectors_thenNativeIndexWriterIsCalled() throws IOException { + // Given + List> expectedVectorValues = new ArrayList<>(); + final Map sizeMap = new HashMap<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + sizeMap.put(i, randomVectorValues.size()); + expectedVectorValues.add(knnVectorValues); + + }); + + final int minThreshold = sizeMap.values().stream().filter(count -> count != 0).min(Integer::compareTo).orElse(0); + final NativeEngines10010KnnVectorsWriter nativeEngineWriter = new NativeEngines10010KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + minThreshold + ); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when( + () -> NativeEngineFieldVectorsWriter.create(fieldInfo, mockedFlatFieldVectorsWriter, segmentWriteState.infoStream) + ).thenReturn(field); + try { + nativeEngineWriter.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + }); + + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + nativeEngineWriter.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + assertTrue((long) KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue() > 0); + } + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + try { + if (vectorsPerField.get(i).size() > 0) { + verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } else { + verify(nativeIndexWriter, never()).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + } + + public void testFlush_whenThresholdIsEqualToFixedValue_thenRelevantNativeIndexWriterIsCalled() throws IOException { + // Given + List> expectedVectorValues = new ArrayList<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + expectedVectorValues.add(knnVectorValues); + + }); + final int threshold = 4; + final NativeEngines10010KnnVectorsWriter nativeEngineWriter = new NativeEngines10010KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + threshold + ); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when( + () -> NativeEngineFieldVectorsWriter.create(fieldInfo, mockedFlatFieldVectorsWriter, segmentWriteState.infoStream) + ).thenReturn(field); + try { + nativeEngineWriter.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + }); + + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + nativeEngineWriter.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + } + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + try { + if (vectorsPerField.get(i).size() >= threshold) { + verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } else { + verify(nativeIndexWriter, never()).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + } + + public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThresholdIsNotMet_thenStillBuildGraph() throws IOException { + // Given + List> expectedVectorValues = new ArrayList<>(); + final Map sizeMap = new HashMap<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + sizeMap.put(i, randomVectorValues.size()); + expectedVectorValues.add(knnVectorValues); + + }); + final int maxThreshold = sizeMap.values().stream().filter(count -> count != 0).max(Integer::compareTo).orElse(0); + final NativeEngines10010KnnVectorsWriter nativeEngineWriter = new NativeEngines10010KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + maxThreshold + 1 // to avoid building graph using max doc threshold, the same can be achieved by -1 too + ); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when( + () -> NativeEngineFieldVectorsWriter.create(fieldInfo, mockedFlatFieldVectorsWriter, segmentWriteState.infoStream) + ).thenReturn(field); + + try { + nativeEngineWriter.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); + try { + when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size())) + .thenReturn(quantizationState); + } catch (Exception e) { + throw new RuntimeException(e); + } + + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)) + .thenReturn(nativeIndexWriter); + }); + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + nativeEngineWriter.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeHeader(segmentWriteState); + } else { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + } + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + try { + if (vectorsPerField.get(i).isEmpty()) { + verify(knn990QuantWriterMockedConstruction.constructed().get(0), never()).writeState(i, quantizationState); + } else { + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(i, quantizationState); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + final Long expectedTimesGetVectorValuesIsCalled = vectorsPerField.stream().filter(Predicate.not(Map::isEmpty)).count(); + knnVectorValuesFactoryMockedStatic.verify( + () -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), + times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled) * 2) + ); + } + } + + public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThresholdIsNegative_thenStillBuildGraph() + throws IOException { + // Given + List> expectedVectorValues = new ArrayList<>(); + final Map sizeMap = new HashMap<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + sizeMap.put(i, randomVectorValues.size()); + expectedVectorValues.add(knnVectorValues); + + }); + final NativeEngines10010KnnVectorsWriter nativeEngineWriter = new NativeEngines10010KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + BUILD_GRAPH_NEVER_THRESHOLD + ); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when( + () -> NativeEngineFieldVectorsWriter.create(fieldInfo, mockedFlatFieldVectorsWriter, segmentWriteState.infoStream) + ).thenReturn(field); + + try { + nativeEngineWriter.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); + try { + when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size())) + .thenReturn(quantizationState); + } catch (Exception e) { + throw new RuntimeException(e); + } + + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)) + .thenReturn(nativeIndexWriter); + }); + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + nativeEngineWriter.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeHeader(segmentWriteState); + } else { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + } + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + try { + if (vectorsPerField.get(i).isEmpty()) { + verify(knn990QuantWriterMockedConstruction.constructed().get(0), never()).writeState(i, quantizationState); + } else { + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(i, quantizationState); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + final Long expectedTimesGetVectorValuesIsCalled = vectorsPerField.stream().filter(Predicate.not(Map::isEmpty)).count(); + knnVectorValuesFactoryMockedStatic.verify( + () -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), + times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled) * 2) + ); + } + } + + private FieldInfo fieldInfo(int fieldNumber, VectorEncoding vectorEncoding, Map attributes) { + FieldInfo fieldInfo = mock(FieldInfo.class); + when(fieldInfo.getFieldNumber()).thenReturn(fieldNumber); + when(fieldInfo.getVectorEncoding()).thenReturn(vectorEncoding); + when(fieldInfo.attributes()).thenReturn(attributes); + attributes.forEach((key, value) -> when(fieldInfo.getAttribute(key)).thenReturn(value)); + return fieldInfo; + } + + private NativeEngineFieldVectorsWriter nativeEngineFieldVectorsWriter(FieldInfo fieldInfo, Map vectors) { + NativeEngineFieldVectorsWriter fieldVectorsWriter = mock(NativeEngineFieldVectorsWriter.class); + FlatFieldVectorsWriter flatFieldVectorsWriter = mock(FlatFieldVectorsWriter.class); + DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet(); + vectors.keySet().stream().sorted().forEach(docsWithFieldSet::add); + when(fieldVectorsWriter.getFieldInfo()).thenReturn(fieldInfo); + when(fieldVectorsWriter.getVectors()).thenReturn(vectors); + when(fieldVectorsWriter.getFlatFieldVectorsWriter()).thenReturn(flatFieldVectorsWriter); + when(flatFieldVectorsWriter.getDocsWithFieldSet()).thenReturn(docsWithFieldSet); + return fieldVectorsWriter; + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN10010Codec/NativeEngines10010KnnVectorsWriterMergeTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN10010Codec/NativeEngines10010KnnVectorsWriterMergeTests.java new file mode 100644 index 0000000000..564eb733ef --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN10010Codec/NativeEngines10010KnnVectorsWriterMergeTests.java @@ -0,0 +1,384 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN10010Codec; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import lombok.RequiredArgsConstructor; +import lombok.SneakyThrows; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.VectorEncoding; +import org.mockito.Mock; +import org.mockito.MockedConstruction; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.KNN990Codec.KNN990QuantizationStateWriter; +import org.opensearch.knn.index.codec.NativeEngineFieldVectorsWriter; +import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; +import org.opensearch.knn.index.quantizationservice.QuantizationService; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; +import org.opensearch.knn.index.vectorvalues.TestVectorValues; +import org.opensearch.knn.plugin.stats.KNNGraphValue; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Map; + +import static com.carrotsearch.randomizedtesting.RandomizedTest.$; +import static com.carrotsearch.randomizedtesting.RandomizedTest.$$; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockConstruction; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +@RequiredArgsConstructor +public class NativeEngines10010KnnVectorsWriterMergeTests extends OpenSearchTestCase { + + @Mock + private FlatVectorsWriter flatVectorsWriter; + @Mock + private SegmentWriteState segmentWriteState; + @Mock + private QuantizationParams quantizationParams; + @Mock + private QuantizationState quantizationState; + @Mock + private QuantizationService quantizationService; + @Mock + private NativeIndexWriter nativeIndexWriter; + @Mock + private FloatVectorValues floatVectorValues; + @Mock + private MergeState mergeState; + + private NativeEngines10010KnnVectorsWriter objectUnderTest; + + private final String description; + private final Map mergedVectors; + private FlatFieldVectorsWriter mockedFlatFieldVectorsWriter; + private static final Integer BUILD_GRAPH_ALWAYS_THRESHOLD = 0; + private static final Integer BUILD_GRAPH_NEVER_THRESHOLD = -1; + + @Override + public void setUp() throws Exception { + super.setUp(); + MockitoAnnotations.openMocks(this); + objectUnderTest = new NativeEngines10010KnnVectorsWriter(segmentWriteState, flatVectorsWriter, BUILD_GRAPH_ALWAYS_THRESHOLD); + mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class); + Mockito.doNothing().when(mockedFlatFieldVectorsWriter).addValue(Mockito.anyInt(), Mockito.any()); + Mockito.when(flatVectorsWriter.addField(Mockito.any())).thenReturn(mockedFlatFieldVectorsWriter); + } + + @ParametersFactory + public static Collection data() { + return Arrays.asList( + $$( + $("Merge one field", Map.of(0, new float[] { 1, 2, 3 }, 1, new float[] { 2, 3, 4 }, 2, new float[] { 3, 4, 5 })), + $("Merge, no live docs", Map.of()) + ) + ); + } + + @SneakyThrows + public void testMerge() { + // Given + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(mergedVectors.values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedStatic mergedVectorValuesMockedStatic = mockStatic( + KnnVectorsWriter.MergedVectorValues.class + ); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + final FieldInfo fieldInfo = fieldInfo( + 0, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, mergedVectors); + fieldWriterMockedStatic.when( + () -> NativeEngineFieldVectorsWriter.create(fieldInfo, mockedFlatFieldVectorsWriter, segmentWriteState.infoStream) + ).thenReturn(field); + + mergedVectorValuesMockedStatic.when(() -> KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)) + .thenReturn(floatVectorValues); + knnVectorValuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues)) + .thenReturn(knnVectorValues); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).mergeIndex(any(), anyInt()); + + // When + objectUnderTest.mergeOneField(fieldInfo, mergeState); + + // Then + verify(flatVectorsWriter).mergeOneField(fieldInfo, mergeState); + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + if (!mergedVectors.isEmpty()) { + verify(nativeIndexWriter).mergeIndex(knnVectorValues, mergedVectors.size()); + assertTrue(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue() > 0L); + knnVectorValuesFactoryMockedStatic.verify( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues), + times(2) + ); + } else { + verifyNoInteractions(nativeIndexWriter); + } + } + } + + public void testMerge_whenThresholdIsNegative_thenNativeIndexWriterIsNeverCalled() throws IOException { + // Given + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(mergedVectors.values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + final NativeEngines10010KnnVectorsWriter nativeEngineWriter = new NativeEngines10010KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + BUILD_GRAPH_NEVER_THRESHOLD + ); + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedStatic mergedVectorValuesMockedStatic = mockStatic( + KnnVectorsWriter.MergedVectorValues.class + ); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + final FieldInfo fieldInfo = fieldInfo( + 0, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, mergedVectors); + fieldWriterMockedStatic.when( + () -> NativeEngineFieldVectorsWriter.create(fieldInfo, mockedFlatFieldVectorsWriter, segmentWriteState.infoStream) + ).thenReturn(field); + + mergedVectorValuesMockedStatic.when(() -> KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)) + .thenReturn(floatVectorValues); + knnVectorValuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues)) + .thenReturn(knnVectorValues); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).mergeIndex(any(), anyInt()); + + // When + nativeEngineWriter.mergeOneField(fieldInfo, mergeState); + + // Then + verify(flatVectorsWriter).mergeOneField(fieldInfo, mergeState); + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + verifyNoInteractions(nativeIndexWriter); + } + } + + public void testMerge_whenThresholdIsEqualToNumberOfVectors_thenNativeIndexWriterIsCalled() throws IOException { + // Given + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(mergedVectors.values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + final NativeEngines10010KnnVectorsWriter nativeEngineWriter = new NativeEngines10010KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + mergedVectors.size() + ); + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedStatic mergedVectorValuesMockedStatic = mockStatic( + KnnVectorsWriter.MergedVectorValues.class + ); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + final FieldInfo fieldInfo = fieldInfo( + 0, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, mergedVectors); + fieldWriterMockedStatic.when( + () -> NativeEngineFieldVectorsWriter.create(fieldInfo, mockedFlatFieldVectorsWriter, segmentWriteState.infoStream) + ).thenReturn(field); + + mergedVectorValuesMockedStatic.when(() -> KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)) + .thenReturn(floatVectorValues); + knnVectorValuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues)) + .thenReturn(knnVectorValues); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).mergeIndex(any(), anyInt()); + + // When + nativeEngineWriter.mergeOneField(fieldInfo, mergeState); + + // Then + verify(flatVectorsWriter).mergeOneField(fieldInfo, mergeState); + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + if (!mergedVectors.isEmpty()) { + verify(nativeIndexWriter).mergeIndex(knnVectorValues, mergedVectors.size()); + } else { + verifyNoInteractions(nativeIndexWriter); + } + } + } + + @SneakyThrows + public void testMerge_WithQuantization() { + // Given + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(mergedVectors.values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + MockedStatic mergedVectorValuesMockedStatic = mockStatic( + KnnVectorsWriter.MergedVectorValues.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + + final FieldInfo fieldInfo = fieldInfo( + 0, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, mergedVectors); + fieldWriterMockedStatic.when( + () -> NativeEngineFieldVectorsWriter.create(fieldInfo, mockedFlatFieldVectorsWriter, segmentWriteState.infoStream) + ).thenReturn(field); + + mergedVectorValuesMockedStatic.when(() -> KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)) + .thenReturn(floatVectorValues); + knnVectorValuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues)) + .thenReturn(knnVectorValues); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); + try { + when(quantizationService.train(quantizationParams, knnVectorValues, mergedVectors.size())).thenReturn(quantizationState); + } catch (Exception e) { + throw new RuntimeException(e); + } + + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)) + .thenReturn(nativeIndexWriter); + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).mergeIndex(any(), anyInt()); + + // When + objectUnderTest.mergeOneField(fieldInfo, mergeState); + + // Then + verify(flatVectorsWriter).mergeOneField(fieldInfo, mergeState); + if (!mergedVectors.isEmpty()) { + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeHeader(segmentWriteState); + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(0, quantizationState); + verify(nativeIndexWriter).mergeIndex(knnVectorValues, mergedVectors.size()); + assertTrue(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue() > 0L); + knnVectorValuesFactoryMockedStatic.verify( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues), + times(3) + ); + } else { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + verifyNoInteractions(nativeIndexWriter); + } + + } + } + + private FieldInfo fieldInfo(int fieldNumber, VectorEncoding vectorEncoding, Map attributes) { + FieldInfo fieldInfo = mock(FieldInfo.class); + when(fieldInfo.getFieldNumber()).thenReturn(fieldNumber); + when(fieldInfo.getVectorEncoding()).thenReturn(vectorEncoding); + when(fieldInfo.attributes()).thenReturn(attributes); + attributes.forEach((key, value) -> when(fieldInfo.getAttribute(key)).thenReturn(value)); + return fieldInfo; + } + + private NativeEngineFieldVectorsWriter nativeEngineFieldVectorsWriter(FieldInfo fieldInfo, Map vectors) { + NativeEngineFieldVectorsWriter fieldVectorsWriter = mock(NativeEngineFieldVectorsWriter.class); + FlatFieldVectorsWriter flatFieldVectorsWriter = mock(FlatFieldVectorsWriter.class); + DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet(); + vectors.keySet().stream().sorted().forEach(docsWithFieldSet::add); + when(fieldVectorsWriter.getFieldInfo()).thenReturn(fieldInfo); + when(fieldVectorsWriter.getVectors()).thenReturn(vectors); + when(fieldVectorsWriter.getFlatFieldVectorsWriter()).thenReturn(flatFieldVectorsWriter); + when(flatFieldVectorsWriter.getDocsWithFieldSet()).thenReturn(docsWithFieldSet); + return fieldVectorsWriter; + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriterTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriterTests.java index 707ebb2a69..bbfe4ca3b1 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriterTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriterTests.java @@ -20,6 +20,7 @@ import org.junit.Assert; import org.mockito.Mockito; import org.opensearch.knn.index.codec.KNNCodecTestCase; +import org.opensearch.knn.index.codec.NativeEngineFieldVectorsWriter; public class NativeEngineFieldVectorsWriterTests extends KNNCodecTestCase { diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java index f87ed6bcf6..7a222156ba 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java @@ -21,6 +21,7 @@ import org.mockito.MockitoAnnotations; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.NativeEngineFieldVectorsWriter; import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java index cdc372bda2..d30e6a4bbc 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java @@ -24,6 +24,7 @@ import org.mockito.MockitoAnnotations; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.NativeEngineFieldVectorsWriter; import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecServiceTests.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecServiceTests.java index dfe4e7f22a..7b17899c5c 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecServiceTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecServiceTests.java @@ -15,6 +15,7 @@ import org.opensearch.knn.KNNTestCase; import org.apache.logging.log4j.Logger; +import org.opensearch.knn.index.remote.RemoteIndexBuilder; import java.util.UUID; @@ -46,7 +47,7 @@ public void testGetCodecByName() { MapperService mapperService = mock(MapperService.class); Logger loggerMock = mock(Logger.class); CodecServiceConfig codecServiceConfig = new CodecServiceConfig(indexSettings, mapperService, loggerMock); - KNNCodecService knnCodecService = new KNNCodecService(codecServiceConfig); + KNNCodecService knnCodecService = new KNNCodecService(codecServiceConfig, mock(RemoteIndexBuilder.class)); Codec codec = knnCodecService.codec(KNNCodecVersion.current().getCodecName()); assertNotNull(codec); } @@ -61,7 +62,7 @@ public void testGetCodecByName() { public void testGetCodecByNameWithNoMapperService() { Logger loggerMock = mock(Logger.class); CodecServiceConfig codecServiceConfig = new CodecServiceConfig(indexSettings, null, loggerMock); - KNNCodecService knnCodecService = new KNNCodecService(codecServiceConfig); + KNNCodecService knnCodecService = new KNNCodecService(codecServiceConfig, mock(RemoteIndexBuilder.class)); Codec codec = knnCodecService.codec(KNNCodecVersion.current().getCodecName()); assertNotNull(codec); }