Source code for static_frame.core.type_clinic

from __future__ import annotations

import re
import types
import typing
import warnings
from collections import deque
from collections.abc import MutableMapping
from collections.abc import Sequence
from enum import Enum
from functools import partial
from functools import wraps
from inspect import BoundArguments
from inspect import Signature
from itertools import chain
from itertools import repeat

import numpy as np
import typing_extensions as tp
from numpy.typing import NBitBase

from static_frame.core.bus import Bus
from static_frame.core.frame import Frame
from static_frame.core.index import Index
from static_frame.core.index_base import IndexBase
from static_frame.core.index_datetime import IndexDatetime
from static_frame.core.index_hierarchy import IndexHierarchy
from static_frame.core.series import Series
from static_frame.core.util import DTYPE_COMPLEX_KIND
from static_frame.core.util import INT_TYPES
from static_frame.core.util import TLabel
from static_frame.core.yarn import Yarn

TFrameAny = Frame[tp.Any, tp.Any, tp.Unpack[tp.Tuple[tp.Any, ...]]]

TValidator = tp.Callable[..., bool]
TLabelMatchSpecifier = tp.Union[TLabel, tp.Pattern[tp.Any], tp.Set[TLabel]]

if tp.TYPE_CHECKING:
    from types import EllipsisType  # pragma: no cover
    TDtypeAny = np.dtype[tp.Any] #pragma: no cover
    TShapeComponent = tp.Union[int, EllipsisType] #pragma: no cover # pyright: ignore
    TShapeSpecifier = tp.Tuple[TShapeComponent, ...] #pragma: no cover

def _iter_generic_classes() -> tp.Iterable[tp.Type[tp.Any]]:
    if t := getattr(types, 'GenericAlias', None):
        yield t
    if t := getattr(typing, '_GenericAlias', None):
        yield t # pyright: ignore

GENERIC_TYPES = tuple(_iter_generic_classes())

def _iter_union_classes() -> tp.Iterable[tp.Type[tp.Any]]:
    if t := getattr(types, 'UnionType', None):
        yield t
    if t := getattr(typing, '_UnionGenericAlias', None):
        yield t # pyright: ignore

UNION_TYPES = tuple(_iter_union_classes())

def _iter_unpack_classes() -> tp.Iterable[tp.Type[tp.Any]]:
    # NOTE: type extensions Unpack is not equal to typing.Unpack
    if t := getattr(tp, 'Unpack', None):
        yield t
    if t := getattr(typing, 'Unpack', None):
        yield t # pyright: ignore

UNPACK_TYPES = tuple(_iter_unpack_classes())

def is_generic(hint: tp.Any) -> bool:
    return isinstance(hint, GENERIC_TYPES)

def is_union(hint: tp.Any) -> bool:
    if UNION_TYPES:
        return isinstance(hint, UNION_TYPES)
    # this might only be possible pre 3.9
    elif isinstance(hint, GENERIC_TYPES): #pragma: no cover
        return tp.get_origin(hint) is tp.Union #pragma: no cover
    return False #pragma: no cover

def is_unpack(origin: tp.Any, generic_alias: tp.Any) -> bool:
    # NOTE: the * syntax on *tuple does not create an Unpack instances
    # tp.get_origin(*tuple[tp.Any, ...]) is tuple
    # NOTE: cannot use isinstance or issubclass with Unpack
    if origin in UNPACK_TYPES:
        return True
    if getattr(generic_alias, '__unpacked__', False):
        # if hint.__unpacked__ is True, this is a *tuple that does not need to be unpacked
        return True
    return False

def get_args_unpack(hint: tp.Any) -> tp.Any:
    '''Normalize the heterogeneity of dealing with *tuple[tp.Any, ...] and tp.Unpack[tp.Tuple[tp.Any, *]]; always return the contained tuple generic alias
    '''
    if getattr(hint, '__unpacked__', False):
        # if hint.__unpacked__ is True, this is a *tuple that does not need to be unpacked
        tga = hint
    else:
        # unpack from tp.Unpack to return the tuple generic alias
        [tga] = tp.get_args(hint)
    assert issubclass(tuple, tp.get_origin(tga))
    return (tga,) # clients expect a tuple of size 1

#-------------------------------------------------------------------------------

TParent = tp.Tuple[tp.Any, ...]
# A validation record can be used to queue checks or report errors
TValidation = tp.Tuple[tp.Any, tp.Any, TParent, TParent]

#-------------------------------------------------------------------------------
# error reporting, presentation

ERROR_MESSAGE_TYPE = object()

def to_name(v: tp.Any,
        func_to_str: tp.Callable[..., str] = str,
        ) -> str:
    if is_generic(v):
        if hasattr(v, '__name__'):
            name = v.__name__
        else:
            # for older Python, not all generics have __name__
            origin = tp.get_origin(v)
            if hasattr(origin, '__name__'):
                name = origin.__name__
            elif is_unpack(origin, v): # needed for backwards compat
                name = 'Unpack'
            else:
                name = str(origin)
        s = f'{name}[{", ".join(to_name(q) for q in tp.get_args(v))}]'
    elif isinstance(v, tp.TypeVar):
        # str() gets tilde, __name__ does not have tilde
        if v.__bound__:
            s = f'{v}: {to_name(v.__bound__)}'
        elif v.__constraints__:
            s = f'{v}: {to_name(v.__constraints__)}'
        else:
            s = str(v)
    elif hasattr(v, '__name__'):
        s = v.__name__
    elif v is ...:
        s = '...'
    elif isinstance(v, tuple):
        s = f"({', '.join(to_name(w) for w in v)})"
    else:
        s = func_to_str(v)
    return s

def to_signature(
        sig: BoundArguments,
        hints: tp.Mapping[str, tp.Any]) -> str:
    msg = []
    for k in sig.arguments:
        msg.append(f'{k}: {to_name(hints.get(k, tp.Any))}')
    r = to_name(hints.get('return', tp.Any))
    return f'({", ".join(msg)}) -> {r}'



