mirror of
https://github.com/3b1b/manim.git
synced 2025-04-13 09:47:07 +00:00
173 lines
4.7 KiB
Python
173 lines
4.7 KiB
Python
from __future__ import annotations
|
|
|
|
from colour import Color
|
|
|
|
import numpy as np
|
|
import random
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
if TYPE_CHECKING:
|
|
from typing import Callable, Iterable, Sequence, TypeVar
|
|
|
|
T = TypeVar("T")
|
|
S = TypeVar("S")
|
|
|
|
|
|
def remove_list_redundancies(lst: Sequence[T]) -> list[T]:
|
|
"""
|
|
Remove duplicate elements while preserving order.
|
|
Keeps the last occurrence of each element
|
|
"""
|
|
return list(reversed(dict.fromkeys(reversed(lst))))
|
|
|
|
|
|
def list_update(l1: Iterable[T], l2: Iterable[T]) -> list[T]:
|
|
"""
|
|
Used instead of list(set(l1).update(l2)) to maintain order,
|
|
making sure duplicates are removed from l1, not l2.
|
|
"""
|
|
return remove_list_redundancies([*l1, *l2])
|
|
|
|
|
|
def list_difference_update(l1: Iterable[T], l2: Iterable[T]) -> list[T]:
|
|
return [e for e in l1 if e not in l2]
|
|
|
|
|
|
def adjacent_n_tuples(objects: Sequence[T], n: int) -> zip[tuple[T, ...]]:
|
|
return zip(*[
|
|
[*objects[k:], *objects[:k]]
|
|
for k in range(n)
|
|
])
|
|
|
|
|
|
def adjacent_pairs(objects: Sequence[T]) -> zip[tuple[T, T]]:
|
|
return adjacent_n_tuples(objects, 2)
|
|
|
|
|
|
def batch_by_property(
|
|
items: Iterable[T],
|
|
property_func: Callable[[T], S]
|
|
) -> list[tuple[T, S]]:
|
|
"""
|
|
Takes in a list, and returns a list of tuples, (batch, prop)
|
|
such that all items in a batch have the same output when
|
|
put into property_func, and such that chaining all these
|
|
batches together would give the original list (i.e. order is
|
|
preserved)
|
|
"""
|
|
batch_prop_pairs = []
|
|
curr_batch = []
|
|
curr_prop = None
|
|
for item in items:
|
|
prop = property_func(item)
|
|
if prop != curr_prop:
|
|
# Add current batch
|
|
if len(curr_batch) > 0:
|
|
batch_prop_pairs.append((curr_batch, curr_prop))
|
|
# Redefine curr
|
|
curr_prop = prop
|
|
curr_batch = [item]
|
|
else:
|
|
curr_batch.append(item)
|
|
if len(curr_batch) > 0:
|
|
batch_prop_pairs.append((curr_batch, curr_prop))
|
|
return batch_prop_pairs
|
|
|
|
|
|
def listify(obj: object) -> list:
|
|
if isinstance(obj, str):
|
|
return [obj]
|
|
try:
|
|
return list(obj)
|
|
except TypeError:
|
|
return [obj]
|
|
|
|
|
|
def shuffled(iterable: Iterable) -> list:
|
|
as_list = list(iterable)
|
|
random.shuffle(as_list)
|
|
return as_list
|
|
|
|
|
|
def resize_array(nparray: np.ndarray, length: int) -> np.ndarray:
|
|
if len(nparray) == length:
|
|
return nparray
|
|
return np.resize(nparray, (length, *nparray.shape[1:]))
|
|
|
|
|
|
def resize_preserving_order(nparray: np.ndarray, length: int) -> np.ndarray:
|
|
if len(nparray) == 0:
|
|
return np.resize(nparray, length)
|
|
if len(nparray) == length:
|
|
return nparray
|
|
indices = np.arange(length) * len(nparray) // length
|
|
return nparray[indices]
|
|
|
|
|
|
def resize_with_interpolation(nparray: np.ndarray, length: int) -> np.ndarray:
|
|
if len(nparray) == length:
|
|
return nparray
|
|
if len(nparray) == 1 or array_is_constant(nparray):
|
|
return nparray[:1].repeat(length, axis=0)
|
|
if length == 0:
|
|
return np.zeros((0, *nparray.shape[1:]))
|
|
cont_indices = np.linspace(0, len(nparray) - 1, length)
|
|
return np.array([
|
|
(1 - a) * nparray[lh] + a * nparray[rh]
|
|
for ci in cont_indices
|
|
for lh, rh, a in [(int(ci), int(np.ceil(ci)), ci % 1)]
|
|
])
|
|
|
|
|
|
def make_even(
|
|
iterable_1: Sequence[T],
|
|
iterable_2: Sequence[S]
|
|
) -> tuple[Sequence[T], Sequence[S]]:
|
|
len1 = len(iterable_1)
|
|
len2 = len(iterable_2)
|
|
if len1 == len2:
|
|
return iterable_1, iterable_2
|
|
new_len = max(len1, len2)
|
|
return (
|
|
[iterable_1[(n * len1) // new_len] for n in range(new_len)],
|
|
[iterable_2[(n * len2) // new_len] for n in range(new_len)]
|
|
)
|
|
|
|
|
|
def arrays_match(arr1: np.ndarray, arr2: np.ndarray) -> bool:
|
|
return arr1.shape == arr2.shape and (arr1 == arr2).all()
|
|
|
|
|
|
def array_is_constant(arr: np.ndarray) -> bool:
|
|
return len(arr) > 0 and (arr == arr[0]).all()
|
|
|
|
|
|
def cartesian_product(*arrays: np.ndarray):
|
|
"""
|
|
Copied from https://stackoverflow.com/a/11146645
|
|
"""
|
|
la = len(arrays)
|
|
dtype = np.result_type(*arrays)
|
|
arr = np.empty([len(a) for a in arrays] + [la], dtype=dtype)
|
|
for i, a in enumerate(np.ix_(*arrays)):
|
|
arr[..., i] = a
|
|
return arr.reshape(-1, la)
|
|
|
|
|
|
def hash_obj(obj: object) -> int:
|
|
if isinstance(obj, dict):
|
|
return hash(tuple(sorted([
|
|
(hash_obj(k), hash_obj(v)) for k, v in obj.items()
|
|
])))
|
|
|
|
if isinstance(obj, set):
|
|
return hash(tuple(sorted(hash_obj(e) for e in obj)))
|
|
|
|
if isinstance(obj, (tuple, list)):
|
|
return hash(tuple(hash_obj(e) for e in obj))
|
|
|
|
if isinstance(obj, Color):
|
|
return hash(obj.get_rgb())
|
|
|
|
return hash(obj)
|