chore: add type hints to manimlib.utils

This commit is contained in:
TonyCrane 2022-02-12 23:47:23 +08:00
parent 67f5b10626
commit 6e292daf58
No known key found for this signature in database
GPG key ID: 2313A5058A9C637C
12 changed files with 281 additions and 122 deletions

View file

@ -1,3 +1,7 @@
from __future__ import annotations
from typing import Iterable, Callable, TypeVar
from scipy import linalg
import numpy as np
@ -8,9 +12,9 @@ from manimlib.utils.space_ops import midpoint
from manimlib.logger import log
CLOSED_THRESHOLD = 0.001
T = TypeVar("T")
def bezier(points):
def bezier(points: Iterable) -> Callable[[float], float | Iterable]:
n = len(points) - 1
def result(t):
@ -22,7 +26,11 @@ def bezier(points):
return result
def partial_bezier_points(points, a, b):
def partial_bezier_points(
points: Iterable[np.ndarray],
a: float,
b: float
) -> list[float]:
"""
Given an list of points which define
a bezier curve, and two numbers 0<=a<b<=1,
@ -48,7 +56,11 @@ def partial_bezier_points(points, a, b):
# Shortened version of partial_bezier_points just for quadratics,
# since this is called a fair amount
def partial_quadratic_bezier_points(points, a, b):
def partial_quadratic_bezier_points(
points: Iterable[np.ndarray],
a: float,
b: float
) -> list[float]:
if a == 1:
return 3 * [points[-1]]
@ -65,7 +77,7 @@ def partial_quadratic_bezier_points(points, a, b):
# Linear interpolation variants
def interpolate(start, end, alpha):
def interpolate(start: T, end: T, alpha: float) -> T:
try:
return (1 - alpha) * start + alpha * end
except TypeError:
@ -76,12 +88,22 @@ def interpolate(start, end, alpha):
sys.exit(2)
def set_array_by_interpolation(arr, arr1, arr2, alpha, interp_func=interpolate):
def set_array_by_interpolation(
arr: list[T],
arr1: list[T],
arr2: list[T],
alpha: float,
interp_func: Callable[[T, T, float], T] = interpolate
) -> list[T]:
arr[:] = interp_func(arr1, arr2, alpha)
return arr
def integer_interpolate(start, end, alpha):
def integer_interpolate(
start: T,
end: T,
alpha: float
) -> tuple[int, float]:
"""
alpha is a float between 0 and 1. This returns
an integer between start and end (inclusive) representing
@ -102,22 +124,30 @@ def integer_interpolate(start, end, alpha):
return (value, residue)
def mid(start, end):
def mid(start: T, end: T) -> T:
return (start + end) / 2.0
def inverse_interpolate(start, end, value):
def inverse_interpolate(start: T, end: T, value: T) -> float:
return np.true_divide(value - start, end - start)
def match_interpolate(new_start, new_end, old_start, old_end, old_value):
def match_interpolate(
new_start: T,
new_end: T,
old_start: T,
old_end: T,
old_value: T
) -> T:
return interpolate(
new_start, new_end,
inverse_interpolate(old_start, old_end, old_value)
)
def get_smooth_quadratic_bezier_handle_points(points):
def get_smooth_quadratic_bezier_handle_points(
points: Iterable[np.ndarray]
) -> np.ndarray | list[np.ndarray]:
"""
Figuring out which bezier curves most smoothly connect a sequence of points.
@ -149,7 +179,9 @@ def get_smooth_quadratic_bezier_handle_points(points):
return handles
def get_smooth_cubic_bezier_handle_points(points):
def get_smooth_cubic_bezier_handle_points(
points: Iterable[np.ndarray]
) -> tuple[np.ndarray, np.ndarray]:
points = np.array(points)
num_handles = len(points) - 1
dim = points.shape[1]
@ -207,7 +239,10 @@ def get_smooth_cubic_bezier_handle_points(points):
return handle_pairs[0::2], handle_pairs[1::2]
def diag_to_matrix(l_and_u, diag):
def diag_to_matrix(
l_and_u: tuple[int, int],
diag: np.ndarray
) -> np.ndarray:
"""
Converts array whose rows represent diagonal
entries of a matrix into the matrix itself.
@ -224,13 +259,18 @@ def diag_to_matrix(l_and_u, diag):
return matrix
def is_closed(points):
def is_closed(points: Iterable[np.ndarray]) -> bool:
return np.allclose(points[0], points[-1])
# Given 4 control points for a cubic bezier curve (or arrays of such)
# return control points for 2 quadratics (or 2n quadratics) approximating them.
def get_quadratic_approximation_of_cubic(a0, h0, h1, a1):
def get_quadratic_approximation_of_cubic(
a0: np.ndarray | Iterable[np.ndarray],
h0: np.ndarray | Iterable[np.ndarray],
h1: np.ndarray | Iterable[np.ndarray],
a1: np.ndarray | Iterable[np.ndarray]
) -> np.ndarray:
a0 = np.array(a0, ndmin=2)
h0 = np.array(h0, ndmin=2)
h1 = np.array(h1, ndmin=2)
@ -298,7 +338,9 @@ def get_quadratic_approximation_of_cubic(a0, h0, h1, a1):
return result
def get_smooth_quadratic_bezier_path_through(points):
def get_smooth_quadratic_bezier_path_through(
points: list[np.ndarray]
) -> np.ndarray:
# TODO
h0, h1 = get_smooth_cubic_bezier_handle_points(points)
a0 = points[:-1]

