Skip to content

Commit

Permalink
Add mujoco.MjSpec.MjStruct type.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 723552214
Change-Id: I8ecb43b4025b532ba8b48e15c9c6a2b350293d20
  • Loading branch information
quagla authored and copybara-github committed Feb 5, 2025
1 parent fe16c27 commit dec7b66
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 9 deletions.
16 changes: 11 additions & 5 deletions mjx/mujoco/mjx/_src/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================
"""Engine support functions."""
from collections.abc import Sequence
from typing import Optional, Tuple, Union, Any
from typing import Optional, Tuple, Union

import jax
from jax import numpy as jp
Expand Down Expand Up @@ -289,7 +289,7 @@ def name2id(
class BindModel(object):
"""Class holding the requested MJX Model and spec id for binding a spec to Model."""

def __init__(self, model: Model, specs: Sequence[Any]):
def __init__(self, model: Model, specs: Sequence[mujoco.MjStruct]):
self.model = model
try:
iter(specs)
Expand Down Expand Up @@ -371,15 +371,19 @@ def __getattr__(self, name: str):
return getattr(self.model, self.prefix + name)[self.id, ...]


def _bind_model(self: Model, obj: Sequence[Any]) -> BindModel:
def _bind_model(
self: Model, obj: Sequence[mujoco.MjStruct]
) -> BindModel:
"""Bind a Mujoco spec to an MJX Model."""
return BindModel(self, obj)


class BindData(object):
"""Class holding the requested MJX Data and spec id for binding a spec to Data."""

def __init__(self, data: Data, model: Model, specs: Sequence[Any]):
def __init__(
self, data: Data, model: Model, specs: Sequence[mujoco.MjStruct]
):
self.data = data
self.model = model
try:
Expand Down Expand Up @@ -499,7 +503,9 @@ def set(self, name: str, value: jax.Array) -> Data:
return self.data.replace(**{self.__getname(name): array})


def _bind_data(self: Data, model: Model, obj: Sequence[Any]) -> BindData:
def _bind_data(
self: Data, model: Model, obj: Sequence[mujoco.MjStruct]
) -> BindData:
"""Bind a Mujoco spec to an MJX Data."""
return BindData(self, model, obj)

Expand Down
35 changes: 31 additions & 4 deletions python/mujoco/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import os
import platform
import subprocess
from typing import Union, IO
from typing import IO, Union
from typing_extensions import TypeAlias
import warnings
import zipfile

Expand Down Expand Up @@ -48,20 +49,46 @@
'machine. This is not supported by MuJoCo. Please install and run a '
'native, arm64 build of Python.')

from mujoco import _specs
from mujoco._callbacks import *
from mujoco._constants import *
from mujoco._enums import *
from mujoco._errors import *
from mujoco._functions import *
from mujoco._render import *
from mujoco._specs import *
from mujoco._specs import MjSpec
from mujoco._structs import *
from mujoco.gl_context import *
from mujoco.renderer import Renderer

MjStruct: TypeAlias = Union[
_specs.MjsBody,
_specs.MjsFrame,
_specs.MjsGeom,
_specs.MjsJoint,
_specs.MjsLight,
_specs.MjsMaterial,
_specs.MjsSite,
_specs.MjsMesh,
_specs.MjsSkin,
_specs.MjsTexture,
_specs.MjsText,
_specs.MjsTuple,
_specs.MjsCamera,
_specs.MjsFlex,
_specs.MjsHField,
_specs.MjsKey,
_specs.MjsNumeric,
_specs.MjsPair,
_specs.MjsExclude,
_specs.MjsEquality,
_specs.MjsTendon,
_specs.MjsSensor,
_specs.MjsActuator,
_specs.MjsPlugin,
]

def to_zip(spec: MjSpec, file: Union[str, IO[bytes]]) -> None:
def to_zip(spec: _specs.MjSpec, file: Union[str, IO[bytes]]) -> None:
"""Converts a spec to a zip file.
Args:
Expand All @@ -79,7 +106,7 @@ def to_zip(spec: MjSpec, file: Union[str, IO[bytes]]) -> None:
zip_info = zipfile.ZipInfo(os.path.join(spec.modelname, filename))
zip_file.writestr(zip_info, contents)

MjSpec.to_zip = to_zip
_specs.MjSpec.to_zip = to_zip

HEADERS_DIR = os.path.join(os.path.dirname(__file__), 'include/mujoco')
PLUGINS_DIR = os.path.join(os.path.dirname(__file__), 'plugin')
Expand Down
7 changes: 7 additions & 0 deletions python/mujoco/specs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import inspect
import os
import textwrap
import typing
import zipfile

from absl.testing import absltest
Expand All @@ -32,6 +33,12 @@ def get_linenumber():

class SpecsTest(absltest.TestCase):

def test_typing(self):
spec = mujoco.MjSpec()
self.assertIsInstance(spec, mujoco.MjSpec)
self.assertIsInstance(spec.worldbody, mujoco.MjsBody)
self.assertIsInstance(spec.worldbody, typing.get_args(mujoco.MjStruct))

def test_basic(self):
# Create a spec.
spec = mujoco.MjSpec()
Expand Down

0 comments on commit dec7b66

Please sign in to comment.