diff --git a/conftest.py b/conftest.py
index 02aa6eb..4196825 100644
--- a/conftest.py
+++ b/conftest.py
@@ -1,5 +1,10 @@
+import asyncio
+import typing as t
+
import pytest
+from htpy import Node, aiter_node, iter_node
+
@pytest.fixture(scope="session")
def django_env() -> None:
@@ -13,3 +18,23 @@ def django_env() -> None:
]
)
django.setup()
+
+
+@pytest.fixture(params=["sync", "async"])
+def to_list(request: pytest.FixtureRequest) -> t.Any:
+ def func(node: Node) -> t.Any:
+ if request.param == "sync":
+ return list(iter_node(node))
+ else:
+
+ async def get_list() -> t.Any:
+ return [chunk async for chunk in aiter_node(node)]
+
+ return asyncio.run(get_list(), debug=True)
+
+ return func
+
+
+@pytest.fixture
+def to_str(to_list: t.Any) -> t.Any:
+ return lambda node: "".join(to_list(node))
diff --git a/docs/assets/starlette.webm b/docs/assets/starlette.webm
new file mode 100644
index 0000000..0587c96
Binary files /dev/null and b/docs/assets/starlette.webm differ
diff --git a/docs/streaming.md b/docs/streaming.md
index 403df6b..768c4e3 100644
--- a/docs/streaming.md
+++ b/docs/streaming.md
@@ -1,7 +1,7 @@
# Streaming of Contents
Internally, htpy is built with generators. Most of the time, you would render
-the full page with `str()`, but htpy can also incrementally generate pages which
+the full page with `str()`, but htpy can also incrementally generate pages synchronously or asynchronous which
can then be streamed to the browser. If your page uses a database or other
services to retrieve data, you can sending the first part of the page to the
client while the page is being generated.
@@ -16,7 +16,6 @@ client while the page is being generated.
This video shows what it looks like in the browser to generate a HTML table with [Django StreamingHttpResponse](https://docs.djangoproject.com/en/5.0/ref/request-response/#django.http.StreamingHttpResponse) ([source code](https://github.com/pelme/htpy/blob/main/examples/djangoproject/stream/views.py)):
@@ -111,3 +110,59 @@ print(
# output:
Fibonacci!
fib(12)=6765
```
+
+
+## Asynchronous streaming
+
+It is also possible to use htpy to stream fully asynchronous. This intended to be used
+with ASGI/async web frameworks/servers such as Starlette and Django. You can
+build htpy components using Python's `asyncio` module and the `async`/`await`
+syntax.
+
+### Starlette, ASGI and uvicorn example
+
+```python
+title="starlette_demo.py"
+import asyncio
+from collections.abc import AsyncIterator
+
+from starlette.applications import Starlette
+from starlette.requests import Request
+from starlette.responses import StreamingResponse
+
+from htpy import Element, div, h1, li, p, ul
+
+app = Starlette(debug=True)
+
+
+@app.route("/")
+async def index(request: Request) -> StreamingResponse:
+ return StreamingResponse(await index_page(), media_type="text/html")
+
+
+async def index_page() -> Element:
+ return div[
+ h1["Starlette Async example"],
+ p["This page is generated asynchronously using Starlette and ASGI."],
+ ul[(li[str(num)] async for num in slow_numbers(1, 10))],
+ ]
+
+
+async def slow_numbers(minimum: int, maximum: int) -> AsyncIterator[int]:
+ for number in range(minimum, maximum + 1):
+ yield number
+ await asyncio.sleep(0.5)
+
+```
+
+Run with [uvicorn](https://www.uvicorn.org/):
+
+
+```
+$ uvicorn starlette_demo:app
+```
+
+In the browser, it looks like this:
+
diff --git a/examples/async_coroutine.py b/examples/async_coroutine.py
new file mode 100644
index 0000000..cd0d9a2
--- /dev/null
+++ b/examples/async_coroutine.py
@@ -0,0 +1,24 @@
+import asyncio
+import random
+
+from htpy import Element, b, div, h1
+
+
+async def magic_number() -> Element:
+ await asyncio.sleep(2)
+ return b[f"The Magic Number is: {random.randint(1, 100)}"]
+
+
+async def my_component() -> Element:
+ return div[
+ h1["The Magic Number"],
+ magic_number(),
+ ]
+
+
+async def main() -> None:
+ async for chunk in await my_component():
+ print(chunk)
+
+
+asyncio.run(main())
diff --git a/examples/starlette_app.py b/examples/starlette_app.py
new file mode 100644
index 0000000..f06b57c
--- /dev/null
+++ b/examples/starlette_app.py
@@ -0,0 +1,35 @@
+import asyncio
+from collections.abc import AsyncIterator
+
+from starlette.applications import Starlette
+from starlette.requests import Request
+from starlette.responses import StreamingResponse
+from starlette.routing import Route
+
+from htpy import Element, div, h1, li, p, ul
+
+
+async def index(request: Request) -> StreamingResponse:
+ return StreamingResponse(await index_page(), media_type="text/html")
+
+
+async def index_page() -> Element:
+ return div[
+ h1["Starlette Async example"],
+ p["This page is generated asynchronously using Starlette and ASGI."],
+ ul[(li[str(num)] async for num in slow_numbers(1, 10))],
+ ]
+
+
+async def slow_numbers(minimum: int, maximum: int) -> AsyncIterator[int]:
+ for number in range(minimum, maximum + 1):
+ yield number
+ await asyncio.sleep(0.5)
+
+
+app = Starlette(
+ debug=True,
+ routes=[
+ Route("/", index),
+ ],
+)
diff --git a/htpy/__init__.py b/htpy/__init__.py
index 46ee5c0..bf3d050 100644
--- a/htpy/__init__.py
+++ b/htpy/__init__.py
@@ -6,7 +6,15 @@
import dataclasses
import functools
import typing as t
-from collections.abc import Callable, Generator, Iterable, Iterator
+from collections.abc import (
+ AsyncIterable,
+ AsyncIterator,
+ Awaitable,
+ Callable,
+ Generator,
+ Iterable,
+ Iterator,
+)
from markupsafe import Markup as _Markup
from markupsafe import escape as _escape
@@ -137,6 +145,10 @@ class ContextConsumer(t.Generic[T]):
func: Callable[[T], Node]
+def _is_noop_node(x: Node) -> bool:
+ return x is None or x is True or x is False
+
+
class _NO_DEFAULT:
pass
@@ -168,15 +180,8 @@ def _iter_node_context(x: Node, context_dict: dict[Context[t.Any], t.Any]) -> It
while not isinstance(x, BaseElement) and callable(x):
x = x()
- if x is None:
- return
-
- if x is True:
- return
-
- if x is False:
+ if _is_noop_node(x):
return
-
if isinstance(x, BaseElement):
yield from x._iter_context(context_dict) # pyright: ignore [reportPrivateUsage]
elif isinstance(x, ContextProvider):
@@ -196,10 +201,72 @@ def _iter_node_context(x: Node, context_dict: dict[Context[t.Any], t.Any]) -> It
elif isinstance(x, Iterable) and not isinstance(x, _KnownInvalidChildren): # pyright: ignore [reportUnnecessaryIsInstance]
for child in x:
yield from _iter_node_context(child, context_dict)
+ elif isinstance(x, Awaitable | AsyncIterable): # pyright: ignore[reportUnnecessaryIsInstance]
+ raise ValueError(
+ f"{x!r} is not a valid child element. "
+ "Use async iteration to retrieve element content: https://htpy.dev/streaming/"
+ )
else:
raise ValueError(f"{x!r} is not a valid child element")
+def aiter_node(x: Node) -> AsyncIterator[str]:
+ return _aiter_node_context(x, {})
+
+
+async def _aiter_node_context(
+ x: Node, context_dict: dict[Context[t.Any], t.Any]
+) -> AsyncIterator[str]:
+ while True:
+ if isinstance(x, Awaitable):
+ x = await x
+ continue
+
+ if not isinstance(x, BaseElement) and callable(x):
+ x = x()
+ continue
+
+ break
+
+ if _is_noop_node(x):
+ return
+
+ if isinstance(x, BaseElement):
+ async for child in x._aiter_context(context_dict): # pyright: ignore [reportPrivateUsage]
+ yield child
+ elif isinstance(x, ContextProvider):
+ async for chunk in _aiter_node_context(
+ x.func(),
+ {**context_dict, x.context: x.value}, # pyright: ignore [reportUnknownMemberType]
+ ):
+ yield chunk
+
+ elif isinstance(x, ContextConsumer):
+ context_value = context_dict.get(x.context, x.context.default)
+ if context_value is _NO_DEFAULT:
+ raise LookupError(
+ f'Context value for "{x.context.name}" does not exist, '
+ f"requested by {x.debug_name}()."
+ )
+ async for chunk in _aiter_node_context(x.func(context_value), context_dict):
+ yield chunk
+
+ elif isinstance(x, str | _HasHtml):
+ yield str(_escape(x))
+ elif isinstance(x, int):
+ yield str(x)
+ elif isinstance(x, Iterable):
+ for child in x: # type: ignore[assignment]
+ async for chunk in _aiter_node_context(child, context_dict):
+ yield chunk
+ elif isinstance(x, AsyncIterable): # pyright: ignore[reportUnnecessaryIsInstance]
+ async for child in x: # type: ignore[assignment]
+ async for chunk in _aiter_node_context(child, context_dict): # pyright: ignore[reportUnknownArgumentType]
+ yield chunk
+ else:
+ raise ValueError(f"{x!r} is not a valid async child element")
+
+
@functools.lru_cache(maxsize=300)
def _get_element(name: str) -> Element:
if not name.islower():
@@ -226,7 +293,10 @@ def __str__(self) -> _Markup:
@t.overload
def __call__(
- self: BaseElementSelf, id_class: str, attrs: dict[str, Attribute], **kwargs: Attribute
+ self: BaseElementSelf,
+ id_class: str,
+ attrs: dict[str, Attribute],
+ **kwargs: Attribute,
) -> BaseElementSelf: ...
@t.overload
def __call__(
@@ -267,6 +337,15 @@ def __call__(self: BaseElementSelf, *args: t.Any, **kwargs: t.Any) -> BaseElemen
self._children,
)
+ async def _aiter_context(self, context: dict[Context[t.Any], t.Any]) -> AsyncIterator[str]:
+ yield f"<{self._name}{self._attrs}>"
+ async for x in _aiter_node_context(self._children, context):
+ yield x
+ yield f"{self._name}>"
+
+ def __aiter__(self) -> AsyncIterator[str]:
+ return self._aiter_context({})
+
def __iter__(self) -> Iterator[str]:
return self._iter_context({})
@@ -275,9 +354,6 @@ def _iter_context(self, ctx: dict[Context[t.Any], t.Any]) -> Iterator[str]:
yield from _iter_node_context(self._children, ctx)
yield f"{self._name}>"
- def __repr__(self) -> str:
- return f"<{self.__class__.__name__} '{self}'>"
-
# Allow starlette Response.render to directly render this element without
# explicitly casting to str:
# https://github.com/encode/starlette/blob/5ed55c441126687106109a3f5e051176f88cd3e6/starlette/responses.py#L44-L49
@@ -308,17 +384,31 @@ def __getitem__(self: ElementSelf, children: Node) -> ElementSelf:
_validate_children(children)
return self.__class__(self._name, self._attrs, children) # pyright: ignore [reportUnknownArgumentType]
+ def __repr__(self) -> str:
+ return f"<{self.__class__.__name__} '<{self._name}{self._attrs}>...{self._name}>'>"
+
class HTMLElement(Element):
def _iter_context(self, ctx: dict[Context[t.Any], t.Any]) -> Iterator[str]:
yield ""
yield from super()._iter_context(ctx)
+ async def _aiter_context(self, context: dict[Context[t.Any], t.Any]) -> AsyncIterator[str]:
+ yield ""
+ async for x in super()._aiter_context(context):
+ yield x
+
class VoidElement(BaseElement):
+ async def _aiter_context(self, context: dict[Context[t.Any], t.Any]) -> AsyncIterator[str]:
+ yield f"<{self._name}{self._attrs}>"
+
def _iter_context(self, ctx: dict[Context[t.Any], t.Any]) -> Iterator[str]:
yield f"<{self._name}{self._attrs}>"
+ def __repr__(self) -> str:
+ return f"<{self.__class__.__name__} '<{self._name}{self._attrs}>'>"
+
def render_node(node: Node) -> _Markup:
return _Markup("".join(iter_node(node)))
@@ -347,6 +437,8 @@ def __html__(self) -> str: ...
| Callable[[], "Node"]
| ContextProvider[t.Any]
| ContextConsumer[t.Any]
+ | AsyncIterable["Node"]
+ | Awaitable["Node"]
)
Attribute: t.TypeAlias = None | bool | str | int | _HasHtml | _ClassNames
@@ -480,6 +572,8 @@ def __html__(self) -> str: ...
_KnownValidChildren: UnionType = ( # pyright: ignore [reportUnknownVariableType]
None
| BaseElement
+ | AsyncIterable # pyright: ignore [reportMissingTypeArgument]
+ | Awaitable # pyright: ignore [reportMissingTypeArgument]
| ContextProvider # pyright: ignore [reportMissingTypeArgument]
| ContextConsumer # pyright: ignore [reportMissingTypeArgument]
| Callable # pyright: ignore [reportMissingTypeArgument]
diff --git a/htpy/starlette.py b/htpy/starlette.py
new file mode 100644
index 0000000..e448bd9
--- /dev/null
+++ b/htpy/starlette.py
@@ -0,0 +1,30 @@
+from __future__ import annotations
+
+import typing as t
+
+from starlette.responses import StreamingResponse
+
+from . import aiter_node
+
+if t.TYPE_CHECKING:
+ from starlette.background import BackgroundTask
+
+ from . import Node
+
+
+class HtpyResponse(StreamingResponse):
+ def __init__(
+ self,
+ content: Node,
+ status_code: int = 200,
+ headers: t.Mapping[str, str] | None = None,
+ media_type: str | None = "text/html",
+ background: BackgroundTask | None = None,
+ ):
+ super().__init__(
+ aiter_node(content),
+ status_code=status_code,
+ headers=headers,
+ media_type=media_type,
+ background=background,
+ )
diff --git a/pyproject.toml b/pyproject.toml
index 24cd8d2..6c5bf64 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -26,6 +26,7 @@ optional-dependencies.dev = [
"mypy",
"pyright",
"pytest",
+ "pytest-asyncio",
"black",
"ruff",
"django",
diff --git a/tests/test_async.py b/tests/test_async.py
new file mode 100644
index 0000000..245cb04
--- /dev/null
+++ b/tests/test_async.py
@@ -0,0 +1,43 @@
+from collections.abc import AsyncIterator
+
+import pytest
+
+from htpy import Element, li, ul
+
+
+async def async_lis() -> AsyncIterator[Element]:
+ yield li["a"]
+ yield li["b"]
+
+
+async def hi() -> Element:
+ return li["hi"]
+
+
+@pytest.mark.asyncio
+async def test_async_iterator() -> None:
+ result = [chunk async for chunk in ul[async_lis()]]
+ assert result == ["
", "
", "a", "
", "
", "b", "
", "
"]
+
+
+@pytest.mark.asyncio
+async def test_cororoutinefunction_children() -> None:
+ result = [chunk async for chunk in ul[hi]]
+ assert result == ["
", "
", "hi", "
", "
"]
+
+
+@pytest.mark.asyncio
+async def test_cororoutine_children() -> None:
+ result = [chunk async for chunk in ul[hi()]]
+ assert result == ["
", "
", "hi", "
", "
"]
+
+
+def test_sync_iteration_with_async_children() -> None:
+ with pytest.raises(
+ ValueError,
+ match=(
+ r" is not a valid child element\. "
+ r"Use async iteration to retrieve element content: https://htpy.dev/streaming/"
+ ),
+ ):
+ str(ul[async_lis()])
diff --git a/tests/test_attributes.py b/tests/test_attributes.py
index 703a73f..cb5f6dc 100644
--- a/tests/test_attributes.py
+++ b/tests/test_attributes.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import typing as t
import pytest
@@ -5,33 +7,36 @@
from htpy import button, div, th
+if t.TYPE_CHECKING:
+ from .types import ToStr
+
def test_attribute() -> None:
assert str(div(id="hello")["hi"]) == '
"""
-def test_flatten_very_nested_children() -> None:
+def test_flatten_very_nested_children(to_str: ToStr) -> None:
# maybe not super useful but the nesting may be arbitrarily deep
result = div[[([["a"]],)], [([["b"]],)]]
- assert str(result) == """