Coverage for muutils/group_equiv.py: 100%

29 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-04-04 03:33 -0600

1"group items by assuming that `eq_func` defines an equivalence relation" 

2 

3from __future__ import annotations 

4 

5from itertools import chain 

6from typing import Callable, Sequence, TypeVar 

7 

8T = TypeVar("T") 

9 

10 

11def group_by_equivalence( 

12 items_in: Sequence[T], 

13 eq_func: Callable[[T, T], bool], 

14) -> list[list[T]]: 

15 """group items by assuming that `eq_func` implies an equivalence relation but might not be transitive 

16 

17 so, if f(a,b) and f(b,c) then f(a,c) might be false, but we still want to put [a,b,c] in the same class 

18 

19 note that lists are used to avoid the need for hashable items, and to allow for duplicates 

20 

21 # Arguments 

22 - `items_in: Sequence[T]` the items to group 

23 - `eq_func: Callable[[T, T], bool]` a function that returns true if two items are equivalent. need not be transitive 

24 """ 

25 

26 items: list[T] = list(items_in) 

27 items.reverse() 

28 output: list[list[T]] = list() 

29 

30 while items: 

31 x: T = items.pop() 

32 

33 # try to add to an existing class 

34 found_classes: list[int] = list() 

35 for i, c in enumerate(output): 

36 if any(eq_func(x, y) for y in c): 

37 found_classes.append(i) 

38 

39 # if one class found, add to it 

40 if len(found_classes) == 1: 

41 output[found_classes.pop()].append(x) 

42 

43 elif len(found_classes) > 1: 

44 # if multiple classes found, merge the classes 

45 

46 # first sort the ones to be merged 

47 output_new: list[list[T]] = list() 

48 to_merge: list[list[T]] = list() 

49 for i, c in enumerate(output): 

50 if i in found_classes: 

51 to_merge.append(c) 

52 else: 

53 output_new.append(c) 

54 

55 # then merge them back in, along with the element `x` 

56 merged: list[T] = list(chain.from_iterable(to_merge)) 

57 merged.append(x) 

58 

59 output_new.append(merged) 

60 output = output_new 

61 

62 # if no class found, make a new one 

63 else: 

64 output.append([x]) 

65 

66 return output