Skip to content

Commit

Permalink
SNOW-1507212 create an interface for Snowflake restful class(es)
Browse files Browse the repository at this point in the history
Description
- create an interface SnowflakeRestfulInterface
  - both the client side restful class and the server side restful class shall conform to this interface
- update the type annotation to use SnowflakeRestfulInterface instead of the concrete class SnowflakeRestful
Testing
  • Loading branch information
sfc-gh-zyao committed Jun 27, 2024
1 parent 42fa6eb commit f7fe991
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 9 deletions.
3 changes: 2 additions & 1 deletion src/snowflake/connector/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
ReauthenticationRequest,
SnowflakeRestful,
)
from .snowflake_restful_interface import SnowflakeRestfulInterface
from .sqlstate import SQLSTATE_CONNECTION_NOT_EXISTS, SQLSTATE_FEATURE_NOT_SUPPORTED
from .telemetry import TelemetryClient, TelemetryData, TelemetryField
from .telemetry_oob import TelemetryService
Expand Down Expand Up @@ -584,7 +585,7 @@ def client_prefetch_threads(self, value) -> None:
self._validate_client_prefetch_threads()

@property
def rest(self) -> SnowflakeRestful | None:
def rest(self) -> SnowflakeRestfulInterface | None:
return self._rest

@property
Expand Down
7 changes: 4 additions & 3 deletions src/snowflake/connector/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
ServiceUnavailableError,
TooManyRequests,
)
from .snowflake_restful_interface import SnowflakeRestfulInterface
from .sqlstate import (
SQLSTATE_CONNECTION_NOT_EXISTS,
SQLSTATE_CONNECTION_REJECTED,
Expand Down Expand Up @@ -319,11 +320,11 @@ def __call__(self, r: PreparedRequest) -> PreparedRequest:


class SessionPool:
def __init__(self, rest: SnowflakeRestful) -> None:
def __init__(self, rest: SnowflakeRestfulInterface) -> None:
# A stack of the idle sessions
self._idle_sessions: list[Session] = []
self._active_sessions: set[Session] = set()
self._rest: SnowflakeRestful = rest
self._rest: SnowflakeRestfulInterface = rest

def get_session(self) -> Session:
"""Returns a session from the session pool or creates a new one."""
Expand Down Expand Up @@ -361,7 +362,7 @@ def close(self) -> None:
self._idle_sessions.clear()


class SnowflakeRestful:
class SnowflakeRestful(SnowflakeRestfulInterface):
"""Snowflake Restful class."""

def __init__(
Expand Down
142 changes: 142 additions & 0 deletions src/snowflake/connector/snowflake_restful_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#!/usr/bin/env python
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from .connection import SnowflakeConnection
from .vendored.requests import Session


class SnowflakeRestfulInterface(ABC):
"""Snowflake Restful Interface
Defines all the interfaces that we expose in the Snowflake restful classes. Both the client side restful class and
the server side one shall conform to this interface. And whenever we introduce a new public method, it should be
defined in this interface, and implemented in both restful classes.
"""

@property
@abstractmethod
def token(self) -> str | None:
pass

@property
@abstractmethod
def master_token(self) -> str | None:
pass

@property
@abstractmethod
def master_validity_in_seconds(self) -> int:
pass

@master_validity_in_seconds.setter
@abstractmethod
def master_validity_in_seconds(self, value) -> None:
pass

@property
@abstractmethod
def id_token(self):
pass

@id_token.setter
@abstractmethod
def id_token(self, value) -> None:
pass

@property
@abstractmethod
def mfa_token(self) -> str | None:
pass

@mfa_token.setter
@abstractmethod
def mfa_token(self, value: str) -> None:
pass

@property
@abstractmethod
def server_url(self) -> str:
pass

@abstractmethod
def close(self) -> None:
pass

@abstractmethod
def request(
self,
url,
body=None,
method: str = "post",
client: str = "sfsql",
timeout: int | None = None,
_no_results: bool = False,
_include_retry_params: bool = False,
_no_retry: bool = False,
):
pass

@abstractmethod
def update_tokens(
self,
session_token,
master_token,
master_validity_in_seconds=None,
id_token=None,
mfa_token=None,
) -> None:
"""Updates session and master tokens and optionally temporary credential."""
pass

@abstractmethod
def delete_session(self, retry: bool = False) -> None:
"""Deletes the session."""
pass

@abstractmethod
def fetch(
self,
method: str,
full_url: str,
headers: dict[str, Any],
data: dict[str, Any] | None = None,
timeout: int | None = None,
**kwargs,
) -> dict[Any, Any]:
"""Carry out API request with session management."""
pass

@staticmethod
@abstractmethod
def add_request_guid(full_url: str) -> str:
"""Adds request_guid parameter for HTTP request tracing."""
pass

@abstractmethod
def log_and_handle_http_error_with_cause(
self,
e: Exception,
full_url: str,
method: str,
retry_timeout: int,
retry_count: int,
conn: SnowflakeConnection,
timed_out: bool = True,
) -> None:
pass

@abstractmethod
def handle_invalid_certificate_error(self, conn, full_url, cause) -> None:
pass

@abstractmethod
def make_requests_session(self) -> Session:
pass
6 changes: 3 additions & 3 deletions src/snowflake/connector/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

if TYPE_CHECKING:
from .connection import SnowflakeConnection
from .network import SnowflakeRestful
from .snowflake_restful_interface import SnowflakeRestfulInterface

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -131,8 +131,8 @@ class TelemetryClient:
SF_PATH_TELEMETRY = "/telemetry/send"
DEFAULT_FORCE_FLUSH_SIZE = 100

def __init__(self, rest: SnowflakeRestful, flush_size=None) -> None:
self._rest: SnowflakeRestful | None = rest
def __init__(self, rest: SnowflakeRestfulInterface, flush_size=None) -> None:
self._rest: SnowflakeRestfulInterface | None = rest
self._log_batch = []
self._flush_size = flush_size or TelemetryClient.DEFAULT_FORCE_FLUSH_SIZE
self._lock = Lock()
Expand Down
5 changes: 3 additions & 2 deletions test/unit/test_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from unittest import mock

from snowflake.connector.network import SnowflakeRestful
from snowflake.connector.snowflake_restful_interface import SnowflakeRestfulInterface

try:
from snowflake.connector.ssl_wrap_socket import DEFAULT_OCSP_MODE
Expand All @@ -32,15 +33,15 @@ class OCSPMode(Enum):
mock_conn._ocsp_mode = lambda: DEFAULT_OCSP_MODE


def close_sessions(rest: SnowflakeRestful, num_session_pools: int) -> None:
def close_sessions(rest: SnowflakeRestfulInterface, num_session_pools: int) -> None:
"""Helper function to call SnowflakeRestful.close(). Asserts close was called on all SessionPools."""
with mock.patch("snowflake.connector.network.SessionPool.close") as close_mock:
rest.close()
assert close_mock.call_count == num_session_pools


def create_session(
rest: SnowflakeRestful, num_sessions: int = 1, url: str | None = None
rest: SnowflakeRestfulInterface, num_sessions: int = 1, url: str | None = None
) -> None:
"""
Creates 'num_sessions' sessions to 'url'. This is recursive so that idle sessions
Expand Down

0 comments on commit f7fe991

Please sign in to comment.