diff --git a/chromadb/test/ef/test_voyageai.py b/chromadb/test/ef/test_voyageai.py index ced093192d9..5baaf70fe26 100644 --- a/chromadb/test/ef/test_voyageai.py +++ b/chromadb/test/ef/test_voyageai.py @@ -2,7 +2,11 @@ import pytest -from chromadb.utils.embedding_functions import VoyageAIEmbeddingFunction +from chromadb.utils.embedding_functions.voyage_ai_embedding_function import ( + VoyageAIEmbeddingFunction, +) + +voyageai = pytest.importorskip("voyageai", reason="voyageai not installed") @pytest.fixture(scope="function") @@ -18,7 +22,10 @@ def remove_api_key(): os.environ["VOYAGE_API_KEY"] = existing_api_key -@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.") +@pytest.mark.skipif( + "VOYAGE_API_KEY" not in os.environ, + reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.", +) def test_voyage() -> None: ef = VoyageAIEmbeddingFunction(api_key=os.environ.get("VOYAGE_API_KEY", "")) embeddings = ef(["test doc"]) @@ -27,7 +34,10 @@ def test_voyage() -> None: assert len(embeddings[0]) > 0 -@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.") +@pytest.mark.skipif( + "VOYAGE_API_KEY" not in os.environ, + reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.", +) def test_voyage_input_type_query() -> None: ef = VoyageAIEmbeddingFunction( api_key=os.environ.get("VOYAGE_API_KEY", ""), @@ -39,7 +49,10 @@ def test_voyage_input_type_query() -> None: assert len(embeddings[0]) > 0 -@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.") +@pytest.mark.skipif( + "VOYAGE_API_KEY" not in os.environ, + reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.", +) def test_voyage_input_type_document() -> None: ef = VoyageAIEmbeddingFunction( api_key=os.environ.get("VOYAGE_API_KEY", ""), @@ -51,7 +64,10 @@ def test_voyage_input_type_document() -> None: assert len(embeddings[0]) > 0 -@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.") +@pytest.mark.skipif( + "VOYAGE_API_KEY" not in os.environ, + reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.", +) def test_voyage_model() -> None: ef = VoyageAIEmbeddingFunction( api_key=os.environ.get("VOYAGE_API_KEY", ""), model_name="voyage-01" @@ -62,7 +78,10 @@ def test_voyage_model() -> None: assert len(embeddings[0]) > 0 -@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.") +@pytest.mark.skipif( + "VOYAGE_API_KEY" not in os.environ, + reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.", +) def test_voyage_truncation_default() -> None: ef = VoyageAIEmbeddingFunction(api_key=os.environ.get("VOYAGE_API_KEY", "")) embeddings = ef(["this is a test-message" * 10000]) @@ -71,7 +90,10 @@ def test_voyage_truncation_default() -> None: assert len(embeddings[0]) > 0 -@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.") +@pytest.mark.skipif( + "VOYAGE_API_KEY" not in os.environ, + reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.", +) def test_voyage_truncation_enabled() -> None: ef = VoyageAIEmbeddingFunction( api_key=os.environ.get("VOYAGE_API_KEY", ""), truncation=True @@ -82,7 +104,10 @@ def test_voyage_truncation_enabled() -> None: assert len(embeddings[0]) > 0 -@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.") +@pytest.mark.skipif( + "VOYAGE_API_KEY" not in os.environ, + reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.", +) def test_voyage_truncation_disabled() -> None: ef = VoyageAIEmbeddingFunction( api_key=os.environ.get("VOYAGE_API_KEY", ""), truncation=False @@ -91,30 +116,45 @@ def test_voyage_truncation_disabled() -> None: ef(["this is a test-message" * 10000]) -@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.") +@pytest.mark.skipif( + "VOYAGE_API_KEY" not in os.environ, + reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.", +) def test_voyage_env_api_key() -> None: VoyageAIEmbeddingFunction() -@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.") +@pytest.mark.skipif( + "VOYAGE_API_KEY" not in os.environ, + reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.", +) def test_voyage_no_api_key(remove_api_key) -> None: with pytest.raises(ValueError, match="Please provide a VoyageAI API key"): VoyageAIEmbeddingFunction(api_key=None) # type: ignore -@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.") +@pytest.mark.skipif( + "VOYAGE_API_KEY" not in os.environ, + reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.", +) def test_voyage_no_api_key_in_env(remove_api_key) -> None: with pytest.raises(ValueError, match="Please provide a VoyageAI API key"): VoyageAIEmbeddingFunction(api_key=None) # type: ignore -@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.") +@pytest.mark.skipif( + "VOYAGE_API_KEY" not in os.environ, + reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.", +) def test_voyage_max_batch_size_exceeded_in_init() -> None: with pytest.raises(ValueError, match="The maximum batch size supported is"): VoyageAIEmbeddingFunction(api_key="dummy", max_batch_size=99999999) -@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.") +@pytest.mark.skipif( + "VOYAGE_API_KEY" not in os.environ, + reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.", +) def test_voyage_max_batch_size_exceeded_in_call() -> None: ef = VoyageAIEmbeddingFunction(api_key="dummy", max_batch_size=1) with pytest.raises(ValueError, match="The maximum batch size supported is"): diff --git a/chromadb/utils/embedding_functions/voyage_ai_embedding_function.py b/chromadb/utils/embedding_functions/voyage_ai_embedding_function.py new file mode 100644 index 00000000000..cbe1534bc70 --- /dev/null +++ b/chromadb/utils/embedding_functions/voyage_ai_embedding_function.py @@ -0,0 +1,77 @@ +import os +from enum import Enum +from typing import Optional, cast + +from chromadb.api.types import ( + Documents, + EmbeddingFunction, + Embeddings, +) + + +class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]): + """Embedding function for Voyageai.com. API docs - https://docs.voyageai.com/reference/embeddings-api""" + + class InputType(str, Enum): + DOCUMENT = "document" + QUERY = "query" + + def __init__( + self, + api_key: Optional[str] = None, + model_name: str = "voyage-2", + max_batch_size: int = 128, + truncation: Optional[bool] = True, + input_type: Optional[InputType] = None, + ): + """ + Initialize the VoyageAIEmbeddingFunction. + Args: + api_key (str): Your API key for the HuggingFace API. + model_name (str, optional): The name of the model to use for text embeddings. Defaults to "voyage-01". + batch_size (int, optional): The number of documents to send at a time. Defaults to 128 (The max supported 7th Apr 2024). see voyageai.VOYAGE_EMBED_BATCH_SIZE for actual max. + truncation (bool, optional): Whether to truncate the input (`True`) or raise an error if the input is too long (`False`). Defaults to `False`. + input_type (str, optional): The type of input text. Can be `None`, `query`, `document`. Defaults to `None`. + """ + + if not api_key and "VOYAGE_API_KEY" not in os.environ: + raise ValueError("Please provide a VoyageAI API key.") + + try: + import voyageai + + if max_batch_size > voyageai.VOYAGE_EMBED_BATCH_SIZE: + raise ValueError( + f"The maximum batch size supported is {voyageai.VOYAGE_EMBED_BATCH_SIZE}." + ) + self._batch_size = max_batch_size + self._model = model_name + self._truncation = truncation + self._client = voyageai.Client(api_key=api_key) + self._input_type = input_type + except ImportError: + raise ValueError( + "The VoyageAI python package is not installed. Please install it with `pip install voyageai`" + ) + + def __call__(self, input: Documents) -> Embeddings: + """ + Get the embeddings for a list of texts. + Args: + input (Documents): A list of texts to get embeddings for. + Returns: + Embeddings: The embeddings for the texts. + Example: + >>> voyage_ef = VoyageAIEmbeddingFunction(api_key="your_api_key") + >>> input = ["Hello, world!", "How are you?"] + >>> embeddings = voyage_ef(input) + """ + if len(input) > self._batch_size: + raise ValueError(f"The maximum batch size supported is {self._batch_size}.") + results = self._client.embed( + texts=input, + model=self._model, + truncation=self._truncation, + input_type=self._input_type, + ) + return cast(Embeddings, results.embeddings)