From 30cd9c04d82bac93ef53f43443bd969cd8e74f06 Mon Sep 17 00:00:00 2001 From: Taylor Howell Date: Mon, 26 Jan 2026 23:56:41 +0000 Subject: [PATCH 1/2] Add sleep state fields - Add MJ_MINAWAKE constant, SLEEP enable bit, SleepPolicy and SleepState enums - Add sleep_tolerance to Option, tree_sleep_policy/dof_length to Model - Add tree_asleep, tree_awake, body_awake, dof_awake_ind, body_awake_ind, nv_awake, nbody_awake, ntree_awake to Data - Initialize sleep state in make_data and put_data (all trees start awake) - Add tests for sleep state initialization, sleep policy and dof_length import --- mujoco_warp/_src/io.py | 15 ++++++++- mujoco_warp/_src/io_test.py | 60 ++++++++++++++++++++++++++++++++++++ mujoco_warp/_src/types.py | 61 +++++++++++++++++++++++++++++++++++++ 3 files changed, 135 insertions(+), 1 deletion(-) diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index 7873f7e04..6c8c33265 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -594,7 +594,10 @@ def geom_trid_index(i, j): m.qM_madr_ij.append(madr_ij) # place m on device - sizes = dict({"*": 1}, **{f.name: getattr(m, f.name) for f in dataclasses.fields(types.Model) if f.type is int}) + # TODO(team): remove ntree once field is added to types.Model + sizes = dict( + {"*": 1, "ntree": mjm.ntree}, **{f.name: getattr(m, f.name) for f in dataclasses.fields(types.Model) if f.type is int} + ) for f in dataclasses.fields(types.Model): if isinstance(f.type, wp.array): setattr(m, f.name, _create_array(getattr(m, f.name), f.type, sizes)) @@ -691,6 +694,7 @@ def make_data( sizes["nworld"] = nworld sizes["naconmax"] = naconmax sizes["njmax"] = njmax + sizes["ntree"] = mjm.ntree contact = types.Contact(**{f.name: _create_array(None, f.type, sizes) for f in dataclasses.fields(types.Contact)}) efc = types.Constraint(**{f.name: _create_array(None, f.type, sizes) for f in dataclasses.fields(types.Constraint)}) @@ -728,6 +732,10 @@ def make_data( ), # equality constraints "eq_active": wp.array(np.tile(mjm.eq_active0.astype(bool), (nworld, 1)), shape=(nworld, mjm.neq), dtype=bool), + # sleep state: all trees start fully awake + "tree_asleep": wp.array(np.full((nworld, mjm.ntree), -(1 + types.MJ_MINAWAKE)), dtype=int), + "tree_awake": wp.array(np.ones((nworld, mjm.ntree)), dtype=int), + "body_awake": wp.array(np.ones((nworld, mjm.nbody)), dtype=int), } for f in dataclasses.fields(types.Data): if f.name in d_kwargs: @@ -805,6 +813,7 @@ def put_data( sizes["nworld"] = nworld sizes["naconmax"] = naconmax sizes["njmax"] = njmax + sizes["ntree"] = mjm.ntree # ensure static geom positions are computed # TODO: remove once MjData creation semantics are fixed @@ -882,6 +891,10 @@ def put_data( "ne_jnt": None, "ne_ten": None, "ne_flex": None, + # sleep state: all trees start fully awake + "tree_asleep": wp.array(np.full((nworld, mjm.ntree), -(1 + types.MJ_MINAWAKE)), dtype=int), + "tree_awake": wp.array(np.ones((nworld, mjm.ntree)), dtype=int), + "body_awake": wp.array(np.ones((nworld, mjm.nbody)), dtype=int), } for f in dataclasses.fields(types.Data): if f.name in d_kwargs: diff --git a/mujoco_warp/_src/io_test.py b/mujoco_warp/_src/io_test.py index 91f81611e..09f9134bc 100644 --- a/mujoco_warp/_src/io_test.py +++ b/mujoco_warp/_src/io_test.py @@ -645,6 +645,66 @@ def test_eq_active(self, active, make_data): _assert_eq(d.eq_active.numpy()[0], mjd.eq_active, "eq_active") + @parameterized.parameters(True, False) + def test_sleep_state_initial(self, use_make_data): + """Tests that make_data and put_data initialize all trees awake.""" + mjm = mujoco.MjModel.from_xml_string(""" + + + + + + + + + """) + mjd = mujoco.MjData(mjm) + + if use_make_data: + d = mjwarp.make_data(mjm) + else: + d = mjwarp.put_data(mjm, mjd) + + # All trees should be awake (tree_asleep < 0) + tree_asleep = d.tree_asleep.numpy() + self.assertTrue((tree_asleep < 0).all(), "tree_asleep should be < 0 (awake)") + # tree_awake should all be 1 + tree_awake = d.tree_awake.numpy() + np.testing.assert_array_equal(tree_awake, 1, "tree_awake should be 1") + # body_awake should all be 1 + body_awake = d.body_awake.numpy() + np.testing.assert_array_equal(body_awake, 1, "body_awake should be 1") + + def test_sleep_policy_import(self): + """Tests that tree_sleep_policy matches MuJoCo.""" + mjm = mujoco.MjModel.from_xml_string(""" + + + + + + + + + """) + m = mjwarp.put_model(mjm) + np.testing.assert_array_equal(m.tree_sleep_policy.numpy(), mjm.tree_sleep_policy) + + def test_dof_length_import(self): + """Tests that dof_length matches MuJoCo.""" + mjm = mujoco.MjModel.from_xml_string(""" + + + + + + + + + """) + m = mjwarp.put_model(mjm) + np.testing.assert_allclose(m.dof_length.numpy(), mjm.dof_length) + if __name__ == "__main__": wp.init() diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index 1e1b61e8d..fca8d44e4 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -24,6 +24,9 @@ MJ_MAXIMP = mujoco.mjMAXIMP # maximum constraint impedance MJ_MAXCONPAIR = mujoco.mjMAXCONPAIR MJ_MINMU = mujoco.mjMINMU # minimum friction +# TODO(team): set with mujoco.mjMINAWAKE after mjwarp depends +# on mujoco > 3.4.0 in pyproject.toml +MJ_MINAWAKE = 10 # minimum number of timesteps before sleeping # maximum size (by number of edges) of an horizon in EPA algorithm MJ_MAX_EPAHORIZON = 24 # maximum average number of trianglarfaces EPA can insert at each iteration @@ -174,14 +177,50 @@ class EnableBit(enum.IntFlag): ENERGY: energy computation INVDISCRETE: discrete-time inverse dynamics MULTICCD: multiple contacts with CCD + SLEEP: sleeping """ ENERGY = mujoco.mjtEnableBit.mjENBL_ENERGY INVDISCRETE = mujoco.mjtEnableBit.mjENBL_INVDISCRETE MULTICCD = mujoco.mjtEnableBit.mjENBL_MULTICCD + SLEEP = mujoco.mjtEnableBit.mjENBL_SLEEP # unsupported: OVERRIDE, FWDINV, ISLAND +class SleepPolicy(enum.IntEnum): + """Per-tree sleep policy. + + Attributes: + AUTO: compiler chooses sleep policy + AUTO_NEVER: compiler sleep policy: never + AUTO_ALLOWED: compiler sleep policy: allowed + NEVER: user sleep policy: never + ALLOWED: user sleep policy: allowed + INIT: user sleep policy: initialized asleep + """ + + AUTO = mujoco.mjtSleepPolicy.mjSLEEP_AUTO + AUTO_NEVER = mujoco.mjtSleepPolicy.mjSLEEP_AUTO_NEVER + AUTO_ALLOWED = mujoco.mjtSleepPolicy.mjSLEEP_AUTO_ALLOWED + NEVER = mujoco.mjtSleepPolicy.mjSLEEP_NEVER + ALLOWED = mujoco.mjtSleepPolicy.mjSLEEP_ALLOWED + INIT = mujoco.mjtSleepPolicy.mjSLEEP_INIT + + +class SleepState(enum.IntEnum): + """Sleep state for bodies. + + Attributes: + ASLEEP: body is asleep + AWAKE: body is awake + STATIC: body is static (world body or mocap) + """ + + ASLEEP = 0 + AWAKE = 1 + STATIC = 2 + + class TrnType(enum.IntEnum): """Type of actuator transmission. @@ -647,6 +686,7 @@ class Option: tolerance: main solver tolerance ls_tolerance: CG/Newton linesearch tolerance ccd_tolerance: convex collision detection tolerance + sleep_tolerance: sleep velocity tolerance density: density of medium viscosity: viscosity of medium gravity: gravitational acceleration @@ -683,6 +723,7 @@ class Option: tolerance: array("*", float) ls_tolerance: array("*", float) ccd_tolerance: array("*", float) + sleep_tolerance: float density: array("*", float) viscosity: array("*", float) gravity: array("*", wp.vec3) @@ -835,6 +876,8 @@ class Model: dof_armature: dof armature inertia/mass (*, nv) dof_damping: damping coefficient (*, nv) dof_invweight0: diag. inverse inertia in qpos0 (*, nv) + dof_length: dof length for weighting velocity norm (nv,) + tree_sleep_policy: tree sleep policy (SleepPolicy) (ntree,) geom_type: geometric type (GeomType) (ngeom,) geom_contype: geom contact type (ngeom,) geom_conaffinity: geom contact affinity (ngeom,) @@ -1186,6 +1229,8 @@ class Model: dof_armature: array("*", "nv", float) dof_damping: array("*", "nv", float) dof_invweight0: array("*", "nv", float) + dof_length: array("nv", float) + tree_sleep_policy: array("ntree", int) geom_type: array("ngeom", int) geom_contype: array("ngeom", int) geom_conaffinity: array("ngeom", int) @@ -1562,6 +1607,9 @@ class Data: nf: number of friction constraints (nworld,) nl: number of limit constraints (nworld,) nefc: number of constraints (nworld,) + ntree_awake: number of awake trees (nworld,) + nbody_awake: number of awake bodies (nworld,) + nv_awake: number of awake dofs (nworld,) time: simulation time (nworld,) energy: potential, kinetic energy (nworld, 2) qpos: position (nworld, nq) @@ -1577,6 +1625,7 @@ class Data: qacc: acceleration (nworld, nv) act_dot: time-derivative of actuator activation (nworld, na) sensordata: sensor data array (nworld, nsensordata,) + tree_asleep: tree asleep counter; >=0: asleep cycle (nworld, ntree) xpos: Cartesian position of body frame (nworld, nbody, 3) xquat: Cartesian orientation of body frame (nworld, nbody, 4) xmat: Cartesian orientation of body frame (nworld, nbody, 3, 3) @@ -1612,6 +1661,10 @@ class Data: qLD: L'*D*L factorization of M (nworld, nv, nv) if dense (nworld, 1, nC) if sparse qLDiagInv: 1/diag(D) (nworld, nv) + tree_awake: is tree awake; 0: asleep; 1: awake (nworld, ntree) + body_awake: body sleep state (SleepState) (nworld, nbody) + body_awake_ind: indices of awake/static bodies (nworld, nbody) + dof_awake_ind: indices of awake dofs (nworld, nv) flexedge_velocity: flex edge velocities (nworld, nflexedge) ten_velocity: tendon velocities (nworld, ntendon) actuator_velocity: actuator velocities (nworld, nu) @@ -1660,6 +1713,9 @@ class Data: nf: array("nworld", int) nl: array("nworld", int) nefc: array("nworld", int) + ntree_awake: array("nworld", int) + nbody_awake: array("nworld", int) + nv_awake: array("nworld", int) time: array("nworld", float) energy: array("nworld", wp.vec2) qpos: array("nworld", "nq", float) @@ -1675,6 +1731,7 @@ class Data: qacc: array("nworld", "nv", float) act_dot: array("nworld", "na", float) sensordata: array("nworld", "nsensordata", float) + tree_asleep: array("nworld", "ntree", int) xpos: array("nworld", "nbody", wp.vec3) xquat: array("nworld", "nbody", wp.quat) xmat: array("nworld", "nbody", wp.mat33) @@ -1708,6 +1765,10 @@ class Data: qM: wp.array3d(dtype=float) qLD: wp.array3d(dtype=float) qLDiagInv: array("nworld", "nv", float) + tree_awake: array("nworld", "ntree", int) + body_awake: array("nworld", "nbody", int) + body_awake_ind: array("nworld", "nbody", int) + dof_awake_ind: array("nworld", "nv", int) flexedge_velocity: array("nworld", "nflexedge", float) ten_velocity: array("nworld", "ntendon", float) actuator_velocity: array("nworld", "nu", float) From ad4540132fd103d28db6c866755728a13610f1e1 Mon Sep 17 00:00:00 2001 From: Taylor Howell Date: Mon, 2 Feb 2026 14:46:17 +0000 Subject: [PATCH 2/2] ntree --- mujoco_warp/_src/io.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index d9287a671..cca424cd0 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -602,10 +602,7 @@ def geom_trid_index(i, j): m.flexedge_J_colind = mjd.flexedge_J_colind.reshape(-1) # place m on device - # TODO(team): remove ntree once field is added to types.Model - sizes = dict( - {"*": 1, "ntree": mjm.ntree}, **{f.name: getattr(m, f.name) for f in dataclasses.fields(types.Model) if f.type is int} - ) + sizes = dict({"*": 1}, **{f.name: getattr(m, f.name) for f in dataclasses.fields(types.Model) if f.type is int}) for f in dataclasses.fields(types.Model): if isinstance(f.type, wp.array): setattr(m, f.name, _create_array(getattr(m, f.name), f.type, sizes)) @@ -702,7 +699,6 @@ def make_data( sizes["nworld"] = nworld sizes["naconmax"] = naconmax sizes["njmax"] = njmax - sizes["ntree"] = mjm.ntree contact = types.Contact(**{f.name: _create_array(None, f.type, sizes) for f in dataclasses.fields(types.Contact)}) efc = types.Constraint(**{f.name: _create_array(None, f.type, sizes) for f in dataclasses.fields(types.Constraint)}) @@ -825,7 +821,6 @@ def put_data( sizes["nworld"] = nworld sizes["naconmax"] = naconmax sizes["njmax"] = njmax - sizes["ntree"] = mjm.ntree # ensure static geom positions are computed # TODO: remove once MjData creation semantics are fixed