View file

@ -1,19 +1,27 @@
from __future__ import annotations
import time
import numpy as np
from typing import Callable
from manimlib.constants import BLACK
from manimlib.mobject.mobject import Mobject
from manimlib.mobject.numbers import Integer
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.logger import log
def print_family(mobject, n_tabs=0):
def print_family(mobject: Mobject, n_tabs: int = 0) -> None:
"""For debugging purposes"""
log.debug("\t" * n_tabs + str(mobject) + " " + str(id(mobject)))
for submob in mobject.submobjects:
print_family(submob, n_tabs + 1)
def index_labels(mobject, label_height=0.15):
def index_labels(
mobject: Mobject | np.ndarray,
label_height: float = 0.15
) -> VGroup:
labels = VGroup()
for n, submob in enumerate(mobject):
label = Integer(n)
@ -24,7 +32,7 @@ def index_labels(mobject, label_height=0.15):
return labels
def get_runtime(func):
def get_runtime(func: Callable) -> float:
now = time.time()
func()
return time.time() - now

View file

@ -1,48 +1,50 @@
from __future__ import annotations
import os
from manimlib.utils.file_ops import guarantee_existence
from manimlib.utils.customization import get_customization
def get_directories():
def get_directories() -> dict[str, str]:
return get_customization()["directories"]
def get_temp_dir():
def get_temp_dir() -> str:
return get_directories()["temporary_storage"]
def get_tex_dir():
def get_tex_dir() -> str:
return guarantee_existence(os.path.join(get_temp_dir(), "Tex"))
def get_text_dir():
def get_text_dir() -> str:
return guarantee_existence(os.path.join(get_temp_dir(), "Text"))
def get_mobject_data_dir():
def get_mobject_data_dir() -> str:
return guarantee_existence(os.path.join(get_temp_dir(), "mobject_data"))
def get_downloads_dir():
def get_downloads_dir() -> str:
return guarantee_existence(os.path.join(get_temp_dir(), "manim_downloads"))
def get_output_dir():
def get_output_dir() -> str:
return guarantee_existence(get_directories()["output"])
def get_raster_image_dir():
def get_raster_image_dir() -> str:
return get_directories()["raster_images"]
def get_vector_image_dir():
def get_vector_image_dir() -> str:
return get_directories()["vector_images"]
def get_sound_dir():
def get_sound_dir() -> str:
return get_directories()["sounds"]
def get_shader_dir():
def get_shader_dir() -> str:
return get_directories()["shaders"]

View file

