chore: add type hints to manimlib.utils

This commit is contained in:
TonyCrane 2022-02-12 23:47:23 +08:00
parent 67f5b10626
commit 6e292daf58
No known key found for this signature in database
GPG key ID: 2313A5058A9C637C
12 changed files with 281 additions and 122 deletions

View file

@ -1,3 +1,7 @@
from __future__ import annotations
from typing import Iterable, Callable, TypeVar
from scipy import linalg from scipy import linalg
import numpy as np import numpy as np
@ -8,9 +12,9 @@ from manimlib.utils.space_ops import midpoint
from manimlib.logger import log from manimlib.logger import log
CLOSED_THRESHOLD = 0.001 CLOSED_THRESHOLD = 0.001
T = TypeVar("T")
def bezier(points: Iterable) -> Callable[[float], float | Iterable]:
def bezier(points):
n = len(points) - 1 n = len(points) - 1
def result(t): def result(t):
@ -22,7 +26,11 @@ def bezier(points):
return result return result
def partial_bezier_points(points, a, b): def partial_bezier_points(
points: Iterable[np.ndarray],
a: float,
b: float
) -> list[float]:
""" """
Given an list of points which define Given an list of points which define
a bezier curve, and two numbers 0<=a<b<=1, a bezier curve, and two numbers 0<=a<b<=1,
@ -48,7 +56,11 @@ def partial_bezier_points(points, a, b):
# Shortened version of partial_bezier_points just for quadratics, # Shortened version of partial_bezier_points just for quadratics,
# since this is called a fair amount # since this is called a fair amount
def partial_quadratic_bezier_points(points, a, b): def partial_quadratic_bezier_points(
points: Iterable[np.ndarray],
a: float,
b: float
) -> list[float]:
if a == 1: if a == 1:
return 3 * [points[-1]] return 3 * [points[-1]]
@ -65,7 +77,7 @@ def partial_quadratic_bezier_points(points, a, b):
# Linear interpolation variants # Linear interpolation variants
def interpolate(start, end, alpha): def interpolate(start: T, end: T, alpha: float) -> T:
try: try:
return (1 - alpha) * start + alpha * end return (1 - alpha) * start + alpha * end
except TypeError: except TypeError:
@ -76,12 +88,22 @@ def interpolate(start, end, alpha):
sys.exit(2) sys.exit(2)
def set_array_by_interpolation(arr, arr1, arr2, alpha, interp_func=interpolate): def set_array_by_interpolation(
arr: list[T],
arr1: list[T],
arr2: list[T],
alpha: float,
interp_func: Callable[[T, T, float], T] = interpolate
) -> list[T]:
arr[:] = interp_func(arr1, arr2, alpha) arr[:] = interp_func(arr1, arr2, alpha)
return arr return arr
def integer_interpolate(start, end, alpha): def integer_interpolate(
start: T,
end: T,
alpha: float
) -> tuple[int, float]:
""" """
alpha is a float between 0 and 1. This returns alpha is a float between 0 and 1. This returns
an integer between start and end (inclusive) representing an integer between start and end (inclusive) representing
@ -102,22 +124,30 @@ def integer_interpolate(start, end, alpha):
return (value, residue) return (value, residue)
def mid(start, end): def mid(start: T, end: T) -> T:
return (start + end) / 2.0 return (start + end) / 2.0
def inverse_interpolate(start, end, value): def inverse_interpolate(start: T, end: T, value: T) -> float:
return np.true_divide(value - start, end - start) return np.true_divide(value - start, end - start)
def match_interpolate(new_start, new_end, old_start, old_end, old_value): def match_interpolate(
new_start: T,
new_end: T,
old_start: T,
old_end: T,
old_value: T
) -> T:
return interpolate( return interpolate(
new_start, new_end, new_start, new_end,
inverse_interpolate(old_start, old_end, old_value) inverse_interpolate(old_start, old_end, old_value)
) )
def get_smooth_quadratic_bezier_handle_points(points): def get_smooth_quadratic_bezier_handle_points(
points: Iterable[np.ndarray]
) -> np.ndarray | list[np.ndarray]:
""" """
Figuring out which bezier curves most smoothly connect a sequence of points. Figuring out which bezier curves most smoothly connect a sequence of points.
@ -149,7 +179,9 @@ def get_smooth_quadratic_bezier_handle_points(points):
return handles return handles
def get_smooth_cubic_bezier_handle_points(points): def get_smooth_cubic_bezier_handle_points(
points: Iterable[np.ndarray]
) -> tuple[np.ndarray, np.ndarray]:
points = np.array(points) points = np.array(points)
num_handles = len(points) - 1 num_handles = len(points) - 1
dim = points.shape[1] dim = points.shape[1]
@ -207,7 +239,10 @@ def get_smooth_cubic_bezier_handle_points(points):
return handle_pairs[0::2], handle_pairs[1::2] return handle_pairs[0::2], handle_pairs[1::2]
def diag_to_matrix(l_and_u, diag): def diag_to_matrix(
l_and_u: tuple[int, int],
diag: np.ndarray
) -> np.ndarray:
""" """
Converts array whose rows represent diagonal Converts array whose rows represent diagonal
entries of a matrix into the matrix itself. entries of a matrix into the matrix itself.
@ -224,13 +259,18 @@ def diag_to_matrix(l_and_u, diag):
return matrix return matrix
def is_closed(points): def is_closed(points: Iterable[np.ndarray]) -> bool:
return np.allclose(points[0], points[-1]) return np.allclose(points[0], points[-1])
# Given 4 control points for a cubic bezier curve (or arrays of such) # Given 4 control points for a cubic bezier curve (or arrays of such)
# return control points for 2 quadratics (or 2n quadratics) approximating them. # return control points for 2 quadratics (or 2n quadratics) approximating them.
def get_quadratic_approximation_of_cubic(a0, h0, h1, a1): def get_quadratic_approximation_of_cubic(
a0: np.ndarray | Iterable[np.ndarray],
h0: np.ndarray | Iterable[np.ndarray],
h1: np.ndarray | Iterable[np.ndarray],
a1: np.ndarray | Iterable[np.ndarray]
) -> np.ndarray:
a0 = np.array(a0, ndmin=2) a0 = np.array(a0, ndmin=2)
h0 = np.array(h0, ndmin=2) h0 = np.array(h0, ndmin=2)
h1 = np.array(h1, ndmin=2) h1 = np.array(h1, ndmin=2)
@ -298,7 +338,9 @@ def get_quadratic_approximation_of_cubic(a0, h0, h1, a1):
return result return result
def get_smooth_quadratic_bezier_path_through(points): def get_smooth_quadratic_bezier_path_through(
points: list[np.ndarray]
) -> np.ndarray:
# TODO # TODO
h0, h1 = get_smooth_cubic_bezier_handle_points(points) h0, h1 = get_smooth_cubic_bezier_handle_points(points)
a0 = points[:-1] a0 = points[:-1]

View file

@ -1,19 +1,27 @@
from __future__ import annotations
import time import time
import numpy as np
from typing import Callable
from manimlib.constants import BLACK from manimlib.constants import BLACK
from manimlib.mobject.mobject import Mobject
from manimlib.mobject.numbers import Integer from manimlib.mobject.numbers import Integer
from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.logger import log from manimlib.logger import log
def print_family(mobject, n_tabs=0): def print_family(mobject: Mobject, n_tabs: int = 0) -> None:
"""For debugging purposes""" """For debugging purposes"""
log.debug("\t" * n_tabs + str(mobject) + " " + str(id(mobject))) log.debug("\t" * n_tabs + str(mobject) + " " + str(id(mobject)))
for submob in mobject.submobjects: for submob in mobject.submobjects:
print_family(submob, n_tabs + 1) print_family(submob, n_tabs + 1)
def index_labels(mobject, label_height=0.15): def index_labels(
mobject: Mobject | np.ndarray,
label_height: float = 0.15
) -> VGroup:
labels = VGroup() labels = VGroup()
for n, submob in enumerate(mobject): for n, submob in enumerate(mobject):
label = Integer(n) label = Integer(n)
@ -24,7 +32,7 @@ def index_labels(mobject, label_height=0.15):
return labels return labels
def get_runtime(func): def get_runtime(func: Callable) -> float:
now = time.time() now = time.time()
func() func()
return time.time() - now return time.time() - now

View file

@ -1,48 +1,50 @@
from __future__ import annotations
import os import os
from manimlib.utils.file_ops import guarantee_existence from manimlib.utils.file_ops import guarantee_existence
from manimlib.utils.customization import get_customization from manimlib.utils.customization import get_customization
def get_directories(): def get_directories() -> dict[str, str]:
return get_customization()["directories"] return get_customization()["directories"]
def get_temp_dir(): def get_temp_dir() -> str:
return get_directories()["temporary_storage"] return get_directories()["temporary_storage"]
def get_tex_dir(): def get_tex_dir() -> str:
return guarantee_existence(os.path.join(get_temp_dir(), "Tex")) return guarantee_existence(os.path.join(get_temp_dir(), "Tex"))
def get_text_dir(): def get_text_dir() -> str:
return guarantee_existence(os.path.join(get_temp_dir(), "Text")) return guarantee_existence(os.path.join(get_temp_dir(), "Text"))
def get_mobject_data_dir(): def get_mobject_data_dir() -> str:
return guarantee_existence(os.path.join(get_temp_dir(), "mobject_data")) return guarantee_existence(os.path.join(get_temp_dir(), "mobject_data"))
def get_downloads_dir(): def get_downloads_dir() -> str:
return guarantee_existence(os.path.join(get_temp_dir(), "manim_downloads")) return guarantee_existence(os.path.join(get_temp_dir(), "manim_downloads"))
def get_output_dir(): def get_output_dir() -> str:
return guarantee_existence(get_directories()["output"]) return guarantee_existence(get_directories()["output"])
def get_raster_image_dir(): def get_raster_image_dir() -> str:
return get_directories()["raster_images"] return get_directories()["raster_images"]
def get_vector_image_dir(): def get_vector_image_dir() -> str:
return get_directories()["vector_images"] return get_directories()["vector_images"]
def get_sound_dir(): def get_sound_dir() -> str:
return get_directories()["sounds"] return get_directories()["sounds"]
def get_shader_dir(): def get_shader_dir() -> str:
return get_directories()["shaders"] return get_directories()["shaders"]

View file

@ -1,7 +1,15 @@
from __future__ import annotations
import itertools as it import itertools as it
from typing import Iterable
from manimlib.mobject.mobject import Mobject
def extract_mobject_family_members(mobject_list, only_those_with_points=False): def extract_mobject_family_members(
mobject_list: Iterable[Mobject],
only_those_with_points: bool = False
) -> list[Mobject]:
result = list(it.chain(*[ result = list(it.chain(*[
mob.get_family() mob.get_family()
for mob in mobject_list for mob in mobject_list
@ -11,7 +19,10 @@ def extract_mobject_family_members(mobject_list, only_those_with_points=False):
return result return result
def restructure_list_to_exclude_certain_family_members(mobject_list, to_remove): def restructure_list_to_exclude_certain_family_members(
mobject_list: list[Mobject],
to_remove: list[Mobject]
) -> list[Mobject]:
""" """
Removes anything in to_remove from mobject_list, but in the event that one of Removes anything in to_remove from mobject_list, but in the event that one of
the items to be removed is a member of the family of an item in mobject_list, the items to be removed is a member of the family of an item in mobject_list,

View file

@ -1,9 +1,13 @@
from __future__ import annotations
import os import os
from typing import Iterable
import numpy as np import numpy as np
import validators import validators
def add_extension_if_not_present(file_name, extension): def add_extension_if_not_present(file_name: str, extension: str) -> str:
# This could conceivably be smarter about handling existing differing extensions # This could conceivably be smarter about handling existing differing extensions
if(file_name[-len(extension):] != extension): if(file_name[-len(extension):] != extension):
return file_name + extension return file_name + extension
@ -11,13 +15,17 @@ def add_extension_if_not_present(file_name, extension):
return file_name return file_name
def guarantee_existence(path): def guarantee_existence(path: str) -> str:
if not os.path.exists(path): if not os.path.exists(path):
os.makedirs(path) os.makedirs(path)
return os.path.abspath(path) return os.path.abspath(path)
def find_file(file_name, directories=None, extensions=None): def find_file(
file_name: str,
directories: Iterable[str] | None = None,
extensions: Iterable[str] | None = None
) -> str:
# Check if this is a file online first, and if so, download # Check if this is a file online first, and if so, download
# it to a temporary directory # it to a temporary directory
if validators.url(file_name): if validators.url(file_name):
@ -47,13 +55,14 @@ def find_file(file_name, directories=None, extensions=None):
raise IOError(f"{file_name} not Found") raise IOError(f"{file_name} not Found")
def get_sorted_integer_files(directory, def get_sorted_integer_files(
min_index=0, directory: str,
max_index=np.inf, min_index: float = 0,
remove_non_integer_files=False, max_index: float = np.inf,
remove_indices_greater_than=None, remove_non_integer_files: bool = False,
extension=None, remove_indices_greater_than: float | None = None,
): extension: str | None = None,
) -> list[str]:
indexed_files = [] indexed_files = []
for file in os.listdir(directory): for file in os.listdir(directory):
if '.' in file: if '.' in file:

View file

@ -1,12 +1,13 @@
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from typing import Iterable
from manimlib.utils.file_ops import find_file from manimlib.utils.file_ops import find_file
from manimlib.utils.directories import get_raster_image_dir from manimlib.utils.directories import get_raster_image_dir
from manimlib.utils.directories import get_vector_image_dir from manimlib.utils.directories import get_vector_image_dir
def get_full_raster_image_path(image_file_name): def get_full_raster_image_path(image_file_name: str) -> str:
return find_file( return find_file(
image_file_name, image_file_name,
directories=[get_raster_image_dir()], directories=[get_raster_image_dir()],
@ -14,7 +15,7 @@ def get_full_raster_image_path(image_file_name):
) )
def get_full_vector_image_path(image_file_name): def get_full_vector_image_path(image_file_name: str) -> str:
return find_file( return find_file(
image_file_name, image_file_name,
directories=[get_vector_image_dir()], directories=[get_vector_image_dir()],
@ -22,7 +23,7 @@ def get_full_vector_image_path(image_file_name):
) )
def drag_pixels(frames): def drag_pixels(frames: Iterable) -> list:
curr = frames[0] curr = frames[0]
new_frames = [] new_frames = []
for frame in frames: for frame in frames:
@ -31,7 +32,7 @@ def drag_pixels(frames):
return new_frames return new_frames
def invert_image(image): def invert_image(image: Iterable) -> Image:
arr = np.array(image) arr = np.array(image)
arr = (255 * np.ones(arr.shape)).astype(arr.dtype) - arr arr = (255 * np.ones(arr.shape)).astype(arr.dtype) - arr
return Image.fromarray(arr) return Image.fromarray(arr)

View file

@ -2,6 +2,7 @@ import os
import yaml import yaml
import inspect import inspect
import importlib import importlib
from typing import Any
from rich import box from rich import box
from rich.rule import Rule from rich.rule import Rule
@ -10,13 +11,13 @@ from rich.console import Console
from rich.prompt import Prompt, Confirm from rich.prompt import Prompt, Confirm
def get_manim_dir(): def get_manim_dir() -> str:
manimlib_module = importlib.import_module("manimlib") manimlib_module = importlib.import_module("manimlib")
manimlib_dir = os.path.dirname(inspect.getabsfile(manimlib_module)) manimlib_dir = os.path.dirname(inspect.getabsfile(manimlib_module))
return os.path.abspath(os.path.join(manimlib_dir, "..")) return os.path.abspath(os.path.join(manimlib_dir, ".."))
def remove_empty_value(dictionary): def remove_empty_value(dictionary: dict[str, Any]) -> dict[str, Any]:
for key in list(dictionary.keys()): for key in list(dictionary.keys()):
if dictionary[key] == "": if dictionary[key] == "":
dictionary.pop(key) dictionary.pop(key)
@ -24,7 +25,7 @@ def remove_empty_value(dictionary):
remove_empty_value(dictionary[key]) remove_empty_value(dictionary[key])
def init_customization(): def init_customization() -> None:
configuration = { configuration = {
"directories": { "directories": {
"mirror_module_path": False, "mirror_module_path": False,

View file

@ -1,8 +1,15 @@
from __future__ import annotations
import itertools as it import itertools as it
from typing import Callable, Iterable, TypeVar
import numpy as np import numpy as np
T = TypeVar("T")
S = TypeVar("S")
def remove_list_redundancies(l):
def remove_list_redundancies(l: Iterable[T]) -> list[T]:
""" """
Used instead of list(set(l)) to maintain order Used instead of list(set(l)) to maintain order
Keeps the last occurrence of each element Keeps the last occurrence of each element
@ -17,7 +24,7 @@ def remove_list_redundancies(l):
return reversed_result return reversed_result
def list_update(l1, l2): def list_update(l1: Iterable[T], l2: Iterable[T]) -> list[T]:
""" """
Used instead of list(set(l1).update(l2)) to maintain order, Used instead of list(set(l1).update(l2)) to maintain order,
making sure duplicates are removed from l1, not l2. making sure duplicates are removed from l1, not l2.
@ -25,26 +32,29 @@ def list_update(l1, l2):
return [e for e in l1 if e not in l2] + list(l2) return [e for e in l1 if e not in l2] + list(l2)
def list_difference_update(l1, l2): def list_difference_update(l1: Iterable[T], l2: Iterable[T]) -> list[T]:
return [e for e in l1 if e not in l2] return [e for e in l1 if e not in l2]
def all_elements_are_instances(iterable, Class): def all_elements_are_instances(iterable: Iterable, Class: type) -> bool:
return all([isinstance(e, Class) for e in iterable]) return all([isinstance(e, Class) for e in iterable])
def adjacent_n_tuples(objects, n): def adjacent_n_tuples(objects: Iterable[T], n: int) -> zip[tuple[T, T]]:
return zip(*[ return zip(*[
[*objects[k:], *objects[:k]] [*objects[k:], *objects[:k]]
for k in range(n) for k in range(n)
]) ])
def adjacent_pairs(objects): def adjacent_pairs(objects: Iterable[T]) -> zip[tuple[T, T]]:
return adjacent_n_tuples(objects, 2) return adjacent_n_tuples(objects, 2)
def batch_by_property(items, property_func): def batch_by_property(
items: Iterable[T],
property_func: Callable[[T], S]
) -> list[tuple[T, S]]:
""" """
Takes in a list, and returns a list of tuples, (batch, prop) Takes in a list, and returns a list of tuples, (batch, prop)
such that all items in a batch have the same output when such that all items in a batch have the same output when
@ -71,7 +81,7 @@ def batch_by_property(items, property_func):
return batch_prop_pairs return batch_prop_pairs
def listify(obj): def listify(obj) -> list:
if isinstance(obj, str): if isinstance(obj, str):
return [obj] return [obj]
try: try:
@ -80,13 +90,13 @@ def listify(obj):
return [obj] return [obj]
def resize_array(nparray, length): def resize_array(nparray: np.ndarray, length: int) -> np.ndarray:
if len(nparray) == length: if len(nparray) == length:
return nparray return nparray
return np.resize(nparray, (length, *nparray.shape[1:])) return np.resize(nparray, (length, *nparray.shape[1:]))
def resize_preserving_order(nparray, length): def resize_preserving_order(nparray: np.ndarray, length: int) -> np.ndarray:
if len(nparray) == 0: if len(nparray) == 0:
return np.zeros((length, *nparray.shape[1:])) return np.zeros((length, *nparray.shape[1:]))
if len(nparray) == length: if len(nparray) == length:
@ -95,7 +105,7 @@ def resize_preserving_order(nparray, length):
return nparray[indices] return nparray[indices]
def resize_with_interpolation(nparray, length): def resize_with_interpolation(nparray: np.ndarray, length: int) -> np.ndarray:
if len(nparray) == length: if len(nparray) == length:
return nparray return nparray
if length == 0: if length == 0:
@ -108,7 +118,10 @@ def resize_with_interpolation(nparray, length):
]) ])
def make_even(iterable_1, iterable_2): def make_even(
iterable_1: Iterable[T],
iterable_2: Iterable[S]
) -> tuple[list[T], list[S]]:
len1 = len(iterable_1) len1 = len(iterable_1)
len2 = len(iterable_2) len2 = len(iterable_2)
if len1 == len2: if len1 == len2:
@ -120,7 +133,10 @@ def make_even(iterable_1, iterable_2):
) )
def make_even_by_cycling(iterable_1, iterable_2): def make_even_by_cycling(
iterable_1: Iterable[T],
iterable_2: Iterable[S]
) -> tuple[list[T], list[S]]:
length = max(len(iterable_1), len(iterable_2)) length = max(len(iterable_1), len(iterable_2))
cycle1 = it.cycle(iterable_1) cycle1 = it.cycle(iterable_1)
cycle2 = it.cycle(iterable_2) cycle2 = it.cycle(iterable_2)
@ -130,7 +146,7 @@ def make_even_by_cycling(iterable_1, iterable_2):
) )
def remove_nones(sequence): def remove_nones(sequence: Iterable) -> list:
return [x for x in sequence if x] return [x for x in sequence if x]

View file

@ -1,5 +1,7 @@
import numpy as np
import math import math
from typing import Callable
import numpy as np
from manimlib.constants import OUT from manimlib.constants import OUT
from manimlib.utils.bezier import interpolate from manimlib.utils.bezier import interpolate
@ -9,7 +11,11 @@ from manimlib.utils.space_ops import rotation_matrix_transpose
STRAIGHT_PATH_THRESHOLD = 0.01 STRAIGHT_PATH_THRESHOLD = 0.01
def straight_path(start_points, end_points, alpha): def straight_path(
start_points: np.ndarray,
end_points: np.ndarray,
alpha: float
) -> np.ndarray:
""" """
Same function as interpolate, but renamed to reflect Same function as interpolate, but renamed to reflect
intent of being used to determine how a set of points move intent of being used to determine how a set of points move
@ -19,7 +25,10 @@ def straight_path(start_points, end_points, alpha):
return interpolate(start_points, end_points, alpha) return interpolate(start_points, end_points, alpha)
def path_along_arc(arc_angle, axis=OUT): def path_along_arc(
arc_angle: float,
axis: np.ndarray = OUT
) -> Callable[[np.ndarray, np.ndarray, float], np.ndarray]:
""" """
If vect is vector from start to end, [vect[:,1], -vect[:,0]] is If vect is vector from start to end, [vect[:,1], -vect[:,0]] is
perpendicular to vect in the left direction. perpendicular to vect in the left direction.
@ -41,9 +50,9 @@ def path_along_arc(arc_angle, axis=OUT):
return path return path
def clockwise_path(): def clockwise_path() -> Callable[[np.ndarray, np.ndarray, float], np.ndarray]:
return path_along_arc(-np.pi) return path_along_arc(-np.pi)
def counterclockwise_path(): def counterclockwise_path() -> Callable[[np.ndarray, np.ndarray, float], np.ndarray]:
return path_along_arc(np.pi) return path_along_arc(np.pi)

View file

@ -1,44 +1,46 @@
from typing import Callable
import numpy as np import numpy as np
from manimlib.utils.bezier import bezier from manimlib.utils.bezier import bezier
def linear(t): def linear(t: float) -> float:
return t return t
def smooth(t): def smooth(t: float) -> float:
# Zero first and second derivatives at t=0 and t=1. # Zero first and second derivatives at t=0 and t=1.
# Equivalent to bezier([0, 0, 0, 1, 1, 1]) # Equivalent to bezier([0, 0, 0, 1, 1, 1])
s = 1 - t s = 1 - t
return (t**3) * (10 * s * s + 5 * s * t + t * t) return (t**3) * (10 * s * s + 5 * s * t + t * t)
def rush_into(t): def rush_into(t: float) -> float:
return 2 * smooth(0.5 * t) return 2 * smooth(0.5 * t)
def rush_from(t): def rush_from(t: float) -> float:
return 2 * smooth(0.5 * (t + 1)) - 1 return 2 * smooth(0.5 * (t + 1)) - 1
def slow_into(t): def slow_into(t: float) -> float:
return np.sqrt(1 - (1 - t) * (1 - t)) return np.sqrt(1 - (1 - t) * (1 - t))
def double_smooth(t): def double_smooth(t: float) -> float:
if t < 0.5: if t < 0.5:
return 0.5 * smooth(2 * t) return 0.5 * smooth(2 * t)
else: else:
return 0.5 * (1 + smooth(2 * t - 1)) return 0.5 * (1 + smooth(2 * t - 1))
def there_and_back(t): def there_and_back(t: float) -> float:
new_t = 2 * t if t < 0.5 else 2 * (1 - t) new_t = 2 * t if t < 0.5 else 2 * (1 - t)
return smooth(new_t) return smooth(new_t)
def there_and_back_with_pause(t, pause_ratio=1. / 3): def there_and_back_with_pause(t: float, pause_ratio: float = 1. / 3) -> float:
a = 1. / pause_ratio a = 1. / pause_ratio
if t < 0.5 - pause_ratio / 2: if t < 0.5 - pause_ratio / 2:
return smooth(a * t) return smooth(a * t)
@ -48,21 +50,28 @@ def there_and_back_with_pause(t, pause_ratio=1. / 3):
return smooth(a - a * t) return smooth(a - a * t)
def running_start(t, pull_factor=-0.5): def running_start(t: float, pull_factor: float = -0.5) -> float:
return bezier([0, 0, pull_factor, pull_factor, 1, 1, 1])(t) return bezier([0, 0, pull_factor, pull_factor, 1, 1, 1])(t)
def not_quite_there(func=smooth, proportion=0.7): def not_quite_there(
func: Callable[[float], float] = smooth,
proportion: float = 0.7
) -> Callable[[float], float]:
def result(t): def result(t):
return proportion * func(t) return proportion * func(t)
return result return result
def wiggle(t, wiggles=2): def wiggle(t: float, wiggles: float = 2) -> float:
return there_and_back(t) * np.sin(wiggles * np.pi * t) return there_and_back(t) * np.sin(wiggles * np.pi * t)
def squish_rate_func(func, a=0.4, b=0.6): def squish_rate_func(
func: Callable[[float], float],
a: float = 0.4,
b: float = 0.6
) -> Callable[[float], float]:
def result(t): def result(t):
if a == b: if a == b:
return a return a
@ -81,11 +90,11 @@ def squish_rate_func(func, a=0.4, b=0.6):
# "lingering", different from squish_rate_func's default params # "lingering", different from squish_rate_func's default params
def lingering(t): def lingering(t: float) -> float:
return squish_rate_func(lambda t: t, 0, 0.8)(t) return squish_rate_func(lambda t: t, 0, 0.8)(t)
def exponential_decay(t, half_life=0.1): def exponential_decay(t: float, half_life: float = 0.1) -> float:
# The half-life should be rather small to minimize # The half-life should be rather small to minimize
# the cut-off error at the end # the cut-off error at the end
return 1 - np.exp(-t / half_life) return 1 - np.exp(-t / half_life)

View file

@ -2,7 +2,7 @@ from manimlib.utils.file_ops import find_file
from manimlib.utils.directories import get_sound_dir from manimlib.utils.directories import get_sound_dir
def get_full_sound_file_path(sound_file_name): def get_full_sound_file_path(sound_file_name) -> str:
return find_file( return find_file(
sound_file_name, sound_file_name,
directories=[get_sound_dir()], directories=[get_sound_dir()],

View file

@ -1,7 +1,11 @@
import numpy as np from __future__ import annotations
import math
import operator as op import operator as op
from functools import reduce from functools import reduce
import math from typing import Callable, Iterable, Sequence
import numpy as np
from mapbox_earcut import triangulate_float32 as earcut from mapbox_earcut import triangulate_float32 as earcut
from manimlib.constants import RIGHT from manimlib.constants import RIGHT
@ -13,7 +17,7 @@ from manimlib.utils.iterables import adjacent_pairs
from manimlib.utils.simple_functions import clip from manimlib.utils.simple_functions import clip
def cross(v1, v2): def cross(v1: np.ndarray, v2: np.ndarray) -> list[np.ndarray]:
return [ return [
v1[1] * v2[2] - v1[2] * v2[1], v1[1] * v2[2] - v1[2] * v2[1],
v1[2] * v2[0] - v1[0] * v2[2], v1[2] * v2[0] - v1[0] * v2[2],
@ -21,7 +25,7 @@ def cross(v1, v2):
] ]
def get_norm(vect): def get_norm(vect: np.ndarray) -> np.flaoting:
return sum((x**2 for x in vect))**0.5 return sum((x**2 for x in vect))**0.5
@ -29,7 +33,7 @@ def get_norm(vect):
# TODO, implement quaternion type # TODO, implement quaternion type
def quaternion_mult(*quats): def quaternion_mult(*quats: Sequence[float]) -> list[float]:
if len(quats) == 0: if len(quats) == 0:
return [1, 0, 0, 0] return [1, 0, 0, 0]
result = quats[0] result = quats[0]
@ -45,13 +49,19 @@ def quaternion_mult(*quats):
return result return result
def quaternion_from_angle_axis(angle, axis, axis_normalized=False): def quaternion_from_angle_axis(
angle: float,
axis: np.ndarray,
axis_normalized: bool = False
) -> list[float]:
if not axis_normalized: if not axis_normalized:
axis = normalize(axis) axis = normalize(axis)
return [math.cos(angle / 2), *(math.sin(angle / 2) * axis)] return [math.cos(angle / 2), *(math.sin(angle / 2) * axis)]
def angle_axis_from_quaternion(quaternion): def angle_axis_from_quaternion(
quaternion: Sequence[float]
) -> tuple[float, np.ndarray]:
axis = normalize( axis = normalize(
quaternion[1:], quaternion[1:],
fall_back=[1, 0, 0] fall_back=[1, 0, 0]
@ -62,14 +72,18 @@ def angle_axis_from_quaternion(quaternion):
return angle, axis return angle, axis
def quaternion_conjugate(quaternion): def quaternion_conjugate(quaternion: Iterable) -> list:
result = list(quaternion) result = list(quaternion)
for i in range(1, len(result)): for i in range(1, len(result)):
result[i] *= -1 result[i] *= -1
return result return result
def rotate_vector(vector, angle, axis=OUT): def rotate_vector(
vector: Iterable,
angle: float,
axis: np.ndarray = OUT
) -> np.ndarray | list[float]:
if len(vector) == 2: if len(vector) == 2:
# Use complex numbers...because why not # Use complex numbers...because why not
z = complex(*vector) * np.exp(complex(0, angle)) z = complex(*vector) * np.exp(complex(0, angle))
@ -88,13 +102,13 @@ def rotate_vector(vector, angle, axis=OUT):
return result return result
def thick_diagonal(dim, thickness=2): def thick_diagonal(dim: int, thickness: int = 2) -> np.ndarray:
row_indices = np.arange(dim).repeat(dim).reshape((dim, dim)) row_indices = np.arange(dim).repeat(dim).reshape((dim, dim))
col_indices = np.transpose(row_indices) col_indices = np.transpose(row_indices)
return (np.abs(row_indices - col_indices) < thickness).astype('uint8') return (np.abs(row_indices - col_indices) < thickness).astype('uint8')
def rotation_matrix_transpose_from_quaternion(quat): def rotation_matrix_transpose_from_quaternion(quat: Iterable) -> list[list[float]]:
quat_inv = quaternion_conjugate(quat) quat_inv = quaternion_conjugate(quat)
return [ return [
quaternion_mult(quat, [0, *basis], quat_inv)[1:] quaternion_mult(quat, [0, *basis], quat_inv)[1:]
@ -106,11 +120,11 @@ def rotation_matrix_transpose_from_quaternion(quat):
] ]
def rotation_matrix_from_quaternion(quat): def rotation_matrix_from_quaternion(quat: Iterable) -> np.ndarray:
return np.transpose(rotation_matrix_transpose_from_quaternion(quat)) return np.transpose(rotation_matrix_transpose_from_quaternion(quat))
def rotation_matrix_transpose(angle, axis): def rotation_matrix_transpose(angle: float, axis: np.ndarray) -> list[list[flaot]]:
if axis[0] == 0 and axis[1] == 0: if axis[0] == 0 and axis[1] == 0:
# axis = [0, 0, z] case is common enough it's worth # axis = [0, 0, z] case is common enough it's worth
# having a shortcut # having a shortcut
@ -126,14 +140,14 @@ def rotation_matrix_transpose(angle, axis):
return rotation_matrix_transpose_from_quaternion(quat) return rotation_matrix_transpose_from_quaternion(quat)
def rotation_matrix(angle, axis): def rotation_matrix(angle: float, axis: np.ndarray) -> np.ndarray:
""" """
Rotation in R^3 about a specified axis of rotation. Rotation in R^3 about a specified axis of rotation.
""" """
return np.transpose(rotation_matrix_transpose(angle, axis)) return np.transpose(rotation_matrix_transpose(angle, axis))
def rotation_about_z(angle): def rotation_about_z(angle: float) -> list[list[float]]:
return [ return [
[math.cos(angle), -math.sin(angle), 0], [math.cos(angle), -math.sin(angle), 0],
[math.sin(angle), math.cos(angle), 0], [math.sin(angle), math.cos(angle), 0],
@ -141,7 +155,7 @@ def rotation_about_z(angle):
] ]
def z_to_vector(vector): def z_to_vector(vector: np.ndarray) -> np.ndarray:
""" """
Returns some matrix in SO(3) which takes the z-axis to the Returns some matrix in SO(3) which takes the z-axis to the
(normalized) vector provided as an argument (normalized) vector provided as an argument
@ -156,7 +170,7 @@ def z_to_vector(vector):
return rotation_matrix(angle, axis=axis) return rotation_matrix(angle, axis=axis)
def rotation_between_vectors(v1, v2): def rotation_between_vectors(v1, v2) -> np.ndarray:
if np.all(np.isclose(v1, v2)): if np.all(np.isclose(v1, v2)):
return np.identity(3) return np.identity(3)
return rotation_matrix( return rotation_matrix(
@ -165,14 +179,14 @@ def rotation_between_vectors(v1, v2):
) )
def angle_of_vector(vector): def angle_of_vector(vector: Sequence[float]) -> float:
""" """
Returns polar coordinate theta when vector is project on xy plane Returns polar coordinate theta when vector is project on xy plane
""" """
return np.angle(complex(*vector[:2])) return np.angle(complex(*vector[:2]))
def angle_between_vectors(v1, v2): def angle_between_vectors(v1: np.ndarray, v2: np.ndarray) -> float:
""" """
Returns the angle between two 3D vectors. Returns the angle between two 3D vectors.
This angle will always be btw 0 and pi This angle will always be btw 0 and pi
@ -180,12 +194,15 @@ def angle_between_vectors(v1, v2):
return math.acos(clip(np.dot(normalize(v1), normalize(v2)), -1, 1)) return math.acos(clip(np.dot(normalize(v1), normalize(v2)), -1, 1))
def project_along_vector(point, vector): def project_along_vector(point: np.ndarray, vector: np.ndarray) -> np.ndarray:
matrix = np.identity(3) - np.outer(vector, vector) matrix = np.identity(3) - np.outer(vector, vector)
return np.dot(point, matrix.T) return np.dot(point, matrix.T)
def normalize(vect, fall_back=None): def normalize(
vect: np.ndarray,
fall_back: np.ndarray | None = None
) -> np.ndarray:
norm = get_norm(vect) norm = get_norm(vect)
if norm > 0: if norm > 0:
return np.array(vect) / norm return np.array(vect) / norm
@ -195,7 +212,10 @@ def normalize(vect, fall_back=None):
return np.zeros(len(vect)) return np.zeros(len(vect))
def normalize_along_axis(array, axis, fall_back=None): def normalize_along_axis(
array: np.ndarray,
axis: np.ndarray,
) -> np.ndarray:
norms = np.sqrt((array * array).sum(axis)) norms = np.sqrt((array * array).sum(axis))
norms[norms == 0] = 1 norms[norms == 0] = 1
buffed_norms = np.repeat(norms, array.shape[axis]).reshape(array.shape) buffed_norms = np.repeat(norms, array.shape[axis]).reshape(array.shape)
@ -203,7 +223,11 @@ def normalize_along_axis(array, axis, fall_back=None):
return array return array
def get_unit_normal(v1, v2, tol=1e-6): def get_unit_normal(
v1: np.ndarray,
v2: np.ndarray,
tol: float=1e-6
) -> np.ndarray:
v1 = normalize(v1) v1 = normalize(v1)
v2 = normalize(v2) v2 = normalize(v2)
cp = cross(v1, v2) cp = cross(v1, v2)
@ -221,7 +245,7 @@ def get_unit_normal(v1, v2, tol=1e-6):
### ###
def compass_directions(n=4, start_vect=RIGHT): def compass_directions(n: int = 4, start_vect: np.ndarray = RIGHT) -> np.ndarray:
angle = TAU / n angle = TAU / n
return np.array([ return np.array([
rotate_vector(start_vect, k * angle) rotate_vector(start_vect, k * angle)
@ -229,28 +253,36 @@ def compass_directions(n=4, start_vect=RIGHT):
]) ])
def complex_to_R3(complex_num): def complex_to_R3(complex_num: complex) -> np.ndarray:
return np.array((complex_num.real, complex_num.imag, 0)) return np.array((complex_num.real, complex_num.imag, 0))
def R3_to_complex(point): def R3_to_complex(point: Sequence[float]) -> complex:
return complex(*point[:2]) return complex(*point[:2])
def complex_func_to_R3_func(complex_func): def complex_func_to_R3_func(
complex_func: Callable[[complex], complex]
) -> Callable[[np.ndarray], np.ndarray]:
return lambda p: complex_to_R3(complex_func(R3_to_complex(p))) return lambda p: complex_to_R3(complex_func(R3_to_complex(p)))
def center_of_mass(points): def center_of_mass(points: Iterable[Sequence[float]]) -> np.ndarray:
points = [np.array(point).astype("float") for point in points] points = [np.array(point).astype("float") for point in points]
return sum(points) / len(points) return sum(points) / len(points)
def midpoint(point1, point2): def midpoint(
point1: Sequence[float],
point2: Sequence[float]
) -> np.ndarray:
return center_of_mass([point1, point2]) return center_of_mass([point1, point2])
def line_intersection(line1, line2): def line_intersection(
line1: Sequence[Sequence[float]],
line2: Sequence[Sequence[float]]
) -> np.ndarray:
""" """
return intersection point of two lines, return intersection point of two lines,
each defined with a pair of vectors determining each defined with a pair of vectors determining
@ -271,7 +303,13 @@ def line_intersection(line1, line2):
return np.array([x, y, 0]) return np.array([x, y, 0])
def find_intersection(p0, v0, p1, v1, threshold=1e-5): def find_intersection(
p0: Iterable[float],
v0: Iterable[float],
p1: Iterable[float],
v1: Iterable[float],
threshold: float = 1e-5
) -> np.ndarray:
""" """
Return the intersection of a line passing through p0 in direction v0 Return the intersection of a line passing through p0 in direction v0
with one passing through p1 in direction v1. (Or array of intersections with one passing through p1 in direction v1. (Or array of intersections
@ -300,7 +338,11 @@ def find_intersection(p0, v0, p1, v1, threshold=1e-5):
return p0 + ratio * v0 return p0 + ratio * v0
def get_closest_point_on_line(a, b, p): def get_closest_point_on_line(
a: np.ndarray,
b: np.ndarray,
p: np.ndarray
) -> np.ndarray:
""" """
It returns point x such that It returns point x such that
x is on line ab and xp is perpendicular to ab. x is on line ab and xp is perpendicular to ab.
@ -315,7 +357,7 @@ def get_closest_point_on_line(a, b, p):
return ((t * a) + ((1 - t) * b)) return ((t * a) + ((1 - t) * b))
def get_winding_number(points): def get_winding_number(points: Iterable[float]) -> float:
total_angle = 0 total_angle = 0
for p1, p2 in adjacent_pairs(points): for p1, p2 in adjacent_pairs(points):
d_angle = angle_of_vector(p2) - angle_of_vector(p1) d_angle = angle_of_vector(p2) - angle_of_vector(p1)
@ -326,14 +368,18 @@ def get_winding_number(points):
## ##
def cross2d(a, b): def cross2d(a: np.ndarray, b: np.ndarray) -> np.ndarray:
if len(a.shape) == 2: if len(a.shape) == 2:
return a[:, 0] * b[:, 1] - a[:, 1] * b[:, 0] return a[:, 0] * b[:, 1] - a[:, 1] * b[:, 0]
else: else:
return a[0] * b[1] - b[0] * a[1] return a[0] * b[1] - b[0] * a[1]
def tri_area(a, b, c): def tri_area(
a: Sequence[float],
b: Sequence[float],
c: Sequence[float]
) -> float:
return 0.5 * abs( return 0.5 * abs(
a[0] * (b[1] - c[1]) + a[0] * (b[1] - c[1]) +
b[0] * (c[1] - a[1]) + b[0] * (c[1] - a[1]) +
@ -341,7 +387,12 @@ def tri_area(a, b, c):
) )
def is_inside_triangle(p, a, b, c): def is_inside_triangle(
p: np.ndarray,
a: np.ndarray,
b: np.ndarray,
c: np.ndarray
) -> bool:
""" """
Test if point p is inside triangle abc Test if point p is inside triangle abc
""" """
@ -353,12 +404,12 @@ def is_inside_triangle(p, a, b, c):
return np.all(crosses > 0) or np.all(crosses < 0) return np.all(crosses > 0) or np.all(crosses < 0)
def norm_squared(v): def norm_squared(v: Sequence[float]) -> float:
return v[0] * v[0] + v[1] * v[1] + v[2] * v[2] return v[0] * v[0] + v[1] * v[1] + v[2] * v[2]
# TODO, fails for polygons drawn over themselves # TODO, fails for polygons drawn over themselves
def earclip_triangulation(verts, ring_ends): def earclip_triangulation(verts: np.ndarray, ring_ends: list[int]) -> list:
""" """
Returns a list of indices giving a triangulation Returns a list of indices giving a triangulation
of a polygon, potentially with holes of a polygon, potentially with holes