Skip to content

Commit

Permalink
feat: add create method to handle token headers (#630)
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew Smith authored Dec 1, 2023
2 parents b9240d8 + 4f47306 commit fd612a0
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 35 deletions.
22 changes: 11 additions & 11 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion supabase/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from .__version__ import __version__
from ._sync.auth_client import SyncSupabaseAuthClient as SupabaseAuthClient
from ._sync.client import Client
from ._sync.client import SyncClient as Client
from ._sync.client import SyncStorageClient as SupabaseStorageClient
from ._sync.client import create_client
from .lib.realtime_client import SupabaseRealtimeClient
42 changes: 30 additions & 12 deletions supabase/_async/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from typing import Any, Dict, Union

from gotrue.types import AuthChangeEvent
from gotrue.types import AuthChangeEvent, Session
from httpx import Timeout
from postgrest import (
AsyncFilterRequestBuilder,
Expand All @@ -24,7 +24,7 @@ def __init__(self, message: str):
super().__init__(self.message)


class Client:
class AsyncClient:
"""Supabase client class."""

def __init__(
Expand Down Expand Up @@ -63,6 +63,9 @@ def __init__(

self.supabase_url = supabase_url
self.supabase_key = supabase_key
self._auth_token = {
"Authorization": f"Bearer {supabase_key}",
}
options.headers.update(self._get_auth_headers())
self.options = options
self.rest_url = f"{supabase_url}/rest/v1"
Expand All @@ -88,6 +91,17 @@ def __init__(
self._functions = None
self.auth.on_auth_state_change(self._listen_to_auth_events)

@classmethod
async def create(
cls,
supabase_url: str,
supabase_key: str,
options: ClientOptions = ClientOptions(),
):
client = cls(supabase_url, supabase_key, options)
client._auth_token = await client._get_token_header()
return client

def table(self, table_name: str) -> AsyncRequestBuilder:
"""Perform a table operation.
Expand Down Expand Up @@ -125,20 +139,21 @@ def rpc(self, fn: str, params: Dict[Any, Any]) -> AsyncFilterRequestBuilder:
@property
def postgrest(self):
if self._postgrest is None:
self.options.headers.update(self._get_token_header())
self.options.headers.update(self._auth_token)
self._postgrest = self._init_postgrest_client(
rest_url=self.rest_url,
headers=self.options.headers,
schema=self.options.schema,
timeout=self.options.postgrest_client_timeout,
)

return self._postgrest

@property
def storage(self):
if self._storage is None:
headers = self._get_auth_headers()
headers.update(self._get_token_header())
headers.update(self._auth_token)
self._storage = self._init_storage_client(
storage_url=self.storage_url,
headers=headers,
Expand All @@ -150,7 +165,7 @@ def storage(self):
def functions(self):
if self._functions is None:
headers = self._get_auth_headers()
headers.update(self._get_token_header())
headers.update(self._auth_token)
self._functions = AsyncFunctionsClient(self.functions_url, headers)
return self._functions

Expand Down Expand Up @@ -231,29 +246,30 @@ def _get_auth_headers(self) -> Dict[str, str]:
"Authorization": f"Bearer {self.supabase_key}",
}

def _get_token_header(self):
async def _get_token_header(self):
try:
access_token = self.auth.get_session().access_token
except:
session = await self.auth.get_session()
access_token = session.access_token
except Exception as err:
access_token = self.supabase_key

return {
"Authorization": f"Bearer {access_token}",
}

def _listen_to_auth_events(self, event: AuthChangeEvent, session):
def _listen_to_auth_events(self, event: AuthChangeEvent, session: Session):
if event in ["SIGNED_IN", "TOKEN_REFRESHED", "SIGNED_OUT"]:
# reset postgrest and storage instance on event change
self._postgrest = None
self._storage = None
self._functions = None


def create_client(
async def create_client(
supabase_url: str,
supabase_key: str,
options: ClientOptions = ClientOptions(),
) -> Client:
) -> AsyncClient:
"""Create client function to instantiate supabase client like JS runtime.
Parameters
Expand All @@ -280,4 +296,6 @@ def create_client(
-------
Client
"""
return Client(supabase_url=supabase_url, supabase_key=supabase_key, options=options)
return await AsyncClient.create(
supabase_url=supabase_url, supabase_key=supabase_key, options=options
)
38 changes: 28 additions & 10 deletions supabase/_sync/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from typing import Any, Dict, Union

from gotrue.types import AuthChangeEvent
from gotrue.types import AuthChangeEvent, Session
from httpx import Timeout
from postgrest import SyncFilterRequestBuilder, SyncPostgrestClient, SyncRequestBuilder
from postgrest.constants import DEFAULT_POSTGREST_CLIENT_TIMEOUT
Expand All @@ -20,7 +20,7 @@ def __init__(self, message: str):
super().__init__(self.message)


class Client:
class SyncClient:
"""Supabase client class."""

def __init__(
Expand Down Expand Up @@ -59,6 +59,9 @@ def __init__(

self.supabase_url = supabase_url
self.supabase_key = supabase_key
self._auth_token = {
"Authorization": f"Bearer {supabase_key}",
}
options.headers.update(self._get_auth_headers())
self.options = options
self.rest_url = f"{supabase_url}/rest/v1"
Expand All @@ -84,6 +87,17 @@ def __init__(
self._functions = None
self.auth.on_auth_state_change(self._listen_to_auth_events)

@classmethod
def create(
cls,
supabase_url: str,
supabase_key: str,
options: ClientOptions = ClientOptions(),
):
client = cls(supabase_url, supabase_key, options)
client._auth_token = client._get_token_header()
return client

def table(self, table_name: str) -> SyncRequestBuilder:
"""Perform a table operation.
Expand Down Expand Up @@ -121,20 +135,21 @@ def rpc(self, fn: str, params: Dict[Any, Any]) -> SyncFilterRequestBuilder:
@property
def postgrest(self):
if self._postgrest is None:
self.options.headers.update(self._get_token_header())
self.options.headers.update(self._auth_token)
self._postgrest = self._init_postgrest_client(
rest_url=self.rest_url,
headers=self.options.headers,
schema=self.options.schema,
timeout=self.options.postgrest_client_timeout,
)

return self._postgrest

@property
def storage(self):
if self._storage is None:
headers = self._get_auth_headers()
headers.update(self._get_token_header())
headers.update(self._auth_token)
self._storage = self._init_storage_client(
storage_url=self.storage_url,
headers=headers,
Expand All @@ -146,7 +161,7 @@ def storage(self):
def functions(self):
if self._functions is None:
headers = self._get_auth_headers()
headers.update(self._get_token_header())
headers.update(self._auth_token)
self._functions = SyncFunctionsClient(self.functions_url, headers)
return self._functions

Expand Down Expand Up @@ -229,15 +244,16 @@ def _get_auth_headers(self) -> Dict[str, str]:

def _get_token_header(self):
try:
access_token = self.auth.get_session().access_token
except:
session = self.auth.get_session()
access_token = session.access_token
except Exception as err:
access_token = self.supabase_key

return {
"Authorization": f"Bearer {access_token}",
}

def _listen_to_auth_events(self, event: AuthChangeEvent, session):
def _listen_to_auth_events(self, event: AuthChangeEvent, session: Session):
if event in ["SIGNED_IN", "TOKEN_REFRESHED", "SIGNED_OUT"]:
# reset postgrest and storage instance on event change
self._postgrest = None
Expand All @@ -249,7 +265,7 @@ def create_client(
supabase_url: str,
supabase_key: str,
options: ClientOptions = ClientOptions(),
) -> Client:
) -> SyncClient:
"""Create client function to instantiate supabase client like JS runtime.
Parameters
Expand All @@ -276,4 +292,6 @@ def create_client(
-------
Client
"""
return Client(supabase_url=supabase_url, supabase_key=supabase_key, options=options)
return SyncClient.create(
supabase_url=supabase_url, supabase_key=supabase_key, options=options
)
3 changes: 2 additions & 1 deletion supabase/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from .__version__ import __version__
from ._sync.auth_client import SyncSupabaseAuthClient as SupabaseAuthClient
from ._sync.client import Client, ClientOptions
from ._sync.client import ClientOptions
from ._sync.client import SyncClient as Client
from ._sync.client import SyncStorageClient as SupabaseStorageClient
from ._sync.client import create_client
from .lib.realtime_client import SupabaseRealtimeClient
Expand Down

0 comments on commit fd612a0

Please sign in to comment.