diff --git a/ss2r/algorithms/mbpo/losses.py b/ss2r/algorithms/mbpo/losses.py index de95f5ade..99b586e1c 100644 --- a/ss2r/algorithms/mbpo/losses.py +++ b/ss2r/algorithms/mbpo/losses.py @@ -83,7 +83,8 @@ def critic_loss( action = transitions.action scale = cost_scaling if safe else reward_scaling gamma = safety_discounting if safe else discounting - q_old_action = qr_network.apply( + q_net = qc_network if (safe and qc_network is not None) else qr_network + q_old_action = q_net.apply( normalizer_params, q_params, transitions.observation, action ) key, another_key = jax.random.split(key) @@ -101,7 +102,7 @@ def policy(obs: jax.Array) -> tuple[jax.Array, jax.Array]: next_action = parametric_action_distribution.postprocess(next_action) return next_action, next_log_prob - q_fn = lambda obs, action: qr_network.apply( + q_fn = lambda obs, action: q_net.apply( normalizer_params, target_q_params, obs, action ) target_q = target_q_fn( diff --git a/ss2r/algorithms/mbpo/model_env.py b/ss2r/algorithms/mbpo/model_env.py index 7dd689229..5834dfb15 100644 --- a/ss2r/algorithms/mbpo/model_env.py +++ b/ss2r/algorithms/mbpo/model_env.py @@ -112,9 +112,11 @@ def step(self, state: base.State, action: jax.Array) -> base.State: ).mean(axis=-1) backup_policy = self.policy_network.apply backup_policy_params = self.backup_policy_params - backup_action = backup_policy( - self.initial_normalizer_params, backup_policy_params, state.obs - )[: self.action_size] + backup_action = jnp.tanh( + backup_policy( + self.initial_normalizer_params, backup_policy_params, state.obs + )[: self.action_size] + ) qc_backup = self.qc_network.apply( self.normalizer_params, self.backup_qc_params, @@ -130,9 +132,11 @@ def step(self, state: base.State, action: jax.Array) -> base.State: pred_backup_action = self.policy_network.apply backup_policy_params = self.backup_policy_params - backup_action = pred_backup_action( - self.normalizer_params, backup_policy_params, state.obs - )[: self.action_size] + backup_action = jnp.tanh( + pred_backup_action( + self.normalizer_params, backup_policy_params, state.obs + )[: self.action_size] + ) pred_qr = self.qr_network.apply backup_qr_params = self.backup_qr_params pessimistic_qr_pred = pred_qr( @@ -226,13 +230,13 @@ def _propagate_ensemble( """Propagate the ensemble predictions based on the selection method.""" # Calculate the nominal predictions if ensemble_selection == "nominal": - # Get the average model parameters - avg_model_params = jax.tree_util.tree_map( - lambda p: jnp.mean(p, axis=0), model_params - ) - next_obs, reward, cost = pred_fn( - normalizer_params, avg_model_params, obs, action + vmap_pred_fn = jax.vmap(pred_fn, in_axes=(None, 0, None, None)) + next_obs_pred, reward_pred, cost_pred = vmap_pred_fn( + normalizer_params, model_params, obs, action ) + next_obs = jax.tree_map(lambda x: jnp.mean(x, axis=0), next_obs_pred) + reward = jnp.mean(reward_pred, axis=0) + cost = jnp.mean(cost_pred, axis=0) elif ensemble_selection == "random": vmap_pred_fn = jax.vmap(pred_fn, in_axes=(None, 0, None, None)) next_obs_pred, reward_pred, cost_pred = vmap_pred_fn( @@ -243,14 +247,6 @@ def _propagate_ensemble( next_obs = jax.tree_map(lambda x: x[idx], next_obs_pred) reward = reward_pred[idx] cost = cost_pred[idx] - elif ensemble_selection == "mean": - vmap_pred_fn = jax.vmap(pred_fn, in_axes=(None, 0, None, None)) - next_obs_pred, reward_pred, cost_pred = vmap_pred_fn( - normalizer_params, model_params, obs, action - ) - next_obs = jax.tree_map(lambda x: jnp.mean(x, axis=0), next_obs_pred) - reward = jnp.mean(reward_pred, axis=0) - cost = jnp.mean(cost_pred, axis=0) else: raise ValueError(f"Unknown ensemble selection: {ensemble_selection}") return next_obs, reward, cost diff --git a/ss2r/algorithms/mbpo/on_policy_training_step.py b/ss2r/algorithms/mbpo/on_policy_training_step.py index 5bd20ddfc..c90d552a1 100644 --- a/ss2r/algorithms/mbpo/on_policy_training_step.py +++ b/ss2r/algorithms/mbpo/on_policy_training_step.py @@ -303,9 +303,11 @@ def relabel_transitions( pred_backup_action = planning_env.policy_network.apply backup_policy_params = planning_env.backup_policy_params - backup_action = pred_backup_action( - normalizer_params, backup_policy_params, transitions.observation - )[..., : planning_env.action_size] + backup_action = jnp.tanh( + pred_backup_action( + normalizer_params, backup_policy_params, transitions.observation + )[..., : planning_env.action_size] + ) disagreement = ( next_obs_pred.std(axis=0).mean(-1) if isinstance(next_obs_pred, jax.Array)