@ -1,7 +1,15 @@
from __future__ import annotations
import itertools as it
from typing import Iterable
from manimlib.mobject.mobject import Mobject
def extract_mobject_family_members(mobject_list, only_those_with_points=False):
def extract_mobject_family_members(
mobject_list: Iterable[Mobject],
only_those_with_points: bool = False
) -> list[Mobject]:
result = list(it.chain(*[
mob.get_family()
for mob in mobject_list
@ -11,7 +19,10 @@ def extract_mobject_family_members(mobject_list, only_those_with_points=False):
return result
def restructure_list_to_exclude_certain_family_members(mobject_list, to_remove):
def restructure_list_to_exclude_certain_family_members(
mobject_list: list[Mobject],
to_remove: list[Mobject]
) -> list[Mobject]:
"""
Removes anything in to_remove from mobject_list, but in the event that one of
the items to be removed is a member of the family of an item in mobject_list,

View file

@ -1,9 +1,13 @@
from __future__ import annotations
import os
from typing import Iterable
import numpy as np
import validators
def add_extension_if_not_present(file_name, extension):
def add_extension_if_not_present(file_name: str, extension: str) -> str:
# This could conceivably be smarter about handling existing differing extensions
if(file_name[-len(extension):] != extension):
return file_name + extension
@ -11,13 +15,17 @@ def add_extension_if_not_present(file_name, extension):
return file_name
def guarantee_existence(path):
def guarantee_existence(path: str) -> str:
if not os.path.exists(path):
os.makedirs(path)
return os.path.abspath(path)
def find_file(file_name, directories=None, extensions=None):
def find_file(
file_name: str,
directories: Iterable[str] | None = None,
extensions: Iterable[str] | None = None
) -> str:
# Check if this is a file online first, and if so, download
# it to a temporary directory
if validators.url(file_name):
@ -47,13 +55,14 @@ def find_file(file_name, directories=None, extensions=None):
raise IOError(f"{file_name} not Found")
def get_sorted_integer_files(directory,
min_index=0,
max_index=np.inf,
remove_non_integer_files=False,
remove_indices_greater_than=None,
extension=None,
):
def get_sorted_integer_files(
directory: str,
min_index: float = 0,
max_index: float = np.inf,
remove_non_integer_files: bool = False,
remove_indices_greater_than: float | None = None,
extension: str | None = None,
) -> list[str]:
indexed_files = []
for file in os.listdir(directory):
if '.' in file:

View file

@ -1,12 +1,13 @@
import numpy as np
from PIL import Image
from typing import Iterable
from manimlib.utils.file_ops import find_file
from manimlib.utils.directories import get_raster_image_dir
from manimlib.utils.directories import get_vector_image_dir
def get_full_raster_image_path(image_file_name):
def get_full_raster_image_path(image_file_name: str) -> str:
return find_file(
image_file_name,
directories=[get_raster_image_dir()],
@ -14,7 +15,7 @@ def get_full_raster_image_path(image_file_name):
)
def get_full_vector_image_path(image_file_name):
def get_full_vector_image_path(image_file_name: str) -> str:
return find_file(
image_file_name,
directories=[get_vector_image_dir()],
@ -22,7 +23,7 @@ def get_full_vector_image_path(image_file_name):
)
def drag_pixels(frames):
def drag_pixels(frames: Iterable) -> list:
curr = frames[0]
new_frames = []
for frame in frames:
@ -31,7 +32,7 @@ def drag_pixels(frames):
return new_frames
def invert_image(image):
def invert_image(image: Iterable) -> Image:
arr = np.array(image)
arr = (255 * np.ones(arr.shape)).astype(arr.dtype) - arr
return Image.fromarray(arr)

View file

@ -1,7 +1,8 @@
import os
import yaml
import inspect
import importlib
import importlib
from typing import Any
from rich import box
from rich.rule import Rule
@ -10,13 +11,13 @@ from rich.console import Console
from rich.prompt import Prompt, Confirm
def get_manim_dir():
def get_manim_dir() -> str:
manimlib_module = importlib.import_module("manimlib")
manimlib_dir = os.path.dirname(inspect.getabsfile(manimlib_module))
return os.path.abspath(os.path.join(manimlib_dir, ".."))
def remove_empty_value(dictionary):
def remove_empty_value(dictionary: dict[str, Any]) -> dict[str, Any]:
for key in list(dictionary.keys()):
if dictionary[key] == "":
dictionary.pop(key)
@ -24,7 +25,7 @@ def remove_empty_value(dictionary):
remove_empty_value(dictionary[key])
def init_customization():
def init_customization() -> None:
configuration = {
"directories": {
"mirror_module_path": False,

View file

@ -1,8 +1,15 @@
from __future__ import annotations
import itertools as it
from typing import Callable, Iterable, TypeVar
import numpy as np
T = TypeVar("T")
S = TypeVar("S")
def remove_list_redundancies(l):
def remove_list_redundancies(l: Iterable[T]) -> list[T]:
"""
Used instead of list(set(l)) to maintain order
Keeps the last occurrence of each element
@ -17,7 +24,7 @@ def remove_list_redundancies(l):
return reversed_result
def list_update(l1, l2):
def list_update(l1: Iterable[T], l2: Iterable[T]) -> list[T]:
"""
Used instead of list(set(l1).update(l2)) to maintain order,
making sure duplicates are removed from l1, not l2.
@ -25,26 +32,29 @@ def list_update(l1, l2):
return [e for e in l1 if e not in l2] + list(l2)
def list_difference_update(l1, l2):
def list_difference_update(l1: Iterable[T], l2: Iterable[T]) -> list[T]:
return [e for e in l1 if e not in l2]
def all_elements_are_instances(iterable, Class):
def all_elements_are_instances(iterable: Iterable, Class: type) -> bool:
return all([isinstance(e, Class) for e in iterable])
def adjacent_n_tuples(objects, n):
def adjacent_n_tuples(objects: Iterable[T], n: int) -> zip[tuple[T, T]]:
return zip(*[
[*objects[k:], *objects[:k]]
for k in range(n)
])
def adjacent_pairs(objects):
def adjacent_pairs(objects: Iterable[T]) -> zip[tuple[T, T]]:
return adjacent_n_tuples(objects, 2)
def batch_by_property(items, property_func):
def batch_by_property(
items: Iterable[T],
property_func: Callable[[T], S]
) -> list[tuple[T, S]]:
"""
Takes in a list, and returns a list of tuples, (batch, prop)
such that all items in a batch have the same output when
@ -71,7 +81,7 @@ def batch_by_property(items, property_func):
return batch_prop_pairs
def listify(obj):
def listify(obj) -> list:
if isinstance(obj, str):
return [obj]
try:
@ -80,13 +90,13 @@ def listify(obj):
return [obj]
def resize_array(nparray, length):
def resize_array(nparray: np.ndarray, length: int) -> np.ndarray:
if len(nparray) == length:
return nparray
return np.resize(nparray, (length, *nparray.shape[1:]))
def resize_preserving_order(nparray, length):
def resize_preserving_order(nparray: np.ndarray, length: int) -> np.ndarray:
if len(nparray) == 0:
return np.zeros((length, *nparray.shape[1:]))
if len(nparray) == length:
@ -95,7 +105,7 @@ def resize_preserving_order(nparray, length):
return nparray[indices]
def resize_with_interpolation(nparray, length):
def resize_with_interpolation(nparray: np.ndarray, length: int) -> np.ndarray:
if len(nparray) == length:
return nparray
if length == 0:
@ -108,7 +118,10 @@ def resize_with_interpolation(nparray, length):
])
def make_even(iterable_1, iterable_2):
def make_even(
iterable_1: Iterable[T],
iterable_2: Iterable[S]
) -> tuple[list[T], list[S]]:
len1 = len(iterable_1)
len2 = len(iterable_2)
if len1 == len2:
@ -120,7 +133,10 @@ def make_even(iterable_1, iterable_2):
)
def make_even_by_cycling(iterable_1, iterable_2):
def make_even_by_cycling(
iterable_1: Iterable[T],
iterable_2: Iterable[S]
) -> tuple[list[T], list[S]]:
length = max(len(iterable_1), len(iterable_2))
cycle1 = it.cycle(iterable_1)
cycle2 = it.cycle(iterable_2)
@ -130,7 +146,7 @@ def make_even_by_cycling(iterable_1, iterable_2):
)
def remove_nones(sequence):
def remove_nones(sequence: Iterable) -> list:
return [x for x in sequence if x]

