diff --git a/ss2r/algorithms/mbpo/model_env.py b/ss2r/algorithms/mbpo/model_env.py index 5834dfb1..c7b07aae 100644 --- a/ss2r/algorithms/mbpo/model_env.py +++ b/ss2r/algorithms/mbpo/model_env.py @@ -3,6 +3,7 @@ from brax import envs from brax.envs import base +from ss2r.algorithms.mbpo.networks import make_inference_fn from ss2r.algorithms.sac.types import float32 @@ -46,6 +47,15 @@ def __init__( self.initial_normalizer_params = ( initial_normalizer_params if initial_normalizer_params is not None else {} ) + _make_policy = make_inference_fn(mbpo_network) + self.backup_policy_fn = _make_policy( + (training_state.normalizer_params, training_state.backup_policy_params), + deterministic=True, + ) + self.backup_policy_fn_initial = _make_policy( + (self.initial_normalizer_params, training_state.backup_policy_params), + deterministic=True, + ) def reset(self, rng: jax.Array) -> base.State: sample_key, model_key = jax.random.split(rng) @@ -110,12 +120,8 @@ def step(self, state: base.State, action: jax.Array) -> base.State: state.obs, action, ).mean(axis=-1) - backup_policy = self.policy_network.apply - backup_policy_params = self.backup_policy_params - backup_action = jnp.tanh( - backup_policy( - self.initial_normalizer_params, backup_policy_params, state.obs - )[: self.action_size] + backup_action, _ = self.backup_policy_fn_initial( + state.obs, jax.random.PRNGKey(0) ) qc_backup = self.qc_network.apply( self.normalizer_params, @@ -130,12 +136,8 @@ def step(self, state: base.State, action: jax.Array) -> base.State: done, ) - pred_backup_action = self.policy_network.apply - backup_policy_params = self.backup_policy_params - backup_action = jnp.tanh( - pred_backup_action( - self.normalizer_params, backup_policy_params, state.obs - )[: self.action_size] + backup_action, _ = self.backup_policy_fn( + state.obs, jax.random.PRNGKey(0) ) pred_qr = self.qr_network.apply backup_qr_params = self.backup_qr_params diff --git a/ss2r/algorithms/mbpo/on_policy_training_step.py b/ss2r/algorithms/mbpo/on_policy_training_step.py index c90d552a..ce7ffdaa 100644 --- a/ss2r/algorithms/mbpo/on_policy_training_step.py +++ b/ss2r/algorithms/mbpo/on_policy_training_step.py @@ -301,12 +301,8 @@ def relabel_transitions( normalizer_params, model_params, transitions.observation, transitions.action ) - pred_backup_action = planning_env.policy_network.apply - backup_policy_params = planning_env.backup_policy_params - backup_action = jnp.tanh( - pred_backup_action( - normalizer_params, backup_policy_params, transitions.observation - )[..., : planning_env.action_size] + backup_action, _ = planning_env.backup_policy_fn( + transitions.observation, jax.random.PRNGKey(0) ) disagreement = ( next_obs_pred.std(axis=0).mean(-1)