142 lines
4.5 KiB
Python
142 lines
4.5 KiB
Python
"""Utility functions and classes."""
|
|
from __future__ import annotations
|
|
|
|
import sys
|
|
import warnings
|
|
from dataclasses import Field, fields, is_dataclass, replace
|
|
from enum import EnumMeta
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Any, Generator, Mapping, TypeVar
|
|
|
|
from packaging.requirements import Requirement
|
|
from packaging.specifiers import SpecifierSet
|
|
|
|
from .types import Key
|
|
|
|
if TYPE_CHECKING:
|
|
from _typeshed import DataclassInstance
|
|
|
|
|
|
def find_config_file(
|
|
working_directory: Path,
|
|
filename: str,
|
|
bases: list[str] | None = None,
|
|
) -> Path | None:
|
|
"""Find a configuration file given a working directory.
|
|
|
|
Args:
|
|
working_directory: Working directory to start from
|
|
filename: Filename to look for
|
|
bases: Bases to stop at
|
|
|
|
Returns:
|
|
Path of config file
|
|
|
|
"""
|
|
if bases is None:
|
|
bases = [".git", ".hg"]
|
|
"""Recursively find the configuration file."""
|
|
target = working_directory / filename
|
|
if target.exists():
|
|
return target
|
|
for base in bases:
|
|
if (working_directory / base).exists():
|
|
return None
|
|
if working_directory == working_directory.parent:
|
|
return None
|
|
return find_config_file(working_directory.parent, filename, bases)
|
|
|
|
|
|
def min_py_version(specifier: str) -> tuple[int, int]:
|
|
"""Return the minimum python 3 version.
|
|
|
|
Between 3.4 and interpreter version.
|
|
"""
|
|
parsed = SpecifierSet(specifier)
|
|
for i in range(4, sys.version_info.minor):
|
|
if parsed.contains(f"3.{i}"):
|
|
return (3, i)
|
|
return (3, sys.version_info.minor)
|
|
|
|
|
|
def max_py_version(specifier: str) -> tuple[int, int]:
|
|
"""Return the maximum python 3 version.
|
|
|
|
Between 3.4 and interpreter version.
|
|
"""
|
|
parsed = SpecifierSet(specifier)
|
|
for i in range(sys.version_info.minor, 4, -1):
|
|
if parsed.contains(f"3.{i}"):
|
|
return (3, i)
|
|
return (3, 4) # Please don't cap your project at python3.4
|
|
|
|
|
|
def parse_dependencies(dependencies: list[str]) -> Generator[Requirement, None, None]:
|
|
"""Parse the dependencies from TOML using packaging."""
|
|
for dependency in dependencies:
|
|
yield Requirement(dependency)
|
|
|
|
|
|
T = TypeVar("T", bound="DataclassInstance")
|
|
|
|
|
|
def _subtables(dataclass_fields: dict[str, Field]) -> dict[str, type[Any]]:
|
|
return {
|
|
name: field.type
|
|
for name, field in dataclass_fields.items()
|
|
if is_dataclass(field.type)
|
|
}
|
|
|
|
|
|
def _fields(dataclass: DataclassInstance | type[DataclassInstance]) -> dict[str, Field]:
|
|
return {field.name: field for field in fields(dataclass) if field.init}
|
|
|
|
|
|
def _format_enum(option: Any) -> str:
|
|
if isinstance(option, str):
|
|
return f'"{option}"'
|
|
return str(option)
|
|
|
|
|
|
def _dict_to_dataclass(
|
|
dataclass: type[T],
|
|
dictionary: Mapping[str, Key],
|
|
) -> T:
|
|
filtered_arg_dict: dict[str, Any] = {}
|
|
dataclass_fields = _fields(dataclass)
|
|
sub_tables = _subtables(dataclass_fields)
|
|
for key_name, value in dictionary.items():
|
|
if key_name in sub_tables:
|
|
sub_table = sub_tables[key_name]
|
|
assert isinstance(value, Mapping)
|
|
filtered_arg_dict[key_name] = _dict_to_dataclass(sub_table, value)
|
|
elif key_name in dataclass_fields:
|
|
keytype = dataclass_fields[key_name].type
|
|
if isinstance(keytype, EnumMeta):
|
|
try:
|
|
filtered_arg_dict[key_name] = keytype(value)
|
|
except ValueError:
|
|
valid = set(keytype._value2member_map_.keys())
|
|
warnings.warn(
|
|
f"{value} is not a valid option for {key_name}, skipping."
|
|
f"Valid options are: {','.join(map(_format_enum, valid))}.",
|
|
stacklevel=1,
|
|
)
|
|
else:
|
|
filtered_arg_dict[key_name] = value
|
|
return dataclass(**filtered_arg_dict)
|
|
|
|
|
|
def _recursive_merge(dataclass: T, dictionary: Mapping[str, Key]) -> T:
|
|
"""Overwrite every value specified in dictionary on the dataclass."""
|
|
filtered_arg_dict: dict[str, Any] = {}
|
|
dataclass_fields = _fields(dataclass)
|
|
sub_tables = _subtables(dataclass_fields)
|
|
for key_name, value in dictionary.items():
|
|
if key_name in sub_tables:
|
|
sub_table = getattr(dataclass, key_name)
|
|
assert isinstance(value, Mapping)
|
|
filtered_arg_dict[key_name] = _recursive_merge(sub_table, value)
|
|
elif key_name in dataclass_fields:
|
|
filtered_arg_dict[key_name] = value
|
|
return replace(dataclass, **filtered_arg_dict)
|