From a5b6247a4614ad69f0e463a01f23015b136621ff Mon Sep 17 00:00:00 2001 From: Jake Vanderplas Date: Tue, 3 Dec 2024 12:46:41 -0800 Subject: [PATCH] Adds styling, base content, new template (#121) Co-authored-by: barnesjoseph --- .readthedocs.yaml | 2 +- docs/index.md | 51 ---- docs/requirements.txt | 1 + .../{ => source}/JAX_Vision_transformer.ipynb | 0 docs/{ => source}/JAX_Vision_transformer.md | 0 .../JAX_basic_text_classification.ipynb | 0 .../JAX_basic_text_classification.md | 0 .../JAX_examples_image_segmentation.ipynb | 2 +- .../JAX_examples_image_segmentation.md | 2 +- .../JAX_for_LLM_pretraining.ipynb | 0 docs/{ => source}/JAX_for_LLM_pretraining.md | 0 docs/{ => source}/JAX_for_PyTorch_users.ipynb | 0 docs/{ => source}/JAX_for_PyTorch_users.md | 0 docs/{ => source}/JAX_image_captioning.ipynb | 0 docs/{ => source}/JAX_image_captioning.md | 0 .../JAX_machine_translation.ipynb | 0 docs/{ => source}/JAX_machine_translation.md | 0 .../JAX_porting_PyTorch_model.ipynb | 0 .../{ => source}/JAX_porting_PyTorch_model.md | 0 .../JAX_time_series_classification.ipynb | 0 .../JAX_time_series_classification.md | 0 .../JAX_transformer_text_classification.ipynb | 0 .../JAX_transformer_text_classification.md | 0 .../JAX_visualizing_models_metrics.ipynb | 12 +- .../JAX_visualizing_models_metrics.md | 12 +- docs/source/_static/css/custom.css | 284 ++++++++++++++++++ docs/source/_static/images/ai-stack-logo.svg | 26 ++ docs/source/_static/images/compiler.svg | 239 +++++++++++++++ docs/source/_static/images/favicon.png | Bin 0 -> 3756 bytes docs/source/_static/images/hardware.svg | 96 ++++++ docs/source/_static/images/hero-radial.svg | 101 +++++++ docs/source/_static/images/hero.svg | 128 ++++++++ .../_static/images}/loss_acc_example.png | Bin .../_static/images}/model_display_example.png | Bin .../_static/images}/nnx_display_example.png | Bin .../images}/testsheet_start_example.png | Bin .../_static/images}/testsheets_500_3000.png | Bin .../_static/images}/training_data_example.png | Bin .../_static/images}/unetr_architecture.png | Bin .../images/what-is-the-jax-ai-stack.svg | 31 ++ docs/{ => source}/conf.py | 4 +- docs/{ => source}/contributing.md | 2 +- .../data_loaders_on_cpu_with_jax.ipynb | 0 .../data_loaders_on_cpu_with_jax.md | 0 .../{ => source}/digits_diffusion_model.ipynb | 0 docs/{ => source}/digits_diffusion_model.md | 0 docs/{ => source}/digits_vae.ipynb | 0 docs/{ => source}/digits_vae.md | 0 .../getting_started_with_jax_for_AI.ipynb | 0 .../getting_started_with_jax_for_AI.md | 0 docs/source/index.html | 71 +++++ docs/source/index.rst | 12 + docs/{ => source}/tutorials.md | 0 pyproject.toml | 2 +- 54 files changed, 1009 insertions(+), 69 deletions(-) delete mode 100644 docs/index.md rename docs/{ => source}/JAX_Vision_transformer.ipynb (100%) rename docs/{ => source}/JAX_Vision_transformer.md (100%) rename docs/{ => source}/JAX_basic_text_classification.ipynb (100%) rename docs/{ => source}/JAX_basic_text_classification.md (100%) rename docs/{ => source}/JAX_examples_image_segmentation.ipynb (99%) rename docs/{ => source}/JAX_examples_image_segmentation.md (99%) rename docs/{ => source}/JAX_for_LLM_pretraining.ipynb (100%) rename docs/{ => source}/JAX_for_LLM_pretraining.md (100%) rename docs/{ => source}/JAX_for_PyTorch_users.ipynb (100%) rename docs/{ => source}/JAX_for_PyTorch_users.md (100%) rename docs/{ => source}/JAX_image_captioning.ipynb (100%) rename docs/{ => source}/JAX_image_captioning.md (100%) rename docs/{ => source}/JAX_machine_translation.ipynb (100%) rename docs/{ => source}/JAX_machine_translation.md (100%) rename docs/{ => source}/JAX_porting_PyTorch_model.ipynb (100%) rename docs/{ => source}/JAX_porting_PyTorch_model.md (100%) rename docs/{ => source}/JAX_time_series_classification.ipynb (100%) rename docs/{ => source}/JAX_time_series_classification.md (100%) rename docs/{ => source}/JAX_transformer_text_classification.ipynb (100%) rename docs/{ => source}/JAX_transformer_text_classification.md (100%) rename docs/{ => source}/JAX_visualizing_models_metrics.ipynb (99%) rename docs/{ => source}/JAX_visualizing_models_metrics.md (96%) create mode 100644 docs/source/_static/css/custom.css create mode 100644 docs/source/_static/images/ai-stack-logo.svg create mode 100644 docs/source/_static/images/compiler.svg create mode 100644 docs/source/_static/images/favicon.png create mode 100644 docs/source/_static/images/hardware.svg create mode 100644 docs/source/_static/images/hero-radial.svg create mode 100644 docs/source/_static/images/hero.svg rename docs/{_static => source/_static/images}/loss_acc_example.png (100%) rename docs/{_static => source/_static/images}/model_display_example.png (100%) rename docs/{_static => source/_static/images}/nnx_display_example.png (100%) rename docs/{_static => source/_static/images}/testsheet_start_example.png (100%) rename docs/{_static => source/_static/images}/testsheets_500_3000.png (100%) rename docs/{_static => source/_static/images}/training_data_example.png (100%) rename docs/{_static => source/_static/images}/unetr_architecture.png (100%) create mode 100644 docs/source/_static/images/what-is-the-jax-ai-stack.svg rename docs/{ => source}/conf.py (96%) rename docs/{ => source}/contributing.md (98%) rename docs/{ => source}/data_loaders_on_cpu_with_jax.ipynb (100%) rename docs/{ => source}/data_loaders_on_cpu_with_jax.md (100%) rename docs/{ => source}/digits_diffusion_model.ipynb (100%) rename docs/{ => source}/digits_diffusion_model.md (100%) rename docs/{ => source}/digits_vae.ipynb (100%) rename docs/{ => source}/digits_vae.md (100%) rename docs/{ => source}/getting_started_with_jax_for_AI.ipynb (100%) rename docs/{ => source}/getting_started_with_jax_for_AI.md (100%) create mode 100644 docs/source/index.html create mode 100644 docs/source/index.rst rename docs/{ => source}/tutorials.md (100%) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 2511979..4ad7d00 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -6,7 +6,7 @@ build: python: "3.12" sphinx: - configuration: docs/conf.py + configuration: docs/source/conf.py fail_on_warning: true python: diff --git a/docs/index.md b/docs/index.md deleted file mode 100644 index 220e3db..0000000 --- a/docs/index.md +++ /dev/null @@ -1,51 +0,0 @@ -# JAX AI Stack - -`jax-ai-stack` provides a one-line installation command for JAX and associated -packages: -``` -pip install jax-ai-stack -``` -This will install mutually-compatible versions of the following packages: - -- [JAX](http://github.com/google/jax): the core JAX package, which includes array operations - and program transformations like `jit`, `vmap`, `grad`, etc. -- [flax](http://github.com/google/flax): build neural networks with JAX -- [ml_dtypes](http://github.com/jax-ml/ml_dtypes): NumPy dtype extensions for machine learning. -- [optax](https://github.com/google-deepmind/optax): gradient processing and optimization in JAX. -- [orbax](https://github.com/google/orbax): checkpointing and persistence utilities for JAX. - -To get started using the stack, see {doc}`getting_started_with_jax_for_AI`. - -## Why the JAX AI stack? - -[JAX](http://github.com/jax-ml/jax) is a Python package for array-oriented -computation and program transformation. Built around it is a growing ecosystem -of packages for specialized numerical computing across a range of domains; an -up-to-date list of such projects can be found at -[Awesome JAX](https://github.com/n2cholas/awesome-jax). - -Though JAX is often compared to neural network libraries like pytorch, the JAX -core package itself contains very little that is specific to neural network -models. Instead, JAX encourages modularity, where domain-specific libraries -are developed separately from the core package: this helps drive innovation -as researchers and other users explore what is possible. - -Within this larger, distributed ecosystem, there are a number of projects that -Google researchers and engineers have found useful for implementing and deploying -the models behind generative AI tools like [Imagen](https://imagen.research.google/), -[Gemini](https://gemini.google.com/), and more. The JAX AI stack serves as a -single point-of-entry for this suite of libraries, so you can install and begin -using many of the same open source packages that Google developers are using -in their everyday work. - -To get started with the JAX AI stack, you can check out {doc}`getting_started_with_jax_for_AI`. -This is still a work-in-progress, please check back for more documentation and tutorials -in the coming weeks! - -```{toctree} -:maxdepth: 2 -:hidden: - -tutorials -contributing -``` diff --git a/docs/requirements.txt b/docs/requirements.txt index 568f828..5dc8941 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,5 +1,6 @@ # Sphinx-related requirements. sphinx +sphinx-book-theme>=1.0.1 myst-nb myst-parser[linkify] sphinx-book-theme diff --git a/docs/JAX_Vision_transformer.ipynb b/docs/source/JAX_Vision_transformer.ipynb similarity index 100% rename from docs/JAX_Vision_transformer.ipynb rename to docs/source/JAX_Vision_transformer.ipynb diff --git a/docs/JAX_Vision_transformer.md b/docs/source/JAX_Vision_transformer.md similarity index 100% rename from docs/JAX_Vision_transformer.md rename to docs/source/JAX_Vision_transformer.md diff --git a/docs/JAX_basic_text_classification.ipynb b/docs/source/JAX_basic_text_classification.ipynb similarity index 100% rename from docs/JAX_basic_text_classification.ipynb rename to docs/source/JAX_basic_text_classification.ipynb diff --git a/docs/JAX_basic_text_classification.md b/docs/source/JAX_basic_text_classification.md similarity index 100% rename from docs/JAX_basic_text_classification.md rename to docs/source/JAX_basic_text_classification.md diff --git a/docs/JAX_examples_image_segmentation.ipynb b/docs/source/JAX_examples_image_segmentation.ipynb similarity index 99% rename from docs/JAX_examples_image_segmentation.ipynb rename to docs/source/JAX_examples_image_segmentation.ipynb index c6820da..1bac5d4 100644 --- a/docs/JAX_examples_image_segmentation.ipynb +++ b/docs/source/JAX_examples_image_segmentation.ipynb @@ -693,7 +693,7 @@ "In this section we will implement the [UNETR](https://arxiv.org/abs/2103.10504) model from scratch using Flax NNX. The reference PyTorch implementation of this model can be found on the [MONAI Library GitHub repository](https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/nets/unetr.py).\n", "\n", "The UNETR model utilizes a transformer as the encoder to learn sequence representations of the input and to capture the global multi-scale information, while also following the “U-shaped” network design like [UNet](https://arxiv.org/abs/1505.04597) model:\n", - "![image.png](./_static/unetr_architecture.png)\n", + "![image.png](./_static/images/unetr_architecture.png)\n", "\n", "The UNETR architecture on the image above is processing 3D inputs, but it can be easily adapted to 2D input.\n", "\n", diff --git a/docs/JAX_examples_image_segmentation.md b/docs/source/JAX_examples_image_segmentation.md similarity index 99% rename from docs/JAX_examples_image_segmentation.md rename to docs/source/JAX_examples_image_segmentation.md index 1f0fbeb..c25e1c5 100644 --- a/docs/JAX_examples_image_segmentation.md +++ b/docs/source/JAX_examples_image_segmentation.md @@ -367,7 +367,7 @@ for img, mask in zip(images[:3], masks[:3]): In this section we will implement the [UNETR](https://arxiv.org/abs/2103.10504) model from scratch using Flax NNX. The reference PyTorch implementation of this model can be found on the [MONAI Library GitHub repository](https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/nets/unetr.py). The UNETR model utilizes a transformer as the encoder to learn sequence representations of the input and to capture the global multi-scale information, while also following the “U-shaped” network design like [UNet](https://arxiv.org/abs/1505.04597) model: -![image.png](./_static/unetr_architecture.png) +![image.png](./_static/images/unetr_architecture.png) The UNETR architecture on the image above is processing 3D inputs, but it can be easily adapted to 2D input. diff --git a/docs/JAX_for_LLM_pretraining.ipynb b/docs/source/JAX_for_LLM_pretraining.ipynb similarity index 100% rename from docs/JAX_for_LLM_pretraining.ipynb rename to docs/source/JAX_for_LLM_pretraining.ipynb diff --git a/docs/JAX_for_LLM_pretraining.md b/docs/source/JAX_for_LLM_pretraining.md similarity index 100% rename from docs/JAX_for_LLM_pretraining.md rename to docs/source/JAX_for_LLM_pretraining.md diff --git a/docs/JAX_for_PyTorch_users.ipynb b/docs/source/JAX_for_PyTorch_users.ipynb similarity index 100% rename from docs/JAX_for_PyTorch_users.ipynb rename to docs/source/JAX_for_PyTorch_users.ipynb diff --git a/docs/JAX_for_PyTorch_users.md b/docs/source/JAX_for_PyTorch_users.md similarity index 100% rename from docs/JAX_for_PyTorch_users.md rename to docs/source/JAX_for_PyTorch_users.md diff --git a/docs/JAX_image_captioning.ipynb b/docs/source/JAX_image_captioning.ipynb similarity index 100% rename from docs/JAX_image_captioning.ipynb rename to docs/source/JAX_image_captioning.ipynb diff --git a/docs/JAX_image_captioning.md b/docs/source/JAX_image_captioning.md similarity index 100% rename from docs/JAX_image_captioning.md rename to docs/source/JAX_image_captioning.md diff --git a/docs/JAX_machine_translation.ipynb b/docs/source/JAX_machine_translation.ipynb similarity index 100% rename from docs/JAX_machine_translation.ipynb rename to docs/source/JAX_machine_translation.ipynb diff --git a/docs/JAX_machine_translation.md b/docs/source/JAX_machine_translation.md similarity index 100% rename from docs/JAX_machine_translation.md rename to docs/source/JAX_machine_translation.md diff --git a/docs/JAX_porting_PyTorch_model.ipynb b/docs/source/JAX_porting_PyTorch_model.ipynb similarity index 100% rename from docs/JAX_porting_PyTorch_model.ipynb rename to docs/source/JAX_porting_PyTorch_model.ipynb diff --git a/docs/JAX_porting_PyTorch_model.md b/docs/source/JAX_porting_PyTorch_model.md similarity index 100% rename from docs/JAX_porting_PyTorch_model.md rename to docs/source/JAX_porting_PyTorch_model.md diff --git a/docs/JAX_time_series_classification.ipynb b/docs/source/JAX_time_series_classification.ipynb similarity index 100% rename from docs/JAX_time_series_classification.ipynb rename to docs/source/JAX_time_series_classification.ipynb diff --git a/docs/JAX_time_series_classification.md b/docs/source/JAX_time_series_classification.md similarity index 100% rename from docs/JAX_time_series_classification.md rename to docs/source/JAX_time_series_classification.md diff --git a/docs/JAX_transformer_text_classification.ipynb b/docs/source/JAX_transformer_text_classification.ipynb similarity index 100% rename from docs/JAX_transformer_text_classification.ipynb rename to docs/source/JAX_transformer_text_classification.ipynb diff --git a/docs/JAX_transformer_text_classification.md b/docs/source/JAX_transformer_text_classification.md similarity index 100% rename from docs/JAX_transformer_text_classification.md rename to docs/source/JAX_transformer_text_classification.md diff --git a/docs/JAX_visualizing_models_metrics.ipynb b/docs/source/JAX_visualizing_models_metrics.ipynb similarity index 99% rename from docs/JAX_visualizing_models_metrics.ipynb rename to docs/source/JAX_visualizing_models_metrics.ipynb index b28c603..347b473 100644 --- a/docs/JAX_visualizing_models_metrics.ipynb +++ b/docs/source/JAX_visualizing_models_metrics.ipynb @@ -123,7 +123,7 @@ "source": [ "After running all above and launching `tensorboard --logdir runs/test` from the same folder, you should see the following in the supplied URL:\n", "\n", - "![image.png](./_static/training_data_example.png)" + "![image.png](./_static/images/training_data_example.png)" ] }, { @@ -225,7 +225,7 @@ "source": [ "We've now created the basic model - the above cell will render an interactive view of the model. Which, when fully expanded, should look something like this:\n", "\n", - "![image.png](./_static/nnx_display_example.png)" + "![image.png](./_static/images/nnx_display_example.png)" ] }, { @@ -328,7 +328,7 @@ "\n", "The output there should look something like the following:\n", "\n", - "![image.png](./_static/loss_acc_example.png)" + "![image.png](./_static/images/loss_acc_example.png)" ] }, { @@ -339,11 +339,11 @@ "\n", "At step 1, we see poor accuracy, as you would expect\n", "\n", - "![image.png](./_static/testsheet_start_example.png)\n", + "![image.png](./_static/images/testsheet_start_example.png)\n", "\n", "By 500, the model is essentially done, but we see the bottom row `7` get lost and recovered at higher epochs as we go far into an overfitting regime. This kind of stored data can be very useful when the training routines become automated and a human is potentially only looking when something has gone wrong.\n", "\n", - "![image.png](./_static/testsheets_500_3000.png)" + "![image.png](./_static/images/testsheets_500_3000.png)" ] }, { @@ -427,7 +427,7 @@ "source": [ "The above cell output will give you an interactive plot that looks like this image below, where here we've 'clicked' in the bottom plot for entry `7` and hover over the corresponding value in the top plot.\n", "\n", - "![image.png](./_static/model_display_example.png)" + "![image.png](./_static/images/model_display_example.png)" ] }, { diff --git a/docs/JAX_visualizing_models_metrics.md b/docs/source/JAX_visualizing_models_metrics.md similarity index 96% rename from docs/JAX_visualizing_models_metrics.md rename to docs/source/JAX_visualizing_models_metrics.md index 2033d14..6007c3f 100644 --- a/docs/JAX_visualizing_models_metrics.md +++ b/docs/source/JAX_visualizing_models_metrics.md @@ -83,7 +83,7 @@ with test_summary_writer.as_default(): After running all above and launching `tensorboard --logdir runs/test` from the same folder, you should see the following in the supplied URL: -![image.png](./_static/training_data_example.png) +![image.png](./_static/images/training_data_example.png) ```{code-cell} ipython3 :id: 6jrYisoPh6TL @@ -131,7 +131,7 @@ nnx.display(model) # Interactive display if penzai is installed. We've now created the basic model - the above cell will render an interactive view of the model. Which, when fully expanded, should look something like this: -![image.png](./_static/nnx_display_example.png) +![image.png](./_static/images/nnx_display_example.png) +++ @@ -211,7 +211,7 @@ During the training has run, and after, the added `Loss` and `Accuracy` scalars The output there should look something like the following: -![image.png](./_static/loss_acc_example.png) +![image.png](./_static/images/loss_acc_example.png) +++ @@ -219,11 +219,11 @@ Since we've stored the example test sheet every 500 epochs, it's easy to go back At step 1, we see poor accuracy, as you would expect -![image.png](./_static/testsheet_start_example.png) +![image.png](./_static/images/testsheet_start_example.png) By 500, the model is essentially done, but we see the bottom row `7` get lost and recovered at higher epochs as we go far into an overfitting regime. This kind of stored data can be very useful when the training routines become automated and a human is potentially only looking when something has gone wrong. -![image.png](./_static/testsheets_500_3000.png) +![image.png](./_static/images/testsheets_500_3000.png) +++ @@ -235,7 +235,7 @@ nnx.display(model(images_test[:35])), nnx.display(model(images_test[:35]).argmax The above cell output will give you an interactive plot that looks like this image below, where here we've 'clicked' in the bottom plot for entry `7` and hover over the corresponding value in the top plot. -![image.png](./_static/model_display_example.png) +![image.png](./_static/images/model_display_example.png) +++ diff --git a/docs/source/_static/css/custom.css b/docs/source/_static/css/custom.css new file mode 100644 index 0000000..d75bfce --- /dev/null +++ b/docs/source/_static/css/custom.css @@ -0,0 +1,284 @@ +/* Base file modifications */ +body:has(.hero) .bd-sidebar-primary, +body:has(.hero) .sidebar-toggle, +body:has(.hero) .bd-sidebar-secondary { + display: none !important; +} + +body:has(.hero) .prev-next-footer { + display: none; +} + +body:has(.hero) .bd-article-container { + max-width: unset !important; +} + +body:has(.hero) .bd-page-width { + max-width: unset !important; +} + +body:has(.hero) .bd-article { + display: flex; + flex-direction: column; + padding: 0; +} + +body:has(.hero) .bd-container { + flex-direction: column; +} + +body:has(.hero) .bd-article > section > h1 { + display: none; +} + +@media (min-width: 960px) { + body:has(.hero) .bd-header-article { + justify-content: center; + } + + body:has(.hero) .header-article-items, + body:has(.hero) .doc-body > section { + max-width: 80rem !important; + align-self: center; + width: -moz-available; + width: -webkit-fill-available; + width: fill-available; + } + + body:has(.hero) .doc-body > section.hero, + body:has(.hero) .doc-body > section.banner { + max-width: 90rem !important; + } +} + +/* -------------- Page styles ---------------- */ +.doc-body { + display: flex; + flex-direction: column; +} + +.hero { + width: 100%; + display: grid; + grid: auto-flow / 1fr .8fr; + gap: 20px; + align-items: center; + background: url(../../_static/images/hero-radial.svg) no-repeat top -750px right -450px / 1500px #202124; + border-radius: 24px; + padding-left: 5%; +} + +.hero span { + display: flex; + align-items: center; + margin-bottom: 20px; +} + +.hero span img { + margin-right: 24px; + background: transparent !important; +} + +.hero h1 { + font: 700 52px 'Google Sans', 'Roboto', sans-serif; + color: white; + margin: 0; +} + +.hero-image { + background: none !important; +} + +.button-primary { + background: #1A73E8; + border-radius: 4px; + color: white; + font: 400 14px 'Google Sans', 'Roboto', sans-serif; + text-decoration: none; + padding: 9px 26px; + transition: background-color .2s, border .2s, box-shadow .2s; + width: max-content; +} + +.button-primary:visited:hover { + color: white !important; +} + +.button-primary:hover { + background-color: #1765cc; + color: white; + transition: background-color .2s, border .2s, box-shadow .2s; + box-shadow: 0 1px 2px 0 rgba(60,64,67,.3),0 1px 3px 1px rgba(60,64,67,.15); +} + +.button-primary:active { + background-color: #185abc; + color: white; + box-shadow: 0 1px 2px 0 rgba(60,64,67,.3),0 2px 6px 2px rgba(60,64,67,.15); +} + +.button-primary:visited { + color: white; + text-decoration: none; +} + +.banner { + background: #E8F0FE; + border-radius: 24px; + margin-block: 80px 100px; + padding-inline: 50px; + padding-bottom: 24px; +} + +.three-up { + display: grid; + grid: auto-flow / 1fr 1fr 1fr; + gap: 24px; + padding-inline: 60px; + margin-top: 20px; +} + +.doc-body .image-section h3 { + font: 500 32px 'Google Sans', 'Roboto', sans-serif; + margin-top: 0; +} + +.doc-body .hero p { + color: #bdc1c6; +} + +.doc-body p, +.doc-body ul { + font: 400 16px 'Roboto', sans-serif; + color: #5F6368; + line-height: 24px; +} + +.image-section { + display: flex; + flex-direction: row; + justify-content: space-between; + gap: 85px; + margin-bottom: 60px; +} + + +.image-section img { + background-color: transparent !important; +} + +.image-section.image-right { + flex-direction: row-reverse; +} + +.image-section .text-body { + padding-inline: 50px; + display: flex; + flex-direction: column; + justify-content: center; +} + +.image-section .text-body *:last-child { + margin-bottom: 0; +} + +.image-section li:not(:last-of-type) { + margin-bottom: 12px; +} + +.image-section .button { + margin-top: 12px; +} + +@media (max-width: 1240px) { + .hero h1 { + font-size: 42px; + margin-bottom: 24px; + } + + .image-section { + gap: 32px; + } + + .image-section img { + min-width: 50%; + } + + .image-section .text-body { + padding-inline: 0; + } +} + +@media (max-width: 1020px) { + .hero h1 { + font-size: 32px; + } + + .image-section { + margin-bottom: 40px; + } +} + +@media (max-width: 860px) { + .three-up { + grid: auto-flow / 1fr; + padding-inline: 0; + } + + .three-up h3 { + margin-top: 8px; + } +} + +@media (max-width: 800px) { + .hero { + grid: auto-flow / 1fr; + padding: 24px; + } + + .hero > img { + display: none; + } + + .image-section, + .image-section.image-right { + flex-direction: column; + margin-bottom: 80px; + } + + .image-section h3 { + margin-top: 0; + } + + .image-section .button { + margin-top: 0; + } +} + +html[data-theme="dark"] .banner { + background: url(../../_static/images/hero-radial.svg) no-repeat bottom -1050px left -350px / 1500px #202124; +} + +html[data-theme="dark"] .doc-body p, +html[data-theme="dark"] .doc-body ul { + color: #bdc1c6; +} + +html[data-theme="dark"] .button-primary { + background-color: #8ab4f8; + color: #121212; +} + +html[data-theme="dark"] .button-primary:hover { + background-color: #98bdf9; + color: #121212; +} + +html[data-theme="dark"] .button-primary:active { + background-color: #aecbfa; + color: #121212; +} + +html[data-theme="dark"] .button-primary:visited:hover { + color: #121212 !important; +} diff --git a/docs/source/_static/images/ai-stack-logo.svg b/docs/source/_static/images/ai-stack-logo.svg new file mode 100644 index 0000000..677e107 --- /dev/null +++ b/docs/source/_static/images/ai-stack-logo.svg @@ -0,0 +1,26 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/source/_static/images/compiler.svg b/docs/source/_static/images/compiler.svg new file mode 100644 index 0000000..1bf2bcd --- /dev/null +++ b/docs/source/_static/images/compiler.svg @@ -0,0 +1,239 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/source/_static/images/favicon.png b/docs/source/_static/images/favicon.png new file mode 100644 index 0000000000000000000000000000000000000000..862712fa8c65522ee1b4148cd3b36b5a9ce056c8 GIT binary patch literal 3756 zcmV;d4pZ@oP)*J~r7}L2Z4*e9)S{^9RRj)9Nb@4S zzkY~{SPUS&&e?@zaQ{o6wwt!4mdMMDW=(@nUNHah6JeRaglQ&A!>@Q+BG_x$N^4kA zeKbvAE;{$m54_t6uow1p&K9h?H4%}WX){v<)(~}q;ZX#-ZRlL4w4~q=5Pe~x_ZLbZ zGZGj}OL~OB;Rz|drYuHW=XwFOnSCZ13p|e_kwZA6pT?t(-FQgdW~-g!+F=}z9QK`% zC%jgki|OMYO`tWCx1}Y}T}>l7=$zfO1VHpp|99jN7I{8?&8gJl>7w65sn}qujVXNs z|G6}T%X|^c@SU5Ec7QA1r51+hQGovYypG{aZ5aBRa$vOCp183~-6lC7q!1@K1w zkcV0HAcwSMN}s@cOK)P1ucy6dtGW%l9J|vJImg%WZ%c1tTAzqtV}|eOcgFlGijtzg z@M{SBHhb5~Om=1B!jb7gGca_H*}l%D`qSxfTKfpo`owK+y}$&TT*W=Y~ z-+x9$QP1tIu}4Lb=Za7BMmT&R8fWb64-u>p@d zb^`!Dhzuq%5;Z>|JUT>y33Bb3>>+mQ4kk=k)=$t0F3!Gn_vx+&sj`x<{ zyzP2p`NMdz@l|}S{9!B25rBcuypbLaORZ*mqhjR^0B`{Nx<1xF^ulJ4n9?Wkv+(Qq zVCnWAlP_#pSAoszDs4S`GklZ$C-(L=3uRh+o=A>KIkL!6lXFn&!} z1ODEQ`fA}{zJpCWG|9{RsFcC&nEEOAYByuR8mN-h*qw3LDu%+n* zQ$|G=ejQTRtn2>i%j^8cwWcRtYPfsL){%-wUNoh>i*pw-ePmzUtqr``@G8Dp_bmSX z>@V>4s6XniDPXn-=9R!ySXZ?R^=*HLlDn+O92OW};>GS_QSuCQ_Cg5Dja;4j6lVs1 z9~}AI(sL{3!r3sdZ^Eaa%%FmF+`HLaz%d{gZN(^Q0|OA&X5O@jhMC|jE?<}j)VlzS z1e!&`2?T)fLKwmEXq+36=@e75E(hn9b{8GOo>DN9gJt?yua4Ev|CA&BPI22qQQ) zjhmrcIk$ya^f&Y`8!V4Rq?B?(O0Op10^@*n`=Zby8fN4yJ~=juYnRf`zGQ797;PgD zGU|ye#vmPVOZ@Hl~uR-{`q4hboL5hkA0J~fkP~WZ|v^s z9+h#j#!0sGS->QL_0Nw`-2)E?1)~pEXQERWI{wbCz}UeZE}c0!a@ztk0$^?&8{F9? zha2PWn()EBX&x;T0$z7D1?}#yHrE3q4l4R;FR2ygXLpm3($Mk2oqkqCE0|5CI>|`X zxzWL$@tyuuNiFVb8nN)f2?_=hy~Qjnxv-=1Srj+)L%Hu8Rvt@4)X-`gNp$Pc!JS>Y zB3ul!BO49@?3)j-Z5)4RS0M2hcNKY56(g{-okYO11r@=+cK^uqjta|Zj40WNBKLQp zmp_Tf*?#~pUbWf;0JOyPF)rleZ*7+VJoV2V4G1B6(xyEu{GTta?M`&n#nBG z$tmeFE{?suqp$w?W5dwd%gllC!JYnGZB2$#sm0?(?YLi5WWPBL@-mbYB^%+W{sXAx zyAiqY3+PvWX{Q-VzF@dD5Dg4=cabKeRBXWGj&^)m-I{Y-IZq}e!qCT@+&T$HL{(Lc z1@Gt%nDW~dvR5xMj}^vmzt3N%Dirt0X0Shn}ZNH?189#<$4 zYq3VGwQNs1hQ;KtTEOtn3Rn>Ce?Qp(9`=}ED~kPQJA+`9KnTStd% zH=U^4j>7Sw*Yn4`iq1m-a&{l_^@kv8s-YkK2z>N-*8Au0-+<2^X+Y$z+oP&6^IS{h z)ft&)#6!By-DVHwBqQf2JpR^pIeKtMW8U=fo>C}JKLPc7uRwh5p_I2Q-g7r*9)AQg zk8Mv$ULq!>^y&hwsoa_c`_x5t*(DFiOkjhR6d_R`E!r?F)|A}tGWV21dG6_DZmkzN zaP8|4;=+@=Fu&D^+H$?T~ijB9}8{siy9=E5AhK z+)om33kU(RDuO(_x7(+Eq zGF@84(8s#xj`lf(7)sKg5CpodE-5b667&n_tvq52@guy91|i^^n94l<&Z>bEQU=aV z$wr}I%4%tmvx7TN8B_Owek!(g!j+Hp*h0WvMIMDQA{2Z)IVIc8Z1J$v-D7Uu647#}IG2(vmeQCQVO@qkx}CJ@rI&I+?0@tCfZYyZD-TTASmaR| z1MVs+NJWf7VVTY;)=Cv5Bd0uYQcC~ilx!6$%G}j7LRDnzc&iXvE5NyqY64)9^>C!S zirQj2H)U4mG9hnePx4V(r2gmoNoREgDfBRd21+pd@Y;|48SN`j>8ZRI-ZAngW(0Hv z2ka18_p|*TA=G6@XHyUu0*ihuIQr1nmzIoNN)?zv!Ad{nT3V$3=le<56%T~y0g?&i zVX)qf*M1^1wNc}w9wE?TBa|+_4}cay!XyXX}K4ihX~db+*UYa_4qZ9}%rcg^(4kU>PedG6IG_pA{!5 zPZcqLW022^lZB{=X$7lRgqDcP96(o0AKR@)$3j=cG>cU)#BDWyG&{|-Qdo5Y4b&2~ zFr$YVfmHgKOQpE~Q%I$c?ke)sRMSZ}-Fa2a@0cE#h{S0#09b9E`;7&~u-bNH!v6vU Ww~7pwDBHmR0000 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/source/_static/images/hero-radial.svg b/docs/source/_static/images/hero-radial.svg new file mode 100644 index 0000000..e969824 --- /dev/null +++ b/docs/source/_static/images/hero-radial.svg @@ -0,0 +1,101 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/source/_static/images/hero.svg b/docs/source/_static/images/hero.svg new file mode 100644 index 0000000..98bf98d --- /dev/null +++ b/docs/source/_static/images/hero.svg @@ -0,0 +1,128 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/_static/loss_acc_example.png b/docs/source/_static/images/loss_acc_example.png similarity index 100% rename from docs/_static/loss_acc_example.png rename to docs/source/_static/images/loss_acc_example.png diff --git a/docs/_static/model_display_example.png b/docs/source/_static/images/model_display_example.png similarity index 100% rename from docs/_static/model_display_example.png rename to docs/source/_static/images/model_display_example.png diff --git a/docs/_static/nnx_display_example.png b/docs/source/_static/images/nnx_display_example.png similarity index 100% rename from docs/_static/nnx_display_example.png rename to docs/source/_static/images/nnx_display_example.png diff --git a/docs/_static/testsheet_start_example.png b/docs/source/_static/images/testsheet_start_example.png similarity index 100% rename from docs/_static/testsheet_start_example.png rename to docs/source/_static/images/testsheet_start_example.png diff --git a/docs/_static/testsheets_500_3000.png b/docs/source/_static/images/testsheets_500_3000.png similarity index 100% rename from docs/_static/testsheets_500_3000.png rename to docs/source/_static/images/testsheets_500_3000.png diff --git a/docs/_static/training_data_example.png b/docs/source/_static/images/training_data_example.png similarity index 100% rename from docs/_static/training_data_example.png rename to docs/source/_static/images/training_data_example.png diff --git a/docs/_static/unetr_architecture.png b/docs/source/_static/images/unetr_architecture.png similarity index 100% rename from docs/_static/unetr_architecture.png rename to docs/source/_static/images/unetr_architecture.png diff --git a/docs/source/_static/images/what-is-the-jax-ai-stack.svg b/docs/source/_static/images/what-is-the-jax-ai-stack.svg new file mode 100644 index 0000000..95b6842 --- /dev/null +++ b/docs/source/_static/images/what-is-the-jax-ai-stack.svg @@ -0,0 +1,31 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/conf.py b/docs/source/conf.py similarity index 96% rename from docs/conf.py rename to docs/source/conf.py index b27a425..61ea950 100644 --- a/docs/conf.py +++ b/docs/source/conf.py @@ -30,6 +30,8 @@ html_theme = 'sphinx_book_theme' html_title = 'JAX AI Stack' html_static_path = ['_static'] +html_css_files = ['css/custom.css'] +html_favicon = '_static/images/favicon.png' # Theme-specific options # https://sphinx-book-theme.readthedocs.io/en/stable/reference.html @@ -37,7 +39,7 @@ 'show_navbar_depth': 2, 'show_toc_level': 2, 'repository_url': 'https://github.com/jax-ml/jax-ai-stack', - 'path_to_docs': 'docs/', + 'path_to_docs': 'docs/source/', 'use_repository_button': True, 'navigation_with_keys': True, } diff --git a/docs/contributing.md b/docs/source/contributing.md similarity index 98% rename from docs/contributing.md rename to docs/source/contributing.md index ed2c914..9bf1810 100644 --- a/docs/contributing.md +++ b/docs/source/contributing.md @@ -81,7 +81,7 @@ git commit -m "update new tutorial" # commit to the branch To build the documentation locally, you can run the following command: ```bash -sphinx-build -b html docs/ docs/_build/html +sphinx-build -b html docs/source docs/_build/html ``` You can then open the generated HTML files in your browser by opening `docs/_build/html/index.html`. diff --git a/docs/data_loaders_on_cpu_with_jax.ipynb b/docs/source/data_loaders_on_cpu_with_jax.ipynb similarity index 100% rename from docs/data_loaders_on_cpu_with_jax.ipynb rename to docs/source/data_loaders_on_cpu_with_jax.ipynb diff --git a/docs/data_loaders_on_cpu_with_jax.md b/docs/source/data_loaders_on_cpu_with_jax.md similarity index 100% rename from docs/data_loaders_on_cpu_with_jax.md rename to docs/source/data_loaders_on_cpu_with_jax.md diff --git a/docs/digits_diffusion_model.ipynb b/docs/source/digits_diffusion_model.ipynb similarity index 100% rename from docs/digits_diffusion_model.ipynb rename to docs/source/digits_diffusion_model.ipynb diff --git a/docs/digits_diffusion_model.md b/docs/source/digits_diffusion_model.md similarity index 100% rename from docs/digits_diffusion_model.md rename to docs/source/digits_diffusion_model.md diff --git a/docs/digits_vae.ipynb b/docs/source/digits_vae.ipynb similarity index 100% rename from docs/digits_vae.ipynb rename to docs/source/digits_vae.ipynb diff --git a/docs/digits_vae.md b/docs/source/digits_vae.md similarity index 100% rename from docs/digits_vae.md rename to docs/source/digits_vae.md diff --git a/docs/getting_started_with_jax_for_AI.ipynb b/docs/source/getting_started_with_jax_for_AI.ipynb similarity index 100% rename from docs/getting_started_with_jax_for_AI.ipynb rename to docs/source/getting_started_with_jax_for_AI.ipynb diff --git a/docs/getting_started_with_jax_for_AI.md b/docs/source/getting_started_with_jax_for_AI.md similarity index 100% rename from docs/getting_started_with_jax_for_AI.md rename to docs/source/getting_started_with_jax_for_AI.md diff --git a/docs/source/index.html b/docs/source/index.html new file mode 100644 index 0000000..eb64b49 --- /dev/null +++ b/docs/source/index.html @@ -0,0 +1,71 @@ +
+
+
+

JAX AI Stack

+

Flexible, scalable components for AI research and development

+ Get started +
+ +
+
+
+

Flexible and scalable

+

Iterate quickly and with efficient out-of-the-box scaling

+
+
+

Run anywhere

+

Execute the same code on any CPU, GPU, & TPU

+
+
+

Reliability and compatibility

+

JAX AI Stack tested releases provide high reliability by ensuring compatibility across its libraries

+
+
+ +
+ +
+

JAX AI Stack

+

The JAX AI Stack is a curated collection of libraries that researchers and engineers, both inside and outside of Google, have found useful for implementing and deploying the models behind generative AI tools like Imagen, Gemini, and more.

+
    +
  • JAX - core array operations and program transformations
  • +
  • Flax - For building neural networks
  • +
  • Orbax -For checkpointing and persistence utilities
  • +
  • Optax - For gradient processing and optimization
  • +
  • ml_dtypes - NumPy dtype extensions for machine learning.
  • +
  • Optional data loading libraries (Grain or tf.data)
  • +
+ Get started +
+
+
+ +
+

Powered by JAX

+

JAX is a Python library for efficient array-oriented computation and program transformation. JAX's flexible and modular approach has encouraged communities across AI, scientific computing, simulation and more to build ecosystems on top of it.

+

JAX is often compared to neural network libraries like PyTorch, but the core JAX package contains very little specific to deep learning. Instead, JAX encourages modularity, where domain-specific libraries are developed separately from the core package. This helps drive innovation as researchers and other users explore what's possible.

+ Learn more about JAX +
+
+
+ +
+

Part of a wider AI Ecosystem

+

The JAX AI Stack is part of a growing AI community and ecosystem around JAX. Modularity and choice are important principles for JAX and we are excited to see the development happening around it!

+ Learn more about the AI Ecosystem +
+
+
diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 0000000..ffa2c91 --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,12 @@ +Jax AI Stack +============ + +.. raw:: html + :file: index.html + +.. toctree:: + :hidden: + :maxdepth: 2 + + tutorials + contributing diff --git a/docs/tutorials.md b/docs/source/tutorials.md similarity index 100% rename from docs/tutorials.md rename to docs/source/tutorials.md diff --git a/pyproject.toml b/pyproject.toml index 1a6a8f6..d42719b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,4 +66,4 @@ include-package-data = false [tool.ruff.lint.per-file-ignores] # F811: Redefinition of unused name. -"docs/digits_vae.ipynb" = ["F811"] +"docs/source/digits_vae.ipynb" = ["F811"]