Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mujoco_warp/_src/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,10 @@ def make_data(
),
# equality constraints
"eq_active": wp.array(np.tile(mjm.eq_active0.astype(bool), (nworld, 1)), shape=(nworld, mjm.neq), dtype=bool),
# sleep state: all trees start fully awake
"tree_asleep": wp.array(np.full((nworld, mjm.ntree), -(1 + types.MJ_MINAWAKE)), dtype=int),
"tree_awake": wp.array(np.ones((nworld, mjm.ntree)), dtype=int),
"body_awake": wp.array(np.ones((nworld, mjm.nbody)), dtype=int),
# flexedge
"flexedge_J": None,
}
Expand Down Expand Up @@ -889,6 +893,10 @@ def put_data(
"actuator_moment": None,
"flexedge_J": None,
"nacon": None,
# sleep state: all trees start fully awake
"tree_asleep": wp.array(np.full((nworld, mjm.ntree), -(1 + types.MJ_MINAWAKE)), dtype=int),
"tree_awake": wp.array(np.ones((nworld, mjm.ntree)), dtype=int),
"body_awake": wp.array(np.ones((nworld, mjm.nbody)), dtype=int),
}
for f in dataclasses.fields(types.Data):
if f.name in d_kwargs:
Expand Down
60 changes: 60 additions & 0 deletions mujoco_warp/_src/io_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,66 @@ def test_eq_active(self, active, make_data):

_assert_eq(d.eq_active.numpy()[0], mjd.eq_active, "eq_active")

@parameterized.parameters(True, False)
def test_sleep_state_initial(self, use_make_data):
"""Tests that make_data and put_data initialize all trees awake."""
mjm = mujoco.MjModel.from_xml_string("""
<mujoco>
<worldbody>
<body>
<joint/>
<geom size=".1"/>
</body>
</worldbody>
</mujoco>
""")
mjd = mujoco.MjData(mjm)

if use_make_data:
d = mjwarp.make_data(mjm)
else:
d = mjwarp.put_data(mjm, mjd)

# All trees should be awake (tree_asleep < 0)
tree_asleep = d.tree_asleep.numpy()
self.assertTrue((tree_asleep < 0).all(), "tree_asleep should be < 0 (awake)")
# tree_awake should all be 1
tree_awake = d.tree_awake.numpy()
np.testing.assert_array_equal(tree_awake, 1, "tree_awake should be 1")
# body_awake should all be 1
body_awake = d.body_awake.numpy()
np.testing.assert_array_equal(body_awake, 1, "body_awake should be 1")

def test_sleep_policy_import(self):
"""Tests that tree_sleep_policy matches MuJoCo."""
mjm = mujoco.MjModel.from_xml_string("""
<mujoco>
<worldbody>
<body>
<joint/>
<geom size=".1"/>
</body>
</worldbody>
</mujoco>
""")
m = mjwarp.put_model(mjm)
np.testing.assert_array_equal(m.tree_sleep_policy.numpy(), mjm.tree_sleep_policy)

def test_dof_length_import(self):
"""Tests that dof_length matches MuJoCo."""
mjm = mujoco.MjModel.from_xml_string("""
<mujoco>
<worldbody>
<body>
<joint/>
<geom size=".1"/>
</body>
</worldbody>
</mujoco>
""")
m = mjwarp.put_model(mjm)
np.testing.assert_allclose(m.dof_length.numpy(), mjm.dof_length)

def test_tree_structure_fields(self):
"""Tests that tree structure fields match between types.Model and mjModel."""
mjm, _, m, _ = test_data.fixture("pendula.xml")
Expand Down
61 changes: 61 additions & 0 deletions mujoco_warp/_src/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
MJ_MAXIMP = mujoco.mjMAXIMP # maximum constraint impedance
MJ_MAXCONPAIR = mujoco.mjMAXCONPAIR
MJ_MINMU = mujoco.mjMINMU # minimum friction
# TODO(team): set with mujoco.mjMINAWAKE after mjwarp depends
# on mujoco > 3.4.0 in pyproject.toml
MJ_MINAWAKE = 10 # minimum number of timesteps before sleeping
# maximum size (by number of edges) of an horizon in EPA algorithm
MJ_MAX_EPAHORIZON = 24
# maximum average number of trianglarfaces EPA can insert at each iteration
Expand Down Expand Up @@ -175,14 +178,50 @@ class EnableBit(enum.IntFlag):
ENERGY: energy computation
INVDISCRETE: discrete-time inverse dynamics
MULTICCD: multiple contacts with CCD
SLEEP: sleeping
"""

ENERGY = mujoco.mjtEnableBit.mjENBL_ENERGY
INVDISCRETE = mujoco.mjtEnableBit.mjENBL_INVDISCRETE
MULTICCD = mujoco.mjtEnableBit.mjENBL_MULTICCD
SLEEP = mujoco.mjtEnableBit.mjENBL_SLEEP
# unsupported: OVERRIDE, FWDINV, ISLAND


class SleepPolicy(enum.IntEnum):
"""Per-tree sleep policy.

Attributes:
AUTO: compiler chooses sleep policy
AUTO_NEVER: compiler sleep policy: never
AUTO_ALLOWED: compiler sleep policy: allowed
NEVER: user sleep policy: never
ALLOWED: user sleep policy: allowed
INIT: user sleep policy: initialized asleep
"""

AUTO = mujoco.mjtSleepPolicy.mjSLEEP_AUTO
AUTO_NEVER = mujoco.mjtSleepPolicy.mjSLEEP_AUTO_NEVER
AUTO_ALLOWED = mujoco.mjtSleepPolicy.mjSLEEP_AUTO_ALLOWED
NEVER = mujoco.mjtSleepPolicy.mjSLEEP_NEVER
ALLOWED = mujoco.mjtSleepPolicy.mjSLEEP_ALLOWED
INIT = mujoco.mjtSleepPolicy.mjSLEEP_INIT


class SleepState(enum.IntEnum):
"""Sleep state for bodies.

