Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions ss2r/algorithms/mbpo/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def critic_loss(
action = transitions.action
scale = cost_scaling if safe else reward_scaling
gamma = safety_discounting if safe else discounting
q_old_action = qr_network.apply(
q_net = qc_network if (safe and qc_network is not None) else qr_network
q_old_action = q_net.apply(
normalizer_params, q_params, transitions.observation, action
)
key, another_key = jax.random.split(key)
Expand All @@ -101,7 +102,7 @@ def policy(obs: jax.Array) -> tuple[jax.Array, jax.Array]:
next_action = parametric_action_distribution.postprocess(next_action)
return next_action, next_log_prob

q_fn = lambda obs, action: qr_network.apply(
q_fn = lambda obs, action: q_net.apply(
normalizer_params, target_q_params, obs, action
)
target_q = target_q_fn(
Expand Down
36 changes: 16 additions & 20 deletions ss2r/algorithms/mbpo/model_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,11 @@ def step(self, state: base.State, action: jax.Array) -> base.State:
).mean(axis=-1)
backup_policy = self.policy_network.apply
backup_policy_params = self.backup_policy_params
backup_action = backup_policy(
self.initial_normalizer_params, backup_policy_params, state.obs
)[: self.action_size]
backup_action = jnp.tanh(
backup_policy(
self.initial_normalizer_params, backup_policy_params, state.obs
)[: self.action_size]
)
qc_backup = self.qc_network.apply(
self.normalizer_params,
self.backup_qc_params,
Expand All @@ -130,9 +132,11 @@ def step(self, state: base.State, action: jax.Array) -> base.State:

pred_backup_action = self.policy_network.apply
backup_policy_params = self.backup_policy_params
backup_action = pred_backup_action(
self.normalizer_params, backup_policy_params, state.obs
)[: self.action_size]
backup_action = jnp.tanh(
pred_backup_action(
self.normalizer_params, backup_policy_params, state.obs
)[: self.action_size]
)
pred_qr = self.qr_network.apply
backup_qr_params = self.backup_qr_params
pessimistic_qr_pred = pred_qr(
Expand Down Expand Up @@ -226,13 +230,13 @@ def _propagate_ensemble(
"""Propagate the ensemble predictions based on the selection method."""
# Calculate the nominal predictions
if ensemble_selection == "nominal":
# Get the average model parameters
avg_model_params = jax.tree_util.tree_map(
lambda p: jnp.mean(p, axis=0), model_params
)
next_obs, reward, cost = pred_fn(
normalizer_params, avg_model_params, obs, action
vmap_pred_fn = jax.vmap(pred_fn, in_axes=(None, 0, None, None))
next_obs_pred, reward_pred, cost_pred = vmap_pred_fn(
normalizer_params, model_params, obs, action
)
next_obs = jax.tree_map(lambda x: jnp.mean(x, axis=0), next_obs_pred)
reward = jnp.mean(reward_pred, axis=0)
cost = jnp.mean(cost_pred, axis=0)
elif ensemble_selection == "random":
vmap_pred_fn = jax.vmap(pred_fn, in_axes=(None, 0, None, None))
next_obs_pred, reward_pred, cost_pred = vmap_pred_fn(
Expand All @@ -243,14 +247,6 @@ def _propagate_ensemble(
next_obs = jax.tree_map(lambda x: x[idx], next_obs_pred)
reward = reward_pred[idx]
cost = cost_pred[idx]
elif ensemble_selection == "mean":
vmap_pred_fn = jax.vmap(pred_fn, in_axes=(None, 0, None, None))
next_obs_pred, reward_pred, cost_pred = vmap_pred_fn(
normalizer_params, model_params, obs, action
)
next_obs = jax.tree_map(lambda x: jnp.mean(x, axis=0), next_obs_pred)
reward = jnp.mean(reward_pred, axis=0)
cost = jnp.mean(cost_pred, axis=0)
else:
raise ValueError(f"Unknown ensemble selection: {ensemble_selection}")
return next_obs, reward, cost
8 changes: 5 additions & 3 deletions ss2r/algorithms/mbpo/on_policy_training_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,11 @@ def relabel_transitions(

pred_backup_action = planning_env.policy_network.apply
backup_policy_params = planning_env.backup_policy_params
backup_action = pred_backup_action(
normalizer_params, backup_policy_params, transitions.observation
)[..., : planning_env.action_size]
backup_action = jnp.tanh(
pred_backup_action(
normalizer_params, backup_policy_params, transitions.observation
)[..., : planning_env.action_size]
)
disagreement = (
next_obs_pred.std(axis=0).mean(-1)
if isinstance(next_obs_pred, jax.Array)
Expand Down
Loading