Skip to content

Commit

Permalink
grpc-web-text support for ASGI backend
Browse files Browse the repository at this point in the history
  • Loading branch information
public committed Apr 26, 2021
1 parent 9884f8f commit 6c781e9
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 48 deletions.
149 changes: 102 additions & 47 deletions sonora/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,14 @@ async def __call__(self, scope, receive, send):
if rpc_method:
if request_method == "POST":
context = self._create_context(scope)

try:
async with timeout(context.time_remaining()):
await self._do_grpc_request(rpc_method, context, receive, send)
except asyncio.TimeoutError:
context.code = grpc.StatusCode.DEADLINE_EXCEEDED
context.details = "request timed out at the server"
await self._do_grpc_error(context, send)
await self._do_grpc_error(send, context)

elif request_method == "OPTIONS":
await self._do_cors_preflight(scope, receive, send)
Expand Down Expand Up @@ -87,10 +88,9 @@ def _create_context(self, scope):
return ServicerContext(timeout, metadata)

async def _do_grpc_request(self, rpc_method, context, receive, send):
request_proto_iterator = (
rpc_method.request_deserializer(message)
async for _, _, message in protocol.unwrap_message_asgi(receive)
)
headers = context._response_headers
wrap_message = context._wrap_message
unwrap_message = context._unwrap_message

if not rpc_method.request_streaming and not rpc_method.response_streaming:
method = rpc_method.unary_unary
Expand All @@ -103,40 +103,48 @@ async def _do_grpc_request(self, rpc_method, context, receive, send):
else:
raise NotImplementedError

if rpc_method.request_streaming:
coroutine = method(request_proto_iterator, context)
else:
request_proto = await anext(request_proto_iterator)
coroutine = method(request_proto, context)
request_proto_iterator = (
rpc_method.request_deserializer(message)
async for _, _, message in unwrap_message(receive)
)

headers = [
(b"Content-Type", b"application/grpc-web+proto"),
(b"Access-Control-Allow-Origin", b"*"),
(b"Access-Control-Expose-Headers", b"*"),
]
try:
if rpc_method.request_streaming:
coroutine = method(request_proto_iterator, context)
else:
request_proto = await anext(
request_proto_iterator, None
) or rpc_method.request_deserializer(b"")
coroutine = method(request_proto, context)
except NotImplementedError:
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
coroutine = None

try:
if rpc_method.response_streaming:
await self._do_streaming_response(
rpc_method, receive, send, context, headers, coroutine
rpc_method, receive, send, wrap_message, context, coroutine
)
else:
await self._do_unary_response(
rpc_method, receive, send, context, headers, coroutine
rpc_method, receive, send, wrap_message, context, coroutine
)
except grpc.RpcError:
await self._do_grpc_error(context, send)
await self._do_grpc_error(send, context)

async def _do_streaming_response(
self, rpc_method, receive, send, context, headers, coroutine
self, rpc_method, receive, send, wrap_message, context, coroutine
):
message = await anext(coroutine)
headers = context._response_headers

status = protocol.grpc_status_to_http_status(context.code)
if coroutine:
message = await anext(coroutine)
else:
message = b""

body = protocol.wrap_message(
False, False, rpc_method.response_serializer(message)
)
status = 200

body = wrap_message(False, False, rpc_method.response_serializer(message))

