diff --git a/conftest.py b/conftest.py index 02aa6eb..4196825 100644 --- a/conftest.py +++ b/conftest.py @@ -1,5 +1,10 @@ +import asyncio +import typing as t + import pytest +from htpy import Node, aiter_node, iter_node + @pytest.fixture(scope="session") def django_env() -> None: @@ -13,3 +18,23 @@ def django_env() -> None: ] ) django.setup() + + +@pytest.fixture(params=["sync", "async"]) +def to_list(request: pytest.FixtureRequest) -> t.Any: + def func(node: Node) -> t.Any: + if request.param == "sync": + return list(iter_node(node)) + else: + + async def get_list() -> t.Any: + return [chunk async for chunk in aiter_node(node)] + + return asyncio.run(get_list(), debug=True) + + return func + + +@pytest.fixture +def to_str(to_list: t.Any) -> t.Any: + return lambda node: "".join(to_list(node)) diff --git a/docs/assets/starlette.webm b/docs/assets/starlette.webm new file mode 100644 index 0000000..0587c96 Binary files /dev/null and b/docs/assets/starlette.webm differ diff --git a/docs/streaming.md b/docs/streaming.md index 403df6b..768c4e3 100644 --- a/docs/streaming.md +++ b/docs/streaming.md @@ -1,7 +1,7 @@ # Streaming of Contents Internally, htpy is built with generators. Most of the time, you would render -the full page with `str()`, but htpy can also incrementally generate pages which +the full page with `str()`, but htpy can also incrementally generate pages synchronously or asynchronous which can then be streamed to the browser. If your page uses a database or other services to retrieve data, you can sending the first part of the page to the client while the page is being generated. @@ -16,7 +16,6 @@ client while the page is being generated. This video shows what it looks like in the browser to generate a HTML table with [Django StreamingHttpResponse](https://docs.djangoproject.com/en/5.0/ref/request-response/#django.http.StreamingHttpResponse) ([source code](https://github.com/pelme/htpy/blob/main/examples/djangoproject/stream/views.py)): @@ -111,3 +110,59 @@ print( # output:

Fibonacci!

