Skip to content

Commit

Permalink
chore: input_type is now an enum
Browse files Browse the repository at this point in the history
- Updated API key to be VOYAGE_API_KEY
- Test cleanup
  • Loading branch information
tazarov committed Apr 8, 2024
1 parent c0cbbed commit 2ffd9e3
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 35 deletions.
72 changes: 44 additions & 28 deletions chromadb/test/ef/test_voyageai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,101 +5,117 @@
from chromadb.utils.embedding_functions import VoyageAIEmbeddingFunction


@pytest.fixture(scope="function")
def remove_api_key():
existing_api_key = None
if "VOYAGE_API_KEY" in os.environ:
existing_api_key = os.environ["VOYAGE_API_KEY"]
print("removing key")
del os.environ["VOYAGE_API_KEY"]
yield
if existing_api_key:
print("setting kye")
os.environ["VOYAGE_API_KEY"] = existing_api_key


@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.")
def test_voyage() -> None:
if "VOYAGEAI_API_KEY" not in os.environ:
pytest.skip("VOYAGEAI_API_KEY not set, not going to test VoyageAI EF.")
ef = VoyageAIEmbeddingFunction(api_key=os.environ.get("VOYAGEAI_API_KEY", ""))
ef = VoyageAIEmbeddingFunction(api_key=os.environ.get("VOYAGE_API_KEY", ""))
embeddings = ef(["test doc"])
assert embeddings is not None
assert len(embeddings) == 1
assert len(embeddings[0]) > 0


@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.")
def test_voyage_input_type_query() -> None:
if "VOYAGEAI_API_KEY" not in os.environ:
pytest.skip("VOYAGEAI_API_KEY not set, not going to test VoyageAI EF.")
ef = VoyageAIEmbeddingFunction(
api_key=os.environ.get("VOYAGEAI_API_KEY", ""), input_type="query"
api_key=os.environ.get("VOYAGE_API_KEY", ""),
input_type=VoyageAIEmbeddingFunction.InputType.QUERY,
)
embeddings = ef(["test doc"])
assert embeddings is not None
assert len(embeddings) == 1
assert len(embeddings[0]) > 0


@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.")
def test_voyage_input_type_document() -> None:
if "VOYAGEAI_API_KEY" not in os.environ:
pytest.skip("VOYAGEAI_API_KEY not set, not going to test VoyageAI EF.")
ef = VoyageAIEmbeddingFunction(
api_key=os.environ.get("VOYAGEAI_API_KEY", ""), input_type="document"
api_key=os.environ.get("VOYAGE_API_KEY", ""),
input_type=VoyageAIEmbeddingFunction.InputType.DOCUMENT,
)
embeddings = ef(["test doc"])
assert embeddings is not None
assert len(embeddings) == 1
assert len(embeddings[0]) > 0


@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.")
def test_voyage_model() -> None:
if "VOYAGEAI_API_KEY" not in os.environ:
pytest.skip("VOYAGEAI_API_KEY not set, not going to test VoyageAI EF.")
ef = VoyageAIEmbeddingFunction(
api_key=os.environ.get("VOYAGEAI_API_KEY", ""), model_name="voyage-code-2"
api_key=os.environ.get("VOYAGE_API_KEY", ""), model_name="voyage-01"
)
embeddings = ef(["def test():\n return 1"])
assert embeddings is not None
assert len(embeddings) == 1
assert len(embeddings[0]) > 0


@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.")
def test_voyage_truncation_default() -> None:
if "VOYAGEAI_API_KEY" not in os.environ:
pytest.skip("VOYAGEAI_API_KEY not set, not going to test VoyageAI EF.")
ef = VoyageAIEmbeddingFunction(api_key=os.environ.get("VOYAGEAI_API_KEY", ""))
ef = VoyageAIEmbeddingFunction(api_key=os.environ.get("VOYAGE_API_KEY", ""))
embeddings = ef(["this is a test-message" * 10000])
assert embeddings is not None
assert len(embeddings) == 1
assert len(embeddings[0]) > 0


@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.")
def test_voyage_truncation_enabled() -> None:
if "VOYAGEAI_API_KEY" not in os.environ:
pytest.skip("VOYAGEAI_API_KEY not set, not going to test VoyageAI EF.")
ef = VoyageAIEmbeddingFunction(
api_key=os.environ.get("VOYAGEAI_API_KEY", ""), truncation=True
api_key=os.environ.get("VOYAGE_API_KEY", ""), truncation=True
)
embeddings = ef(["this is a test-message" * 10000])
assert embeddings is not None
assert len(embeddings) == 1
assert len(embeddings[0]) > 0


