Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] support Multi clouds for ARM SDK (not ready to review) #5925

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
7 changes: 7 additions & 0 deletions .chronus/changes/multi-clouds-2025-1-10-14-54-3.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
changeKind: feature
packages:
- "@typespec/http-client-python"
---

Improve user experience in multi clouds scenario
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
OverloadedRequestBuilder,
get_request_builder,
)
from .parameter import Parameter, ParameterMethodLocation
from .parameter import Parameter, ParameterMethodLocation, ParameterLocation
from .lro_operation import LROOperation
from .lro_paging_operation import LROPagingOperation
from ...utils import extract_original_name, NAME_LENGTH_LIMIT
Expand Down Expand Up @@ -54,7 +54,7 @@ def name(self) -> str:
return self.yaml_data["name"]


class Client(_ClientConfigBase[ClientGlobalParameterList]):
class Client(_ClientConfigBase[ClientGlobalParameterList]): # pylint: disable=too-many-public-methods
"""Model representing our service client"""

def __init__(
Expand All @@ -79,6 +79,26 @@ def __init__(
self.request_id_header_name = self.yaml_data.get("requestIdHeaderName", None)
self.has_etag: bool = yaml_data.get("hasEtag", False)

# update the host parameter value. In later logic, SDK will overwrite it
# with value from cloud_setting if users don't provide it.
if self.need_cloud_setting:
for p in self.parameters.parameters:
if p.location == ParameterLocation.ENDPOINT_PATH:
p.client_default_value = ""
break

@property
def need_cloud_setting(self) -> bool:
return bool(
self.code_model.options["azure_arm"]
and self.credential_scopes is not None
and self.endpoint_parameter is not None
)

@property
def endpoint_parameter(self) -> Optional[Parameter]:
return next((p for p in self.parameters.parameters if p.location == ParameterLocation.ENDPOINT_PATH), None)

def _build_request_builders(
self,
) -> List[Union[RequestBuilder, OverloadedRequestBuilder]]:
Expand Down Expand Up @@ -241,6 +261,17 @@ def _imports_shared(self, async_mode: bool, **kwargs) -> FileImport:
"Self",
ImportType.STDLIB,
)
if self.need_cloud_setting:
file_import.add_submodule_import(
"azure.core.settings",
"settings",
ImportType.SDKCORE,
)
file_import.add_submodule_import(
"azure.mgmt.core.tools",
"get_arm_endpoints",
ImportType.SDKCORE,
)
return file_import

@property
Expand Down Expand Up @@ -340,6 +371,18 @@ def imports_for_multiapi(self, async_mode: bool, **kwargs) -> FileImport:
)
return file_import

@property
def credential_scopes(self) -> Optional[List[str]]:
"""Credential scopes for this client"""

if self.credential:
if hasattr(getattr(self.credential.type, "policy", None), "credential_scopes"):
return self.credential.type.policy.credential_scopes # type: ignore
for t in getattr(self.credential.type, "types", []):
if hasattr(getattr(t, "policy", None), "credential_scopes"):
return t.policy.credential_scopes
return None

