Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions ss2r/algorithms/mbpo/model_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from brax import envs
from brax.envs import base

from ss2r.algorithms.mbpo.networks import make_inference_fn
from ss2r.algorithms.sac.types import float32


Expand Down Expand Up @@ -46,6 +47,15 @@ def __init__(
self.initial_normalizer_params = (
initial_normalizer_params if initial_normalizer_params is not None else {}
)
_make_policy = make_inference_fn(mbpo_network)
self.backup_policy_fn = _make_policy(
(training_state.normalizer_params, training_state.backup_policy_params),
deterministic=True,
)
self.backup_policy_fn_initial = _make_policy(
(self.initial_normalizer_params, training_state.backup_policy_params),
deterministic=True,
)

def reset(self, rng: jax.Array) -> base.State:
sample_key, model_key = jax.random.split(rng)
Expand Down Expand Up @@ -110,12 +120,8 @@ def step(self, state: base.State, action: jax.Array) -> base.State:
state.obs,
action,
).mean(axis=-1)
backup_policy = self.policy_network.apply
backup_policy_params = self.backup_policy_params
backup_action = jnp.tanh(
backup_policy(
self.initial_normalizer_params, backup_policy_params, state.obs
)[: self.action_size]
backup_action, _ = self.backup_policy_fn_initial(
state.obs, jax.random.PRNGKey(0)
)
qc_backup = self.qc_network.apply(
self.normalizer_params,
Expand All @@ -130,12 +136,8 @@ def step(self, state: base.State, action: jax.Array) -> base.State:
done,
)

pred_backup_action = self.policy_network.apply
backup_policy_params = self.backup_policy_params
backup_action = jnp.tanh(
pred_backup_action(
self.normalizer_params, backup_policy_params, state.obs
)[: self.action_size]
backup_action, _ = self.backup_policy_fn(
state.obs, jax.random.PRNGKey(0)
)
pred_qr = self.qr_network.apply
backup_qr_params = self.backup_qr_params
Expand Down
8 changes: 2 additions & 6 deletions ss2r/algorithms/mbpo/on_policy_training_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,12 +301,8 @@ def relabel_transitions(
normalizer_params, model_params, transitions.observation, transitions.action
)

pred_backup_action = planning_env.policy_network.apply
backup_policy_params = planning_env.backup_policy_params
backup_action = jnp.tanh(
pred_backup_action(
normalizer_params, backup_policy_params, transitions.observation
)[..., : planning_env.action_size]
backup_action, _ = planning_env.backup_policy_fn(
transitions.observation, jax.random.PRNGKey(0)
)
disagreement = (
next_obs_pred.std(axis=0).mean(-1)
Expand Down
Loading