Skip to content

Commit

Permalink
feat: Rebase + minor test improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
tazarov committed Jun 21, 2024
1 parent b3a4a70 commit 104ef71
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 13 deletions.
66 changes: 53 additions & 13 deletions chromadb/test/ef/test_voyageai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

import pytest

from chromadb.utils.embedding_functions import VoyageAIEmbeddingFunction
from chromadb.utils.embedding_functions.voyage_ai_embedding_function import (
VoyageAIEmbeddingFunction,
)

voyageai = pytest.importorskip("voyageai", reason="voyageai not installed")


@pytest.fixture(scope="function")
Expand All @@ -18,7 +22,10 @@ def remove_api_key():
os.environ["VOYAGE_API_KEY"] = existing_api_key


@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.")
@pytest.mark.skipif(
"VOYAGE_API_KEY" not in os.environ,
reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.",
)
def test_voyage() -> None:
ef = VoyageAIEmbeddingFunction(api_key=os.environ.get("VOYAGE_API_KEY", ""))
embeddings = ef(["test doc"])
Expand All @@ -27,7 +34,10 @@ def test_voyage() -> None:
assert len(embeddings[0]) > 0


@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.")
@pytest.mark.skipif(
"VOYAGE_API_KEY" not in os.environ,
reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.",
)
def test_voyage_input_type_query() -> None:
ef = VoyageAIEmbeddingFunction(
api_key=os.environ.get("VOYAGE_API_KEY", ""),
Expand All @@ -39,7 +49,10 @@ def test_voyage_input_type_query() -> None:
assert len(embeddings[0]) > 0


@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.")
@pytest.mark.skipif(
"VOYAGE_API_KEY" not in os.environ,
reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.",
)
def test_voyage_input_type_document() -> None:
ef = VoyageAIEmbeddingFunction(
api_key=os.environ.get("VOYAGE_API_KEY", ""),
Expand All @@ -51,7 +64,10 @@ def test_voyage_input_type_document() -> None:
assert len(embeddings[0]) > 0


@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.")
@pytest.mark.skipif(
"VOYAGE_API_KEY" not in os.environ,
reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.",
)
def test_voyage_model() -> None:
ef = VoyageAIEmbeddingFunction(
api_key=os.environ.get("VOYAGE_API_KEY", ""), model_name="voyage-01"
Expand All @@ -62,7 +78,10 @@ def test_voyage_model() -> None:
assert len(embeddings[0]) > 0


@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.")
@pytest.mark.skipif(
"VOYAGE_API_KEY" not in os.environ,
reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.",
)
def test_voyage_truncation_default() -> None:
ef = VoyageAIEmbeddingFunction(api_key=os.environ.get("VOYAGE_API_KEY", ""))
embeddings = ef(["this is a test-message" * 10000])
Expand All @@ -71,7 +90,10 @@ def test_voyage_truncation_default() -> None:
assert len(embeddings[0]) > 0


@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.")
@pytest.mark.skipif(
"VOYAGE_API_KEY" not in os.environ,
reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.",
)
def test_voyage_truncation_enabled() -> None:
ef = VoyageAIEmbeddingFunction(
api_key=os.environ.get("VOYAGE_API_KEY", ""), truncation=True
Expand All @@ -82,7 +104,10 @@ def test_voyage_truncation_enabled() -> None:
assert len(embeddings[0]) > 0


@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.")
@pytest.mark.skipif(
"VOYAGE_API_KEY" not in os.environ,
reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.",
)
def test_voyage_truncation_disabled() -> None:
ef = VoyageAIEmbeddingFunction(
api_key=os.environ.get("VOYAGE_API_KEY", ""), truncation=False
Expand All @@ -91,30 +116,45 @@ def test_voyage_truncation_disabled() -> None:
ef(["this is a test-message" * 10000])


@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.")
@pytest.mark.skipif(
"VOYAGE_API_KEY" not in os.environ,
reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.",
)
def test_voyage_env_api_key() -> None:
VoyageAIEmbeddingFunction()


@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.")
@pytest.mark.skipif(
"VOYAGE_API_KEY" not in os.environ,
reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.",
)
def test_voyage_no_api_key(remove_api_key) -> None:
with pytest.raises(ValueError, match="Please provide a VoyageAI API key"):
VoyageAIEmbeddingFunction(api_key=None) # type: ignore


@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.")
@pytest.mark.skipif(
"VOYAGE_API_KEY" not in os.environ,
reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.",
)
def test_voyage_no_api_key_in_env(remove_api_key) -> None:
with pytest.raises(ValueError, match="Please provide a VoyageAI API key"):
VoyageAIEmbeddingFunction(api_key=None) # type: ignore


@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.")
@pytest.mark.skipif(
"VOYAGE_API_KEY" not in os.environ,
reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.",
)
def test_voyage_max_batch_size_exceeded_in_init() -> None:
with pytest.raises(ValueError, match="The maximum batch size supported is"):
VoyageAIEmbeddingFunction(api_key="dummy", max_batch_size=99999999)


@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.")
@pytest.mark.skipif(
"VOYAGE_API_KEY" not in os.environ,
reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.",
)
def test_voyage_max_batch_size_exceeded_in_call() -> None:
ef = VoyageAIEmbeddingFunction(api_key="dummy", max_batch_size=1)
with pytest.raises(ValueError, match="The maximum batch size supported is"):
Expand Down
77 changes: 77 additions & 0 deletions chromadb/utils/embedding_functions/voyage_ai_embedding_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import os
from enum import Enum
from typing import Optional, cast

from chromadb.api.types import (
Documents,
EmbeddingFunction,
Embeddings,
)


class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
"""Embedding function for Voyageai.com. API docs - https://docs.voyageai.com/reference/embeddings-api"""

class InputType(str, Enum):
DOCUMENT = "document"
QUERY = "query"

def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "voyage-2",
max_batch_size: int = 128,
truncation: Optional[bool] = True,
input_type: Optional[InputType] = None,
):
"""
Initialize the VoyageAIEmbeddingFunction.
Args:
api_key (str): Your API key for the HuggingFace API.
model_name (str, optional): The name of the model to use for text embeddings. Defaults to "voyage-01".
batch_size (int, optional): The number of documents to send at a time. Defaults to 128 (The max supported 7th Apr 2024). see voyageai.VOYAGE_EMBED_BATCH_SIZE for actual max.
truncation (bool, optional): Whether to truncate the input (`True`) or raise an error if the input is too long (`False`). Defaults to `False`.
input_type (str, optional): The type of input text. Can be `None`, `query`, `document`. Defaults to `None`.
"""

if not api_key and "VOYAGE_API_KEY" not in os.environ:
raise ValueError("Please provide a VoyageAI API key.")

try:
import voyageai

if max_batch_size > voyageai.VOYAGE_EMBED_BATCH_SIZE:
raise ValueError(
f"The maximum batch size supported is {voyageai.VOYAGE_EMBED_BATCH_SIZE}."
)
self._batch_size = max_batch_size
self._model = model_name
self._truncation = truncation
self._client = voyageai.Client(api_key=api_key)
self._input_type = input_type
except ImportError:
raise ValueError(
"The VoyageAI python package is not installed. Please install it with `pip install voyageai`"
)

def __call__(self, input: Documents) -> Embeddings:
"""
Get the embeddings for a list of texts.
Args:
input (Documents): A list of texts to get embeddings for.
Returns:
Embeddings: The embeddings for the texts.
Example:
>>> voyage_ef = VoyageAIEmbeddingFunction(api_key="your_api_key")
>>> input = ["Hello, world!", "How are you?"]
>>> embeddings = voyage_ef(input)
"""
if len(input) > self._batch_size:
raise ValueError(f"The maximum batch size supported is {self._batch_size}.")
results = self._client.embed(
texts=input,
model=self._model,
truncation=self._truncation,
input_type=self._input_type,
)
return cast(Embeddings, results.embeddings)

0 comments on commit 104ef71

Please sign in to comment.