Coverage for muutils/misc/classes.py: 78%

23 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-04-04 03:33 -0600

1from __future__ import annotations 

2 

3from typing import ( 

4 Iterable, 

5 Any, 

6 Protocol, 

7 ClassVar, 

8 runtime_checkable, 

9) 

10 

11from muutils.misc.sequence import flatten 

12 

13 

14def is_abstract(cls: type) -> bool: 

15 """ 

16 Returns if a class is abstract. 

17 """ 

18 if not hasattr(cls, "__abstractmethods__"): 

19 return False # an ordinary class 

20 elif len(cls.__abstractmethods__) == 0: 

21 return False # a concrete implementation of an abstract class 

22 else: 

23 return True # an abstract class 

24 

25 

26def get_all_subclasses(class_: type, include_self=False) -> set[type]: 

27 """ 

28 Returns a set containing all child classes in the subclass graph of `class_`. 

29 I.e., includes subclasses of subclasses, etc. 

30 

31 # Parameters 

32 - `include_self`: Whether to include `class_` itself in the returned set 

33 - `class_`: Superclass 

34 

35 # Development 

36 Since most class hierarchies are small, the inefficiencies of the existing recursive implementation aren't problematic. 

37 It might be valuable to refactor with memoization if the need arises to use this function on a very large class hierarchy. 

38 """ 

39 subs: set[type] = set( 

40 flatten( 

41 get_all_subclasses(sub, include_self=True) 

42 for sub in class_.__subclasses__() 

43 if sub is not None 

44 ) 

45 ) 

46 if include_self: 

47 subs.add(class_) 

48 return subs 

49 

50 

51def isinstance_by_type_name(o: object, type_name: str): 

52 """Behaves like stdlib `isinstance` except it accepts a string representation of the type rather than the type itself. 

53 This is a hacky function intended to circumvent the need to import a type into a module. 

54 It is susceptible to type name collisions. 

55 

56 # Parameters 

57 `o`: Object (not the type itself) whose type to interrogate 

58 `type_name`: The string returned by `type_.__name__`. 

59 Generic types are not supported, only types that would appear in `type_.__mro__`. 

60 """ 

61 return type_name in {s.__name__ for s in type(o).__mro__} 

62 

63 

64# dataclass magic 

65# -------------------------------------------------------------------------------- 

66 

67 

68@runtime_checkable 

69class IsDataclass(Protocol): 

70 # Generic type for any dataclass instance 

71 # https://stackoverflow.com/questions/54668000/type-hint-for-an-instance-of-a-non-specific-dataclass 

72 __dataclass_fields__: ClassVar[dict[str, Any]] 

73 

74 

75def get_hashable_eq_attrs(dc: IsDataclass) -> tuple[Any]: 

76 """Returns a tuple of all fields used for equality comparison, including the type of the dataclass itself. 

77 The type is included to preserve the unequal equality behavior of instances of different dataclasses whose fields are identical. 

78 Essentially used to generate a hashable dataclass representation for equality comparison even if it's not frozen. 

79 """ 

80 return *( 

81 getattr(dc, fld.name) 

82 for fld in filter(lambda x: x.compare, dc.__dataclass_fields__.values()) 

83 ), type(dc) 

84 

85 

86def dataclass_set_equals( 

87 coll1: Iterable[IsDataclass], coll2: Iterable[IsDataclass] 

88) -> bool: 

89 """Compares 2 collections of dataclass instances as if they were sets. 

90 Duplicates are ignored in the same manner as a set. 

91 Unfrozen dataclasses can't be placed in sets since they're not hashable. 

92 Collections of them may be compared using this function. 

93 """ 

94 

95 return {get_hashable_eq_attrs(x) for x in coll1} == { 

96 get_hashable_eq_attrs(y) for y in coll2 

97 }