Attributes:
ASLEEP: body is asleep
AWAKE: body is awake
STATIC: body is static (world body or mocap)
"""

ASLEEP = 0
AWAKE = 1
STATIC = 2


class TrnType(enum.IntEnum):
"""Type of actuator transmission.

Expand Down Expand Up @@ -662,6 +701,7 @@ class Option:
tolerance: main solver tolerance
ls_tolerance: CG/Newton linesearch tolerance
ccd_tolerance: convex collision detection tolerance
sleep_tolerance: sleep velocity tolerance
density: density of medium
viscosity: viscosity of medium
gravity: gravitational acceleration
Expand Down Expand Up @@ -698,6 +738,7 @@ class Option:
tolerance: array("*", float)
ls_tolerance: array("*", float)
ccd_tolerance: array("*", float)
sleep_tolerance: float
density: array("*", float)
viscosity: array("*", float)
gravity: array("*", wp.vec3)
Expand Down Expand Up @@ -853,9 +894,11 @@ class Model:
dof_armature: dof armature inertia/mass (*, nv)
dof_damping: damping coefficient (*, nv)
dof_invweight0: diag. inverse inertia in qpos0 (*, nv)
dof_length: dof length for weighting velocity norm (nv,)
tree_bodynum: number of bodies in tree (incl. root) (ntree,)
tree_dofadr: start address of tree's dofs (ntree,)
tree_dofnum: number of dofs in tree (ntree,)
tree_sleep_policy: tree sleep policy (SleepPolicy) (ntree,)
geom_type: geometric type (GeomType) (ngeom,)
geom_contype: geom contact type (ngeom,)
geom_conaffinity: geom contact affinity (ngeom,)
Expand Down Expand Up @@ -1213,9 +1256,11 @@ class Model:
dof_armature: array("*", "nv", float)
dof_damping: array("*", "nv", float)
dof_invweight0: array("*", "nv", float)
dof_length: array("nv", float)
tree_bodynum: array("ntree", int)
tree_dofadr: array("ntree", int)
tree_dofnum: array("ntree", int)
tree_sleep_policy: array("ntree", int)
geom_type: array("ngeom", int)
geom_contype: array("ngeom", int)
geom_conaffinity: array("ngeom", int)
Expand Down Expand Up @@ -1561,6 +1606,9 @@ class Data:
nf: number of friction constraints (nworld,)
nl: number of limit constraints (nworld,)
nefc: number of constraints (nworld,)
ntree_awake: number of awake trees (nworld,)
nbody_awake: number of awake bodies (nworld,)
nv_awake: number of awake dofs (nworld,)
time: simulation time (nworld,)
energy: potential, kinetic energy (nworld, 2)
qpos: position (nworld, nq)
Expand All @@ -1576,6 +1624,7 @@ class Data:
qacc: acceleration (nworld, nv)
act_dot: time-derivative of actuator activation (nworld, na)
sensordata: sensor data array (nworld, nsensordata,)
tree_asleep: tree asleep counter; >=0: asleep cycle (nworld, ntree)
xpos: Cartesian position of body frame (nworld, nbody, 3)
xquat: Cartesian orientation of body frame (nworld, nbody, 4)
xmat: Cartesian orientation of body frame (nworld, nbody, 3, 3)
Expand Down Expand Up @@ -1611,6 +1660,10 @@ class Data:
qLD: L'*D*L factorization of M (nworld, nv, nv) if dense
(nworld, 1, nC) if sparse
qLDiagInv: 1/diag(D) (nworld, nv)
tree_awake: is tree awake; 0: asleep; 1: awake (nworld, ntree)
body_awake: body sleep state (SleepState) (nworld, nbody)
body_awake_ind: indices of awake/static bodies (nworld, nbody)
dof_awake_ind: indices of awake dofs (nworld, nv)
flexedge_velocity: flex edge velocities (nworld, nflexedge)
ten_velocity: tendon velocities (nworld, ntendon)
actuator_velocity: actuator velocities (nworld, nu)
Expand Down Expand Up @@ -1654,6 +1707,9 @@ class Data:
nf: array("nworld", int)
nl: array("nworld", int)
nefc: array("nworld", int)
ntree_awake: array("nworld", int)
nbody_awake: array("nworld", int)
nv_awake: array("nworld", int)
time: array("nworld", float)
energy: array("nworld", wp.vec2)
qpos: array("nworld", "nq", float)
Expand All @@ -1669,6 +1725,7 @@ class Data:
qacc: array("nworld", "nv", float)
act_dot: array("nworld", "na", float)
sensordata: array("nworld", "nsensordata", float)
tree_asleep: array("nworld", "ntree", int)
xpos: array("nworld", "nbody", wp.vec3)
xquat: array("nworld", "nbody", wp.quat)
xmat: array("nworld", "nbody", wp.mat33)
Expand Down Expand Up @@ -1702,6 +1759,10 @@ class Data:
qM: wp.array3d(dtype=float)
qLD: wp.array3d(dtype=float)
qLDiagInv: array("nworld", "nv", float)
tree_awake: array("nworld", "ntree", int)
body_awake: array("nworld", "nbody", int)
body_awake_ind: array("nworld", "nbody", int)
dof_awake_ind: array("nworld", "nv", int)
flexedge_velocity: array("nworld", "nflexedge", float)
ten_velocity: array("nworld", "ntendon", float)
actuator_velocity: array("nworld", "nu", float)
Expand Down
Loading