View file

@ -1,5 +1,7 @@
import numpy as np
import math
from typing import Callable
import numpy as np
from manimlib.constants import OUT
from manimlib.utils.bezier import interpolate
@ -9,7 +11,11 @@ from manimlib.utils.space_ops import rotation_matrix_transpose
STRAIGHT_PATH_THRESHOLD = 0.01
def straight_path(start_points, end_points, alpha):
def straight_path(
start_points: np.ndarray,
end_points: np.ndarray,
alpha: float
) -> np.ndarray:
"""
Same function as interpolate, but renamed to reflect
intent of being used to determine how a set of points move
@ -19,7 +25,10 @@ def straight_path(start_points, end_points, alpha):
return interpolate(start_points, end_points, alpha)
def path_along_arc(arc_angle, axis=OUT):
def path_along_arc(
arc_angle: float,
axis: np.ndarray = OUT
) -> Callable[[np.ndarray, np.ndarray, float], np.ndarray]:
"""
If vect is vector from start to end, [vect[:,1], -vect[:,0]] is
perpendicular to vect in the left direction.
@ -41,9 +50,9 @@ def path_along_arc(arc_angle, axis=OUT):
return path
def clockwise_path():
def clockwise_path() -> Callable[[np.ndarray, np.ndarray, float], np.ndarray]:
return path_along_arc(-np.pi)
def counterclockwise_path():
def counterclockwise_path() -> Callable[[np.ndarray, np.ndarray, float], np.ndarray]:
return path_along_arc(np.pi)

View file

@ -1,44 +1,46 @@
from typing import Callable
import numpy as np
from manimlib.utils.bezier import bezier
def linear(t):
def linear(t: float) -> float:
return t
def smooth(t):
def smooth(t: float) -> float:
# Zero first and second derivatives at t=0 and t=1.
# Equivalent to bezier([0, 0, 0, 1, 1, 1])
s = 1 - t
return (t**3) * (10 * s * s + 5 * s * t + t * t)
def rush_into(t):
def rush_into(t: float) -> float:
return 2 * smooth(0.5 * t)
def rush_from(t):
def rush_from(t: float) -> float:
return 2 * smooth(0.5 * (t + 1)) - 1
def slow_into(t):
def slow_into(t: float) -> float:
return np.sqrt(1 - (1 - t) * (1 - t))
def double_smooth(t):
def double_smooth(t: float) -> float:
if t < 0.5:
return 0.5 * smooth(2 * t)
else:
return 0.5 * (1 + smooth(2 * t - 1))
def there_and_back(t):
def there_and_back(t: float) -> float:
new_t = 2 * t if t < 0.5 else 2 * (1 - t)
return smooth(new_t)
def there_and_back_with_pause(t, pause_ratio=1. / 3):
def there_and_back_with_pause(t: float, pause_ratio: float = 1. / 3) -> float:
a = 1. / pause_ratio
if t < 0.5 - pause_ratio / 2:
return smooth(a * t)
@ -48,21 +50,28 @@ def there_and_back_with_pause(t, pause_ratio=1. / 3):
return smooth(a - a * t)
def running_start(t, pull_factor=-0.5):
def running_start(t: float, pull_factor: float = -0.5) -> float:
return bezier([0, 0, pull_factor, pull_factor, 1, 1, 1])(t)
def not_quite_there(func=smooth, proportion=0.7):
def not_quite_there(
func: Callable[[float], float] = smooth,
proportion: float = 0.7
) -> Callable[[float], float]:
def result(t):
return proportion * func(t)
return result
def wiggle(t, wiggles=2):
def wiggle(t: float, wiggles: float = 2) -> float:
return there_and_back(t) * np.sin(wiggles * np.pi * t)
def squish_rate_func(func, a=0.4, b=0.6):
def squish_rate_func(
func: Callable[[float], float],
a: float = 0.4,
b: float = 0.6
) -> Callable[[float], float]:
def result(t):
if a == b:
return a
@ -81,11 +90,11 @@ def squish_rate_func(func, a=0.4, b=0.6):
# "lingering", different from squish_rate_func's default params
def lingering(t):
def lingering(t: float) -> float:
return squish_rate_func(lambda t: t, 0, 0.8)(t)
def exponential_decay(t, half_life=0.1):
def exponential_decay(t: float, half_life: float = 0.1) -> float:
# The half-life should be rather small to minimize
# the cut-off error at the end
return 1 - np.exp(-t / half_life)

