diff --git a/manimlib/utils/iterables.py b/manimlib/utils/iterables.py index 3b7dbfba..90cd4c4b 100644 --- a/manimlib/utils/iterables.py +++ b/manimlib/utils/iterables.py @@ -136,6 +136,18 @@ def array_is_constant(arr: np.ndarray) -> bool: return len(arr) > 0 and (arr == arr[0]).all() +def cartesian_product(*arrays: np.ndarray): + """ + Copied from https://stackoverflow.com/a/11146645 + """ + la = len(arrays) + dtype = np.result_type(*arrays) + arr = np.empty([len(a) for a in arrays] + [la], dtype=dtype) + for i, a in enumerate(np.ix_(*arrays)): + arr[..., i] = a + return arr.reshape(-1, la) + + def hash_obj(obj: object) -> int: if isinstance(obj, dict): return hash(tuple(sorted([