Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add type annotations and fix linting issues #85

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions llm_gateway/db/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
# limitations under the License.

from contextlib import contextmanager
from typing import Iterator
from typing import Generator

from sqlalchemy import create_engine
from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm.decl_api import DeclarativeMeta

from llm_gateway.constants import get_settings
Expand All @@ -33,6 +33,8 @@ class DB:
Class for managing connections to the logging DB
"""

db_url: str

def __init__(self) -> None:
self.db_url = settings.DATABASE_URL

Expand All @@ -47,17 +49,17 @@ def create_db_engine(self) -> Engine:


@contextmanager
def db_session_scope() -> Iterator[None]:
def db_session_scope() -> Generator[Session, None, None]:
"""
Open a connected DB session

:raises Exception: Raised if session fails for some reason
:yield: DB session
:rtype: Iterator[None]
:rtype: Generator[Session, None, None]
"""
llm_gateway_db = DB()
session = sessionmaker(bind=llm_gateway_db.create_db_engine())
session = session()
session_factory = sessionmaker(bind=llm_gateway_db.create_db_engine())
session: Session = session_factory()
try:
yield session
session.commit()
Expand All @@ -76,4 +78,5 @@ def write_record_to_db(db_record: DeclarativeMeta) -> None:
:type db_record: DeclarativeMeta
"""
with db_session_scope() as session:
session.add(db_record)
if session is not None:
session.add(db_record)
22 changes: 11 additions & 11 deletions llm_gateway/providers/awsbedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@

import datetime
import json
from typing import Iterator, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple

import boto3
from fastapi.responses import JSONResponse

from llm_gateway.constants import get_settings
from llm_gateway.db.models import AWSBedrockRequests
from llm_gateway.db.utils import write_record_to_db
from llm_gateway.exceptions import AWSBEDROCK_EXCEPTIONS
from llm_gateway.pii_scrubber import scrub_all
from llm_gateway.types import AWSBedrockResponse, DBRecord
from llm_gateway.utils import max_retries

settings = get_settings()
Expand Down Expand Up @@ -113,11 +113,11 @@ def _structure_model_body(
model: str,
max_tokens: int,
prompt: Optional[str] = None,
embedding_texts: Optional[list] = None,
embedding_texts: Optional[List[str]] = None,
instruction: Optional[str] = None,
temperature: Optional[float] = 0,
**kwargs,
) -> Tuple[dict, str]:
**kwargs: Any,
) -> Tuple[Dict[str, Any], str]:
"""
Structure the body of the AWS Bedrock API request (Model specific)

Expand Down Expand Up @@ -219,8 +219,8 @@ def _structure_model_body(
def _invoke_awsbedrock_model(
self,
model: str,
body: dict,
) -> JSONResponse:
body: Dict[str, Any],
) -> Dict[str, Any]:
"""
Call the invoke model enpoint from the AWS Bedrock client and return response

Expand Down Expand Up @@ -249,9 +249,9 @@ def send_awsbedrock_request(
prompt: Optional[str] = None,
temperature: Optional[float] = 0,
instruction: Optional[str] = None,
embedding_texts: Optional[str] = None,
**kwargs,
) -> Tuple[Union[dict, Iterator[str]], dict]:
embedding_texts: Optional[List[str]] = None,
**kwargs: Any,
) -> Tuple[AWSBedrockResponse, DBRecord]:
"""
Send a request to the AWS Bedrock API and return response and logs for db write

Expand Down Expand Up @@ -309,7 +309,7 @@ def send_awsbedrock_request(

return awsbedrock_response, db_record

def write_logs_to_db(self, db_logs: dict):
def write_logs_to_db(self, db_logs: Dict[str, Any]) -> None:
if isinstance(db_logs["awsbedrock_response"], list):
db_logs["awsbedrock_response"] = "".join(db_logs["awsbedrock_response"])
write_record_to_db(AWSBedrockRequests(**db_logs))
28 changes: 15 additions & 13 deletions llm_gateway/providers/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import datetime
import json
from typing import Iterator, Optional
from typing import Any, Dict, Iterator, Optional, Tuple, Union

import cohere
from cohere.responses.generation import StreamingText
Expand All @@ -26,6 +26,7 @@
from llm_gateway.db.models import CohereRequests
from llm_gateway.db.utils import write_record_to_db
from llm_gateway.pii_scrubber import scrub_all
from llm_gateway.types import CohereResponse, DBRecord
from llm_gateway.utils import StreamProcessor

settings = get_settings()
Expand Down Expand Up @@ -62,8 +63,8 @@ def _call_summarize_endpoint(
additional_command: str,
model: str,
temperature: float,
**kwargs,
):
**kwargs: Any,
) -> Dict[str, Any]:
"""
Call the summarize endpoint from the Cohere client and return response

Expand Down Expand Up @@ -93,8 +94,8 @@ def _call_generate_endpoint(
max_tokens: int,
temperature: float,
stream: bool = False,
**kwargs,
):
**kwargs: Any,
) -> Union[Dict[str, Any], Iterator[StreamingText]]:
"""
Call the generate endpoint from the Cohere client and return response

Expand All @@ -119,14 +120,14 @@ def _call_generate_endpoint(
)
return resp

def _flatten_cohere_response(self, cohere_response):
def _flatten_cohere_response(self, cohere_response: Any) -> Dict[str, Any]:
"""
Flatten response from Cohere as JSON

