1from __future__ import annotations
2
3import argparse
4import importlib
5import logging
6import os
7import sys
8from argparse import Namespace
9from pathlib import Path
10from typing import Sequence
11
12import simple_parsing
13from simple_parsing.wrappers.field_wrapper import ArgumentGenerationMode, FieldWrapper
14
15logger = logging.getLogger(__name__)
16
17
[docs]
18class ArgumentParser(simple_parsing.ArgumentParser):
19 IMPORTED_MODULES: set[Path] = set()
20
[docs]
21 def add_plugin_argumnets(self, parser: ArgumentParser) -> None:
22 """Add arguments for plugins.
23
24 Args:
25 parser (ArgumentParser): Argument parser.
26 """
27 parser.add_argument(
28 "--plugin_dir",
29 type=Path,
30 default=None,
31 help="Path to a directory containing user defined plugins.",
32 )
33
[docs]
34 def import_plugin(self, plugin_dir: Path) -> None:
35 """Import plugin modules.
36
37 Args:
38 plugin_dir (pathlib.Path): A directory containing user defined plugins.
39 """
40 plugin_dir = plugin_dir.absolute()
41 if not os.path.exists(plugin_dir) or not os.path.isdir(
42 os.path.dirname(plugin_dir)
43 ):
44 raise FileNotFoundError(plugin_dir)
45
46 module_parent, module_name = os.path.split(plugin_dir)
47 if plugin_dir not in ArgumentParser.IMPORTED_MODULES:
48 if module_name not in sys.modules:
49 sys.path.insert(0, module_parent)
50 importlib.import_module(module_name)
51 elif plugin_dir in sys.modules[module_name].__path__:
52 logger.info(f"--plugin_dir={plugin_dir} has already been imported.")
53 else:
54 raise ImportError(
55 f"Failed to import --plugin_dir={plugin_dir} because the module name "
56 f"({module_name}) is not globally unique."
57 )
58 self.IMPORTED_MODULES.add(plugin_dir)
59
[docs]
60 def preprocess_parser(self) -> None:
61 """Preprocess ArgumentParser."""
62 self.parse_known_args_preprocess(sys.argv[1:])
63
[docs]
64 def parse_known_args_preprocess(
65 self,
66 args: Sequence[str] | None = None,
67 namespace: Namespace | None = None,
68 attempt_to_reorder: bool = False,
69 ) -> None:
70 # default Namespace built from parser defaults
71 if namespace is None:
72 namespace = Namespace()
73 if self.config_path:
74 if isinstance(self.config_path, Path):
75 config_paths = [self.config_path]
76 else:
77 config_paths = self.config_path
78 for config_file in config_paths:
79 self.set_defaults(config_file)
80
81 if self.add_config_path_arg:
82 temp_parser = ArgumentParser(
83 add_config_path_arg=False,
84 add_help=False,
85 add_option_string_dash_variants=FieldWrapper.add_dash_variants,
86 argument_generation_mode=FieldWrapper.argument_generation_mode,
87 nested_mode=FieldWrapper.nested_mode,
88 )
89 temp_parser.add_argument(
90 "--config_path",
91 type=Path,
92 default=self.config_path,
93 help="Path to a config file containing default values to use.",
94 )
95 args_with_config_path, args = temp_parser.parse_known_args(args)
96 config_path = args_with_config_path.config_path
97
98 if config_path is not None:
99 config_paths = (
100 config_path if isinstance(config_path, list) else [config_path]
101 )
102 for config_file in config_paths:
103 self.set_defaults(config_file)
104
105 # Adding it here just so it shows up in the help message. The default will be set in
106 # the help string.
107 if self._option_string_actions.get("--config_path", None) is None:
108 self.add_argument(
109 "--config_path",
110 type=Path,
111 default=config_path,
112 help="Path to a config file containing default values to use.",
113 )
114
115 # Plugin loader
116 self.add_plugin_argumnets(temp_parser)
117 args_with_plugin_dir, args = temp_parser.parse_known_args(args)
118 plugin_dir: Path | None = args_with_plugin_dir.plugin_dir
119 if plugin_dir is not None:
120 self.import_plugin(plugin_dir)
121 self.add_plugin_argumnets(self)
122
123 assert isinstance(args, list)
124 self._preprocessing(args=args, namespace=namespace)
125
[docs]
126 def parse_known_args(
127 self,
128 args: Sequence[str] | None = None,
129 namespace: Namespace | None = None,
130 attempt_to_reorder: bool = False,
131 ):
132 # NOTE: since the usual ArgumentParser.parse_args() calls
133 # parse_known_args, we therefore just need to overload the
134 # parse_known_args method to support both.
135 if args is None:
136 # args default to the system args
137 args = sys.argv[1:]
138 else:
139 # make sure that args are mutable
140 args = list(args)
141
142 self.parse_known_args_preprocess(
143 args=args, namespace=namespace, attempt_to_reorder=attempt_to_reorder
144 )
145 logger.debug(
146 f"Parser {id(self)} is parsing args: {args}, namespace: {namespace}"
147 )
148 parsed_args, unparsed_args = super(
149 simple_parsing.ArgumentParser, self
150 ).parse_known_args(args, namespace)
151
152 if unparsed_args and self._subparsers and attempt_to_reorder:
153 logger.warning(
154 f"Unparsed arguments when using subparsers. Will "
155 f"attempt to automatically re-order the unparsed arguments "
156 f"{unparsed_args}."
157 )
158 index_in_start = args.index(unparsed_args[0])
159 # Simply 'cycle' the args to the right ordering.
160 new_start_args = args[index_in_start:] + args[:index_in_start]
161 parsed_args, unparsed_args = super(
162 simple_parsing.ArgumentParser, self
163 ).parse_known_args(new_start_args)
164
165 parsed_args = self._postprocessing(parsed_args)
166 return parsed_args, unparsed_args
167
168
[docs]
169class DataclassWrapper(simple_parsing.wrappers.DataclassWrapper):
[docs]
170 def add_arguments(self, parser: argparse.ArgumentParser) -> None:
171 if self._name == "common":
172 FieldWrapper.argument_generation_mode = ArgumentGenerationMode.FLAT
173 else:
174 FieldWrapper.argument_generation_mode = ArgumentGenerationMode.NESTED
175 super().add_arguments(parser)
176 FieldWrapper.argument_generation_mode = ArgumentGenerationMode.NESTED