Skip to content

bump jax to 0.10, drop distrax to unblock the persistent compile cache#21

Merged
PaulHax merged 1 commit into
mainfrom
worktree-jax-cache-bug
May 15, 2026
Merged

bump jax to 0.10, drop distrax to unblock the persistent compile cache#21
PaulHax merged 1 commit into
mainfrom
worktree-jax-cache-bug

Conversation

@PaulHax

@PaulHax PaulHax commented May 15, 2026

Copy link
Copy Markdown
Collaborator

Summary

  • JAX 0.4.38 silently no-ops the persistent compilation cache when any jax.Array is materialized before the cache dir is set (jax-ml/jax#25768). Every process pays full cold compile (~7 min GPU / ~20 min/cell CPU). JAX 0.5.x has a worse bug: cache hits silently return stale .at[].set() updates inside lax.fori_loop / lax.cond (jax-ml/jax#31733) — corrupts the env step. Reproduced here with 17/35 blue-Remove tests failing on a warm cache under 0.5.1. Fixed from 0.7.1 onward; this PR lands on 0.10.0.
  • Drops distrax because its transitive tensorflow-probability<=0.25 imports jax.interpreters.xla.pytype_aval_mappings, removed in JAX 0.7. All 5 callsites used distrax.Categorical for .sample() / .log_prob() / .entropy(), replaced by a local 31-line src/jaxborg/policies/categorical.py (flax struct, jit/vmap/scan-compatible). Also removes TFP, gast, decorator from the dep graph.
  • Knock-on bumps from JAX 0.7+ needing ml-dtypes>=0.5 (numpy 2 C ABI):
    • numpy 1.26.4 → 2.3.5 (overrides cyborg's pin; cyborg runtime is numpy-2-compatible empirically)
    • scipy 1.12 → 1.17 (overrides jaxmarl's precautionary scipy<=1.12; cf. jaxmarl commit 3ffa5b8f and issue #175)
    • torch 2.2 → 2.10 (overrides cyborg's pin; needed for numpy-2 init)
  • Retires tests/subsystems/test_fsm_red_agent.py::test_fsm_hidden_state_applies_after_completion_step's reliance on accidental RNG alignment between CybORG's numpy and JAX's threefry. It was walking both pipelines forward on seed=0 and asserting they hit the same FSM sequence — parity-by-coincidence that breaks any time JAX's PRNG layout changes. Restructured as a pure-state check of the two-stage delayed-update mechanism that the test name actually promises.

Test plan

  • Fast suite: 772/772, cold 154s → warm 115s (cache hit, correct)
  • Full suite incl. slow: 1214 passed, 100 skipped, 3 xfailed (xfails are pre-existing, unrelated)
  • Cache miscompile probe: tests/subsystems/test_blue_remove.py 35/35 cold and 35/35 warm against the same cache dir — under 0.5.1 the same probe gave 17/35 failures on warm
  • uv run ruff check . && uv run ruff format . clean

JAX 0.4.38 silently no-ops the persistent compilation cache when any
jax.Array is materialized before the cache dir is set (jax-ml/jax#25768),
so every process paid full cold compile (~7 min GPU / ~20 min/cell CPU).
JAX 0.5.x has a worse bug: cache hits silently return stale .at[].set()
updates inside lax.fori_loop / lax.cond (jax-ml/jax#31733), corrupting
the env step — reproduced here with 17/35 blue-Remove tests failing on a
warm cache under 0.5.1. Bug is fixed from 0.7.1 onward; we land on 0.10.

Distrax is dropped because its transitive tensorflow-probability<=0.25
imports jax.interpreters.xla.pytype_aval_mappings, removed in JAX 0.7.
The 5 callsites only used distrax.Categorical for sampling / log_prob /
entropy, all easily expressed on jax.random.categorical and
jax.nn.log_softmax — see src/jaxborg/policies/categorical.py (31 lines,
flax struct so it's jit/vmap/scan-compatible). Dropping distrax also
removes TFP, gast, and decorator from the dep graph.

Knock-on bumps from JAX 0.7+ requiring ml-dtypes>=0.5 (numpy 2 C ABI):
  * numpy 1.26.4 -> 2.3.5 (overrides cyborg's pin; cyborg runtime is
    numpy-2-compatible empirically)
  * scipy 1.12 -> 1.17 (overrides jaxmarl's `scipy<=1.12` precautionary
    upper bound; cf. jaxmarl commit 3ffa5b8f and issue #175)
  * torch 2.2 -> 2.10 (overrides cyborg's pin; needed for numpy-2 init)

Also retires a flaky test that relied on accidental RNG alignment
between CybORG's numpy and JAX's threefry:
tests/subsystems/test_fsm_red_agent.py::test_fsm_hidden_state_applies_after_completion_step
walked both pipelines forward on seed=0 and asserted they hit the same
FSM sequence — a parity-by-coincidence that breaks any time JAX's PRNG
layout changes. Restructured as a pure-state check of the two-stage
delayed-update mechanism (fsm_red_schedule_post_step_update stages
red_fsm_delayed_states; fsm_red_apply_delayed_update commits it on the
next step), which is the actual invariant the test name promises.

Verification:
  * fast suite: 772/772, cold 154s -> warm 115s (cache hit, correct)
  * full suite incl. slow: 1214 passed, 100 skipped, 3 xfailed
    (xfails are pre-existing, unrelated)
  * cache miscompile probe: blue_remove.py 35/35 cold and 35/35 warm
    against the same cache dir — under 0.5.1 the same probe gave 17/35
    failures on warm
@PaulHax PaulHax force-pushed the worktree-jax-cache-bug branch from b4cd079 to b221685 Compare May 15, 2026 18:53
@PaulHax PaulHax merged commit eafc60a into main May 15, 2026
2 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant