mirror of
https://github.com/3b1b/manim.git
synced 2025-11-13 15:47:49 +00:00
chore: add type hints to manimlib.utils
This commit is contained in:
parent
67f5b10626
commit
6e292daf58
12 changed files with 281 additions and 122 deletions
|
|
@ -1,3 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable, Callable, TypeVar
|
||||
|
||||
from scipy import linalg
|
||||
import numpy as np
|
||||
|
||||
|
|
@ -8,9 +12,9 @@ from manimlib.utils.space_ops import midpoint
|
|||
from manimlib.logger import log
|
||||
|
||||
CLOSED_THRESHOLD = 0.001
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def bezier(points):
|
||||
def bezier(points: Iterable) -> Callable[[float], float | Iterable]:
|
||||
n = len(points) - 1
|
||||
|
||||
def result(t):
|
||||
|
|
@ -22,7 +26,11 @@ def bezier(points):
|
|||
return result
|
||||
|
||||
|
||||
def partial_bezier_points(points, a, b):
|
||||
def partial_bezier_points(
|
||||
points: Iterable[np.ndarray],
|
||||
a: float,
|
||||
b: float
|
||||
) -> list[float]:
|
||||
"""
|
||||
Given an list of points which define
|
||||
a bezier curve, and two numbers 0<=a<b<=1,
|
||||
|
|
@ -48,7 +56,11 @@ def partial_bezier_points(points, a, b):
|
|||
|
||||
# Shortened version of partial_bezier_points just for quadratics,
|
||||
# since this is called a fair amount
|
||||
def partial_quadratic_bezier_points(points, a, b):
|
||||
def partial_quadratic_bezier_points(
|
||||
points: Iterable[np.ndarray],
|
||||
a: float,
|
||||
b: float
|
||||
) -> list[float]:
|
||||
if a == 1:
|
||||
return 3 * [points[-1]]
|
||||
|
||||
|
|
@ -65,7 +77,7 @@ def partial_quadratic_bezier_points(points, a, b):
|
|||
|
||||
# Linear interpolation variants
|
||||
|
||||
def interpolate(start, end, alpha):
|
||||
def interpolate(start: T, end: T, alpha: float) -> T:
|
||||
try:
|
||||
return (1 - alpha) * start + alpha * end
|
||||
except TypeError:
|
||||
|
|
@ -76,12 +88,22 @@ def interpolate(start, end, alpha):
|
|||
sys.exit(2)
|
||||
|
||||
|
||||
def set_array_by_interpolation(arr, arr1, arr2, alpha, interp_func=interpolate):
|
||||
def set_array_by_interpolation(
|
||||
arr: list[T],
|
||||
arr1: list[T],
|
||||
arr2: list[T],
|
||||
alpha: float,
|
||||
interp_func: Callable[[T, T, float], T] = interpolate
|
||||
) -> list[T]:
|
||||
arr[:] = interp_func(arr1, arr2, alpha)
|
||||
return arr
|
||||
|
||||
|
||||
def integer_interpolate(start, end, alpha):
|
||||
def integer_interpolate(
|
||||
start: T,
|
||||
end: T,
|
||||
alpha: float
|
||||
) -> tuple[int, float]:
|
||||
"""
|
||||
alpha is a float between 0 and 1. This returns
|
||||
an integer between start and end (inclusive) representing
|
||||
|
|
@ -102,22 +124,30 @@ def integer_interpolate(start, end, alpha):
|
|||
return (value, residue)
|
||||
|
||||
|
||||
def mid(start, end):
|
||||
def mid(start: T, end: T) -> T:
|
||||
return (start + end) / 2.0
|
||||
|
||||
|
||||
def inverse_interpolate(start, end, value):
|
||||
def inverse_interpolate(start: T, end: T, value: T) -> float:
|
||||
return np.true_divide(value - start, end - start)
|
||||
|
||||
|
||||
def match_interpolate(new_start, new_end, old_start, old_end, old_value):
|
||||
def match_interpolate(
|
||||
new_start: T,
|
||||
new_end: T,
|
||||
old_start: T,
|
||||
old_end: T,
|
||||
old_value: T
|
||||
) -> T:
|
||||
return interpolate(
|
||||
new_start, new_end,
|
||||
inverse_interpolate(old_start, old_end, old_value)
|
||||
)
|
||||
|
||||
|
||||
def get_smooth_quadratic_bezier_handle_points(points):
|
||||
def get_smooth_quadratic_bezier_handle_points(
|
||||
points: Iterable[np.ndarray]
|
||||
) -> np.ndarray | list[np.ndarray]:
|
||||
"""
|
||||
Figuring out which bezier curves most smoothly connect a sequence of points.
|
||||
|
||||
|
|
@ -149,7 +179,9 @@ def get_smooth_quadratic_bezier_handle_points(points):
|
|||
return handles
|
||||
|
||||
|
||||
def get_smooth_cubic_bezier_handle_points(points):
|
||||
def get_smooth_cubic_bezier_handle_points(
|
||||
points: Iterable[np.ndarray]
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
points = np.array(points)
|
||||
num_handles = len(points) - 1
|
||||
dim = points.shape[1]
|
||||
|
|
@ -207,7 +239,10 @@ def get_smooth_cubic_bezier_handle_points(points):
|
|||
return handle_pairs[0::2], handle_pairs[1::2]
|
||||
|
||||
|
||||
def diag_to_matrix(l_and_u, diag):
|
||||
def diag_to_matrix(
|
||||
l_and_u: tuple[int, int],
|
||||
diag: np.ndarray
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Converts array whose rows represent diagonal
|
||||
entries of a matrix into the matrix itself.
|
||||
|
|
@ -224,13 +259,18 @@ def diag_to_matrix(l_and_u, diag):
|
|||
return matrix
|
||||
|
||||
|
||||
def is_closed(points):
|
||||
def is_closed(points: Iterable[np.ndarray]) -> bool:
|
||||
return np.allclose(points[0], points[-1])
|
||||
|
||||
|
||||
# Given 4 control points for a cubic bezier curve (or arrays of such)
|
||||
# return control points for 2 quadratics (or 2n quadratics) approximating them.
|
||||
def get_quadratic_approximation_of_cubic(a0, h0, h1, a1):
|
||||
def get_quadratic_approximation_of_cubic(
|
||||
a0: np.ndarray | Iterable[np.ndarray],
|
||||
h0: np.ndarray | Iterable[np.ndarray],
|
||||
h1: np.ndarray | Iterable[np.ndarray],
|
||||
a1: np.ndarray | Iterable[np.ndarray]
|
||||
) -> np.ndarray:
|
||||
a0 = np.array(a0, ndmin=2)
|
||||
h0 = np.array(h0, ndmin=2)
|
||||
h1 = np.array(h1, ndmin=2)
|
||||
|
|
@ -298,7 +338,9 @@ def get_quadratic_approximation_of_cubic(a0, h0, h1, a1):
|
|||
return result
|
||||
|
||||
|
||||
def get_smooth_quadratic_bezier_path_through(points):
|
||||
def get_smooth_quadratic_bezier_path_through(
|
||||
points: list[np.ndarray]
|
||||
) -> np.ndarray:
|
||||
# TODO
|
||||
h0, h1 = get_smooth_cubic_bezier_handle_points(points)
|
||||
a0 = points[:-1]
|
||||
|
|
|
|||
|
|
@ -1,19 +1,27 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import numpy as np
|
||||
from typing import Callable
|
||||
|
||||
from manimlib.constants import BLACK
|
||||
from manimlib.mobject.mobject import Mobject
|
||||
from manimlib.mobject.numbers import Integer
|
||||
from manimlib.mobject.types.vectorized_mobject import VGroup
|
||||
from manimlib.logger import log
|
||||
|
||||
|
||||
def print_family(mobject, n_tabs=0):
|
||||
def print_family(mobject: Mobject, n_tabs: int = 0) -> None:
|
||||
"""For debugging purposes"""
|
||||
log.debug("\t" * n_tabs + str(mobject) + " " + str(id(mobject)))
|
||||
for submob in mobject.submobjects:
|
||||
print_family(submob, n_tabs + 1)
|
||||
|
||||
|
||||
def index_labels(mobject, label_height=0.15):
|
||||
def index_labels(
|
||||
mobject: Mobject | np.ndarray,
|
||||
label_height: float = 0.15
|
||||
) -> VGroup:
|
||||
labels = VGroup()
|
||||
for n, submob in enumerate(mobject):
|
||||
label = Integer(n)
|
||||
|
|
@ -24,7 +32,7 @@ def index_labels(mobject, label_height=0.15):
|
|||
return labels
|
||||
|
||||
|
||||
def get_runtime(func):
|
||||
def get_runtime(func: Callable) -> float:
|
||||
now = time.time()
|
||||
func()
|
||||
return time.time() - now
|
||||
|
|
|
|||
|
|
@ -1,48 +1,50 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from manimlib.utils.file_ops import guarantee_existence
|
||||
from manimlib.utils.customization import get_customization
|
||||
|
||||
|
||||
def get_directories():
|
||||
def get_directories() -> dict[str, str]:
|
||||
return get_customization()["directories"]
|
||||
|
||||
|
||||
def get_temp_dir():
|
||||
def get_temp_dir() -> str:
|
||||
return get_directories()["temporary_storage"]
|
||||
|
||||
|
||||
def get_tex_dir():
|
||||
def get_tex_dir() -> str:
|
||||
return guarantee_existence(os.path.join(get_temp_dir(), "Tex"))
|
||||
|
||||
|
||||
def get_text_dir():
|
||||
def get_text_dir() -> str:
|
||||
return guarantee_existence(os.path.join(get_temp_dir(), "Text"))
|
||||
|
||||
|
||||
def get_mobject_data_dir():
|
||||
def get_mobject_data_dir() -> str:
|
||||
return guarantee_existence(os.path.join(get_temp_dir(), "mobject_data"))
|
||||
|
||||
|
||||
def get_downloads_dir():
|
||||
def get_downloads_dir() -> str:
|
||||
return guarantee_existence(os.path.join(get_temp_dir(), "manim_downloads"))
|
||||
|
||||
|
||||
def get_output_dir():
|
||||
def get_output_dir() -> str:
|
||||
return guarantee_existence(get_directories()["output"])
|
||||
|
||||
|
||||
def get_raster_image_dir():
|
||||
def get_raster_image_dir() -> str:
|
||||
return get_directories()["raster_images"]
|
||||
|
||||
|
||||
def get_vector_image_dir():
|
||||
def get_vector_image_dir() -> str:
|
||||
return get_directories()["vector_images"]
|
||||
|
||||
|
||||
def get_sound_dir():
|
||||
def get_sound_dir() -> str:
|
||||
return get_directories()["sounds"]
|
||||
|
||||
|
||||
def get_shader_dir():
|
||||
def get_shader_dir() -> str:
|
||||
return get_directories()["shaders"]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,15 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import itertools as it
|
||||
from typing import Iterable
|
||||
|
||||
from manimlib.mobject.mobject import Mobject
|
||||
|
||||
|
||||
def extract_mobject_family_members(mobject_list, only_those_with_points=False):
|
||||
def extract_mobject_family_members(
|
||||
mobject_list: Iterable[Mobject],
|
||||
only_those_with_points: bool = False
|
||||
) -> list[Mobject]:
|
||||
result = list(it.chain(*[
|
||||
mob.get_family()
|
||||
for mob in mobject_list
|
||||
|
|
@ -11,7 +19,10 @@ def extract_mobject_family_members(mobject_list, only_those_with_points=False):
|
|||
return result
|
||||
|
||||
|
||||
def restructure_list_to_exclude_certain_family_members(mobject_list, to_remove):
|
||||
def restructure_list_to_exclude_certain_family_members(
|
||||
mobject_list: list[Mobject],
|
||||
to_remove: list[Mobject]
|
||||
) -> list[Mobject]:
|
||||
"""
|
||||
Removes anything in to_remove from mobject_list, but in the event that one of
|
||||
the items to be removed is a member of the family of an item in mobject_list,
|
||||
|
|
|
|||
|
|
@ -1,9 +1,13 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Iterable
|
||||
|
||||
import numpy as np
|
||||
import validators
|
||||
|
||||
|
||||
def add_extension_if_not_present(file_name, extension):
|
||||
def add_extension_if_not_present(file_name: str, extension: str) -> str:
|
||||
# This could conceivably be smarter about handling existing differing extensions
|
||||
if(file_name[-len(extension):] != extension):
|
||||
return file_name + extension
|
||||
|
|
@ -11,13 +15,17 @@ def add_extension_if_not_present(file_name, extension):
|
|||
return file_name
|
||||
|
||||
|
||||
def guarantee_existence(path):
|
||||
def guarantee_existence(path: str) -> str:
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
return os.path.abspath(path)
|
||||
|
||||
|
||||
def find_file(file_name, directories=None, extensions=None):
|
||||
def find_file(
|
||||
file_name: str,
|
||||
directories: Iterable[str] | None = None,
|
||||
extensions: Iterable[str] | None = None
|
||||
) -> str:
|
||||
# Check if this is a file online first, and if so, download
|
||||
# it to a temporary directory
|
||||
if validators.url(file_name):
|
||||
|
|
@ -47,13 +55,14 @@ def find_file(file_name, directories=None, extensions=None):
|
|||
raise IOError(f"{file_name} not Found")
|
||||
|
||||
|
||||
def get_sorted_integer_files(directory,
|
||||
min_index=0,
|
||||
max_index=np.inf,
|
||||
remove_non_integer_files=False,
|
||||
remove_indices_greater_than=None,
|
||||
extension=None,
|
||||
):
|
||||
def get_sorted_integer_files(
|
||||
directory: str,
|
||||
min_index: float = 0,
|
||||
max_index: float = np.inf,
|
||||
remove_non_integer_files: bool = False,
|
||||
remove_indices_greater_than: float | None = None,
|
||||
extension: str | None = None,
|
||||
) -> list[str]:
|
||||
indexed_files = []
|
||||
for file in os.listdir(directory):
|
||||
if '.' in file:
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
import numpy as np
|
||||
from PIL import Image
|
||||
from typing import Iterable
|
||||
|
||||
from manimlib.utils.file_ops import find_file
|
||||
from manimlib.utils.directories import get_raster_image_dir
|
||||
from manimlib.utils.directories import get_vector_image_dir
|
||||
|
||||
|
||||
def get_full_raster_image_path(image_file_name):
|
||||
def get_full_raster_image_path(image_file_name: str) -> str:
|
||||
return find_file(
|
||||
image_file_name,
|
||||
directories=[get_raster_image_dir()],
|
||||
|
|
@ -14,7 +15,7 @@ def get_full_raster_image_path(image_file_name):
|
|||
)
|
||||
|
||||
|
||||
def get_full_vector_image_path(image_file_name):
|
||||
def get_full_vector_image_path(image_file_name: str) -> str:
|
||||
return find_file(
|
||||
image_file_name,
|
||||
directories=[get_vector_image_dir()],
|
||||
|
|
@ -22,7 +23,7 @@ def get_full_vector_image_path(image_file_name):
|
|||
)
|
||||
|
||||
|
||||
def drag_pixels(frames):
|
||||
def drag_pixels(frames: Iterable) -> list:
|
||||
curr = frames[0]
|
||||
new_frames = []
|
||||
for frame in frames:
|
||||
|
|
@ -31,7 +32,7 @@ def drag_pixels(frames):
|
|||
return new_frames
|
||||
|
||||
|
||||
def invert_image(image):
|
||||
def invert_image(image: Iterable) -> Image:
|
||||
arr = np.array(image)
|
||||
arr = (255 * np.ones(arr.shape)).astype(arr.dtype) - arr
|
||||
return Image.fromarray(arr)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import os
|
||||
import yaml
|
||||
import inspect
|
||||
import importlib
|
||||
import importlib
|
||||
from typing import Any
|
||||
|
||||
from rich import box
|
||||
from rich.rule import Rule
|
||||
|
|
@ -10,13 +11,13 @@ from rich.console import Console
|
|||
from rich.prompt import Prompt, Confirm
|
||||
|
||||
|
||||
def get_manim_dir():
|
||||
def get_manim_dir() -> str:
|
||||
manimlib_module = importlib.import_module("manimlib")
|
||||
manimlib_dir = os.path.dirname(inspect.getabsfile(manimlib_module))
|
||||
return os.path.abspath(os.path.join(manimlib_dir, ".."))
|
||||
|
||||
|
||||
def remove_empty_value(dictionary):
|
||||
def remove_empty_value(dictionary: dict[str, Any]) -> dict[str, Any]:
|
||||
for key in list(dictionary.keys()):
|
||||
if dictionary[key] == "":
|
||||
dictionary.pop(key)
|
||||
|
|
@ -24,7 +25,7 @@ def remove_empty_value(dictionary):
|
|||
remove_empty_value(dictionary[key])
|
||||
|
||||
|
||||
def init_customization():
|
||||
def init_customization() -> None:
|
||||
configuration = {
|
||||
"directories": {
|
||||
"mirror_module_path": False,
|
||||
|
|
|
|||
|
|
@ -1,8 +1,15 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import itertools as it
|
||||
from typing import Callable, Iterable, TypeVar
|
||||
|
||||
import numpy as np
|
||||
|
||||
T = TypeVar("T")
|
||||
S = TypeVar("S")
|
||||
|
||||
def remove_list_redundancies(l):
|
||||
|
||||
def remove_list_redundancies(l: Iterable[T]) -> list[T]:
|
||||
"""
|
||||
Used instead of list(set(l)) to maintain order
|
||||
Keeps the last occurrence of each element
|
||||
|
|
@ -17,7 +24,7 @@ def remove_list_redundancies(l):
|
|||
return reversed_result
|
||||
|
||||
|
||||
def list_update(l1, l2):
|
||||
def list_update(l1: Iterable[T], l2: Iterable[T]) -> list[T]:
|
||||
"""
|
||||
Used instead of list(set(l1).update(l2)) to maintain order,
|
||||
making sure duplicates are removed from l1, not l2.
|
||||
|
|
@ -25,26 +32,29 @@ def list_update(l1, l2):
|
|||
return [e for e in l1 if e not in l2] + list(l2)
|
||||
|
||||
|
||||
def list_difference_update(l1, l2):
|
||||
def list_difference_update(l1: Iterable[T], l2: Iterable[T]) -> list[T]:
|
||||
return [e for e in l1 if e not in l2]
|
||||
|
||||
|
||||
def all_elements_are_instances(iterable, Class):
|
||||
def all_elements_are_instances(iterable: Iterable, Class: type) -> bool:
|
||||
return all([isinstance(e, Class) for e in iterable])
|
||||
|
||||
|
||||
def adjacent_n_tuples(objects, n):
|
||||
def adjacent_n_tuples(objects: Iterable[T], n: int) -> zip[tuple[T, T]]:
|
||||
return zip(*[
|
||||
[*objects[k:], *objects[:k]]
|
||||
for k in range(n)
|
||||
])
|
||||
|
||||
|
||||
def adjacent_pairs(objects):
|
||||
def adjacent_pairs(objects: Iterable[T]) -> zip[tuple[T, T]]:
|
||||
return adjacent_n_tuples(objects, 2)
|
||||
|
||||
|
||||
def batch_by_property(items, property_func):
|
||||
def batch_by_property(
|
||||
items: Iterable[T],
|
||||
property_func: Callable[[T], S]
|
||||
) -> list[tuple[T, S]]:
|
||||
"""
|
||||
Takes in a list, and returns a list of tuples, (batch, prop)
|
||||
such that all items in a batch have the same output when
|
||||
|
|
@ -71,7 +81,7 @@ def batch_by_property(items, property_func):
|
|||
return batch_prop_pairs
|
||||
|
||||
|
||||
def listify(obj):
|
||||
def listify(obj) -> list:
|
||||
if isinstance(obj, str):
|
||||
return [obj]
|
||||
try:
|
||||
|
|
@ -80,13 +90,13 @@ def listify(obj):
|
|||
return [obj]
|
||||
|
||||
|
||||
def resize_array(nparray, length):
|
||||
def resize_array(nparray: np.ndarray, length: int) -> np.ndarray:
|
||||
if len(nparray) == length:
|
||||
return nparray
|
||||
return np.resize(nparray, (length, *nparray.shape[1:]))
|
||||
|
||||
|
||||
def resize_preserving_order(nparray, length):
|
||||
def resize_preserving_order(nparray: np.ndarray, length: int) -> np.ndarray:
|
||||
if len(nparray) == 0:
|
||||
return np.zeros((length, *nparray.shape[1:]))
|
||||
if len(nparray) == length:
|
||||
|
|
@ -95,7 +105,7 @@ def resize_preserving_order(nparray, length):
|
|||
return nparray[indices]
|
||||
|
||||
|
||||
def resize_with_interpolation(nparray, length):
|
||||
def resize_with_interpolation(nparray: np.ndarray, length: int) -> np.ndarray:
|
||||
if len(nparray) == length:
|
||||
return nparray
|
||||
if length == 0:
|
||||
|
|
@ -108,7 +118,10 @@ def resize_with_interpolation(nparray, length):
|
|||
])
|
||||
|
||||
|
||||
def make_even(iterable_1, iterable_2):
|
||||
def make_even(
|
||||
iterable_1: Iterable[T],
|
||||
iterable_2: Iterable[S]
|
||||
) -> tuple[list[T], list[S]]:
|
||||
len1 = len(iterable_1)
|
||||
len2 = len(iterable_2)
|
||||
if len1 == len2:
|
||||
|
|
@ -120,7 +133,10 @@ def make_even(iterable_1, iterable_2):
|
|||
)
|
||||
|
||||
|
||||
def make_even_by_cycling(iterable_1, iterable_2):
|
||||
def make_even_by_cycling(
|
||||
iterable_1: Iterable[T],
|
||||
iterable_2: Iterable[S]
|
||||
) -> tuple[list[T], list[S]]:
|
||||
length = max(len(iterable_1), len(iterable_2))
|
||||
cycle1 = it.cycle(iterable_1)
|
||||
cycle2 = it.cycle(iterable_2)
|
||||
|
|
@ -130,7 +146,7 @@ def make_even_by_cycling(iterable_1, iterable_2):
|
|||
)
|
||||
|
||||
|
||||
def remove_nones(sequence):
|
||||
def remove_nones(sequence: Iterable) -> list:
|
||||
return [x for x in sequence if x]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import numpy as np
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from manimlib.constants import OUT
|
||||
from manimlib.utils.bezier import interpolate
|
||||
|
|
@ -9,7 +11,11 @@ from manimlib.utils.space_ops import rotation_matrix_transpose
|
|||
STRAIGHT_PATH_THRESHOLD = 0.01
|
||||
|
||||
|
||||
def straight_path(start_points, end_points, alpha):
|
||||
def straight_path(
|
||||
start_points: np.ndarray,
|
||||
end_points: np.ndarray,
|
||||
alpha: float
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Same function as interpolate, but renamed to reflect
|
||||
intent of being used to determine how a set of points move
|
||||
|
|
@ -19,7 +25,10 @@ def straight_path(start_points, end_points, alpha):
|
|||
return interpolate(start_points, end_points, alpha)
|
||||
|
||||
|
||||
def path_along_arc(arc_angle, axis=OUT):
|
||||
def path_along_arc(
|
||||
arc_angle: float,
|
||||
axis: np.ndarray = OUT
|
||||
) -> Callable[[np.ndarray, np.ndarray, float], np.ndarray]:
|
||||
"""
|
||||
If vect is vector from start to end, [vect[:,1], -vect[:,0]] is
|
||||
perpendicular to vect in the left direction.
|
||||
|
|
@ -41,9 +50,9 @@ def path_along_arc(arc_angle, axis=OUT):
|
|||
return path
|
||||
|
||||
|
||||
def clockwise_path():
|
||||
def clockwise_path() -> Callable[[np.ndarray, np.ndarray, float], np.ndarray]:
|
||||
return path_along_arc(-np.pi)
|
||||
|
||||
|
||||
def counterclockwise_path():
|
||||
def counterclockwise_path() -> Callable[[np.ndarray, np.ndarray, float], np.ndarray]:
|
||||
return path_along_arc(np.pi)
|
||||
|
|
|
|||
|
|
@ -1,44 +1,46 @@
|
|||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from manimlib.utils.bezier import bezier
|
||||
|
||||
|
||||
def linear(t):
|
||||
def linear(t: float) -> float:
|
||||
return t
|
||||
|
||||
|
||||
def smooth(t):
|
||||
def smooth(t: float) -> float:
|
||||
# Zero first and second derivatives at t=0 and t=1.
|
||||
# Equivalent to bezier([0, 0, 0, 1, 1, 1])
|
||||
s = 1 - t
|
||||
return (t**3) * (10 * s * s + 5 * s * t + t * t)
|
||||
|
||||
|
||||
def rush_into(t):
|
||||
def rush_into(t: float) -> float:
|
||||
return 2 * smooth(0.5 * t)
|
||||
|
||||
|
||||
def rush_from(t):
|
||||
def rush_from(t: float) -> float:
|
||||
return 2 * smooth(0.5 * (t + 1)) - 1
|
||||
|
||||
|
||||
def slow_into(t):
|
||||
def slow_into(t: float) -> float:
|
||||
return np.sqrt(1 - (1 - t) * (1 - t))
|
||||
|
||||
|
||||
def double_smooth(t):
|
||||
def double_smooth(t: float) -> float:
|
||||
if t < 0.5:
|
||||
return 0.5 * smooth(2 * t)
|
||||
else:
|
||||
return 0.5 * (1 + smooth(2 * t - 1))
|
||||
|
||||
|
||||
def there_and_back(t):
|
||||
def there_and_back(t: float) -> float:
|
||||
new_t = 2 * t if t < 0.5 else 2 * (1 - t)
|
||||
return smooth(new_t)
|
||||
|
||||
|
||||
def there_and_back_with_pause(t, pause_ratio=1. / 3):
|
||||
def there_and_back_with_pause(t: float, pause_ratio: float = 1. / 3) -> float:
|
||||
a = 1. / pause_ratio
|
||||
if t < 0.5 - pause_ratio / 2:
|
||||
return smooth(a * t)
|
||||
|
|
@ -48,21 +50,28 @@ def there_and_back_with_pause(t, pause_ratio=1. / 3):
|
|||
return smooth(a - a * t)
|
||||
|
||||
|
||||
def running_start(t, pull_factor=-0.5):
|
||||
def running_start(t: float, pull_factor: float = -0.5) -> float:
|
||||
return bezier([0, 0, pull_factor, pull_factor, 1, 1, 1])(t)
|
||||
|
||||
|
||||
def not_quite_there(func=smooth, proportion=0.7):
|
||||
def not_quite_there(
|
||||
func: Callable[[float], float] = smooth,
|
||||
proportion: float = 0.7
|
||||
) -> Callable[[float], float]:
|
||||
def result(t):
|
||||
return proportion * func(t)
|
||||
return result
|
||||
|
||||
|
||||
def wiggle(t, wiggles=2):
|
||||
def wiggle(t: float, wiggles: float = 2) -> float:
|
||||
return there_and_back(t) * np.sin(wiggles * np.pi * t)
|
||||
|
||||
|
||||
def squish_rate_func(func, a=0.4, b=0.6):
|
||||
def squish_rate_func(
|
||||
func: Callable[[float], float],
|
||||
a: float = 0.4,
|
||||
b: float = 0.6
|
||||
) -> Callable[[float], float]:
|
||||
def result(t):
|
||||
if a == b:
|
||||
return a
|
||||
|
|
@ -81,11 +90,11 @@ def squish_rate_func(func, a=0.4, b=0.6):
|
|||
# "lingering", different from squish_rate_func's default params
|
||||
|
||||
|
||||
def lingering(t):
|
||||
def lingering(t: float) -> float:
|
||||
return squish_rate_func(lambda t: t, 0, 0.8)(t)
|
||||
|
||||
|
||||
def exponential_decay(t, half_life=0.1):
|
||||
def exponential_decay(t: float, half_life: float = 0.1) -> float:
|
||||
# The half-life should be rather small to minimize
|
||||
# the cut-off error at the end
|
||||
return 1 - np.exp(-t / half_life)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from manimlib.utils.file_ops import find_file
|
|||
from manimlib.utils.directories import get_sound_dir
|
||||
|
||||
|
||||
def get_full_sound_file_path(sound_file_name):
|
||||
def get_full_sound_file_path(sound_file_name) -> str:
|
||||
return find_file(
|
||||
sound_file_name,
|
||||
directories=[get_sound_dir()],
|
||||
|
|
|
|||
|
|
@ -1,7 +1,11 @@
|
|||
import numpy as np
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import operator as op
|
||||
from functools import reduce
|
||||
import math
|
||||
from typing import Callable, Iterable, Sequence
|
||||
|
||||
import numpy as np
|
||||
from mapbox_earcut import triangulate_float32 as earcut
|
||||
|
||||
from manimlib.constants import RIGHT
|
||||
|
|
@ -13,7 +17,7 @@ from manimlib.utils.iterables import adjacent_pairs
|
|||
from manimlib.utils.simple_functions import clip
|
||||
|
||||
|
||||
def cross(v1, v2):
|
||||
def cross(v1: np.ndarray, v2: np.ndarray) -> list[np.ndarray]:
|
||||
return [
|
||||
v1[1] * v2[2] - v1[2] * v2[1],
|
||||
v1[2] * v2[0] - v1[0] * v2[2],
|
||||
|
|
@ -21,7 +25,7 @@ def cross(v1, v2):
|
|||
]
|
||||
|
||||
|
||||
def get_norm(vect):
|
||||
def get_norm(vect: np.ndarray) -> np.flaoting:
|
||||
return sum((x**2 for x in vect))**0.5
|
||||
|
||||
|
||||
|
|
@ -29,7 +33,7 @@ def get_norm(vect):
|
|||
# TODO, implement quaternion type
|
||||
|
||||
|
||||
def quaternion_mult(*quats):
|
||||
def quaternion_mult(*quats: Sequence[float]) -> list[float]:
|
||||
if len(quats) == 0:
|
||||
return [1, 0, 0, 0]
|
||||
result = quats[0]
|
||||
|
|
@ -45,13 +49,19 @@ def quaternion_mult(*quats):
|
|||
return result
|
||||
|
||||
|
||||
def quaternion_from_angle_axis(angle, axis, axis_normalized=False):
|
||||
def quaternion_from_angle_axis(
|
||||
angle: float,
|
||||
axis: np.ndarray,
|
||||
axis_normalized: bool = False
|
||||
) -> list[float]:
|
||||
if not axis_normalized:
|
||||
axis = normalize(axis)
|
||||
return [math.cos(angle / 2), *(math.sin(angle / 2) * axis)]
|
||||
|
||||
|
||||
def angle_axis_from_quaternion(quaternion):
|
||||
def angle_axis_from_quaternion(
|
||||
quaternion: Sequence[float]
|
||||
) -> tuple[float, np.ndarray]:
|
||||
axis = normalize(
|
||||
quaternion[1:],
|
||||
fall_back=[1, 0, 0]
|
||||
|
|
@ -62,14 +72,18 @@ def angle_axis_from_quaternion(quaternion):
|
|||
return angle, axis
|
||||
|
||||
|
||||
def quaternion_conjugate(quaternion):
|
||||
def quaternion_conjugate(quaternion: Iterable) -> list:
|
||||
result = list(quaternion)
|
||||
for i in range(1, len(result)):
|
||||
result[i] *= -1
|
||||
return result
|
||||
|
||||
|
||||
def rotate_vector(vector, angle, axis=OUT):
|
||||
def rotate_vector(
|
||||
vector: Iterable,
|
||||
angle: float,
|
||||
axis: np.ndarray = OUT
|
||||
) -> np.ndarray | list[float]:
|
||||
if len(vector) == 2:
|
||||
# Use complex numbers...because why not
|
||||
z = complex(*vector) * np.exp(complex(0, angle))
|
||||
|
|
@ -88,13 +102,13 @@ def rotate_vector(vector, angle, axis=OUT):
|
|||
return result
|
||||
|
||||
|
||||
def thick_diagonal(dim, thickness=2):
|
||||
def thick_diagonal(dim: int, thickness: int = 2) -> np.ndarray:
|
||||
row_indices = np.arange(dim).repeat(dim).reshape((dim, dim))
|
||||
col_indices = np.transpose(row_indices)
|
||||
return (np.abs(row_indices - col_indices) < thickness).astype('uint8')
|
||||
|
||||
|
||||
def rotation_matrix_transpose_from_quaternion(quat):
|
||||
def rotation_matrix_transpose_from_quaternion(quat: Iterable) -> list[list[float]]:
|
||||
quat_inv = quaternion_conjugate(quat)
|
||||
return [
|
||||
quaternion_mult(quat, [0, *basis], quat_inv)[1:]
|
||||
|
|
@ -106,11 +120,11 @@ def rotation_matrix_transpose_from_quaternion(quat):
|
|||
]
|
||||
|
||||
|
||||
def rotation_matrix_from_quaternion(quat):
|
||||
def rotation_matrix_from_quaternion(quat: Iterable) -> np.ndarray:
|
||||
return np.transpose(rotation_matrix_transpose_from_quaternion(quat))
|
||||
|
||||
|
||||
def rotation_matrix_transpose(angle, axis):
|
||||
def rotation_matrix_transpose(angle: float, axis: np.ndarray) -> list[list[flaot]]:
|
||||
if axis[0] == 0 and axis[1] == 0:
|
||||
# axis = [0, 0, z] case is common enough it's worth
|
||||
# having a shortcut
|
||||
|
|
@ -126,14 +140,14 @@ def rotation_matrix_transpose(angle, axis):
|
|||
return rotation_matrix_transpose_from_quaternion(quat)
|
||||
|
||||
|
||||
def rotation_matrix(angle, axis):
|
||||
def rotation_matrix(angle: float, axis: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Rotation in R^3 about a specified axis of rotation.
|
||||
"""
|
||||
return np.transpose(rotation_matrix_transpose(angle, axis))
|
||||
|
||||
|
||||
def rotation_about_z(angle):
|
||||
def rotation_about_z(angle: float) -> list[list[float]]:
|
||||
return [
|
||||
[math.cos(angle), -math.sin(angle), 0],
|
||||
[math.sin(angle), math.cos(angle), 0],
|
||||
|
|
@ -141,7 +155,7 @@ def rotation_about_z(angle):
|
|||
]
|
||||
|
||||
|
||||
def z_to_vector(vector):
|
||||
def z_to_vector(vector: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Returns some matrix in SO(3) which takes the z-axis to the
|
||||
(normalized) vector provided as an argument
|
||||
|
|
@ -156,7 +170,7 @@ def z_to_vector(vector):
|
|||
return rotation_matrix(angle, axis=axis)
|
||||
|
||||
|
||||
def rotation_between_vectors(v1, v2):
|
||||
def rotation_between_vectors(v1, v2) -> np.ndarray:
|
||||
if np.all(np.isclose(v1, v2)):
|
||||
return np.identity(3)
|
||||
return rotation_matrix(
|
||||
|
|
@ -165,14 +179,14 @@ def rotation_between_vectors(v1, v2):
|
|||
)
|
||||
|
||||
|
||||
def angle_of_vector(vector):
|
||||
def angle_of_vector(vector: Sequence[float]) -> float:
|
||||
"""
|
||||
Returns polar coordinate theta when vector is project on xy plane
|
||||
"""
|
||||
return np.angle(complex(*vector[:2]))
|
||||
|
||||
|
||||
def angle_between_vectors(v1, v2):
|
||||
def angle_between_vectors(v1: np.ndarray, v2: np.ndarray) -> float:
|
||||
"""
|
||||
Returns the angle between two 3D vectors.
|
||||
This angle will always be btw 0 and pi
|
||||
|
|
@ -180,12 +194,15 @@ def angle_between_vectors(v1, v2):
|
|||
return math.acos(clip(np.dot(normalize(v1), normalize(v2)), -1, 1))
|
||||
|
||||
|
||||
def project_along_vector(point, vector):
|
||||
def project_along_vector(point: np.ndarray, vector: np.ndarray) -> np.ndarray:
|
||||
matrix = np.identity(3) - np.outer(vector, vector)
|
||||
return np.dot(point, matrix.T)
|
||||
|
||||
|
||||
def normalize(vect, fall_back=None):
|
||||
def normalize(
|
||||
vect: np.ndarray,
|
||||
fall_back: np.ndarray | None = None
|
||||
) -> np.ndarray:
|
||||
norm = get_norm(vect)
|
||||
if norm > 0:
|
||||
return np.array(vect) / norm
|
||||
|
|
@ -195,7 +212,10 @@ def normalize(vect, fall_back=None):
|
|||
return np.zeros(len(vect))
|
||||
|
||||
|
||||
def normalize_along_axis(array, axis, fall_back=None):
|
||||
def normalize_along_axis(
|
||||
array: np.ndarray,
|
||||
axis: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
norms = np.sqrt((array * array).sum(axis))
|
||||
norms[norms == 0] = 1
|
||||
buffed_norms = np.repeat(norms, array.shape[axis]).reshape(array.shape)
|
||||
|
|
@ -203,7 +223,11 @@ def normalize_along_axis(array, axis, fall_back=None):
|
|||
return array
|
||||
|
||||
|
||||
def get_unit_normal(v1, v2, tol=1e-6):
|
||||
def get_unit_normal(
|
||||
v1: np.ndarray,
|
||||
v2: np.ndarray,
|
||||
tol: float=1e-6
|
||||
) -> np.ndarray:
|
||||
v1 = normalize(v1)
|
||||
v2 = normalize(v2)
|
||||
cp = cross(v1, v2)
|
||||
|
|
@ -221,7 +245,7 @@ def get_unit_normal(v1, v2, tol=1e-6):
|
|||
###
|
||||
|
||||
|
||||
def compass_directions(n=4, start_vect=RIGHT):
|
||||
def compass_directions(n: int = 4, start_vect: np.ndarray = RIGHT) -> np.ndarray:
|
||||
angle = TAU / n
|
||||
return np.array([
|
||||
rotate_vector(start_vect, k * angle)
|
||||
|
|
@ -229,28 +253,36 @@ def compass_directions(n=4, start_vect=RIGHT):
|
|||
])
|
||||
|
||||
|
||||
def complex_to_R3(complex_num):
|
||||
def complex_to_R3(complex_num: complex) -> np.ndarray:
|
||||
return np.array((complex_num.real, complex_num.imag, 0))
|
||||
|
||||
|
||||
def R3_to_complex(point):
|
||||
def R3_to_complex(point: Sequence[float]) -> complex:
|
||||
return complex(*point[:2])
|
||||
|
||||
|
||||
def complex_func_to_R3_func(complex_func):
|
||||
def complex_func_to_R3_func(
|
||||
complex_func: Callable[[complex], complex]
|
||||
) -> Callable[[np.ndarray], np.ndarray]:
|
||||
return lambda p: complex_to_R3(complex_func(R3_to_complex(p)))
|
||||
|
||||
|
||||
def center_of_mass(points):
|
||||
def center_of_mass(points: Iterable[Sequence[float]]) -> np.ndarray:
|
||||
points = [np.array(point).astype("float") for point in points]
|
||||
return sum(points) / len(points)
|
||||
|
||||
|
||||
def midpoint(point1, point2):
|
||||
def midpoint(
|
||||
point1: Sequence[float],
|
||||
point2: Sequence[float]
|
||||
) -> np.ndarray:
|
||||
return center_of_mass([point1, point2])
|
||||
|
||||
|
||||
def line_intersection(line1, line2):
|
||||
def line_intersection(
|
||||
line1: Sequence[Sequence[float]],
|
||||
line2: Sequence[Sequence[float]]
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
return intersection point of two lines,
|
||||
each defined with a pair of vectors determining
|
||||
|
|
@ -271,7 +303,13 @@ def line_intersection(line1, line2):
|
|||
return np.array([x, y, 0])
|
||||
|
||||
|
||||
def find_intersection(p0, v0, p1, v1, threshold=1e-5):
|
||||
def find_intersection(
|
||||
p0: Iterable[float],
|
||||
v0: Iterable[float],
|
||||
p1: Iterable[float],
|
||||
v1: Iterable[float],
|
||||
threshold: float = 1e-5
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Return the intersection of a line passing through p0 in direction v0
|
||||
with one passing through p1 in direction v1. (Or array of intersections
|
||||
|
|
@ -300,7 +338,11 @@ def find_intersection(p0, v0, p1, v1, threshold=1e-5):
|
|||
return p0 + ratio * v0
|
||||
|
||||
|
||||
def get_closest_point_on_line(a, b, p):
|
||||
def get_closest_point_on_line(
|
||||
a: np.ndarray,
|
||||
b: np.ndarray,
|
||||
p: np.ndarray
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
It returns point x such that
|
||||
x is on line ab and xp is perpendicular to ab.
|
||||
|
|
@ -315,7 +357,7 @@ def get_closest_point_on_line(a, b, p):
|
|||
return ((t * a) + ((1 - t) * b))
|
||||
|
||||
|
||||
def get_winding_number(points):
|
||||
def get_winding_number(points: Iterable[float]) -> float:
|
||||
total_angle = 0
|
||||
for p1, p2 in adjacent_pairs(points):
|
||||
d_angle = angle_of_vector(p2) - angle_of_vector(p1)
|
||||
|
|
@ -326,14 +368,18 @@ def get_winding_number(points):
|
|||
|
||||
##
|
||||
|
||||
def cross2d(a, b):
|
||||
def cross2d(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
||||
if len(a.shape) == 2:
|
||||
return a[:, 0] * b[:, 1] - a[:, 1] * b[:, 0]
|
||||
else:
|
||||
return a[0] * b[1] - b[0] * a[1]
|
||||
|
||||
|
||||
def tri_area(a, b, c):
|
||||
def tri_area(
|
||||
a: Sequence[float],
|
||||
b: Sequence[float],
|
||||
c: Sequence[float]
|
||||
) -> float:
|
||||
return 0.5 * abs(
|
||||
a[0] * (b[1] - c[1]) +
|
||||
b[0] * (c[1] - a[1]) +
|
||||
|
|
@ -341,7 +387,12 @@ def tri_area(a, b, c):
|
|||
)
|
||||
|
||||
|
||||
def is_inside_triangle(p, a, b, c):
|
||||
def is_inside_triangle(
|
||||
p: np.ndarray,
|
||||
a: np.ndarray,
|
||||
b: np.ndarray,
|
||||
c: np.ndarray
|
||||
) -> bool:
|
||||
"""
|
||||
Test if point p is inside triangle abc
|
||||
"""
|
||||
|
|
@ -353,12 +404,12 @@ def is_inside_triangle(p, a, b, c):
|
|||
return np.all(crosses > 0) or np.all(crosses < 0)
|
||||
|
||||
|
||||
def norm_squared(v):
|
||||
def norm_squared(v: Sequence[float]) -> float:
|
||||
return v[0] * v[0] + v[1] * v[1] + v[2] * v[2]
|
||||
|
||||
|
||||
# TODO, fails for polygons drawn over themselves
|
||||
def earclip_triangulation(verts, ring_ends):
|
||||
def earclip_triangulation(verts: np.ndarray, ring_ends: list[int]) -> list:
|
||||
"""
|
||||
Returns a list of indices giving a triangulation
|
||||
of a polygon, potentially with holes
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue