diff --git a/build.gradle b/build.gradle index 72d66abade..5b6526ac7b 100644 --- a/build.gradle +++ b/build.gradle @@ -316,6 +316,10 @@ dependencies { } testFixturesImplementation "org.opensearch:common-utils:${version}" implementation 'com.github.oshi:oshi-core:6.4.13' + + implementation 'io.github.jbellis:jvector:4.0.0-beta.2-SNAPSHOT' + implementation 'org.agrona:agrona:1.20.0' + api "net.java.dev.jna:jna:5.13.0" api "net.java.dev.jna:jna-platform:5.13.0" // OpenSearch core is using slf4j 1.7.36. Therefore, we cannot change the version here. @@ -331,7 +335,7 @@ task windowsPatches(type:Exec) { task cmakeJniLib(type:Exec) { workingDir 'jni' def args = [] - args.add("cmake") + args.add("/opt/homebrew/bin/cmake") args.add(".") args.add("-DKNN_PLUGIN_VERSION=${opensearch_version}") args.add("-DAVX2_ENABLED=${avx2_enabled}") @@ -364,6 +368,8 @@ test { dependsOn buildJniLib systemProperty 'tests.security.manager', 'false' systemProperty "java.library.path", "$rootDir/jni/release" + systemProperty 'log4j.configurationFile', "$rootDir/src/test/resources/log4j2.properties" + //this change enables mockito-inline that supports mocking of static classes/calls systemProperty "jdk.attach.allowAttachSelf", true if (Os.isFamily(Os.FAMILY_WINDOWS)) { @@ -378,6 +384,11 @@ integTest { dependsOn buildJniLib } systemProperty 'tests.security.manager', 'false' + println "Project root directory: ${project.rootDir}" + systemProperty "java.security.policy", "file://${project.rootDir}/src/main/plugin-metadata/plugin-security.policy" + systemProperty 'log4j.configurationFile', "${project.rootDir}/src/test/resources/log4j2.properties" + testLogging.showStandardStreams = true + systemProperty 'tests.output', 'true' systemProperty 'java.io.tmpdir', opensearch_tmp_dir.absolutePath systemProperty "java.library.path", "$rootDir/jni/release" // allows integration test classes to access test resource from project root path @@ -421,7 +432,8 @@ integTest { testClusters.integTest { testDistribution = "ARCHIVE" - + systemProperty "java.security.policy", "file://${project.rootDir}/src/main/plugin-metadata/plugin-security.policy" + systemProperty 'log4j.configurationFile', "${project.rootDir}/src/test/resources/log4j2.properties" // Optionally install security if (System.getProperty("security.enabled") != null) { configureSecurityPlugin(testClusters.integTest) @@ -460,7 +472,7 @@ task integTestRemote(type: RestIntegTestTask) { systemProperty 'cluster.number_of_nodes', "${_numNodes}" systemProperty 'tests.security.manager', 'false' - + systemProperty 'tests.output', 'true' // Run tests with remote cluster only if rest case is defined if (System.getProperty("tests.rest.cluster") != null) { filter { diff --git a/demo.sh b/demo.sh new file mode 100644 index 0000000000..69d5c267ec --- /dev/null +++ b/demo.sh @@ -0,0 +1,77 @@ +#!/bin/bash + +./gradlew run -PcustomDistributionUrl=file://${HOME}/projects/OpenSearch/distribution/archives/darwin-tar/build/distributions/opensearch-min-3.0.0-SNAPSHOT-darwin-x64.tar.gz + +# ping local cluster +curl localhost:9200 + +# Check test cluster status +curl -X GET "http://localhost:9200/_cluster/health?pretty" + +# Create new knn index with 1 shard and 0 replicas +curl -X PUT "localhost:9200/my_knn_index?pretty" -H 'Content-Type: application/json' -d' +{ + "settings": { + "index.knn": true, + "index.number_of_shards": 1, + "index.number_of_replicas": 0, + "index.use_compound_file": false + } +}' + +# Check index settings +curl -X GET "localhost:9200/my_knn_index/_settings?pretty" + +# Add mapping for knn_vector field with jVector engine +curl -X PUT "localhost:9200/my_knn_index/_mapping?pretty" -H 'Content-Type: application/json' -d' +{ + "properties": { + "my_vector": { + "type": "knn_vector", + "dimension": 3, + "method": { + "name": "disk_ann", + "space_type": "l2", + "engine": "jvector" + } + } + } +}' + + +# Check index mapping +curl -X GET "localhost:9200/my_knn_index/_mapping?pretty" + +# Add document with knn_vector field +curl -X POST "localhost:9200/_bulk?pretty" -H 'Content-Type: application/json' -d' +{"index": {"_index": "my_knn_index"}} +{"my_vector": [1, 2, 3]} +{"index": {"_index": "my_knn_index"}} +{"my_vector": [4, 5, 6]} +{"index": {"_index": "my_knn_index"}} +{"my_vector": [7, 8, 9]} +' + +# refresh index +curl -X POST "localhost:9200/my_knn_index/_refresh?pretty" + + +# Search for nearest neighbors +curl -X GET "localhost:9200/my_knn_index/_search?pretty" -H 'Content-Type: application/json' -d' +{ + "query": { + "knn": { + "my_vector": { + "vector": [1, 2, 3], + "k": 3 + } + } + } +}' + +# Delete index +curl -X DELETE "localhost:9200/my_knn_index?pretty" + + +# Check test cluster location +ls -lah build/testclusters/integTest-0/data/nodes/0/indices \ No newline at end of file diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar index 7f93135c49..a4b76b9530 100644 Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index b42703b944..9355b41557 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,12 +1,7 @@ -# -# Copyright OpenSearch Contributors -# SPDX-License-Identifier: Apache-2.0 -# - distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionSha256Sum=f2b9ed0faf8472cbe469255ae6c86eddb77076c75191741b4a462f33128dd419 -distributionUrl=https\://services.gradle.org/distributions/gradle-8.4-all.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-8.10-bin.zip networkTimeout=10000 +validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/gradlew b/gradlew index a71dbba9b2..f5feea6d6b 100755 --- a/gradlew +++ b/gradlew @@ -1,8 +1,4 @@ #!/bin/sh -# -# Copyright OpenSearch Contributors -# SPDX-License-Identifier: Apache-2.0 -# # # Copyright © 2015-2021 the original authors. @@ -19,6 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +# SPDX-License-Identifier: Apache-2.0 +# ############################################################################## # @@ -59,7 +57,7 @@ # Darwin, MinGW, and NonStop. # # (3) This script is generated from the Groovy template -# https://github.com/gradle/gradle/blob/HEAD/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# https://github.com/gradle/gradle/blob/HEAD/platforms/jvm/plugins-application/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt # within the Gradle project. # # You can find Gradle at https://github.com/gradle/gradle/. @@ -88,7 +86,8 @@ done # shellcheck disable=SC2034 APP_BASE_NAME=${0##*/} # Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036) -APP_HOME=$( cd "${APP_HOME:-./}" > /dev/null && pwd -P ) || exit +APP_HOME=$( cd -P "${APP_HOME:-./}" > /dev/null && printf '%s +' "$PWD" ) || exit # Use the maximum available, or set MAX_FD != -1 to use that value. MAX_FD=maximum diff --git a/gradlew.bat b/gradlew.bat index cd89102823..9b42019c79 100644 --- a/gradlew.bat +++ b/gradlew.bat @@ -1,8 +1,4 @@ @rem -@rem Copyright OpenSearch Contributors -@rem SPDX-License-Identifier: Apache-2.0 -@rem -@rem @rem Copyright 2015 the original author or authors. @rem @rem Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,6 +13,8 @@ @rem See the License for the specific language governing permissions and @rem limitations under the License. @rem +@rem SPDX-License-Identifier: Apache-2.0 +@rem @if "%DEBUG%"=="" @echo off @rem ########################################################################## @@ -47,11 +45,11 @@ set JAVA_EXE=java.exe %JAVA_EXE% -version >NUL 2>&1 if %ERRORLEVEL% equ 0 goto execute -echo. -echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. +echo. 1>&2 +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 goto fail @@ -61,11 +59,11 @@ set JAVA_EXE=%JAVA_HOME%/bin/java.exe if exist "%JAVA_EXE%" goto execute -echo. -echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. +echo. 1>&2 +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 goto fail diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index ce6095fd05..a247eab110 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -17,6 +17,7 @@ public class KNNConstants { public static final String NAME = "name"; public static final String PARAMETERS = "parameters"; public static final String METHOD_HNSW = "hnsw"; + public static final String DISK_ANN = "disk_ann"; public static final String TYPE = "type"; public static final String TYPE_NESTED = "nested"; public static final String PATH = "path"; @@ -121,6 +122,9 @@ public class KNNConstants { public static final String FAISS_SIGNED_BYTE_SQ = "SQ8_direct_signed"; public static final String FAISS_SQ_CLIP = "clip"; + // JVector specific constants + public static final String JVECTOR_NAME = "jvector"; + // Parameter defaults/limits public static final Integer ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT = 1; public static final Integer ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT = 1024; diff --git a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java index f3a1258387..0fa383ec31 100644 --- a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java @@ -15,6 +15,7 @@ import org.opensearch.index.mapper.MapperService; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.codec.KNN990Codec.NativeEngines990KnnVectorsFormat; +import org.opensearch.knn.index.codec.jvector.JVectorFormat; import org.opensearch.knn.index.codec.params.KNNScalarQuantizedVectorsFormatParams; import org.opensearch.knn.index.codec.params.KNNVectorsFormatParams; import org.opensearch.knn.index.engine.KNNEngine; @@ -24,6 +25,7 @@ import java.util.Map; import java.util.Optional; +import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Supplier; @@ -42,7 +44,7 @@ public abstract class BasePerFieldKnnVectorsFormat extends PerFieldKnnVectorsFor private final int defaultMaxConnections; private final int defaultBeamWidth; private final Supplier defaultFormatSupplier; - private final Function vectorsFormatSupplier; + private final BiFunction vectorsFormatSupplier; private Function scalarQuantizedVectorsFormatSupplier; private static final String MAX_CONNECTIONS = "max_connections"; private static final String BEAM_WIDTH = "beam_width"; @@ -52,7 +54,7 @@ public BasePerFieldKnnVectorsFormat( int defaultMaxConnections, int defaultBeamWidth, Supplier defaultFormatSupplier, - Function vectorsFormatSupplier + BiFunction vectorsFormatSupplier ) { this.mapperService = mapperService; this.defaultMaxConnections = defaultMaxConnections; @@ -89,50 +91,52 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) { .orElseThrow(() -> new IllegalArgumentException("KNN method context cannot be empty")); final KNNEngine engine = knnMethodContext.getKnnEngine(); final Map params = knnMethodContext.getMethodComponentContext().getParameters(); + switch (engine) { + // All Java engines to use Lucene extensions directly + case JVECTOR: + case LUCENE: + if (params != null && params.containsKey(METHOD_ENCODER_PARAMETER)) { + KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams( + params, + defaultMaxConnections, + defaultBeamWidth + ); + if (knnScalarQuantizedVectorsFormatParams.validate(params)) { + log.debug( + "Initialize KNN vector format for field [{}] with params [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\"", + field, + MAX_CONNECTIONS, + knnScalarQuantizedVectorsFormatParams.getMaxConnections(), + BEAM_WIDTH, + knnScalarQuantizedVectorsFormatParams.getBeamWidth(), + LUCENE_SQ_CONFIDENCE_INTERVAL, + knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(), + LUCENE_SQ_BITS, + knnScalarQuantizedVectorsFormatParams.getBits() + ); + return scalarQuantizedVectorsFormatSupplier.apply(knnScalarQuantizedVectorsFormatParams); + } + } - if (engine == KNNEngine.LUCENE) { - if (params != null && params.containsKey(METHOD_ENCODER_PARAMETER)) { - KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams( - params, - defaultMaxConnections, - defaultBeamWidth + KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams( + params, + defaultMaxConnections, + defaultBeamWidth, + knnMethodContext.getSpaceType() ); - if (knnScalarQuantizedVectorsFormatParams.validate(params)) { - log.debug( - "Initialize KNN vector format for field [{}] with params [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\"", + log.debug( + "Initialize KNN vector format for field [{}] with params [{}] = \"{}\" and [{}] = \"{}\"", field, MAX_CONNECTIONS, - knnScalarQuantizedVectorsFormatParams.getMaxConnections(), + knnVectorsFormatParams.getMaxConnections(), BEAM_WIDTH, - knnScalarQuantizedVectorsFormatParams.getBeamWidth(), - LUCENE_SQ_CONFIDENCE_INTERVAL, - knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(), - LUCENE_SQ_BITS, - knnScalarQuantizedVectorsFormatParams.getBits() - ); - return scalarQuantizedVectorsFormatSupplier.apply(knnScalarQuantizedVectorsFormatParams); - } - } - - KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams( - params, - defaultMaxConnections, - defaultBeamWidth, - knnMethodContext.getSpaceType() - ); - log.debug( - "Initialize KNN vector format for field [{}] with params [{}] = \"{}\" and [{}] = \"{}\"", - field, - MAX_CONNECTIONS, - knnVectorsFormatParams.getMaxConnections(), - BEAM_WIDTH, - knnVectorsFormatParams.getBeamWidth() - ); - return vectorsFormatSupplier.apply(knnVectorsFormatParams); + knnVectorsFormatParams.getBeamWidth() + ); + return vectorsFormatSupplier.apply(engine, knnVectorsFormatParams); + default: + // All native engines to use NativeEngines990KnnVectorsFormat + return nativeEngineVectorsFormat(); } - - // All native engines to use NativeEngines990KnnVectorsFormat - return nativeEngineVectorsFormat(); } private NativeEngines990KnnVectorsFormat nativeEngineVectorsFormat() { diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java index a370197ecc..3d0a917ccf 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java @@ -12,6 +12,7 @@ import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.opensearch.index.mapper.MapperService; import org.opensearch.knn.index.codec.KNNCodecVersion; import org.opensearch.knn.index.codec.KNNFormatFacade; diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120PerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120PerFieldKnnVectorsFormat.java index 6e8fc767ec..3563fac2fd 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120PerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120PerFieldKnnVectorsFormat.java @@ -10,6 +10,7 @@ import org.opensearch.index.mapper.MapperService; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.codec.BasePerFieldKnnVectorsFormat; +import org.opensearch.knn.index.codec.jvector.JVectorFormat; import org.opensearch.knn.index.engine.KNNEngine; import java.util.Optional; @@ -26,16 +27,28 @@ public KNN9120PerFieldKnnVectorsFormat(final Optional mapperServi Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, Lucene99HnswVectorsFormat::new, - knnVectorsFormatParams -> { - // There is an assumption here that hamming space will only be used for binary vectors. This will need to be fixed if that - // changes in the future. - if (knnVectorsFormatParams.getSpaceType() == SpaceType.HAMMING) { - return new KNN9120HnswBinaryVectorsFormat( - knnVectorsFormatParams.getMaxConnections(), - knnVectorsFormatParams.getBeamWidth() - ); - } else { - return new Lucene99HnswVectorsFormat(knnVectorsFormatParams.getMaxConnections(), knnVectorsFormatParams.getBeamWidth()); + (knnEngine, knnVectorsFormatParams) -> { + + + switch (knnEngine) { + case LUCENE: + // There is an assumption here that hamming space will only be used for binary vectors. This will need to be fixed if that + // changes in the future. + if (knnVectorsFormatParams.getSpaceType() == SpaceType.HAMMING) { + return new KNN9120HnswBinaryVectorsFormat( + knnVectorsFormatParams.getMaxConnections(), + knnVectorsFormatParams.getBeamWidth() + ); + } else { + return new Lucene99HnswVectorsFormat(knnVectorsFormatParams.getMaxConnections(), knnVectorsFormatParams.getBeamWidth()); + } + case JVECTOR: + return new JVectorFormat( + knnVectorsFormatParams.getMaxConnections(), + knnVectorsFormatParams.getBeamWidth() + ); + default: + throw new IllegalArgumentException("Unsupported java engine: " + knnEngine); } }, knnScalarQuantizedVectorsFormatParams -> new Lucene99HnswScalarQuantizedVectorsFormat( diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120WraperCodec.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120WraperCodec.java new file mode 100644 index 0000000000..7345f0cc9b --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120WraperCodec.java @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN9120Codec; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.lucene912.Lucene912Codec; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.knn.index.codec.KNN80Codec.KNN80DocValuesFormat; +import org.opensearch.knn.index.codec.KNNFormatFacade; +import org.opensearch.knn.index.codec.WrapperCodecForKNNPlugin; +import org.opensearch.knn.index.codec.jvector.JVectorCompoundFormat; + +import java.util.Optional; +import java.util.function.BiFunction; +import java.util.function.Supplier; + +/** + * Example for a specific codec version that extends {@link WrapperCodecForKNNPlugin} and provides codec specific delegates + * This makes allows for all the constructors and resources to be encapsulated in a single class without going back and forth between providers + */ +public class KNN9120WraperCodec extends WrapperCodecForKNNPlugin { + private static final String CODEC_NAME = "KNN9120Codec"; + + public KNN9120WraperCodec(String name, MapperService mapperService) { + super(name, + new Lucene912Codec(), + CODEC_NAME, + mapperService, + new KNN9120PerFieldKnnVectorsFormat(Optional.empty()), + (delegate) -> new KNNFormatFacade( + new KNN80DocValuesFormat(delegate.docValuesFormat()), + new JVectorCompoundFormat(delegate.compoundFormat()) + )); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920PerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920PerFieldKnnVectorsFormat.java index 7cca04319b..0d46557ec5 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920PerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920PerFieldKnnVectorsFormat.java @@ -22,7 +22,7 @@ public KNN920PerFieldKnnVectorsFormat(final Optional mapperServic Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN, Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH, () -> new Lucene92HnswVectorsFormat(), - knnVectorsFormatParams -> new Lucene92HnswVectorsFormat( + (knnEngine, knnVectorsFormatParams) -> new Lucene92HnswVectorsFormat( knnVectorsFormatParams.getMaxConnections(), knnVectorsFormatParams.getBeamWidth() ) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940PerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940PerFieldKnnVectorsFormat.java index 1ed9c929c1..70aa738ee1 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940PerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940PerFieldKnnVectorsFormat.java @@ -22,7 +22,7 @@ public KNN940PerFieldKnnVectorsFormat(final Optional mapperServic Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN, Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH, () -> new Lucene94HnswVectorsFormat(), - knnVectorsFormatParams -> new Lucene94HnswVectorsFormat( + (knnEngine, knnVectorsFormatParams) -> new Lucene94HnswVectorsFormat( knnVectorsFormatParams.getMaxConnections(), knnVectorsFormatParams.getBeamWidth() ) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN950Codec/KNN950PerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN950Codec/KNN950PerFieldKnnVectorsFormat.java index 7a14580576..7cb2d35b9c 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN950Codec/KNN950PerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN950Codec/KNN950PerFieldKnnVectorsFormat.java @@ -23,7 +23,7 @@ public KNN950PerFieldKnnVectorsFormat(final Optional mapperServic Lucene95HnswVectorsFormat.DEFAULT_MAX_CONN, Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH, () -> new Lucene95HnswVectorsFormat(), - knnVectorsFormatParams -> new Lucene95HnswVectorsFormat( + (knnEngine, knnVectorsFormatParams) -> new Lucene95HnswVectorsFormat( knnVectorsFormatParams.getMaxConnections(), knnVectorsFormatParams.getBeamWidth() ) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990PerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990PerFieldKnnVectorsFormat.java index 67ea7b5444..b61e2e5350 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990PerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990PerFieldKnnVectorsFormat.java @@ -25,7 +25,7 @@ public KNN990PerFieldKnnVectorsFormat(final Optional mapperServic Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, Lucene99HnswVectorsFormat::new, - knnVectorsFormatParams -> new Lucene99HnswVectorsFormat( + (knnEngine, knnVectorsFormatParams) -> new Lucene99HnswVectorsFormat( knnVectorsFormatParams.getMaxConnections(), knnVectorsFormatParams.getBeamWidth() ), diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java b/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java index 4343c845b0..147b832240 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java @@ -29,6 +29,8 @@ import org.opensearch.knn.index.codec.KNN950Codec.KNN950PerFieldKnnVectorsFormat; import org.opensearch.knn.index.codec.KNN990Codec.KNN990Codec; import org.opensearch.knn.index.codec.KNN990Codec.KNN990PerFieldKnnVectorsFormat; +import org.opensearch.knn.index.codec.jvector.JVectorCodec; +import org.opensearch.knn.index.codec.jvector.JVectorCompoundFormat; import java.util.Optional; import java.util.function.BiFunction; @@ -128,6 +130,22 @@ public enum KNNCodecVersion { .knnVectorsFormat(new KNN9120PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) .build(), KNN9120Codec::new + ), + + /** + * Version 9.12.0 with JVector codec. This is for demosntration purposes of direct encoding of JVector and + * potentially other JVM KNN libraries direct encoding into Lucene compound file without creating new files + */ + V_9_12_0_WITH_JVECTOR( + JVectorCodec.CODEC_NAME, + new JVectorCodec(), + new KNN9120PerFieldKnnVectorsFormat(Optional.empty()), + (delegate) -> new KNNFormatFacade( + new KNN80DocValuesFormat(delegate.docValuesFormat()), + new JVectorCompoundFormat(delegate.compoundFormat()) + ), + (userCodec, mapperService) -> new JVectorCodec(), + JVectorCodec::new ); private static final KNNCodecVersion CURRENT = V_9_12_0; diff --git a/src/main/java/org/opensearch/knn/index/codec/WrapperCodecForKNNPlugin.java b/src/main/java/org/opensearch/knn/index/codec/WrapperCodecForKNNPlugin.java new file mode 100644 index 0000000000..36abe6d331 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/WrapperCodecForKNNPlugin.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec; + +import org.apache.lucene.codecs.*; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.opensearch.index.mapper.MapperService; + +import java.util.function.Function; + +/** + * Wrapper codec for KNN plugin + * + * All specific codecs should inherit from this class, this insures that all the relevant format elements are all in the same class + * and that we are not delegating back and forth between suppliers/consumers of various codec elements from other classes and thus keeps better encapsulation and readability. + */ +public class WrapperCodecForKNNPlugin extends FilterCodec { + private final String codecName; + private final Codec codecDelegate; + private final MapperService mapperService; + private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat; + private final Function knnFormatFacadeSupplier; + private final KNNFormatFacade knnFormatFacade; + + public WrapperCodecForKNNPlugin(String name, Codec codecDelegate, String codecName, MapperService mapperService, PerFieldKnnVectorsFormat perFieldKnnVectorsFormat, Function knnFormatFacadeSupplier) { + super(name, codecDelegate); + this.codecName = codecName; + this.codecDelegate = codecDelegate; + this.mapperService = mapperService; + this.perFieldKnnVectorsFormat = perFieldKnnVectorsFormat; + this.knnFormatFacadeSupplier = knnFormatFacadeSupplier; + this.knnFormatFacade = knnFormatFacadeSupplier.apply(codecDelegate); + } + + @Override + public DocValuesFormat docValuesFormat() { + return knnFormatFacade.docValuesFormat(); + } + + @Override + public CompoundFormat compoundFormat() { + return knnFormatFacade.compoundFormat(); + } + + @Override + public KnnVectorsFormat knnVectorsFormat() { + return perFieldKnnVectorsFormat; + } + +} diff --git a/src/main/java/org/opensearch/knn/index/codec/jvector/JVectorCodec.java b/src/main/java/org/opensearch/knn/index/codec/jvector/JVectorCodec.java new file mode 100644 index 0000000000..828a2657ad --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/jvector/JVectorCodec.java @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.jvector; + +import org.apache.lucene.codecs.CompoundFormat; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.lucene912.Lucene912Codec; + +public class JVectorCodec extends FilterCodec { + + public static final String CODEC_NAME = "JVectorCodec"; + + public JVectorCodec() { + super(CODEC_NAME, new Lucene912Codec()); + } + + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new JVectorFormat(); + } + + @Override + public CompoundFormat compoundFormat() { + return new JVectorCompoundFormat(delegate.compoundFormat()); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/jvector/JVectorCompoundFormat.java b/src/main/java/org/opensearch/knn/index/codec/jvector/JVectorCompoundFormat.java new file mode 100644 index 0000000000..3e44494175 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/jvector/JVectorCompoundFormat.java @@ -0,0 +1,182 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.jvector; + +import lombok.Getter; +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.CompoundDirectory; +import org.apache.lucene.codecs.CompoundFormat; +import org.apache.lucene.index.CorruptIndexException; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.store.*; +import org.apache.lucene.util.CollectionUtil; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.Collections; +import java.util.Map; +import java.util.Set; + +// TODO: This needs to be moved under the same package name as the Lucene internal package name for {@link Lucene90CompoundReader} +// this way the internal package constants can be accessed directly and we can avoid duplicating them. +@Log4j2 +public class JVectorCompoundFormat extends CompoundFormat { + private final CompoundFormat delegate; + static final String DATA_EXTENSION = "cfs"; + + /** Extension of compound file entries */ + static final String ENTRIES_EXTENSION = "cfe"; + + static final String DATA_CODEC = "Lucene90CompoundData"; + static final String ENTRY_CODEC = "Lucene90CompoundEntries"; + static final int VERSION_START = 0; + static final int VERSION_CURRENT = VERSION_START; + public JVectorCompoundFormat(CompoundFormat delegate) { + super(); + this.delegate = delegate; + } + + @Override + public CompoundDirectory getCompoundReader(Directory dir, SegmentInfo si, IOContext context) throws IOException { + return new JVectorCompoundReader(delegate.getCompoundReader(dir, si, context), dir, si, context); + } + + @Override + public void write(Directory dir, SegmentInfo si, IOContext context) throws IOException { + delegate.write(dir, si, context); + } + + public static class JVectorCompoundReader extends CompoundDirectory { + private final CompoundDirectory delegate; + private final String segmentName; + private final Map entries; + @Getter + private final Path directoryBasePath; + @Getter + private final String compoundFileName; + @Getter + private final Path compoundFilePath; + private int version; + + public JVectorCompoundReader(CompoundDirectory delegate, Directory directory, SegmentInfo si, IOContext context) throws IOException { + this.delegate = delegate; + this.segmentName = si.name; + this.directoryBasePath = resolveDirectoryPath(directory); + String entriesFileName = + IndexFileNames.segmentFileName(segmentName, "", JVectorCompoundFormat.ENTRIES_EXTENSION); + this.entries = readEntries(si.getId(), directory, entriesFileName); + this.compoundFileName = + IndexFileNames.segmentFileName(segmentName, "", JVectorCompoundFormat.DATA_EXTENSION); + this.compoundFilePath = directoryBasePath.resolve(compoundFileName); + } + + @Override + public void checkIntegrity() throws IOException { + delegate.checkIntegrity(); + } + + @Override + public String[] listAll() throws IOException { + return delegate.listAll(); + } + + @Override + public long fileLength(String name) throws IOException { + return delegate.fileLength(name); + } + + /** + * Returns the offset of the given file in the compound file. + */ + public long getOffsetInCompoundFile(String name) { + FileEntry entry = entries.get(name); + if (entry == null) { + throw new IllegalArgumentException("No sub-file with id " + name + " found in compound file"); + } + return entry.offset; + } + + @Override + public IndexInput openInput(String name, IOContext context) throws IOException { + return delegate.openInput(name, context); + } + + @Override + public void close() throws IOException { + delegate.close(); + } + + @Override + public Set getPendingDeletions() throws IOException { + return delegate.getPendingDeletions(); + } + + private Map readEntries( + byte[] segmentID, Directory dir, String entriesFileName) throws IOException { + Map mapping = null; + try (ChecksumIndexInput entriesStream = + dir.openChecksumInput(entriesFileName, IOContext.READONCE)) { + Throwable priorE = null; + try { + version = + CodecUtil.checkIndexHeader( + entriesStream, + JVectorCompoundFormat.ENTRY_CODEC, + JVectorCompoundFormat.VERSION_START, + JVectorCompoundFormat.VERSION_CURRENT, + segmentID, + ""); + + mapping = readMapping(entriesStream); + + } catch (Throwable exception) { + priorE = exception; + } finally { + CodecUtil.checkFooter(entriesStream, priorE); + } + } + return Collections.unmodifiableMap(mapping); + } + + private Map readMapping(IndexInput entriesStream) throws IOException { + final int numEntries = entriesStream.readVInt(); + Map mapping = CollectionUtil.newHashMap(numEntries); + for (int i = 0; i < numEntries; i++) { + final FileEntry fileEntry = new FileEntry(); + final String id = segmentName + entriesStream.readString(); + FileEntry previous = mapping.put(id, fileEntry); + if (previous != null) { + throw new CorruptIndexException("Duplicate cfs entry id=" + id + " in CFS ", entriesStream); + } + fileEntry.offset = entriesStream.readLong(); + fileEntry.length = entriesStream.readLong(); + } + return mapping; + } + + public static final class FileEntry { + long offset; + long length; + } + + private Path resolveDirectoryPath(Directory dir) { + while (!(dir instanceof FSDirectory)) { + final String dirType = dir.getClass().getName(); + log.debug("unwrapping dir of type: {} to find path", dirType); + if (dir instanceof FilterDirectory) { + dir = ((FilterDirectory) dir).getDelegate(); + } else { + throw new IllegalArgumentException("directory must be FSDirectory or a wrapper around it but instead had type: " + dirType); + } + } + final Path path = ((FSDirectory) dir).getDirectory(); + log.debug("resolved directory path from FSDirectory: {}", path); + return path; + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/jvector/JVectorFloatVectorValues.java b/src/main/java/org/opensearch/knn/index/codec/jvector/JVectorFloatVectorValues.java new file mode 100644 index 0000000000..84072f803d --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/jvector/JVectorFloatVectorValues.java @@ -0,0 +1,90 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.jvector; + +import io.github.jbellis.jvector.disk.ReaderSupplier; +import io.github.jbellis.jvector.graph.NodesIterator; +import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.search.VectorScorer; + +import java.io.IOException; + +public class JVectorFloatVectorValues extends FloatVectorValues { + private static final VectorTypeSupport VECTOR_TYPE_SUPPORT = VectorizationProvider.getInstance().getVectorTypeSupport(); + + private final OnDiskGraphIndex onDiskGraphIndex; + private final OnDiskGraphIndex.View view; + private int docId = -1; + private final NodesIterator nodesIterator; + private final VectorSimilarityFunction similarityFunction; + + public JVectorFloatVectorValues(OnDiskGraphIndex onDiskGraphIndex, VectorSimilarityFunction similarityFunction) throws IOException { + this.onDiskGraphIndex = onDiskGraphIndex; + this.view = onDiskGraphIndex.getView(); + this.nodesIterator = onDiskGraphIndex.getNodes(); + this.similarityFunction = similarityFunction; + } + + @Override + public int dimension() { + return onDiskGraphIndex.getDimension(); + } + + @Override + public int size() { + return onDiskGraphIndex.size(); + } + + public VectorFloat vectorFloatValue() { + if (!onDiskGraphIndex.containsNode(docId)) { + throw new RuntimeException("DocId " + docId + " not found in graph"); + } + + return view.getVector(docId); + } + + @Override + public float[] vectorValue() throws IOException { + try { + final VectorFloat vector = vectorFloatValue(); + return (float[])vector.get(); + } catch (Throwable e) { + throw new RuntimeException(e); + } + + } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + return new JVectorVectorScorer(this, VECTOR_TYPE_SUPPORT.createFloatVector(query), similarityFunction); + } + + @Override + public int docID() { + return docId; + } + + @Override + public int nextDoc() throws IOException { + if (nodesIterator.hasNext()) { + docId = nodesIterator.next(); + } else { + docId = NO_MORE_DOCS; + } + + return docId; + } + + @Override + public int advance(int target) throws IOException { + return slowAdvance(target); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/jvector/JVectorFormat.java b/src/main/java/org/opensearch/knn/index/codec/jvector/JVectorFormat.java new file mode 100644 index 0000000000..d986530a94 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/jvector/JVectorFormat.java @@ -0,0 +1,64 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.jvector; + +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; + +import java.io.IOException; +import java.nio.file.Path; + +public class JVectorFormat extends KnnVectorsFormat { + public static final String NAME = "JVectorFormat"; + public static final String META_CODEC_NAME = "JVectorVectorsFormatMeta"; + public static final String VECTOR_INDEX_CODEC_NAME = "JVectorVectorsFormatIndex"; + public static final String JVECTOR_FILES_SUFFIX = "jvector"; + public static final String META_EXTENSION = "meta-" + JVECTOR_FILES_SUFFIX; + public static final String VECTOR_INDEX_EXTENSION = "data-" + JVECTOR_FILES_SUFFIX; + public static final int VERSION_START = 0; + public static final int VERSION_CURRENT = VERSION_START; + public static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16; + private static final int DEFAULT_MAX_CONN = 16; + private static final int DEFAULT_BEAM_WIDTH = 100; + private static final float DEFAULT_DEGREE_OVERFLOW = 1.2f; + private static final float DEFAULT_ALPHA = 1.2f; + + private final int maxConn; + private final int beamWidth; + + public JVectorFormat() { + this(NAME, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH); + } + + public JVectorFormat(int maxConn, int beamWidth) { + this(NAME, maxConn, beamWidth); + } + + public JVectorFormat(String name, int maxConn, int beamWidth) { + super(name); + this.maxConn = maxConn; + this.beamWidth = beamWidth; + } + + + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new JVectorWriter(state, maxConn, beamWidth, DEFAULT_DEGREE_OVERFLOW, DEFAULT_ALPHA); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new JVectorReader(state); + } + + static Path getVectorIndexPath(Path directoryBasePath, String baseDataFileName, String field) { + return directoryBasePath.resolve(baseDataFileName + "_" + field + "." + JVectorFormat.VECTOR_INDEX_EXTENSION); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/jvector/JVectorIndexOutput.java b/src/main/java/org/opensearch/knn/index/codec/jvector/JVectorIndexOutput.java new file mode 100644 index 0000000000..7cb1424967 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/jvector/JVectorIndexOutput.java @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.jvector; + +import io.github.jbellis.jvector.disk.RandomAccessWriter; +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.store.IndexOutput; + +import java.io.IOException; + +@Log4j2 +public class JVectorIndexOutput extends IndexOutput { + private final RandomAccessWriter randomAccessWriter; + + public JVectorIndexOutput(RandomAccessWriter randomAccessWriter) { + super("JVectorIndexOutput", "JVectorIndexOutput"); + this.randomAccessWriter = randomAccessWriter; + } + + @Override + public void close() throws IOException { + randomAccessWriter.close(); + } + + @Override + public long getFilePointer() { + try { + return randomAccessWriter.position(); + } catch (IOException e) { + log.error("Error getting file pointer", e); + throw new RuntimeException(e); + } + } + + @Override + public long getChecksum() throws IOException { + return randomAccessWriter.checksum(0, randomAccessWriter.position()); + } + + @Override + public void writeByte(byte b) throws IOException { + randomAccessWriter.write(b); + } + + @Override + public void writeBytes(byte[] b, int offset, int length) throws IOException { + randomAccessWriter.write(b, offset, length); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/jvector/JVectorReader.java b/src/main/java/org/opensearch/knn/index/codec/jvector/JVectorReader.java new file mode 100644 index 0000000000..f068360fe4 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/jvector/JVectorReader.java @@ -0,0 +1,281 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.jvector; + +import io.github.jbellis.jvector.disk.ReaderSupplier; +import io.github.jbellis.jvector.graph.GraphSearcher; +import io.github.jbellis.jvector.graph.SearchResult; +import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; +import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.index.*; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.store.*; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.IOUtils; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.disk.ReaderSupplierFactory; +import org.opensearch.knn.index.codec.KNN80Codec.KNN80CompoundDirectory; + +import java.io.IOException; +import java.nio.file.Path; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding; + +@Log4j2 +public class JVectorReader extends KnnVectorsReader { + private static final VectorTypeSupport VECTOR_TYPE_SUPPORT = VectorizationProvider.getInstance().getVectorTypeSupport(); + + private final FieldInfos fieldInfos; + private final String indexDataFileName; + private final String baseDataFileName; + private final Path directoryBasePath; + // Maps field name to field entries + private final Map fieldEntryMap = new HashMap<>(1); + private final Directory directory; + private final SegmentReadState state; + + + public JVectorReader(SegmentReadState state) throws IOException { + this.state = state; + this.fieldInfos = state.fieldInfos; + this.baseDataFileName = state.segmentInfo.name + "_" + state.segmentSuffix; + String metaFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, state.segmentSuffix, JVectorFormat.META_EXTENSION); + this.directory = state.directory; + this.directoryBasePath = resolveDirectoryPath(directory); + boolean success = false; + try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName, state.context)) { + CodecUtil.checkIndexHeader( + meta, + JVectorFormat.META_CODEC_NAME, + JVectorFormat.VERSION_START, + JVectorFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + Set filenames = state.segmentInfo.files(); + readFields(meta); + CodecUtil.checkFooter(meta); + + this.indexDataFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + JVectorFormat.VECTOR_INDEX_EXTENSION); + + success = true; + } finally { + if (!success) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + @Override + public void checkIntegrity() throws IOException { + // This is already done when loading the fields + // TODO: Implement this, for now this will always pass + //CodecUtil.checksumEntireFile(vectorIndex); + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + return new JVectorFloatVectorValues(fieldEntryMap.get(field).index, fieldEntryMap.get(field).similarityFunction); + } + + @Override + public ByteVectorValues getByteVectorValues(String field) throws IOException { + /** + * Byte vector values are not supported in jVector library. Instead use PQ. + */ + return null; + } + + @Override + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + final OnDiskGraphIndex index = fieldEntryMap.get(field).index; + + // search for a random vector using a GraphSearcher and SearchScoreProvider + VectorFloat q = VECTOR_TYPE_SUPPORT.createFloatVector(target); + try (GraphSearcher searcher = new GraphSearcher(index)) { + SearchScoreProvider ssp = SearchScoreProvider.exact(q, fieldEntryMap.get(field).similarityFunction, index.getView()); + SearchResult sr = searcher.search(ssp, knnCollector.k(), io.github.jbellis.jvector.util.Bits.ALL); + for (SearchResult.NodeScore ns : sr.getNodes()) { + knnCollector.collect(ns.node, ns.score); + } + } + } + + @Override + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + // TODO: implement this + } + + @Override + public void close() throws IOException { + for (FieldEntry fieldEntry : fieldEntryMap.values()) { + IOUtils.close(fieldEntry.readerSupplier::close); + } + } + + @Override + public long ramBytesUsed() { + return 0; + } + + + private void readFields(ChecksumIndexInput meta) throws IOException { + for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) { + final FieldInfo fieldInfo = fieldInfos.fieldInfo(fieldNumber); // read field number) + final VectorEncoding vectorEncoding = readVectorEncoding(meta); + final VectorSimilarityFunction similarityFunction = VectorSimilarityMapper.ordToDistFunc(meta.readInt()); + final long vectorIndexOffset = meta.readVLong(); + final long vectorIndexLength = meta.readVLong(); + final int dimension = meta.readVInt(); + fieldEntryMap.put(fieldInfo.name, new FieldEntry(fieldInfo, similarityFunction, vectorEncoding, vectorIndexOffset, vectorIndexLength, dimension)); + } + } + + class FieldEntry { + private final FieldInfo fieldInfo; + private final VectorEncoding vectorEncoding; + private final VectorSimilarityFunction similarityFunction; + private final long vectorIndexOffset; + private final long vectorIndexLength; + private final int dimension; + private final ReaderSupplier readerSupplier; + private final OnDiskGraphIndex index; + + public FieldEntry( + FieldInfo fieldInfo, + VectorSimilarityFunction similarityFunction, + VectorEncoding vectorEncoding, + long vectorIndexOffset, + long vectorIndexLength, + int dimension) throws IOException { + this.fieldInfo = fieldInfo; + this.similarityFunction = similarityFunction; + this.vectorEncoding = vectorEncoding; + this.vectorIndexOffset = vectorIndexOffset; + this.vectorIndexLength = vectorIndexLength; + this.dimension = dimension; + // TODO: do not depend on the actual nio.Path switch to file name only! + final Path expectedIndexFilePath = JVectorFormat.getVectorIndexPath(directoryBasePath, baseDataFileName, fieldInfo.name); + final String originalIndexFileName = expectedIndexFilePath.getFileName().toString(); + final Path indexFilePath; + final long sliceOffset; + if (state.segmentInfo.getUseCompoundFile()) { + if (directory instanceof JVectorCompoundFormat.JVectorCompoundReader) { + JVectorCompoundFormat.JVectorCompoundReader jVectorCompoundReader = (JVectorCompoundFormat.JVectorCompoundReader) directory; + sliceOffset = jVectorCompoundReader.getOffsetInCompoundFile(originalIndexFileName); + indexFilePath = jVectorCompoundReader.getCompoundFilePath(); + } else if (directory instanceof KNN80CompoundDirectory) { + KNN80CompoundDirectory knn80CompoundDirectory = (KNN80CompoundDirectory) directory; + JVectorCompoundFormat.JVectorCompoundReader jVectorCompoundReader = new JVectorCompoundFormat.JVectorCompoundReader( + knn80CompoundDirectory.getDelegate(), + knn80CompoundDirectory.getDir(), + state.segmentInfo, + state.context + ); + sliceOffset = jVectorCompoundReader.getOffsetInCompoundFile(originalIndexFileName); + indexFilePath = jVectorCompoundReader.getCompoundFilePath(); + } else { + throw new IllegalArgumentException("directory must be JVectorCompoundFormat or KNN80CompoundDirectory but instead had type: " + directory.getClass().getName()); + } + + } else { + sliceOffset = 0; + indexFilePath = expectedIndexFilePath; + } + + // Check the header + try (IndexInput indexInput = directory.openInput(originalIndexFileName, state.context)) { + CodecUtil.checkIndexHeader( + indexInput, + JVectorFormat.VECTOR_INDEX_CODEC_NAME, + JVectorFormat.VERSION_START, + JVectorFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + } + + // Load the graph index + this.readerSupplier = ReaderSupplierFactory.open(indexFilePath); + this.index = OnDiskGraphIndex.load(readerSupplier, sliceOffset + vectorIndexOffset); + + // Check the footer + try (ChecksumIndexInput indexInput = directory.openChecksumInput(originalIndexFileName, state.context)) { + indexInput.seek(vectorIndexOffset + vectorIndexLength); + CodecUtil.checkFooter(indexInput); + } + + } + } + + + /** + * Utility class to map between Lucene and jVector similarity functions and metadata ordinals. + */ + public static class VectorSimilarityMapper { + /** + List of vector similarity functions supported by jVector library + The similarity functions orders matter in this list because it is later used to resolve the similarity function by ordinal. + */ + public static final List JVECTOR_SUPPORTED_SIMILARITY_FUNCTIONS = + List.of( + VectorSimilarityFunction.EUCLIDEAN, + VectorSimilarityFunction.DOT_PRODUCT, + VectorSimilarityFunction.COSINE); + + public static final Map luceneToJVectorMap = Map.of( + org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN, VectorSimilarityFunction.EUCLIDEAN, + org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT, VectorSimilarityFunction.DOT_PRODUCT, + org.apache.lucene.index.VectorSimilarityFunction.COSINE, VectorSimilarityFunction.COSINE + ); + + public static int distFuncToOrd(org.apache.lucene.index.VectorSimilarityFunction func) { + if (luceneToJVectorMap.containsKey(func)) { + return JVECTOR_SUPPORTED_SIMILARITY_FUNCTIONS.indexOf(luceneToJVectorMap.get(func)); + } + + throw new IllegalArgumentException("invalid distance function: " + func); + } + + public static VectorSimilarityFunction ordToDistFunc(int ord) { + return JVECTOR_SUPPORTED_SIMILARITY_FUNCTIONS.get(ord); + } + } + + public static Path resolveDirectoryPath(Directory dir) { + while (!(dir instanceof FSDirectory)) { + final String dirType = dir.getClass().getName(); + log.info("unwrapping dir of type: {} to find path", dirType); + if (dir instanceof FilterDirectory) { + dir = ((FilterDirectory) dir).getDelegate(); + } else if (dir instanceof JVectorCompoundFormat.JVectorCompoundReader) { + return ((JVectorCompoundFormat.JVectorCompoundReader) dir).getDirectoryBasePath(); + } else if (dir instanceof KNN80CompoundDirectory) { + dir = ((KNN80CompoundDirectory) dir).getDir(); + } else { + throw new IllegalArgumentException("directory must be FSDirectory or a wrapper around it but instead had type: " + dirType); + } + } + final Path path = ((FSDirectory) dir).getDirectory(); + log.info("resolved directory path from FSDirectory: {}", path); + return path; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/jvector/JVectorVectorScorer.java b/src/main/java/org/opensearch/knn/index/codec/jvector/JVectorVectorScorer.java new file mode 100644 index 0000000000..676c1b47d4 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/jvector/JVectorVectorScorer.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.jvector; + +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; + +import java.io.IOException; + +public class JVectorVectorScorer implements VectorScorer { + private final JVectorFloatVectorValues floatVectorValues; + private final VectorFloat target; + private final VectorSimilarityFunction similarityFunction; + + public JVectorVectorScorer(JVectorFloatVectorValues vectorValues, VectorFloat target, VectorSimilarityFunction similarityFunction) { + this.floatVectorValues = vectorValues; + this.target = target; + this.similarityFunction = similarityFunction; + } + + @Override + public float score() throws IOException { + return similarityFunction.compare(target, floatVectorValues.vectorFloatValue()); + } + + @Override + public DocIdSetIterator iterator() { + return floatVectorValues; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/jvector/JVectorWriter.java b/src/main/java/org/opensearch/knn/index/codec/jvector/JVectorWriter.java new file mode 100644 index 0000000000..9eb67d69f6 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/jvector/JVectorWriter.java @@ -0,0 +1,376 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.jvector; + +import io.github.jbellis.jvector.graph.GraphIndexBuilder; +import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; +import io.github.jbellis.jvector.graph.OnHeapGraphIndex; +import io.github.jbellis.jvector.graph.RandomAccessVectorValues; +import io.github.jbellis.jvector.graph.disk.*; +import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; +import lombok.Getter; +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnFieldVectorsWriter; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.index.*; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.store.*; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.RamUsageEstimator; +import org.opensearch.common.collect.Tuple; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.*; + +@Log4j2 +public class JVectorWriter extends KnnVectorsWriter { + private static final long SHALLOW_RAM_BYTES_USED = + RamUsageEstimator.shallowSizeOfInstance(JVectorWriter.class); + private final List> fields = new ArrayList<>(); + + private final IndexOutput meta; + private final IndexOutput vectorIndex; + private final String indexDataFileName; + private final String baseDataFileName; + private final Path directoryBasePath; + private final SegmentWriteState segmentWriteState; + private final int maxConn; + private final int beamWidth; + private final float degreeOverflow; + private final float alpha; + private boolean finished = false; + + + public JVectorWriter(SegmentWriteState segmentWriteState, int maxConn, int beamWidth, float degreeOverflow, float alpha) throws IOException { + this.segmentWriteState = segmentWriteState; + this.maxConn = maxConn; + this.beamWidth = beamWidth; + this.degreeOverflow = degreeOverflow; + this.alpha = alpha; + String metaFileName = + IndexFileNames.segmentFileName( + segmentWriteState.segmentInfo.name, segmentWriteState.segmentSuffix, JVectorFormat.META_EXTENSION); + + this.indexDataFileName = + IndexFileNames.segmentFileName( + segmentWriteState.segmentInfo.name, + segmentWriteState.segmentSuffix, + JVectorFormat.VECTOR_INDEX_EXTENSION); + this.baseDataFileName = segmentWriteState.segmentInfo.name + "_" + segmentWriteState.segmentSuffix; + + Directory dir = segmentWriteState.directory; + this.directoryBasePath = JVectorReader.resolveDirectoryPath(dir); + + boolean success = false; + try { + meta = segmentWriteState.directory.createOutput(metaFileName, segmentWriteState.context); + vectorIndex = segmentWriteState.directory.createOutput(indexDataFileName, segmentWriteState.context); + CodecUtil.writeIndexHeader( + meta, + JVectorFormat.META_CODEC_NAME, + JVectorFormat.VERSION_CURRENT, + segmentWriteState.segmentInfo.getId(), + segmentWriteState.segmentSuffix); + + CodecUtil.writeIndexHeader( + vectorIndex, + JVectorFormat.VECTOR_INDEX_CODEC_NAME, + JVectorFormat.VERSION_CURRENT, + segmentWriteState.segmentInfo.getId(), + segmentWriteState.segmentSuffix); + + success = true; + } finally { + if (!success) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + @Override + public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { + log.info("Adding field {} in segment {}", fieldInfo.name, segmentWriteState.segmentInfo.name); + if (fieldInfo.getVectorEncoding() == VectorEncoding.BYTE) { + final String errorMessage = "byte[] vectors are not supported in JVector. " + + "Instead you should only use float vectors and leverage product quantization during indexing." + + "This can provides much greater savings in storage and memory"; + log.error(errorMessage); + throw new UnsupportedOperationException(errorMessage); + } + JVectorWriter.FieldWriter newField = + new JVectorWriter.FieldWriter<>(fieldInfo, segmentWriteState.segmentInfo.name); + fields.add(newField); + return newField; + } + + @Override + public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + log.info("Merging field {} into segment {}", fieldInfo.name, segmentWriteState.segmentInfo.name); + var success = false; + try { + switch (fieldInfo.getVectorEncoding()) { + case BYTE: + var byteWriter = + (JVectorWriter.FieldWriter) addField(fieldInfo); + ByteVectorValues mergedBytes = + MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); + for (int doc = mergedBytes.nextDoc(); + doc != DocIdSetIterator.NO_MORE_DOCS; + doc = mergedBytes.nextDoc()) { + byteWriter.addValue(doc, mergedBytes.vectorValue()); + } + writeField(byteWriter); + break; + case FLOAT32: + var floatVectorFieldWriter = + (JVectorWriter.FieldWriter) addField(fieldInfo); + int baseDocId = 0; + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + FloatVectorValues floatVectorValues = mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name); + var floatVectors = new ArrayList(); + for (int doc = floatVectorValues.nextDoc(); + doc != DocIdSetIterator.NO_MORE_DOCS; + doc = floatVectorValues.nextDoc()) { + floatVectors.add(floatVectorValues.vectorValue()); + } + for (int doc = 0; doc < floatVectors.size(); doc++) { + floatVectorFieldWriter.addValue(baseDocId + doc, floatVectors.get(doc)); + } + + baseDocId += floatVectorValues.size(); + } + writeField(floatVectorFieldWriter); + break; + } + success = true; + log.info("Completed Merge field {} into segment {}", fieldInfo.name, segmentWriteState.segmentInfo.name); + } finally { + if (success) { + //IOUtils.close(scorerSupplier); + } else { + //IOUtils.closeWhileHandlingException(scorerSupplier); + } + } + } + + @Override + public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { + log.info("Flushing {} fields", fields.size()); + + for (JVectorWriter.FieldWriter field : fields) { + if (sortMap == null) { + writeField(field); + } else { + throw new UnsupportedOperationException("Not implemented yet"); + //writeSortingField(field, sortMap); + } + } + } + + private void writeField(JVectorWriter.FieldWriter fieldData) throws IOException { + // write graph + //long vectorIndexOffset = vectorIndex.getFilePointer(); + OnHeapGraphIndex graph = fieldData.getGraph(); + Tuple vectorIndexOffsetAndLength = writeGraph(graph, fieldData); + + + writeMeta( + fieldData.fieldInfo, + vectorIndexOffsetAndLength.v1(), // vectorIndexOffset + vectorIndexOffsetAndLength.v2() // vectorIndexLength); + ); + } + + + + private void writeMeta( + FieldInfo field, + long vectorIndexOffset, + long vectorIndexLength) + throws IOException { + meta.writeInt(field.number); + meta.writeInt(field.getVectorEncoding().ordinal()); + meta.writeInt(JVectorReader.VectorSimilarityMapper.distFuncToOrd(field.getVectorSimilarityFunction())); + meta.writeVLong(vectorIndexOffset); + meta.writeVLong(vectorIndexLength); + meta.writeVInt(field.getVectorDimension()); + } + + + + /** + * Writes the graph to the vector index file + * @param graph graph + * @param fieldData fieldData + * @return Tuple of start offset and length of the graph + * @throws IOException IOException + */ + private Tuple writeGraph(OnHeapGraphIndex graph, FieldWriter fieldData) throws IOException { + // TODO: use the vector index inputStream instead of this! + final Path jvecFilePath = JVectorFormat.getVectorIndexPath(directoryBasePath, baseDataFileName, fieldData.fieldInfo.name); + /** This is an ugly hack to make sure Lucene actually knows about our input stream files, otherwise it will delete them */ + IndexOutput indexOutput = segmentWriteState.directory.createOutput(jvecFilePath.getFileName().toString(), segmentWriteState.context); + CodecUtil.writeIndexHeader( + indexOutput, + JVectorFormat.VECTOR_INDEX_CODEC_NAME, + JVectorFormat.VERSION_CURRENT, + segmentWriteState.segmentInfo.getId(), + segmentWriteState.segmentSuffix); + final long startOffset = indexOutput.getFilePointer(); + indexOutput.close(); + /** End of ugly hack */ + + log.info("Writing graph to {}", jvecFilePath); + final Tuple result; + try (var writer = new OnDiskGraphIndexWriter.Builder(graph, jvecFilePath) + .with(new InlineVectors(fieldData.randomAccessVectorValues.dimension())) + .withStartOffset(startOffset) + .build()) { + var suppliers = Feature.singleStateFactory(FeatureId.INLINE_VECTORS, + nodeId -> new InlineVectors.State(fieldData.randomAccessVectorValues.getVector(nodeId))); + writer.write(suppliers); + long endOffset = writer.getOutput().position(); + result = new Tuple<>(startOffset, endOffset - startOffset); + // write footer by wrapping jVector RandomAccessOutput to IndexOutput object + // This mostly helps to interface with the existing Lucene CodecUtil + IndexOutput jvecIndexOutput = new JVectorIndexOutput(writer.getOutput()); + CodecUtil.writeFooter(jvecIndexOutput); + } + + + return result; + } + + @Override + public void finish() throws IOException { + log.info("Finishing segment {}", segmentWriteState.segmentInfo.name); + if (finished) { + throw new IllegalStateException("already finished"); + } + finished = true; + + if (meta != null) { + // write end of fields marker + meta.writeInt(-1); + CodecUtil.writeFooter(meta); + } + + if (vectorIndex != null) { + CodecUtil.writeFooter(vectorIndex); + } + } + + @Override + public void close() throws IOException { + IOUtils.close(meta, vectorIndex); + } + + @Override + public long ramBytesUsed() { + long total = SHALLOW_RAM_BYTES_USED; + for (JVectorWriter.FieldWriter field : fields) { + // the field tracks the delegate field usage + total += field.ramBytesUsed(); + } + return total; + } + + class FieldWriter extends KnnFieldVectorsWriter { + private final VectorTypeSupport VECTOR_TYPE_SUPPORT = VectorizationProvider.getInstance().getVectorTypeSupport(); + private final long SHALLOW_SIZE = + RamUsageEstimator.shallowSizeOfInstance(JVectorWriter.FieldWriter.class); + @Getter + private final FieldInfo fieldInfo; + private int lastDocID = -1; + private final GraphIndexBuilder graphIndexBuilder; + private final List> floatVectors = new ArrayList<>(); + private final String segmentName; + private final RandomAccessVectorValues randomAccessVectorValues; + private final BuildScoreProvider buildScoreProvider; + + + FieldWriter(FieldInfo fieldInfo, String segmentName) { + this.fieldInfo = fieldInfo; + this.segmentName = segmentName; + var originalDimension = fieldInfo.getVectorDimension(); + this.randomAccessVectorValues = new ListRandomAccessVectorValues(floatVectors, originalDimension); + this.buildScoreProvider = BuildScoreProvider.randomAccessScoreProvider(randomAccessVectorValues, getVectorSimilarityFunction(fieldInfo)); + this.graphIndexBuilder = new GraphIndexBuilder(buildScoreProvider, + randomAccessVectorValues.dimension(), + maxConn, + beamWidth, + degreeOverflow, + alpha); + } + + @Override + public void addValue(int docID, T vectorValue) throws IOException { + log.debug("Adding value {} to field {} in segment {}", vectorValue, fieldInfo.name, segmentName); + if (docID == lastDocID) { + throw new IllegalArgumentException( + "VectorValuesField \"" + + fieldInfo.name + + "\" appears more than once in this document (only one value is allowed per field)"); + } + if (vectorValue instanceof float[]) { + var floats = (float[]) vectorValue; + var vector = VECTOR_TYPE_SUPPORT.createFloatVector(floats); + floatVectors.add(vector); + graphIndexBuilder.addGraphNode(docID, vector); + } else if (vectorValue instanceof byte[]) { + final String errorMessage = "byte[] vectors are not supported in JVector. " + + "Instead you should only use float vectors and leverage product quantization during indexing." + + "This can provides much greater savings in storage and memory"; + log.error(errorMessage); + throw new UnsupportedOperationException(errorMessage); + } else { + throw new IllegalArgumentException("Unsupported vector type: " + vectorValue.getClass()); + } + + + lastDocID = docID; + } + + @Override + public T copyValue(T vectorValue) { + throw new UnsupportedOperationException("copyValue not supported"); + } + + @Override + public long ramBytesUsed() { + return SHALLOW_SIZE + + graphIndexBuilder.getGraph().ramBytesUsed(); + } + + io.github.jbellis.jvector.vector.VectorSimilarityFunction getVectorSimilarityFunction(FieldInfo fieldInfo) { + log.info("Matching vector similarity function {} for field {}", fieldInfo.getVectorSimilarityFunction(), fieldInfo.name); + switch (fieldInfo.getVectorSimilarityFunction()) { + case EUCLIDEAN: + return io.github.jbellis.jvector.vector.VectorSimilarityFunction.EUCLIDEAN; + case COSINE: + return io.github.jbellis.jvector.vector.VectorSimilarityFunction.COSINE; + case DOT_PRODUCT: + return io.github.jbellis.jvector.vector.VectorSimilarityFunction.DOT_PRODUCT; + default: + throw new IllegalArgumentException("Unsupported similarity function: " + fieldInfo.getVectorSimilarityFunction()); + } + } + + /** + * This method will return the graph index for the field + * @return OnHeapGraphIndex + * @throws IOException IOException + */ + public OnHeapGraphIndex getGraph() throws IOException { + return graphIndexBuilder.getGraph(); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java b/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java index 1e560a11ba..28b73892a7 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java @@ -11,6 +11,7 @@ import org.opensearch.knn.index.engine.faiss.Faiss; import org.opensearch.knn.index.engine.lucene.Lucene; import org.opensearch.knn.index.engine.nmslib.Nmslib; +import org.opensearch.knn.index.engine.jvector.JVector; import java.util.List; import java.util.Map; @@ -19,6 +20,7 @@ import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME; +import static org.opensearch.knn.common.KNNConstants.JVECTOR_NAME; /** * KNNEngine provides the functionality to validate and transform user defined indices into information that can be @@ -27,7 +29,8 @@ public enum KNNEngine implements KNNLibrary { NMSLIB(NMSLIB_NAME, Nmslib.INSTANCE), FAISS(FAISS_NAME, Faiss.INSTANCE), - LUCENE(LUCENE_NAME, Lucene.INSTANCE); + LUCENE(LUCENE_NAME, Lucene.INSTANCE), + JVECTOR(JVECTOR_NAME, JVector.INSTANCE); public static final KNNEngine DEFAULT = FAISS; @@ -77,6 +80,10 @@ public static KNNEngine getEngine(String name) { return LUCENE; } + if (JVECTOR.getName().equalsIgnoreCase(name)) { + return JVECTOR; + } + throw new IllegalArgumentException(String.format("Invalid engine type: %s", name)); } diff --git a/src/main/java/org/opensearch/knn/index/engine/jvector/JVector.java b/src/main/java/org/opensearch/knn/index/engine/jvector/JVector.java new file mode 100644 index 0000000000..f4872906af --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/engine/jvector/JVector.java @@ -0,0 +1,95 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine.jvector; + +import com.google.common.collect.ImmutableMap; +import org.apache.lucene.util.Version; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.codec.jvector.JVectorFormat; +import org.opensearch.knn.index.engine.*; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +import static org.opensearch.knn.common.KNNConstants.DISK_ANN; + +/** + * JVector engine is a JVM-based k-NN library that leverages the JVector library . + * This class provides specific functionalities for handling k-NN methods along with transformations + * for distance and scoring based on different space types. + * JVector unique advantage over other libraries are it's very lightweight, SIMD support and implementation of Disk ANN. + * + * JVector also manages method resolution during query execution and facilitates interaction with + * Lucene's k-NN framework by providing conversions between scores and distance thresholds. + * + * The class makes use of pre-defined static mappings for supported k-NN methods and distance translations, + * which determine how scores and distances should be interpreted and converted in various space types. + */ +public class JVector extends JVMLibrary { + private final static String CUSTOM_COMPOUND_FILE_EXTENSION = "cjvec"; + private final Map> distanceTransform; + private final MethodResolver methodResolver = new JVectorMethodResolver(); + + public JVector(Map methods, String version, Map> distanceTransform) { + super(methods, version); + this.distanceTransform = distanceTransform; + } + + final static Map METHODS = Map.of(DISK_ANN, new JVectorDiskANNMethod()); + + private final static Map> DISTANCE_TRANSLATIONS = ImmutableMap.< + SpaceType, + Function>builder() + .put(SpaceType.COSINESIMIL, distance -> (2 - distance) / 2) + .put(SpaceType.INNER_PRODUCT, distance -> distance <= 0 ? 1 / (1 - distance) : distance + 1) + .build(); + + public final static JVector INSTANCE = new JVector(METHODS, Version.LATEST.toString(), DISTANCE_TRANSLATIONS); + + @Override + public String getExtension() { + return JVectorFormat.JVECTOR_FILES_SUFFIX; + } + + @Override + public String getCompoundExtension() { + return CUSTOM_COMPOUND_FILE_EXTENSION; + } + + @Override + public float score(float rawScore, SpaceType spaceType) { + return rawScore; + } + + @Override + public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { + return distance; + } + + @Override + public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { + return score; + } + + // TODO: Implement this + @Override + public ResolvedMethodContext resolveMethod( + KNNMethodContext knnMethodContext, + KNNMethodConfigContext knnMethodConfigContext, + boolean shouldRequireTraining, + SpaceType spaceType + ) { + return methodResolver.resolveMethod(knnMethodContext, knnMethodConfigContext, shouldRequireTraining, spaceType); + } + + // TODO: add actual file suffix there + @Override + public List mmapFileExtensions() { + return Collections.emptyList(); + } +} diff --git a/src/main/java/org/opensearch/knn/index/engine/jvector/JVectorDiskANNMethod.java b/src/main/java/org/opensearch/knn/index/engine/jvector/JVectorDiskANNMethod.java new file mode 100644 index 0000000000..05d6cd396d --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/engine/jvector/JVectorDiskANNMethod.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine.jvector; + +import com.google.common.collect.ImmutableSet; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.*; + +import java.util.Arrays; +import java.util.List; +import java.util.Set; + +import static org.opensearch.knn.common.KNNConstants.*; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; + +public class JVectorDiskANNMethod extends AbstractKNNMethod { + + private static final Set SUPPORTED_DATA_TYPES = ImmutableSet.of(VectorDataType.FLOAT); + + public final static List SUPPORTED_SPACES = Arrays.asList( + SpaceType.UNDEFINED, + SpaceType.L2, + SpaceType.L1, + SpaceType.LINF, + SpaceType.COSINESIMIL, + SpaceType.INNER_PRODUCT + ); + + final static MethodComponent HNSW_METHOD_COMPONENT = initMethodComponent(); + + private static MethodComponent initMethodComponent() { + return MethodComponent.Builder.builder(METHOD_HNSW) + .addSupportedDataTypes(SUPPORTED_DATA_TYPES) + .addParameter( + METHOD_PARAMETER_M, + new Parameter.IntegerParameter(METHOD_PARAMETER_M, KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M, (v, context) -> v > 0) + ) + .addParameter( + METHOD_PARAMETER_EF_CONSTRUCTION, + new Parameter.IntegerParameter( + METHOD_PARAMETER_EF_CONSTRUCTION, + KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION, + (v, context) -> v > 0 + ) + ) + .build(); + } + + public JVectorDiskANNMethod() { + super(HNSW_METHOD_COMPONENT, Set.copyOf(SUPPORTED_SPACES), new DefaultHnswSearchContext()); + } +} diff --git a/src/main/java/org/opensearch/knn/index/engine/jvector/JVectorMethodResolver.java b/src/main/java/org/opensearch/knn/index/engine/jvector/JVectorMethodResolver.java new file mode 100644 index 0000000000..994798f1aa --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/engine/jvector/JVectorMethodResolver.java @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine.jvector; + +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.engine.AbstractMethodResolver; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.ResolvedMethodContext; +import org.opensearch.knn.index.mapper.CompressionLevel; + +public class JVectorMethodResolver extends AbstractMethodResolver { + + @Override + public ResolvedMethodContext resolveMethod( + KNNMethodContext knnMethodContext, + KNNMethodConfigContext knnMethodConfigContext, + boolean shouldRequireTraining, + SpaceType spaceType + ) { + return ResolvedMethodContext.builder().knnMethodContext(knnMethodContext).compressionLevel(CompressionLevel.x1).build(); + } +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index c17abc2e1e..62add9aefc 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -283,8 +283,8 @@ public KNNVectorFieldMapper build(BuilderContext context) { originalParameters ); } - - if (originalParameters.getResolvedKnnMethodContext().getKnnEngine() == KNNEngine.LUCENE) { + final KNNEngine knnEngine = originalParameters.getResolvedKnnMethodContext().getKnnEngine(); + if (knnEngine == KNNEngine.LUCENE || knnEngine == KNNEngine.JVECTOR) { log.debug(String.format(Locale.ROOT, "Use [LuceneFieldMapper] mapper for field [%s]", name)); LuceneFieldMapper.CreateLuceneFieldMapperInput createLuceneFieldMapperInput = LuceneFieldMapper.CreateLuceneFieldMapperInput .builder() diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index 7fb880f192..93192a8c26 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -161,7 +161,6 @@ public class KNNPlugin extends Plugin public static final String LEGACY_KNN_BASE_URI = "/_opendistro/_knn"; public static final String KNN_BASE_URI = "/_plugins/_knn"; - private KNNStats knnStats; private ClusterService clusterService; @Override @@ -212,8 +211,7 @@ public Collection createComponents( clusterService.addListener(TrainingJobClusterStateListener.getInstance()); - knnStats = new KNNStats(); - return ImmutableList.of(knnStats); + return ImmutableList.of(new KNNStats()); } @Override diff --git a/src/main/plugin-metadata/plugin-security.policy b/src/main/plugin-metadata/plugin-security.policy index ed329740f5..a4692a1cb1 100644 --- a/src/main/plugin-metadata/plugin-security.policy +++ b/src/main/plugin-metadata/plugin-security.policy @@ -4,7 +4,13 @@ grant { permission java.lang.RuntimePermission "loadLibrary.opensearchknn_common"; permission java.lang.RuntimePermission "loadLibrary.opensearchknn_faiss_avx2"; permission java.lang.RuntimePermission "loadLibrary.opensearchknn_faiss_avx512"; + permission java.lang.RuntimePermission "modifyThreadGroup"; + permission java.lang.RuntimePermission "modifyThread"; + permission java.lang.RuntimePermission "getStackTrace"; + permission java.lang.RuntimePermission "setContextClassLoader"; permission java.net.SocketPermission "*", "connect,resolve"; + permission java.util.PropertyPermission "*", "write"; + permission java.util.PropertyPermission "*", "read"; permission java.lang.RuntimePermission "accessDeclaredMembers"; permission java.io.FilePermission "/proc/cpuinfo", "read"; }; diff --git a/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec b/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec index 7a8916981e..24ebc00e68 100644 --- a/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec +++ b/src/main/resources/META-INF/services/org.apache.lucene.codecs.Codec @@ -8,4 +8,5 @@ org.opensearch.knn.index.codec.KNN940Codec.KNN940Codec org.opensearch.knn.index.codec.KNN950Codec.KNN950Codec org.opensearch.knn.index.codec.KNN990Codec.KNN990Codec org.opensearch.knn.index.codec.KNN9120Codec.KNN9120Codec +org.opensearch.knn.index.codec.jvector.JVectorCodec org.opensearch.knn.index.codec.KNN990Codec.UnitTestCodec diff --git a/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index 0fa7314c85..8e078a8318 100644 --- a/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -11,3 +11,4 @@ org.opensearch.knn.index.codec.KNN990Codec.NativeEngines990KnnVectorsFormat org.opensearch.knn.index.codec.KNN9120Codec.KNN9120HnswBinaryVectorsFormat +org.opensearch.knn.index.codec.jvector.JVectorFormat diff --git a/src/test/java/org/opensearch/knn/index/JVectorEngineIT.java b/src/test/java/org/opensearch/knn/index/JVectorEngineIT.java new file mode 100644 index 0000000000..bb68d3797c --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/JVectorEngineIT.java @@ -0,0 +1,453 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import lombok.SneakyThrows; +import org.apache.commons.lang.math.RandomUtils; +import org.apache.hc.core5.http.ParseException; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.junit.After; +import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.KNNResult; +import org.opensearch.knn.TestUtils; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.test.OpenSearchIntegTestCase; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static org.opensearch.index.engine.EngineConfig.INDEX_USE_COMPOUND_FILE; +import static org.opensearch.knn.common.KNNConstants.*; + +@OpenSearchIntegTestCase.ClusterScope(numDataNodes = 1) +public class JVectorEngineIT extends KNNRestTestCase { + + private static final int DIMENSION = 3; + private static final String DOC_ID = "doc1"; + private static final String DOC_ID_2 = "doc2"; + private static final String DOC_ID_3 = "doc3"; + private static final int EF_CONSTRUCTION = 128; + private static final String COLOR_FIELD_NAME = "color"; + private static final String TASTE_FIELD_NAME = "taste"; + private static final int M = 16; + + private static final Float[][] TEST_INDEX_VECTORS = { { 1.0f, 1.0f, 1.0f }, { 2.0f, 2.0f, 2.0f }, { 3.0f, 3.0f, 3.0f } }; + private static final Float[][] TEST_COSINESIMIL_INDEX_VECTORS = { { 6.0f, 7.0f, 3.0f }, { 3.0f, 2.0f, 5.0f }, { 4.0f, 5.0f, 7.0f } }; + private static final Float[][] TEST_INNER_PRODUCT_INDEX_VECTORS = { + { 1.0f, 1.0f, 1.0f }, + { 2.0f, 2.0f, 2.0f }, + { 3.0f, 3.0f, 3.0f }, + { -1.0f, -1.0f, -1.0f }, + { -2.0f, -2.0f, -2.0f }, + { -3.0f, -3.0f, -3.0f } }; + + private static final float[][] TEST_QUERY_VECTORS = { { 1.0f, 1.0f, 1.0f }, { 2.0f, 2.0f, 2.0f }, { 3.0f, 3.0f, 3.0f } }; + + private static final Map> VECTOR_SIMILARITY_TO_SCORE = ImmutableMap.of( + KNNVectorSimilarityFunction.EUCLIDEAN, + (similarity) -> 1 / (1 + similarity), + KNNVectorSimilarityFunction.DOT_PRODUCT, + (similarity) -> (1 + similarity) / 2, + KNNVectorSimilarityFunction.COSINE, + (similarity) -> (1 + similarity) / 2, + KNNVectorSimilarityFunction.MAXIMUM_INNER_PRODUCT, + (similarity) -> similarity <= 0 ? 1 / (1 - similarity) : similarity + 1 + ); + private static final String DIMENSION_FIELD_NAME = "dimension"; + private static final String KNN_VECTOR_TYPE = "knn_vector"; + private static final String PROPERTIES_FIELD_NAME = "properties"; + private static final String TYPE_FIELD_NAME = "type"; + private static final String INTEGER_FIELD_NAME = "int_field"; + private static final String FILED_TYPE_INTEGER = "integer"; + private static final String NON_EXISTENT_INTEGER_FIELD_NAME = "nonexistent_int_field"; + + @After + public final void cleanUp() throws IOException { + deleteKNNIndex(INDEX_NAME); + } + + public void testQuery_l2() throws Exception { + baseQueryTest(SpaceType.L2); + } + + public void testQuery_cosine() throws Exception { + baseQueryTest(SpaceType.COSINESIMIL); + } + + public void testQuery_invalidVectorDimensionInQuery() throws Exception { + + createKnnIndexMappingWithJVectorEngine(DIMENSION, SpaceType.L2, VectorDataType.FLOAT); + for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) { + addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]); + } + + float[] invalidQuery = new float[DIMENSION - 1]; + int validK = 1; + expectThrows( + ResponseException.class, + () -> searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, invalidQuery, validK), validK) + ); + } + + public void testQuery_documentsMissingField() throws Exception { + + SpaceType spaceType = SpaceType.L2; + + createKnnIndexMappingWithJVectorEngine(DIMENSION, spaceType, VectorDataType.FLOAT); + for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) { + addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]); + } + + // Add a doc without the lucene field set + String secondField = "field-2"; + addDocWithNumericField(INDEX_NAME, Integer.toString(TEST_INDEX_VECTORS.length + 1), secondField, 0L); + + validateQueries(spaceType, FIELD_NAME); + } + + public void testAddDoc() throws Exception { + List mValues = ImmutableList.of(16, 32, 64, 128); + List efConstructionValues = ImmutableList.of(16, 32, 64, 128); + + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD_NAME) + .startObject(FIELD_NAME) + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, DIMENSION) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, DISK_ANN) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) + .field(KNNConstants.KNN_ENGINE, KNNEngine.JVECTOR.getName()) + .startObject(KNNConstants.PARAMETERS) + .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + Settings indexSettings = getDefaultIndexSettings(); + /* For now we will disable compound file to make it easier to read the index files */ + // TODO: Enable compound file once we have a way to read it with jVector files + indexSettings = Settings.builder().put(indexSettings).put(INDEX_USE_COMPOUND_FILE.getKey(), false).build(); + + createKnnIndex(INDEX_NAME, indexSettings, mapping); + assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(INDEX_NAME))); + + Float[] vector = new Float[] { 2.0f, 4.5f, 6.5f }; + addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector); + + refreshIndex(INDEX_NAME); + assertEquals(1, getDocCount(INDEX_NAME)); + } + + public void testUpdateDoc() throws Exception { + createKnnIndexMappingWithJVectorEngine(2, SpaceType.L2, VectorDataType.FLOAT); + Float[] vector = { 6.0f, 6.0f }; + addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector); + + Float[] updatedVector = { 8.0f, 8.0f }; + updateKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, updatedVector); + + refreshIndex(INDEX_NAME); + assertEquals(1, getDocCount(INDEX_NAME)); + } + + public void testDeleteDoc() throws Exception { + createKnnIndexMappingWithJVectorEngine(2, SpaceType.L2, VectorDataType.FLOAT); + Float[] vector = { 6.0f, 6.0f }; + addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector); + + deleteKnnDoc(INDEX_NAME, DOC_ID); + + refreshIndex(INDEX_NAME); + assertEquals(0, getDocCount(INDEX_NAME)); + } + + public void testQueryWithFilterUsingFloatVectorDataType() throws Exception { + createKnnIndexMappingWithJVectorEngine(DIMENSION, SpaceType.L2, VectorDataType.FLOAT); + + addKnnDocWithAttributes( + DOC_ID, + new float[] { 6.0f, 7.9f, 3.1f }, + ImmutableMap.of(COLOR_FIELD_NAME, "red", TASTE_FIELD_NAME, "sweet") + ); + addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.2f, 2.1f, 4.8f }, ImmutableMap.of(COLOR_FIELD_NAME, "green")); + addKnnDocWithAttributes(DOC_ID_3, new float[] { 4.1f, 5.0f, 7.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); + + refreshIndex(INDEX_NAME); + + final float[] searchVector = { 6.0f, 6.0f, 4.1f }; + List expectedDocIdsKGreaterThanFilterResult = Arrays.asList(DOC_ID, DOC_ID_3); + List expectedDocIdsKLimitsFilterResult = Arrays.asList(DOC_ID); + validateQueryResultsWithFilters(searchVector, 5, 1, expectedDocIdsKGreaterThanFilterResult, expectedDocIdsKLimitsFilterResult); + } + + @SneakyThrows + public void testQueryWithFilterMultipleShards() { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD_NAME) + .startObject(FIELD_NAME) + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, DIMENSION) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, DISK_ANN) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) + .field(KNNConstants.KNN_ENGINE, KNNEngine.JVECTOR.getName()) + .endObject() + .endObject() + .startObject(INTEGER_FIELD_NAME) + .field(TYPE_FIELD_NAME, FILED_TYPE_INTEGER) + .endObject() + .endObject() + .endObject(); + String mapping = builder.toString(); + + createIndex(INDEX_NAME, Settings.builder().put("number_of_shards", 10).put("number_of_replicas", 0).put("index.knn", true).build()); + putMappingRequest(INDEX_NAME, mapping); + + addKnnDocWithAttributes("doc1", new float[] { 7.0f, 7.0f, 3.0f }, ImmutableMap.of("dateReceived", "2024-10-01")); + + refreshIndex(INDEX_NAME); + + final float[] searchVector = { 6.0f, 7.0f, 3.0f }; + final Response response = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder( + FIELD_NAME, + searchVector, + 1, + QueryBuilders.boolQuery().must(QueryBuilders.rangeQuery("dateReceived").gte("2023-11-01")) + ), + 10 + ); + final String responseBody = EntityUtils.toString(response.getEntity()); + final List knnResults = parseSearchResponse(responseBody, FIELD_NAME); + + assertEquals(1, knnResults.size()); + } + + @SneakyThrows + public void testQueryWithFilter_whenNonExistingFieldUsedInFilter_thenSuccessful() { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD_NAME) + .startObject(FIELD_NAME) + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, DIMENSION) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, DISK_ANN) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) + .field(KNNConstants.KNN_ENGINE, KNNEngine.JVECTOR.getName()) + .endObject() + .endObject() + .startObject(INTEGER_FIELD_NAME) + .field(TYPE_FIELD_NAME, FILED_TYPE_INTEGER) + .endObject() + .endObject() + .endObject(); + Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + + createKnnIndex(INDEX_NAME, mapping); + + Float[] vector = new Float[] { 2.0f, 4.5f, 6.5f }; + + String documentAsString = XContentFactory.jsonBuilder() + .startObject() + .field(INTEGER_FIELD_NAME, 5) + .field(FIELD_NAME, vector) + .endObject() + .toString(); + + addKnnDoc(INDEX_NAME, DOC_ID, documentAsString); + + refreshIndex(INDEX_NAME); + assertEquals(1, getDocCount(INDEX_NAME)); + + float[] searchVector = new float[] { 1.0f, 2.1f, 3.9f }; + int k = 10; + + // use filter where nonexistent field is must, we should have no results + QueryBuilder filterWithRequiredNonExistentField = QueryBuilders.boolQuery() + .must(QueryBuilders.rangeQuery(NON_EXISTENT_INTEGER_FIELD_NAME).gte(1)); + Response searchWithRequiredNonExistentFiledInFilterResponse = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder(FIELD_NAME, searchVector, k, filterWithRequiredNonExistentField), + k + ); + List resultsQuery1 = parseSearchResponse( + EntityUtils.toString(searchWithRequiredNonExistentFiledInFilterResponse.getEntity()), + FIELD_NAME + ); + assertTrue(resultsQuery1.isEmpty()); + + // use filter with non existent field as optional, we should have some results + QueryBuilder filterWithOptionalNonExistentField = QueryBuilders.boolQuery() + .should(QueryBuilders.rangeQuery(NON_EXISTENT_INTEGER_FIELD_NAME).gte(1)) + .must(QueryBuilders.rangeQuery(INTEGER_FIELD_NAME).gte(1)); + Response searchWithOptionalNonExistentFiledInFilterResponse = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder(FIELD_NAME, searchVector, k, filterWithOptionalNonExistentField), + k + ); + List resultsQuery2 = parseSearchResponse( + EntityUtils.toString(searchWithOptionalNonExistentFiledInFilterResponse.getEntity()), + FIELD_NAME + ); + assertEquals(1, resultsQuery2.size()); + } + + public void testIndexReopening() throws Exception { + createKnnIndexMappingWithJVectorEngine(DIMENSION, SpaceType.L2, VectorDataType.FLOAT); + + for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) { + addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]); + } + + final float[] searchVector = TEST_QUERY_VECTORS[0]; + final int k = 1 + RandomUtils.nextInt(TEST_INDEX_VECTORS.length); + + final List knnResultsBeforeIndexClosure = queryResults(searchVector, k); + + closeIndex(INDEX_NAME); + openIndex(INDEX_NAME); + + ensureGreen(INDEX_NAME); + + final List knnResultsAfterIndexClosure = queryResults(searchVector, k); + + assertArrayEquals(knnResultsBeforeIndexClosure.toArray(), knnResultsAfterIndexClosure.toArray()); + } + + private List queryResults(final float[] searchVector, final int k) throws Exception { + final String responseBody = EntityUtils.toString( + searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, searchVector, k), k).getEntity() + ); + final List knnResults = parseSearchResponse(responseBody, FIELD_NAME); + assertNotNull(knnResults); + return knnResults.stream().map(KNNResult::getVector).collect(Collectors.toUnmodifiableList()); + } + + private void validateQueryResultsWithFilters( + float[] searchVector, + int kGreaterThanFilterResult, + int kLimitsFilterResult, + List expectedDocIdsKGreaterThanFilterResult, + List expectedDocIdsKLimitsFilterResult + ) throws IOException, ParseException { + final Response response = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder(FIELD_NAME, searchVector, kGreaterThanFilterResult, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")), + kGreaterThanFilterResult + ); + final String responseBody = EntityUtils.toString(response.getEntity()); + final List knnResults = parseSearchResponse(responseBody, FIELD_NAME); + + assertEquals(expectedDocIdsKGreaterThanFilterResult.size(), knnResults.size()); + assertTrue( + knnResults.stream().map(KNNResult::getDocId).collect(Collectors.toList()).containsAll(expectedDocIdsKGreaterThanFilterResult) + ); + + final Response responseKLimitsFilterResult = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder(FIELD_NAME, searchVector, kLimitsFilterResult, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")), + kLimitsFilterResult + ); + final String responseBodyKLimitsFilterResult = EntityUtils.toString(responseKLimitsFilterResult.getEntity()); + final List knnResultsKLimitsFilterResult = parseSearchResponse(responseBodyKLimitsFilterResult, FIELD_NAME); + + assertEquals(expectedDocIdsKLimitsFilterResult.size(), knnResultsKLimitsFilterResult.size()); + assertTrue( + knnResultsKLimitsFilterResult.stream() + .map(KNNResult::getDocId) + .collect(Collectors.toList()) + .containsAll(expectedDocIdsKLimitsFilterResult) + ); + } + + private void createKnnIndexMappingWithJVectorEngine(int dimension, SpaceType spaceType, VectorDataType vectorDataType) throws Exception { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD_NAME) + .startObject(FIELD_NAME) + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, dimension) + .field(VECTOR_DATA_TYPE_FIELD, vectorDataType) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, DISK_ANN) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNNConstants.KNN_ENGINE, KNNEngine.JVECTOR.getName()) + .startObject(KNNConstants.PARAMETERS) + .field(KNNConstants.METHOD_PARAMETER_M, M) + .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, EF_CONSTRUCTION) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + String mapping = builder.toString(); + Settings indexSettings = getDefaultIndexSettings(); + //indexSettings = Settings.builder().put(indexSettings).put(INDEX_USE_COMPOUND_FILE.getKey(), false).build(); + createKnnIndex(INDEX_NAME, indexSettings, mapping); + } + + private void baseQueryTest(SpaceType spaceType) throws Exception { + + createKnnIndexMappingWithJVectorEngine(DIMENSION, spaceType, VectorDataType.FLOAT); + for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) { + addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]); + } + + validateQueries(spaceType, FIELD_NAME); + validateQueries(spaceType, FIELD_NAME, Map.of("ef_search", 100)); + } + + private void validateQueries(SpaceType spaceType, String fieldName) throws Exception { + validateQueries(spaceType, fieldName, null); + } + + private void validateQueries(SpaceType spaceType, String fieldName, Map methodParameters) throws Exception { + + int k = JVectorEngineIT.TEST_INDEX_VECTORS.length; + for (float[] queryVector : TEST_QUERY_VECTORS) { + Response response = searchKNNIndex(INDEX_NAME, buildSearchQuery(fieldName, k, queryVector, methodParameters), k); + String responseBody = EntityUtils.toString(response.getEntity()); + List knnResults = parseSearchResponse(responseBody, fieldName); + assertEquals(k, knnResults.size()); + + List actualScores = parseSearchResponseScore(responseBody, fieldName); + for (int j = 0; j < k; j++) { + float[] primitiveArray = knnResults.get(j).getVector(); + float distance = TestUtils.computeDistFromSpaceType(spaceType, primitiveArray, queryVector); + float rawScore = VECTOR_SIMILARITY_TO_SCORE.get(spaceType.getKnnVectorSimilarityFunction()).apply(distance); + assertEquals(KNNEngine.LUCENE.score(rawScore, spaceType), actualScores.get(j), 0.0001); + } + } + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/ForceMergesOnlyMergePolicy.java b/src/test/java/org/opensearch/knn/index/codec/ForceMergesOnlyMergePolicy.java new file mode 100644 index 0000000000..eb460e5590 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/ForceMergesOnlyMergePolicy.java @@ -0,0 +1,86 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec; + +import org.apache.lucene.index.*; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +public class ForceMergesOnlyMergePolicy extends MergePolicy { + private final boolean useCompoundFile; + + public ForceMergesOnlyMergePolicy() { + this(false); + } + + public ForceMergesOnlyMergePolicy(boolean useCompoundFile) { + super(); + this.useCompoundFile = useCompoundFile; + } + + @Override + public MergeSpecification findMerges(MergeTrigger mergeTrigger, SegmentInfos segmentInfos, MergeContext mergeContext) throws IOException { + return null; + } + + @Override + public MergeSpecification findForcedMerges(SegmentInfos segmentInfos, int maxSegmentCount, Map segmentsToMerge, MergeContext mergeContext) throws IOException { + // If the segments are already merged (e.g. there's only 1 segment), or + // there are segments = segmentInfos.asList(); + MergeSpecification spec = new MergeSpecification(); + + final OneMerge merge = new OneMerge(segments); + spec.add(merge); + return spec; + } + + @Override + public boolean useCompoundFile(SegmentInfos segmentInfos, SegmentCommitInfo newSegment, MergeContext mergeContext) throws IOException { + return useCompoundFile; + } + + @Override + public MergeSpecification findForcedDeletesMerges(SegmentInfos segmentInfos, MergeContext mergeContext) throws IOException { + return null; + } + + /** + * Returns true if the number of segments eligible for merging is less than or equal to the + * specified {@code maxNumSegments}. + */ + protected boolean isMerged( + SegmentInfos infos, + int maxNumSegments, + Map segmentsToMerge, + MergeContext mergeContext) + throws IOException { + final int numSegments = infos.size(); + int numToMerge = 0; + SegmentCommitInfo mergeInfo = null; + boolean segmentIsOriginal = false; + for (int i = 0; i < numSegments && numToMerge <= maxNumSegments; i++) { + final SegmentCommitInfo info = infos.info(i); + final Boolean isOriginal = segmentsToMerge.get(info); + if (isOriginal != null) { + segmentIsOriginal = isOriginal; + numToMerge++; + mergeInfo = info; + } + } + + return numToMerge <= maxNumSegments + && (numToMerge != 1 || !segmentIsOriginal || isMerged(infos, mergeInfo, mergeContext)); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/jvector/KNNJVectorTests.java b/src/test/java/org/opensearch/knn/index/codec/jvector/KNNJVectorTests.java new file mode 100644 index 0000000000..c383a5bc1a --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/jvector/KNNJVectorTests.java @@ -0,0 +1,395 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.jvector; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.document.*; +import org.apache.lucene.index.*; +import org.apache.lucene.search.*; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.FSDirectory; +import org.apache.lucene.store.FSLockFactory; +import org.apache.lucene.store.NIOFSDirectory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.junit.Assert; +import org.junit.Test; +import org.opensearch.knn.index.codec.ForceMergesOnlyMergePolicy; + +import java.io.IOException; +import java.nio.file.Path; + +/** + * Test used specifically for JVector + */ +// Currently {@link IndexGraphBuilder} is using the default ForkJoinPool.commonPool() which is not being shutdown. +// Ignore thread leaks until we remove the ForkJoinPool.commonPool() usage from IndexGraphBuilder +// TODO: Wire the execution thread pool to {@link IndexGraphBuilder} to avoid the failure of the UT due to leaked thread pool warning. +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +@LuceneTestCase.SuppressSysoutChecks(bugUrl = "") +@Log4j2 +public class KNNJVectorTests extends LuceneTestCase { + + /** + * Test to verify that the JVector codec is able to successfully search for the nearest neighbours + * in the index. + * Single field is used to store the vectors. + * All the documents are stored in a single segment. + * Single commit without refreshing the index. + * No merge. + */ + @Test + public void testJVectorKnnIndex_simpleCase() throws IOException { + int k = 3; // The number of nearest neighbours to gather + int totalNumberOfDocs = 10; + IndexWriterConfig indexWriterConfig = LuceneTestCase.newIndexWriterConfig(); + // TODO: re-enable this after fixing the compound file augmentation for JVector + indexWriterConfig.setUseCompoundFile(false); + indexWriterConfig.setCodec(new JVectorCodec()); + indexWriterConfig.setMergePolicy(new ForceMergesOnlyMergePolicy()); + final Path indexPath = createTempDir(); + log.info("Index path: {}", indexPath); + try (Directory dir = newFSDirectory(indexPath); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, indexWriterConfig)) { + final float[] target = new float[] { 0.0f, 0.0f }; + for (int i = 1; i < totalNumberOfDocs + 1; i++) { + final float[] source = new float[] { 0.0f, 1.0f / i }; + final Document doc = new Document(); + doc.add(new KnnFloatVectorField("test_field", source, VectorSimilarityFunction.EUCLIDEAN)); + w.addDocument(doc); + } + log.info("Flushing docs to make them discoverable on the file system"); + w.commit(); + + + try (IndexReader reader = w.getReader()) { + log.info("We should now have a single segment with 10 documents"); + Assert.assertEquals(1, reader.getContext().leaves().size()); + Assert.assertEquals(totalNumberOfDocs, reader.numDocs()); + + final Query filterQuery = new MatchAllDocsQuery(); + final IndexSearcher searcher = newSearcher(reader); + KnnFloatVectorQuery knnFloatVectorQuery = new KnnFloatVectorQuery("test_field", target, k, filterQuery); + TopDocs topDocs = searcher.search(knnFloatVectorQuery, k); + assertEquals(k, topDocs.totalHits.value); + assertEquals(9, topDocs.scoreDocs[0].doc); + Assert.assertEquals(VectorSimilarityFunction.EUCLIDEAN.compare(target, new float[]{0.0f, 1.0f / 10.0f}), topDocs.scoreDocs[0].score, 0.001f); + assertEquals(8, topDocs.scoreDocs[1].doc); + Assert.assertEquals(VectorSimilarityFunction.EUCLIDEAN.compare(target, new float[]{0.0f, 1.0f / 9.0f}), topDocs.scoreDocs[1].score, 0.001f); + assertEquals(7, topDocs.scoreDocs[2].doc); + Assert.assertEquals(VectorSimilarityFunction.EUCLIDEAN.compare(target, new float[]{0.0f, 1.0f / 8.0f}), topDocs.scoreDocs[2].score, 0.001f); + log.info("successfully completed search tests"); + } + } + } + + /** + * Test to verify that the JVector codec is able to successfully search for the nearest neighbours + * in the index. + * Single field is used to store the vectors. + * Documents are stored in a multiple segments. + * Multiple commits without refreshing the index. + * No merge. + */ + @Test + public void testJVectorKnnIndex_multipleSegments() throws IOException { + int k = 3; // The number of nearest neighbours to gather + int totalNumberOfDocs = 10; + IndexWriterConfig indexWriterConfig = LuceneTestCase.newIndexWriterConfig(); + // TODO: re-enable this after fixing the compound file augmentation for JVector + indexWriterConfig.setUseCompoundFile(false); + indexWriterConfig.setCodec(new JVectorCodec()); + indexWriterConfig.setMergePolicy(new ForceMergesOnlyMergePolicy(false)); + final Path indexPath = createTempDir(); + log.info("Index path: {}", indexPath); + try (FSDirectory dir = new NIOFSDirectory(indexPath, FSLockFactory.getDefault()); + IndexWriter w = new IndexWriter(dir, indexWriterConfig)) { + final float[] target = new float[] { 0.0f, 0.0f }; + for (int i = 1; i < totalNumberOfDocs + 1; i++) { + final float[] source = new float[] { 0.0f, 1.0f / i }; + final Document doc = new Document(); + doc.add(new KnnFloatVectorField("test_field", source, VectorSimilarityFunction.EUCLIDEAN)); + w.addDocument(doc); + w.commit(); // this creates a new segment + } + log.info("Done writing all files to the file system"); + + try (IndexReader reader = DirectoryReader.open(w)) { + log.info("We should now have 10 segments, each with a single document"); + Assert.assertEquals(10, reader.getContext().leaves().size()); + Assert.assertEquals(totalNumberOfDocs, reader.numDocs()); + final Query filterQuery = new MatchAllDocsQuery(); + final IndexSearcher searcher = newSearcher(reader); + KnnFloatVectorQuery knnFloatVectorQuery = new KnnFloatVectorQuery("test_field", target, k, filterQuery); + TopDocs topDocs = searcher.search(knnFloatVectorQuery, k); + assertEquals(k, topDocs.totalHits.value); + assertEquals(9, topDocs.scoreDocs[0].doc); + Assert.assertEquals(VectorSimilarityFunction.EUCLIDEAN.compare(target, new float[]{0.0f, 1.0f / 10.0f}), topDocs.scoreDocs[0].score, 0.001f); + assertEquals(8, topDocs.scoreDocs[1].doc); + Assert.assertEquals(VectorSimilarityFunction.EUCLIDEAN.compare(target, new float[]{0.0f, 1.0f / 9.0f}), topDocs.scoreDocs[1].score, 0.001f); + assertEquals(7, topDocs.scoreDocs[2].doc); + Assert.assertEquals(VectorSimilarityFunction.EUCLIDEAN.compare(target, new float[]{0.0f, 1.0f / 8.0f}), topDocs.scoreDocs[2].score, 0.001f); + log.info("successfully completed search tests"); + } + } + } + + /** + * Test to verify that the JVector codec is able to successfully search for the nearest neighbours + * in the index. + * Single field is used to store the vectors. + * Documents are stored in a multiple segments. + * Multiple commits without refreshing the index. + * Merge is enabled. + */ + @Test + public void testJVectorKnnIndex_mergeEnabled() throws IOException { + int k = 3; // The number of nearest neighbours to gather + int totalNumberOfDocs = 10; + IndexWriterConfig indexWriterConfig = LuceneTestCase.newIndexWriterConfig(); + // TODO: re-enable this after fixing the compound file augmentation for JVector + indexWriterConfig.setUseCompoundFile(false); + indexWriterConfig.setCodec(new JVectorCodec()); + indexWriterConfig.setMergePolicy(new ForceMergesOnlyMergePolicy()); + indexWriterConfig.setMergeScheduler(new SerialMergeScheduler()); + final Path indexPath = createTempDir(); + log.info("Index path: {}", indexPath); + try (FSDirectory dir = new NIOFSDirectory(indexPath, FSLockFactory.getDefault()); + IndexWriter w = new IndexWriter(dir, indexWriterConfig)) { + final float[] target = new float[]{0.0f, 0.0f}; + for (int i = 1; i < totalNumberOfDocs + 1; i++) { + final float[] source = new float[]{0.0f, 1.0f / i}; + final Document doc = new Document(); + doc.add(new KnnFloatVectorField("test_field", source, VectorSimilarityFunction.EUCLIDEAN)); + doc.add(new StringField("my_doc_id", Integer.toString(i, 10), Field.Store.YES)); + w.addDocument(doc); + w.commit(); // this creates a new segment without triggering a merge + } + log.info("Done writing all files to the file system"); + + w.forceMerge(1); // this merges all segments into a single segment + log.info("Done merging all segments"); + try (IndexReader reader = DirectoryReader.open(w)) { + log.info("We should now have 1 segment with 10 documents"); + Assert.assertEquals(1, reader.getContext().leaves().size()); + Assert.assertEquals(totalNumberOfDocs, reader.numDocs()); + final Query filterQuery = new MatchAllDocsQuery(); + final IndexSearcher searcher = newSearcher(reader); + KnnFloatVectorQuery knnFloatVectorQuery = new KnnFloatVectorQuery("test_field", target, k, filterQuery); + TopDocs topDocs = searcher.search(knnFloatVectorQuery, k); + assertEquals(k, topDocs.totalHits.value); + Document doc = reader.document(topDocs.scoreDocs[0].doc); + assertEquals("10", doc.get("my_doc_id")); + Assert.assertEquals(VectorSimilarityFunction.EUCLIDEAN.compare(target, new float[]{0.0f, 1.0f / 10.0f}), topDocs.scoreDocs[0].score, 0.001f); + doc = reader.document(topDocs.scoreDocs[1].doc); + assertEquals("9", doc.get("my_doc_id")); + Assert.assertEquals(VectorSimilarityFunction.EUCLIDEAN.compare(target, new float[]{0.0f, 1.0f / 9.0f}), topDocs.scoreDocs[1].score, 0.001f); + doc = reader.document(topDocs.scoreDocs[2].doc); + assertEquals("8", doc.get("my_doc_id")); + Assert.assertEquals(VectorSimilarityFunction.EUCLIDEAN.compare(target, new float[]{0.0f, 1.0f / 8.0f}), topDocs.scoreDocs[2].score, 0.001f); + log.info("successfully completed search tests"); + } + } + } + + /** + * Test to verify that the Lucene codec is able to successfully search for the nearest neighbours + * in the index. + * Single field is used to store the vectors. + * Documents are stored in potentially multiple segments. + * Multiple commits. + * Multiple merges. + */ + @Test + public void testLuceneKnnIndex_multipleMerges() throws IOException { + int k = 3; // The number of nearest neighbours to gather + int totalNumberOfDocs = 10; + IndexWriterConfig indexWriterConfig = LuceneTestCase.newIndexWriterConfig(); + // TODO: re-enable this after fixing the compound file augmentation for JVector + indexWriterConfig.setUseCompoundFile(false); + indexWriterConfig.setCodec(new JVectorCodec()); + indexWriterConfig.setMergePolicy(new ForceMergesOnlyMergePolicy()); + indexWriterConfig.setMergeScheduler(new SerialMergeScheduler()); + final Path indexPath = createTempDir(); + log.info("Index path: {}", indexPath); + try (FSDirectory dir = new NIOFSDirectory(indexPath, FSLockFactory.getDefault()); + IndexWriter w = new IndexWriter(dir, indexWriterConfig)) { + final float[] target = new float[]{0.0f, 0.0f}; + for (int i = 1; i < totalNumberOfDocs + 1; i++) { + final float[] source = new float[]{0.0f, 1.0f / i}; + final Document doc = new Document(); + doc.add(new KnnFloatVectorField("test_field", source, VectorSimilarityFunction.EUCLIDEAN)); + doc.add(new StringField("my_doc_id", Integer.toString(i, 10), Field.Store.YES)); + w.addDocument(doc); + w.commit(); // this creates a new segment without triggering a merge + w.forceMerge(1); // this merges all segments into a single segment + } + log.info("Done writing all files to the file system"); + + w.forceMerge(1); // this merges all segments into a single segment + log.info("Done merging all segments"); + try (IndexReader reader = DirectoryReader.open(w)) { + log.info("We should now have 1 segment with 10 documents"); + Assert.assertEquals(1, reader.getContext().leaves().size()); + Assert.assertEquals(totalNumberOfDocs, reader.numDocs()); + final Query filterQuery = new MatchAllDocsQuery(); + final IndexSearcher searcher = newSearcher(reader); + KnnFloatVectorQuery knnFloatVectorQuery = new KnnFloatVectorQuery("test_field", target, k, filterQuery); + TopDocs topDocs = searcher.search(knnFloatVectorQuery, k); + assertEquals(k, topDocs.totalHits.value); + Document doc = reader.document(topDocs.scoreDocs[0].doc); + assertEquals("10", doc.get("my_doc_id")); + Assert.assertEquals(VectorSimilarityFunction.EUCLIDEAN.compare(target, new float[]{0.0f, 1.0f / 10.0f}), topDocs.scoreDocs[0].score, 0.001f); + doc = reader.document(topDocs.scoreDocs[1].doc); + assertEquals("9", doc.get("my_doc_id")); + Assert.assertEquals(VectorSimilarityFunction.EUCLIDEAN.compare(target, new float[]{0.0f, 1.0f / 9.0f}), topDocs.scoreDocs[1].score, 0.001f); + doc = reader.document(topDocs.scoreDocs[2].doc); + assertEquals("8", doc.get("my_doc_id")); + Assert.assertEquals(VectorSimilarityFunction.EUCLIDEAN.compare(target, new float[]{0.0f, 1.0f / 8.0f}), topDocs.scoreDocs[2].score, 0.001f); + log.info("successfully completed search tests"); + } + } + } + + /** + * Test to verify that the Lucene codec is able to successfully search for the nearest neighbours + * in the index. + * Single field is used to store the vectors. + * Documents are stored in potentially multiple segments. + * Multiple commits. + * Multiple merges. + * Merge is enabled. + * compound file is enabled. + */ + @Test + public void testLuceneKnnIndex_mergeEnabled_withCompoundFile() throws IOException { + int k = 3; // The number of nearest neighbours to gather + int totalNumberOfDocs = 10; + IndexWriterConfig indexWriterConfig = LuceneTestCase.newIndexWriterConfig(); + // TODO: re-enable this after fixing the compound file augmentation for JVector + indexWriterConfig.setUseCompoundFile(true); + indexWriterConfig.setCodec(new JVectorCodec()); + indexWriterConfig.setMergePolicy(new ForceMergesOnlyMergePolicy(true)); + indexWriterConfig.setMergeScheduler(new SerialMergeScheduler()); + final Path indexPath = createTempDir(); + log.info("Index path: {}", indexPath); + try (FSDirectory dir = new NIOFSDirectory(indexPath, FSLockFactory.getDefault()); + IndexWriter w = new IndexWriter(dir, indexWriterConfig)) { + final float[] target = new float[]{0.0f, 0.0f}; + for (int i = 1; i < totalNumberOfDocs + 1; i++) { + final float[] source = new float[]{0.0f, 1.0f / i}; + final Document doc = new Document(); + doc.add(new KnnFloatVectorField("test_field", source, VectorSimilarityFunction.EUCLIDEAN)); + w.addDocument(doc); + w.flush(); // this creates a new segment without triggering a merge + } + log.info("Done writing all files to the file system"); + + w.forceMerge(1); // this merges all segments into a single segment + log.info("Done merging all segments"); + try (IndexReader reader = DirectoryReader.open(w)) { + log.info("We should now have 1 segment with 10 documents"); + Assert.assertEquals(1, reader.getContext().leaves().size()); + Assert.assertEquals(totalNumberOfDocs, reader.numDocs()); + final Query filterQuery = new MatchAllDocsQuery(); + final IndexSearcher searcher = newSearcher(reader); + KnnFloatVectorQuery knnFloatVectorQuery = new KnnFloatVectorQuery("test_field", target, k, filterQuery); + TopDocs topDocs = searcher.search(knnFloatVectorQuery, k); + assertEquals(k, topDocs.totalHits.value); + assertEquals(9, topDocs.scoreDocs[0].doc); + Assert.assertEquals(VectorSimilarityFunction.EUCLIDEAN.compare(target, new float[]{0.0f, 1.0f / 10.0f}), topDocs.scoreDocs[0].score, 0.01f); + assertEquals(8, topDocs.scoreDocs[1].doc); + Assert.assertEquals(VectorSimilarityFunction.EUCLIDEAN.compare(target, new float[]{0.0f, 1.0f / 9.0f}), topDocs.scoreDocs[1].score, 0.01f); + assertEquals(7, topDocs.scoreDocs[2].doc); + Assert.assertEquals(VectorSimilarityFunction.EUCLIDEAN.compare(target, new float[]{0.0f, 1.0f / 8.0f}), topDocs.scoreDocs[2].score, 0.01f); + log.info("successfully completed search tests"); + } + } + } + + /** + * Test to verify that the Lucene codec is able to successfully search for the nearest neighbours + * in the index. + * Single field is used to store the vectors. + * Documents are stored in potentially multiple segments. + * Multiple commits. + * Multiple merges. + * Merge is enabled. + * compound file is enabled. + * cosine similarity is used. + */ + @Test + public void testLuceneKnnIndex_mergeEnabled_withCompoundFile_cosine() throws IOException { + int k = 3; // The number of nearest neighbours to gather + int totalNumberOfDocs = 10; + IndexWriterConfig indexWriterConfig = LuceneTestCase.newIndexWriterConfig(); + // TODO: re-enable this after fixing the compound file augmentation for JVector + indexWriterConfig.setUseCompoundFile(true); + indexWriterConfig.setCodec(new JVectorCodec()); + indexWriterConfig.setMergePolicy(new ForceMergesOnlyMergePolicy(true)); + indexWriterConfig.setMergeScheduler(new SerialMergeScheduler()); + final Path indexPath = createTempDir(); + log.info("Index path: {}", indexPath); + try (FSDirectory dir = new NIOFSDirectory(indexPath, FSLockFactory.getDefault()); + IndexWriter w = new IndexWriter(dir, indexWriterConfig)) { + final float[] target = new float[]{1.0f, 1.0f}; + for (int i = 1; i < totalNumberOfDocs + 1; i++) { + final float[] source = new float[]{1.0f + i, 2.0f * i}; + final Document doc = new Document(); + doc.add(new KnnFloatVectorField("test_field", source, VectorSimilarityFunction.COSINE)); + w.addDocument(doc); + w.flush(); // this creates a new segment without triggering a merge + } + log.info("Done writing all files to the file system"); + + w.forceMerge(1); // this merges all segments into a single segment + log.info("Done merging all segments"); + try (IndexReader reader = DirectoryReader.open(w)) { + log.info("We should now have 1 segment with 10 documents"); + Assert.assertEquals(1, reader.getContext().leaves().size()); + Assert.assertEquals(totalNumberOfDocs, reader.numDocs()); + final Query filterQuery = new MatchAllDocsQuery(); + final IndexSearcher searcher = newSearcher(reader); + KnnFloatVectorQuery knnFloatVectorQuery = new KnnFloatVectorQuery("test_field", target, k, filterQuery); + TopDocs topDocs = searcher.search(knnFloatVectorQuery, k); + assertEquals(k, topDocs.totalHits.value); + assertEquals(0, topDocs.scoreDocs[0].doc); + Assert.assertEquals(VectorSimilarityFunction.COSINE.compare(target, new float[]{2.0f, 2.0f}), topDocs.scoreDocs[0].score, 0.001f); + assertEquals(1, topDocs.scoreDocs[1].doc); + Assert.assertEquals(VectorSimilarityFunction.COSINE.compare(target, new float[]{3.0f, 4.0f}), topDocs.scoreDocs[1].score, 0.001f); + assertEquals(2, topDocs.scoreDocs[2].doc); + Assert.assertEquals(VectorSimilarityFunction.COSINE.compare(target, new float[]{4.0f, 6.0f}), topDocs.scoreDocs[2].score, 0.001f); + log.info("successfully completed search tests"); + } + } + } + + /** + * Test to verify that the JVector codec is providing proper error if used with byte vector + * TODO: Create Product Quantization support for JVector codec + */ + @Test + public void testJVectorKnnIndex_simpleCase_withBinaryVector() throws IOException { + int k = 3; // The number of nearest neighbours to gather + int totalNumberOfDocs = 10; + IndexWriterConfig indexWriterConfig = LuceneTestCase.newIndexWriterConfig(); + // TODO: re-enable this after fixing the compound file augmentation for JVector + indexWriterConfig.setUseCompoundFile(false); + indexWriterConfig.setCodec(new JVectorCodec()); + indexWriterConfig.setMergePolicy(new ForceMergesOnlyMergePolicy()); + final Path indexPath = createTempDir(); + log.info("Index path: {}", indexPath); + try (Directory dir = newFSDirectory(indexPath); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, indexWriterConfig)) { + final byte[] source = new byte[] { (byte) 0, (byte)0 }; + final Document doc = new Document(); + doc.add(new KnnByteVectorField("test_field", source, VectorSimilarityFunction.EUCLIDEAN)); + Assert.assertThrows(UnsupportedOperationException.class, () -> w.addDocument(doc)); + } + } + +} diff --git a/src/test/resources/log4j2.properties b/src/test/resources/log4j2.properties new file mode 100644 index 0000000000..a06597803a --- /dev/null +++ b/src/test/resources/log4j2.properties @@ -0,0 +1,22 @@ +# +# SPDX-License-Identifier: Apache-2.0 +# +# The OpenSearch Contributors require contributions made to +# this file be licensed under the Apache-2.0 license or a +# compatible open source license. +# +# Modifications Copyright OpenSearch Contributors. See +# GitHub history for details. +# + +appender.console.type = Console +appender.console.name = console +appender.console.layout.type = PatternLayout +#appender.console.layout.pattern = [%d{ISO8601}][%-5p][%-25c] %marker %m%n +#appender.console.layout.pattern = [%d{ISO8601}][%-5p][%-25c] [%t] [%node_name]%marker %m%n +appender.console.layout.pattern = [%d{yyyy-MM-dd HH:mm:ss.SSS}][%-5level][%15.15t][%-25c] %X{correlationId} %msg%n + + +# Do not log at all if it is not really critical - we're in a benchmark +rootLogger.level = debug +rootLogger.appenderRef.console.ref = console diff --git a/src/test/resources/log4j2.xml b/src/test/resources/log4j2.xml new file mode 100644 index 0000000000..28ba86bb3b --- /dev/null +++ b/src/test/resources/log4j2.xml @@ -0,0 +1,22 @@ + + + + + + + + %d{yyyy-MM-dd HH:mm:ss.SSS} [%t] %-5level %c{1.} - %msg%n%throwable + + + + + + + + + + + \ No newline at end of file