Skip to content

Commit

Permalink
bugfix related to key input of neural_ode_model/controller
Browse files Browse the repository at this point in the history
  • Loading branch information
simon-bachhuber committed Sep 6, 2023
1 parent 85cfc4b commit 402ae46
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 26 deletions.
27 changes: 17 additions & 10 deletions cc/examples/neural_ode_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def make_neural_ode_controller(
output_specs: ArraySpecs,
control_timestep: float,
state_dim: int,
key,
key=jrand.PRNGKey(1),
f_integrate_method: str = "RK4",
f_use_bias: bool = True,
f_time_invariant: bool = True,
Expand All @@ -37,6 +37,10 @@ def make_neural_ode_controller(
g_activation=jax.nn.relu,
g_final_activation=lambda x: x,
):
# TODO: implement dropout; could be done using a "KeyWrapper"
if f_use_dropout or g_use_dropout:
raise NotImplementedError

toy_input = sample_from_tree_of_specs(input_specs)
toy_output = sample_from_tree_of_specs(output_specs)
input_dim = batch_concat(toy_input, 0).size
Expand Down Expand Up @@ -98,28 +102,27 @@ class NeuralOdeController(AbstractController):
def reset(self):
return NeuralOdeController(self.f, self.g, self.init_state, self.init_state)

def step(self, u): # u has shape identical to (`toy_input`, PRNGKey)
u, key = u
def step(self, u):
# TODO
# u has shape identical to (`toy_input`, PRNGKey)
# u, key = u
key = jrand.PRNGKey(1)

if has_time_state:
(x, t) = self.state # pytype: disable=attribute-error
else:
x = self.state
t = jnp.array(0.0)

if f_use_dropout:
key, consume = jrand.split(key)

key, consume = jrand.split(key)
if not f_time_invariant:
rhs = lambda t, x: self.f(batch_concat((x, t, u), 0), key=consume)
else:
rhs = lambda t, x: self.f(batch_concat((x, u), 0), key=consume)

x_next = integrate(rhs, x, t, control_timestep, f_integrate_method)

if g_use_dropout:
key, consume = jrand.split(key)

key, consume = jrand.split(key)
if not g_time_invariant:
y_next = self.g(batch_concat((x_next, t), 0), key=consume)
else:
Expand All @@ -132,10 +135,14 @@ def step(self, u): # u has shape identical to (`toy_input`, PRNGKey)
else:
state_next = x_next

return NeuralOdeController(self.f, self.g, state_next, self.init_state), (
# TODO
out = (
y_next,
key,
) # y_next has shape identical to (`toy_output`, PRNGKey)
out = y_next

return NeuralOdeController(self.f, self.g, state_next, self.init_state), out

def grad_filter_spec(self) -> PyTree[bool]:
filter_spec = super().grad_filter_spec()
Expand Down
42 changes: 28 additions & 14 deletions cc/examples/neural_ode_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def make_neural_ode_model(
output_specs: ArraySpecs,
control_timestep: float,
state_dim: int,
key,
key=jrand.PRNGKey(1),
f_integrate_method: str = "RK4",
f_use_bias: bool = True,
f_time_invariant: bool = True,
Expand All @@ -38,6 +38,10 @@ def make_neural_ode_model(
g_final_activation=lambda x: x,
u_transform=lambda u: u,
):
# TODO: implement dropout; could be done using a "KeyWrapper"
if f_use_dropout or g_use_dropout:
raise NotImplementedError

toy_input = sample_from_tree_of_specs(input_specs)
toy_output = sample_from_tree_of_specs(output_specs)
input_dim = batch_concat(toy_input, 0).size
Expand Down Expand Up @@ -99,8 +103,11 @@ class NeuralOdeModel(AbstractModel):
def reset(self):
return NeuralOdeModel(self.f, self.g, self.init_state, self.init_state)

def step(self, u): # u has shape identical to (`toy_input`, PRNGKey)
u, key = u
def step(self, u):
# TODO
# u has shape identical to (`toy_input`, PRNGKey)
# u, key = u
key = jrand.PRNGKey(1)

# e.g. the model might saturate for large controls u
u = u_transform(u)
Expand All @@ -111,19 +118,15 @@ def step(self, u): # u has shape identical to (`toy_input`, PRNGKey)
x = self.state
t = jnp.array(0.0)

if f_use_dropout:
key, consume = jrand.split(key)

key, consume = jrand.split(key)
if not f_time_invariant:
rhs = lambda t, x: self.f(batch_concat((x, t, u), 0), key=consume)
else:
rhs = lambda t, x: self.f(batch_concat((x, u), 0), key=consume)

x_next = integrate(rhs, x, t, control_timestep, f_integrate_method)

if g_use_dropout:
key, consume = jrand.split(key)

key, consume = jrand.split(key)
if not g_time_invariant:
y_next = self.g(batch_concat((x_next, t), 0), key=consume)
else:
Expand All @@ -136,10 +139,14 @@ def step(self, u): # u has shape identical to (`toy_input`, PRNGKey)
else:
state_next = x_next

return NeuralOdeModel(self.f, self.g, state_next, self.init_state), (
# TODO
out = (
y_next,
key,
) # y_next has shape identical to (`toy_output`, PRNGKey)
out = y_next

return NeuralOdeModel(self.f, self.g, state_next, self.init_state), out

def grad_filter_spec(self) -> PyTree[bool]:
filter_spec = super().grad_filter_spec()
Expand All @@ -151,12 +158,19 @@ def grad_filter_spec(self) -> PyTree[bool]:
return filter_spec

def y0(self):
g = eqx.tree_inference(self.g)
# TODO; Implememnt Dropout;
# g = eqx.tree_inference(self.g, True)
g = self.g

if has_time_state:
x, t = self.init_state
else:
x, t = self.init_state, jnp.array(0.0)

if not g_time_invariant:
t = jnp.array(0.0)
inp = batch_concat((self.init_state, t), 0)
inp = batch_concat((x, t), 0)
else:
inp = batch_concat((self.init_state,), 0)
inp = batch_concat(x, 0)

return postprocess_fn(g(inp))

Expand Down
3 changes: 3 additions & 0 deletions cc/examples/pole_placement_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ def make_pole_placed_controller(
pretrained_model=None,
**kwargs,
):
"""
NOTE: len(poles) == kwargs["state_dim"] * 2
"""
if pretrained_model is None:
model_trainer = _train_linear_model(env, **kwargs)
model = model_trainer._model
Expand Down
32 changes: 31 additions & 1 deletion cc/examples/test_controllers.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,33 @@
import jax

from cc.utils.high_level.defaults import Env

from ..env import make_env
from ..env.collect import collect
from ..env.collect import sample_feedforward_collect_and_make_source
from ..env.wrappers import AddRefSignalRewardFnWrapper
from ..utils.utils import timestep_array_from_env
from .feedforward_controller import make_feedforward_controller
from .neural_ode_controller_compact_example import make_neural_ode_controller
from .neural_ode_controller import make_neural_ode_controller
from .neural_ode_controller_compact_example import (
make_neural_ode_controller as make_neural_ode_controller_compact,
)
from .pid_controller import make_pid_controller
from .pole_placement_controller import make_pole_placed_controller


def dummy_env():
return make_env("two_segments_v1", random=1)


_env_data = {
"train_gp": list(range(1)),
"train_cos": list(range(1)),
"val_gp": list(range(2)),
"val_cos": [2],
}


def test_controllers():
env = dummy_env()
source, _, _ = sample_feedforward_collect_and_make_source(env, seeds=[0])
Expand All @@ -24,8 +40,22 @@ def test_controllers():
env.action_spec(),
env.control_timestep(),
10,
jax.random.PRNGKey(1),
f_time_invariant=False,
),
make_neural_ode_controller_compact(
env_w_source.observation_spec(),
env.action_spec(),
env.control_timestep(),
10,
),
make_feedforward_controller(timestep_array_from_env(env)),
make_pole_placed_controller(
Env("two_segments_v1", {}, {}, data=_env_data),
[-0.1] * 2,
verbose=False,
state_dim=1,
)[0](),
]

for controller in controllers:
Expand Down
3 changes: 2 additions & 1 deletion cc/train/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from cc.env.collect import sample_feedforward_and_collect
from cc.env.wrappers import AddRefSignalRewardFnWrapper
from cc.examples.neural_ode_controller_compact_example import make_neural_ode_controller
from cc.examples.neural_ode_model_compact_example import make_neural_ode_model
from cc.examples.neural_ode_model import make_neural_ode_model
from cc.train import DictLogger
from cc.train import EvaluationMetrices
from cc.train import make_dataloader
Expand Down Expand Up @@ -44,6 +44,7 @@ def test_trainer():
state_dim=1,
f_depth=0,
u_transform=jnp.arctan,
f_time_invariant=False,
)

model_train_dataloader = make_dataloader(
Expand Down

0 comments on commit 402ae46

Please sign in to comment.