@classmethod
def from_yaml(
cls,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,9 @@ def response_deserialization( # pylint: disable=too-many-statements
retval.extend(deserialize_code)
return retval

def handle_error_response(self, builder: OperationType) -> List[str]: # pylint: disable=too-many-statements, too-many-branches
def handle_error_response( # pylint: disable=too-many-statements, too-many-branches
self, builder: OperationType
) -> List[str]:
async_await = "await " if self.async_mode else ""
retval = [f"if response.status_code not in {str(builder.success_status_codes)}:"]
response_read = [
Expand Down Expand Up @@ -1084,9 +1086,7 @@ def handle_error_response(self, builder: OperationType) -> List[str]: # pylint
f" error = _failsafe_deserialize_xml({type_annotation}, response.text())"
)
else:
retval.append(
f" error = _failsafe_deserialize({type_annotation}, response.json())"
)
retval.append(f" error = _failsafe_deserialize({type_annotation}, response.json())")
else:
retval.append(
f" error = self._deserialize.failsafe_deserialize({type_annotation}, "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from typing import List
from typing import List, cast

from . import utils
from ..models import Client, ParameterMethodLocation
from ..models import Client, ParameterMethodLocation, Parameter
from .parameter_serializer import ParameterSerializer, PopKwargType
from ...utils import build_policies

Expand Down Expand Up @@ -77,17 +77,33 @@ def property_descriptions(self, async_mode: bool) -> List[str]:
retval.append('"""')
return retval

def initialize_config(self) -> str:
def initialize_config(self) -> List[str]:
retval = []
additional_signatures = []
if self.client.need_cloud_setting:
additional_signatures.append("credential_scopes=credential_scopes")
endpoint_parameter = cast(Parameter, self.client.endpoint_parameter)
retval.extend(
[
'_cloud = kwargs.pop("cloud_setting", None) or settings.current.azure_cloud # type: ignore',
"_endpoints = get_arm_endpoints(_cloud)",
f"if not {endpoint_parameter.client_name}:",
f' {endpoint_parameter.client_name} = _endpoints["resource_manager"]',
'credential_scopes = kwargs.pop("credential_scopes", _endpoints["credential_scopes"])',
]
)
config_name = f"{self.client.name}Configuration"
config_call = ", ".join(
[
f"{p.client_name}={p.client_name}"
for p in self.client.config.parameters.method
if p.method_location != ParameterMethodLocation.KWARG
]
+ additional_signatures
+ ["**kwargs"]
)
return f"self._config = {config_name}({config_call})"
retval.append(f"self._config = {config_name}({config_call})")
return retval

@property
def host_variable_name(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _imports(self) -> FileImportSerializer:
ImportType.SDKCORE,
)
for param in self.operation.parameters.positional + self.operation.parameters.keyword_only:
if not param.client_default_value and not param.optional and param.wire_name in self.sample_params:
if param.client_default_value is None and not param.optional and param.wire_name in self.sample_params:
imports.merge(param.type.imports_for_sample())
return FileImportSerializer(imports, True)

Expand All @@ -80,7 +80,7 @@ def _client_params(self) -> Dict[str, Any]:
for p in (
self.code_model.clients[0].parameters.positional + self.code_model.clients[0].parameters.keyword_only
)
if not (p.optional or p.client_default_value)
if not p.optional and p.client_default_value is None
]
client_params = {
p.client_name: special_param.get(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
{% if client.has_parameterized_host %}
{{ serializer.host_variable_name }} = {{ keywords.escape_str(client.url) }}{{ client.url_pylint_disable }}
{% endif %}
{{ serializer.initialize_config() }}
{{ op_tools.serialize(serializer.initialize_config()) | indent(8) }}
{{ op_tools.serialize(serializer.initialize_pipeline_client(async_mode)) | indent(8) }}

{{ op_tools.serialize(serializer.serializers_and_operation_groups_properties()) | indent(8) }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,8 @@ class {{ client.name }}Configuration: {{ client.config.pylint_disable() }}
{% if serializer.set_constants() %}
{{ op_tools.serialize(serializer.set_constants()) | indent(8) -}}
{% endif %}
{% if client.credential %}
{% set cred_scopes = client.credential.type if client.credential.type.policy is defined and client.credential.type.policy.credential_scopes is defined %}
{% if not cred_scopes %}
{% set cred_scopes = client.credential.type.types | selectattr("policy.credential_scopes") | first if client.credential.type.types is defined %}
{% endif %}
{% if cred_scopes %}
self.credential_scopes = kwargs.pop('credential_scopes', {{ cred_scopes.policy.credential_scopes }})
{% endif %}
{% if client.credential_scopes is not none %}
self.credential_scopes = kwargs.pop('credential_scopes', {{ client.credential_scopes }})
{% endif %}
kwargs.setdefault('sdk_moniker', '{{ client.config.sdk_moniker }}/{}'.format(VERSION))
self.polling_interval = kwargs.get("polling_interval", 30)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ setup(
"isodate>=0.6.1",
{% endif %}
{% if azure_arm %}
"azure-mgmt-core>=1.3.2",
"azure-mgmt-core>=1.5.0",
{% elif code_model.is_azure_flavor %}
"azure-core>=1.30.0",
{% else %}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
-r ../dev_requirements.txt
-e ../../
azure-core==1.30.0
azure-mgmt-core==1.3.2
azure-mgmt-core==1.5.0

# only for azure
-e ./generated/azure-client-generator-core-access
Expand Down
Loading