"""Tool to configure Python tools.""" from __future__ import annotations from argparse import SUPPRESS, ArgumentParser from dataclasses import is_dataclass from pathlib import Path from typing import TYPE_CHECKING, Any, Generic, Sequence, TypeVar if TYPE_CHECKING: from _typeshed import DataclassInstance DataclassT = TypeVar("DataclassT", bound=DataclassInstance) else: DataclassT = TypeVar("DataclassT") from pytoolconfig.fields import _gather_config_fields from pytoolconfig.sources import PyProject, PyTool, Source from pytoolconfig.types import ConfigField from pytoolconfig.universal_config import UniversalConfig from pytoolconfig.utils import _dict_to_dataclass, _recursive_merge class PyToolConfig(Generic[DataclassT]): """Python Tool Configuration Aggregator.""" sources: list[Source] tool: str working_directory: Path model: type[DataclassT] fall_through: bool = False arg_parser: ArgumentParser | None = None _config_fields: dict[str, ConfigField] def __init__( # noqa: PLR0913 self, tool: str, working_directory: Path, model: type[DataclassT], arg_parser: ArgumentParser | None = None, custom_sources: Sequence[Source] | None = None, global_config: bool = False, global_sources: Sequence[Source] | None = None, fall_through: bool = False, *args: Any, **kwargs: Any, ) -> None: """Initialize the configuration object. :param tool: name of the tool to use. :param working_directory: working directory in use. :param model: Model of configuration. :param arg_parser: Arugument Parser. :param custom_sources: Custom sources :param global_config: Enable global configuration :param global_sources: Custom global sources :param fall_through: Configuration options should fall through between sources. :param args: Passed to constructor for PyProject :param kwargs: Passed to constructor for PyProject """ assert is_dataclass(model) self.model = model self._config_fields = _gather_config_fields(model) self.tool = tool self.sources = [PyProject(working_directory, tool, *args, **kwargs)] if custom_sources: self.sources.extend(custom_sources) if global_config: self.sources.append(PyTool(tool)) if global_sources: self.sources.extend(global_sources) self.arg_parser = arg_parser self.fall_through = fall_through self._setup_arg_parser() def parse(self, args: list[str] | None = None) -> DataclassT: """Parse the configuration. :param args: any additional command line overwrites. """ configuration = self._parse_sources() assert isinstance(self.sources[0], PyProject) universal: UniversalConfig = self.sources[0].universalconfig() if self.arg_parser: if args is None: args = [] parsed = self.arg_parser.parse_args(args) for name, value in parsed._get_kwargs(): setattr(configuration, name, value) for name, field in self._config_fields.items(): if field.universal_config: universal_value = vars(universal)[field.universal_config.name] if universal_value is not None: setattr( configuration, name, universal_value, ) return configuration def _setup_arg_parser(self) -> None: if self.arg_parser: for name, field in self._config_fields.items(): if field.command_line: flags = field.command_line self.arg_parser.add_argument( *flags, type=field._type, help=field.description, default=SUPPRESS, metavar=name, dest=name, ) def _parse_sources(self) -> DataclassT: configuration = self.model() if self.fall_through: for source in reversed(self.sources): parsed = source.parse() if parsed is not None: configuration = _recursive_merge(configuration, parsed) else: for source in self.sources: parsed = source.parse() if parsed: return _dict_to_dataclass(self.model, parsed) return configuration