diff --git a/src/main/java/org/opensearch/knn/common/featureflags/KNNFeatureFlags.java b/src/main/java/org/opensearch/knn/common/featureflags/KNNFeatureFlags.java index bab5b97bb..f2b411278 100644 --- a/src/main/java/org/opensearch/knn/common/featureflags/KNNFeatureFlags.java +++ b/src/main/java/org/opensearch/knn/common/featureflags/KNNFeatureFlags.java @@ -26,6 +26,7 @@ public class KNNFeatureFlags { // Feature flags private static final String KNN_FORCE_EVICT_CACHE_ENABLED = "knn.feature.cache.force_evict.enabled"; + private static final String KNN_REMOTE_VECTOR_BUILD = "knn.feature.remote_index_build.enabled"; @VisibleForTesting public static final Setting KNN_FORCE_EVICT_CACHE_ENABLED_SETTING = Setting.boolSetting( @@ -35,8 +36,15 @@ public class KNNFeatureFlags { Dynamic ); + public static final Setting KNN_REMOTE_VECTOR_BUILD_SETTING = Setting.boolSetting( + KNN_REMOTE_VECTOR_BUILD, + false, + NodeScope, + Dynamic + ); + public static List> getFeatureFlags() { - return ImmutableList.of(KNN_FORCE_EVICT_CACHE_ENABLED_SETTING); + return ImmutableList.of(KNN_FORCE_EVICT_CACHE_ENABLED_SETTING, KNN_REMOTE_VECTOR_BUILD_SETTING); } /** @@ -46,4 +54,11 @@ public static List> getFeatureFlags() { public static boolean isForceEvictCacheEnabled() { return Booleans.parseBoolean(KNNSettings.state().getSettingValue(KNN_FORCE_EVICT_CACHE_ENABLED).toString(), false); } + + /** + * @return true if remote vector index build feature flag is enabled + */ + public static boolean isKNNRemoteVectorBuildEnabled() { + return Booleans.parseBooleanStrict(KNNSettings.state().getSettingValue(KNN_REMOTE_VECTOR_BUILD).toString(), false); + } } diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index ebb55ea2b..02bc11338 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -41,9 +41,9 @@ import static java.util.stream.Collectors.toUnmodifiableMap; import static org.opensearch.common.settings.Setting.Property.Dynamic; +import static org.opensearch.common.settings.Setting.Property.Final; import static org.opensearch.common.settings.Setting.Property.IndexScope; import static org.opensearch.common.settings.Setting.Property.NodeScope; -import static org.opensearch.common.settings.Setting.Property.Final; import static org.opensearch.common.settings.Setting.Property.UnmodifiableOnRestore; import static org.opensearch.common.unit.MemorySizeValue.parseBytesSizeValueOrHeapRatio; import static org.opensearch.core.common.unit.ByteSizeValue.parseBytesSizeValue; @@ -94,6 +94,8 @@ public class KNNSettings { public static final String KNN_FAISS_AVX512_SPR_DISABLED = "knn.faiss.avx512_spr.disabled"; public static final String KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED = "index.knn.disk.vector.shard_level_rescoring_disabled"; public static final String KNN_DERIVED_SOURCE_ENABLED = "index.knn.derived_source.enabled"; + public static final String KNN_INDEX_REMOTE_VECTOR_BUILD = "index.knn.remote_index_build.enabled"; + public static final String KNN_REMOTE_VECTOR_REPO = "knn.remote_index_build.vector_repo"; /** * Default setting values @@ -371,6 +373,15 @@ public class KNNSettings { NodeScope ); + public static final Setting KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING = Setting.boolSetting( + KNN_INDEX_REMOTE_VECTOR_BUILD, + false, + Dynamic, + IndexScope + ); + + public static final Setting KNN_REMOTE_VECTOR_REPO_SETTING = Setting.simpleString(KNN_REMOTE_VECTOR_REPO, Dynamic, NodeScope); + /** * Dynamic settings */ @@ -525,6 +536,14 @@ private Setting getSetting(String key) { return KNN_DERIVED_SOURCE_ENABLED_SETTING; } + if (KNN_INDEX_REMOTE_VECTOR_BUILD.equals(key)) { + return KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING; + } + + if (KNN_REMOTE_VECTOR_REPO.equals(key)) { + return KNN_REMOTE_VECTOR_REPO_SETTING; + } + throw new IllegalArgumentException("Cannot find setting by key [" + key + "]"); } @@ -550,7 +569,9 @@ public List> getSettings() { QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING, KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING, - KNN_DERIVED_SOURCE_ENABLED_SETTING + KNN_DERIVED_SOURCE_ENABLED_SETTING, + KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING, + KNN_REMOTE_VECTOR_REPO_SETTING ); return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream())) .collect(Collectors.toList()); diff --git a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java index f3a125838..e5dacc423 100644 --- a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java @@ -21,6 +21,7 @@ import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.mapper.KNNMappingConfig; import org.opensearch.knn.index.mapper.KNNVectorFieldType; +import org.opensearch.knn.index.remote.RemoteIndexBuilder; import java.util.Map; import java.util.Optional; @@ -44,6 +45,7 @@ public abstract class BasePerFieldKnnVectorsFormat extends PerFieldKnnVectorsFor private final Supplier defaultFormatSupplier; private final Function vectorsFormatSupplier; private Function scalarQuantizedVectorsFormatSupplier; + private final RemoteIndexBuilder remoteIndexBuilder; private static final String MAX_CONNECTIONS = "max_connections"; private static final String BEAM_WIDTH = "beam_width"; @@ -54,11 +56,26 @@ public BasePerFieldKnnVectorsFormat( Supplier defaultFormatSupplier, Function vectorsFormatSupplier ) { - this.mapperService = mapperService; - this.defaultMaxConnections = defaultMaxConnections; - this.defaultBeamWidth = defaultBeamWidth; - this.defaultFormatSupplier = defaultFormatSupplier; - this.vectorsFormatSupplier = vectorsFormatSupplier; + this(mapperService, defaultMaxConnections, defaultBeamWidth, defaultFormatSupplier, vectorsFormatSupplier, null, null); + } + + public BasePerFieldKnnVectorsFormat( + Optional mapperService, + int defaultMaxConnections, + int defaultBeamWidth, + Supplier defaultFormatSupplier, + Function vectorsFormatSupplier, + Function scalarQuantizedVectorsFormatSupplier + ) { + this( + mapperService, + defaultMaxConnections, + defaultBeamWidth, + defaultFormatSupplier, + vectorsFormatSupplier, + scalarQuantizedVectorsFormatSupplier, + null + ); } @Override @@ -141,7 +158,8 @@ private NativeEngines990KnnVectorsFormat nativeEngineVectorsFormat() { int approximateThreshold = getApproximateThresholdValue(); return new NativeEngines990KnnVectorsFormat( new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()), - approximateThreshold + approximateThreshold, + remoteIndexBuilder ); } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120PerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120PerFieldKnnVectorsFormat.java index afebae2e6..882908ea7 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120PerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120PerFieldKnnVectorsFormat.java @@ -13,6 +13,7 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.codec.BasePerFieldKnnVectorsFormat; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.remote.RemoteIndexBuilder; import java.util.Optional; import java.util.concurrent.ExecutorService; @@ -25,6 +26,10 @@ public class KNN9120PerFieldKnnVectorsFormat extends BasePerFieldKnnVectorsForma private static final Tuple DEFAULT_MERGE_THREAD_COUNT_AND_EXECUTOR_SERVICE = Tuple.tuple(1, null); public KNN9120PerFieldKnnVectorsFormat(final Optional mapperService) { + this(mapperService, null); + } + + public KNN9120PerFieldKnnVectorsFormat(final Optional mapperService, RemoteIndexBuilder remoteIndexBuilder) { super( mapperService, Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, @@ -67,7 +72,8 @@ public KNN9120PerFieldKnnVectorsFormat(final Optional mapperServi // Executor service mergeThreadCountAndExecutorService.v2() ); - } + }, + remoteIndexBuilder ); } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java index 17304c146..c234e738b 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java @@ -21,6 +21,7 @@ import org.apache.lucene.index.SegmentWriteState; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.remote.RemoteIndexBuilder; import java.io.IOException; @@ -33,6 +34,7 @@ public class NativeEngines990KnnVectorsFormat extends KnnVectorsFormat { private static FlatVectorsFormat flatVectorsFormat; private static final String FORMAT_NAME = "NativeEngines990KnnVectorsFormat"; private static int approximateThreshold; + private final RemoteIndexBuilder remoteIndexBuilder; public NativeEngines990KnnVectorsFormat() { this(new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer())); @@ -47,9 +49,18 @@ public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat flatVectorsForma } public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat flatVectorsFormat, int approximateThreshold) { + this(flatVectorsFormat, approximateThreshold, null); + } + + public NativeEngines990KnnVectorsFormat( + final FlatVectorsFormat flatVectorsFormat, + int approximateThreshold, + RemoteIndexBuilder remoteIndexBuilder + ) { super(FORMAT_NAME); NativeEngines990KnnVectorsFormat.flatVectorsFormat = flatVectorsFormat; NativeEngines990KnnVectorsFormat.approximateThreshold = approximateThreshold; + this.remoteIndexBuilder = remoteIndexBuilder; } /** @@ -59,7 +70,7 @@ public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat flatVectorsForma */ @Override public KnnVectorsWriter fieldsWriter(final SegmentWriteState state) throws IOException { - return new NativeEngines990KnnVectorsWriter(state, flatVectorsFormat.fieldsWriter(state), approximateThreshold); + return new NativeEngines990KnnVectorsWriter(state, flatVectorsFormat.fieldsWriter(state), approximateThreshold, remoteIndexBuilder); } /** diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index fb93bfc07..995f957b5 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -28,6 +28,7 @@ import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; import org.opensearch.knn.index.quantizationservice.QuantizationService; +import org.opensearch.knn.index.remote.RemoteIndexBuilder; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.plugin.stats.KNNGraphValue; import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; @@ -54,15 +55,26 @@ public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter { private final List> fields = new ArrayList<>(); private boolean finished; private final Integer approximateThreshold; + private final RemoteIndexBuilder remoteIndexBuilder; public NativeEngines990KnnVectorsWriter( SegmentWriteState segmentWriteState, FlatVectorsWriter flatVectorsWriter, Integer approximateThreshold + ) { + this(segmentWriteState, flatVectorsWriter, approximateThreshold, null); + } + + public NativeEngines990KnnVectorsWriter( + SegmentWriteState segmentWriteState, + FlatVectorsWriter flatVectorsWriter, + Integer approximateThreshold, + RemoteIndexBuilder remoteIndexBuilder ) { this.segmentWriteState = segmentWriteState; this.flatVectorsWriter = flatVectorsWriter; this.approximateThreshold = approximateThreshold; + this.remoteIndexBuilder = remoteIndexBuilder; } /** @@ -114,7 +126,13 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { ); continue; } - final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState); + final NativeIndexWriter writer = NativeIndexWriter.getWriter( + fieldInfo, + segmentWriteState, + quantizationState, + remoteIndexBuilder, + knnVectorValuesSupplier + ); final KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); StopWatch stopWatch = new StopWatch().start(); @@ -153,7 +171,13 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState ); return; } - final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState); + final NativeIndexWriter writer = NativeIndexWriter.getWriter( + fieldInfo, + segmentWriteState, + quantizationState, + remoteIndexBuilder, + knnVectorValuesSupplier + ); final KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); StopWatch stopWatch = new StopWatch().start(); diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java b/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java index 9e210fcd9..6aa2ad996 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java @@ -5,10 +5,11 @@ package org.opensearch.knn.index.codec; -import org.opensearch.index.codec.CodecServiceConfig; import org.apache.lucene.codecs.Codec; import org.opensearch.index.codec.CodecService; +import org.opensearch.index.codec.CodecServiceConfig; import org.opensearch.index.mapper.MapperService; +import org.opensearch.knn.index.remote.RemoteIndexBuilder; /** * KNNCodecService to inject the right KNNCodec version @@ -16,10 +17,12 @@ public class KNNCodecService extends CodecService { private final MapperService mapperService; + private final RemoteIndexBuilder remoteIndexBuilder; - public KNNCodecService(CodecServiceConfig codecServiceConfig) { + public KNNCodecService(CodecServiceConfig codecServiceConfig, RemoteIndexBuilder remoteIndexBuilder) { super(codecServiceConfig.getMapperService(), codecServiceConfig.getIndexSettings(), codecServiceConfig.getLogger()); mapperService = codecServiceConfig.getMapperService(); + this.remoteIndexBuilder = remoteIndexBuilder; } /** @@ -30,6 +33,6 @@ public KNNCodecService(CodecServiceConfig codecServiceConfig) { */ @Override public Codec codec(String name) { - return KNNCodecVersion.current().getKnnCodecSupplier().apply(super.codec(name), mapperService); + return KNNCodecVersion.current().getKnnCodecSupplier().apply(super.codec(name), mapperService, remoteIndexBuilder); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java b/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java index 0f03170c2..107006b58 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java @@ -8,14 +8,15 @@ import lombok.AllArgsConstructor; import lombok.Getter; import org.apache.lucene.backward_codecs.lucene91.Lucene91Codec; +import org.apache.lucene.backward_codecs.lucene912.Lucene912Codec; import org.apache.lucene.backward_codecs.lucene92.Lucene92Codec; import org.apache.lucene.backward_codecs.lucene94.Lucene94Codec; -import org.apache.lucene.codecs.Codec; import org.apache.lucene.backward_codecs.lucene95.Lucene95Codec; import org.apache.lucene.backward_codecs.lucene99.Lucene99Codec; -import org.apache.lucene.backward_codecs.lucene912.Lucene912Codec; +import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.lucene101.Lucene101Codec; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.opensearch.common.TriFunction; import org.opensearch.index.mapper.MapperService; import org.opensearch.knn.index.codec.KNN10010Codec.KNN10010Codec; import org.opensearch.knn.index.codec.KNN80Codec.KNN80CompoundFormat; @@ -31,9 +32,9 @@ import org.opensearch.knn.index.codec.KNN950Codec.KNN950PerFieldKnnVectorsFormat; import org.opensearch.knn.index.codec.KNN990Codec.KNN990Codec; import org.opensearch.knn.index.codec.KNN990Codec.KNN990PerFieldKnnVectorsFormat; +import org.opensearch.knn.index.remote.RemoteIndexBuilder; import java.util.Optional; -import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Supplier; @@ -53,7 +54,7 @@ public enum KNNCodecVersion { new KNN80DocValuesFormat(delegate.docValuesFormat()), new KNN80CompoundFormat(delegate.compoundFormat()) ), - (userCodec, mapperService) -> new KNN910Codec(userCodec), + (userCodec, mapperService, remoteIndexBuilder) -> new KNN910Codec(userCodec), KNN910Codec::new ), @@ -65,7 +66,7 @@ public enum KNNCodecVersion { new KNN80DocValuesFormat(delegate.docValuesFormat()), new KNN80CompoundFormat(delegate.compoundFormat()) ), - (userCodec, mapperService) -> KNN920Codec.builder() + (userCodec, mapperService, remoteIndexBuilder) -> KNN920Codec.builder() .delegate(userCodec) .knnVectorsFormat(new KNN920PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) .build(), @@ -80,7 +81,7 @@ public enum KNNCodecVersion { new KNN80DocValuesFormat(delegate.docValuesFormat()), new KNN80CompoundFormat(delegate.compoundFormat()) ), - (userCodec, mapperService) -> KNN940Codec.builder() + (userCodec, mapperService, remoteIndexBuilder) -> KNN940Codec.builder() .delegate(userCodec) .knnVectorsFormat(new KNN940PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) .build(), @@ -95,7 +96,7 @@ public enum KNNCodecVersion { new KNN80DocValuesFormat(delegate.docValuesFormat()), new KNN80CompoundFormat(delegate.compoundFormat()) ), - (userCodec, mapperService) -> KNN950Codec.builder() + (userCodec, mapperService, remoteIndexBuilder) -> KNN950Codec.builder() .delegate(userCodec) .knnVectorsFormat(new KNN950PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) .build(), @@ -110,7 +111,7 @@ public enum KNNCodecVersion { new KNN80DocValuesFormat(delegate.docValuesFormat()), new KNN80CompoundFormat(delegate.compoundFormat()) ), - (userCodec, mapperService) -> KNN990Codec.builder() + (userCodec, mapperService, remoteIndexBuilder) -> KNN990Codec.builder() .delegate(userCodec) .knnVectorsFormat(new KNN990PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) .build(), @@ -125,7 +126,7 @@ public enum KNNCodecVersion { new KNN80DocValuesFormat(delegate.docValuesFormat()), new KNN80CompoundFormat(delegate.compoundFormat()) ), - (userCodec, mapperService) -> KNN9120Codec.builder() + (userCodec, mapperService, remoteIndexBuilder) -> KNN9120Codec.builder() .delegate(userCodec) .knnVectorsFormat(new KNN9120PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) .mapperService(mapperService) @@ -140,9 +141,9 @@ public enum KNNCodecVersion { new KNN80DocValuesFormat(delegate.docValuesFormat()), new KNN80CompoundFormat(delegate.compoundFormat()) ), - (userCodec, mapperService) -> KNN10010Codec.builder() + (userCodec, mapperService, remoteIndexBuilder) -> KNN10010Codec.builder() .delegate(userCodec) - .knnVectorsFormat(new KNN9120PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) + .knnVectorsFormat(new KNN9120PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService), remoteIndexBuilder)) .mapperService(mapperService) .build(), KNN10010Codec::new @@ -154,7 +155,7 @@ public enum KNNCodecVersion { private final Codec defaultCodecDelegate; private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat; private final Function knnFormatFacadeSupplier; - private final BiFunction knnCodecSupplier; + private final TriFunction knnCodecSupplier; private final Supplier defaultKnnCodecSupplier; public static final KNNCodecVersion current() { diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/LocalNativeIndexWriter.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/LocalNativeIndexWriter.java new file mode 100644 index 000000000..e90f12edd --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/LocalNativeIndexWriter.java @@ -0,0 +1,260 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.store.IndexOutput; +import org.opensearch.common.Nullable; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.knn.common.FieldInfoExtractor; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.qframe.QuantizationConfig; +import org.opensearch.knn.index.quantizationservice.QuantizationService; +import org.opensearch.knn.index.store.IndexOutputWithBuffer; +import org.opensearch.knn.index.util.IndexUtil; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.indices.Model; +import org.opensearch.knn.indices.ModelCache; +import org.opensearch.knn.plugin.stats.KNNGraphValue; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.opensearch.knn.common.FieldInfoExtractor.extractKNNEngine; +import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType; +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues; +import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName; +import static org.opensearch.knn.index.engine.faiss.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; + +/** + * Writes KNN Index for a field in a segment. This is intended to be used for native engines. This class uses the local CPU path for building segments. + * See {@link RemoteNativeIndexWriter} for remote vector index build path. + */ +@AllArgsConstructor +@Log4j2 +public class LocalNativeIndexWriter implements NativeIndexWriter { + private static final Long CRC32_CHECKSUM_SANITY = 0xFFFFFFFF00000000L; + + private final SegmentWriteState state; + private final FieldInfo fieldInfo; + private final NativeIndexBuildStrategy indexBuilder; + @Nullable + private final QuantizationState quantizationState; + + @Override + public void flushIndex(final KNNVectorValues knnVectorValues, int totalLiveDocs) throws IOException { + initializeVectorValues(knnVectorValues); + buildAndWriteIndex(knnVectorValues, totalLiveDocs); + recordRefreshStats(); + } + + @Override + public void mergeIndex(final KNNVectorValues knnVectorValues, int totalLiveDocs) throws IOException { + initializeVectorValues(knnVectorValues); + if (knnVectorValues.docId() == NO_MORE_DOCS) { + // This is in place so we do not add metrics + log.debug("Skipping mergeIndex, vector values are already iterated for {}", fieldInfo.name); + return; + } + + long bytesPerVector = knnVectorValues.bytesPerVector(); + startMergeStats(totalLiveDocs, bytesPerVector); + buildAndWriteIndex(knnVectorValues, totalLiveDocs); + endMergeStats(totalLiveDocs, bytesPerVector); + } + + private void buildAndWriteIndex(final KNNVectorValues knnVectorValues, int totalLiveDocs) throws IOException { + if (totalLiveDocs == 0) { + log.debug("No live docs for field {}", fieldInfo.name); + return; + } + + final KNNEngine knnEngine = extractKNNEngine(fieldInfo); + final String engineFileName = buildEngineFileName( + state.segmentInfo.name, + knnEngine.getVersion(), + fieldInfo.name, + knnEngine.getExtension() + ); + try (IndexOutput output = state.directory.createOutput(engineFileName, state.context)) { + final IndexOutputWithBuffer indexOutputWithBuffer = new IndexOutputWithBuffer(output); + final BuildIndexParams nativeIndexParams = indexParams( + fieldInfo, + indexOutputWithBuffer, + knnEngine, + knnVectorValues, + totalLiveDocs + ); + indexBuilder.buildAndWriteIndex(nativeIndexParams); + CodecUtil.writeFooter(output); + } + } + + // The logic for building parameters need to be cleaned up. There are various cases handled here + // Currently it falls under two categories - with model and without model. Without model is further divided based on vector data type + // TODO: Refactor this so its scalable. Possibly move it out of this class + private BuildIndexParams indexParams( + FieldInfo fieldInfo, + IndexOutputWithBuffer indexOutputWithBuffer, + KNNEngine knnEngine, + KNNVectorValues vectorValues, + int totalLiveDocs + ) throws IOException { + final Map parameters; + VectorDataType vectorDataType; + if (quantizationState != null) { + vectorDataType = QuantizationService.getInstance().getVectorDataTypeForTransfer(fieldInfo); + } else { + vectorDataType = extractVectorDataType(fieldInfo); + } + if (fieldInfo.attributes().containsKey(MODEL_ID)) { + Model model = getModel(fieldInfo); + parameters = getTemplateParameters(fieldInfo, model); + } else { + parameters = getParameters(fieldInfo, vectorDataType, knnEngine); + } + + return BuildIndexParams.builder() + .fieldName(fieldInfo.name) + .parameters(parameters) + .vectorDataType(vectorDataType) + .knnEngine(knnEngine) + .indexOutputWithBuffer(indexOutputWithBuffer) + .quantizationState(quantizationState) + .vectorValues(vectorValues) + .totalLiveDocs(totalLiveDocs) + .build(); + } + + private Map getParameters(FieldInfo fieldInfo, VectorDataType vectorDataType, KNNEngine knnEngine) throws IOException { + Map parameters = new HashMap<>(); + Map fieldAttributes = fieldInfo.attributes(); + String parametersString = fieldAttributes.get(KNNConstants.PARAMETERS); + + // parametersString will be null when legacy mapper is used + if (parametersString == null) { + parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue())); + + String efConstruction = fieldAttributes.get(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION); + Map algoParams = new HashMap<>(); + if (efConstruction != null) { + algoParams.put(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, Integer.parseInt(efConstruction)); + } + + String m = fieldAttributes.get(KNNConstants.HNSW_ALGO_M); + if (m != null) { + algoParams.put(KNNConstants.METHOD_PARAMETER_M, Integer.parseInt(m)); + } + parameters.put(PARAMETERS, algoParams); + } else { + parameters.putAll( + XContentHelper.createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + new BytesArray(parametersString), + MediaTypeRegistry.getDefaultMediaType() + ).map() + ); + } + + parameters.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()); + // In OpenSearch 2.16, we added the prefix for binary indices in the index description in the codec logic. + // After 2.16, we added the binary prefix in the faiss library code. However, to ensure backwards compatibility, + // we need to ensure that if the description does not contain the prefix but the type is binary, we add the + // description. + maybeAddBinaryPrefixForFaissBWC(knnEngine, parameters, fieldAttributes); + + // Used to determine how many threads to use when indexing + parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.getIndexThreadQty()); + + return parameters; + } + + private void maybeAddBinaryPrefixForFaissBWC(KNNEngine knnEngine, Map parameters, Map fieldAttributes) { + if (KNNEngine.FAISS != knnEngine) { + return; + } + + if (!VectorDataType.BINARY.getValue() + .equals(fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()))) { + return; + } + + if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) == null) { + return; + } + + if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString().startsWith(FAISS_BINARY_INDEX_DESCRIPTION_PREFIX)) { + return; + } + + parameters.put( + KNNConstants.INDEX_DESCRIPTION_PARAMETER, + FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString() + ); + IndexUtil.updateVectorDataTypeToParameters(parameters, VectorDataType.BINARY); + } + + private Map getTemplateParameters(FieldInfo fieldInfo, Model model) throws IOException { + Map parameters = new HashMap<>(); + parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.getIndexThreadQty()); + parameters.put(KNNConstants.MODEL_ID, fieldInfo.attributes().get(MODEL_ID)); + parameters.put(KNNConstants.MODEL_BLOB_PARAMETER, model.getModelBlob()); + if (FieldInfoExtractor.extractQuantizationConfig(fieldInfo) != QuantizationConfig.EMPTY) { + IndexUtil.updateVectorDataTypeToParameters(parameters, VectorDataType.BINARY); + } else { + IndexUtil.updateVectorDataTypeToParameters(parameters, model.getModelMetadata().getVectorDataType()); + } + + return parameters; + } + + private Model getModel(FieldInfo fieldInfo) { + String modelId = fieldInfo.attributes().get(MODEL_ID); + Model model = ModelCache.getInstance().get(modelId); + if (model.getModelBlob() == null) { + throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); + } + return model; + } + + private void startMergeStats(int numDocs, long bytesPerVector) { + KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment(); + KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(numDocs); + KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(bytesPerVector); + KNNGraphValue.MERGE_TOTAL_OPERATIONS.increment(); + KNNGraphValue.MERGE_TOTAL_DOCS.incrementBy(numDocs); + KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.incrementBy(bytesPerVector); + } + + private void endMergeStats(int numDocs, long arraySize) { + KNNGraphValue.MERGE_CURRENT_OPERATIONS.decrement(); + KNNGraphValue.MERGE_CURRENT_DOCS.decrementBy(numDocs); + KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.decrementBy(arraySize); + } + + private void recordRefreshStats() { + KNNGraphValue.REFRESH_TOTAL_OPERATIONS.increment(); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java index de535c39e..d919887dc 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java @@ -5,61 +5,70 @@ package org.opensearch.knn.index.codec.nativeindex; -import lombok.AllArgsConstructor; -import lombok.extern.log4j.Log4j2; -import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.SegmentWriteState; -import org.apache.lucene.store.IndexOutput; import org.opensearch.common.Nullable; -import org.opensearch.common.xcontent.XContentHelper; -import org.opensearch.core.common.bytes.BytesArray; -import org.opensearch.core.xcontent.DeprecationHandler; -import org.opensearch.core.xcontent.MediaTypeRegistry; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.knn.common.FieldInfoExtractor; -import org.opensearch.knn.common.KNNConstants; -import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; import org.opensearch.knn.index.engine.KNNEngine; -import org.opensearch.knn.index.engine.qframe.QuantizationConfig; -import org.opensearch.knn.index.quantizationservice.QuantizationService; -import org.opensearch.knn.index.store.IndexOutputWithBuffer; -import org.opensearch.knn.index.util.IndexUtil; +import org.opensearch.knn.index.remote.RemoteIndexBuilder; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; -import org.opensearch.knn.indices.Model; -import org.opensearch.knn.indices.ModelCache; -import org.opensearch.knn.plugin.stats.KNNGraphValue; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.IOException; -import java.util.HashMap; -import java.util.Map; +import java.util.function.Supplier; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.opensearch.knn.common.FieldInfoExtractor.extractKNNEngine; -import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; -import static org.opensearch.knn.common.KNNConstants.PARAMETERS; -import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues; -import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName; -import static org.opensearch.knn.index.engine.faiss.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; /** - * Writes KNN Index for a field in a segment. This is intended to be used for native engines + * Interface for writing a KNN index field in a segment. This is intended to be used for native engines. */ -@AllArgsConstructor -@Log4j2 -public class NativeIndexWriter { - private static final Long CRC32_CHECKSUM_SANITY = 0xFFFFFFFF00000000L; +public interface NativeIndexWriter { + Logger log = LoggerFactory.getLogger(NativeIndexWriter.class); - private final SegmentWriteState state; - private final FieldInfo fieldInfo; - private final NativeIndexBuildStrategy indexBuilder; - @Nullable - private final QuantizationState quantizationState; + /** + * flushes the index + * + * @param knnVectorValues + * @throws IOException + */ + void flushIndex(KNNVectorValues knnVectorValues, int totalLiveDocs) throws IOException; + + /** + * Merges kNN index + * @param knnVectorValues + * @throws IOException + */ + void mergeIndex(KNNVectorValues knnVectorValues, int totalLiveDocs) throws IOException; + + /** + * + * @param fieldInfo + * @param segmentWriteState + * @param quantizationState + * @param remoteIndexBuilder + * @param knnVectorValuesSupplier + * @return {@link RemoteNativeIndexWriter} if {@link RemoteIndexBuilder} is available configured and configured properly. Uses {@link LocalNativeIndexWriter} as fallback. + */ + static NativeIndexWriter getWriter( + final FieldInfo fieldInfo, + final SegmentWriteState segmentWriteState, + final QuantizationState quantizationState, + final RemoteIndexBuilder remoteIndexBuilder, + final Supplier> knnVectorValuesSupplier + ) { + if (remoteIndexBuilder != null && remoteIndexBuilder.shouldBuildIndexRemotely()) { + log.debug("Using RemoteNativeIndexWriter"); + return new RemoteNativeIndexWriter( + createWriter(fieldInfo, segmentWriteState, quantizationState), + remoteIndexBuilder, + knnVectorValuesSupplier + ); + } else { + return createWriter(fieldInfo, segmentWriteState, quantizationState); + } + } /** * Gets the correct writer type from fieldInfo @@ -67,7 +76,7 @@ public class NativeIndexWriter { * @param fieldInfo * @return correct NativeIndexWriter to make index specified in fieldInfo */ - public static NativeIndexWriter getWriter(final FieldInfo fieldInfo, SegmentWriteState state) { + static NativeIndexWriter getWriter(final FieldInfo fieldInfo, SegmentWriteState state) { return createWriter(fieldInfo, state, null); } @@ -85,7 +94,7 @@ public static NativeIndexWriter getWriter(final FieldInfo fieldInfo, SegmentWrit * @param quantizationState The QuantizationState that contains quantization state required for quantization * @return A NativeIndexWriter instance appropriate for the specified field, configured with or without quantization. */ - public static NativeIndexWriter getWriter( + static NativeIndexWriter getWriter( final FieldInfo fieldInfo, final SegmentWriteState state, final QuantizationState quantizationState @@ -93,211 +102,6 @@ public static NativeIndexWriter getWriter( return createWriter(fieldInfo, state, quantizationState); } - /** - * flushes the index - * - * @param knnVectorValues - * @throws IOException - */ - public void flushIndex(final KNNVectorValues knnVectorValues, int totalLiveDocs) throws IOException { - initializeVectorValues(knnVectorValues); - buildAndWriteIndex(knnVectorValues, totalLiveDocs); - recordRefreshStats(); - } - - /** - * Merges kNN index - * @param knnVectorValues - * @throws IOException - */ - public void mergeIndex(final KNNVectorValues knnVectorValues, int totalLiveDocs) throws IOException { - initializeVectorValues(knnVectorValues); - if (knnVectorValues.docId() == NO_MORE_DOCS) { - // This is in place so we do not add metrics - log.debug("Skipping mergeIndex, vector values are already iterated for {}", fieldInfo.name); - return; - } - - long bytesPerVector = knnVectorValues.bytesPerVector(); - startMergeStats(totalLiveDocs, bytesPerVector); - buildAndWriteIndex(knnVectorValues, totalLiveDocs); - endMergeStats(totalLiveDocs, bytesPerVector); - } - - private void buildAndWriteIndex(final KNNVectorValues knnVectorValues, int totalLiveDocs) throws IOException { - if (totalLiveDocs == 0) { - log.debug("No live docs for field {}", fieldInfo.name); - return; - } - - final KNNEngine knnEngine = extractKNNEngine(fieldInfo); - final String engineFileName = buildEngineFileName( - state.segmentInfo.name, - knnEngine.getVersion(), - fieldInfo.name, - knnEngine.getExtension() - ); - try (IndexOutput output = state.directory.createOutput(engineFileName, state.context)) { - final IndexOutputWithBuffer indexOutputWithBuffer = new IndexOutputWithBuffer(output); - final BuildIndexParams nativeIndexParams = indexParams( - fieldInfo, - indexOutputWithBuffer, - knnEngine, - knnVectorValues, - totalLiveDocs - ); - indexBuilder.buildAndWriteIndex(nativeIndexParams); - CodecUtil.writeFooter(output); - } - } - - // The logic for building parameters need to be cleaned up. There are various cases handled here - // Currently it falls under two categories - with model and without model. Without model is further divided based on vector data type - // TODO: Refactor this so its scalable. Possibly move it out of this class - private BuildIndexParams indexParams( - FieldInfo fieldInfo, - IndexOutputWithBuffer indexOutputWithBuffer, - KNNEngine knnEngine, - KNNVectorValues vectorValues, - int totalLiveDocs - ) throws IOException { - final Map parameters; - VectorDataType vectorDataType; - if (quantizationState != null) { - vectorDataType = QuantizationService.getInstance().getVectorDataTypeForTransfer(fieldInfo); - } else { - vectorDataType = extractVectorDataType(fieldInfo); - } - if (fieldInfo.attributes().containsKey(MODEL_ID)) { - Model model = getModel(fieldInfo); - parameters = getTemplateParameters(fieldInfo, model); - } else { - parameters = getParameters(fieldInfo, vectorDataType, knnEngine); - } - - return BuildIndexParams.builder() - .fieldName(fieldInfo.name) - .parameters(parameters) - .vectorDataType(vectorDataType) - .knnEngine(knnEngine) - .indexOutputWithBuffer(indexOutputWithBuffer) - .quantizationState(quantizationState) - .vectorValues(vectorValues) - .totalLiveDocs(totalLiveDocs) - .build(); - } - - private Map getParameters(FieldInfo fieldInfo, VectorDataType vectorDataType, KNNEngine knnEngine) throws IOException { - Map parameters = new HashMap<>(); - Map fieldAttributes = fieldInfo.attributes(); - String parametersString = fieldAttributes.get(KNNConstants.PARAMETERS); - - // parametersString will be null when legacy mapper is used - if (parametersString == null) { - parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue())); - - String efConstruction = fieldAttributes.get(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION); - Map algoParams = new HashMap<>(); - if (efConstruction != null) { - algoParams.put(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, Integer.parseInt(efConstruction)); - } - - String m = fieldAttributes.get(KNNConstants.HNSW_ALGO_M); - if (m != null) { - algoParams.put(KNNConstants.METHOD_PARAMETER_M, Integer.parseInt(m)); - } - parameters.put(PARAMETERS, algoParams); - } else { - parameters.putAll( - XContentHelper.createParser( - NamedXContentRegistry.EMPTY, - DeprecationHandler.THROW_UNSUPPORTED_OPERATION, - new BytesArray(parametersString), - MediaTypeRegistry.getDefaultMediaType() - ).map() - ); - } - - parameters.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()); - // In OpenSearch 2.16, we added the prefix for binary indices in the index description in the codec logic. - // After 2.16, we added the binary prefix in the faiss library code. However, to ensure backwards compatibility, - // we need to ensure that if the description does not contain the prefix but the type is binary, we add the - // description. - maybeAddBinaryPrefixForFaissBWC(knnEngine, parameters, fieldAttributes); - - // Used to determine how many threads to use when indexing - parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.getIndexThreadQty()); - - return parameters; - } - - private void maybeAddBinaryPrefixForFaissBWC(KNNEngine knnEngine, Map parameters, Map fieldAttributes) { - if (KNNEngine.FAISS != knnEngine) { - return; - } - - if (!VectorDataType.BINARY.getValue() - .equals(fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()))) { - return; - } - - if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) == null) { - return; - } - - if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString().startsWith(FAISS_BINARY_INDEX_DESCRIPTION_PREFIX)) { - return; - } - - parameters.put( - KNNConstants.INDEX_DESCRIPTION_PARAMETER, - FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString() - ); - IndexUtil.updateVectorDataTypeToParameters(parameters, VectorDataType.BINARY); - } - - private Map getTemplateParameters(FieldInfo fieldInfo, Model model) throws IOException { - Map parameters = new HashMap<>(); - parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.getIndexThreadQty()); - parameters.put(KNNConstants.MODEL_ID, fieldInfo.attributes().get(MODEL_ID)); - parameters.put(KNNConstants.MODEL_BLOB_PARAMETER, model.getModelBlob()); - if (FieldInfoExtractor.extractQuantizationConfig(fieldInfo) != QuantizationConfig.EMPTY) { - IndexUtil.updateVectorDataTypeToParameters(parameters, VectorDataType.BINARY); - } else { - IndexUtil.updateVectorDataTypeToParameters(parameters, model.getModelMetadata().getVectorDataType()); - } - - return parameters; - } - - private Model getModel(FieldInfo fieldInfo) { - String modelId = fieldInfo.attributes().get(MODEL_ID); - Model model = ModelCache.getInstance().get(modelId); - if (model.getModelBlob() == null) { - throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); - } - return model; - } - - private void startMergeStats(int numDocs, long bytesPerVector) { - KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment(); - KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(numDocs); - KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(bytesPerVector); - KNNGraphValue.MERGE_TOTAL_OPERATIONS.increment(); - KNNGraphValue.MERGE_TOTAL_DOCS.incrementBy(numDocs); - KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.incrementBy(bytesPerVector); - } - - private void endMergeStats(int numDocs, long arraySize) { - KNNGraphValue.MERGE_CURRENT_OPERATIONS.decrement(); - KNNGraphValue.MERGE_CURRENT_DOCS.decrementBy(numDocs); - KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.decrementBy(arraySize); - } - - private void recordRefreshStats() { - KNNGraphValue.REFRESH_TOTAL_OPERATIONS.increment(); - } - /** * Helper method to create the appropriate NativeIndexWriter based on the field info and quantization state. * @@ -317,6 +121,6 @@ private static NativeIndexWriter createWriter( NativeIndexBuildStrategy strategy = iterative ? MemOptimizedNativeIndexBuildStrategy.getInstance() : DefaultIndexBuildStrategy.getInstance(); - return new NativeIndexWriter(state, fieldInfo, strategy, quantizationState); + return new LocalNativeIndexWriter(state, fieldInfo, strategy, quantizationState); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/RemoteNativeIndexWriter.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/RemoteNativeIndexWriter.java new file mode 100644 index 000000000..fad3b1436 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/RemoteNativeIndexWriter.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex; + +import org.opensearch.common.annotation.ExperimentalApi; +import org.opensearch.knn.index.remote.RemoteIndexBuilder; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; + +import java.io.IOException; +import java.util.function.Supplier; + +/** + * Writes KNN Index for a field in a segment. This is intended to be used for native engines. This class uses a remote index build service for building segments. + * See {@link LocalNativeIndexWriter} for local vector index build path. + */ +@ExperimentalApi +public class RemoteNativeIndexWriter implements NativeIndexWriter { + + private final NativeIndexWriter fallbackWriter; + private final RemoteIndexBuilder remoteIndexBuilder; + private final Supplier> knnVectorValuesSupplier; + + public RemoteNativeIndexWriter( + NativeIndexWriter fallbackWriter, + RemoteIndexBuilder remoteIndexBuilder, + Supplier> knnVectorValuesSupplier + ) { + this.fallbackWriter = fallbackWriter; + this.remoteIndexBuilder = remoteIndexBuilder; + this.knnVectorValuesSupplier = knnVectorValuesSupplier; + } + + @Override + public void flushIndex(KNNVectorValues knnVectorValues, int totalLiveDocs) throws IOException { + fallbackWriter.flushIndex(knnVectorValues, totalLiveDocs); + } + + @Override + public void mergeIndex(KNNVectorValues knnVectorValues, int totalLiveDocs) throws IOException { + fallbackWriter.mergeIndex(knnVectorValues, totalLiveDocs); + } +} diff --git a/src/main/java/org/opensearch/knn/index/remote/RemoteIndexBuilder.java b/src/main/java/org/opensearch/knn/index/remote/RemoteIndexBuilder.java new file mode 100644 index 000000000..17c46dab1 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/remote/RemoteIndexBuilder.java @@ -0,0 +1,161 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.remote; + +import lombok.extern.log4j.Log4j2; +import org.apache.commons.lang.NotImplementedException; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.SegmentWriteState; +import org.opensearch.common.StopWatch; +import org.opensearch.common.annotation.ExperimentalApi; +import org.opensearch.index.IndexSettings; +import org.opensearch.knn.common.featureflags.KNNFeatureFlags; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.repositories.RepositoriesService; +import org.opensearch.repositories.Repository; +import org.opensearch.repositories.RepositoryMissingException; +import org.opensearch.repositories.blobstore.BlobStoreRepository; + +import java.io.IOException; +import java.util.function.Supplier; + +import static org.opensearch.knn.index.KNNSettings.KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING; +import static org.opensearch.knn.index.KNNSettings.KNN_REMOTE_VECTOR_REPO_SETTING; + +/** + * This class orchestrates building vector indices. It handles uploading data to a repository, submitting a remote + * build request, awaiting upon the build request to complete, and finally downloading the data from a repository. + * This class is intended to be used by {@link org.opensearch.knn.index.codec.nativeindex.RemoteNativeIndexWriter}. + * This class is created per-index, so we should not store per-segment information as class fields. + */ +@Log4j2 +@ExperimentalApi +public class RemoteIndexBuilder { + + private final Supplier repositoriesServiceSupplier; + private final IndexSettings indexSettings; + private static final String VECTOR_BLOB_FILE_EXTENSION = ".knnvec"; + private static final String DOC_ID_FILE_EXTENSION = ".knndid"; + private static final String GRAPH_FILE_EXTENSION = ".knngraph"; + + /** + * Public constructor + * + * @param repositoriesServiceSupplier A supplier for {@link RepositoriesService} used for interacting with repository + * @param indexSettings + */ + public RemoteIndexBuilder(Supplier repositoriesServiceSupplier, IndexSettings indexSettings) { + this.repositoriesServiceSupplier = repositoriesServiceSupplier; + this.indexSettings = indexSettings; + } + + /** + * @return whether to use the remote build feature + */ + public boolean shouldBuildIndexRemotely() { + String vectorRepo = KNNSettings.state().getSettingValue(KNN_REMOTE_VECTOR_REPO_SETTING.getKey()); + return KNNFeatureFlags.isKNNRemoteVectorBuildEnabled() + && indexSettings.getValue(KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING) + && vectorRepo != null + && !vectorRepo.isEmpty(); + } + + /** + * Gets the KNN repository container from the repository service. + * + * @return {@link RepositoriesService} + * @throws RepositoryMissingException if repository is not registered or if {@link KNN_REMOTE_VECTOR_REPO_SETTING} is not set + */ + private BlobStoreRepository getRepository() throws RepositoryMissingException { + RepositoriesService repositoriesService = repositoriesServiceSupplier.get(); + String vectorRepo = KNNSettings.state().getSettingValue(KNN_REMOTE_VECTOR_REPO_SETTING.getKey()); + if (vectorRepo == null || vectorRepo.isEmpty()) { + throw new RepositoryMissingException("Vector repository " + KNN_REMOTE_VECTOR_REPO_SETTING.getKey() + " is not registered"); + } + final Repository repository = repositoriesService.repository(vectorRepo); + assert repository instanceof BlobStoreRepository : "Repository should be instance of BlobStoreRepository"; + return (BlobStoreRepository) repository; + } + + /** + * 1. upload files, 2. trigger build and wait for completion, 3. download the graph, 4. write to indexoutput + * + * @param fieldInfo + * @param knnVectorValuesSupplier + * @param totalLiveDocs + */ + public void buildIndexRemotely( + FieldInfo fieldInfo, + Supplier> knnVectorValuesSupplier, + int totalLiveDocs, + SegmentWriteState segmentWriteState + ) { + StopWatch stopWatch; + long time_in_millis; + + stopWatch = new StopWatch().start(); + writeToRepository(fieldInfo, knnVectorValuesSupplier, totalLiveDocs, segmentWriteState); + time_in_millis = stopWatch.stop().totalTime().millis(); + log.debug("Repository write took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName()); + + stopWatch = new StopWatch().start(); + submitVectorBuild(); + time_in_millis = stopWatch.stop().totalTime().millis(); + log.debug("Submit vector build took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName()); + + stopWatch = new StopWatch().start(); + awaitVectorBuild(); + time_in_millis = stopWatch.stop().totalTime().millis(); + log.debug("Await vector build took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName()); + + stopWatch = new StopWatch().start(); + readFromRepository(); + time_in_millis = stopWatch.stop().totalTime().millis(); + log.debug("Repository read took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName()); + } + + /** + * Write relevant vector data to repository + * + * @param fieldInfo + * @param knnVectorValuesSupplier + * @param totalLiveDocs + * @param segmentWriteState + * @throws IOException + * @throws InterruptedException + */ + private void writeToRepository( + FieldInfo fieldInfo, + Supplier> knnVectorValuesSupplier, + int totalLiveDocs, + SegmentWriteState segmentWriteState + ) { + throw new NotImplementedException(); + } + + /** + * Submit vector build request to remote vector build service + * + */ + private void submitVectorBuild() { + throw new NotImplementedException(); + } + + /** + * Wait on remote vector build to complete + */ + private void awaitVectorBuild() { + throw new NotImplementedException(); + } + + /** + * Read constructed vector file from remote repository and write to IndexOutput + */ + private void readFromRepository() { + throw new NotImplementedException(); + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index 44c824862..7582006fc 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -5,37 +5,56 @@ package org.opensearch.knn.plugin; +import com.google.common.collect.ImmutableList; +import org.opensearch.action.ActionRequest; import org.opensearch.cluster.NamedDiff; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.IndexScopedSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.settings.SettingsFilter; import org.opensearch.core.ParseField; import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.env.Environment; +import org.opensearch.env.NodeEnvironment; +import org.opensearch.index.IndexModule; +import org.opensearch.index.IndexSettings; import org.opensearch.index.codec.CodecServiceFactory; import org.opensearch.index.engine.EngineFactory; +import org.opensearch.index.mapper.Mapper; import org.opensearch.indices.SystemIndexDescriptor; import org.opensearch.knn.index.KNNCircuitBreaker; -import org.opensearch.knn.index.memory.NativeMemoryCacheManager; -import org.opensearch.knn.plugin.search.KNNConcurrentSearchRequestDecider; -import org.opensearch.knn.index.util.KNNClusterUtil; -import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; - -import org.opensearch.knn.index.query.parser.KNNQueryBuilderParser; -import org.opensearch.knn.index.query.KNNWeight; import org.opensearch.knn.index.codec.KNNCodecService; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; -import org.opensearch.knn.indices.ModelGraveyard; +import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.knn.index.query.KNNWeight; +import org.opensearch.knn.index.query.parser.KNNQueryBuilderParser; +import org.opensearch.knn.index.remote.RemoteIndexBuilder; +import org.opensearch.knn.index.util.KNNClusterUtil; import org.opensearch.knn.indices.ModelCache; import org.opensearch.knn.indices.ModelDao; +import org.opensearch.knn.indices.ModelGraveyard; +import org.opensearch.knn.plugin.rest.RestClearCacheHandler; import org.opensearch.knn.plugin.rest.RestDeleteModelHandler; import org.opensearch.knn.plugin.rest.RestGetModelHandler; import org.opensearch.knn.plugin.rest.RestKNNStatsHandler; import org.opensearch.knn.plugin.rest.RestKNNWarmupHandler; import org.opensearch.knn.plugin.rest.RestSearchModelHandler; import org.opensearch.knn.plugin.rest.RestTrainModelHandler; -import org.opensearch.knn.plugin.rest.RestClearCacheHandler; import org.opensearch.knn.plugin.script.KNNScoringScriptEngine; +import org.opensearch.knn.plugin.search.KNNConcurrentSearchRequestDecider; import org.opensearch.knn.plugin.stats.KNNStats; +import org.opensearch.knn.plugin.transport.ClearCacheAction; +import org.opensearch.knn.plugin.transport.ClearCacheTransportAction; import org.opensearch.knn.plugin.transport.DeleteModelAction; import org.opensearch.knn.plugin.transport.DeleteModelTransportAction; import org.opensearch.knn.plugin.transport.GetModelAction; @@ -44,27 +63,6 @@ import org.opensearch.knn.plugin.transport.KNNStatsTransportAction; import org.opensearch.knn.plugin.transport.KNNWarmupAction; import org.opensearch.knn.plugin.transport.KNNWarmupTransportAction; -import org.opensearch.knn.plugin.transport.ClearCacheAction; -import org.opensearch.knn.plugin.transport.ClearCacheTransportAction; -import com.google.common.collect.ImmutableList; - -import org.opensearch.action.ActionRequest; -import org.opensearch.transport.client.Client; -import org.opensearch.cluster.metadata.IndexNameExpressionResolver; -import org.opensearch.cluster.node.DiscoveryNodes; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.core.common.io.stream.NamedWriteableRegistry; -import org.opensearch.common.settings.ClusterSettings; -import org.opensearch.common.settings.IndexScopedSettings; -import org.opensearch.common.settings.Setting; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.settings.SettingsFilter; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.env.Environment; -import org.opensearch.env.NodeEnvironment; -import org.opensearch.index.IndexModule; -import org.opensearch.index.IndexSettings; -import org.opensearch.index.mapper.Mapper; import org.opensearch.knn.plugin.transport.RemoveModelFromCacheAction; import org.opensearch.knn.plugin.transport.RemoveModelFromCacheTransportAction; import org.opensearch.knn.plugin.transport.SearchModelAction; @@ -76,10 +74,10 @@ import org.opensearch.knn.plugin.transport.TrainingModelAction; import org.opensearch.knn.plugin.transport.TrainingModelRequest; import org.opensearch.knn.plugin.transport.TrainingModelTransportAction; -import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction; -import org.opensearch.knn.plugin.transport.UpdateModelMetadataTransportAction; import org.opensearch.knn.plugin.transport.UpdateModelGraveyardAction; import org.opensearch.knn.plugin.transport.UpdateModelGraveyardTransportAction; +import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction; +import org.opensearch.knn.plugin.transport.UpdateModelMetadataTransportAction; import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCache; import org.opensearch.knn.training.TrainingJobClusterStateListener; import org.opensearch.knn.training.TrainingJobRunner; @@ -102,10 +100,11 @@ import org.opensearch.threadpool.ExecutorBuilder; import org.opensearch.threadpool.FixedExecutorBuilder; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.client.Client; import org.opensearch.watcher.ResourceWatcherService; -import java.util.Arrays; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.List; @@ -163,6 +162,7 @@ public class KNNPlugin extends Plugin private KNNStats knnStats; private ClusterService clusterService; + private Supplier repositoriesServiceSupplier; @Override public Map getMappers() { @@ -192,6 +192,7 @@ public Collection createComponents( Supplier repositoriesServiceSupplier ) { this.clusterService = clusterService; + this.repositoriesServiceSupplier = repositoriesServiceSupplier; // Initialize Native Memory loading strategies VectorReader vectorReader = new VectorReader(client); @@ -281,10 +282,14 @@ public Optional getEngineFactory(IndexSettings indexSettings) { return Optional.empty(); } + /** + * Get KNN codec service. + * We create {@link RemoteIndexBuilder} per-index in order to consume the {@link IndexSettings} from the {@link org.apache.lucene.codecs.KnnVectorsWriter} implementation. + */ @Override public Optional getCustomCodecServiceFactory(IndexSettings indexSettings) { if (indexSettings.getValue(KNNSettings.IS_KNN_INDEX_SETTING)) { - return Optional.of(KNNCodecService::new); + return Optional.of((config) -> new KNNCodecService(config, new RemoteIndexBuilder(repositoriesServiceSupplier, indexSettings))); } return Optional.empty(); } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java index f87ed6bcf..ebe9a22a6 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java @@ -21,6 +21,7 @@ import org.mockito.MockitoAnnotations; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.nativeindex.LocalNativeIndexWriter; import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; @@ -47,6 +48,7 @@ import static com.carrotsearch.randomizedtesting.RandomizedTest.$$; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockConstruction; @@ -71,7 +73,7 @@ public class NativeEngines990KnnVectorsWriterFlushTests extends OpenSearchTestCa @Mock private QuantizationService quantizationService; @Mock - private NativeIndexWriter nativeIndexWriter; + private LocalNativeIndexWriter nativeIndexWriter; private FlatFieldVectorsWriter mockedFlatFieldVectorsWriter; @@ -167,8 +169,9 @@ public void testFlush() { ).thenReturn(expectedVectorValues.get(i)); when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); - nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) - .thenReturn(nativeIndexWriter); + nativeIndexWriterMockedStatic.when( + () -> NativeIndexWriter.getWriter(eq(fieldInfo), eq(segmentWriteState), eq(null), any(), any()) + ).thenReturn(nativeIndexWriter); }); doAnswer(answer -> { @@ -263,8 +266,9 @@ public void testFlush_WithQuantization() { throw new RuntimeException(e); } - nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)) - .thenReturn(nativeIndexWriter); + nativeIndexWriterMockedStatic.when( + () -> NativeIndexWriter.getWriter(eq(fieldInfo), eq(segmentWriteState), eq(quantizationState), any(), any()) + ).thenReturn(nativeIndexWriter); }); doAnswer(answer -> { Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion @@ -358,8 +362,9 @@ public void testFlush_whenThresholdIsNegative_thenNativeIndexWriterIsNeverCalled ).thenReturn(expectedVectorValues.get(i)); when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); - nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) - .thenReturn(nativeIndexWriter); + nativeIndexWriterMockedStatic.when( + () -> NativeIndexWriter.getWriter(eq(fieldInfo), eq(segmentWriteState), eq(null), any(), any()) + ).thenReturn(nativeIndexWriter); }); doAnswer(answer -> { @@ -435,8 +440,9 @@ public void testFlush_whenThresholdIsGreaterThanVectorSize_thenNativeIndexWriter ).thenReturn(expectedVectorValues.get(i)); when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); - nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) - .thenReturn(nativeIndexWriter); + nativeIndexWriterMockedStatic.when( + () -> NativeIndexWriter.getWriter(eq(fieldInfo), eq(segmentWriteState), eq(quantizationState), any(), any()) + ).thenReturn(nativeIndexWriter); }); doAnswer(answer -> { @@ -513,8 +519,9 @@ public void testFlush_whenThresholdIsEqualToMinNumberOfVectors_thenNativeIndexWr ).thenReturn(expectedVectorValues.get(i)); when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); - nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) - .thenReturn(nativeIndexWriter); + nativeIndexWriterMockedStatic.when( + () -> NativeIndexWriter.getWriter(eq(fieldInfo), eq(segmentWriteState), eq(null), any(), any()) + ).thenReturn(nativeIndexWriter); }); doAnswer(answer -> { @@ -599,8 +606,9 @@ public void testFlush_whenThresholdIsEqualToFixedValue_thenRelevantNativeIndexWr ).thenReturn(expectedVectorValues.get(i)); when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); - nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) - .thenReturn(nativeIndexWriter); + nativeIndexWriterMockedStatic.when( + () -> NativeIndexWriter.getWriter(eq(fieldInfo), eq(segmentWriteState), eq(null), any(), any()) + ).thenReturn(nativeIndexWriter); }); doAnswer(answer -> { @@ -695,8 +703,9 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres throw new RuntimeException(e); } - nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)) - .thenReturn(nativeIndexWriter); + nativeIndexWriterMockedStatic.when( + () -> NativeIndexWriter.getWriter(eq(fieldInfo), eq(segmentWriteState), eq(quantizationState), any(), any()) + ).thenReturn(nativeIndexWriter); }); doAnswer(answer -> { Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion @@ -797,8 +806,9 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres throw new RuntimeException(e); } - nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)) - .thenReturn(nativeIndexWriter); + nativeIndexWriterMockedStatic.when( + () -> NativeIndexWriter.getWriter(eq(fieldInfo), eq(segmentWriteState), eq(quantizationState), any(), any()) + ).thenReturn(mock(LocalNativeIndexWriter.class)); }); doAnswer(answer -> { Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java index cdc372bda..eee4dd6f2 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java @@ -24,6 +24,7 @@ import org.mockito.MockitoAnnotations; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.nativeindex.LocalNativeIndexWriter; import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; @@ -44,6 +45,7 @@ import static com.carrotsearch.randomizedtesting.RandomizedTest.$$; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockConstruction; @@ -67,7 +69,7 @@ public class NativeEngines990KnnVectorsWriterMergeTests extends OpenSearchTestCa @Mock private QuantizationService quantizationService; @Mock - private NativeIndexWriter nativeIndexWriter; + private LocalNativeIndexWriter nativeIndexWriter; @Mock private FloatVectorValues floatVectorValues; @Mock @@ -139,8 +141,9 @@ public void testMerge() { .thenReturn(knnVectorValues); when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); - nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) - .thenReturn(nativeIndexWriter); + nativeIndexWriterMockedStatic.when( + () -> NativeIndexWriter.getWriter(eq(fieldInfo), eq(segmentWriteState), eq(null), any(), any()) + ).thenReturn(nativeIndexWriter); doAnswer(answer -> { Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion return null; @@ -206,8 +209,9 @@ public void testMerge_whenThresholdIsNegative_thenNativeIndexWriterIsNeverCalled .thenReturn(knnVectorValues); when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); - nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) - .thenReturn(nativeIndexWriter); + nativeIndexWriterMockedStatic.when( + () -> NativeIndexWriter.getWriter(eq(fieldInfo), eq(segmentWriteState), eq(null), any(), any()) + ).thenReturn(nativeIndexWriter); doAnswer(answer -> { Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion return null; @@ -264,8 +268,9 @@ public void testMerge_whenThresholdIsEqualToNumberOfVectors_thenNativeIndexWrite .thenReturn(knnVectorValues); when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); - nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) - .thenReturn(nativeIndexWriter); + nativeIndexWriterMockedStatic.when( + () -> NativeIndexWriter.getWriter(eq(fieldInfo), eq(segmentWriteState), eq(null), any(), any()) + ).thenReturn(nativeIndexWriter); doAnswer(answer -> { Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion return null; @@ -330,8 +335,9 @@ public void testMerge_WithQuantization() { throw new RuntimeException(e); } - nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)) - .thenReturn(nativeIndexWriter); + nativeIndexWriterMockedStatic.when( + () -> NativeIndexWriter.getWriter(eq(fieldInfo), eq(segmentWriteState), eq(quantizationState), any(), any()) + ).thenReturn(nativeIndexWriter); doAnswer(answer -> { Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion return null; diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecServiceTests.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecServiceTests.java index dfe4e7f22..1c08f1daa 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecServiceTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecServiceTests.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.codec; +import org.apache.logging.log4j.Logger; import org.apache.lucene.codecs.Codec; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.common.settings.Settings; @@ -13,8 +14,7 @@ import org.opensearch.index.codec.CodecServiceConfig; import org.opensearch.index.mapper.MapperService; import org.opensearch.knn.KNNTestCase; - -import org.apache.logging.log4j.Logger; +import org.opensearch.knn.index.remote.RemoteIndexBuilder; import java.util.UUID; @@ -46,7 +46,7 @@ public void testGetCodecByName() { MapperService mapperService = mock(MapperService.class); Logger loggerMock = mock(Logger.class); CodecServiceConfig codecServiceConfig = new CodecServiceConfig(indexSettings, mapperService, loggerMock); - KNNCodecService knnCodecService = new KNNCodecService(codecServiceConfig); + KNNCodecService knnCodecService = new KNNCodecService(codecServiceConfig, mock(RemoteIndexBuilder.class)); Codec codec = knnCodecService.codec(KNNCodecVersion.current().getCodecName()); assertNotNull(codec); } @@ -61,7 +61,7 @@ public void testGetCodecByName() { public void testGetCodecByNameWithNoMapperService() { Logger loggerMock = mock(Logger.class); CodecServiceConfig codecServiceConfig = new CodecServiceConfig(indexSettings, null, loggerMock); - KNNCodecService knnCodecService = new KNNCodecService(codecServiceConfig); + KNNCodecService knnCodecService = new KNNCodecService(codecServiceConfig, mock(RemoteIndexBuilder.class)); Codec codec = knnCodecService.codec(KNNCodecVersion.current().getCodecName()); assertNotNull(codec); }