From d305d8cb66ec85b6f9d621aca0e18a3cb58d6822 Mon Sep 17 00:00:00 2001 From: Daniil Gusev Date: Wed, 6 Nov 2024 14:57:02 +0100 Subject: [PATCH] Fix invalid mapping for `oauth_cb` in BaseSettings Also remove `oauthbearer_token_refresh_cb` since it's the same as `oauth_cb` --- quixstreams/kafka/configuration.py | 11 +++++++++-- .../test_quixstreams/test_kafka/test_configuration.py | 7 +++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/quixstreams/kafka/configuration.py b/quixstreams/kafka/configuration.py index fd8bafa8b..a3ad4b009 100644 --- a/quixstreams/kafka/configuration.py +++ b/quixstreams/kafka/configuration.py @@ -1,5 +1,6 @@ from typing import Callable, Literal, Optional, Tuple, Type +import pydantic from pydantic import AliasChoices, Field, SecretStr from pydantic.functional_validators import BeforeValidator from pydantic_settings import PydanticBaseSettingsSource @@ -44,11 +45,17 @@ class ConnectionConfig(BaseSettings): sasl_kerberos_min_time_before_relogin: Optional[int] = None sasl_kerberos_service_name: Optional[str] = None sasl_kerberos_principal: Optional[str] = None + # for oauth_cb, see https://docs.confluent.io/platform/current/clients/confluent-kafka-python/html/index.html#pythonclient-configuration - oauth_cb: Optional[Callable[[str], Tuple[str, float]]] = None + oauth_cb: Optional[Callable[[str], Tuple[str, float]]] = pydantic.Field( + # Prevent the AliasGenerator from changing the field name to "oauth.cb" + default=None, + alias_priority=2, + serialization_alias="oauth_cb", + ) + sasl_oauthbearer_config: Optional[str] = None enable_sasl_oauthbearer_unsecure_jwt: Optional[bool] = None - oauthbearer_token_refresh_cb: Optional[Callable] = None sasl_oauthbearer_method: Annotated[ Optional[Literal["default", "oidc"]], BeforeValidator(lambda v: v.lower() if v is not None else v), diff --git a/tests/test_quixstreams/test_kafka/test_configuration.py b/tests/test_quixstreams/test_kafka/test_configuration.py index 4ee886219..0831400d5 100644 --- a/tests/test_quixstreams/test_kafka/test_configuration.py +++ b/tests/test_quixstreams/test_kafka/test_configuration.py @@ -31,12 +31,14 @@ def test_from_librdkafka_dict(self, mechanism_casing): "bootstrap.servers": "url", "sasl.mechanism": mechanism_casing, "sasl.username": "my-username", + "oauth_cb": lambda _: _, } config = ConnectionConfig.from_librdkafka_dict(librdkafka_dict) assert config.bootstrap_servers == librdkafka_dict["bootstrap.servers"] assert config.sasl_mechanism == librdkafka_dict["sasl.mechanism"].upper() assert config.sasl_username == librdkafka_dict["sasl.username"] + assert config.oauth_cb == librdkafka_dict["oauth_cb"] def test_from_librdkafka_dict_extras_raise(self): librdkafka_dict = { @@ -96,6 +98,11 @@ def test_sasl_mechanism_aliases(self): assert "sasl.mechanism" in d assert "sasl.mechanisms" not in d + def test_oauth_cb(self): + config = ConnectionConfig(bootstrap_servers="url", oauth_cb=lambda _: _) + rd_config = config.as_librdkafka_dict() + assert config.oauth_cb == rd_config["oauth_cb"] + def test_secret_field(self): """ Confirm a secret field is obscured