diff --git a/diffrax/_term.py b/diffrax/_term.py index 9b0c0314..ab231b2e 100644 --- a/diffrax/_term.py +++ b/diffrax/_term.py @@ -736,7 +736,8 @@ class WrapTerm(AbstractTerm[_VF, _Control]): direction: IntScalarLike def vf(self, t: RealScalarLike, y: Y, args: Args) -> _VF: - t = t * self.direction + with jax.numpy_dtype_promotion("standard"): + t = t * self.direction return self.term.vf(t, y, args) def contr(self, t0: RealScalarLike, t1: RealScalarLike, **kwargs) -> _Control: @@ -749,7 +750,8 @@ def prod(self, vf: _VF, control: _Control) -> Y: return self.term.prod(vf, control) def vf_prod(self, t: RealScalarLike, y: Y, args: Args, control: _Control) -> Y: - t = t * self.direction + with jax.numpy_dtype_promotion("standard"): + t = t * self.direction return self.term.vf_prod(t, y, args, control) def is_vf_expensive( diff --git a/test/test_integrate.py b/test/test_integrate.py index fcf3a836..9918aa20 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -631,6 +631,8 @@ def out_structure(self): return (jax.ShapeDtypeStruct((2, 3), jnp.float64),) @lx.is_symmetric.register(TestLinearOperator) + @lx.is_positive_semidefinite.register(TestLinearOperator) + @lx.is_negative_semidefinite.register(TestLinearOperator) def _(operator): del operator return False