if context._initial_metadata:
headers.extend(context._initial_metadata)
Expand All @@ -148,9 +156,7 @@ async def _do_streaming_response(
await send({"type": "http.response.body", "body": body, "more_body": True})

async for message in coroutine:
body = protocol.wrap_message(
False, False, rpc_method.response_serializer(message)
)
body = wrap_message(False, False, rpc_method.response_serializer(message))

send_task = asyncio.create_task(
send({"type": "http.response.body", "body": body, "more_body": True})
Expand Down Expand Up @@ -178,33 +184,42 @@ async def _do_streaming_response(
trailers.extend(context._trailing_metadata)

trailer_message = protocol.pack_trailers(trailers)
body = protocol.wrap_message(True, False, trailer_message)
body = wrap_message(True, False, trailer_message)
await send({"type": "http.response.body", "body": body, "more_body": False})

async def _do_unary_response(
self, rpc_method, receive, send, context, headers, coroutine
self, rpc_method, receive, send, wrap_message, context, coroutine
):
message = await coroutine
headers = context._response_headers

if coroutine is None:
message = None
else:
message = await coroutine

status = 200

status = protocol.grpc_status_to_http_status(context.code)
headers.append((b"grpc-status", str(context.code.value[0]).encode()))
if context.details:
headers.append((b"grpc-message", quote(context.details)))
headers.append(
(b"grpc-message", quote(context.details.encode("utf8")).encode("ascii"))
)

if context._initial_metadata:
headers.extend(context._initial_metadata)

if message is not None:
message_data = protocol.wrap_message(
message_data = wrap_message(
False, False, rpc_method.response_serializer(message)
)
else:
message_data = b""

if context._trailing_metadata:
trailers = context._trailing_metadata

trailer_message = protocol.pack_trailers(trailers)
trailer_data = protocol.wrap_message(True, False, trailer_message)
trailer_data = wrap_message(True, False, trailer_message)
else:
trailer_data = b""

Expand All @@ -222,25 +237,27 @@ async def _do_unary_response(
{"type": "http.response.body", "body": trailer_data, "more_body": False}
)

async def _do_grpc_error(self, context, send):
headers = [
(b"Content-Type", b"application/grpc-web+proto"),
(b"Access-Control-Allow-Origin", b"*"),
(b"Access-Control-Expose-Headers", b"*"),
]

status = protocol.grpc_status_to_http_status(context.code)
async def _do_grpc_error(self, send, context):
status = 200
headers = context._response_headers
headers.append((b"grpc-status", str(context.code.value[0]).encode()))

if context.details:
headers.append((b"grpc-message", quote(context.details).encode()))
headers.append(
(b"grpc-message", quote(context.details.encode("utf8")).encode("ascii"))
)

await send(
{"type": "http.response.start", "status": status, "headers": headers}
)
await send({"type": "http.response.body", "body": b"", "more_body": False})

async def _do_cors_preflight(self, scope, receive, send):
origin = next(
(value for header, value in scope["headers"] if header == "host"),
scope["server"][0],
)

await send(
{
"type": "http.response.start",
Expand All @@ -250,7 +267,7 @@ async def _do_cors_preflight(self, scope, receive, send):
(b"Content-Length", b"0"),
(b"Access-Control-Allow-Methods", b"POST, OPTIONS"),
(b"Access-Control-Allow-Headers", b"*"),
(b"Access-Control-Allow-Origin", b"*"),
(b"Access-Control-Allow-Origin", origin),
(b"Access-Control-Allow-Credentials", b"true"),
(b"Access-Control-Expose-Headers", b"*"),
],
Expand Down Expand Up @@ -290,8 +307,46 @@ def __init__(self, timeout=None, metadata=None):
self._initial_metadata = None
self._trailing_metadata = None

response_content_type = "application/grpc-web+proto"

self._wrap_message = protocol.wrap_message
self._unwrap_message = protocol.unwrap_message_asgi
origin = None

for header, value in metadata:
if header == "content-type":
if value == "application/grpc-web-text":
self._wrap_message = protocol.b64_wrap_message
self._unwrap_message = protocol.b64_unwrap_message_asgi
elif header == "accept":
response_content_type = value.split(",")[0].strip()
elif header == "host":
origin = value

if not origin:
raise ValueError("Request is missing the host header")

self._response_headers = [
(b"Content-Type", response_content_type.encode("ascii")),
(b"Access-Control-Allow-Origin", origin.encode("ascii")),
(b"Access-Control-Expose-Headers", b"*"),
]

def set_code(self, code):
self.code = code
if isinstance(code, grpc.StatusCode):
self.code = code

elif isinstance(code, int):
for status_code in grpc.StatusCode:
if status_code.value[0] == code:
self.code = status_code
break
else:
raise ValueError(f"Unknown StatusCode: {code}")
else:
raise NotImplementedError(
f"Unsupported status code type: {type(code)} with value {code}"
)

def set_details(self, details):
self.details = details
Expand All @@ -315,11 +370,11 @@ async def abort_with_status(self, status):

async def send_initial_metadata(self, initial_metadata):
self._initial_metadata = [
(key.encode("ascii"), value.encode("ascii"))
(key.encode("ascii"), value.encode("utf8"))
for key, value in protocol.encode_headers(initial_metadata)
]

async def set_trailing_metadata(self, trailing_metadata):
def set_trailing_metadata(self, trailing_metadata):
self._trailing_metadata = protocol.encode_headers(trailing_metadata)

def invocation_metadata(self):
Expand Down
5 changes: 4 additions & 1 deletion sonora/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ async def unwrap_message_asgi(receive, decoder=None):
break


b64_unwrap_message_asgi = functools.partial(unwrap_message_asgi, decoder=base64.b64decode)
b64_unwrap_message_asgi = functools.partial(
unwrap_message_asgi, decoder=base64.b64decode
)


def pack_trailers(trailers):
Expand Down Expand Up @@ -149,6 +151,7 @@ def encode_headers(metadata):

yield header, value


class WebRpcError(grpc.RpcError):
_code_to_enum = {code.value[0]: code for code in grpc.StatusCode}

Expand Down

0 comments on commit 6c781e9

Please sign in to comment.