From 02986a8d04d9033838421975c7c3714ac009db51 Mon Sep 17 00:00:00 2001 From: Guy Azran Date: Sat, 5 Aug 2023 23:00:09 +0300 Subject: [PATCH] PyMJCF nested include tags relative to base model --- dm_control/mjcf/parser.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/dm_control/mjcf/parser.py b/dm_control/mjcf/parser.py index c710f179..b9592060 100644 --- a/dm_control/mjcf/parser.py +++ b/dm_control/mjcf/parser.py @@ -80,7 +80,7 @@ def from_file(file_handle, escape_separators=False, def from_path(path, escape_separators=False, resolve_references=True, - assets=None): + assets=None, model_dir=None): """Parses an XML file into an MJCF object model. Args: @@ -94,11 +94,14 @@ def from_path(path, escape_separators=False, resolve_references=True, assets: (optional) A dictionary of pre-loaded assets, of the form `{filename: bytestring}`. If present, PyMJCF will search for assets in this dictionary before attempting to load them from the filesystem. + model_dir: (optional) Path to the directory containing the model XML file. + This is used to prefix the paths of all asset files. Returns: An `mjcf.RootElement`. """ - model_dir, _ = os.path.split(path) + if model_dir is None: + model_dir, _ = os.path.split(path) contents = resources.GetResource(path) xml_root = etree.fromstring(contents) return _parse(xml_root, escape_separators, @@ -153,7 +156,7 @@ def _parse(xml_root, escape_separators=False, path_or_xml_string, escape_separators=escape_separators, resolve_references=resolve_references, - assets=assets) + assets=assets, model_dir=model_dir) to_include.append(included_mjcf) # We must remove tags before parsing the main XML file, since # these are a schema violation.