Coverage for muutils/group_equiv.py: 100%
29 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2025-04-04 03:33 -0600
1"group items by assuming that `eq_func` defines an equivalence relation"
3from __future__ import annotations
5from itertools import chain
6from typing import Callable, Sequence, TypeVar
8T = TypeVar("T")
11def group_by_equivalence(
12 items_in: Sequence[T],
13 eq_func: Callable[[T, T], bool],
14) -> list[list[T]]:
15 """group items by assuming that `eq_func` implies an equivalence relation but might not be transitive
17 so, if f(a,b) and f(b,c) then f(a,c) might be false, but we still want to put [a,b,c] in the same class
19 note that lists are used to avoid the need for hashable items, and to allow for duplicates
21 # Arguments
22 - `items_in: Sequence[T]` the items to group
23 - `eq_func: Callable[[T, T], bool]` a function that returns true if two items are equivalent. need not be transitive
24 """
26 items: list[T] = list(items_in)
27 items.reverse()
28 output: list[list[T]] = list()
30 while items:
31 x: T = items.pop()
33 # try to add to an existing class
34 found_classes: list[int] = list()
35 for i, c in enumerate(output):
36 if any(eq_func(x, y) for y in c):
37 found_classes.append(i)
39 # if one class found, add to it
40 if len(found_classes) == 1:
41 output[found_classes.pop()].append(x)
43 elif len(found_classes) > 1:
44 # if multiple classes found, merge the classes
46 # first sort the ones to be merged
47 output_new: list[list[T]] = list()
48 to_merge: list[list[T]] = list()
49 for i, c in enumerate(output):
50 if i in found_classes:
51 to_merge.append(c)
52 else:
53 output_new.append(c)
55 # then merge them back in, along with the element `x`
56 merged: list[T] = list(chain.from_iterable(to_merge))
57 merged.append(x)
59 output_new.append(merged)
60 output = output_new
62 # if no class found, make a new one
63 else:
64 output.append([x])
66 return output