[docs] class ClinicResult: '''A ``ClinicResult`` instance stores zero or more error messages resulting from a check. ''' __slots__ = ('_log',) _LINE = '─' _CORNER = '└' _WIDTH = 4
[docs] def __init__(self, log: tp.Sequence[TValidation]) -> None: self._log = log
[docs] def __iter__(self) -> tp.Iterator[TValidation]: return self._log.__iter__()
[docs] def __len__(self) -> int: return len(self._log)
[docs] def __bool__(self) -> bool: '''Return True if there are validation issues. ''' return bool(self._log)
@property def validated(self) -> bool: return not bool(self._log) @classmethod def _get_indent(cls, size: int) -> str: if size < 1: return '' c_width = cls._WIDTH - 2 l_width = cls._WIDTH * (size - 1) return f'{" " * l_width}{cls._CORNER}{cls._LINE * c_width} '
[docs] def to_str(self) -> str: '''Return error messages as a formatted string with line breaks and indentation. ''' msg = [] for v, h, ph, pv in self._log: if ph: path_components = [] for i, pc in enumerate(ph): path_components.append(f'{self._get_indent(i)}{to_name(pc)}') path = '\n'.join(path_components) i_next = i + 1 else: path = '' i_next = 1 if v is ERROR_MESSAGE_TYPE: # in this case, do not use the value error_msg = f'{h}' else: error_msg = f'Expected {to_name(h)}, provided {to_name(type(v))} invalid' if not path: msg.append(f'\n{error_msg}') else: msg.append(f'\nIn {path}\n{self._get_indent(i_next)}{error_msg}') return ''.join(msg)
[docs] def __repr__(self) -> str: log_len = len(self) return f'<ClinicResult: {log_len} {"errors" if log_len != 1 else "error"}>'
class ClinicError(TypeError): '''A TypeError subclass for exposing check errors. ''' def __init__(self, cr: ClinicResult) -> None: TypeError.__init__(self, cr.to_str()) #------------------------------------------------------------------------------- class Validator: '''Base class of all run-time constraints, deployed in ``Annotated`` generics. ''' __slots__: tp.Tuple[str, ...] = () def _iter_errors(self, value: tp.Any, hint: tp.Any, parent_hints: TParent, parent_values: TParent, ) -> tp.Iterator[TValidation]: raise NotImplementedError() #pragma: no cover def __repr__(self) -> str: args = ', '.join(to_name(getattr(self, v)) for v in self.__slots__) return f'{self.__class__.__name__}({args})'
[docs] class Require: '''A collection of classes to be used in ``Annotated`` generics to perform run-time data validations. ''' __slots__ = ()
[docs] class Name(Validator): '''Validate the name of a container. Args: name: The name to validate against. / ''' __slots__ = ('_name',) def __init__(self, name: TLabel, /): self._name = name def _iter_errors(self, value: tp.Any, hint: tp.Any, parent_hints: TParent, parent_values: TParent, ) -> tp.Iterator[TValidation]: # returning anything is an error if (n := value.name) != self._name: yield (ERROR_MESSAGE_TYPE, f'Expected name {self._name!r}, provided name {n!r}', parent_hints, parent_values, )
[docs] class Len(Validator): '''Validate the length of a container. Args: len: The length to validate against. / ''' __slots__ = ('_len',) def __init__(self, len: int, /): self._len = len def _iter_errors(self, value: tp.Any, hint: tp.Any, parent_hints: TParent, parent_values: TParent, ) -> tp.Iterator[TValidation]: if (vl := len(value)) != self._len: yield (ERROR_MESSAGE_TYPE, f'Expected length {self._len}, provided length {vl}', parent_hints, parent_values, )
[docs] class Shape(Validator): '''Validate the length of a container. Args: shape: A tuple of one or two values, where values are either an integer or an `...`, specifying any value for that position. The size of the shape always species the dimensionality. / ''' __slots__ = ('_shape',) @staticmethod def _validate_shape_component(ss: TShapeSpecifier) -> TShapeSpecifier: for c in ss: if c is not ... and not isinstance(c, INT_TYPES): raise TypeError(f'Components must be either `...` or an integer, not {c!r}.') return ss def __init__(self, /, *shape: TShapeComponent): self._shape = self._validate_shape_component(shape) def _iter_errors(self, value: tp.Any, hint: tp.Any, parent_hints: TParent, parent_values: TParent, ) -> tp.Iterator[TValidation]: # same for both 1d and 2d if len(self._shape) != len(value.shape): yield (ERROR_MESSAGE_TYPE, f'Expected shape ({self._shape}), provided shape {value.shape}', parent_hints, parent_values, ) elif self._shape[0] is not ... and self._shape[0] != value.shape[0]: yield (ERROR_MESSAGE_TYPE, f'Expected shape ({self._shape}), provided shape {value.shape}', parent_hints, parent_values, ) elif len(self._shape) == 2 and self._shape[1] is not ... and self._shape[1] != value.shape[1]: yield (ERROR_MESSAGE_TYPE, f'Expected shape ({self._shape}), provided shape {value.shape}', parent_hints, parent_values, )
class _LabelsValidator(Validator): __slots__ = ('_labels',) def __repr__(self) -> str: msg = [] for v in self._labels: # type: ignore if isinstance(v, list): parts = [to_name(p, func_to_str=repr) for p in v] msg.append(f'[{", ".join(parts)}]') else: msg.append(to_name(v, func_to_str=repr)) return f'{self.__class__.__name__}({", ".join(msg)})' @staticmethod def _find_parent_frame(parent_values: TParent) -> TFrameAny | None: for v in reversed(parent_values): if isinstance(v, Frame): return v return None @staticmethod def _split_validators( label: tp.Any, ) -> tp.Tuple[TLabel | tp.Set[TLabel], tp.Sequence[TValidator]]: '''Given an object that might be a label, or a list of label and validators, split into two and return label, validators ''' label_e: TLabel label_validators: tp.Sequence[TValidator] | None # evaluate that all post-label values are callables? if isinstance(label, list): label_e = label[0] # must be first label_validators = label[1:] else: label_e = label label_validators = () return label_e, label_validators @staticmethod def _iter_validator_results(*, frame: TFrameAny | None, labels: IndexBase, label: TLabel, validators: tp.Sequence[TValidator], parent_hints: TParent, parent_values: TParent, ) -> tp.Iterator[TValidation]: # be a no-op when no validators are present if validators: if frame is None: raise RuntimeError('Provided label validators in a context without a discoverable Frame.') # NOTE: the same index instance might be on both axis, though this is unlikely if labels is frame.index: s = frame.loc[label] elif labels is frame.columns: s = frame[label] else: raise RuntimeError('Labels associated with an index that is not a member of the parent Frame') for validator in validators: if not validator(s): yield (ERROR_MESSAGE_TYPE, f'Validation failed of label {label!r} with {to_name(validator)}', parent_hints, parent_values, )
[docs] class LabelsOrder(_LabelsValidator): r'''Validate the ordering of labels. Args: \*labels: Provide labels as args. Use ... for regions of zero or more undefined labels. ''' __slots__ = () def __init__(self, *labels: tp.Sequence[TLabel]): self._labels: tp.Sequence[TLabel | tp.List[TLabel | TValidator]] = labels @staticmethod def _repr_remainder( labels: tp.Sequence[TLabel | tp.List[TLabel | TValidator]], ) -> str: # always drop leading or trailing ellipses if labels[0] is ...: labels = labels[1:] if labels[-1] is ...: labels = labels[:-1] return ', '.join((repr(l) if l is not ... else '...') for l in labels) def _provided_is_expected(self, label_e: tp.Any, label_p: TLabel, iloc_p: int, index: IndexBase, ) -> bool: if label_e == label_p: return True try: # expected is defined in the hint; see if we can use it as a lookup, permitting string to date combination in IndexDate return index.loc_to_iloc(label_e) == iloc_p # type: ignore except KeyError: pass return False def _iter_errors(self, value: tp.Any, # an index object hint: tp.Any, parent_hints: TParent, parent_values: TParent, ) -> tp.Iterator[TValidation]: if not isinstance(value, IndexBase): yield (ERROR_MESSAGE_TYPE, f'Expected {self} to be used on Index or IndexHierarchy, not provided {to_name(type(value))}', parent_hints, parent_values, ) else: pf = self._find_parent_frame(parent_values) pos_e = 0 # position in expected len_e = len(self._labels) for iloc_p, label_p in enumerate(value): # iterate provided index # print('pos_p:', iloc_p, repr(label_p), '| pos_e:', pos_e, repr(self._labels[pos_e])) if pos_e >= len_e: yield (ERROR_MESSAGE_TYPE, f'Expected labels exhausted at provided {label_p!r}', parent_hints, parent_values, ) break label_e, label_validators = self._split_validators(self._labels[pos_e]) # partial processor, but defer calling until after label eval iter_validator_results = partial( self._iter_validator_results, frame=pf, labels=value, parent_hints=parent_hints, parent_values=parent_values,) if label_e is not ...: if not self._provided_is_expected( label_e, label_p, iloc_p, value, ): yield (ERROR_MESSAGE_TYPE, f'Expected {label_e!r}, provided {label_p!r}', parent_hints, parent_values, ) break for log in iter_validator_results( label=label_p, validators=label_validators, ): yield log break pos_e += 1 # expected is an Ellipses; either find next as match or continue with Ellipses elif pos_e + 1 == len_e: # last expected is an ellipses # do not need to look ahead, evaluate validators for log in iter_validator_results( label=label_p, validators=label_validators, ): yield log break elif pos_e + 1 < len_e: # more expected labels available # look ahead to see if the next expected hint matches the current label, if so, we use those validators on this column label_next_e, label_next_validators = self._split_validators( self._labels[pos_e + 1], ) if label_next_e is ...: yield (ERROR_MESSAGE_TYPE, 'Expected cannot be defined with adjacent ellipses', parent_hints, parent_values, ) break if self._provided_is_expected( label_next_e, label_p, iloc_p, value, ): # NOTE: if current expected is an ellipses, and the current provided label is equal to the next expected, we know we are done with the ellipses region and must compare to the next expected value; we then skip that value for subsequent evaluation for log in iter_validator_results( label=label_next_e, # type: ignore validators=label_next_validators, ): yield log break pos_e += 2 # skip the compared value, prepare to get next else: # if the lookahead label is not this label, run these validators for log in iter_validator_results( label=label_p, validators=label_validators, ): yield log break else: # no break, exhausted all labels; evaluate final conditions # NOTE: must re-split validators to handle all scenarios if pos_e + 1 == len_e and self._split_validators(self._labels[pos_e])[0] is ...: pass # expected ending on an ellipses elif pos_e < len_e: # if we have unevaluated expected remainder = self._repr_remainder(self._labels[pos_e:]) yield (ERROR_MESSAGE_TYPE, f'Expected has unmatched labels {remainder}', parent_hints, parent_values, )
[docs] class LabelsMatch(_LabelsValidator): r'''Validate the presence of one or more labels, specified with the value, a pattern, or set of values. Order of labels is not relevant. Args: \*labels: Provide labels matchers as args. A label matcher can be a label, a set of labels (of which at least one contained label must match), or a compiled regular expression (with which the `search` method is used to determine a match of string labels). Each label matcher provided must find at least one match, otherwise an error is returned. ''' __slots__ = ( '_match_labels', '_match_res', '_match_sets', '_match_to_validators', ) def __init__(self, *labels: tp.Sequence[TLabelMatchSpecifier]): self._labels: tp.Sequence[TLabelMatchSpecifier | tp.List[TLabelMatchSpecifier | TValidator]] = labels self._match_labels: tp.Set[TLabel] = set() self._match_res: tp.List[tp.Pattern[str]] = [] self._match_sets: tp.List[tp.FrozenSet[TLabel]] = [] self._match_to_validators: tp.Dict[ tp.Union[TLabel, tp.Pattern[str], tp.FrozenSet[TLabel]], tp.Sequence[TValidator] ] = {} for l in labels: m, validators = self._split_validators(l) if isinstance(m, re.Pattern): self._match_res.append(m) elif isinstance(m, set): m = frozenset(m) self._match_sets.append(m) else: self._match_labels.add(m) # validator might be empty tuple self._match_to_validators[m] = validators def _iter_errors(self, value: tp.Any, # an index object hint: tp.Any, parent_hints: TParent, parent_values: TParent, ) -> tp.Iterator[TValidation]: if not isinstance(value, IndexBase): yield (ERROR_MESSAGE_TYPE, f'Expected {self} to be used on Index or IndexHierarchy, not provided {to_name(type(value))}', parent_hints, parent_values, ) else: # count number of times each condition is met score = {k: 0 for k in chain(self._match_labels, self._match_res, self._match_sets)} pf = self._find_parent_frame(parent_values) iter_validator_results = partial( self._iter_validator_results, frame=pf, labels=value, parent_hints=parent_hints, parent_values=parent_values,) # NOTE: a label_p will be tested to match each of the possible three scenarios, as different validators might be assigned to different groups, and all should be tested if validators exist for label_p in value: # iterate provided index matched = False if label_p in self._match_labels: score[label_p] += 1 for log in iter_validator_results(label=label_p, validators=self._match_to_validators[label_p], ): yield log break # continue if isinstance(label_p, str): # NOTE: could coerce all values to strings for match_re in self._match_res: if match_re.search(label_p): score[match_re] += 1 matched = True for log in iter_validator_results(label=label_p, validators=self._match_to_validators[match_re], ): yield log break if matched: break # if matched: # continue for match_set in self._match_sets: if label_p in match_set: score[match_set] += 1 matched = True for log in iter_validator_results(label=label_p, validators=self._match_to_validators[match_set], ): yield log break if matched: break for k, count in score.items(): if count == 0: yield (ERROR_MESSAGE_TYPE, f'Expected label to match {k!r}, no provided match', parent_hints, parent_values, )
[docs] class Apply(Validator): '''Apply a function to a container with an arbitrary function. The validation passes if the function returns True (or a truthy value). Args: func: A function that takes a container and returns a Boolean. ''' __slots__ = ('_func',) def __init__(self, func: tp.Callable[..., bool], /): self._func: tp.Callable[..., bool] = func @staticmethod def _prepare_callable(func: tp.Callable[..., bool]) -> str: return func.__name__ def _iter_errors(self, value: tp.Any, hint: tp.Any, parent_hints: TParent, parent_values: TParent, ) -> tp.Iterator[TValidation]: post = self._func(value) if not bool(post): yield (ERROR_MESSAGE_TYPE, f'{to_name(type(value))} failed validation with {self._prepare_callable(self._func)}', parent_hints, parent_values, )
#------------------------------------------------------------------------------- def _value_to_hint(value: tp.Any) -> tp.Any: # tp._GenericAlias if isinstance(value, type): return tp.Type[value] if isinstance(value, tuple): return value.__class__.__class_getitem__(tuple(_value_to_hint(v) for v in value)) if isinstance(value, Sequence) and not isinstance(value, str): if not len(value): return value.__class__.__class_getitem__(tp.Any) # type: ignore[attr-defined] # as classes may not be hashable, we key to string name to from a set; this is imperfect ut = {v.__class__.__name__: v.__class__ for v in value} if len(ut) == 1: return value.__class__.__class_getitem__(ut[next(iter(ut.keys()))]) # type: ignore[attr-defined] hu = tp.Union.__getitem__(tuple(ut.values())) # pyright: ignore return value.__class__.__class_getitem__(hu) # type: ignore[attr-defined] if isinstance(value, MutableMapping): if not len(value): return value.__class__.__class_getitem__((tp.Any, tp.Any)) # type: ignore[attr-defined] keys_ut = {k.__class__.__name__: k.__class__ for k in value.keys()} values_ut = {v.__class__.__name__: v.__class__ for v in value.values()} if len(keys_ut) == 1: kt = keys_ut[next(iter(keys_ut.keys()))] else: kt = tp.Union.__getitem__(tuple(keys_ut.values())) # pyright: ignore if len(values_ut) == 1: vt = values_ut[next(iter(values_ut.keys()))] else: vt = tp.Union.__getitem__(tuple(values_ut.values())) # pyright: ignore return value.__class__.__class_getitem__((kt, vt)) # type: ignore[attr-defined] # -------------------------------------------------------------------------- # SF containers if isinstance(value, Frame): hints = [_value_to_hint(value.index), _value_to_hint(value.columns)] hints.extend(dt.type().__class__ for dt in value._blocks._iter_dtypes()) return value.__class__.__class_getitem__(tuple(hints)) # type: ignore if isinstance(value, Series): return value.__class__[_value_to_hint(value.index), value.dtype.type().__class__] # type: ignore # must come before index if isinstance(value, IndexDatetime): return value.__class__ if isinstance(value, Index): return value.__class__[value.dtype.type().__class__] # type: ignore if isinstance(value, IndexHierarchy): hints = list(_value_to_hint(value.index_at_depth(i)) for i in range(value.depth)) return value.__class__.__class_getitem__(tuple(hints)) # type: ignore if isinstance(value, (Bus, Yarn)): return value.__class__[_value_to_hint(value.index)] # type: ignore if isinstance(value, np.dtype): return np.dtype.__class_getitem__(value.type().__class__) if isinstance(value, np.ndarray): return value.__class__.__class_getitem__(_value_to_hint(value.dtype)) return value.__class__ #------------------------------------------------------------------------------- # handlers for getting components out of generics def iter_sequence_checks( value: tp.Any, hint: tp.Any, parent_hints: TParent, parent_values: TParent, ) -> tp.Iterable[TValidation]: [h_component] = tp.get_args(hint) pv_next = parent_values + (value,) for v in value: yield v, h_component, parent_hints, pv_next def iter_tuple_checks( value: tp.Any, hint: tp.Any, parent_hints: TParent, parent_values: TParent, ) -> tp.Iterable[TValidation]: h_components = tp.get_args(hint) if h_components[-1] is ...: if len(h_components) != 2 or h_components[0] is ...: yield ERROR_MESSAGE_TYPE, 'Invalid ellipses usage', parent_hints, parent_values else: # support any number of values using the same hint h = h_components[0] pv_next = parent_values + (value,) for v in value: yield v, h, parent_hints, pv_next else: if (h_len := len(h_components)) != len(value): msg = f'Expected tuple length of {h_len}, provided tuple length of {len(value)}' yield ERROR_MESSAGE_TYPE, msg, parent_hints, parent_values pv_next = parent_values + (value,) for v, h in zip(value, h_components): yield v, h, parent_hints, pv_next def iter_mapping_checks( value: tp.Any, hint: tp.Any, parent_hints: TParent, parent_values: TParent, ) -> tp.Iterable[TValidation]: [h_keys, h_values] = tp.get_args(hint) pv_next = parent_values + (value,) for v, h in chain( zip(value.keys(), repeat(h_keys)), zip(value.values(), repeat(h_values)), ): yield v, h, parent_hints, pv_next def iter_typeddict_checks( value: tp.Any, hint: tp.Any, parent_hints: TParent, parent_values: TParent, ) -> tp.Iterable[TValidation]: # hint is a typedict class; returns dict of key name and value hints = tp.get_type_hints(hint, include_extras=True) total = hint.__total__ pv_next = parent_values + (value,) for k, hint_key in hints.items(): # iterating hints retains order k_found = k in value if not total or (is_generic(hint_key) and tp.get_origin(hint_key) is tp.NotRequired): required = False else: required = True if k_found: yield (value[k], hint_key, parent_hints + (f'Key {k!r}',), pv_next, ) elif not k_found and required: yield (ERROR_MESSAGE_TYPE, 'Expected key not provided', parent_hints + (f'Key {k!r}',), pv_next, ) # get over-specified keys if keys_os := value.keys() - hints.keys(): yield (ERROR_MESSAGE_TYPE, f"Keys provided not expected: {', '.join(f'{k!r}' for k in sorted(keys_os))}", parent_hints, pv_next, ) # NOTE: For SF containers, we create an instance of dtype.type() so as to not modify h_generic, as it might be Union or other generic that cannot be wrapped in a tp.Type. This returns a "sample" instance of the type that can be used for testing. Caching this value does not seem a benefit. def iter_series_checks( value: tp.Any, hint: tp.Any, parent_hints: TParent, parent_values: TParent, ) -> tp.Iterable[TValidation]: h_index, h_generic = tp.get_args(hint) # there must be two pv_next = parent_values + (value,) yield value.index, h_index, parent_hints, pv_next yield value.dtype.type(), h_generic, parent_hints, pv_next def iter_index_checks( value: tp.Any, hint: tp.Any, parent_hints: TParent, parent_values: TParent, ) -> tp.Iterable[TValidation]: [h_generic] = tp.get_args(hint) pv_next = parent_values + (value,) yield value.dtype.type(), h_generic, parent_hints, pv_next def iter_index_hierarchy_checks( value: tp.Any, hint: tp.Any, parent_hints: TParent, parent_values: TParent, ) -> tp.Iterable[TValidation]: h_generics = tp.get_args(hint) h_len = len(h_generics) pv_next = parent_values + (value,) unpack_pos = -1 for i, h in enumerate(h_generics): # must get_origin to id unpack; origin of simplet types is None if is_unpack(tp.get_origin(h), h): unpack_pos = i break if h_len == 1 and unpack_pos == 0: [h_tuple] = get_args_unpack(h_generics[0]) # to support usage of Ellipses, treat this as a hint of a corresponding tuple of types yield (tuple(value.index_at_depth(i) for i in range(value.depth)), h_tuple, parent_hints, pv_next, ) else: col_count = value.shape[1] if unpack_pos == -1: # no unpack if h_len != col_count: # if no unpack and lengths are not equal yield (ERROR_MESSAGE_TYPE, f'Expected IndexHierarchy has {h_len} depth, provided IndexHierarchy has {col_count} depth', parent_hints, pv_next, ) else: yield from zip( (value.index_at_depth(i) for i in range(value.depth)), h_generics, repeat(parent_hints), repeat(pv_next), ) else: # unpack found if h_len > col_count + 1: # if 1 unpack, there cannot be more than width h_generics + 1 for the Unpack yield (ERROR_MESSAGE_TYPE, f'Expected IndexHierarchy has {h_len - 1} depth (excluding Unpack), provided IndexHierarchy has {col_count} depth', parent_hints, pv_next, ) else: h_types_post = h_len - unpack_pos - 1 depth_post_unpack = col_count - h_types_post indexes = tuple(value.index_at_depth(i) for i in range(value.depth)) index_pre = indexes[:unpack_pos] index_unpack = indexes[unpack_pos: depth_post_unpack] index_post = indexes[depth_post_unpack:] h_pre = h_generics[:unpack_pos] h_unpack = h_generics[unpack_pos] h_post = h_generics[unpack_pos + 1:] # if len(index_pre) != len(h_pre): # yield ERROR_MESSAGE_TYPE, f'Expected IndexHierarchy has {len(h_pre)} depth before Unpack, provided IndexHierarchy has {len(index_pre)} alignable depth', parent # if len(index_post) != len(h_post): # yield ERROR_MESSAGE_TYPE, f'Expected IndexHierarchy has {len(h_post)} depth after Unpack, provided IndexHierarchy has {len(index_post)} alignable depth', parent # else: col_pos = 0 for index, h in zip(index_pre, h_pre): yield index, h, parent_hints + (f'Depth {col_pos}',), parent_values col_pos += 1 if index_unpack: # we may not have dtypes to compare if fewer [h_tuple] = get_args_unpack(h_unpack) yield (tuple(index_unpack), h_tuple, parent_hints + (f'Depths {col_pos} to {depth_post_unpack - 1}',), pv_next, ) col_pos = depth_post_unpack for index, h in zip(index_post, h_post): yield (index, h, parent_hints + (f'Depth {col_pos}',), pv_next, ) col_pos += 1 def iter_frame_checks( value: tp.Any, hint: tp.Any, parent_hints: TParent, parent_values: TParent, ) -> tp.Iterable[TValidation]: # NOTE: note: at runtime TypeVarTuple defaults do not return anything h_index, h_columns, *h_types = tp.get_args(hint) h_types_len = len(h_types) pv_next = parent_values + (value,) yield value.index, h_index, parent_hints, pv_next yield value.columns, h_columns, parent_hints, pv_next unpack_pos = -1 for i, h in enumerate(h_types): # must get_origin to id unpack; origin of simplet types is None if is_unpack(tp.get_origin(h), h): unpack_pos = i break if h_types_len == 1 and unpack_pos == 0: [h_tuple] = get_args_unpack(h_types[0]) # to support usage of Ellipses, treat this as a hint of a corresponding tuple of types yield (tuple(d.type() for d in value._blocks._iter_dtypes()), h_tuple, parent_hints, pv_next, ) else: col_count = value.shape[1] if unpack_pos == -1: # no unpack if h_types_len != col_count: # if no unpack and lengths are not equal yield (ERROR_MESSAGE_TYPE, f'Expected Frame has {h_types_len} dtype, provided Frame has {col_count} dtype', parent_hints, pv_next, ) else: for dt, h in zip(value._blocks._iter_dtypes(), h_types): yield dt.type(), h, parent_hints, pv_next else: # unpack found if h_types_len > col_count + 1: # if 1 unpack, there cannot be more than width h_types + 1 for the Unpack yield (ERROR_MESSAGE_TYPE, f'Expected Frame has {h_types_len - 1} dtype (excluding Unpack), provided Frame has {col_count} dtype', parent_hints, pv_next, ) else: # in terms of column position, we need to find the first column position after the unpack # if unpack pos is 0, and 5 types, 4 are post # if unpack pos is 3, and 5 types, 1 is post h_types_post = h_types_len - unpack_pos - 1 col_post_unpack = col_count - h_types_post dts = tuple(value._blocks._iter_dtypes()) dt_pre = dts[:unpack_pos] dt_unpack = dts[unpack_pos: col_post_unpack] dt_post = dts[col_post_unpack:] h_pre = h_types[:unpack_pos] h_unpack = h_types[unpack_pos] h_post = h_types[unpack_pos + 1:] # if len(dt_pre) != len(h_pre): # yield ERROR_MESSAGE_TYPE, f'Expected Frame has {len(h_pre)} dtype before Unpack, provided Frame has {len(dt_pre)} alignable dtype', parent # elif len(dt_post) != len(h_post): # yield ERROR_MESSAGE_TYPE, f'Expected Frame has {len(h_post)} dtype after Unpack, provided Frame has {len(dt_post)} alignable dtype', parent # else: col_pos = 0 for dt, h in zip(dt_pre, h_pre): yield (dt.type(), h, parent_hints + (f'Field {col_pos}',), pv_next, ) col_pos += 1 if dt_unpack: # we may not have dtypes to compare if fewer [h_tuple] = get_args_unpack(h_unpack) yield (tuple(d.type() for d in dt_unpack), h_tuple, parent_hints + (f'Fields {col_pos} to {col_post_unpack - 1}',), pv_next, ) col_pos = col_post_unpack for dt, h in zip(dt_post, h_post): yield (dt.type(), h, parent_hints + (f'Field {col_pos}',), pv_next, ) col_pos += 1 def iter_bus_checks( value: tp.Any, hint: tp.Any, parent_hints: TParent, parent_values: TParent, ) -> tp.Iterable[TValidation]: [h_index] = tp.get_args(hint) # there must be one pv_next = parent_values + (value,) yield value.index, h_index, parent_hints, pv_next def iter_yarn_checks( value: tp.Any, hint: tp.Any, parent_hints: TParent, parent_values: TParent, ) -> tp.Iterable[TValidation]: [h_index] = tp.get_args(hint) # there must be one pv_next = parent_values + (value,) yield value.index, h_index, parent_hints, pv_next def iter_ndarray_checks( value: tp.Any, hint: tp.Any, parent_hints: TParent, parent_values: TParent, ) -> tp.Iterable[TValidation]: h_shape, h_dtype = tp.get_args(hint) pv_next = parent_values + (value,) yield value.dtype, h_dtype, parent_hints, pv_next def iter_dtype_checks( value: tp.Any, hint: tp.Any, parent_hints: TParent, parent_values: TParent, ) -> tp.Iterable[TValidation]: [h_generic] = tp.get_args(hint) pv_next = parent_values + (value,) yield value.type(), h_generic, parent_hints, pv_next def iter_np_generic_checks( value: tp.Any, hint: tp.Any, parent_hints: TParent, parent_values: TParent, ) -> tp.Iterable[TValidation]: # we have already confirmed that value is an instance of the origin type h_components = tp.get_args(hint) pv_next = parent_values + (value,) # there are two components we have a complexfloating for h_component in h_components: yield value, h_component, parent_hints, pv_next BIT_TO_LITERAL = { 8: tp.Literal[8], 16: tp.Literal[16], 32: tp.Literal[32], 64: tp.Literal[64], 80: tp.Literal[80], 96: tp.Literal[96], 128: tp.Literal[128], 256: tp.Literal[256], # 512: tp.Literal[512], # not defined in numpy/_typing/__init__.py } def iter_np_nbit_checks( value: tp.Any, hint: tp.Any, parent_hints: TParent, parent_values: TParent, ) -> tp.Iterable[TValidation]: if not isinstance(value, np.generic): pv_next = parent_values + (value,) yield (ERROR_MESSAGE_TYPE, f'Expected {hint.__name__}, provided value is not an np.generic', parent_hints, pv_next, ) else: v_bits = value.dtype.itemsize * 8 pv_next = parent_values + (v_bits,) # NumPy uses __init_subclass__ to limit class names to those with a the bit number in the name h_bits = int(''.join(c for c in hint.__name__ if c.isdecimal())) if value.dtype.kind == DTYPE_COMPLEX_KIND: # complex is represented with two hints, each half of the whole itemsize; adjust value bits and let each side check independently yield v_bits // 2, BIT_TO_LITERAL[h_bits], parent_hints, pv_next else: yield v_bits, BIT_TO_LITERAL[h_bits], parent_hints, pv_next #------------------------------------------------------------------------------- class TypeVarState: __slots__ = ( '_var', '_value', '_bound', '_bound_unset', '_is_bound_union', ) def __init__(self, var: tp.Any, value: tp.Any): '''Provide `TypeVar` and and the first-encountered value for that `TypeVar` ''' self._var = var self._value = value # as bounds might have unions that need specialization, we store the current bound type here and update it if needed self._bound = var.__bound__ # might be None if self._bound is not None and is_union(self._bound): self._is_bound_union = True self._bound_unset = set(range(len(tp.get_args(self._bound)))) else: self._is_bound_union = False @property def is_bound_union(self) -> bool: return self._is_bound_union def _specialize_bound_union(self, value: tp.Any, ) -> None: '''If `hint` is a Union, given value that is assigned to the type var, find value in the Union and replace that hint with the hint of the value. ''' components = list(tp.get_args(self._bound)) # NOTE: this refence to a set is mutated inplace for i, hint in enumerate(components): if i in self._bound_unset and _check(value, hint).validated: components[i] = _value_to_hint(value) self._bound_unset.discard(i) self._bound = tp.Union.__getitem__(tuple(components)) # pyright: ignore @property def constraints(self) -> tp.Any: return self._var.__constraints__ def get_hint(self, value: tp.Any) -> tp.Any: '''Return a hint from derived from the stored value. If this is a bound union, the value is used to specialize the union. ''' if self.is_bound_union: self._specialize_bound_union(value) return self._bound return _value_to_hint(self._value) # this could be stored on init class TypeVarRegistry: __slots__ = ( '_id_to_var', ) _id_to_var: tp.Dict[tp.TypeVar, TypeVarState] def __init__(self) -> None: self._id_to_var = dict() def iter_checks(self, value: tp.Any, var: tp.TypeVar, parent_hints: TParent, parent_values: TParent, ) -> tp.Iterator[TValidation]: pv_next = parent_values + (value,) if var not in self._id_to_var: tvs = TypeVarState(var, value) self._id_to_var[var] = tvs if hints := tvs.constraints: # with constratings we select one option and use it for the life of the Typevar; on the first value, check that the value meets the constraints (recast as a uion); subsequent checks will be based on the stored value yield value, tp.Union.__getitem__(hints), parent_hints, pv_next # pyright: ignore else: tvs = self._id_to_var[var] yield value, tvs.get_hint(value), parent_hints, pv_next #------------------------------------------------------------------------------- def _check( value: tp.Any, hint: tp.Any, tvr: tp.Optional[TypeVarRegistry] = None, parent_hints: TParent = (), parent_values: TParent = (), fail_fast: bool = False, ) -> ClinicResult: # Check queue: queue all checks q = deque(((value, hint, parent_hints, parent_values),)) # Error log: any entry is considered an error e_log: tp.List[TValidation] = [] def tee_error_or_check(records: tp.Iterable[TValidation]) -> None: for record in records: if record[0] is ERROR_MESSAGE_TYPE: e_log.append(record) else: q.append(record) while q: if fail_fast and e_log: return ClinicResult(e_log) v, h, ph, pv = q.popleft() # an ERROR_MESSAGE_TYPE should only be used as a place holder in error logs, not queued checks assert v is not ERROR_MESSAGE_TYPE # print(v, h, ph, pv) if h is tp.Any: continue ph_next = ph + (h,) if is_union(h): # NOTE: must check union types first as tp.Union matches as generic type u_log: tp.List[TValidation] = [] for c_hint in tp.get_args(h): # get components # handing one pair at a time with a secondary call will allow nested types in the union to be evaluated on their own c_log = _check(v, c_hint, tvr, ph_next, pv, fail_fast) if not c_log: # no error found, can exit break u_log.extend(c_log) else: # no breaks, so no matches within union e_log.extend(u_log) elif isinstance(h, Validator): e_log.extend(h._iter_errors(v, h, ph_next, pv)) elif isinstance(h, tp.TypeVar): if tvr is not None: tee_error_or_check(tvr.iter_checks(v, h, ph_next, pv)) else: # ignore typevar continue elif is_generic(h): origin = tp.get_origin(h) if origin is type: # a tp.Type[x] generic [t] = tp.get_args(h) try: # the v should be a subclass of t t_check = issubclass(t, v) except TypeError: t_check = False if t_check: continue e_log.append((v, h, ph, pv)) elif origin == tp.Annotated: # NOTE: cannot use `is` due backwards compat h_type, *h_annotations = tp.get_args(h) # perform the un-annotated check q.append((v, h_type, ph_next, pv)) for h_annotation in h_annotations: if isinstance(h_annotation, Validator): q.append((v, h_annotation, ph_next, pv)) elif origin == tp.Literal: # NOTE: cannot use `is` due backwards compat l_log: tp.List[TValidation] = [] for l_hint in tp.get_args(h): # get components c_log = _check(v, l_hint, tvr, ph_next, pv, fail_fast) if not c_log: # no error found, can exit break l_log.extend(c_log) else: # no breaks, so no matches within literal e_log.extend(l_log) elif origin == tp.Required or origin == tp.NotRequired: # semantics handled elsewhere, just unpack [h_type] = tp.get_args(h) q.append((v, h_type, ph_next, pv)) else: # NOTE: many generic containers require recursing into component values. It is in these functions below that parent_values is updated and yielded back into the queue. There are many other cases where parent_values does not need to be updated (and for efficiency is not). if not isinstance(v, origin): e_log.append((v, origin, ph, pv)) continue if isinstance(v, tuple): tee_error_or_check(iter_tuple_checks(v, h, ph_next, pv)) elif isinstance(v, Sequence): tee_error_or_check(iter_sequence_checks(v, h, ph_next, pv)) elif isinstance(v, MutableMapping): tee_error_or_check(iter_mapping_checks(v, h, ph_next, pv)) elif isinstance(v, Index): tee_error_or_check(iter_index_checks(v, h, ph_next, pv)) elif isinstance(v, IndexHierarchy): tee_error_or_check(iter_index_hierarchy_checks(v, h, ph_next, pv)) elif isinstance(v, Series): tee_error_or_check(iter_series_checks(v, h, ph_next, pv)) elif isinstance(v, Frame): tee_error_or_check(iter_frame_checks(v, h, ph_next, pv)) elif isinstance(v, Bus): tee_error_or_check(iter_bus_checks(v, h, ph_next, pv)) elif isinstance(v, Yarn): tee_error_or_check(iter_yarn_checks(v, h, ph_next, pv)) elif isinstance(v, np.ndarray): tee_error_or_check(iter_ndarray_checks(v, h, ph_next, pv)) elif isinstance(v, np.dtype): tee_error_or_check(iter_dtype_checks(v, h, ph_next, pv)) elif isinstance(v, np.generic): tee_error_or_check(iter_np_generic_checks(v, h, ph_next, pv)) else: raise NotImplementedError(f'no handling for generic {origin}') #pragma: no cover elif tp.is_typeddict(h): tee_error_or_check(iter_typeddict_checks(v, h, ph_next, pv)) elif not isinstance(h, type): # h is a value from a literal # must check type: https://peps.python.org/pep-0586/#equivalence-of-two-literals if type(v) != type(h) or v != h: # pylint: disable=C0123 e_log.append((v, h, ph, pv)) # h is a class elif issubclass(h, NBitBase): tee_error_or_check(iter_np_nbit_checks(v, h, ph_next, pv)) else: # h is non-generic type, must continue if valid if v.__class__ is bool: if h is bool: continue elif h is np.object_ and v is None: # when comparing object dtypes our test value is None, which is not an instance of np.object_ continue # general case elif isinstance(v, h): continue e_log.append((v, h, ph, pv)) return ClinicResult(e_log) #------------------------------------------------------------------------------- # public interfaces
[docs] class TypeClinic: '''A ``TypeClinic`` instance, created from (almost) any object, can be used to derive a type hint (or type hint string), or test the object against a provided hint. ''' __slots__ = ('_value',) _INTERFACE = ( 'to_hint', 'check', 'warn', '__call__', '__repr__', )
[docs] def __init__(self, value: tp.Any, /): self._value = value
[docs] def to_hint(self) -> tp.Any: '''Return the type hint (the type and/or generic aliases necessary) to represent the object given at initialization. ''' # NOTE: this can cache as value assumed immutable return _value_to_hint(self._value)
[docs] def __repr__(self) -> str: '''Return a compact string representation of the type hint (the type and/or generic aliases necessary) to represent the object given at initialization. ''' return to_name(self.to_hint())
[docs] def check(self, hint: tp.Any, /, *, fail_fast: bool = False, ) -> None: '''Given a hint (a type and/or generic alias), raise a ``ClinicError`` exception describing the result of the check if an error is found. Args: fail_fast: If True, return on first failure. If False, all failures are discovered and reported. ''' if cr := self(hint, fail_fast=fail_fast): raise ClinicError(cr)
[docs] def warn(self, hint: tp.Any, /, *, fail_fast: bool = False, category: tp.Type[Warning] = UserWarning, ) -> None: '''Given a hint (a type and/or generic alias), issue a warning describing the result of the check if an error is found. Args: fail_fast: If True, return on first failure. If False, all failures are discovered and reported. category: The ``Warning`` subclass to be used for issueing the warning. ''' if cr := self(hint, fail_fast=fail_fast): warnings.warn(cr.to_str(), category)
[docs] def __call__(self, hint: tp.Any, /, *, fail_fast: bool = False, ) -> ClinicResult: '''Given a hint (a type and/or generic alias), return a ``ClinicResult`` object describing the result of the check. Args: fail_fast: If True, return on first failure. If False, all failures are discovered and reported. ''' tvr = TypeVarRegistry() post = _check(self._value, hint, tvr, fail_fast=fail_fast) return post
class ErrorAction(Enum): RAISE = 0 WARN = 1 RETURN = 2 def _check_interface( func: tp.Callable[..., tp.Any], args: tp.Any, kwargs: tp.Any, fail_fast: bool, error_action: ErrorAction, category: tp.Type[Warning] = UserWarning, ) -> tp.Any: # include_extras insures that Annotated generics are returned hints = tp.get_type_hints(func, include_extras=True) sig = Signature.from_callable(func) sig_bound = sig.bind(*args, **kwargs) sig_bound.apply_defaults() sig_str = to_signature(sig_bound, hints) parent_hints = (f'args of {sig_str}',) parent_values = (func,) # NOTE: we create one TV registry per check, so state associated with typevars will be bound by the context of one function tvr = TypeVarRegistry() for k, v in sig_bound.arguments.items(): if h_p := hints.get(k, None): arg_hints = parent_hints + (f'In arg {k}',) if cr := _check(v, h_p, tvr, arg_hints, parent_values, fail_fast=fail_fast): if error_action is ErrorAction.RAISE: raise ClinicError(cr) elif error_action is ErrorAction.WARN: warnings.warn(cr.to_str(), category) elif error_action is ErrorAction.RETURN: return cr post = func(*args, **kwargs) if h_return := hints.get('return', None): if cr := _check(post, h_return, tvr, (f'return of {sig_str}',), parent_values, fail_fast=fail_fast, ): if error_action is ErrorAction.RAISE: raise ClinicError(cr) elif error_action is ErrorAction.WARN: warnings.warn(cr.to_str(), category) elif error_action is ErrorAction.RETURN: return cr return post TVFunc = tp.TypeVar('TVFunc', bound=tp.Callable[..., tp.Any])
[docs] class CallGuard: '''A family of decorators for run-time type checking and data validation of functions. ''' @tp.overload @staticmethod def check(func: TVFunc, /) -> TVFunc: ... @tp.overload @staticmethod def check(*, fail_fast: bool) -> tp.Callable[[TVFunc], TVFunc]: ... @tp.overload @staticmethod def check(func: None, /, *, fail_fast: bool) -> tp.Callable[[TVFunc], TVFunc]: ...
[docs] @staticmethod def check( func: TVFunc | None = None, /, *, fail_fast: bool = False, ) -> tp.Any: '''A function decorator to perform run-time checking of function arguments and return values based on the function type annotations, including type hints and ``Require``-provided validators. Raises ``ClinicError`` on failure. ''' def decorator(func: TVFunc) -> TVFunc: @wraps(func) def wrapper(*args: tp.Any, **kwargs: tp.Any) -> tp.Any: return _check_interface(func, args, kwargs, fail_fast, ErrorAction.RAISE, ) return tp.cast(TVFunc, wrapper) if func is not None: return decorator(func) return decorator
@tp.overload @staticmethod def warn(func: TVFunc, /) -> TVFunc: ... @tp.overload @staticmethod def warn(*, fail_fast: bool, category: tp.Type[Warning]) -> tp.Callable[[TVFunc], TVFunc]: ... @tp.overload @staticmethod def warn(func: None, /, *, fail_fast: bool, category: tp.Type[Warning]) -> tp.Callable[[TVFunc], TVFunc]: ...
[docs] @staticmethod def warn( func: TVFunc | None = None, /, *, fail_fast: bool = False, category: tp.Type[Warning] = UserWarning, ) -> tp.Any: '''A function decorator to perform run-time checking of function arguments and return values based on the function type annotations, including type hints and ``Require``-provided validators. Issues a warning on failure. ''' def decorator(func: TVFunc) -> TVFunc: @wraps(func) def wrapper(*args: tp.Any, **kwargs: tp.Any) -> tp.Any: return _check_interface(func, args, kwargs, fail_fast, ErrorAction.WARN, category, ) return tp.cast(TVFunc, wrapper) if func is not None: return decorator(func) return decorator