@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.")
def test_voyage_truncation_disabled() -> None:
if "VOYAGEAI_API_KEY" not in os.environ:
pytest.skip("VOYAGEAI_API_KEY not set, not going to test VoyageAI EF.")
ef = VoyageAIEmbeddingFunction(
api_key=os.environ.get("VOYAGEAI_API_KEY", ""), truncation=False
api_key=os.environ.get("VOYAGE_API_KEY", ""), truncation=False
)
with pytest.raises(Exception, match="your batch has too many tokens"):
ef(["this is a test-message" * 10000])


def test_voyage_no_api_key() -> None:
if "VOYAGEAI_API_KEY" not in os.environ:
pytest.skip("VOYAGEAI_API_KEY not set, not going to test VoyageAI EF.")
@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.")
def test_voyage_env_api_key() -> None:
VoyageAIEmbeddingFunction()


@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.")
def test_voyage_no_api_key(remove_api_key) -> None:
with pytest.raises(ValueError, match="Please provide a VoyageAI API key"):
VoyageAIEmbeddingFunction(api_key=None) # type: ignore


@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.")
def test_voyage_no_api_key_in_env(remove_api_key) -> None:
with pytest.raises(ValueError, match="Please provide a VoyageAI API key"):
VoyageAIEmbeddingFunction(api_key=None) # type: ignore


@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.")
def test_voyage_max_batch_size_exceeded_in_init() -> None:
if "VOYAGEAI_API_KEY" not in os.environ:
pytest.skip("VOYAGEAI_API_KEY not set, not going to test VoyageAI EF.")
with pytest.raises(ValueError, match="The maximum batch size supported is"):
VoyageAIEmbeddingFunction(api_key="dummy", max_batch_size=99999999)


@pytest.mark.skipif("VOYAGE_API_KEY" not in os.environ, reason="VOYAGE_API_KEY not set, not going to test VoyageAI EF.")
def test_voyage_max_batch_size_exceeded_in_call() -> None:
if "VOYAGEAI_API_KEY" not in os.environ:
pytest.skip("VOYAGEAI_API_KEY not set, not going to test VoyageAI EF.")
ef = VoyageAIEmbeddingFunction(api_key="dummy", max_batch_size=1)
with pytest.raises(ValueError, match="The maximum batch size supported is"):
ef(["test doc"] * 2)
18 changes: 11 additions & 7 deletions chromadb/utils/embedding_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import hashlib
import logging
from enum import Enum
from functools import cached_property

from tenacity import stop_after_attempt, wait_random, retry, retry_if_exception
Expand Down Expand Up @@ -898,16 +899,20 @@ def __call__(self, input: Documents) -> Embeddings:
)


class VoyageAIEmbeddingFunction(EmbeddingFunction):
class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
"""Embedding function for Voyageai.com. API docs - https://docs.voyageai.com/reference/embeddings-api"""

class InputType(str, Enum):
DOCUMENT = "document"
QUERY = "query"

def __init__(
self,
api_key: str,
api_key: Optional[str] = None,
model_name: str = "voyage-2",
max_batch_size: int = 128,
truncation: Optional[bool] = True,
input_type: Optional[str] = None,
input_type: Optional[InputType] = None,
):
"""
Initialize the VoyageAIEmbeddingFunction.
Expand All @@ -919,7 +924,7 @@ def __init__(
input_type (str, optional): The type of input text. Can be `None`, `query`, `document`. Defaults to `None`.
"""

if not api_key:
if not api_key and "VOYAGE_API_KEY" not in os.environ:
raise ValueError("Please provide a VoyageAI API key.")

try:
Expand All @@ -929,11 +934,10 @@ def __init__(
raise ValueError(
f"The maximum batch size supported is {voyageai.VOYAGE_EMBED_BATCH_SIZE}."
)
voyageai.api_key = api_key # Voyage API Key
self._batch_size = max_batch_size
self._model = model_name
self._truncation = truncation
self._client = voyageai.Client()
self._client = voyageai.Client(api_key=api_key)
self._input_type = input_type
except ImportError:
raise ValueError(
Expand All @@ -960,7 +964,7 @@ def __call__(self, input: Documents) -> Embeddings:
truncation=self._truncation,
input_type=self._input_type,
)
return results.embeddings
return cast(Embeddings, results.embeddings)


def create_langchain_embedding(langchain_embdding_fn: Any): # type: ignore
Expand Down

0 comments on commit 2ffd9e3

Please sign in to comment.