fib(12)=6765
``` + + +## Asynchronous streaming + +It is also possible to use htpy to stream fully asynchronous. This intended to be used +with ASGI/async web frameworks/servers such as Starlette and Django. You can +build htpy components using Python's `asyncio` module and the `async`/`await` +syntax. + +### Starlette, ASGI and uvicorn example + +```python +title="starlette_demo.py" +import asyncio +from collections.abc import AsyncIterator + +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import StreamingResponse + +from htpy import Element, div, h1, li, p, ul + +app = Starlette(debug=True) + + +@app.route("/") +async def index(request: Request) -> StreamingResponse: + return StreamingResponse(await index_page(), media_type="text/html") + + +async def index_page() -> Element: + return div[ + h1["Starlette Async example"], + p["This page is generated asynchronously using Starlette and ASGI."], + ul[(li[str(num)] async for num in slow_numbers(1, 10))], + ] + + +async def slow_numbers(minimum: int, maximum: int) -> AsyncIterator[int]: + for number in range(minimum, maximum + 1): + yield number + await asyncio.sleep(0.5) + +``` + +Run with [uvicorn](https://www.uvicorn.org/): + + +``` +$ uvicorn starlette_demo:app +``` + +In the browser, it looks like this: + diff --git a/examples/async_coroutine.py b/examples/async_coroutine.py new file mode 100644 index 0000000..cd0d9a2 --- /dev/null +++ b/examples/async_coroutine.py @@ -0,0 +1,24 @@ +import asyncio +import random + +from htpy import Element, b, div, h1 + + +async def magic_number() -> Element: + await asyncio.sleep(2) + return b[f"The Magic Number is: {random.randint(1, 100)}"] + + +async def my_component() -> Element: + return div[ + h1["The Magic Number"], + magic_number(), + ] + + +async def main() -> None: + async for chunk in await my_component(): + print(chunk) + + +asyncio.run(main()) diff --git a/examples/starlette_app.py b/examples/starlette_app.py new file mode 100644 index 0000000..f06b57c --- /dev/null +++ b/examples/starlette_app.py @@ -0,0 +1,35 @@ +import asyncio +from collections.abc import AsyncIterator + +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import StreamingResponse +from starlette.routing import Route + +from htpy import Element, div, h1, li, p, ul + + +async def index(request: Request) -> StreamingResponse: + return StreamingResponse(await index_page(), media_type="text/html") + + +async def index_page() -> Element: + return div[ + h1["Starlette Async example"], + p["This page is generated asynchronously using Starlette and ASGI."], + ul[(li[str(num)] async for num in slow_numbers(1, 10))], + ] + + +async def slow_numbers(minimum: int, maximum: int) -> AsyncIterator[int]: + for number in range(minimum, maximum + 1): + yield number + await asyncio.sleep(0.5) + + +app = Starlette( + debug=True, + routes=[ + Route("/", index), + ], +) diff --git a/htpy/__init__.py b/htpy/__init__.py index 46ee5c0..bf3d050 100644 --- a/htpy/__init__.py +++ b/htpy/__init__.py @@ -6,7 +6,15 @@ import dataclasses import functools import typing as t -from collections.abc import Callable, Generator, Iterable, Iterator +from collections.abc import ( + AsyncIterable, + AsyncIterator, + Awaitable, + Callable, + Generator, + Iterable, + Iterator, +) from markupsafe import Markup as _Markup from markupsafe import escape as _escape @@ -137,6 +145,10 @@ class ContextConsumer(t.Generic[T]): func: Callable[[T], Node] +def _is_noop_node(x: Node) -> bool: + return x is None or x is True or x is False + + class _NO_DEFAULT: pass @@ -168,15 +180,8 @@ def _iter_node_context(x: Node, context_dict: dict[Context[t.Any], t.Any]) -> It while not isinstance(x, BaseElement) and callable(x): x = x() - if x is None: - return - - if x is True: - return - - if x is False: + if _is_noop_node(x): return - if isinstance(x, BaseElement): yield from x._iter_context(context_dict) # pyright: ignore [reportPrivateUsage] elif isinstance(x, ContextProvider): @@ -196,10 +201,72 @@ def _iter_node_context(x: Node, context_dict: dict[Context[t.Any], t.Any]) -> It elif isinstance(x, Iterable) and not isinstance(x, _KnownInvalidChildren): # pyright: ignore [reportUnnecessaryIsInstance] for child in x: yield from _iter_node_context(child, context_dict) + elif isinstance(x, Awaitable | AsyncIterable): # pyright: ignore[reportUnnecessaryIsInstance] + raise ValueError( + f"{x!r} is not a valid child element. " + "Use async iteration to retrieve element content: https://htpy.dev/streaming/" + ) else: raise ValueError(f"{x!r} is not a valid child element") +def aiter_node(x: Node) -> AsyncIterator[str]: + return _aiter_node_context(x, {}) + + +async def _aiter_node_context( + x: Node, context_dict: dict[Context[t.Any], t.Any] +) -> AsyncIterator[str]: + while True: + if isinstance(x, Awaitable): + x = await x + continue + + if not isinstance(x, BaseElement) and callable(x): + x = x() + continue + + break + + if _is_noop_node(x): + return + + if isinstance(x, BaseElement): + async for child in x._aiter_context(context_dict): # pyright: ignore [reportPrivateUsage] + yield child + elif isinstance(x, ContextProvider): + async for chunk in _aiter_node_context( + x.func(), + {**context_dict, x.context: x.value}, # pyright: ignore [reportUnknownMemberType] + ): + yield chunk + + elif isinstance(x, ContextConsumer): + context_value = context_dict.get(x.context, x.context.default) + if context_value is _NO_DEFAULT: + raise LookupError( + f'Context value for "{x.context.name}" does not exist, ' + f"requested by {x.debug_name}()." + ) + async for chunk in _aiter_node_context(x.func(context_value), context_dict): + yield chunk + + elif isinstance(x, str | _HasHtml): + yield str(_escape(x)) + elif isinstance(x, int): + yield str(x) + elif isinstance(x, Iterable): + for child in x: # type: ignore[assignment] + async for chunk in _aiter_node_context(child, context_dict): + yield chunk + elif isinstance(x, AsyncIterable): # pyright: ignore[reportUnnecessaryIsInstance] + async for child in x: # type: ignore[assignment] + async for chunk in _aiter_node_context(child, context_dict): # pyright: ignore[reportUnknownArgumentType] + yield chunk + else: + raise ValueError(f"{x!r} is not a valid async child element") + + @functools.lru_cache(maxsize=300) def _get_element(name: str) -> Element: if not name.islower(): @@ -226,7 +293,10 @@ def __str__(self) -> _Markup: @t.overload def __call__( - self: BaseElementSelf, id_class: str, attrs: dict[str, Attribute], **kwargs: Attribute + self: BaseElementSelf, + id_class: str, + attrs: dict[str, Attribute], + **kwargs: Attribute, ) -> BaseElementSelf: ... @t.overload def __call__( @@ -267,6 +337,15 @@ def __call__(self: BaseElementSelf, *args: t.Any, **kwargs: t.Any) -> BaseElemen self._children, ) + async def _aiter_context(self, context: dict[Context[t.Any], t.Any]) -> AsyncIterator[str]: + yield f"<{self._name}{self._attrs}>" + async for x in _aiter_node_context(self._children, context): + yield x + yield f"" + + def __aiter__(self) -> AsyncIterator[str]: + return self._aiter_context({}) + def __iter__(self) -> Iterator[str]: return self._iter_context({}) @@ -275,9 +354,6 @@ def _iter_context(self, ctx: dict[Context[t.Any], t.Any]) -> Iterator[str]: yield from _iter_node_context(self._children, ctx) yield f"" - def __repr__(self) -> str: - return f"<{self.__class__.__name__} '{self}'>" - # Allow starlette Response.render to directly render this element without # explicitly casting to str: # https://github.com/encode/starlette/blob/5ed55c441126687106109a3f5e051176f88cd3e6/starlette/responses.py#L44-L49 @@ -308,17 +384,31 @@ def __getitem__(self: ElementSelf, children: Node) -> ElementSelf: _validate_children(children) return self.__class__(self._name, self._attrs, children) # pyright: ignore [reportUnknownArgumentType] + def __repr__(self) -> str: + return f"<{self.__class__.__name__} '<{self._name}{self._attrs}>...'>" + class HTMLElement(Element): def _iter_context(self, ctx: dict[Context[t.Any], t.Any]) -> Iterator[str]: yield "" yield from super()._iter_context(ctx) + async def _aiter_context(self, context: dict[Context[t.Any], t.Any]) -> AsyncIterator[str]: + yield "" + async for x in super()._aiter_context(context): + yield x + class VoidElement(BaseElement): + async def _aiter_context(self, context: dict[Context[t.Any], t.Any]) -> AsyncIterator[str]: + yield f"<{self._name}{self._attrs}>" + def _iter_context(self, ctx: dict[Context[t.Any], t.Any]) -> Iterator[str]: yield f"<{self._name}{self._attrs}>" + def __repr__(self) -> str: + return f"<{self.__class__.__name__} '<{self._name}{self._attrs}>'>" + def render_node(node: Node) -> _Markup: return _Markup("".join(iter_node(node))) @@ -347,6 +437,8 @@ def __html__(self) -> str: ... | Callable[[], "Node"] | ContextProvider[t.Any] | ContextConsumer[t.Any] + | AsyncIterable["Node"] + | Awaitable["Node"] ) Attribute: t.TypeAlias = None | bool | str | int | _HasHtml | _ClassNames @@ -480,6 +572,8 @@ def __html__(self) -> str: ... _KnownValidChildren: UnionType = ( # pyright: ignore [reportUnknownVariableType] None | BaseElement + | AsyncIterable # pyright: ignore [reportMissingTypeArgument] + | Awaitable # pyright: ignore [reportMissingTypeArgument] | ContextProvider # pyright: ignore [reportMissingTypeArgument] | ContextConsumer # pyright: ignore [reportMissingTypeArgument] | Callable # pyright: ignore [reportMissingTypeArgument] diff --git a/htpy/starlette.py b/htpy/starlette.py new file mode 100644 index 0000000..e448bd9 --- /dev/null +++ b/htpy/starlette.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import typing as t + +from starlette.responses import StreamingResponse + +from . import aiter_node + +if t.TYPE_CHECKING: + from starlette.background import BackgroundTask + + from . import Node + + +class HtpyResponse(StreamingResponse): + def __init__( + self, + content: Node, + status_code: int = 200, + headers: t.Mapping[str, str] | None = None, + media_type: str | None = "text/html", + background: BackgroundTask | None = None, + ): + super().__init__( + aiter_node(content), + status_code=status_code, + headers=headers, + media_type=media_type, + background=background, + ) diff --git a/pyproject.toml b/pyproject.toml index 24cd8d2..6c5bf64 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ optional-dependencies.dev = [ "mypy", "pyright", "pytest", + "pytest-asyncio", "black", "ruff", "django", diff --git a/tests/test_async.py b/tests/test_async.py new file mode 100644 index 0000000..245cb04 --- /dev/null +++ b/tests/test_async.py @@ -0,0 +1,43 @@ +from collections.abc import AsyncIterator + +import pytest + +from htpy import Element, li, ul + + +async def async_lis() -> AsyncIterator[Element]: + yield li["a"] + yield li["b"] + + +async def hi() -> Element: + return li["hi"] + + +@pytest.mark.asyncio +async def test_async_iterator() -> None: + result = [chunk async for chunk in ul[async_lis()]] + assert result == [""] + + +@pytest.mark.asyncio +async def test_cororoutinefunction_children() -> None: + result = [chunk async for chunk in ul[hi]] + assert result == [""] + + +@pytest.mark.asyncio +async def test_cororoutine_children() -> None: + result = [chunk async for chunk in ul[hi()]] + assert result == [""] + + +def test_sync_iteration_with_async_children() -> None: + with pytest.raises( + ValueError, + match=( + r" is not a valid child element\. " + r"Use async iteration to retrieve element content: https://htpy.dev/streaming/" + ), + ): + str(ul[async_lis()]) diff --git a/tests/test_attributes.py b/tests/test_attributes.py index 703a73f..cb5f6dc 100644 --- a/tests/test_attributes.py +++ b/tests/test_attributes.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import typing as t import pytest @@ -5,33 +7,36 @@ from htpy import button, div, th +if t.TYPE_CHECKING: + from .types import ToStr + def test_attribute() -> None: assert str(div(id="hello")["hi"]) == '
hi
' class Test_class_names: - def test_str(self) -> None: + def test_str(self, to_str: ToStr) -> None: result = div(class_='">foo bar') - assert str(result) == '
' + assert to_str(result) == '
' - def test_safestring(self) -> None: + def test_safestring(self, to_str: ToStr) -> None: result = div(class_=Markup('">foo bar')) - assert str(result) == '
' + assert to_str(result) == '
' - def test_list(self) -> None: + def test_list(self, to_str: ToStr) -> None: result = div(class_=['">foo', Markup('">bar'), False, None, "", "baz"]) - assert str(result) == '
' + assert to_str(result) == '
' - def test_tuple(self) -> None: + def test_tuple(self, to_str: ToStr) -> None: result = div(class_=('">foo', Markup('">bar'), False, None, "", "baz")) - assert str(result) == '
' + assert to_str(result) == '
' - def test_dict(self) -> None: + def test_dict(self, to_str: ToStr) -> None: result = div(class_={'">foo': True, Markup('">bar'): True, "x": False, "baz": True}) - assert str(result) == '
' + assert to_str(result) == '
' - def test_nested_dict(self) -> None: + def test_nested_dict(self, to_str: ToStr) -> None: result = div( class_=[ '">list-foo', @@ -39,54 +44,54 @@ def test_nested_dict(self) -> None: {'">dict-foo': True, Markup('">list-bar'): True, "x": False}, ] ) - assert str(result) == ( + assert to_str(result) == ( '
' ) - def test_false(self) -> None: - result = str(div(class_=False)) + def test_false(self, to_str: ToStr) -> None: + result = to_str(div(class_=False)) assert result == "
" - def test_none(self) -> None: - result = str(div(class_=None)) + def test_none(self, to_str: ToStr) -> None: + result = to_str(div(class_=None)) assert result == "
" - def test_no_classes(self) -> None: - result = str(div(class_={"foo": False})) + def test_no_classes(self, to_str: ToStr) -> None: + result = to_str(div(class_={"foo": False})) assert result == "
" -def test_dict_attributes() -> None: +def test_dict_attributes(to_str: ToStr) -> None: result = div({"@click": 'hi = "hello"'}) - assert str(result) == """
""" + assert to_str(result) == """
""" -def test_underscore() -> None: +def test_underscore(to_str: ToStr) -> None: # Hyperscript (https://hyperscript.org/) uses _, make sure it works good. result = div(_="foo") - assert str(result) == """
""" + assert to_str(result) == """
""" -def test_dict_attributes_avoid_replace() -> None: +def test_dict_attributes_avoid_replace(to_str: ToStr) -> None: result = div({"class_": "foo", "hello_hi": "abc"}) - assert str(result) == """
""" + assert to_str(result) == """
""" -def test_dict_attribute_false() -> None: +def test_dict_attribute_false(to_str: ToStr) -> None: result = div({"bool-false": False}) - assert str(result) == "
" + assert to_str(result) == "
" -def test_dict_attribute_true() -> None: +def test_dict_attribute_true(to_str: ToStr) -> None: result = div({"bool-true": True}) - assert str(result) == "
" + assert to_str(result) == "
" -def test_underscore_replacement() -> None: +def test_underscore_replacement(to_str: ToStr) -> None: result = button(hx_post="/foo")["click me!"] - assert str(result) == """""" + assert to_str(result) == """""" class Test_attribute_escape: @@ -98,54 +103,53 @@ class Test_attribute_escape: ], ) - def test_dict(self, x: str) -> None: + def test_dict(self, x: str, to_str: ToStr) -> None: result = div({x: x}) - assert str(result) == """
""" + assert to_str(result) == """
""" - def test_kwarg(self, x: str) -> None: + def test_kwarg(self, x: str, to_str: ToStr) -> None: result = div(**{x: x}) - assert str(result) == """
""" + assert to_str(result) == """
""" -def test_boolean_attribute_true() -> None: +def test_boolean_attribute_true(to_str: ToStr) -> None: result = button(disabled=True) - assert str(result) == "" + assert to_str(result) == "" -def test_kwarg_attribute_none() -> None: +def test_kwarg_attribute_none(to_str: ToStr) -> None: result = div(foo=None) - assert str(result) == "
" + assert to_str(result) == "
" -def test_dict_attribute_none() -> None: +def test_dict_attribute_none(to_str: ToStr) -> None: result = div({"foo": None}) - assert str(result) == "
" + assert to_str(result) == "
" -def test_boolean_attribute_false() -> None: +def test_boolean_attribute_false(to_str: ToStr) -> None: result = button(disabled=False) - assert str(result) == "" + assert to_str(result) == "" -def test_integer_attribute() -> None: +def test_integer_attribute(to_str: ToStr) -> None: result = th(colspan=123) - assert str(result) == '' + assert to_str(result) == '' -def test_id_class() -> None: +def test_id_class(to_str: ToStr) -> None: result = div("#myid.cls1.cls2") - - assert str(result) == """
""" + assert to_str(result) == """
""" -def test_id_class_only_id() -> None: +def test_id_class_only_id(to_str: ToStr) -> None: result = div("#myid") - assert str(result) == """
""" + assert to_str(result) == """
""" -def test_id_class_only_classes() -> None: +def test_id_class_only_classes(to_str: ToStr) -> None: result = div(".foo.bar") - assert str(result) == """
""" + assert to_str(result) == """
""" def test_id_class_wrong_order() -> None: @@ -163,39 +167,39 @@ def test_id_class_bad_type() -> None: div({"oops": "yes"}, {}) # type: ignore -def test_id_class_and_kwargs() -> None: +def test_id_class_and_kwargs(to_str: ToStr) -> None: result = div("#theid", for_="hello", data_foo="""" + assert to_str(result) == """
""" -def test_attrs_and_kwargs() -> None: +def test_attrs_and_kwargs(to_str: ToStr) -> None: result = div({"a": "1", "for": "a"}, for_="b", b="2") - assert str(result) == """
""" + assert to_str(result) == """
""" -def test_class_priority() -> None: +def test_class_priority(to_str: ToStr) -> None: result = div(".a", {"class": "b"}, class_="c") - assert str(result) == """
""" + assert to_str(result) == """
""" result = div(".a", {"class": "b"}) - assert str(result) == """
""" + assert to_str(result) == """
""" -def test_attribute_priority() -> None: +def test_attribute_priority(to_str: ToStr) -> None: result = div({"foo": "a"}, foo="b") - assert str(result) == """
""" + assert to_str(result) == """
""" @pytest.mark.parametrize("not_an_attr", [1234, b"foo", object(), object, 1, 0, None]) -def test_invalid_attribute_key(not_an_attr: t.Any) -> None: +def test_invalid_attribute_key(not_an_attr: t.Any, to_str: ToStr) -> None: with pytest.raises(ValueError, match="Attribute key must be a string"): - str(div({not_an_attr: "foo"})) + to_str(div({not_an_attr: "foo"})) @pytest.mark.parametrize( "not_an_attr", [12.34, b"foo", object(), object], ) -def test_invalid_attribute_value(not_an_attr: t.Any) -> None: +def test_invalid_attribute_value(not_an_attr: t.Any, to_str: ToStr) -> None: with pytest.raises(ValueError, match="Attribute value must be a string"): div(foo=not_an_attr) diff --git a/tests/test_children.py b/tests/test_children.py index 2e35528..039da53 100644 --- a/tests/test_children.py +++ b/tests/test_children.py @@ -18,58 +18,59 @@ from htpy import Node + from .types import ToList, ToStr -def test_void_element() -> None: - element = input(name="foo") - assert_type(element, VoidElement) - assert isinstance(element, VoidElement) - result = str(element) - assert str(result) == '' +def test_void_element(to_str: ToStr) -> None: + result = input(name="foo") + assert_type(result, VoidElement) + assert isinstance(result, VoidElement) + assert to_str(result) == '' -def test_children() -> None: - assert str(div[img]) == "
" +def test_integer_child(to_str: ToStr) -> None: + assert to_str(div[123]) == "
123
" -def test_integer_child() -> None: - assert str(div[123]) == "
123
" +def test_children(to_str: ToStr) -> None: + assert to_str(div[img]) == "
" -def test_multiple_children() -> None: + +def test_multiple_children(to_str: ToStr) -> None: result = ul[li, li] - assert str(result) == "" + assert to_str(result) == "" -def test_list_children() -> None: +def test_list_children(to_str: ToStr) -> None: children: list[Element] = [li["a"], li["b"]] result = ul[children] - assert str(result) == "" + assert to_str(result) == "" -def test_tuple_children() -> None: +def test_tuple_children(to_str: ToStr) -> None: result = ul[(li["a"], li["b"])] - assert str(result) == "" + assert to_str(result) == "" -def test_flatten_nested_children() -> None: +def test_flatten_nested_children(to_str: ToStr) -> None: result = dl[ [ (dt["a"], dd["b"]), (dt["c"], dd["d"]), ] ] - assert str(result) == """
a
b
c
d
""" + assert to_str(result) == """
a
b
c
d
""" -def test_flatten_very_nested_children() -> None: +def test_flatten_very_nested_children(to_str: ToStr) -> None: # maybe not super useful but the nesting may be arbitrarily deep result = div[[([["a"]],)], [([["b"]],)]] - assert str(result) == """
ab
""" + assert to_str(result) == """
ab
""" -def test_flatten_nested_generators() -> None: +def test_flatten_nested_generators(to_str: ToStr) -> None: def cols() -> Generator[str, None, None]: yield "a" yield "b" @@ -82,43 +83,43 @@ def rows() -> Generator[Generator[str, None, None], None, None]: result = div[rows()] - assert str(result) == """
abcabcabc
""" + assert to_str(result) == """
abcabcabc
""" -def test_generator_children() -> None: +def test_generator_children(to_str: ToStr) -> None: gen: Generator[Element, None, None] = (li[x] for x in ["a", "b"]) result = ul[gen] - assert str(result) == "" + assert to_str(result) == "" -def test_html_tag_with_doctype() -> None: +def test_html_tag_with_doctype(to_str: ToStr) -> None: result = html(foo="bar")["hello"] - assert str(result) == 'hello' + assert to_str(result) == 'hello' -def test_void_element_children() -> None: +def test_void_element_children(to_str: ToStr) -> None: with pytest.raises(TypeError): img["hey"] # type: ignore[index] -def test_call_without_args() -> None: +def test_call_without_args(to_str: ToStr) -> None: result = img() - assert str(result) == "" + assert to_str(result) == "" -def test_custom_element() -> None: - el = my_custom_element() - assert_type(el, Element) - assert isinstance(el, Element) - assert str(el) == "" +def test_custom_element(to_str: ToStr) -> None: + result = my_custom_element() + assert_type(result, Element) + assert isinstance(result, Element) + assert to_str(result) == "" @pytest.mark.parametrize("ignored_value", [None, True, False]) -def test_ignored(ignored_value: t.Any) -> None: - assert str(div[ignored_value]) == "
" +def test_ignored(to_str: ToStr, ignored_value: t.Any) -> None: + assert to_str(div[ignored_value]) == "
" -def test_iter() -> None: +def test_sync_iter() -> None: trace = "not started" def generate_list() -> Generator[Element, None, None]: @@ -143,8 +144,8 @@ def generate_list() -> Generator[Element, None, None]: assert trace == "done" -def test_iter_str() -> None: - _, child, _ = div["a"] +def test_iter_str(to_list: ToList) -> None: + _, child, _ = to_list(div["a"]) assert child == "a" # Make sure we dont get Markup (subclass of str) diff --git a/tests/test_comment.py b/tests/test_comment.py index 4844edd..9049c4f 100644 --- a/tests/test_comment.py +++ b/tests/test_comment.py @@ -1,17 +1,24 @@ +from __future__ import annotations + +import typing as t + from htpy import comment, div +if t.TYPE_CHECKING: + from .types import ToStr + -def test_simple() -> None: - assert str(div[comment("hi")]) == "
" +def test_simple(to_str: ToStr) -> None: + assert to_str(div[comment("hi")]) == "
" -def test_escape_two_dashes() -> None: - assert str(div[comment("foo--bar")]) == "
" +def test_escape_two_dashes(to_str: ToStr) -> None: + assert to_str(div[comment("foo--bar")]) == "
" -def test_escape_three_dashes() -> None: - assert str(div[comment("foo---bar")]) == "
" +def test_escape_three_dashes(to_str: ToStr) -> None: + assert to_str(div[comment("foo---bar")]) == "
" -def test_escape_four_dashes() -> None: - assert str(div[comment("foo----bar")]) == "
" +def test_escape_four_dashes(to_str: ToStr) -> None: + assert to_str(div[comment("foo----bar")]) == "
" diff --git a/tests/test_context.py b/tests/test_context.py index 77a88a2..c538222 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -1,9 +1,14 @@ +from __future__ import annotations + import typing as t import pytest from htpy import Context, Node, div +if t.TYPE_CHECKING: + from .types import ToStr + letter_ctx: Context[t.Literal["a", "b", "c"]] = Context("letter", default="a") no_default_ctx = Context[str]("no_default") @@ -18,25 +23,25 @@ def display_no_default(value: str) -> str: return f"{value=}" -def test_context_default() -> None: +def test_context_default(to_str: ToStr) -> None: result = div[display_letter("Yo")] - assert str(result) == "
Yo: a!
" + assert to_str(result) == "
Yo: a!
" -def test_context_provider() -> None: +def test_context_provider(to_str: ToStr) -> None: result = letter_ctx.provider("c", lambda: div[display_letter("Hello")]) - assert str(result) == "
Hello: c!
" + assert to_str(result) == "
Hello: c!
" -def test_no_default() -> None: +def test_no_default(to_str: ToStr) -> None: with pytest.raises( LookupError, match='Context value for "no_default" does not exist, requested by display_no_default()', ): - str(div[display_no_default()]) + to_str(div[display_no_default()]) -def test_nested_override() -> None: +def test_nested_override(to_str: ToStr) -> None: result = div[ letter_ctx.provider( "b", @@ -46,10 +51,10 @@ def test_nested_override() -> None: ), ) ] - assert str(result) == "
Nested: c!
" + assert to_str(result) == "
Nested: c!
" -def test_multiple_consumers() -> None: +def test_multiple_consumers(to_str: ToStr) -> None: a_ctx: Context[t.Literal["a"]] = Context("a_ctx", default="a") b_ctx: Context[t.Literal["b"]] = Context("b_ctx", default="b") @@ -59,10 +64,10 @@ def ab_display(a: t.Literal["a"], b: t.Literal["b"], greeting: str) -> str: return f"{greeting} a={a}, b={b}" result = div[ab_display("Hello")] - assert str(result) == "
Hello a=a, b=b
" + assert to_str(result) == "
Hello a=a, b=b
" -def test_nested_consumer() -> None: +def test_nested_consumer(to_str: ToStr) -> None: ctx: Context[str] = Context("ctx") @ctx.consumer @@ -75,10 +80,10 @@ def inner(value: str, from_outer: str) -> Node: result = div[ctx.provider("foo", outer)] - assert str(result) == "
outer: foo, inner: foo
" + assert to_str(result) == "
outer: foo, inner: foo
" -def test_context_passed_via_iterable() -> None: +def test_context_passed_via_iterable(to_str: ToStr) -> None: ctx: Context[str] = Context("ctx") @ctx.consumer @@ -87,4 +92,4 @@ def echo(value: str) -> str: result = div[ctx.provider("foo", lambda: [echo()])] - assert str(result) == "
foo
" + assert to_str(result) == "
foo
" diff --git a/tests/test_django.py b/tests/test_django.py index b3ca37a..4ca3fba 100644 --- a/tests/test_django.py +++ b/tests/test_django.py @@ -1,9 +1,10 @@ -from typing import Any +from __future__ import annotations + +import typing as t import pytest from django.core import management from django.forms.utils import ErrorList -from django.http import HttpRequest from django.template import Context, Template, TemplateDoesNotExist from django.template.loader import render_to_string from django.utils.html import escape @@ -11,6 +12,12 @@ from htpy import Element, Node, div, li, ul +if t.TYPE_CHECKING: + from django.http import HttpRequest + + from .types import ToStr + + pytestmark = pytest.mark.usefixtures("django_env") @@ -21,26 +28,26 @@ def test_template_injection() -> None: assert result == '' -def test_SafeString() -> None: +def test_SafeString(to_str: ToStr) -> None: result = ul[SafeString("
  • hello
  • ")] - assert str(result) == "" + assert to_str(result) == "" -def test_explicit_escape() -> None: +def test_explicit_escape(to_str: ToStr) -> None: result = ul[escape("")] - assert str(result) == "
      <hello>
    " + assert to_str(result) == "
      <hello>
    " -def test_errorlist() -> None: +def test_errorlist(to_str: ToStr) -> None: result = div[ErrorList(["my error"])] - assert str(result) == """
    • my error
    """ + assert to_str(result) == """
    • my error
    """ -def my_template(context: dict[str, Any], request: HttpRequest | None) -> Element: +def my_template(context: dict[str, t.Any], request: HttpRequest | None) -> Element: return div[f"hey {context['name']}"] -def my_template_fragment(context: dict[str, Any], request: HttpRequest | None) -> Node: +def my_template_fragment(context: dict[str, t.Any], request: HttpRequest | None) -> Node: return [div[f"hey {context['name']}"]] diff --git a/tests/test_element.py b/tests/test_element.py index eac1c85..2579316 100644 --- a/tests/test_element.py +++ b/tests/test_element.py @@ -19,7 +19,7 @@ def test_invalid_element_name() -> None: def test_element_repr() -> None: - assert repr(htpy.div("#a")) == """'>""" + assert repr(htpy.div("#a")) == """...'>""" def test_void_element_repr() -> None: diff --git a/tests/test_starlette.py b/tests/test_starlette.py index 6ca9a3b..ec39229 100644 --- a/tests/test_starlette.py +++ b/tests/test_starlette.py @@ -7,7 +7,8 @@ from starlette.routing import Route from starlette.testclient import TestClient -from htpy import h1 +from htpy import Element, h1, p +from htpy.starlette import HtpyResponse if t.TYPE_CHECKING: from starlette.requests import Request @@ -17,11 +18,25 @@ async def html_response(request: Request) -> HTMLResponse: return HTMLResponse(h1["Hello, HTMLResponse!"]) +async def stuff() -> Element: + return p["stuff"] + + +async def htpy_response(request: Request) -> HtpyResponse: + return HtpyResponse( + ( + h1["Hello, HtpyResponse!"], + stuff(), + ) + ) + + client = TestClient( Starlette( debug=True, routes=[ Route("/html-response", html_response), + Route("/htpy-response", htpy_response), ], ) ) @@ -29,4 +44,10 @@ async def html_response(request: Request) -> HTMLResponse: def test_html_response() -> None: response = client.get("/html-response") - assert response.content == b"

    Hello, HTMLResponse!

    " + assert response.text == "

    Hello, HTMLResponse!

    " + + +def test_htpy_response() -> None: + response = client.get("/htpy-response") + assert response.headers["content-type"] == "text/html; charset=utf-8" + assert response.text == "

    Hello, HtpyResponse!

    stuff

    " diff --git a/tests/types.py b/tests/types.py new file mode 100644 index 0000000..8691d5a --- /dev/null +++ b/tests/types.py @@ -0,0 +1,7 @@ +import typing as t +from collections.abc import Callable + +from htpy import Node + +ToStr: t.TypeAlias = Callable[[Node], str] +ToList: t.TypeAlias = Callable[[Node], list[str]]