:param cohere_response: Raw response from Cohere
:type cohere_response: _type_
:type cohere_response: Any
:return: Flattened Cohere response as JSON
:rtype: _type_
:rtype: Dict[str, Any]
"""
return json.loads(json.dumps(cohere_response, default=lambda o: o.__dict__))

Expand All @@ -139,8 +140,8 @@ def send_cohere_request(
temperature: Optional[float] = 0,
stream: bool = False,
additional_command: Optional[str] = "",
**kwargs,
):
**kwargs: Any,
) -> Tuple[Union[CohereResponse, Iterator[CohereResponse]], DBRecord]:
"""
Send a request to the Cohere API and log interaction to the DB

Expand Down Expand Up @@ -203,7 +204,8 @@ def write_logs_to_db(self, db_logs: dict):
write_record_to_db(CohereRequests(**db_logs))


def stream_generator_cohere(generator: Iterator) -> Iterator[dict]:
chunk: StreamingText
def stream_generator_cohere(
generator: Iterator[StreamingText],
) -> Iterator[Dict[str, str]]:
for chunk in generator:
yield chunk.text
yield {"content": chunk.text}
31 changes: 19 additions & 12 deletions llm_gateway/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import datetime
import json
from typing import Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

import openai

Expand All @@ -26,6 +26,7 @@
from llm_gateway.db.utils import write_record_to_db
from llm_gateway.exceptions import OPENAI_EXCEPTIONS
from llm_gateway.pii_scrubber import scrub_all
from llm_gateway.types import DBRecord, OpenAIResponse
from llm_gateway.utils import StreamProcessor, max_retries

settings = get_settings()
Expand Down Expand Up @@ -70,7 +71,9 @@ def _validate_openai_endpoint(self, module: str, endpoint: str) -> None:
f"`{endpoint}` not supported action for `{module}`"
)

def _call_model_endpoint(self, endpoint: str, model: Optional[str] = None):
def _call_model_endpoint(
self, endpoint: str, model: Optional[str] = None
) -> Dict[str, Any]:
"""
List or retrieve model(s) from OpenAI

Expand All @@ -79,8 +82,8 @@ def _call_model_endpoint(self, endpoint: str, model: Optional[str] = None):
:param model: Name of model, if "retrieve" is passed, defaults to None
:type model: Optional[str]
:raises Exception: Raised if endpoint is "retrieve" and model is unspecified
:return: List of models or retrieved model
:rtype: _type_
:return: OpenAI API response containing model information
:rtype: Dict[str, Any]
"""
if endpoint == "list":
return openai.Model.list()
Expand Down Expand Up @@ -199,11 +202,11 @@ def send_openai_request(
max_tokens: Optional[int] = None,
prompt: Optional[str] = None,
temperature: Optional[float] = 0,
messages: Optional[list] = None, # TODO: add pydantic type for messages
messages: Optional[List[Dict[str, str]]] = None,
instruction: Optional[str] = None,
embedding_texts: Optional[list] = None,
**kwargs,
) -> Tuple[Union[dict, Iterator[str]], dict]:
embedding_texts: Optional[List[str]] = None,
**kwargs: Any,
) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], DBRecord]:
"""
Send a request to the OpenAI API and return response and logs for db write

Expand Down Expand Up @@ -298,23 +301,27 @@ def write_logs_to_db(self, db_logs: dict):
write_record_to_db(OpenAIRequests(**db_logs))


def stream_generator_openai_chat(generator: Iterator) -> Iterator[str]:
def stream_generator_openai_chat(
generator: Iterator[OpenAIResponse],
) -> Iterator[OpenAIResponse]:
for chunk in generator:
answer = ""
try:
current_response = chunk["choices"][0]["delta"]["content"]
answer += current_response
except KeyError:
pass
yield answer
yield {"content": answer}


def stream_generator_openai_completion(generator: Iterator) -> Iterator[str]:
def stream_generator_openai_completion(
generator: Iterator[OpenAIResponse],
) -> Iterator[OpenAIResponse]:
for chunk in generator:
answer = ""
try:
current_response = chunk["choices"][0]["text"]
answer += current_response
except KeyError:
pass
yield answer
yield {"content": answer}
13 changes: 13 additions & 0 deletions llm_gateway/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import Any, Dict, TypeAlias

# Common response types
LLMResponse: TypeAlias = Dict[str, Any]
PromptMetadata: TypeAlias = Dict[str, Any]

# Provider-specific types
OpenAIResponse: TypeAlias = Dict[str, Any]
CohereResponse: TypeAlias = Dict[str, Any]
AWSBedrockResponse: TypeAlias = Dict[str, Any]

# Database record types
DBRecord: TypeAlias = Dict[str, Any]
14 changes: 9 additions & 5 deletions llm_gateway/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@

import traceback
from functools import wraps
from typing import Any, Callable, Iterator, List
from typing import Any, Callable, Dict, Iterator, List, TypeVar

from fastapi import HTTPException

from llm_gateway.logger import get_logger

T = TypeVar("T")

logger = get_logger(__name__)


Expand Down Expand Up @@ -56,16 +58,18 @@ def newfn(*args, **kwargs):


class StreamProcessor:
def __init__(self, stream_processor: Callable) -> None:
def __init__(
self, stream_processor: Callable[[Iterator[T]], Iterator[Dict[str, Any]]]
) -> None:
self.stream_processor = stream_processor
self.cached_streamed_response = []
self.cached_streamed_response: List[Dict[str, Any]] = []

def process_stream(self, response: Iterator) -> Iterator:
def process_stream(self, response: Iterator[T]) -> Iterator[Dict[str, Any]]:
for item in self.stream_processor(response):
self.cached_streamed_response.append(item)
yield item

def get_cached_streamed_response(self) -> List[str]:
def get_cached_streamed_response(self) -> List[Dict[str, Any]]:
return self.cached_streamed_response


Expand Down
Loading