From dec7b661c14074768259e770a8425bf84c93eb1b Mon Sep 17 00:00:00 2001 From: Alessio Quaglino Date: Wed, 5 Feb 2025 10:00:40 -0800 Subject: [PATCH] Add mujoco.MjSpec.MjStruct type. PiperOrigin-RevId: 723552214 Change-Id: I8ecb43b4025b532ba8b48e15c9c6a2b350293d20 --- mjx/mujoco/mjx/_src/support.py | 16 +++++++++++----- python/mujoco/__init__.py | 35 ++++++++++++++++++++++++++++++---- python/mujoco/specs_test.py | 7 +++++++ 3 files changed, 49 insertions(+), 9 deletions(-) diff --git a/mjx/mujoco/mjx/_src/support.py b/mjx/mujoco/mjx/_src/support.py index 09fc536505..9556c930c7 100644 --- a/mjx/mujoco/mjx/_src/support.py +++ b/mjx/mujoco/mjx/_src/support.py @@ -14,7 +14,7 @@ # ============================================================================== """Engine support functions.""" from collections.abc import Sequence -from typing import Optional, Tuple, Union, Any +from typing import Optional, Tuple, Union import jax from jax import numpy as jp @@ -289,7 +289,7 @@ def name2id( class BindModel(object): """Class holding the requested MJX Model and spec id for binding a spec to Model.""" - def __init__(self, model: Model, specs: Sequence[Any]): + def __init__(self, model: Model, specs: Sequence[mujoco.MjStruct]): self.model = model try: iter(specs) @@ -371,7 +371,9 @@ def __getattr__(self, name: str): return getattr(self.model, self.prefix + name)[self.id, ...] -def _bind_model(self: Model, obj: Sequence[Any]) -> BindModel: +def _bind_model( + self: Model, obj: Sequence[mujoco.MjStruct] +) -> BindModel: """Bind a Mujoco spec to an MJX Model.""" return BindModel(self, obj) @@ -379,7 +381,9 @@ def _bind_model(self: Model, obj: Sequence[Any]) -> BindModel: class BindData(object): """Class holding the requested MJX Data and spec id for binding a spec to Data.""" - def __init__(self, data: Data, model: Model, specs: Sequence[Any]): + def __init__( + self, data: Data, model: Model, specs: Sequence[mujoco.MjStruct] + ): self.data = data self.model = model try: @@ -499,7 +503,9 @@ def set(self, name: str, value: jax.Array) -> Data: return self.data.replace(**{self.__getname(name): array}) -def _bind_data(self: Data, model: Model, obj: Sequence[Any]) -> BindData: +def _bind_data( + self: Data, model: Model, obj: Sequence[mujoco.MjStruct] +) -> BindData: """Bind a Mujoco spec to an MJX Data.""" return BindData(self, model, obj) diff --git a/python/mujoco/__init__.py b/python/mujoco/__init__.py index 1637a2e636..e8376a3487 100644 --- a/python/mujoco/__init__.py +++ b/python/mujoco/__init__.py @@ -19,7 +19,8 @@ import os import platform import subprocess -from typing import Union, IO +from typing import IO, Union +from typing_extensions import TypeAlias import warnings import zipfile @@ -48,6 +49,7 @@ 'machine. This is not supported by MuJoCo. Please install and run a ' 'native, arm64 build of Python.') +from mujoco import _specs from mujoco._callbacks import * from mujoco._constants import * from mujoco._enums import * @@ -55,13 +57,38 @@ from mujoco._functions import * from mujoco._render import * from mujoco._specs import * -from mujoco._specs import MjSpec from mujoco._structs import * from mujoco.gl_context import * from mujoco.renderer import Renderer +MjStruct: TypeAlias = Union[ + _specs.MjsBody, + _specs.MjsFrame, + _specs.MjsGeom, + _specs.MjsJoint, + _specs.MjsLight, + _specs.MjsMaterial, + _specs.MjsSite, + _specs.MjsMesh, + _specs.MjsSkin, + _specs.MjsTexture, + _specs.MjsText, + _specs.MjsTuple, + _specs.MjsCamera, + _specs.MjsFlex, + _specs.MjsHField, + _specs.MjsKey, + _specs.MjsNumeric, + _specs.MjsPair, + _specs.MjsExclude, + _specs.MjsEquality, + _specs.MjsTendon, + _specs.MjsSensor, + _specs.MjsActuator, + _specs.MjsPlugin, +] -def to_zip(spec: MjSpec, file: Union[str, IO[bytes]]) -> None: +def to_zip(spec: _specs.MjSpec, file: Union[str, IO[bytes]]) -> None: """Converts a spec to a zip file. Args: @@ -79,7 +106,7 @@ def to_zip(spec: MjSpec, file: Union[str, IO[bytes]]) -> None: zip_info = zipfile.ZipInfo(os.path.join(spec.modelname, filename)) zip_file.writestr(zip_info, contents) -MjSpec.to_zip = to_zip +_specs.MjSpec.to_zip = to_zip HEADERS_DIR = os.path.join(os.path.dirname(__file__), 'include/mujoco') PLUGINS_DIR = os.path.join(os.path.dirname(__file__), 'plugin') diff --git a/python/mujoco/specs_test.py b/python/mujoco/specs_test.py index 0cdf51a032..2806d01a5a 100644 --- a/python/mujoco/specs_test.py +++ b/python/mujoco/specs_test.py @@ -17,6 +17,7 @@ import inspect import os import textwrap +import typing import zipfile from absl.testing import absltest @@ -32,6 +33,12 @@ def get_linenumber(): class SpecsTest(absltest.TestCase): + def test_typing(self): + spec = mujoco.MjSpec() + self.assertIsInstance(spec, mujoco.MjSpec) + self.assertIsInstance(spec.worldbody, mujoco.MjsBody) + self.assertIsInstance(spec.worldbody, typing.get_args(mujoco.MjStruct)) + def test_basic(self): # Create a spec. spec = mujoco.MjSpec()