Skip to content

Commit

Permalink
Merge pull request #5 from shyamsn97/notebook-fix
Browse files Browse the repository at this point in the history
fixing notebooks
  • Loading branch information
shyamsn97 authored Jan 17, 2023
2 parents b73fe72 + 83c3796 commit 2765728
Show file tree
Hide file tree
Showing 10 changed files with 545 additions and 704 deletions.
4 changes: 3 additions & 1 deletion hypernn/jax/dynamic_hypernet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import Any, Dict, List, Optional, Tuple

import flax.linen as nn
Expand Down Expand Up @@ -91,7 +93,7 @@ def from_target(
weight_chunk_dim: Optional[int] = None,
*args,
**kwargs,
) -> JaxDynamicEmbeddingModule:
) -> JaxDynamicHyperNetwork:
num_target_parameters, variables = cls.count_params(
target_network, target_input_shape, inputs=inputs, return_variables=True
)
Expand Down
11 changes: 9 additions & 2 deletions hypernn/jax/hypernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,10 @@ def count_params(
return_variables: bool = False,
):
return count_jax_params(
target, target_input_shape, inputs=inputs, return_variables=return_variables
target,
target_input_shape,
inputs=inputs,
return_variables=return_variables,
)

@classmethod
Expand All @@ -139,8 +142,12 @@ def from_target(
**kwargs,
) -> JaxHyperNetwork:
num_target_parameters, variables = cls.count_params(
target_network, target_input_shape, inputs=inputs, return_variables=True
target_network,
target_input_shape,
inputs=inputs,
return_variables=True,
)

_value_flat, target_treedef = jax.tree_util.tree_flatten(variables)
target_weight_shapes = [v.shape for v in _value_flat]

Expand Down
3 changes: 2 additions & 1 deletion hypernn/jax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ def count_jax_params(
input_shape: Optional[Tuple[int, ...]] = None,
inputs: Optional[List[jnp.array]] = None,
return_variables: bool = False,
**kwargs
) -> int:
if input_shape is None and inputs is None:
raise ValueError("Input shape or inputs must be specified")
if inputs is None:
inputs = [jnp.zeros(shape) for shape in input_shape]
variables = model.init(jax.random.PRNGKey(0), *inputs)
variables = model.init(jax.random.PRNGKey(0), *inputs, **kwargs)

def count_recursive(d):
s = 0
Expand Down
11 changes: 7 additions & 4 deletions notebooks/Intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -251,7 +251,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -295,7 +295,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -368,10 +368,13 @@
}
],
"metadata": {
"interpreter": {
"hash": "7e5f2d8038e6c8941d283a9a145e7dfd2f60905a23729a9c846dae091a9571f4"
},
"kernelspec": {
"display_name": "Python [conda env:py39] *",
"language": "python",
"name": "conda-env-py39-py"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand Down
230 changes: 66 additions & 164 deletions notebooks/dynamic_hypernetworks/JaxDynamicHyperRNN.ipynb

Large diffs are not rendered by default.

194 changes: 91 additions & 103 deletions notebooks/dynamic_hypernetworks/TorchDynamicHyperRNN.ipynb

Large diffs are not rendered by default.

205 changes: 79 additions & 126 deletions notebooks/mnist/JaxHyperMNIST.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 2765728

Please sign in to comment.