186 lines
6.3 KiB
Python
186 lines
6.3 KiB
Python
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
|
|
# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
|
|
# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt
|
|
|
|
"""Classes representing different types of constraints on inference values."""
|
|
from __future__ import annotations
|
|
|
|
import sys
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import Iterator
|
|
from typing import TYPE_CHECKING
|
|
|
|
from astroid import nodes, util
|
|
from astroid.typing import InferenceResult
|
|
|
|
if sys.version_info >= (3, 11):
|
|
from typing import Self
|
|
else:
|
|
from typing_extensions import Self
|
|
|
|
if TYPE_CHECKING:
|
|
from astroid import bases
|
|
|
|
_NameNodes = nodes.AssignAttr | nodes.Attribute | nodes.AssignName | nodes.Name
|
|
|
|
|
|
class Constraint(ABC):
|
|
"""Represents a single constraint on a variable."""
|
|
|
|
def __init__(self, node: nodes.NodeNG, negate: bool) -> None:
|
|
self.node = node
|
|
"""The node that this constraint applies to."""
|
|
self.negate = negate
|
|
"""True if this constraint is negated. E.g., "is not" instead of "is"."""
|
|
|
|
@classmethod
|
|
@abstractmethod
|
|
def match(
|
|
cls, node: _NameNodes, expr: nodes.NodeNG, negate: bool = False
|
|
) -> Self | None:
|
|
"""Return a new constraint for node matched from expr, if expr matches
|
|
the constraint pattern.
|
|
|
|
If negate is True, negate the constraint.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def satisfied_by(self, inferred: InferenceResult) -> bool:
|
|
"""Return True if this constraint is satisfied by the given inferred value."""
|
|
|
|
|
|
class NoneConstraint(Constraint):
|
|
"""Represents an "is None" or "is not None" constraint."""
|
|
|
|
CONST_NONE: nodes.Const = nodes.Const(None)
|
|
|
|
@classmethod
|
|
def match(
|
|
cls, node: _NameNodes, expr: nodes.NodeNG, negate: bool = False
|
|
) -> Self | None:
|
|
"""Return a new constraint for node matched from expr, if expr matches
|
|
the constraint pattern.
|
|
|
|
Negate the constraint based on the value of negate.
|
|
"""
|
|
if isinstance(expr, nodes.Compare) and len(expr.ops) == 1:
|
|
left = expr.left
|
|
op, right = expr.ops[0]
|
|
if op in {"is", "is not"} and (
|
|
_matches(left, node) and _matches(right, cls.CONST_NONE)
|
|
):
|
|
negate = (op == "is" and negate) or (op == "is not" and not negate)
|
|
return cls(node=node, negate=negate)
|
|
|
|
return None
|
|
|
|
def satisfied_by(self, inferred: InferenceResult) -> bool:
|
|
"""Return True if this constraint is satisfied by the given inferred value."""
|
|
# Assume true if uninferable
|
|
if isinstance(inferred, util.UninferableBase):
|
|
return True
|
|
|
|
# Return the XOR of self.negate and matches(inferred, self.CONST_NONE)
|
|
return self.negate ^ _matches(inferred, self.CONST_NONE)
|
|
|
|
|
|
class BooleanConstraint(Constraint):
|
|
"""Represents an "x" or "not x" constraint."""
|
|
|
|
@classmethod
|
|
def match(
|
|
cls, node: _NameNodes, expr: nodes.NodeNG, negate: bool = False
|
|
) -> Self | None:
|
|
"""Return a new constraint for node if expr matches one of these patterns:
|
|
|
|
- direct match (expr == node): use given negate value
|
|
- negated match (expr == `not node`): flip negate value
|
|
|
|
Return None if no pattern matches.
|
|
"""
|
|
if _matches(expr, node):
|
|
return cls(node=node, negate=negate)
|
|
|
|
if (
|
|
isinstance(expr, nodes.UnaryOp)
|
|
and expr.op == "not"
|
|
and _matches(expr.operand, node)
|
|
):
|
|
return cls(node=node, negate=not negate)
|
|
|
|
return None
|
|
|
|
def satisfied_by(self, inferred: InferenceResult) -> bool:
|
|
"""Return True for uninferable results, or depending on negate flag:
|
|
|
|
- negate=False: satisfied if boolean value is True
|
|
- negate=True: satisfied if boolean value is False
|
|
"""
|
|
inferred_booleaness = inferred.bool_value()
|
|
if isinstance(inferred, util.UninferableBase) or isinstance(
|
|
inferred_booleaness, util.UninferableBase
|
|
):
|
|
return True
|
|
|
|
return self.negate ^ inferred_booleaness
|
|
|
|
|
|
def get_constraints(
|
|
expr: _NameNodes, frame: nodes.LocalsDictNodeNG
|
|
) -> dict[nodes.If | nodes.IfExp, set[Constraint]]:
|
|
"""Returns the constraints for the given expression.
|
|
|
|
The returned dictionary maps the node where the constraint was generated to the
|
|
corresponding constraint(s).
|
|
|
|
Constraints are computed statically by analysing the code surrounding expr.
|
|
Currently this only supports constraints generated from if conditions.
|
|
"""
|
|
current_node: nodes.NodeNG | None = expr
|
|
constraints_mapping: dict[nodes.If | nodes.IfExp, set[Constraint]] = {}
|
|
while current_node is not None and current_node is not frame:
|
|
parent = current_node.parent
|
|
if isinstance(parent, (nodes.If, nodes.IfExp)):
|
|
branch, _ = parent.locate_child(current_node)
|
|
constraints: set[Constraint] | None = None
|
|
if branch == "body":
|
|
constraints = set(_match_constraint(expr, parent.test))
|
|
elif branch == "orelse":
|
|
constraints = set(_match_constraint(expr, parent.test, invert=True))
|
|
|
|
if constraints:
|
|
constraints_mapping[parent] = constraints
|
|
current_node = parent
|
|
|
|
return constraints_mapping
|
|
|
|
|
|
ALL_CONSTRAINT_CLASSES = frozenset(
|
|
(
|
|
NoneConstraint,
|
|
BooleanConstraint,
|
|
)
|
|
)
|
|
"""All supported constraint types."""
|
|
|
|
|
|
def _matches(node1: nodes.NodeNG | bases.Proxy, node2: nodes.NodeNG) -> bool:
|
|
"""Returns True if the two nodes match."""
|
|
if isinstance(node1, nodes.Name) and isinstance(node2, nodes.Name):
|
|
return node1.name == node2.name
|
|
if isinstance(node1, nodes.Attribute) and isinstance(node2, nodes.Attribute):
|
|
return node1.attrname == node2.attrname and _matches(node1.expr, node2.expr)
|
|
if isinstance(node1, nodes.Const) and isinstance(node2, nodes.Const):
|
|
return node1.value == node2.value
|
|
|
|
return False
|
|
|
|
|
|
def _match_constraint(
|
|
node: _NameNodes, expr: nodes.NodeNG, invert: bool = False
|
|
) -> Iterator[Constraint]:
|
|
"""Yields all constraint patterns for node that match."""
|
|
for constraint_cls in ALL_CONSTRAINT_CLASSES:
|
|
constraint = constraint_cls.match(node, expr, invert)
|
|
if constraint:
|
|
yield constraint
|