Source code for mbrs.args

  1from __future__ import annotations
  2
  3import argparse
  4import importlib
  5import logging
  6import os
  7import sys
  8from argparse import Namespace
  9from pathlib import Path
 10from typing import Sequence
 11
 12import simple_parsing
 13from simple_parsing.wrappers.field_wrapper import ArgumentGenerationMode, FieldWrapper
 14
 15logger = logging.getLogger(__name__)
 16
 17
[docs] 18class ArgumentParser(simple_parsing.ArgumentParser): 19 IMPORTED_MODULES: set[Path] = set() 20
[docs] 21 def add_plugin_argumnets(self, parser: ArgumentParser) -> None: 22 """Add arguments for plugins. 23 24 Args: 25 parser (ArgumentParser): Argument parser. 26 """ 27 parser.add_argument( 28 "--plugin_dir", 29 type=Path, 30 default=None, 31 help="Path to a directory containing user defined plugins.", 32 )
33
[docs] 34 def import_plugin(self, plugin_dir: Path) -> None: 35 """Import plugin modules. 36 37 Args: 38 plugin_dir (pathlib.Path): A directory containing user defined plugins. 39 """ 40 plugin_dir = plugin_dir.absolute() 41 if not os.path.exists(plugin_dir) or not os.path.isdir( 42 os.path.dirname(plugin_dir) 43 ): 44 raise FileNotFoundError(plugin_dir) 45 46 module_parent, module_name = os.path.split(plugin_dir) 47 if plugin_dir not in ArgumentParser.IMPORTED_MODULES: 48 if module_name not in sys.modules: 49 sys.path.insert(0, module_parent) 50 importlib.import_module(module_name) 51 elif plugin_dir in sys.modules[module_name].__path__: 52 logger.info(f"--plugin_dir={plugin_dir} has already been imported.") 53 else: 54 raise ImportError( 55 f"Failed to import --plugin_dir={plugin_dir} because the module name " 56 f"({module_name}) is not globally unique." 57 ) 58 self.IMPORTED_MODULES.add(plugin_dir)
59
[docs] 60 def preprocess_parser(self) -> None: 61 """Preprocess ArgumentParser.""" 62 self.parse_known_args_preprocess(sys.argv[1:])
63
[docs] 64 def parse_known_args_preprocess( 65 self, 66 args: Sequence[str] | None = None, 67 namespace: Namespace | None = None, 68 attempt_to_reorder: bool = False, 69 ) -> None: 70 # default Namespace built from parser defaults 71 if namespace is None: 72 namespace = Namespace() 73 if self.config_path: 74 if isinstance(self.config_path, Path): 75 config_paths = [self.config_path] 76 else: 77 config_paths = self.config_path 78 for config_file in config_paths: 79 self.set_defaults(config_file) 80 81 if self.add_config_path_arg: 82 temp_parser = ArgumentParser( 83 add_config_path_arg=False, 84 add_help=False, 85 add_option_string_dash_variants=FieldWrapper.add_dash_variants, 86 argument_generation_mode=FieldWrapper.argument_generation_mode, 87 nested_mode=FieldWrapper.nested_mode, 88 ) 89 temp_parser.add_argument( 90 "--config_path", 91 type=Path, 92 default=self.config_path, 93 help="Path to a config file containing default values to use.", 94 ) 95 args_with_config_path, args = temp_parser.parse_known_args(args) 96 config_path = args_with_config_path.config_path 97 98 if config_path is not None: 99 config_paths = ( 100 config_path if isinstance(config_path, list) else [config_path] 101 ) 102 for config_file in config_paths: 103 self.set_defaults(config_file) 104 105 # Adding it here just so it shows up in the help message. The default will be set in 106 # the help string. 107 if self._option_string_actions.get("--config_path", None) is None: 108 self.add_argument( 109 "--config_path", 110 type=Path, 111 default=config_path, 112 help="Path to a config file containing default values to use.", 113 ) 114 115 # Plugin loader 116 self.add_plugin_argumnets(temp_parser) 117 args_with_plugin_dir, args = temp_parser.parse_known_args(args) 118 plugin_dir: Path | None = args_with_plugin_dir.plugin_dir 119 if plugin_dir is not None: 120 self.import_plugin(plugin_dir) 121 self.add_plugin_argumnets(self) 122 123 assert isinstance(args, list) 124 self._preprocessing(args=args, namespace=namespace)
125
[docs] 126 def parse_known_args( 127 self, 128 args: Sequence[str] | None = None, 129 namespace: Namespace | None = None, 130 attempt_to_reorder: bool = False, 131 ): 132 # NOTE: since the usual ArgumentParser.parse_args() calls 133 # parse_known_args, we therefore just need to overload the 134 # parse_known_args method to support both. 135 if args is None: 136 # args default to the system args 137 args = sys.argv[1:] 138 else: 139 # make sure that args are mutable 140 args = list(args) 141 142 self.parse_known_args_preprocess( 143 args=args, namespace=namespace, attempt_to_reorder=attempt_to_reorder 144 ) 145 logger.debug( 146 f"Parser {id(self)} is parsing args: {args}, namespace: {namespace}" 147 ) 148 parsed_args, unparsed_args = super( 149 simple_parsing.ArgumentParser, self 150 ).parse_known_args(args, namespace) 151 152 if unparsed_args and self._subparsers and attempt_to_reorder: 153 logger.warning( 154 f"Unparsed arguments when using subparsers. Will " 155 f"attempt to automatically re-order the unparsed arguments " 156 f"{unparsed_args}." 157 ) 158 index_in_start = args.index(unparsed_args[0]) 159 # Simply 'cycle' the args to the right ordering. 160 new_start_args = args[index_in_start:] + args[:index_in_start] 161 parsed_args, unparsed_args = super( 162 simple_parsing.ArgumentParser, self 163 ).parse_known_args(new_start_args) 164 165 parsed_args = self._postprocessing(parsed_args) 166 return parsed_args, unparsed_args
167 168
[docs] 169class DataclassWrapper(simple_parsing.wrappers.DataclassWrapper):
[docs] 170 def add_arguments(self, parser: argparse.ArgumentParser) -> None: 171 if self._name == "common": 172 FieldWrapper.argument_generation_mode = ArgumentGenerationMode.FLAT 173 else: 174 FieldWrapper.argument_generation_mode = ArgumentGenerationMode.NESTED 175 super().add_arguments(parser) 176 FieldWrapper.argument_generation_mode = ArgumentGenerationMode.NESTED