View file

@ -2,7 +2,7 @@ from manimlib.utils.file_ops import find_file
from manimlib.utils.directories import get_sound_dir
def get_full_sound_file_path(sound_file_name):
def get_full_sound_file_path(sound_file_name) -> str:
return find_file(
sound_file_name,
directories=[get_sound_dir()],

View file

@ -1,7 +1,11 @@
import numpy as np
from __future__ import annotations
import math
import operator as op
from functools import reduce
import math
from typing import Callable, Iterable, Sequence
import numpy as np
from mapbox_earcut import triangulate_float32 as earcut
from manimlib.constants import RIGHT
@ -13,7 +17,7 @@ from manimlib.utils.iterables import adjacent_pairs
from manimlib.utils.simple_functions import clip
def cross(v1, v2):
def cross(v1: np.ndarray, v2: np.ndarray) -> list[np.ndarray]:
return [
v1[1] * v2[2] - v1[2] * v2[1],
v1[2] * v2[0] - v1[0] * v2[2],
@ -21,7 +25,7 @@ def cross(v1, v2):
]
def get_norm(vect):
def get_norm(vect: np.ndarray) -> np.flaoting:
return sum((x**2 for x in vect))**0.5
@ -29,7 +33,7 @@ def get_norm(vect):
# TODO, implement quaternion type
def quaternion_mult(*quats):
def quaternion_mult(*quats: Sequence[float]) -> list[float]:
if len(quats) == 0:
return [1, 0, 0, 0]
result = quats[0]
@ -45,13 +49,19 @@ def quaternion_mult(*quats):
return result
def quaternion_from_angle_axis(angle, axis, axis_normalized=False):
def quaternion_from_angle_axis(
angle: float,
axis: np.ndarray,
axis_normalized: bool = False
) -> list[float]:
if not axis_normalized:
axis = normalize(axis)
return [math.cos(angle / 2), *(math.sin(angle / 2) * axis)]
def angle_axis_from_quaternion(quaternion):
def angle_axis_from_quaternion(
quaternion: Sequence[float]
) -> tuple[float, np.ndarray]:
axis = normalize(
quaternion[1:],
fall_back=[1, 0, 0]
@ -62,14 +72,18 @@ def angle_axis_from_quaternion(quaternion):
return angle, axis
def quaternion_conjugate(quaternion):
def quaternion_conjugate(quaternion: Iterable) -> list:
result = list(quaternion)
for i in range(1, len(result)):
result[i] *= -1
return result
def rotate_vector(vector, angle, axis=OUT):
def rotate_vector(
vector: Iterable,
angle: float,
axis: np.ndarray = OUT
) -> np.ndarray | list[float]:
if len(vector) == 2:
# Use complex numbers...because why not
z = complex(*vector) * np.exp(complex(0, angle))
@ -88,13 +102,13 @@ def rotate_vector(vector, angle, axis=OUT):
return result
def thick_diagonal(dim, thickness=2):
def thick_diagonal(dim: int, thickness: int = 2) -> np.ndarray:
row_indices = np.arange(dim).repeat(dim).reshape((dim, dim))
col_indices = np.transpose(row_indices)
return (np.abs(row_indices - col_indices) < thickness).astype('uint8')
def rotation_matrix_transpose_from_quaternion(quat):
def rotation_matrix_transpose_from_quaternion(quat: Iterable) -> list[list[float]]:
quat_inv = quaternion_conjugate(quat)
return [
quaternion_mult(quat, [0, *basis], quat_inv)[1:]
@ -106,11 +120,11 @@ def rotation_matrix_transpose_from_quaternion(quat):
]
def rotation_matrix_from_quaternion(quat):
def rotation_matrix_from_quaternion(quat: Iterable) -> np.ndarray:
return np.transpose(rotation_matrix_transpose_from_quaternion(quat))
def rotation_matrix_transpose(angle, axis):
def rotation_matrix_transpose(angle: float, axis: np.ndarray) -> list[list[flaot]]:
if axis[0] == 0 and axis[1] == 0:
# axis = [0, 0, z] case is common enough it's worth
# having a shortcut
@ -126,14 +140,14 @@ def rotation_matrix_transpose(angle, axis):
return rotation_matrix_transpose_from_quaternion(quat)
def rotation_matrix(angle, axis):
def rotation_matrix(angle: float, axis: np.ndarray) -> np.ndarray:
"""
Rotation in R^3 about a specified axis of rotation.
"""
return np.transpose(rotation_matrix_transpose(angle, axis))
def rotation_about_z(angle):
def rotation_about_z(angle: float) -> list[list[float]]:
return [
[math.cos(angle), -math.sin(angle), 0],
[math.sin(angle), math.cos(angle), 0],
@ -141,7 +155,7 @@ def rotation_about_z(angle):
]
def z_to_vector(vector):
def z_to_vector(vector: np.ndarray) -> np.ndarray:
"""
Returns some matrix in SO(3) which takes the z-axis to the
(normalized) vector provided as an argument
@ -156,7 +170,7 @@ def z_to_vector(vector):
return rotation_matrix(angle, axis=axis)
def rotation_between_vectors(v1, v2):
def rotation_between_vectors(v1, v2) -> np.ndarray:
if np.all(np.isclose(v1, v2)):
return np.identity(3)
return rotation_matrix(
@ -165,14 +179,14 @@ def rotation_between_vectors(v1, v2):
)
def angle_of_vector(vector):
def angle_of_vector(vector: Sequence[float]) -> float:
"""
Returns polar coordinate theta when vector is project on xy plane
"""
return np.angle(complex(*vector[:2]))
def angle_between_vectors(v1, v2):
def angle_between_vectors(v1: np.ndarray, v2: np.ndarray) -> float:
"""
Returns the angle between two 3D vectors.
This angle will always be btw 0 and pi
@ -180,12 +194,15 @@ def angle_between_vectors(v1, v2):
return math.acos(clip(np.dot(normalize(v1), normalize(v2)), -1, 1))
def project_along_vector(point, vector):
def project_along_vector(point: np.ndarray, vector: np.ndarray) -> np.ndarray:
matrix = np.identity(3) - np.outer(vector, vector)
return np.dot(point, matrix.T)
def normalize(vect, fall_back=None):
def normalize(
vect: np.ndarray,
fall_back: np.ndarray | None = None
) -> np.ndarray:
norm = get_norm(vect)
if norm > 0:
return np.array(vect) / norm
@ -195,7 +212,10 @@ def normalize(vect, fall_back=None):
return np.zeros(len(vect))
def normalize_along_axis(array, axis, fall_back=None):
def normalize_along_axis(
array: np.ndarray,
axis: np.ndarray,
) -> np.ndarray:
norms = np.sqrt((array * array).sum(axis))
norms[norms == 0] = 1
buffed_norms = np.repeat(norms, array.shape[axis]).reshape(array.shape)
@ -203,7 +223,11 @@ def normalize_along_axis(array, axis, fall_back=None):
return array
def get_unit_normal(v1, v2, tol=1e-6):
def get_unit_normal(
v1: np.ndarray,
v2: np.ndarray,
tol: float=1e-6
) -> np.ndarray:
v1 = normalize(v1)
v2 = normalize(v2)
cp = cross(v1, v2)
@ -221,7 +245,7 @@ def get_unit_normal(v1, v2, tol=1e-6):
###
def compass_directions(n=4, start_vect=RIGHT):
def compass_directions(n: int = 4, start_vect: np.ndarray = RIGHT) -> np.ndarray:
angle = TAU / n
return np.array([
rotate_vector(start_vect, k * angle)
@ -229,28 +253,36 @@ def compass_directions(n=4, start_vect=RIGHT):
])
def complex_to_R3(complex_num):
def complex_to_R3(complex_num: complex) -> np.ndarray:
return np.array((complex_num.real, complex_num.imag, 0))
def R3_to_complex(point):
def R3_to_complex(point: Sequence[float]) -> complex:
return complex(*point[:2])
def complex_func_to_R3_func(complex_func):
def complex_func_to_R3_func(
complex_func: Callable[[complex], complex]
) -> Callable[[np.ndarray], np.ndarray]:
return lambda p: complex_to_R3(complex_func(R3_to_complex(p)))
def center_of_mass(points):
def center_of_mass(points: Iterable[Sequence[float]]) -> np.ndarray:
points = [np.array(point).astype("float") for point in points]
return sum(points) / len(points)
def midpoint(point1, point2):
def midpoint(
point1: Sequence[float],
point2: Sequence[float]
) -> np.ndarray:
return center_of_mass([point1, point2])
def line_intersection(line1, line2):
def line_intersection(
line1: Sequence[Sequence[float]],
line2: Sequence[Sequence[float]]
) -> np.ndarray:
"""
return intersection point of two lines,
each defined with a pair of vectors determining
@ -271,7 +303,13 @@ def line_intersection(line1, line2):
return np.array([x, y, 0])
def find_intersection(p0, v0, p1, v1, threshold=1e-5):
def find_intersection(
p0: Iterable[float],
v0: Iterable[float],
p1: Iterable[float],
v1: Iterable[float],
threshold: float = 1e-5
) -> np.ndarray:
"""
Return the intersection of a line passing through p0 in direction v0
with one passing through p1 in direction v1. (Or array of intersections
@ -300,7 +338,11 @@ def find_intersection(p0, v0, p1, v1, threshold=1e-5):
return p0 + ratio * v0
def get_closest_point_on_line(a, b, p):
def get_closest_point_on_line(
a: np.ndarray,
b: np.ndarray,
p: np.ndarray
) -> np.ndarray:
"""
It returns point x such that
x is on line ab and xp is perpendicular to ab.
@ -315,7 +357,7 @@ def get_closest_point_on_line(a, b, p):
return ((t * a) + ((1 - t) * b))
def get_winding_number(points):
def get_winding_number(points: Iterable[float]) -> float:
total_angle = 0
for p1, p2 in adjacent_pairs(points):
d_angle = angle_of_vector(p2) - angle_of_vector(p1)
@ -326,14 +368,18 @@ def get_winding_number(points):
##
def cross2d(a, b):
def cross2d(a: np.ndarray, b: np.ndarray) -> np.ndarray:
if len(a.shape) == 2:
return a[:, 0] * b[:, 1] - a[:, 1] * b[:, 0]
else:
return a[0] * b[1] - b[0] * a[1]
def tri_area(a, b, c):
def tri_area(
a: Sequence[float],
b: Sequence[float],
c: Sequence[float]
) -> float:
return 0.5 * abs(
a[0] * (b[1] - c[1]) +
b[0] * (c[1] - a[1]) +
@ -341,7 +387,12 @@ def tri_area(a, b, c):
)
def is_inside_triangle(p, a, b, c):
def is_inside_triangle(
p: np.ndarray,
a: np.ndarray,
b: np.ndarray,
c: np.ndarray
) -> bool:
"""
Test if point p is inside triangle abc
"""
@ -353,12 +404,12 @@ def is_inside_triangle(p, a, b, c):
return np.all(crosses > 0) or np.all(crosses < 0)
def norm_squared(v):
def norm_squared(v: Sequence[float]) -> float:
return v[0] * v[0] + v[1] * v[1] + v[2] * v[2]
# TODO, fails for polygons drawn over themselves
def earclip_triangulation(verts, ring_ends):
def earclip_triangulation(verts: np.ndarray, ring_ends: list[int]) -> list:
"""
Returns a list of indices giving a triangulation
of a polygon, potentially with holes