Coverage for muutils/misc/classes.py: 78%
23 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
1from __future__ import annotations
3from typing import (
4 Iterable,
5 Any,
6 Protocol,
7 ClassVar,
8 runtime_checkable,
9)
11from muutils.misc.sequence import flatten
14def is_abstract(cls: type) -> bool:
15 """
16 Returns if a class is abstract.
17 """
18 if not hasattr(cls, "__abstractmethods__"):
19 return False # an ordinary class
20 elif len(cls.__abstractmethods__) == 0:
21 return False # a concrete implementation of an abstract class
22 else:
23 return True # an abstract class
26def get_all_subclasses(class_: type, include_self=False) -> set[type]:
27 """
28 Returns a set containing all child classes in the subclass graph of `class_`.
29 I.e., includes subclasses of subclasses, etc.
31 # Parameters
32 - `include_self`: Whether to include `class_` itself in the returned set
33 - `class_`: Superclass
35 # Development
36 Since most class hierarchies are small, the inefficiencies of the existing recursive implementation aren't problematic.
37 It might be valuable to refactor with memoization if the need arises to use this function on a very large class hierarchy.
38 """
39 subs: set[type] = set(
40 flatten(
41 get_all_subclasses(sub, include_self=True)
42 for sub in class_.__subclasses__()
43 if sub is not None
44 )
45 )
46 if include_self:
47 subs.add(class_)
48 return subs
51def isinstance_by_type_name(o: object, type_name: str):
52 """Behaves like stdlib `isinstance` except it accepts a string representation of the type rather than the type itself.
53 This is a hacky function intended to circumvent the need to import a type into a module.
54 It is susceptible to type name collisions.
56 # Parameters
57 `o`: Object (not the type itself) whose type to interrogate
58 `type_name`: The string returned by `type_.__name__`.
59 Generic types are not supported, only types that would appear in `type_.__mro__`.
60 """
61 return type_name in {s.__name__ for s in type(o).__mro__}
64# dataclass magic
65# --------------------------------------------------------------------------------
68@runtime_checkable
69class IsDataclass(Protocol):
70 # Generic type for any dataclass instance
71 # https://stackoverflow.com/questions/54668000/type-hint-for-an-instance-of-a-non-specific-dataclass
72 __dataclass_fields__: ClassVar[dict[str, Any]]
75def get_hashable_eq_attrs(dc: IsDataclass) -> tuple[Any]:
76 """Returns a tuple of all fields used for equality comparison, including the type of the dataclass itself.
77 The type is included to preserve the unequal equality behavior of instances of different dataclasses whose fields are identical.
78 Essentially used to generate a hashable dataclass representation for equality comparison even if it's not frozen.
79 """
80 return *(
81 getattr(dc, fld.name)
82 for fld in filter(lambda x: x.compare, dc.__dataclass_fields__.values())
83 ), type(dc)
86def dataclass_set_equals(
87 coll1: Iterable[IsDataclass], coll2: Iterable[IsDataclass]
88) -> bool:
89 """Compares 2 collections of dataclass instances as if they were sets.
90 Duplicates are ignored in the same manner as a set.
91 Unfrozen dataclasses can't be placed in sets since they're not hashable.
92 Collections of them may be compared using this function.
93 """
95 return {get_hashable_eq_attrs(x) for x in coll1} == {
96 get_hashable_eq_attrs(y) for y in coll2
97 }