From 6e292daf58da069bfc33618096259f5aa4e24408 Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Sat, 12 Feb 2022 23:47:23 +0800 Subject: [PATCH 01/27] chore: add type hints to manimlib.utils --- manimlib/utils/bezier.py | 74 ++++++++++++++---- manimlib/utils/debug.py | 14 +++- manimlib/utils/directories.py | 24 +++--- manimlib/utils/family_ops.py | 15 +++- manimlib/utils/file_ops.py | 29 ++++--- manimlib/utils/images.py | 9 ++- manimlib/utils/init_config.py | 9 ++- manimlib/utils/iterables.py | 44 +++++++---- manimlib/utils/paths.py | 19 +++-- manimlib/utils/rate_functions.py | 37 +++++---- manimlib/utils/sounds.py | 2 +- manimlib/utils/space_ops.py | 127 ++++++++++++++++++++++--------- 12 files changed, 281 insertions(+), 122 deletions(-) diff --git a/manimlib/utils/bezier.py b/manimlib/utils/bezier.py index c6b750e1..b374b13d 100644 --- a/manimlib/utils/bezier.py +++ b/manimlib/utils/bezier.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import Iterable, Callable, TypeVar + from scipy import linalg import numpy as np @@ -8,9 +12,9 @@ from manimlib.utils.space_ops import midpoint from manimlib.logger import log CLOSED_THRESHOLD = 0.001 +T = TypeVar("T") - -def bezier(points): +def bezier(points: Iterable) -> Callable[[float], float | Iterable]: n = len(points) - 1 def result(t): @@ -22,7 +26,11 @@ def bezier(points): return result -def partial_bezier_points(points, a, b): +def partial_bezier_points( + points: Iterable[np.ndarray], + a: float, + b: float +) -> list[float]: """ Given an list of points which define a bezier curve, and two numbers 0<=a list[float]: if a == 1: return 3 * [points[-1]] @@ -65,7 +77,7 @@ def partial_quadratic_bezier_points(points, a, b): # Linear interpolation variants -def interpolate(start, end, alpha): +def interpolate(start: T, end: T, alpha: float) -> T: try: return (1 - alpha) * start + alpha * end except TypeError: @@ -76,12 +88,22 @@ def interpolate(start, end, alpha): sys.exit(2) -def set_array_by_interpolation(arr, arr1, arr2, alpha, interp_func=interpolate): +def set_array_by_interpolation( + arr: list[T], + arr1: list[T], + arr2: list[T], + alpha: float, + interp_func: Callable[[T, T, float], T] = interpolate +) -> list[T]: arr[:] = interp_func(arr1, arr2, alpha) return arr -def integer_interpolate(start, end, alpha): +def integer_interpolate( + start: T, + end: T, + alpha: float +) -> tuple[int, float]: """ alpha is a float between 0 and 1. This returns an integer between start and end (inclusive) representing @@ -102,22 +124,30 @@ def integer_interpolate(start, end, alpha): return (value, residue) -def mid(start, end): +def mid(start: T, end: T) -> T: return (start + end) / 2.0 -def inverse_interpolate(start, end, value): +def inverse_interpolate(start: T, end: T, value: T) -> float: return np.true_divide(value - start, end - start) -def match_interpolate(new_start, new_end, old_start, old_end, old_value): +def match_interpolate( + new_start: T, + new_end: T, + old_start: T, + old_end: T, + old_value: T +) -> T: return interpolate( new_start, new_end, inverse_interpolate(old_start, old_end, old_value) ) -def get_smooth_quadratic_bezier_handle_points(points): +def get_smooth_quadratic_bezier_handle_points( + points: Iterable[np.ndarray] +) -> np.ndarray | list[np.ndarray]: """ Figuring out which bezier curves most smoothly connect a sequence of points. @@ -149,7 +179,9 @@ def get_smooth_quadratic_bezier_handle_points(points): return handles -def get_smooth_cubic_bezier_handle_points(points): +def get_smooth_cubic_bezier_handle_points( + points: Iterable[np.ndarray] +) -> tuple[np.ndarray, np.ndarray]: points = np.array(points) num_handles = len(points) - 1 dim = points.shape[1] @@ -207,7 +239,10 @@ def get_smooth_cubic_bezier_handle_points(points): return handle_pairs[0::2], handle_pairs[1::2] -def diag_to_matrix(l_and_u, diag): +def diag_to_matrix( + l_and_u: tuple[int, int], + diag: np.ndarray +) -> np.ndarray: """ Converts array whose rows represent diagonal entries of a matrix into the matrix itself. @@ -224,13 +259,18 @@ def diag_to_matrix(l_and_u, diag): return matrix -def is_closed(points): +def is_closed(points: Iterable[np.ndarray]) -> bool: return np.allclose(points[0], points[-1]) # Given 4 control points for a cubic bezier curve (or arrays of such) # return control points for 2 quadratics (or 2n quadratics) approximating them. -def get_quadratic_approximation_of_cubic(a0, h0, h1, a1): +def get_quadratic_approximation_of_cubic( + a0: np.ndarray | Iterable[np.ndarray], + h0: np.ndarray | Iterable[np.ndarray], + h1: np.ndarray | Iterable[np.ndarray], + a1: np.ndarray | Iterable[np.ndarray] +) -> np.ndarray: a0 = np.array(a0, ndmin=2) h0 = np.array(h0, ndmin=2) h1 = np.array(h1, ndmin=2) @@ -298,7 +338,9 @@ def get_quadratic_approximation_of_cubic(a0, h0, h1, a1): return result -def get_smooth_quadratic_bezier_path_through(points): +def get_smooth_quadratic_bezier_path_through( + points: list[np.ndarray] +) -> np.ndarray: # TODO h0, h1 = get_smooth_cubic_bezier_handle_points(points) a0 = points[:-1] diff --git a/manimlib/utils/debug.py b/manimlib/utils/debug.py index 6d495f89..be3e9527 100644 --- a/manimlib/utils/debug.py +++ b/manimlib/utils/debug.py @@ -1,19 +1,27 @@ +from __future__ import annotations + import time +import numpy as np +from typing import Callable from manimlib.constants import BLACK +from manimlib.mobject.mobject import Mobject from manimlib.mobject.numbers import Integer from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.logger import log -def print_family(mobject, n_tabs=0): +def print_family(mobject: Mobject, n_tabs: int = 0) -> None: """For debugging purposes""" log.debug("\t" * n_tabs + str(mobject) + " " + str(id(mobject))) for submob in mobject.submobjects: print_family(submob, n_tabs + 1) -def index_labels(mobject, label_height=0.15): +def index_labels( + mobject: Mobject | np.ndarray, + label_height: float = 0.15 +) -> VGroup: labels = VGroup() for n, submob in enumerate(mobject): label = Integer(n) @@ -24,7 +32,7 @@ def index_labels(mobject, label_height=0.15): return labels -def get_runtime(func): +def get_runtime(func: Callable) -> float: now = time.time() func() return time.time() - now diff --git a/manimlib/utils/directories.py b/manimlib/utils/directories.py index c82a5368..87970523 100644 --- a/manimlib/utils/directories.py +++ b/manimlib/utils/directories.py @@ -1,48 +1,50 @@ +from __future__ import annotations + import os from manimlib.utils.file_ops import guarantee_existence from manimlib.utils.customization import get_customization -def get_directories(): +def get_directories() -> dict[str, str]: return get_customization()["directories"] -def get_temp_dir(): +def get_temp_dir() -> str: return get_directories()["temporary_storage"] -def get_tex_dir(): +def get_tex_dir() -> str: return guarantee_existence(os.path.join(get_temp_dir(), "Tex")) -def get_text_dir(): +def get_text_dir() -> str: return guarantee_existence(os.path.join(get_temp_dir(), "Text")) -def get_mobject_data_dir(): +def get_mobject_data_dir() -> str: return guarantee_existence(os.path.join(get_temp_dir(), "mobject_data")) -def get_downloads_dir(): +def get_downloads_dir() -> str: return guarantee_existence(os.path.join(get_temp_dir(), "manim_downloads")) -def get_output_dir(): +def get_output_dir() -> str: return guarantee_existence(get_directories()["output"]) -def get_raster_image_dir(): +def get_raster_image_dir() -> str: return get_directories()["raster_images"] -def get_vector_image_dir(): +def get_vector_image_dir() -> str: return get_directories()["vector_images"] -def get_sound_dir(): +def get_sound_dir() -> str: return get_directories()["sounds"] -def get_shader_dir(): +def get_shader_dir() -> str: return get_directories()["shaders"] diff --git a/manimlib/utils/family_ops.py b/manimlib/utils/family_ops.py index e6035d30..1f18614b 100644 --- a/manimlib/utils/family_ops.py +++ b/manimlib/utils/family_ops.py @@ -1,7 +1,15 @@ +from __future__ import annotations + import itertools as it +from typing import Iterable + +from manimlib.mobject.mobject import Mobject -def extract_mobject_family_members(mobject_list, only_those_with_points=False): +def extract_mobject_family_members( + mobject_list: Iterable[Mobject], + only_those_with_points: bool = False +) -> list[Mobject]: result = list(it.chain(*[ mob.get_family() for mob in mobject_list @@ -11,7 +19,10 @@ def extract_mobject_family_members(mobject_list, only_those_with_points=False): return result -def restructure_list_to_exclude_certain_family_members(mobject_list, to_remove): +def restructure_list_to_exclude_certain_family_members( + mobject_list: list[Mobject], + to_remove: list[Mobject] +) -> list[Mobject]: """ Removes anything in to_remove from mobject_list, but in the event that one of the items to be removed is a member of the family of an item in mobject_list, diff --git a/manimlib/utils/file_ops.py b/manimlib/utils/file_ops.py index 19322825..a50366bc 100644 --- a/manimlib/utils/file_ops.py +++ b/manimlib/utils/file_ops.py @@ -1,9 +1,13 @@ +from __future__ import annotations + import os +from typing import Iterable + import numpy as np import validators -def add_extension_if_not_present(file_name, extension): +def add_extension_if_not_present(file_name: str, extension: str) -> str: # This could conceivably be smarter about handling existing differing extensions if(file_name[-len(extension):] != extension): return file_name + extension @@ -11,13 +15,17 @@ def add_extension_if_not_present(file_name, extension): return file_name -def guarantee_existence(path): +def guarantee_existence(path: str) -> str: if not os.path.exists(path): os.makedirs(path) return os.path.abspath(path) -def find_file(file_name, directories=None, extensions=None): +def find_file( + file_name: str, + directories: Iterable[str] | None = None, + extensions: Iterable[str] | None = None +) -> str: # Check if this is a file online first, and if so, download # it to a temporary directory if validators.url(file_name): @@ -47,13 +55,14 @@ def find_file(file_name, directories=None, extensions=None): raise IOError(f"{file_name} not Found") -def get_sorted_integer_files(directory, - min_index=0, - max_index=np.inf, - remove_non_integer_files=False, - remove_indices_greater_than=None, - extension=None, - ): +def get_sorted_integer_files( + directory: str, + min_index: float = 0, + max_index: float = np.inf, + remove_non_integer_files: bool = False, + remove_indices_greater_than: float | None = None, + extension: str | None = None, +) -> list[str]: indexed_files = [] for file in os.listdir(directory): if '.' in file: diff --git a/manimlib/utils/images.py b/manimlib/utils/images.py index e302c3b2..cab0a45f 100644 --- a/manimlib/utils/images.py +++ b/manimlib/utils/images.py @@ -1,12 +1,13 @@ import numpy as np from PIL import Image +from typing import Iterable from manimlib.utils.file_ops import find_file from manimlib.utils.directories import get_raster_image_dir from manimlib.utils.directories import get_vector_image_dir -def get_full_raster_image_path(image_file_name): +def get_full_raster_image_path(image_file_name: str) -> str: return find_file( image_file_name, directories=[get_raster_image_dir()], @@ -14,7 +15,7 @@ def get_full_raster_image_path(image_file_name): ) -def get_full_vector_image_path(image_file_name): +def get_full_vector_image_path(image_file_name: str) -> str: return find_file( image_file_name, directories=[get_vector_image_dir()], @@ -22,7 +23,7 @@ def get_full_vector_image_path(image_file_name): ) -def drag_pixels(frames): +def drag_pixels(frames: Iterable) -> list: curr = frames[0] new_frames = [] for frame in frames: @@ -31,7 +32,7 @@ def drag_pixels(frames): return new_frames -def invert_image(image): +def invert_image(image: Iterable) -> Image: arr = np.array(image) arr = (255 * np.ones(arr.shape)).astype(arr.dtype) - arr return Image.fromarray(arr) diff --git a/manimlib/utils/init_config.py b/manimlib/utils/init_config.py index 9fcdd07e..d4e1ac5d 100644 --- a/manimlib/utils/init_config.py +++ b/manimlib/utils/init_config.py @@ -1,7 +1,8 @@ import os import yaml import inspect -import importlib +import importlib +from typing import Any from rich import box from rich.rule import Rule @@ -10,13 +11,13 @@ from rich.console import Console from rich.prompt import Prompt, Confirm -def get_manim_dir(): +def get_manim_dir() -> str: manimlib_module = importlib.import_module("manimlib") manimlib_dir = os.path.dirname(inspect.getabsfile(manimlib_module)) return os.path.abspath(os.path.join(manimlib_dir, "..")) -def remove_empty_value(dictionary): +def remove_empty_value(dictionary: dict[str, Any]) -> dict[str, Any]: for key in list(dictionary.keys()): if dictionary[key] == "": dictionary.pop(key) @@ -24,7 +25,7 @@ def remove_empty_value(dictionary): remove_empty_value(dictionary[key]) -def init_customization(): +def init_customization() -> None: configuration = { "directories": { "mirror_module_path": False, diff --git a/manimlib/utils/iterables.py b/manimlib/utils/iterables.py index d94af506..90b32bdb 100644 --- a/manimlib/utils/iterables.py +++ b/manimlib/utils/iterables.py @@ -1,8 +1,15 @@ +from __future__ import annotations + import itertools as it +from typing import Callable, Iterable, TypeVar + import numpy as np +T = TypeVar("T") +S = TypeVar("S") -def remove_list_redundancies(l): + +def remove_list_redundancies(l: Iterable[T]) -> list[T]: """ Used instead of list(set(l)) to maintain order Keeps the last occurrence of each element @@ -17,7 +24,7 @@ def remove_list_redundancies(l): return reversed_result -def list_update(l1, l2): +def list_update(l1: Iterable[T], l2: Iterable[T]) -> list[T]: """ Used instead of list(set(l1).update(l2)) to maintain order, making sure duplicates are removed from l1, not l2. @@ -25,26 +32,29 @@ def list_update(l1, l2): return [e for e in l1 if e not in l2] + list(l2) -def list_difference_update(l1, l2): +def list_difference_update(l1: Iterable[T], l2: Iterable[T]) -> list[T]: return [e for e in l1 if e not in l2] -def all_elements_are_instances(iterable, Class): +def all_elements_are_instances(iterable: Iterable, Class: type) -> bool: return all([isinstance(e, Class) for e in iterable]) -def adjacent_n_tuples(objects, n): +def adjacent_n_tuples(objects: Iterable[T], n: int) -> zip[tuple[T, T]]: return zip(*[ [*objects[k:], *objects[:k]] for k in range(n) ]) -def adjacent_pairs(objects): +def adjacent_pairs(objects: Iterable[T]) -> zip[tuple[T, T]]: return adjacent_n_tuples(objects, 2) -def batch_by_property(items, property_func): +def batch_by_property( + items: Iterable[T], + property_func: Callable[[T], S] +) -> list[tuple[T, S]]: """ Takes in a list, and returns a list of tuples, (batch, prop) such that all items in a batch have the same output when @@ -71,7 +81,7 @@ def batch_by_property(items, property_func): return batch_prop_pairs -def listify(obj): +def listify(obj) -> list: if isinstance(obj, str): return [obj] try: @@ -80,13 +90,13 @@ def listify(obj): return [obj] -def resize_array(nparray, length): +def resize_array(nparray: np.ndarray, length: int) -> np.ndarray: if len(nparray) == length: return nparray return np.resize(nparray, (length, *nparray.shape[1:])) -def resize_preserving_order(nparray, length): +def resize_preserving_order(nparray: np.ndarray, length: int) -> np.ndarray: if len(nparray) == 0: return np.zeros((length, *nparray.shape[1:])) if len(nparray) == length: @@ -95,7 +105,7 @@ def resize_preserving_order(nparray, length): return nparray[indices] -def resize_with_interpolation(nparray, length): +def resize_with_interpolation(nparray: np.ndarray, length: int) -> np.ndarray: if len(nparray) == length: return nparray if length == 0: @@ -108,7 +118,10 @@ def resize_with_interpolation(nparray, length): ]) -def make_even(iterable_1, iterable_2): +def make_even( + iterable_1: Iterable[T], + iterable_2: Iterable[S] +) -> tuple[list[T], list[S]]: len1 = len(iterable_1) len2 = len(iterable_2) if len1 == len2: @@ -120,7 +133,10 @@ def make_even(iterable_1, iterable_2): ) -def make_even_by_cycling(iterable_1, iterable_2): +def make_even_by_cycling( + iterable_1: Iterable[T], + iterable_2: Iterable[S] +) -> tuple[list[T], list[S]]: length = max(len(iterable_1), len(iterable_2)) cycle1 = it.cycle(iterable_1) cycle2 = it.cycle(iterable_2) @@ -130,7 +146,7 @@ def make_even_by_cycling(iterable_1, iterable_2): ) -def remove_nones(sequence): +def remove_nones(sequence: Iterable) -> list: return [x for x in sequence if x] diff --git a/manimlib/utils/paths.py b/manimlib/utils/paths.py index b13af223..3bbf092d 100644 --- a/manimlib/utils/paths.py +++ b/manimlib/utils/paths.py @@ -1,5 +1,7 @@ -import numpy as np import math +from typing import Callable + +import numpy as np from manimlib.constants import OUT from manimlib.utils.bezier import interpolate @@ -9,7 +11,11 @@ from manimlib.utils.space_ops import rotation_matrix_transpose STRAIGHT_PATH_THRESHOLD = 0.01 -def straight_path(start_points, end_points, alpha): +def straight_path( + start_points: np.ndarray, + end_points: np.ndarray, + alpha: float +) -> np.ndarray: """ Same function as interpolate, but renamed to reflect intent of being used to determine how a set of points move @@ -19,7 +25,10 @@ def straight_path(start_points, end_points, alpha): return interpolate(start_points, end_points, alpha) -def path_along_arc(arc_angle, axis=OUT): +def path_along_arc( + arc_angle: float, + axis: np.ndarray = OUT +) -> Callable[[np.ndarray, np.ndarray, float], np.ndarray]: """ If vect is vector from start to end, [vect[:,1], -vect[:,0]] is perpendicular to vect in the left direction. @@ -41,9 +50,9 @@ def path_along_arc(arc_angle, axis=OUT): return path -def clockwise_path(): +def clockwise_path() -> Callable[[np.ndarray, np.ndarray, float], np.ndarray]: return path_along_arc(-np.pi) -def counterclockwise_path(): +def counterclockwise_path() -> Callable[[np.ndarray, np.ndarray, float], np.ndarray]: return path_along_arc(np.pi) diff --git a/manimlib/utils/rate_functions.py b/manimlib/utils/rate_functions.py index 6f51bda0..79057734 100644 --- a/manimlib/utils/rate_functions.py +++ b/manimlib/utils/rate_functions.py @@ -1,44 +1,46 @@ +from typing import Callable + import numpy as np from manimlib.utils.bezier import bezier -def linear(t): +def linear(t: float) -> float: return t -def smooth(t): +def smooth(t: float) -> float: # Zero first and second derivatives at t=0 and t=1. # Equivalent to bezier([0, 0, 0, 1, 1, 1]) s = 1 - t return (t**3) * (10 * s * s + 5 * s * t + t * t) -def rush_into(t): +def rush_into(t: float) -> float: return 2 * smooth(0.5 * t) -def rush_from(t): +def rush_from(t: float) -> float: return 2 * smooth(0.5 * (t + 1)) - 1 -def slow_into(t): +def slow_into(t: float) -> float: return np.sqrt(1 - (1 - t) * (1 - t)) -def double_smooth(t): +def double_smooth(t: float) -> float: if t < 0.5: return 0.5 * smooth(2 * t) else: return 0.5 * (1 + smooth(2 * t - 1)) -def there_and_back(t): +def there_and_back(t: float) -> float: new_t = 2 * t if t < 0.5 else 2 * (1 - t) return smooth(new_t) -def there_and_back_with_pause(t, pause_ratio=1. / 3): +def there_and_back_with_pause(t: float, pause_ratio: float = 1. / 3) -> float: a = 1. / pause_ratio if t < 0.5 - pause_ratio / 2: return smooth(a * t) @@ -48,21 +50,28 @@ def there_and_back_with_pause(t, pause_ratio=1. / 3): return smooth(a - a * t) -def running_start(t, pull_factor=-0.5): +def running_start(t: float, pull_factor: float = -0.5) -> float: return bezier([0, 0, pull_factor, pull_factor, 1, 1, 1])(t) -def not_quite_there(func=smooth, proportion=0.7): +def not_quite_there( + func: Callable[[float], float] = smooth, + proportion: float = 0.7 +) -> Callable[[float], float]: def result(t): return proportion * func(t) return result -def wiggle(t, wiggles=2): +def wiggle(t: float, wiggles: float = 2) -> float: return there_and_back(t) * np.sin(wiggles * np.pi * t) -def squish_rate_func(func, a=0.4, b=0.6): +def squish_rate_func( + func: Callable[[float], float], + a: float = 0.4, + b: float = 0.6 +) -> Callable[[float], float]: def result(t): if a == b: return a @@ -81,11 +90,11 @@ def squish_rate_func(func, a=0.4, b=0.6): # "lingering", different from squish_rate_func's default params -def lingering(t): +def lingering(t: float) -> float: return squish_rate_func(lambda t: t, 0, 0.8)(t) -def exponential_decay(t, half_life=0.1): +def exponential_decay(t: float, half_life: float = 0.1) -> float: # The half-life should be rather small to minimize # the cut-off error at the end return 1 - np.exp(-t / half_life) diff --git a/manimlib/utils/sounds.py b/manimlib/utils/sounds.py index b73f9c33..79501284 100644 --- a/manimlib/utils/sounds.py +++ b/manimlib/utils/sounds.py @@ -2,7 +2,7 @@ from manimlib.utils.file_ops import find_file from manimlib.utils.directories import get_sound_dir -def get_full_sound_file_path(sound_file_name): +def get_full_sound_file_path(sound_file_name) -> str: return find_file( sound_file_name, directories=[get_sound_dir()], diff --git a/manimlib/utils/space_ops.py b/manimlib/utils/space_ops.py index 9c5e84d2..622a63e1 100644 --- a/manimlib/utils/space_ops.py +++ b/manimlib/utils/space_ops.py @@ -1,7 +1,11 @@ -import numpy as np +from __future__ import annotations + +import math import operator as op from functools import reduce -import math +from typing import Callable, Iterable, Sequence + +import numpy as np from mapbox_earcut import triangulate_float32 as earcut from manimlib.constants import RIGHT @@ -13,7 +17,7 @@ from manimlib.utils.iterables import adjacent_pairs from manimlib.utils.simple_functions import clip -def cross(v1, v2): +def cross(v1: np.ndarray, v2: np.ndarray) -> list[np.ndarray]: return [ v1[1] * v2[2] - v1[2] * v2[1], v1[2] * v2[0] - v1[0] * v2[2], @@ -21,7 +25,7 @@ def cross(v1, v2): ] -def get_norm(vect): +def get_norm(vect: np.ndarray) -> np.flaoting: return sum((x**2 for x in vect))**0.5 @@ -29,7 +33,7 @@ def get_norm(vect): # TODO, implement quaternion type -def quaternion_mult(*quats): +def quaternion_mult(*quats: Sequence[float]) -> list[float]: if len(quats) == 0: return [1, 0, 0, 0] result = quats[0] @@ -45,13 +49,19 @@ def quaternion_mult(*quats): return result -def quaternion_from_angle_axis(angle, axis, axis_normalized=False): +def quaternion_from_angle_axis( + angle: float, + axis: np.ndarray, + axis_normalized: bool = False +) -> list[float]: if not axis_normalized: axis = normalize(axis) return [math.cos(angle / 2), *(math.sin(angle / 2) * axis)] -def angle_axis_from_quaternion(quaternion): +def angle_axis_from_quaternion( + quaternion: Sequence[float] +) -> tuple[float, np.ndarray]: axis = normalize( quaternion[1:], fall_back=[1, 0, 0] @@ -62,14 +72,18 @@ def angle_axis_from_quaternion(quaternion): return angle, axis -def quaternion_conjugate(quaternion): +def quaternion_conjugate(quaternion: Iterable) -> list: result = list(quaternion) for i in range(1, len(result)): result[i] *= -1 return result -def rotate_vector(vector, angle, axis=OUT): +def rotate_vector( + vector: Iterable, + angle: float, + axis: np.ndarray = OUT +) -> np.ndarray | list[float]: if len(vector) == 2: # Use complex numbers...because why not z = complex(*vector) * np.exp(complex(0, angle)) @@ -88,13 +102,13 @@ def rotate_vector(vector, angle, axis=OUT): return result -def thick_diagonal(dim, thickness=2): +def thick_diagonal(dim: int, thickness: int = 2) -> np.ndarray: row_indices = np.arange(dim).repeat(dim).reshape((dim, dim)) col_indices = np.transpose(row_indices) return (np.abs(row_indices - col_indices) < thickness).astype('uint8') -def rotation_matrix_transpose_from_quaternion(quat): +def rotation_matrix_transpose_from_quaternion(quat: Iterable) -> list[list[float]]: quat_inv = quaternion_conjugate(quat) return [ quaternion_mult(quat, [0, *basis], quat_inv)[1:] @@ -106,11 +120,11 @@ def rotation_matrix_transpose_from_quaternion(quat): ] -def rotation_matrix_from_quaternion(quat): +def rotation_matrix_from_quaternion(quat: Iterable) -> np.ndarray: return np.transpose(rotation_matrix_transpose_from_quaternion(quat)) -def rotation_matrix_transpose(angle, axis): +def rotation_matrix_transpose(angle: float, axis: np.ndarray) -> list[list[flaot]]: if axis[0] == 0 and axis[1] == 0: # axis = [0, 0, z] case is common enough it's worth # having a shortcut @@ -126,14 +140,14 @@ def rotation_matrix_transpose(angle, axis): return rotation_matrix_transpose_from_quaternion(quat) -def rotation_matrix(angle, axis): +def rotation_matrix(angle: float, axis: np.ndarray) -> np.ndarray: """ Rotation in R^3 about a specified axis of rotation. """ return np.transpose(rotation_matrix_transpose(angle, axis)) -def rotation_about_z(angle): +def rotation_about_z(angle: float) -> list[list[float]]: return [ [math.cos(angle), -math.sin(angle), 0], [math.sin(angle), math.cos(angle), 0], @@ -141,7 +155,7 @@ def rotation_about_z(angle): ] -def z_to_vector(vector): +def z_to_vector(vector: np.ndarray) -> np.ndarray: """ Returns some matrix in SO(3) which takes the z-axis to the (normalized) vector provided as an argument @@ -156,7 +170,7 @@ def z_to_vector(vector): return rotation_matrix(angle, axis=axis) -def rotation_between_vectors(v1, v2): +def rotation_between_vectors(v1, v2) -> np.ndarray: if np.all(np.isclose(v1, v2)): return np.identity(3) return rotation_matrix( @@ -165,14 +179,14 @@ def rotation_between_vectors(v1, v2): ) -def angle_of_vector(vector): +def angle_of_vector(vector: Sequence[float]) -> float: """ Returns polar coordinate theta when vector is project on xy plane """ return np.angle(complex(*vector[:2])) -def angle_between_vectors(v1, v2): +def angle_between_vectors(v1: np.ndarray, v2: np.ndarray) -> float: """ Returns the angle between two 3D vectors. This angle will always be btw 0 and pi @@ -180,12 +194,15 @@ def angle_between_vectors(v1, v2): return math.acos(clip(np.dot(normalize(v1), normalize(v2)), -1, 1)) -def project_along_vector(point, vector): +def project_along_vector(point: np.ndarray, vector: np.ndarray) -> np.ndarray: matrix = np.identity(3) - np.outer(vector, vector) return np.dot(point, matrix.T) -def normalize(vect, fall_back=None): +def normalize( + vect: np.ndarray, + fall_back: np.ndarray | None = None +) -> np.ndarray: norm = get_norm(vect) if norm > 0: return np.array(vect) / norm @@ -195,7 +212,10 @@ def normalize(vect, fall_back=None): return np.zeros(len(vect)) -def normalize_along_axis(array, axis, fall_back=None): +def normalize_along_axis( + array: np.ndarray, + axis: np.ndarray, +) -> np.ndarray: norms = np.sqrt((array * array).sum(axis)) norms[norms == 0] = 1 buffed_norms = np.repeat(norms, array.shape[axis]).reshape(array.shape) @@ -203,7 +223,11 @@ def normalize_along_axis(array, axis, fall_back=None): return array -def get_unit_normal(v1, v2, tol=1e-6): +def get_unit_normal( + v1: np.ndarray, + v2: np.ndarray, + tol: float=1e-6 +) -> np.ndarray: v1 = normalize(v1) v2 = normalize(v2) cp = cross(v1, v2) @@ -221,7 +245,7 @@ def get_unit_normal(v1, v2, tol=1e-6): ### -def compass_directions(n=4, start_vect=RIGHT): +def compass_directions(n: int = 4, start_vect: np.ndarray = RIGHT) -> np.ndarray: angle = TAU / n return np.array([ rotate_vector(start_vect, k * angle) @@ -229,28 +253,36 @@ def compass_directions(n=4, start_vect=RIGHT): ]) -def complex_to_R3(complex_num): +def complex_to_R3(complex_num: complex) -> np.ndarray: return np.array((complex_num.real, complex_num.imag, 0)) -def R3_to_complex(point): +def R3_to_complex(point: Sequence[float]) -> complex: return complex(*point[:2]) -def complex_func_to_R3_func(complex_func): +def complex_func_to_R3_func( + complex_func: Callable[[complex], complex] +) -> Callable[[np.ndarray], np.ndarray]: return lambda p: complex_to_R3(complex_func(R3_to_complex(p))) -def center_of_mass(points): +def center_of_mass(points: Iterable[Sequence[float]]) -> np.ndarray: points = [np.array(point).astype("float") for point in points] return sum(points) / len(points) -def midpoint(point1, point2): +def midpoint( + point1: Sequence[float], + point2: Sequence[float] +) -> np.ndarray: return center_of_mass([point1, point2]) -def line_intersection(line1, line2): +def line_intersection( + line1: Sequence[Sequence[float]], + line2: Sequence[Sequence[float]] +) -> np.ndarray: """ return intersection point of two lines, each defined with a pair of vectors determining @@ -271,7 +303,13 @@ def line_intersection(line1, line2): return np.array([x, y, 0]) -def find_intersection(p0, v0, p1, v1, threshold=1e-5): +def find_intersection( + p0: Iterable[float], + v0: Iterable[float], + p1: Iterable[float], + v1: Iterable[float], + threshold: float = 1e-5 +) -> np.ndarray: """ Return the intersection of a line passing through p0 in direction v0 with one passing through p1 in direction v1. (Or array of intersections @@ -300,7 +338,11 @@ def find_intersection(p0, v0, p1, v1, threshold=1e-5): return p0 + ratio * v0 -def get_closest_point_on_line(a, b, p): +def get_closest_point_on_line( + a: np.ndarray, + b: np.ndarray, + p: np.ndarray +) -> np.ndarray: """ It returns point x such that x is on line ab and xp is perpendicular to ab. @@ -315,7 +357,7 @@ def get_closest_point_on_line(a, b, p): return ((t * a) + ((1 - t) * b)) -def get_winding_number(points): +def get_winding_number(points: Iterable[float]) -> float: total_angle = 0 for p1, p2 in adjacent_pairs(points): d_angle = angle_of_vector(p2) - angle_of_vector(p1) @@ -326,14 +368,18 @@ def get_winding_number(points): ## -def cross2d(a, b): +def cross2d(a: np.ndarray, b: np.ndarray) -> np.ndarray: if len(a.shape) == 2: return a[:, 0] * b[:, 1] - a[:, 1] * b[:, 0] else: return a[0] * b[1] - b[0] * a[1] -def tri_area(a, b, c): +def tri_area( + a: Sequence[float], + b: Sequence[float], + c: Sequence[float] +) -> float: return 0.5 * abs( a[0] * (b[1] - c[1]) + b[0] * (c[1] - a[1]) + @@ -341,7 +387,12 @@ def tri_area(a, b, c): ) -def is_inside_triangle(p, a, b, c): +def is_inside_triangle( + p: np.ndarray, + a: np.ndarray, + b: np.ndarray, + c: np.ndarray +) -> bool: """ Test if point p is inside triangle abc """ @@ -353,12 +404,12 @@ def is_inside_triangle(p, a, b, c): return np.all(crosses > 0) or np.all(crosses < 0) -def norm_squared(v): +def norm_squared(v: Sequence[float]) -> float: return v[0] * v[0] + v[1] * v[1] + v[2] * v[2] # TODO, fails for polygons drawn over themselves -def earclip_triangulation(verts, ring_ends): +def earclip_triangulation(verts: np.ndarray, ring_ends: list[int]) -> list: """ Returns a list of indices giving a triangulation of a polygon, potentially with holes From 35025631eb01b980e2d11581f1493584cb9eed31 Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Sun, 13 Feb 2022 12:56:03 +0800 Subject: [PATCH 02/27] chore: fix type hint of bezier --- manimlib/utils/bezier.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/manimlib/utils/bezier.py b/manimlib/utils/bezier.py index b374b13d..744ef9f1 100644 --- a/manimlib/utils/bezier.py +++ b/manimlib/utils/bezier.py @@ -14,7 +14,9 @@ from manimlib.logger import log CLOSED_THRESHOLD = 0.001 T = TypeVar("T") -def bezier(points: Iterable) -> Callable[[float], float | Iterable]: +def bezier( + points: Iterable[float | np.ndarray] +) -> Callable[[float], float | np.ndarray]: n = len(points) - 1 def result(t): From e78113373a4717b84f80368a6b362f84cb7e0c65 Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Sun, 13 Feb 2022 15:11:35 +0800 Subject: [PATCH 03/27] chore: add type hints to manimlib.mobject.mobject --- manimlib/mobject/mobject.py | 545 ++++++++++++++++++++++-------------- 1 file changed, 339 insertions(+), 206 deletions(-) diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index c49a65b7..a1426ba7 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -1,12 +1,16 @@ -import copy -import itertools as it -import random -import sys -import moderngl -from functools import wraps -from collections.abc import Iterable +from __future__ import annotations +import sys +import copy +import random +import itertools as it +from functools import wraps +from typing import Iterable, TypeVar, Callable, Union, Sequence + +import colour +import moderngl import numpy as np +import numpy.typing as npt from manimlib.constants import * from manimlib.utils.color import color_gradient @@ -35,6 +39,13 @@ from manimlib.event_handler.event_listner import EventListner from manimlib.event_handler.event_type import EventType +Self = TypeVar("Self", bound="Mobject") +TimeBasedUpdater = Callable[["Mobject", float], None] +NonTimeUpdater = Callable[["Mobject"], None] +Updater = Union[TimeBasedUpdater, NonTimeUpdater] +Color = Union[str, colour.Color, Sequence[float]] + + class Mobject(object): """ Mathematical Object @@ -66,11 +77,11 @@ class Mobject(object): def __init__(self, **kwargs): digest_config(self, kwargs) - self.submobjects = [] - self.parents = [] - self.family = [self] - self.locked_data_keys = set() - self.needs_new_bounding_box = True + self.submobjects: list["Mobject"] = [] + self.parents: list["Mobject"] = [] + self.family: list["Mobject"] = [self] + self.locked_data_keys: set[str] = set() + self.needs_new_bounding_box: bool = True self.init_data() self.init_uniforms() @@ -86,23 +97,23 @@ class Mobject(object): def __str__(self): return self.__class__.__name__ - def __add__(self, other: 'Mobject') -> 'Mobject': + def __add__(self, other: "Mobject") -> "Mobject": assert(isinstance(other, Mobject)) return self.get_group_class()(self, other) - def __mul__(self, other: 'int') -> 'Mobject': + def __mul__(self, other: int) -> "Mobject": assert(isinstance(other, int)) return self.replicate(other) def init_data(self): - self.data = { + self.data: dict[str, np.ndarray] = { "points": np.zeros((0, 3)), "bounding_box": np.zeros((3, 3)), "rgbas": np.zeros((1, 4)), } def init_uniforms(self): - self.uniforms = { + self.uniforms: dict[str, float] = { "is_fixed_in_frame": float(self.is_fixed_in_frame), "gloss": self.gloss, "shadow": self.shadow, @@ -116,12 +127,12 @@ class Mobject(object): # Typically implemented in subclass, unlpess purposefully left blank pass - def set_data(self, data): + def set_data(self, data: dict): for key in data: self.data[key] = data[key].copy() return self - def set_uniforms(self, uniforms): + def set_uniforms(self, uniforms: dict): for key in uniforms: self.uniforms[key] = uniforms[key] # Copy? return self @@ -133,13 +144,17 @@ class Mobject(object): # Only these methods should directly affect points - def resize_points(self, new_length, resize_func=resize_array): + def resize_points( + self, + new_length: int, + resize_func: Callable[[np.ndarray, int], np.ndarray] = resize_array + ): if new_length != len(self.data["points"]): self.data["points"] = resize_func(self.data["points"], new_length) self.refresh_bounding_box() return self - def set_points(self, points): + def set_points(self, points: npt.ArrayLike): if len(points) == len(self.data["points"]): self.data["points"][:] = points elif isinstance(points, np.ndarray): @@ -149,7 +164,7 @@ class Mobject(object): self.refresh_bounding_box() return self - def append_points(self, new_points): + def append_points(self, new_points: npt.ArrayLike): self.data["points"] = np.vstack([self.data["points"], new_points]) self.refresh_bounding_box() return self @@ -161,7 +176,13 @@ class Mobject(object): self.refresh_unit_normal() return self - def apply_points_function(self, func, about_point=None, about_edge=ORIGIN, works_on_bounding_box=False): + def apply_points_function( + self, + func: Callable[[np.ndarray], np.ndarray], + about_point: np.ndarray = None, + about_edge: np.ndarray = ORIGIN, + works_on_bounding_box: bool = False + ): if about_point is None and about_edge is not None: about_point = self.get_bounding_box_point(about_edge) @@ -187,35 +208,35 @@ class Mobject(object): # Others related to points - def match_points(self, mobject): + def match_points(self, mobject: "Mobject"): self.set_points(mobject.get_points()) return self - def get_points(self): + def get_points(self) -> np.ndarray: return self.data["points"] - def clear_points(self): + def clear_points(self) -> None: self.resize_points(0) - def get_num_points(self): + def get_num_points(self) -> int: return len(self.data["points"]) - def get_all_points(self): + def get_all_points(self) -> np.ndarray: if self.submobjects: return np.vstack([sm.get_points() for sm in self.get_family()]) else: return self.get_points() - def has_points(self): + def has_points(self) -> bool: return self.get_num_points() > 0 - def get_bounding_box(self): + def get_bounding_box(self) -> np.ndarray: if self.needs_new_bounding_box: self.data["bounding_box"] = self.compute_bounding_box() self.needs_new_bounding_box = False return self.data["bounding_box"] - def compute_bounding_box(self): + def compute_bounding_box(self) -> np.ndarray: all_points = np.vstack([ self.get_points(), *( @@ -233,7 +254,11 @@ class Mobject(object): mids = (mins + maxs) / 2 return np.array([mins, mids, maxs]) - def refresh_bounding_box(self, recurse_down=False, recurse_up=True): + def refresh_bounding_box( + self, + recurse_down: bool = False, + recurse_up: bool = True + ): for mob in self.get_family(recurse_down): mob.needs_new_bounding_box = True if recurse_up: @@ -241,7 +266,11 @@ class Mobject(object): parent.refresh_bounding_box() return self - def is_point_touching(self, point, buff=MED_SMALL_BUFF): + def is_point_touching( + self, + point: np.ndarray, + buff: float = MED_SMALL_BUFF + ) -> bool: bb = self.get_bounding_box() mins = (bb[0] - buff) maxs = (bb[2] + buff) @@ -273,7 +302,7 @@ class Mobject(object): parent.assemble_family() return self - def get_family(self, recurse=True): + def get_family(self, recurse: bool = True): if recurse: return self.family else: @@ -282,7 +311,7 @@ class Mobject(object): def family_members_with_points(self): return [m for m in self.get_family() if m.has_points()] - def add(self, *mobjects): + def add(self, *mobjects: "Mobject"): if self in mobjects: raise Exception("Mobject cannot contain self") for mobject in mobjects: @@ -293,7 +322,7 @@ class Mobject(object): self.assemble_family() return self - def remove(self, *mobjects): + def remove(self, *mobjects: "Mobject"): for mobject in mobjects: if mobject in self.submobjects: self.submobjects.remove(mobject) @@ -302,11 +331,11 @@ class Mobject(object): self.assemble_family() return self - def add_to_back(self, *mobjects): + def add_to_back(self, *mobjects: "Mobject"): self.set_submobjects(list_update(mobjects, self.submobjects)) return self - def replace_submobject(self, index, new_submob): + def replace_submobject(self, index: int, new_submob: "Mobject"): old_submob = self.submobjects[index] if self in old_submob.parents: old_submob.parents.remove(self) @@ -314,12 +343,12 @@ class Mobject(object): self.assemble_family() return self - def insert_submobject(self, index, new_submob): + def insert_submobject(self, index: int, new_submob: "Mobject"): self.submobjects.insert(index, new_submob) self.assemble_family() return self - def set_submobjects(self, submobject_list): + def set_submobjects(self, submobject_list: list["Mobject"]): self.remove(*self.submobjects) self.add(*submobject_list) return self @@ -335,22 +364,31 @@ class Mobject(object): # Submobject organization - def arrange(self, direction=RIGHT, center=True, **kwargs): + def arrange( + self, + direction: np.ndarray = RIGHT, + center: bool = True, + **kwargs + ): for m1, m2 in zip(self.submobjects, self.submobjects[1:]): m2.next_to(m1, direction, **kwargs) if center: self.center() return self - def arrange_in_grid(self, n_rows=None, n_cols=None, - buff=None, - h_buff=None, - v_buff=None, - buff_ratio=None, - h_buff_ratio=0.5, - v_buff_ratio=0.5, - aligned_edge=ORIGIN, - fill_rows_first=True): + def arrange_in_grid( + self, + n_rows: int | None = None, + n_cols: int | None = None, + buff: float | None = None, + h_buff: float | None = None, + v_buff: float | None = None, + buff_ratio: float | None = None, + h_buff_ratio: float =0.5, + v_buff_ratio: float = 0.5, + aligned_edge: np.ndarray = ORIGIN, + fill_rows_first: bool = True + ): submobs = self.submobjects if n_rows is None and n_cols is None: n_rows = int(np.sqrt(len(submobs))) @@ -384,12 +422,12 @@ class Mobject(object): self.center() return self - def replicate(self, n): + def replicate(self, n: int) -> Group: return self.get_group_class()( *(self.copy() for x in range(n)) ) - def get_grid(self, n_rows, n_cols, height=None, **kwargs): + def get_grid(self, n_rows: int, n_cols: int, height: float | None = None, **kwargs): """ Returns a new mobject containing multiple copies of this one arranged in a grid @@ -400,7 +438,11 @@ class Mobject(object): grid.set_height(height) return grid - def sort(self, point_to_num_func=lambda p: p[0], submob_func=None): + def sort( + self, + point_to_num_func: Callable[[np.ndarray], float] = lambda p: p[0], + submob_func: Callable[["Mobject"]] | None = None + ): if submob_func is not None: self.submobjects.sort(key=submob_func) else: @@ -408,7 +450,7 @@ class Mobject(object): self.assemble_family() return self - def shuffle(self, recurse=False): + def shuffle(self, recurse: bool = False): if recurse: for submob in self.submobjects: submob.shuffle(recurse=True) @@ -461,7 +503,7 @@ class Mobject(object): self.parents = parents return result - def generate_target(self, use_deepcopy=False): + def generate_target(self, use_deepcopy: bool = False): self.target = None # Prevent exponential explosion if use_deepcopy: self.target = self.deepcopy() @@ -469,7 +511,7 @@ class Mobject(object): self.target = self.copy() return self.target - def save_state(self, use_deepcopy=False): + def save_state(self, use_deepcopy: bool = False): if hasattr(self, "saved_state"): # Prevent exponential growth of data self.saved_state = None @@ -488,12 +530,12 @@ class Mobject(object): # Updating def init_updaters(self): - self.time_based_updaters = [] - self.non_time_updaters = [] - self.has_updaters = False - self.updating_suspended = False + self.time_based_updaters: list[TimeBasedUpdater] = [] + self.non_time_updaters: list[NonTimeUpdater] = [] + self.has_updaters: bool = False + self.updating_suspended: bool = False - def update(self, dt=0, recurse=True): + def update(self, dt: float = 0, recurse: bool = True): if not self.has_updaters or self.updating_suspended: return self for updater in self.time_based_updaters: @@ -505,19 +547,24 @@ class Mobject(object): submob.update(dt, recurse) return self - def get_time_based_updaters(self): + def get_time_based_updaters(self) -> list[TimeBasedUpdater]: return self.time_based_updaters - def has_time_based_updater(self): + def has_time_based_updater(self) -> bool: return len(self.time_based_updaters) > 0 - def get_updaters(self): + def get_updaters(self) -> list[Updater]: return self.time_based_updaters + self.non_time_updaters - def get_family_updaters(self): + def get_family_updaters(self) -> list[Updater]: return list(it.chain(*[sm.get_updaters() for sm in self.get_family()])) - def add_updater(self, update_function, index=None, call_updater=True): + def add_updater( + self, + update_function: Updater, + index: int | None = None, + call_updater: bool = True + ): if "dt" in get_parameters(update_function): updater_list = self.time_based_updaters else: @@ -533,14 +580,14 @@ class Mobject(object): self.update(dt=0) return self - def remove_updater(self, update_function): + def remove_updater(self, update_function: Updater): for updater_list in [self.time_based_updaters, self.non_time_updaters]: while update_function in updater_list: updater_list.remove(update_function) self.refresh_has_updater_status() return self - def clear_updaters(self, recurse=True): + def clear_updaters(self, recurse: bool = True): self.time_based_updaters = [] self.non_time_updaters = [] self.refresh_has_updater_status() @@ -549,20 +596,20 @@ class Mobject(object): submob.clear_updaters() return self - def match_updaters(self, mobject): + def match_updaters(self, mobject: "Mobject"): self.clear_updaters() for updater in mobject.get_updaters(): self.add_updater(updater) return self - def suspend_updating(self, recurse=True): + def suspend_updating(self, recurse: bool = True): self.updating_suspended = True if recurse: for submob in self.submobjects: submob.suspend_updating(recurse) return self - def resume_updating(self, recurse=True, call_updater=True): + def resume_updating(self, recurse: bool = True, call_updater: bool = True): self.updating_suspended = False if recurse: for submob in self.submobjects: @@ -579,7 +626,7 @@ class Mobject(object): # Transforming operations - def shift(self, vector): + def shift(self, vector: np.ndarray): self.apply_points_function( lambda points: points + vector, about_edge=None, @@ -587,7 +634,13 @@ class Mobject(object): ) return self - def scale(self, scale_factor, min_scale_factor=1e-8, about_point=None, about_edge=ORIGIN): + def scale( + self, + scale_factor: float | npt.ArrayLike, + min_scale_factor: float = 1e-8, + about_point: np.ndarray | None = None, + about_edge: np.ndarray = ORIGIN + ): """ Default behavior is to scale about the center of the mobject. The argument about_edge can be a vector, indicating which side of @@ -597,7 +650,7 @@ class Mobject(object): Otherwise, if about_point is given a value, scaling is done with respect to that point. """ - if isinstance(scale_factor, Iterable): + if isinstance(scale_factor, npt.ArrayLike): scale_factor = np.array(scale_factor).clip(min=min_scale_factor) else: scale_factor = max(scale_factor, min_scale_factor) @@ -616,28 +669,35 @@ class Mobject(object): # any other changes when the size gets altered pass - def stretch(self, factor, dim, **kwargs): + def stretch(self, factor: float, dim: int, **kwargs): def func(points): points[:, dim] *= factor return points self.apply_points_function(func, works_on_bounding_box=True, **kwargs) return self - def rotate_about_origin(self, angle, axis=OUT): + def rotate_about_origin(self, angle: float, axis: np.ndarray = OUT): return self.rotate(angle, axis, about_point=ORIGIN) - def rotate(self, angle, axis=OUT, **kwargs): + def rotate( + self, + angle: float, + axis: np.ndarray = OUT, + about_point: np.ndarray | None = None, + **kwargs + ): rot_matrix_T = rotation_matrix_transpose(angle, axis) self.apply_points_function( lambda points: np.dot(points, rot_matrix_T), + about_point, **kwargs ) return self - def flip(self, axis=UP, **kwargs): + def flip(self, axis: np.ndarray = UP, **kwargs): return self.rotate(TAU / 2, axis, **kwargs) - def apply_function(self, function, **kwargs): + def apply_function(self, function: Callable[[np.ndarray], np.ndarray], **kwargs): # Default to applying matrix about the origin, not mobjects center if len(kwargs) == 0: kwargs["about_point"] = ORIGIN @@ -647,16 +707,19 @@ class Mobject(object): ) return self - def apply_function_to_position(self, function): + def apply_function_to_position(self, function: Callable[[np.ndarray], np.ndarray]): self.move_to(function(self.get_center())) return self - def apply_function_to_submobject_positions(self, function): + def apply_function_to_submobject_positions( + self, + function: Callable[[np.ndarray], np.ndarray] + ): for submob in self.submobjects: submob.apply_function_to_position(function) return self - def apply_matrix(self, matrix, **kwargs): + def apply_matrix(self, matrix: npt.ArrayLike, **kwargs): # Default to applying matrix about the origin, not mobjects center if ("about_point" not in kwargs) and ("about_edge" not in kwargs): kwargs["about_point"] = ORIGIN @@ -669,7 +732,7 @@ class Mobject(object): ) return self - def apply_complex_function(self, function, **kwargs): + def apply_complex_function(self, function: Callable[[complex], complex], **kwargs): def R3_func(point): x, y, z = point xy_complex = function(complex(x, y)) @@ -678,9 +741,14 @@ class Mobject(object): xy_complex.imag, z ] - return self.apply_function(R3_func) + return self.apply_function(R3_func, **kwargs) - def wag(self, direction=RIGHT, axis=DOWN, wag_factor=1.0): + def wag( + self, + direction: np.ndarray = RIGHT, + axis: np.ndarray = DOWN, + wag_factor: float = 1.0 + ): for mob in self.family_members_with_points(): alphas = np.dot(mob.get_points(), np.transpose(axis)) alphas -= min(alphas) @@ -698,7 +766,11 @@ class Mobject(object): self.shift(-self.get_center()) return self - def align_on_border(self, direction, buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER): + def align_on_border( + self, + direction: np.ndarray, + buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER + ): """ Direction just needs to be a vector pointing towards side or corner in the 2d plane. @@ -710,20 +782,30 @@ class Mobject(object): self.shift(shift_val) return self - def to_corner(self, corner=LEFT + DOWN, buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER): + def to_corner( + self, + corner: np.ndarray = LEFT + DOWN, + buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER + ): return self.align_on_border(corner, buff) - def to_edge(self, edge=LEFT, buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER): + def to_edge( + self, + edge: np.ndarray = LEFT, + buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER + ): return self.align_on_border(edge, buff) - def next_to(self, mobject_or_point, - direction=RIGHT, - buff=DEFAULT_MOBJECT_TO_MOBJECT_BUFFER, - aligned_edge=ORIGIN, - submobject_to_align=None, - index_of_submobject_to_align=None, - coor_mask=np.array([1, 1, 1]), - ): + def next_to( + self, + mobject_or_point: "Mobject" | np.ndarray, + direction: np.ndarray = RIGHT, + buff: float = DEFAULT_MOBJECT_TO_MOBJECT_BUFFER, + aligned_edge: np.ndarray = ORIGIN, + submobject_to_align: "Mobject" | None = None, + index_of_submobject_to_align: int | slice | None = None, + coor_mask: np.ndarray = np.array([1, 1, 1]), + ): if isinstance(mobject_or_point, Mobject): mob = mobject_or_point if index_of_submobject_to_align is not None: @@ -767,14 +849,14 @@ class Mobject(object): return True return False - def stretch_about_point(self, factor, dim, point): + def stretch_about_point(self, factor: float, dim: int, point: np.ndarray): return self.stretch(factor, dim, about_point=point) - def stretch_in_place(self, factor, dim): + def stretch_in_place(self, factor: float, dim: int): # Now redundant with stretch return self.stretch(factor, dim) - def rescale_to_fit(self, length, dim, stretch=False, **kwargs): + def rescale_to_fit(self, length: float, dim: int, stretch: bool = False, **kwargs): old_length = self.length_over_dim(dim) if old_length == 0: return self @@ -784,63 +866,67 @@ class Mobject(object): self.scale(length / old_length, **kwargs) return self - def stretch_to_fit_width(self, width, **kwargs): + def stretch_to_fit_width(self, width: float, **kwargs): return self.rescale_to_fit(width, 0, stretch=True, **kwargs) - def stretch_to_fit_height(self, height, **kwargs): + def stretch_to_fit_height(self, height: float, **kwargs): return self.rescale_to_fit(height, 1, stretch=True, **kwargs) - def stretch_to_fit_depth(self, depth, **kwargs): + def stretch_to_fit_depth(self, depth: float, **kwargs): return self.rescale_to_fit(depth, 2, stretch=True, **kwargs) - def set_width(self, width, stretch=False, **kwargs): + def set_width(self, width: float, stretch: bool = False, **kwargs): return self.rescale_to_fit(width, 0, stretch=stretch, **kwargs) - def set_height(self, height, stretch=False, **kwargs): + def set_height(self, height: float, stretch: bool = False, **kwargs): return self.rescale_to_fit(height, 1, stretch=stretch, **kwargs) - def set_depth(self, depth, stretch=False, **kwargs): + def set_depth(self, depth: float, stretch: bool = False, **kwargs): return self.rescale_to_fit(depth, 2, stretch=stretch, **kwargs) - def set_max_width(self, max_width, **kwargs): + def set_max_width(self, max_width: float, **kwargs): if self.get_width() > max_width: self.set_width(max_width, **kwargs) return self - def set_max_height(self, max_height, **kwargs): + def set_max_height(self, max_height: float, **kwargs): if self.get_height() > max_height: self.set_height(max_height, **kwargs) return self - def set_max_depth(self, max_depth, **kwargs): + def set_max_depth(self, max_depth: float, **kwargs): if self.get_depth() > max_depth: self.set_depth(max_depth, **kwargs) return self - def set_coord(self, value, dim, direction=ORIGIN): + def set_coord(self, value: float, dim: int, direction: np.ndarray = ORIGIN): curr = self.get_coord(dim, direction) shift_vect = np.zeros(self.dim) shift_vect[dim] = value - curr self.shift(shift_vect) return self - def set_x(self, x, direction=ORIGIN): + def set_x(self, x: float, direction: np.ndarray = ORIGIN): return self.set_coord(x, 0, direction) - def set_y(self, y, direction=ORIGIN): + def set_y(self, y: float, direction: np.ndarray = ORIGIN): return self.set_coord(y, 1, direction) - def set_z(self, z, direction=ORIGIN): + def set_z(self, z: float, direction: np.ndarray = ORIGIN): return self.set_coord(z, 2, direction) - def space_out_submobjects(self, factor=1.5, **kwargs): + def space_out_submobjects(self, factor: float = 1.5, **kwargs): self.scale(factor, **kwargs) for submob in self.submobjects: submob.scale(1. / factor) return self - def move_to(self, point_or_mobject, aligned_edge=ORIGIN, - coor_mask=np.array([1, 1, 1])): + def move_to( + self, + point_or_mobject: "Mobject" | np.ndarray, + aligned_edge: np.ndarray = ORIGIN, + coor_mask: np.ndarray = np.array([1, 1, 1]) + ): if isinstance(point_or_mobject, Mobject): target = point_or_mobject.get_bounding_box_point(aligned_edge) else: @@ -849,7 +935,7 @@ class Mobject(object): self.shift((target - point_to_align) * coor_mask) return self - def replace(self, mobject, dim_to_match=0, stretch=False): + def replace(self, mobject: "Mobject", dim_to_match: int = 0, stretch: bool = False): if not mobject.get_num_points() and not mobject.submobjects: self.scale(0) return self @@ -865,16 +951,19 @@ class Mobject(object): self.shift(mobject.get_center() - self.get_center()) return self - def surround(self, mobject, - dim_to_match=0, - stretch=False, - buff=MED_SMALL_BUFF): + def surround( + self, + mobject: "Mobject", + dim_to_match: int = 0, + stretch: bool = False, + buff: float = MED_SMALL_BUFF + ): self.replace(mobject, dim_to_match, stretch) length = mobject.length_over_dim(dim_to_match) self.scale((length + buff) / length) return self - def put_start_and_end_on(self, start, end): + def put_start_and_end_on(self, start: np.ndarray, end: np.ndarray): curr_start, curr_end = self.get_start_and_end() curr_vect = curr_end - curr_start if np.all(curr_vect == 0): @@ -896,12 +985,21 @@ class Mobject(object): # Color functions - def set_rgba_array(self, rgba_array, name="rgbas", recurse=False): + def set_rgba_array( + self, + rgba_array: npt.ArrayLike, + name: str = "rgbas", + recurse: bool = False + ): for mob in self.get_family(recurse): mob.data[name] = np.array(rgba_array) return self - def set_color_by_rgba_func(self, func, recurse=True): + def set_color_by_rgba_func( + self, + func: Callable[[np.ndarray], Sequence[float]], + recurse: bool = True + ): """ Func should take in a point in R3 and output an rgba value """ @@ -910,7 +1008,12 @@ class Mobject(object): mob.set_rgba_array(rgba_array) return self - def set_color_by_rgb_func(self, func, opacity=1, recurse=True): + def set_color_by_rgb_func( + self, + func: Callable[[np.ndarray], Sequence[float]], + opacity: float = 1, + recurse: bool = True + ): """ Func should take in a point in R3 and output an rgb value """ @@ -919,7 +1022,13 @@ class Mobject(object): mob.set_rgba_array(rgba_array) return self - def set_rgba_array_by_color(self, color=None, opacity=None, name="rgbas", recurse=True): + def set_rgba_array_by_color( + self, + color: Color | None = None, + opacity: float | None = None, + name: str = "rgbas", + recurse: bool = True + ): if color is not None: rgbs = np.array([color_to_rgb(c) for c in listify(color)]) if opacity is not None: @@ -947,7 +1056,7 @@ class Mobject(object): mob.data[name] = rgbas.copy() return self - def set_color(self, color, opacity=None, recurse=True): + def set_color(self, color: Color, opacity: float | None = None, recurse: bool = True): self.set_rgba_array_by_color(color, opacity, recurse=False) # Recurse to submobjects differently from how set_rgba_array_by_color # in case they implement set_color differently @@ -956,24 +1065,24 @@ class Mobject(object): submob.set_color(color, recurse=True) return self - def set_opacity(self, opacity, recurse=True): + def set_opacity(self, opacity: float, recurse: bool = True): self.set_rgba_array_by_color(color=None, opacity=opacity, recurse=False) if recurse: for submob in self.submobjects: submob.set_opacity(opacity, recurse=True) return self - def get_color(self): + def get_color(self) -> str: return rgb_to_hex(self.data["rgbas"][0, :3]) - def get_opacity(self): + def get_opacity(self) -> float: return self.data["rgbas"][0, 3] - def set_color_by_gradient(self, *colors): + def set_color_by_gradient(self, *colors: Color): self.set_submobject_colors_by_gradient(*colors) return self - def set_submobject_colors_by_gradient(self, *colors): + def set_submobject_colors_by_gradient(self, *colors: Color): if len(colors) == 0: raise Exception("Need at least one color") elif len(colors) == 1: @@ -987,36 +1096,41 @@ class Mobject(object): mob.set_color(color) return self - def fade(self, darkness=0.5, recurse=True): + def fade(self, darkness: float = 0.5, recurse: bool = True): self.set_opacity(1.0 - darkness, recurse=recurse) - def get_reflectiveness(self): + def get_reflectiveness(self) -> float: return self.uniforms["reflectiveness"] - def set_reflectiveness(self, reflectiveness, recurse=True): + def set_reflectiveness(self, reflectiveness: float, recurse: bool = True): for mob in self.get_family(recurse): mob.uniforms["reflectiveness"] = reflectiveness return self - def get_shadow(self): + def get_shadow(self) -> float: return self.uniforms["shadow"] - def set_shadow(self, shadow, recurse=True): + def set_shadow(self, shadow: float, recurse: bool = True): for mob in self.get_family(recurse): mob.uniforms["shadow"] = shadow return self - def get_gloss(self): + def get_gloss(self) -> float: return self.uniforms["gloss"] - def set_gloss(self, gloss, recurse=True): + def set_gloss(self, gloss: float, recurse: bool = True): for mob in self.get_family(recurse): mob.uniforms["gloss"] = gloss return self # Background rectangle - def add_background_rectangle(self, color=None, opacity=0.75, **kwargs): + def add_background_rectangle( + self, + color: Color | None = None, + opacity: float = 0.75, + **kwargs + ): # TODO, this does not behave well when the mobject has points, # since it gets displayed on top from manimlib.mobject.shape_matchers import BackgroundRectangle @@ -1040,7 +1154,7 @@ class Mobject(object): # Getters - def get_bounding_box_point(self, direction): + def get_bounding_box_point(self, direction: np.ndarray) -> np.ndarray: bb = self.get_bounding_box() indices = (np.sign(direction) + 1).astype(int) return np.array([ @@ -1048,19 +1162,19 @@ class Mobject(object): for i in range(3) ]) - def get_edge_center(self, direction): + def get_edge_center(self, direction: np.ndarray) -> np.ndarray: return self.get_bounding_box_point(direction) - def get_corner(self, direction): + def get_corner(self, direction: np.ndarray) -> np.ndarray: return self.get_bounding_box_point(direction) - def get_center(self): + def get_center(self) -> np.ndarray: return self.get_bounding_box()[1] - def get_center_of_mass(self): + def get_center_of_mass(self) -> np.ndarray: return self.get_all_points().mean(0) - def get_boundary_point(self, direction): + def get_boundary_point(self, direction: np.ndarray) -> np.ndarray: all_points = self.get_all_points() boundary_directions = all_points - self.get_center() norms = np.linalg.norm(boundary_directions, axis=1) @@ -1068,7 +1182,7 @@ class Mobject(object): index = np.argmax(np.dot(boundary_directions, np.array(direction).T)) return all_points[index] - def get_continuous_bounding_box_point(self, direction): + def get_continuous_bounding_box_point(self, direction: np.ndarray) -> np.ndarray: dl, center, ur = self.get_bounding_box() corner_vect = (ur - center) return center + direction / np.max(np.abs(np.true_divide( @@ -1077,66 +1191,66 @@ class Mobject(object): where=((corner_vect) != 0) ))) - def get_top(self): + def get_top(self) -> np.ndarray: return self.get_edge_center(UP) - def get_bottom(self): + def get_bottom(self) -> np.ndarray: return self.get_edge_center(DOWN) - def get_right(self): + def get_right(self) -> np.ndarray: return self.get_edge_center(RIGHT) - def get_left(self): + def get_left(self) -> np.ndarray: return self.get_edge_center(LEFT) - def get_zenith(self): + def get_zenith(self) -> np.ndarray: return self.get_edge_center(OUT) - def get_nadir(self): + def get_nadir(self) -> np.ndarray: return self.get_edge_center(IN) - def length_over_dim(self, dim): + def length_over_dim(self, dim: int) -> float: bb = self.get_bounding_box() return abs((bb[2] - bb[0])[dim]) - def get_width(self): + def get_width(self) -> float: return self.length_over_dim(0) - def get_height(self): + def get_height(self) -> float: return self.length_over_dim(1) - def get_depth(self): + def get_depth(self) -> float: return self.length_over_dim(2) - def get_coord(self, dim, direction=ORIGIN): + def get_coord(self, dim: int, direction: np.ndarray = ORIGIN) -> float: """ Meant to generalize get_x, get_y, get_z """ return self.get_bounding_box_point(direction)[dim] - def get_x(self, direction=ORIGIN): + def get_x(self, direction=ORIGIN) -> float: return self.get_coord(0, direction) - def get_y(self, direction=ORIGIN): + def get_y(self, direction=ORIGIN) -> float: return self.get_coord(1, direction) - def get_z(self, direction=ORIGIN): + def get_z(self, direction=ORIGIN) -> float: return self.get_coord(2, direction) - def get_start(self): + def get_start(self) -> np.ndarray: self.throw_error_if_no_points() return self.get_points()[0].copy() - def get_end(self): + def get_end(self) -> np.ndarray: self.throw_error_if_no_points() return self.get_points()[-1].copy() - def get_start_and_end(self): + def get_start_and_end(self) -> tuple(np.ndarray, np.ndarray): self.throw_error_if_no_points() points = self.get_points() return (points[0].copy(), points[-1].copy()) - def point_from_proportion(self, alpha): + def point_from_proportion(self, alpha: float) -> np.ndarray: points = self.get_points() i, subalpha = integer_interpolate(0, len(points) - 1, alpha) return interpolate(points[i], points[i + 1], subalpha) @@ -1145,7 +1259,7 @@ class Mobject(object): """Abbreviation fo point_from_proportion""" return self.point_from_proportion(alpha) - def get_pieces(self, n_pieces): + def get_pieces(self, n_pieces: int) -> Group: template = self.copy() template.set_submobjects([]) alphas = np.linspace(0, 1, n_pieces + 1) @@ -1163,41 +1277,45 @@ class Mobject(object): # Match other mobject properties - def match_color(self, mobject): + def match_color(self, mobject: "Mobject"): return self.set_color(mobject.get_color()) - def match_dim_size(self, mobject, dim, **kwargs): + def match_dim_size(self, mobject: "Mobject", dim: int, **kwargs): return self.rescale_to_fit( mobject.length_over_dim(dim), dim, **kwargs ) - def match_width(self, mobject, **kwargs): + def match_width(self, mobject: "Mobject", **kwargs): return self.match_dim_size(mobject, 0, **kwargs) - def match_height(self, mobject, **kwargs): + def match_height(self, mobject: "Mobject", **kwargs): return self.match_dim_size(mobject, 1, **kwargs) - def match_depth(self, mobject, **kwargs): + def match_depth(self, mobject: "Mobject", **kwargs): return self.match_dim_size(mobject, 2, **kwargs) - def match_coord(self, mobject, dim, direction=ORIGIN): + def match_coord(self, mobject: "Mobject", dim: int, direction: np.ndarray = ORIGIN): return self.set_coord( mobject.get_coord(dim, direction), dim=dim, direction=direction, ) - def match_x(self, mobject, direction=ORIGIN): + def match_x(self, mobject: "Mobject", direction: np.ndarray = ORIGIN): return self.match_coord(mobject, 0, direction) - def match_y(self, mobject, direction=ORIGIN): + def match_y(self, mobject: "Mobject", direction: np.ndarray = ORIGIN): return self.match_coord(mobject, 1, direction) - def match_z(self, mobject, direction=ORIGIN): + def match_z(self, mobject: "Mobject", direction: np.ndarray = ORIGIN): return self.match_coord(mobject, 2, direction) - def align_to(self, mobject_or_point, direction=ORIGIN): + def align_to( + self, + mobject_or_point: "Mobject" | np.ndarray, + direction: np.ndarray = ORIGIN + ): """ Examples: mob1.align_to(mob2, UP) moves mob1 vertically so that its @@ -1222,11 +1340,11 @@ class Mobject(object): # Alignment - def align_data_and_family(self, mobject): + def align_data_and_family(self, mobject: "Mobject") -> None: self.align_family(mobject) self.align_data(mobject) - def align_data(self, mobject): + def align_data(self, mobject: "Mobject") -> None: # In case any data arrays get resized when aligned to shader data self.refresh_shader_data() for mob1, mob2 in zip(self.get_family(), mobject.get_family()): @@ -1243,13 +1361,13 @@ class Mobject(object): elif len(arr1) > len(arr2): mob2.data[key] = resize_preserving_order(arr2, len(arr1)) - def align_points(self, mobject): + def align_points(self, mobject: "Mobject"): max_len = max(self.get_num_points(), mobject.get_num_points()) for mob in (self, mobject): mob.resize_points(max_len, resize_func=resize_preserving_order) return self - def align_family(self, mobject): + def align_family(self, mobject: "Mobject"): mob1 = self mob2 = mobject n1 = len(mob1) @@ -1269,7 +1387,7 @@ class Mobject(object): self.add(copy) return self - def add_n_more_submobjects(self, n): + def add_n_more_submobjects(self, n: int): if n == 0: return self @@ -1304,7 +1422,13 @@ class Mobject(object): # Interpolate - def interpolate(self, mobject1, mobject2, alpha, path_func=straight_path): + def interpolate( + self, + mobject1: "Mobject", + mobject2: "Mobject", + alpha: float, + path_func: Callable[[np.ndarray, np.ndarray, float], np.ndarray] = straight_path + ): for key in self.data: if key in self.locked_data_keys: continue @@ -1340,7 +1464,7 @@ class Mobject(object): """ pass # To implement in subclass - def become(self, mobject): + def become(self, mobject: "Mobject"): """ Edit all data and submobjects to be idential to another mobject @@ -1354,7 +1478,7 @@ class Mobject(object): # Locking data - def lock_data(self, keys): + def lock_data(self, keys: Iterable[str]): """ To speed up some animations, particularly transformations, it can be handy to acknowledge which pieces of data @@ -1368,7 +1492,7 @@ class Mobject(object): self.refresh_shader_data() self.locked_data_keys = set(keys) - def lock_matching_data(self, mobject1, mobject2): + def lock_matching_data(self, mobject1: "Mobject", mobject2: "Mobject"): for sm, sm1, sm2 in zip(self.get_family(), mobject1.get_family(), mobject2.get_family()): keys = sm.data.keys() & sm1.data.keys() & sm2.data.keys() sm.lock_data(list(filter( @@ -1416,7 +1540,7 @@ class Mobject(object): # Shader code manipulation - def replace_shader_code(self, old, new): + def replace_shader_code(self, old: str, new: str): # TODO, will this work with VMobject structure, given # that it does not simpler return shader_wrappers of # family? @@ -1424,7 +1548,7 @@ class Mobject(object): wrapper.replace_code(old, new) return self - def set_color_by_code(self, glsl_code): + def set_color_by_code(self, glsl_code: str): """ Takes a snippet of code and inserts it into a context which has the following variables: @@ -1437,9 +1561,13 @@ class Mobject(object): ) return self - def set_color_by_xyz_func(self, glsl_snippet, - min_value=-5.0, max_value=5.0, - colormap="viridis"): + def set_color_by_xyz_func( + self, + glsl_snippet: str, + min_value: float = -5.0, + max_value: float = 5.0, + colormap: str = "viridis" + ): """ Pass in a glsl expression in terms of x, y and z which returns a float. @@ -1484,7 +1612,7 @@ class Mobject(object): self.shader_wrapper.depth_test = self.depth_test return self.shader_wrapper - def get_shader_wrapper_list(self): + def get_shader_wrapper_list(self) -> list[ShaderWrapper]: shader_wrappers = it.chain( [self.get_shader_wrapper()], *[sm.get_shader_wrapper_list() for sm in self.submobjects] @@ -1501,7 +1629,7 @@ class Mobject(object): result.append(shader_wrapper) return result - def check_data_alignment(self, array, data_key): + def check_data_alignment(self, array: Iterable, data_key: str): # Makes sure that self.data[key] can be broadcast into # the given array, meaning its length has to be either 1 # or the length of the array @@ -1512,14 +1640,19 @@ class Mobject(object): ) return self - def get_resized_shader_data_array(self, length): + def get_resized_shader_data_array(self, length: int) -> np.ndarray: # If possible, try to populate an existing array, rather # than recreating it each frame if len(self.shader_data) != length: self.shader_data = resize_array(self.shader_data, length) return self.shader_data - def read_data_to_shader(self, shader_data, shader_data_key, data_key): + def read_data_to_shader( + self, + shader_data: np.ndarray, + shader_data_key: str, + data_key: str + ): if data_key in self.locked_data_keys: return self.check_data_alignment(shader_data, data_key) @@ -1551,22 +1684,22 @@ class Mobject(object): """ def init_event_listners(self): - self.event_listners = [] + self.event_listners: list[EventListner] = [] - def add_event_listner(self, event_type, event_callback): + def add_event_listner(self, event_type: EventType, event_callback: Callable): event_listner = EventListner(self, event_type, event_callback) self.event_listners.append(event_listner) EVENT_DISPATCHER.add_listner(event_listner) return self - def remove_event_listner(self, event_type, event_callback): + def remove_event_listner(self, event_type: EventType, event_callback: Callable): event_listner = EventListner(self, event_type, event_callback) while event_listner in self.event_listners: self.event_listners.remove(event_listner) EVENT_DISPATCHER.remove_listner(event_listner) return self - def clear_event_listners(self, recurse=True): + def clear_event_listners(self, recurse: bool = True): self.event_listners = [] if recurse: for submob in self.submobjects: @@ -1638,13 +1771,13 @@ class Mobject(object): class Group(Mobject): - def __init__(self, *mobjects, **kwargs): + def __init__(self, *mobjects: "Mobject", **kwargs): if not all([isinstance(m, Mobject) for m in mobjects]): raise Exception("All submobjects must be of type Mobject") Mobject.__init__(self, **kwargs) self.add(*mobjects) - def __add__(self, other: 'Mobject' or 'Group'): + def __add__(self, other: "Mobject" | "Group"): assert(isinstance(other, Mobject)) return self.add(other) @@ -1655,35 +1788,35 @@ class Point(Mobject): "artificial_height": 1e-6, } - def __init__(self, location=ORIGIN, **kwargs): + def __init__(self, location: npt.ArrayLike = ORIGIN, **kwargs): Mobject.__init__(self, **kwargs) self.set_location(location) - def get_width(self): + def get_width(self) -> float: return self.artificial_width - def get_height(self): + def get_height(self) -> float: return self.artificial_height - def get_location(self): + def get_location(self) -> np.ndarray: return self.get_points()[0].copy() - def get_bounding_box_point(self, *args, **kwargs): + def get_bounding_box_point(self, *args, **kwargs) -> np.ndarray: return self.get_location() - def set_location(self, new_loc): + def set_location(self, new_loc: npt.ArrayLike): self.set_points(np.array(new_loc, ndmin=2, dtype=float)) class _AnimationBuilder: - def __init__(self, mobject): + def __init__(self, mobject: Mobject): self.mobject = mobject self.overridden_animation = None self.mobject.generate_target() self.is_chaining = False self.methods = [] - def __getattr__(self, method_name): + def __getattr__(self, method_name: str): method = getattr(self.mobject.target, method_name) self.methods.append(method) has_overridden_animation = hasattr(method, "_override_animate") From 7f8216bb095123e02fe806eb613c54954c6f4d51 Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Sun, 13 Feb 2022 15:18:04 +0800 Subject: [PATCH 04/27] chore: replace some iterable with npt.ArrayLike --- manimlib/utils/bezier.py | 21 +++++++++++---------- manimlib/utils/iterables.py | 6 +++--- manimlib/utils/space_ops.py | 13 +++++++------ 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/manimlib/utils/bezier.py b/manimlib/utils/bezier.py index 744ef9f1..ddca0671 100644 --- a/manimlib/utils/bezier.py +++ b/manimlib/utils/bezier.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import Iterable, Callable, TypeVar +from typing import Iterable, Callable, TypeVar, Sequence from scipy import linalg import numpy as np +import numpy.typing as npt from manimlib.utils.simple_functions import choose from manimlib.utils.space_ops import find_intersection @@ -29,7 +30,7 @@ def bezier( def partial_bezier_points( - points: Iterable[np.ndarray], + points: Sequence[np.ndarray], a: float, b: float ) -> list[float]: @@ -59,7 +60,7 @@ def partial_bezier_points( # Shortened version of partial_bezier_points just for quadratics, # since this is called a fair amount def partial_quadratic_bezier_points( - points: Iterable[np.ndarray], + points: Sequence[np.ndarray], a: float, b: float ) -> list[float]: @@ -148,7 +149,7 @@ def match_interpolate( def get_smooth_quadratic_bezier_handle_points( - points: Iterable[np.ndarray] + points: Sequence[np.ndarray] ) -> np.ndarray | list[np.ndarray]: """ Figuring out which bezier curves most smoothly connect a sequence of points. @@ -182,7 +183,7 @@ def get_smooth_quadratic_bezier_handle_points( def get_smooth_cubic_bezier_handle_points( - points: Iterable[np.ndarray] + points: npt.ArrayLike ) -> tuple[np.ndarray, np.ndarray]: points = np.array(points) num_handles = len(points) - 1 @@ -261,17 +262,17 @@ def diag_to_matrix( return matrix -def is_closed(points: Iterable[np.ndarray]) -> bool: +def is_closed(points: Sequence[np.ndarray]) -> bool: return np.allclose(points[0], points[-1]) # Given 4 control points for a cubic bezier curve (or arrays of such) # return control points for 2 quadratics (or 2n quadratics) approximating them. def get_quadratic_approximation_of_cubic( - a0: np.ndarray | Iterable[np.ndarray], - h0: np.ndarray | Iterable[np.ndarray], - h1: np.ndarray | Iterable[np.ndarray], - a1: np.ndarray | Iterable[np.ndarray] + a0: npt.ArrayLike, + h0: npt.ArrayLike, + h1: npt.ArrayLike, + a1: npt.ArrayLike ) -> np.ndarray: a0 = np.array(a0, ndmin=2) h0 = np.array(h0, ndmin=2) diff --git a/manimlib/utils/iterables.py b/manimlib/utils/iterables.py index 90b32bdb..30be9c32 100644 --- a/manimlib/utils/iterables.py +++ b/manimlib/utils/iterables.py @@ -1,7 +1,7 @@ from __future__ import annotations import itertools as it -from typing import Callable, Iterable, TypeVar +from typing import Callable, Iterable, Sequence, TypeVar import numpy as np @@ -119,8 +119,8 @@ def resize_with_interpolation(nparray: np.ndarray, length: int) -> np.ndarray: def make_even( - iterable_1: Iterable[T], - iterable_2: Iterable[S] + iterable_1: Sequence[T], + iterable_2: Sequence[S] ) -> tuple[list[T], list[S]]: len1 = len(iterable_1) len2 = len(iterable_2) diff --git a/manimlib/utils/space_ops.py b/manimlib/utils/space_ops.py index 622a63e1..210ceae7 100644 --- a/manimlib/utils/space_ops.py +++ b/manimlib/utils/space_ops.py @@ -6,6 +6,7 @@ from functools import reduce from typing import Callable, Iterable, Sequence import numpy as np +import numpy.typing as npt from mapbox_earcut import triangulate_float32 as earcut from manimlib.constants import RIGHT @@ -124,7 +125,7 @@ def rotation_matrix_from_quaternion(quat: Iterable) -> np.ndarray: return np.transpose(rotation_matrix_transpose_from_quaternion(quat)) -def rotation_matrix_transpose(angle: float, axis: np.ndarray) -> list[list[flaot]]: +def rotation_matrix_transpose(angle: float, axis: np.ndarray) -> list[list[float]]: if axis[0] == 0 and axis[1] == 0: # axis = [0, 0, z] case is common enough it's worth # having a shortcut @@ -267,7 +268,7 @@ def complex_func_to_R3_func( return lambda p: complex_to_R3(complex_func(R3_to_complex(p))) -def center_of_mass(points: Iterable[Sequence[float]]) -> np.ndarray: +def center_of_mass(points: Iterable[npt.ArrayLike]) -> np.ndarray: points = [np.array(point).astype("float") for point in points] return sum(points) / len(points) @@ -304,10 +305,10 @@ def line_intersection( def find_intersection( - p0: Iterable[float], - v0: Iterable[float], - p1: Iterable[float], - v1: Iterable[float], + p0: npt.ArrayLike, + v0: npt.ArrayLike, + p1: npt.ArrayLike, + v1: npt.ArrayLike, threshold: float = 1e-5 ) -> np.ndarray: """ From 19187ead06deeff0a71e2393e8e82280becb2983 Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Sun, 13 Feb 2022 18:56:50 +0800 Subject: [PATCH 05/27] chore: add type hints to manimlib.mobject.types --- manimlib/mobject/types/dot_cloud.py | 54 ++-- manimlib/mobject/types/image_mobject.py | 17 +- manimlib/mobject/types/point_cloud_mobject.py | 44 ++- manimlib/mobject/types/surface.py | 83 ++++-- manimlib/mobject/types/vectorized_mobject.py | 272 +++++++++++------- 5 files changed, 306 insertions(+), 164 deletions(-) diff --git a/manimlib/mobject/types/dot_cloud.py b/manimlib/mobject/types/dot_cloud.py index 03511ecb..d44bdc25 100644 --- a/manimlib/mobject/types/dot_cloud.py +++ b/manimlib/mobject/types/dot_cloud.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import numpy as np +import numpy.typing as npt import moderngl from manimlib.constants import GREY_C @@ -29,27 +32,31 @@ class DotCloud(PMobject): ], } - def __init__(self, points=None, **kwargs): + def __init__(self, points: npt.ArrayLike = None, **kwargs): super().__init__(**kwargs) if points is not None: self.set_points(points) - def init_data(self): + def init_data(self) -> None: super().init_data() self.data["radii"] = np.zeros((1, 1)) self.set_radius(self.radius) - def init_uniforms(self): + def init_uniforms(self) -> None: super().init_uniforms() self.uniforms["glow_factor"] = self.glow_factor - def to_grid(self, n_rows, n_cols, n_layers=1, - buff_ratio=None, - h_buff_ratio=1.0, - v_buff_ratio=1.0, - d_buff_ratio=1.0, - height=DEFAULT_GRID_HEIGHT, - ): + def to_grid( + self, + n_rows: int, + n_cols: int, + n_layers: int = 1, + buff_ratio: float | None = None, + h_buff_ratio: float = 1.0, + v_buff_ratio: float = 1.0, + d_buff_ratio: float = 1.0, + height: float = DEFAULT_GRID_HEIGHT, + ): n_points = n_rows * n_cols * n_layers points = np.repeat(range(n_points), 3, axis=0).reshape((n_points, 3)) points[:, 0] = points[:, 0] % n_cols @@ -74,50 +81,55 @@ class DotCloud(PMobject): self.center() return self - def set_radii(self, radii): + def set_radii(self, radii: npt.ArrayLike): n_points = len(self.get_points()) radii = np.array(radii).reshape((len(radii), 1)) self.data["radii"] = resize_preserving_order(radii, n_points) self.refresh_bounding_box() return self - def get_radii(self): + def get_radii(self) -> np.ndarray: return self.data["radii"] - def set_radius(self, radius): + def set_radius(self, radius: float): self.data["radii"][:] = radius self.refresh_bounding_box() return self - def get_radius(self): + def get_radius(self) -> float: return self.get_radii().max() - def set_glow_factor(self, glow_factor): + def set_glow_factor(self, glow_factor: float) -> None: self.uniforms["glow_factor"] = glow_factor - def get_glow_factor(self): + def get_glow_factor(self) -> float: return self.uniforms["glow_factor"] - def compute_bounding_box(self): + def compute_bounding_box(self) -> np.ndarray: bb = super().compute_bounding_box() radius = self.get_radius() bb[0] += np.full((3,), -radius) bb[2] += np.full((3,), radius) return bb - def scale(self, scale_factor, scale_radii=True, **kwargs): + def scale( + self, + scale_factor: float | npt.ArrayLike, + scale_radii: bool = True, + **kwargs + ): super().scale(scale_factor, **kwargs) if scale_radii: self.set_radii(scale_factor * self.get_radii()) return self - def make_3d(self, reflectiveness=0.5, shadow=0.2): + def make_3d(self, reflectiveness: float = 0.5, shadow: float = 0.2): self.set_reflectiveness(reflectiveness) self.set_shadow(shadow) self.apply_depth_test() return self - def get_shader_data(self): + def get_shader_data(self) -> np.ndarray: shader_data = super().get_shader_data() self.read_data_to_shader(shader_data, "radius", "radii") self.read_data_to_shader(shader_data, "color", "rgbas") @@ -125,7 +137,7 @@ class DotCloud(PMobject): class TrueDot(DotCloud): - def __init__(self, center=ORIGIN, **kwargs): + def __init__(self, center: np.ndarray = ORIGIN, **kwargs): super().__init__(points=[center], **kwargs) diff --git a/manimlib/mobject/types/image_mobject.py b/manimlib/mobject/types/image_mobject.py index 334b389d..d3f11f2b 100644 --- a/manimlib/mobject/types/image_mobject.py +++ b/manimlib/mobject/types/image_mobject.py @@ -1,5 +1,6 @@ -import numpy as np +from __future__ import annotations +import numpy as np from PIL import Image from manimlib.constants import * @@ -21,33 +22,33 @@ class ImageMobject(Mobject): ] } - def __init__(self, filename, **kwargs): + def __init__(self, filename: str, **kwargs): self.set_image_path(get_full_raster_image_path(filename)) super().__init__(**kwargs) - def set_image_path(self, path): + def set_image_path(self, path: str) -> None: self.path = path self.image = Image.open(path) self.texture_paths = {"Texture": path} - def init_data(self): + def init_data(self) -> None: self.data = { "points": np.array([UL, DL, UR, DR]), "im_coords": np.array([(0, 0), (0, 1), (1, 0), (1, 1)]), "opacity": np.array([[self.opacity]], dtype=np.float32), } - def init_points(self): + def init_points(self) -> None: size = self.image.size self.set_width(2 * size[0] / size[1], stretch=True) self.set_height(self.height) - def set_opacity(self, opacity, recurse=True): + def set_opacity(self, opacity: float, recurse: bool = True): for mob in self.get_family(recurse): mob.data["opacity"] = np.array([[o] for o in listify(opacity)]) return self - def point_to_rgb(self, point): + def point_to_rgb(self, point: np.ndarray) -> np.ndarray: x0, y0 = self.get_corner(UL)[:2] x1, y1 = self.get_corner(DR)[:2] x_alpha = inverse_interpolate(x0, x1, point[0]) @@ -63,7 +64,7 @@ class ImageMobject(Mobject): )) return np.array(rgb) / 255 - def get_shader_data(self): + def get_shader_data(self) -> np.ndarray: shader_data = super().get_shader_data() self.read_data_to_shader(shader_data, "im_coords", "im_coords") self.read_data_to_shader(shader_data, "opacity", "opacity") diff --git a/manimlib/mobject/types/point_cloud_mobject.py b/manimlib/mobject/types/point_cloud_mobject.py index 28ccee7e..2af3e191 100644 --- a/manimlib/mobject/types/point_cloud_mobject.py +++ b/manimlib/mobject/types/point_cloud_mobject.py @@ -1,3 +1,10 @@ +from __future__ import annotations + +from typing import Callable, Sequence, Union + +import colour +import numpy.typing as npt + from manimlib.constants import * from manimlib.mobject.mobject import Mobject from manimlib.utils.color import color_gradient @@ -6,26 +13,39 @@ from manimlib.utils.iterables import resize_with_interpolation from manimlib.utils.iterables import resize_array +Color = Union[str, colour.Color, Sequence[float]] + + class PMobject(Mobject): CONFIG = { "opacity": 1.0, } - def resize_points(self, size, resize_func=resize_array): + def resize_points( + self, + size: int, + resize_func: Callable[[np.ndarray, int], np.ndarray] = resize_array + ): # TODO for key in self.data: if key == "bounding_box": continue if len(self.data[key]) != size: - self.data[key] = resize_array(self.data[key], size) + self.data[key] = resize_func(self.data[key], size) return self - def set_points(self, points): + def set_points(self, points: npt.ArrayLike): super().set_points(points) self.resize_points(len(points)) return self - def add_points(self, points, rgbas=None, color=None, opacity=None): + def add_points( + self, + points: npt.ArrayLike, + rgbas: np.ndarray | None = None, + color: Color | None = None, + opacity: float | None = None + ): """ points must be a Nx3 numpy array, as must rgbas if it is not None """ @@ -44,20 +64,20 @@ class PMobject(Mobject): self.data["rgbas"][-len(new_rgbas):] = new_rgbas return self - def set_color_by_gradient(self, *colors): + def set_color_by_gradient(self, *colors: Color): self.data["rgbas"] = np.array(list(map( color_to_rgba, color_gradient(colors, self.get_num_points()) ))) return self - def match_colors(self, pmobject): + def match_colors(self, pmobject: "PMobject"): self.data["rgbas"][:] = resize_with_interpolation( pmobject.data["rgbas"], self.get_num_points() ) return self - def filter_out(self, condition): + def filter_out(self, condition: Callable[[np.ndarray], bool]): for mob in self.family_members_with_points(): to_keep = ~np.apply_along_axis(condition, 1, mob.get_points()) for key in mob.data: @@ -66,7 +86,7 @@ class PMobject(Mobject): mob.data[key] = mob.data[key][to_keep] return self - def sort_points(self, function=lambda p: p[0]): + def sort_points(self, function: Callable[[np.ndarray]] = lambda p: p[0]): """ function is any map from R^3 to R """ @@ -86,11 +106,11 @@ class PMobject(Mobject): ]) return self - def point_from_proportion(self, alpha): + def point_from_proportion(self, alpha: float) -> np.ndarray: index = alpha * (self.get_num_points() - 1) return self.get_points()[int(index)] - def pointwise_become_partial(self, pmobject, a, b): + def pointwise_become_partial(self, pmobject: "PMobject", a: float, b: float): lower_index = int(a * pmobject.get_num_points()) upper_index = int(b * pmobject.get_num_points()) for key in self.data: @@ -101,7 +121,7 @@ class PMobject(Mobject): class PGroup(PMobject): - def __init__(self, *pmobs, **kwargs): + def __init__(self, *pmobs: PMobject, **kwargs): if not all([isinstance(m, PMobject) for m in pmobs]): raise Exception("All submobjects must be of type PMobject") super().__init__(*pmobs, **kwargs) @@ -112,6 +132,6 @@ class Point(PMobject): "color": BLACK, } - def __init__(self, location=ORIGIN, **kwargs): + def __init__(self, location: np.ndarray = ORIGIN, **kwargs): super().__init__(**kwargs) self.add_points([location]) diff --git a/manimlib/mobject/types/surface.py b/manimlib/mobject/types/surface.py index 1160c1ae..a8b4fd5c 100644 --- a/manimlib/mobject/types/surface.py +++ b/manimlib/mobject/types/surface.py @@ -1,7 +1,13 @@ -import numpy as np +from __future__ import annotations + +from typing import Iterable, Callable + import moderngl +import numpy as np +import numpy.typing as npt from manimlib.constants import * +from manimlib.camera.camera import Camera from manimlib.mobject.mobject import Mobject from manimlib.utils.bezier import integer_interpolate from manimlib.utils.bezier import interpolate @@ -42,7 +48,7 @@ class Surface(Mobject): super().__init__(**kwargs) self.compute_triangle_indices() - def uv_func(self, u, v): + def uv_func(self, u: float, v: float) -> tuple[float, float, float]: # To be implemented in subclasses return (u, v, 0.0) @@ -85,15 +91,17 @@ class Surface(Mobject): indices[5::6] = index_grid[+1:, +1:].flatten() # Bottom right self.triangle_indices = indices - def get_triangle_indices(self): + def get_triangle_indices(self) -> np.ndarray: return self.triangle_indices - def get_surface_points_and_nudged_points(self): + def get_surface_points_and_nudged_points( + self + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: points = self.get_points() k = len(points) // 3 return points[:k], points[k:2 * k], points[2 * k:] - def get_unit_normals(self): + def get_unit_normals(self) -> np.ndarray: s_points, du_points, dv_points = self.get_surface_points_and_nudged_points() normals = np.cross( (du_points - s_points) / self.epsilon, @@ -101,7 +109,13 @@ class Surface(Mobject): ) return normalize_along_axis(normals, 1) - def pointwise_become_partial(self, smobject, a, b, axis=None): + def pointwise_become_partial( + self, + smobject: "Surface", + a: float, + b: float, + axis: np.ndarray | None = None + ): assert(isinstance(smobject, Surface)) if axis is None: axis = self.prefered_creation_axis @@ -116,7 +130,14 @@ class Surface(Mobject): ])) return self - def get_partial_points_array(self, points, a, b, resolution, axis): + def get_partial_points_array( + self, + points: np.ndarray, + a: float, + b: float, + resolution: npt.ArrayLike, + axis: int + ) -> np.ndarray: if len(points) == 0: return points nu, nv = resolution[:2] @@ -149,7 +170,7 @@ class Surface(Mobject): ).reshape(shape) return points.reshape((nu * nv, *resolution[2:])) - def sort_faces_back_to_front(self, vect=OUT): + def sort_faces_back_to_front(self, vect: np.ndarray = OUT): tri_is = self.triangle_indices indices = list(range(len(tri_is) // 3)) points = self.get_points() @@ -162,13 +183,13 @@ class Surface(Mobject): tri_is[k::3] = tri_is[k::3][indices] return self - def always_sort_to_camera(self, camera): + def always_sort_to_camera(self, camera: Camera): self.add_updater(lambda m: m.sort_faces_back_to_front( camera.get_location() - self.get_center() )) # For shaders - def get_shader_data(self): + def get_shader_data(self) -> np.ndarray: s_points, du_points, dv_points = self.get_surface_points_and_nudged_points() shader_data = self.get_resized_shader_data_array(len(s_points)) if "points" not in self.locked_data_keys: @@ -178,16 +199,22 @@ class Surface(Mobject): self.fill_in_shader_color_info(shader_data) return shader_data - def fill_in_shader_color_info(self, shader_data): + def fill_in_shader_color_info(self, shader_data: np.ndarray) -> np.ndarray: self.read_data_to_shader(shader_data, "color", "rgbas") return shader_data - def get_shader_vert_indices(self): + def get_shader_vert_indices(self) -> np.ndarray: return self.get_triangle_indices() class ParametricSurface(Surface): - def __init__(self, uv_func, u_range=(0, 1), v_range=(0, 1), **kwargs): + def __init__( + self, + uv_func: Callable[[float, float], Iterable[float]], + u_range: tuple[float, float] = (0, 1), + v_range: tuple[float, float] = (0, 1), + **kwargs + ): self.passed_uv_func = uv_func super().__init__(u_range=u_range, v_range=v_range, **kwargs) @@ -200,7 +227,7 @@ class SGroup(Surface): "resolution": (0, 0), } - def __init__(self, *parametric_surfaces, **kwargs): + def __init__(self, *parametric_surfaces: Surface, **kwargs): super().__init__(uv_func=None, **kwargs) self.add(*parametric_surfaces) @@ -220,7 +247,13 @@ class TexturedSurface(Surface): ] } - def __init__(self, uv_surface, image_file, dark_image_file=None, **kwargs): + def __init__( + self, + uv_surface: Surface, + image_file: str, + dark_image_file: str | None = None, + **kwargs + ): if not isinstance(uv_surface, Surface): raise Exception("uv_surface must be of type Surface") # Set texture information @@ -236,10 +269,10 @@ class TexturedSurface(Surface): self.uv_surface = uv_surface self.uv_func = uv_surface.uv_func - self.u_range = uv_surface.u_range - self.v_range = uv_surface.v_range - self.resolution = uv_surface.resolution - self.gloss = self.uv_surface.gloss + self.u_range: tuple[float, float] = uv_surface.u_range + self.v_range: tuple[float, float] = uv_surface.v_range + self.resolution: tuple[float, float] = uv_surface.resolution + self.gloss: float = self.uv_surface.gloss super().__init__(**kwargs) def init_data(self): @@ -263,12 +296,18 @@ class TexturedSurface(Surface): def init_colors(self): self.data["opacity"] = np.array([self.uv_surface.data["rgbas"][:, 3]]) - def set_opacity(self, opacity, recurse=True): + def set_opacity(self, opacity: float, recurse: bool = True): for mob in self.get_family(recurse): mob.data["opacity"] = np.array([[o] for o in listify(opacity)]) return self - def pointwise_become_partial(self, tsmobject, a, b, axis=1): + def pointwise_become_partial( + self, + tsmobject: "TexturedSurface", + a: float, + b: float, + axis: int = 1 + ): super().pointwise_become_partial(tsmobject, a, b, axis) im_coords = self.data["im_coords"] im_coords[:] = tsmobject.data["im_coords"] @@ -280,7 +319,7 @@ class TexturedSurface(Surface): ) return self - def fill_in_shader_color_info(self, shader_data): + def fill_in_shader_color_info(self, shader_data: np.ndarray) -> np.ndarray: self.read_data_to_shader(shader_data, "opacity", "opacity") self.read_data_to_shader(shader_data, "im_coords", "im_coords") return shader_data diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index 3c7a4326..a1a6c29f 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -1,8 +1,13 @@ -import itertools as it -import operator as op -import moderngl +from __future__ import annotations +import operator as op +import itertools as it from functools import reduce, wraps +from typing import Iterable, Sequence, Callable, Union + +import colour +import moderngl +import numpy.typing as npt from manimlib.constants import * from manimlib.mobject.mobject import Mobject @@ -29,6 +34,9 @@ from manimlib.utils.space_ops import z_to_vector from manimlib.shader_wrapper import ShaderWrapper +Color = Union[str, colour.Color, Sequence[float]] + + class VMobject(Mobject): CONFIG = { "fill_color": None, @@ -105,7 +113,12 @@ class VMobject(Mobject): self.set_flat_stroke(self.flat_stroke) return self - def set_rgba_array(self, rgba_array, name=None, recurse=False): + def set_rgba_array( + self, + rgba_array: npt.ArrayLike, + name: str = None, + recurse: bool = False + ): if name is None: names = ["fill_rgba", "stroke_rgba"] else: @@ -115,11 +128,23 @@ class VMobject(Mobject): super().set_rgba_array(rgba_array, name, recurse) return self - def set_fill(self, color=None, opacity=None, recurse=True): + def set_fill( + self, + color: Color | None = None, + opacity: float | None = None, + recurse: bool = True + ): self.set_rgba_array_by_color(color, opacity, 'fill_rgba', recurse) return self - def set_stroke(self, color=None, width=None, opacity=None, background=None, recurse=True): + def set_stroke( + self, + color: Color | None = None, + width: float | npt.ArrayLike | None = None, + opacity: float | None = None, + background: bool | None = None, + recurse: bool = True + ): self.set_rgba_array_by_color(color, opacity, 'stroke_rgba', recurse) if width is not None: @@ -135,29 +160,36 @@ class VMobject(Mobject): mob.draw_stroke_behind_fill = background return self - def set_backstroke(self, color=BLACK, width=3, background=True): + def set_backstroke( + self, + color: Color = BLACK, + width: float | npt.ArrayLike = 3, + background: bool = True + ): self.set_stroke(color, width, background=background) return self - def align_stroke_width_data_to_points(self, recurse=True): + def align_stroke_width_data_to_points(self, recurse: bool = True) -> None: for mob in self.get_family(recurse): mob.data["stroke_width"] = resize_with_interpolation( mob.data["stroke_width"], len(mob.get_points()) ) - def set_style(self, - fill_color=None, - fill_opacity=None, - fill_rgba=None, - stroke_color=None, - stroke_opacity=None, - stroke_rgba=None, - stroke_width=None, - stroke_background=True, - reflectiveness=None, - gloss=None, - shadow=None, - recurse=True): + def set_style( + self, + fill_color: Color | None = None, + fill_opacity: float | None = None, + fill_rgba: npt.ArrayLike | None = None, + stroke_color: Color | None = None, + stroke_opacity: float | None = None, + stroke_rgba: npt.ArrayLike | None = None, + stroke_width: float | npt.ArrayLike | None = None, + stroke_background: bool = True, + reflectiveness: float | None = None, + gloss: float | None = None, + shadow: float | None = None, + recurse: bool = True + ): if fill_rgba is not None: self.data['fill_rgba'] = resize_with_interpolation(fill_rgba, len(fill_rgba)) else: @@ -201,7 +233,7 @@ class VMobject(Mobject): "shadow": self.get_shadow(), } - def match_style(self, vmobject, recurse=True): + def match_style(self, vmobject: "VMobject", recurse: bool = True): self.set_style(**vmobject.get_style(), recurse=False) if recurse: # Does its best to match up submobject lists, and @@ -215,17 +247,17 @@ class VMobject(Mobject): sm1.match_style(sm2) return self - def set_color(self, color, recurse=True): + def set_color(self, color: Color, recurse: bool = True): self.set_fill(color, recurse=recurse) self.set_stroke(color, recurse=recurse) return self - def set_opacity(self, opacity, recurse=True): + def set_opacity(self, opacity: float, recurse: bool = True): self.set_fill(opacity=opacity, recurse=recurse) self.set_stroke(opacity=opacity, recurse=recurse) return self - def fade(self, darkness=0.5, recurse=True): + def fade(self, darkness: float = 0.5, recurse: bool = True): mobs = self.get_family() if recurse else [self] for mob in mobs: factor = 1.0 - darkness @@ -239,78 +271,83 @@ class VMobject(Mobject): ) return self - def get_fill_colors(self): + def get_fill_colors(self) -> list[str]: return [ rgb_to_hex(rgba[:3]) for rgba in self.data['fill_rgba'] ] - def get_fill_opacities(self): + def get_fill_opacities(self) -> np.ndarray: return self.data['fill_rgba'][:, 3] - def get_stroke_colors(self): + def get_stroke_colors(self) -> list[str]: return [ rgb_to_hex(rgba[:3]) for rgba in self.data['stroke_rgba'] ] - def get_stroke_opacities(self): + def get_stroke_opacities(self) -> np.ndarray: return self.data['stroke_rgba'][:, 3] - def get_stroke_widths(self): + def get_stroke_widths(self) -> np.ndarray: return self.data['stroke_width'][:, 0] # TODO, it's weird for these to return the first of various lists # rather than the full information - def get_fill_color(self): + def get_fill_color(self) -> str: """ If there are multiple colors (for gradient) this returns the first one """ return self.get_fill_colors()[0] - def get_fill_opacity(self): + def get_fill_opacity(self) -> float: """ If there are multiple opacities, this returns the first """ return self.get_fill_opacities()[0] - def get_stroke_color(self): + def get_stroke_color(self) -> str: return self.get_stroke_colors()[0] - def get_stroke_width(self): + def get_stroke_width(self) -> float | np.ndarray: return self.get_stroke_widths()[0] - def get_stroke_opacity(self): + def get_stroke_opacity(self) -> float: return self.get_stroke_opacities()[0] - def get_color(self): + def get_color(self) -> str: if self.has_fill(): return self.get_fill_color() return self.get_stroke_color() - def has_stroke(self): + def has_stroke(self) -> bool: return self.get_stroke_widths().any() and self.get_stroke_opacities().any() - def has_fill(self): + def has_fill(self) -> bool: return any(self.get_fill_opacities()) - def get_opacity(self): + def get_opacity(self) -> float: if self.has_fill(): return self.get_fill_opacity() return self.get_stroke_opacity() - def set_flat_stroke(self, flat_stroke=True, recurse=True): + def set_flat_stroke(self, flat_stroke: bool = True, recurse: bool = True): for mob in self.get_family(recurse): mob.flat_stroke = flat_stroke return self - def get_flat_stroke(self): + def get_flat_stroke(self) -> bool: return self.flat_stroke # Points - def set_anchors_and_handles(self, anchors1, handles, anchors2): + def set_anchors_and_handles( + self, + anchors1: np.ndarray, + handles: np.ndarray, + anchors2: np.ndarray + ): assert(len(anchors1) == len(handles) == len(anchors2)) nppc = self.n_points_per_curve new_points = np.zeros((nppc * len(anchors1), self.dim)) @@ -320,16 +357,27 @@ class VMobject(Mobject): self.set_points(new_points) return self - def start_new_path(self, point): + def start_new_path(self, point: np.ndarray): assert(self.get_num_points() % self.n_points_per_curve == 0) self.append_points([point]) return self - def add_cubic_bezier_curve(self, anchor1, handle1, handle2, anchor2): + def add_cubic_bezier_curve( + self, + anchor1: npt.ArrayLike, + handle1: npt.ArrayLike, + handle2: npt.ArrayLike, + anchor2: npt.ArrayLike + ): new_points = get_quadratic_approximation_of_cubic(anchor1, handle1, handle2, anchor2) self.append_points(new_points) - def add_cubic_bezier_curve_to(self, handle1, handle2, anchor): + def add_cubic_bezier_curve_to( + self, + handle1: npt.ArrayLike, + handle2: npt.ArrayLike, + anchor: npt.ArrayLike + ): """ Add cubic bezier curve to the path. """ @@ -342,14 +390,14 @@ class VMobject(Mobject): else: self.append_points(quadratic_approx) - def add_quadratic_bezier_curve_to(self, handle, anchor): + def add_quadratic_bezier_curve_to(self, handle: np.ndarray, anchor: np.ndarray): self.throw_error_if_no_points() if self.has_new_path_started(): self.append_points([handle, anchor]) else: self.append_points([self.get_last_point(), handle, anchor]) - def add_line_to(self, point): + def add_line_to(self, point: np.ndarray): end = self.get_points()[-1] alphas = np.linspace(0, 1, self.n_points_per_curve) if self.long_lines: @@ -371,7 +419,7 @@ class VMobject(Mobject): self.append_points(points) return self - def add_smooth_curve_to(self, point): + def add_smooth_curve_to(self, point: np.ndarray): if self.has_new_path_started(): self.add_line_to(point) else: @@ -380,7 +428,7 @@ class VMobject(Mobject): self.add_quadratic_bezier_curve_to(new_handle, point) return self - def add_smooth_cubic_curve_to(self, handle, point): + def add_smooth_cubic_curve_to(self, handle: np.ndarray, point: np.ndarray): self.throw_error_if_no_points() if self.get_num_points() == 1: new_handle = self.get_points()[-1] @@ -388,13 +436,13 @@ class VMobject(Mobject): new_handle = self.get_reflection_of_last_handle() self.add_cubic_bezier_curve_to(new_handle, handle, point) - def has_new_path_started(self): + def has_new_path_started(self) -> bool: return self.get_num_points() % self.n_points_per_curve == 1 - def get_last_point(self): + def get_last_point(self) -> np.ndarray: return self.get_points()[-1] - def get_reflection_of_last_handle(self): + def get_reflection_of_last_handle(self) -> np.ndarray: points = self.get_points() return 2 * points[-1] - points[-2] @@ -402,12 +450,16 @@ class VMobject(Mobject): if not self.is_closed(): self.add_line_to(self.get_subpaths()[-1][0]) - def is_closed(self): + def is_closed(self) -> bool: return self.consider_points_equals( self.get_points()[0], self.get_points()[-1] ) - def subdivide_sharp_curves(self, angle_threshold=30 * DEGREES, recurse=True): + def subdivide_sharp_curves( + self, + angle_threshold: float = 30 * DEGREES, + recurse: bool = True + ): vmobs = [vm for vm in self.get_family(recurse) if vm.has_points()] for vmob in vmobs: new_points = [] @@ -425,12 +477,12 @@ class VMobject(Mobject): vmob.set_points(np.vstack(new_points)) return self - def add_points_as_corners(self, points): + def add_points_as_corners(self, points: Iterable[np.ndarray]): for point in points: self.add_line_to(point) return points - def set_points_as_corners(self, points): + def set_points_as_corners(self, points: Iterable[np.ndarray]): nppc = self.n_points_per_curve points = np.array(points) self.set_anchors_and_handles(*[ @@ -439,7 +491,11 @@ class VMobject(Mobject): ]) return self - def set_points_smoothly(self, points, true_smooth=False): + def set_points_smoothly( + self, + points: Iterable[np.ndarray], + true_smooth: bool = False + ): self.set_points_as_corners(points) if true_smooth: self.make_smooth() @@ -447,7 +503,7 @@ class VMobject(Mobject): self.make_approximately_smooth() return self - def change_anchor_mode(self, mode): + def change_anchor_mode(self, mode: str): assert(mode in ("jagged", "approx_smooth", "true_smooth")) nppc = self.n_points_per_curve for submob in self.family_members_with_points(): @@ -492,12 +548,12 @@ class VMobject(Mobject): self.change_anchor_mode("jagged") return self - def add_subpath(self, points): + def add_subpath(self, points: Iterable[np.ndarray]): assert(len(points) % self.n_points_per_curve == 0) self.append_points(points) return self - def append_vectorized_mobject(self, vectorized_mobject): + def append_vectorized_mobject(self, vectorized_mobject: "VMobject"): new_points = list(vectorized_mobject.get_points()) if self.has_new_path_started(): @@ -508,11 +564,11 @@ class VMobject(Mobject): return self # - def consider_points_equals(self, p0, p1): + def consider_points_equals(self, p0: np.ndarray, p1: np.ndarray) -> bool: return get_norm(p1 - p0) < self.tolerance_for_point_equality # Information about the curve - def get_bezier_tuples_from_points(self, points): + def get_bezier_tuples_from_points(self, points: Sequence[np.ndarray]): nppc = self.n_points_per_curve remainder = len(points) % nppc points = points[:len(points) - remainder] @@ -524,7 +580,10 @@ class VMobject(Mobject): def get_bezier_tuples(self): return self.get_bezier_tuples_from_points(self.get_points()) - def get_subpaths_from_points(self, points): + def get_subpaths_from_points( + self, + points: Sequence[np.ndarray] + ) -> list[Sequence[np.ndarray]]: nppc = self.n_points_per_curve diffs = points[nppc - 1:-1:nppc] - points[nppc::nppc] splits = (diffs * diffs).sum(1) > self.tolerance_for_point_equality @@ -541,28 +600,28 @@ class VMobject(Mobject): if (i2 - i1) >= nppc ] - def get_subpaths(self): + def get_subpaths(self) -> list[Sequence[np.ndarray]]: return self.get_subpaths_from_points(self.get_points()) - def get_nth_curve_points(self, n): + def get_nth_curve_points(self, n: int) -> np.ndarray: assert(n < self.get_num_curves()) nppc = self.n_points_per_curve return self.get_points()[nppc * n:nppc * (n + 1)] - def get_nth_curve_function(self, n): + def get_nth_curve_function(self, n: int) -> Callable[[float], np.ndarray]: return bezier(self.get_nth_curve_points(n)) - def get_num_curves(self): + def get_num_curves(self) -> int: return self.get_num_points() // self.n_points_per_curve - def quick_point_from_proportion(self, alpha): + def quick_point_from_proportion(self, alpha: float) -> np.ndarray: # Assumes all curves have the same length, so is inaccurate num_curves = self.get_num_curves() n, residue = integer_interpolate(0, num_curves, alpha) curve_func = self.get_nth_curve_function(n) return curve_func(residue) - def point_from_proportion(self, alpha): + def point_from_proportion(self, alpha: float) -> np.ndarray: if alpha <= 0: return self.get_start() elif alpha >= 1: @@ -584,7 +643,7 @@ class VMobject(Mobject): residue = inverse_interpolate(partials[i - 1] / full, partials[i] / full, alpha) return self.get_nth_curve_function(i - 1)(residue) - def get_anchors_and_handles(self): + def get_anchors_and_handles(self) -> list[np.ndarray]: """ returns anchors1, handles, anchors2, where (anchors1[i], handles[i], anchors2[i]) @@ -598,14 +657,14 @@ class VMobject(Mobject): for i in range(nppc) ] - def get_start_anchors(self): + def get_start_anchors(self) -> np.ndarray: return self.get_points()[0::self.n_points_per_curve] - def get_end_anchors(self): + def get_end_anchors(self) -> np.ndarray: nppc = self.n_points_per_curve return self.get_points()[nppc - 1::nppc] - def get_anchors(self): + def get_anchors(self) -> np.ndarray: points = self.get_points() if len(points) == 1: return points @@ -614,7 +673,7 @@ class VMobject(Mobject): self.get_end_anchors(), )))) - def get_points_without_null_curves(self, atol=1e-9): + def get_points_without_null_curves(self, atol: float=1e-9) -> np.ndarray: nppc = self.n_points_per_curve points = self.get_points() distinct_curves = reduce(op.or_, [ @@ -623,7 +682,7 @@ class VMobject(Mobject): ]) return points[distinct_curves.repeat(nppc)] - def get_arc_length(self, n_sample_points=None): + def get_arc_length(self, n_sample_points: int | None = None) -> float: if n_sample_points is None: n_sample_points = 4 * self.get_num_curves() + 1 points = np.array([ @@ -634,7 +693,7 @@ class VMobject(Mobject): norms = np.array([get_norm(d) for d in diffs]) return norms.sum() - def get_area_vector(self): + def get_area_vector(self) -> np.ndarray: # Returns a vector whose length is the area bound by # the polygon formed by the anchor points, pointing # in a direction perpendicular to the polygon according @@ -654,7 +713,7 @@ class VMobject(Mobject): sum((p0[:, 0] + p1[:, 0]) * (p1[:, 1] - p0[:, 1])), # Add up (x1 + x2)*(y2 - y1) ]) - def get_unit_normal(self, recompute=False): + def get_unit_normal(self, recompute: bool = False) -> np.ndarray: if not recompute: return self.data["unit_normal"][0] @@ -680,7 +739,7 @@ class VMobject(Mobject): return self # Alignment - def align_points(self, vmobject): + def align_points(self, vmobject: "VMobject"): if self.get_num_points() == len(vmobject.get_points()): return @@ -723,7 +782,7 @@ class VMobject(Mobject): vmobject.set_points(np.vstack(new_subpaths2)) return self - def insert_n_curves(self, n, recurse=True): + def insert_n_curves(self, n: int, recurse: bool = True): for mob in self.get_family(recurse): if mob.get_num_curves() > 0: new_points = mob.insert_n_curves_to_point_list(n, mob.get_points()) @@ -733,7 +792,7 @@ class VMobject(Mobject): mob.set_points(new_points) return self - def insert_n_curves_to_point_list(self, n, points): + def insert_n_curves_to_point_list(self, n: int, points: np.ndarray): nppc = self.n_points_per_curve if len(points) == 1: return np.repeat(points, nppc * n, 0) @@ -766,7 +825,13 @@ class VMobject(Mobject): new_points += partial_quadratic_bezier_points(group, a1, a2) return np.vstack(new_points) - def interpolate(self, mobject1, mobject2, alpha, *args, **kwargs): + def interpolate( + self, + mobject1: "VMobject", + mobject2: "VMobject", + alpha: float, + *args, **kwargs + ): super().interpolate(mobject1, mobject2, alpha, *args, **kwargs) if self.has_fill(): tri1 = mobject1.get_triangulation() @@ -775,7 +840,7 @@ class VMobject(Mobject): self.refresh_triangulation() return self - def pointwise_become_partial(self, vmobject, a, b): + def pointwise_become_partial(self, vmobject: "VMobject", a: float, b: float): assert(isinstance(vmobject, VMobject)) if a <= 0 and b >= 1: self.become(vmobject) @@ -817,7 +882,7 @@ class VMobject(Mobject): self.set_points(new_points) return self - def get_subcurve(self, a, b): + def get_subcurve(self, a: float, b: float) -> "VMobject": vmob = self.copy() vmob.pointwise_become_partial(self, a, b) return vmob @@ -829,7 +894,7 @@ class VMobject(Mobject): mob.needs_new_triangulation = True return self - def get_triangulation(self, normal_vector=None): + def get_triangulation(self, normal_vector: np.ndarray | None = None): # Figure out how to triangulate the interior to know # how to send the points as to the vertex shader. # First triangles come directly from the points @@ -898,25 +963,30 @@ class VMobject(Mobject): return wrapper @triggers_refreshed_triangulation - def set_points(self, points): + def set_points(self, points: npt.ArrayLike): super().set_points(points) return self @triggers_refreshed_triangulation - def set_data(self, data): + def set_data(self, data: dict): super().set_data(data) return self # TODO, how to be smart about tangents here? @triggers_refreshed_triangulation - def apply_function(self, function, make_smooth=False, **kwargs): + def apply_function( + self, + function: Callable[[np.ndarray], np.ndarray], + make_smooth: bool = False, + **kwargs + ): super().apply_function(function, **kwargs) if self.make_smooth_after_applying_functions or make_smooth: self.make_approximately_smooth() return self - def flip(self, *args, **kwargs): - super().flip(*args, **kwargs) + def flip(self, axis: np.ndarray = UP, **kwargs): + super().flip(axis, **kwargs) self.refresh_unit_normal() self.refresh_triangulation() return self @@ -942,20 +1012,20 @@ class VMobject(Mobject): wrapper.refresh_id() return self - def get_fill_shader_wrapper(self): + def get_fill_shader_wrapper(self) -> ShaderWrapper: self.fill_shader_wrapper.vert_data = self.get_fill_shader_data() self.fill_shader_wrapper.vert_indices = self.get_fill_shader_vert_indices() self.fill_shader_wrapper.uniforms = self.get_shader_uniforms() self.fill_shader_wrapper.depth_test = self.depth_test return self.fill_shader_wrapper - def get_stroke_shader_wrapper(self): + def get_stroke_shader_wrapper(self) -> ShaderWrapper: self.stroke_shader_wrapper.vert_data = self.get_stroke_shader_data() self.stroke_shader_wrapper.uniforms = self.get_stroke_uniforms() self.stroke_shader_wrapper.depth_test = self.depth_test return self.stroke_shader_wrapper - def get_shader_wrapper_list(self): + def get_shader_wrapper_list(self) -> list[ShaderWrapper]: # Build up data lists fill_shader_wrappers = [] stroke_shader_wrappers = [] @@ -984,13 +1054,13 @@ class VMobject(Mobject): result.append(wrapper) return result - def get_stroke_uniforms(self): + def get_stroke_uniforms(self) -> dict[str, float]: result = dict(super().get_shader_uniforms()) result["joint_type"] = JOINT_TYPE_MAP[self.joint_type] result["flat_stroke"] = float(self.flat_stroke) return result - def get_stroke_shader_data(self): + def get_stroke_shader_data(self) -> np.ndarray: points = self.get_points() if len(self.stroke_data) != len(points): self.stroke_data = resize_array(self.stroke_data, len(points)) @@ -1009,7 +1079,7 @@ class VMobject(Mobject): return self.stroke_data - def get_fill_shader_data(self): + def get_fill_shader_data(self) -> np.ndarray: points = self.get_points() if len(self.fill_data) != len(points): self.fill_data = resize_array(self.fill_data, len(points)) @@ -1025,18 +1095,18 @@ class VMobject(Mobject): self.get_fill_shader_data() self.get_stroke_shader_data() - def get_fill_shader_vert_indices(self): + def get_fill_shader_vert_indices(self) -> np.ndarray: return self.get_triangulation() class VGroup(VMobject): - def __init__(self, *vmobjects, **kwargs): + def __init__(self, *vmobjects: VMobject, **kwargs): if not all([isinstance(m, VMobject) for m in vmobjects]): raise Exception("All submobjects must be of type VMobject") super().__init__(**kwargs) self.add(*vmobjects) - def __add__(self: 'VGroup', other: 'VMobject' or 'VGroup'): + def __add__(self, other: VMobject | "VGroup"): assert(isinstance(other, VMobject)) return self.add(other) @@ -1050,14 +1120,14 @@ class VectorizedPoint(Point, VMobject): "artificial_height": 0.01, } - def __init__(self, location=ORIGIN, **kwargs): + def __init__(self, location: np.ndarray = ORIGIN, **kwargs): Point.__init__(self, **kwargs) VMobject.__init__(self, **kwargs) self.set_points(np.array([location])) class CurvesAsSubmobjects(VGroup): - def __init__(self, vmobject, **kwargs): + def __init__(self, vmobject: VMobject, **kwargs): super().__init__(**kwargs) for tup in vmobject.get_bezier_tuples(): part = VMobject() @@ -1073,7 +1143,7 @@ class DashedVMobject(VMobject): "color": WHITE } - def __init__(self, vmobject, **kwargs): + def __init__(self, vmobject: VMobject, **kwargs): super().__init__(**kwargs) num_dashes = self.num_dashes ps_ratio = self.positive_space_ratio From 992e61ddf24b67ac6198dd4a9ee55f8698d422bd Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Sun, 13 Feb 2022 19:02:28 +0800 Subject: [PATCH 06/27] style: rename Color type to ManimColor --- manimlib/mobject/mobject.py | 12 ++++++------ manimlib/mobject/types/point_cloud_mobject.py | 6 +++--- manimlib/mobject/types/vectorized_mobject.py | 14 +++++++------- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index a1426ba7..0a13fbc9 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -43,7 +43,7 @@ Self = TypeVar("Self", bound="Mobject") TimeBasedUpdater = Callable[["Mobject", float], None] NonTimeUpdater = Callable[["Mobject"], None] Updater = Union[TimeBasedUpdater, NonTimeUpdater] -Color = Union[str, colour.Color, Sequence[float]] +ManimColor = Union[str, colour.Color, Sequence[float]] class Mobject(object): @@ -1024,7 +1024,7 @@ class Mobject(object): def set_rgba_array_by_color( self, - color: Color | None = None, + color: ManimColor | None = None, opacity: float | None = None, name: str = "rgbas", recurse: bool = True @@ -1056,7 +1056,7 @@ class Mobject(object): mob.data[name] = rgbas.copy() return self - def set_color(self, color: Color, opacity: float | None = None, recurse: bool = True): + def set_color(self, color: ManimColor, opacity: float | None = None, recurse: bool = True): self.set_rgba_array_by_color(color, opacity, recurse=False) # Recurse to submobjects differently from how set_rgba_array_by_color # in case they implement set_color differently @@ -1078,11 +1078,11 @@ class Mobject(object): def get_opacity(self) -> float: return self.data["rgbas"][0, 3] - def set_color_by_gradient(self, *colors: Color): + def set_color_by_gradient(self, *colors: ManimColor): self.set_submobject_colors_by_gradient(*colors) return self - def set_submobject_colors_by_gradient(self, *colors: Color): + def set_submobject_colors_by_gradient(self, *colors: ManimColor): if len(colors) == 0: raise Exception("Need at least one color") elif len(colors) == 1: @@ -1127,7 +1127,7 @@ class Mobject(object): def add_background_rectangle( self, - color: Color | None = None, + color: ManimColor | None = None, opacity: float = 0.75, **kwargs ): diff --git a/manimlib/mobject/types/point_cloud_mobject.py b/manimlib/mobject/types/point_cloud_mobject.py index 2af3e191..ffde8198 100644 --- a/manimlib/mobject/types/point_cloud_mobject.py +++ b/manimlib/mobject/types/point_cloud_mobject.py @@ -13,7 +13,7 @@ from manimlib.utils.iterables import resize_with_interpolation from manimlib.utils.iterables import resize_array -Color = Union[str, colour.Color, Sequence[float]] +ManimColor = Union[str, colour.Color, Sequence[float]] class PMobject(Mobject): @@ -43,7 +43,7 @@ class PMobject(Mobject): self, points: npt.ArrayLike, rgbas: np.ndarray | None = None, - color: Color | None = None, + color: ManimColor | None = None, opacity: float | None = None ): """ @@ -64,7 +64,7 @@ class PMobject(Mobject): self.data["rgbas"][-len(new_rgbas):] = new_rgbas return self - def set_color_by_gradient(self, *colors: Color): + def set_color_by_gradient(self, *colors: ManimColor): self.data["rgbas"] = np.array(list(map( color_to_rgba, color_gradient(colors, self.get_num_points()) diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index a1a6c29f..8b6d9db1 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -34,7 +34,7 @@ from manimlib.utils.space_ops import z_to_vector from manimlib.shader_wrapper import ShaderWrapper -Color = Union[str, colour.Color, Sequence[float]] +ManimColor = Union[str, colour.Color, Sequence[float]] class VMobject(Mobject): @@ -130,7 +130,7 @@ class VMobject(Mobject): def set_fill( self, - color: Color | None = None, + color: ManimColor | None = None, opacity: float | None = None, recurse: bool = True ): @@ -139,7 +139,7 @@ class VMobject(Mobject): def set_stroke( self, - color: Color | None = None, + color: ManimColor | None = None, width: float | npt.ArrayLike | None = None, opacity: float | None = None, background: bool | None = None, @@ -162,7 +162,7 @@ class VMobject(Mobject): def set_backstroke( self, - color: Color = BLACK, + color: ManimColor = BLACK, width: float | npt.ArrayLike = 3, background: bool = True ): @@ -177,10 +177,10 @@ class VMobject(Mobject): def set_style( self, - fill_color: Color | None = None, + fill_color: ManimColor | None = None, fill_opacity: float | None = None, fill_rgba: npt.ArrayLike | None = None, - stroke_color: Color | None = None, + stroke_color: ManimColor | None = None, stroke_opacity: float | None = None, stroke_rgba: npt.ArrayLike | None = None, stroke_width: float | npt.ArrayLike | None = None, @@ -247,7 +247,7 @@ class VMobject(Mobject): sm1.match_style(sm2) return self - def set_color(self, color: Color, recurse: bool = True): + def set_color(self, color: ManimColor, recurse: bool = True): self.set_fill(color, recurse=recurse) self.set_stroke(color, recurse=recurse) return self From 1064e2bb3019990c83dbc0820ef4bd2c0b9fee47 Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Sun, 13 Feb 2022 19:32:53 +0800 Subject: [PATCH 07/27] chore: add type hints to manimlib.camera --- manimlib/camera/camera.py | 175 ++++++++++++++++++++++---------------- 1 file changed, 104 insertions(+), 71 deletions(-) diff --git a/manimlib/camera/camera.py b/manimlib/camera/camera.py index 18adce91..33e77d91 100644 --- a/manimlib/camera/camera.py +++ b/manimlib/camera/camera.py @@ -1,15 +1,18 @@ -import moderngl -import math -from colour import Color -import OpenGL.GL as gl +from __future__ import annotations -from PIL import Image -import numpy as np +import math import itertools as it +import moderngl +import numpy as np +from PIL import Image +import OpenGL.GL as gl +from colour import Color + from manimlib.constants import * from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Point +from manimlib.shader_wrapper import ShaderWrapper from manimlib.utils.config_ops import digest_config from manimlib.utils.simple_functions import fdiv from manimlib.utils.simple_functions import clip @@ -29,12 +32,12 @@ class CameraFrame(Mobject): "focal_distance": 2, } - def init_data(self): + def init_data(self) -> None: super().init_data() self.data["euler_angles"] = np.array(self.euler_angles, dtype=float) self.refresh_rotation_matrix() - def init_points(self): + def init_points(self) -> None: self.set_points([ORIGIN, LEFT, RIGHT, DOWN, UP]) self.set_width(self.frame_shape[0], stretch=True) self.set_height(self.frame_shape[1], stretch=True) @@ -47,13 +50,13 @@ class CameraFrame(Mobject): self.set_euler_angles(0, 0, 0) return self - def get_euler_angles(self): + def get_euler_angles(self) -> np.ndarray: return self.data["euler_angles"] - def get_inverse_camera_rotation_matrix(self): + def get_inverse_camera_rotation_matrix(self) -> list[list[float]]: return self.inverse_camera_rotation_matrix - def refresh_rotation_matrix(self): + def refresh_rotation_matrix(self) -> None: # Rotate based on camera orientation theta, phi, gamma = self.get_euler_angles() quat = quaternion_mult( @@ -63,7 +66,7 @@ class CameraFrame(Mobject): ) self.inverse_camera_rotation_matrix = rotation_matrix_transpose_from_quaternion(quat) - def rotate(self, angle, axis=OUT, **kwargs): + def rotate(self, angle: float, axis: np.ndarray = OUT, **kwargs): curr_rot_T = self.get_inverse_camera_rotation_matrix() added_rot_T = rotation_matrix_transpose(angle, axis) new_rot_T = np.dot(curr_rot_T, added_rot_T) @@ -78,7 +81,13 @@ class CameraFrame(Mobject): self.set_euler_angles(theta, phi, gamma) return self - def set_euler_angles(self, theta=None, phi=None, gamma=None, units=RADIANS): + def set_euler_angles( + self, + theta: float | None = None, + phi: float | None = None, + gamma: float | None = None, + units: float = RADIANS + ): if theta is not None: self.data["euler_angles"][0] = theta * units if phi is not None: @@ -88,7 +97,12 @@ class CameraFrame(Mobject): self.refresh_rotation_matrix() return self - def reorient(self, theta_degrees=None, phi_degrees=None, gamma_degrees=None): + def reorient( + self, + theta_degrees: float | None = None, + phi_degrees: float | None = None, + gamma_degrees: float | None = None, + ): """ Shortcut for set_euler_angles, defaulting to taking in angles in degrees @@ -96,60 +110,60 @@ class CameraFrame(Mobject): self.set_euler_angles(theta_degrees, phi_degrees, gamma_degrees, units=DEGREES) return self - def set_theta(self, theta): + def set_theta(self, theta: float): return self.set_euler_angles(theta=theta) - def set_phi(self, phi): + def set_phi(self, phi: float): return self.set_euler_angles(phi=phi) - def set_gamma(self, gamma): + def set_gamma(self, gamma: float): return self.set_euler_angles(gamma=gamma) - def increment_theta(self, dtheta): + def increment_theta(self, dtheta: float): self.data["euler_angles"][0] += dtheta self.refresh_rotation_matrix() return self - def increment_phi(self, dphi): + def increment_phi(self, dphi: float): phi = self.data["euler_angles"][1] new_phi = clip(phi + dphi, 0, PI) self.data["euler_angles"][1] = new_phi self.refresh_rotation_matrix() return self - def increment_gamma(self, dgamma): + def increment_gamma(self, dgamma: float): self.data["euler_angles"][2] += dgamma self.refresh_rotation_matrix() return self - def get_theta(self): + def get_theta(self) -> float: return self.data["euler_angles"][0] - def get_phi(self): + def get_phi(self) -> float: return self.data["euler_angles"][1] - def get_gamma(self): + def get_gamma(self) -> float: return self.data["euler_angles"][2] - def get_shape(self): + def get_shape(self) -> tuple[float, float]: return (self.get_width(), self.get_height()) - def get_center(self): + def get_center(self) -> np.ndarray: # Assumes first point is at the center return self.get_points()[0] - def get_width(self): + def get_width(self) -> float: points = self.get_points() return points[2, 0] - points[1, 0] - def get_height(self): + def get_height(self) -> float: points = self.get_points() return points[4, 1] - points[3, 1] - def get_focal_distance(self): + def get_focal_distance(self) -> float: return self.focal_distance * self.get_height() - def get_implied_camera_location(self): + def get_implied_camera_location(self) -> tuple[float, float, float]: theta, phi, gamma = self.get_euler_angles() dist = self.get_focal_distance() x, y, z = self.get_center() @@ -190,10 +204,10 @@ class Camera(object): "samples": 0, } - def __init__(self, ctx=None, **kwargs): + def __init__(self, ctx: moderngl.Context | None = None, **kwargs): digest_config(self, kwargs, locals()) - self.rgb_max_val = np.iinfo(self.pixel_array_dtype).max - self.background_rgba = [ + self.rgb_max_val: float = np.iinfo(self.pixel_array_dtype).max + self.background_rgba: list[float] = [ *Color(self.background_color).get_rgb(), self.background_opacity ] @@ -205,10 +219,10 @@ class Camera(object): self.refresh_perspective_uniforms() self.static_mobject_to_render_group_list = {} - def init_frame(self): + def init_frame(self) -> None: self.frame = CameraFrame(**self.frame_config) - def init_context(self, ctx=None): + def init_context(self, ctx: moderngl.Context | None = None) -> None: if ctx is None: ctx = moderngl.create_standalone_context() fbo = self.get_fbo(ctx, 0) @@ -223,7 +237,7 @@ class Camera(object): fbo_msaa.use() self.fbo_msaa = fbo_msaa - def set_ctx_blending(self, enable=True): + def set_ctx_blending(self, enable: bool = True) -> None: if enable: self.ctx.enable(moderngl.BLEND) else: @@ -233,17 +247,21 @@ class Camera(object): # moderngl.ONE, moderngl.ONE ) - def set_ctx_depth_test(self, enable=True): + def set_ctx_depth_test(self, enable: bool = True) -> None: if enable: self.ctx.enable(moderngl.DEPTH_TEST) else: self.ctx.disable(moderngl.DEPTH_TEST) - def init_light_source(self): + def init_light_source(self) -> None: self.light_source = Point(self.light_source_position) # Methods associated with the frame buffer - def get_fbo(self, ctx, samples=0): + def get_fbo( + self, + ctx: moderngl.Context, + samples: int = 0 + ) -> moderngl.Framebuffer: pw = self.pixel_width ph = self.pixel_height return ctx.framebuffer( @@ -258,16 +276,16 @@ class Camera(object): ) ) - def clear(self): + def clear(self) -> None: self.fbo.clear(*self.background_rgba) self.fbo_msaa.clear(*self.background_rgba) - def reset_pixel_shape(self, new_width, new_height): + def reset_pixel_shape(self, new_width: int, new_height: int) -> None: self.pixel_width = new_width self.pixel_height = new_height self.refresh_perspective_uniforms() - def get_raw_fbo_data(self, dtype='f1'): + def get_raw_fbo_data(self, dtype: str = 'f1') -> bytes: # Copy blocks from the fbo_msaa to the drawn fbo using Blit pw, ph = (self.pixel_width, self.pixel_height) gl.glBindFramebuffer(gl.GL_READ_FRAMEBUFFER, self.fbo_msaa.glo) @@ -279,7 +297,7 @@ class Camera(object): dtype=dtype, ) - def get_image(self, pixel_array=None): + def get_image(self) -> Image: return Image.frombytes( 'RGBA', self.get_pixel_shape(), @@ -287,7 +305,7 @@ class Camera(object): 'raw', 'RGBA', 0, -1 ) - def get_pixel_array(self): + def get_pixel_array(self) -> np.ndarray: raw = self.get_raw_fbo_data(dtype='f4') flat_arr = np.frombuffer(raw, dtype='f4') arr = flat_arr.reshape([*self.fbo.size, self.n_channels]) @@ -295,7 +313,7 @@ class Camera(object): return (self.rgb_max_val * arr).astype(self.pixel_array_dtype) # Needed? - def get_texture(self): + def get_texture(self) -> moderngl.Texture: texture = self.ctx.texture( size=self.fbo.size, components=4, @@ -305,32 +323,32 @@ class Camera(object): return texture # Getting camera attributes - def get_pixel_shape(self): + def get_pixel_shape(self) -> tuple[int, int]: return self.fbo.viewport[2:4] # return (self.pixel_width, self.pixel_height) - def get_pixel_width(self): + def get_pixel_width(self) -> int: return self.get_pixel_shape()[0] - def get_pixel_height(self): + def get_pixel_height(self) -> int: return self.get_pixel_shape()[1] - def get_frame_height(self): + def get_frame_height(self) -> float: return self.frame.get_height() - def get_frame_width(self): + def get_frame_width(self) -> float: return self.frame.get_width() - def get_frame_shape(self): + def get_frame_shape(self) -> tuple[float, float]: return (self.get_frame_width(), self.get_frame_height()) - def get_frame_center(self): + def get_frame_center(self) -> np.ndarray: return self.frame.get_center() - def get_location(self): + def get_location(self) -> tuple[float, float, float]: return self.frame.get_implied_camera_location() - def resize_frame_shape(self, fixed_dimension=0): + def resize_frame_shape(self, fixed_dimension: bool = False) -> None: """ Changes frame_shape to match the aspect ratio of the pixels, where fixed_dimension determines @@ -342,7 +360,7 @@ class Camera(object): frame_height = self.get_frame_height() frame_width = self.get_frame_width() aspect_ratio = fdiv(pixel_width, pixel_height) - if fixed_dimension == 0: + if not fixed_dimension: frame_height = frame_width / aspect_ratio else: frame_width = aspect_ratio * frame_height @@ -350,13 +368,13 @@ class Camera(object): self.frame.set_width(frame_width) # Rendering - def capture(self, *mobjects, **kwargs): + def capture(self, *mobjects: Mobject, **kwargs) -> None: self.refresh_perspective_uniforms() for mobject in mobjects: for render_group in self.get_render_group_list(mobject): self.render(render_group) - def render(self, render_group): + def render(self, render_group: dict[str]) -> None: shader_wrapper = render_group["shader_wrapper"] shader_program = render_group["prog"] self.set_shader_uniforms(shader_program, shader_wrapper) @@ -365,13 +383,17 @@ class Camera(object): if render_group["single_use"]: self.release_render_group(render_group) - def get_render_group_list(self, mobject): + def get_render_group_list(self, mobject: Mobject) -> list[dict[str]] | map[dict[str]]: try: return self.static_mobject_to_render_group_list[id(mobject)] except KeyError: return map(self.get_render_group, mobject.get_shader_wrapper_list()) - def get_render_group(self, shader_wrapper, single_use=True): + def get_render_group( + self, + shader_wrapper: ShaderWrapper, + single_use: bool = True + ) -> dict[str]: # Data buffers vbo = self.ctx.buffer(shader_wrapper.vert_data.tobytes()) if shader_wrapper.vert_indices is None: @@ -399,12 +421,12 @@ class Camera(object): "single_use": single_use, } - def release_render_group(self, render_group): + def release_render_group(self, render_group: dict[str]) -> None: for key in ["vbo", "ibo", "vao"]: if render_group[key] is not None: render_group[key].release() - def set_mobjects_as_static(self, *mobjects): + def set_mobjects_as_static(self, *mobjects: Mobject) -> None: # Creates buffer and array objects holding each mobjects shader data for mob in mobjects: self.static_mobject_to_render_group_list[id(mob)] = [ @@ -412,18 +434,23 @@ class Camera(object): for sw in mob.get_shader_wrapper_list() ] - def release_static_mobjects(self): + def release_static_mobjects(self) -> None: for rg_list in self.static_mobject_to_render_group_list.values(): for render_group in rg_list: self.release_render_group(render_group) self.static_mobject_to_render_group_list = {} # Shaders - def init_shaders(self): + def init_shaders(self) -> None: # Initialize with the null id going to None - self.id_to_shader_program = {"": None} + self.id_to_shader_program: dict[ + int | str, tuple[moderngl.Program, str] | None + ] = {"": None} - def get_shader_program(self, shader_wrapper): + def get_shader_program( + self, + shader_wrapper: ShaderWrapper + ) -> tuple[moderngl.Program, str]: sid = shader_wrapper.get_program_id() if sid not in self.id_to_shader_program: # Create shader program for the first time, then cache @@ -433,7 +460,11 @@ class Camera(object): self.id_to_shader_program[sid] = (program, vert_format) return self.id_to_shader_program[sid] - def set_shader_uniforms(self, shader, shader_wrapper): + def set_shader_uniforms( + self, + shader: moderngl.Program, + shader_wrapper: ShaderWrapper + ) -> None: for name, path in shader_wrapper.texture_paths.items(): tid = self.get_texture_id(path) shader[name].value = tid @@ -445,7 +476,7 @@ class Camera(object): except KeyError: pass - def refresh_perspective_uniforms(self): + def refresh_perspective_uniforms(self) -> None: frame = self.frame pw, ph = self.get_pixel_shape() fw, fh = frame.get_shape() @@ -470,11 +501,13 @@ class Camera(object): "focal_distance": frame.get_focal_distance(), } - def init_textures(self): - self.n_textures = 0 - self.path_to_texture = {} + def init_textures(self) -> None: + self.n_textures: int = 0 + self.path_to_texture: dict[ + str, tuple[int, moderngl.Texture] + ] = {} - def get_texture_id(self, path): + def get_texture_id(self, path: str) -> int: if path not in self.path_to_texture: if self.n_textures == 15: # I have no clue why this is needed self.n_textures += 1 @@ -490,7 +523,7 @@ class Camera(object): self.path_to_texture[path] = (tid, texture) return self.path_to_texture[path][0] - def release_texture(self, path): + def release_texture(self, path: str): tid_and_texture = self.path_to_texture.pop(path, None) if tid_and_texture: tid_and_texture[1].release() From 9a8aee481df574955c74788b0588329e4bc53853 Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Sun, 13 Feb 2022 20:03:05 +0800 Subject: [PATCH 08/27] chore: add type hints to manimlib.event_handler --- manimlib/event_handler/event_dispatcher.py | 25 ++++++++++++---------- manimlib/event_handler/event_listner.py | 15 ++++++++++++- manimlib/mobject/mobject.py | 12 +++++++++-- 3 files changed, 38 insertions(+), 14 deletions(-) diff --git a/manimlib/event_handler/event_dispatcher.py b/manimlib/event_handler/event_dispatcher.py index a760d9ee..69cd07cd 100644 --- a/manimlib/event_handler/event_dispatcher.py +++ b/manimlib/event_handler/event_dispatcher.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from manimlib.event_handler.event_type import EventType @@ -6,21 +8,23 @@ from manimlib.event_handler.event_listner import EventListner class EventDispatcher(object): def __init__(self): - self.event_listners = { + self.event_listners: dict[ + EventType, list[EventListner] + ] = { event_type: [] for event_type in EventType } self.mouse_point = np.array((0., 0., 0.)) self.mouse_drag_point = np.array((0., 0., 0.)) - self.pressed_keys = set() - self.draggable_object_listners = [] + self.pressed_keys: set[str] = set() + self.draggable_object_listners: list[EventListner] = [] - def add_listner(self, event_listner): + def add_listner(self, event_listner: EventListner): assert(isinstance(event_listner, EventListner)) self.event_listners[event_listner.event_type].append(event_listner) return self - def remove_listner(self, event_listner): + def remove_listner(self, event_listner: EventListner): assert(isinstance(event_listner, EventListner)) try: while event_listner in self.event_listners[event_listner.event_type]: @@ -30,8 +34,7 @@ class EventDispatcher(object): pass return self - def dispatch(self, event_type, **event_data): - + def dispatch(self, event_type: EventType, **event_data): if event_type == EventType.MouseMotionEvent: self.mouse_point = event_data["point"] elif event_type == EventType.MouseDragEvent: @@ -74,16 +77,16 @@ class EventDispatcher(object): return propagate_event - def get_listners_count(self): + def get_listners_count(self) -> int: return sum([len(value) for key, value in self.event_listners.items()]) - def get_mouse_point(self): + def get_mouse_point(self) -> np.ndarray: return self.mouse_point - def get_mouse_drag_point(self): + def get_mouse_drag_point(self) -> np.ndarray: return self.mouse_drag_point - def is_key_pressed(self, symbol): + def is_key_pressed(self, symbol) -> bool: return (symbol in self.pressed_keys) __iadd__ = add_listner diff --git a/manimlib/event_handler/event_listner.py b/manimlib/event_handler/event_listner.py index 2f8663f7..5784c497 100644 --- a/manimlib/event_handler/event_listner.py +++ b/manimlib/event_handler/event_listner.py @@ -1,5 +1,18 @@ +from __future__ import annotations + +from typing import Callable + +from manimlib.mobject.mobject import Mobject +from manimlib.event_handler.event_type import EventType + + class EventListner(object): - def __init__(self, mobject, event_type, event_callback): + def __init__( + self, + mobject: Mobject, + event_type: EventType, + event_callback: Callable[[Mobject, dict[str]]] + ): self.mobject = mobject self.event_type = event_type self.callback = event_callback diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 0a13fbc9..935c08a6 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -1686,13 +1686,21 @@ class Mobject(object): def init_event_listners(self): self.event_listners: list[EventListner] = [] - def add_event_listner(self, event_type: EventType, event_callback: Callable): + def add_event_listner( + self, + event_type: EventType, + event_callback: Callable[[Mobject, dict[str]]] + ): event_listner = EventListner(self, event_type, event_callback) self.event_listners.append(event_listner) EVENT_DISPATCHER.add_listner(event_listner) return self - def remove_event_listner(self, event_type: EventType, event_callback: Callable): + def remove_event_listner( + self, + event_type: EventType, + event_callback: Callable[[Mobject, dict[str]]] + ): event_listner = EventListner(self, event_type, event_callback) while event_listner in self.event_listners: self.event_listners.remove(event_listner) From 960463d143f873306af55e3c9445c2b20f80da25 Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Sun, 13 Feb 2022 20:47:04 +0800 Subject: [PATCH 09/27] docs: remove support for python 3.6 --- README.md | 2 +- docs/source/getting_started/installation.rst | 2 +- setup.cfg | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 4c6afec9..c22c675a 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ Note, there are two versions of manim. This repository began as a personal proj > > **Note**: To install manim directly through pip, please pay attention to the name of the installed package. This repository is ManimGL of 3b1b. The package name is `manimgl` instead of `manim` or `manimlib`. Please use `pip install manimgl` to install the version in this repository. -Manim runs on Python 3.6 or higher (Python 3.8 is recommended). +Manim runs on Python 3.7 or higher. System requirements are [FFmpeg](https://ffmpeg.org/), [OpenGL](https://www.opengl.org/) and [LaTeX](https://www.latex-project.org) (optional, if you want to use LaTeX). For Linux, [Pango](https://pango.gnome.org) along with its development headers are required. See instruction [here](https://github.com/ManimCommunity/ManimPango#building). diff --git a/docs/source/getting_started/installation.rst b/docs/source/getting_started/installation.rst index 6600cc1e..b6b8b531 100644 --- a/docs/source/getting_started/installation.rst +++ b/docs/source/getting_started/installation.rst @@ -1,7 +1,7 @@ Installation ============ -Manim runs on Python 3.6 or higher (Python 3.8 is recommended). +Manim runs on Python 3.7 or higher. System requirements are: diff --git a/setup.cfg b/setup.cfg index 4b65f258..5e8d053a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,7 +18,6 @@ classifiers = Topic :: Scientific/Engineering Topic :: Multimedia :: Video Topic :: Multimedia :: Graphics - Programming Language :: Python :: 3.6 Programming Language :: Python :: 3.7 Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 From 7fb6f352c4a21e34d149e1d13b2d965707693bc3 Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Mon, 14 Feb 2022 20:02:24 +0800 Subject: [PATCH 10/27] fix: fix some bugs caused by type hints and imports --- manimlib/event_handler/event_listner.py | 15 +-------------- manimlib/mobject/mobject.py | 2 +- 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/manimlib/event_handler/event_listner.py b/manimlib/event_handler/event_listner.py index 5784c497..2f8663f7 100644 --- a/manimlib/event_handler/event_listner.py +++ b/manimlib/event_handler/event_listner.py @@ -1,18 +1,5 @@ -from __future__ import annotations - -from typing import Callable - -from manimlib.mobject.mobject import Mobject -from manimlib.event_handler.event_type import EventType - - class EventListner(object): - def __init__( - self, - mobject: Mobject, - event_type: EventType, - event_callback: Callable[[Mobject, dict[str]]] - ): + def __init__(self, mobject, event_type, event_callback): self.mobject = mobject self.event_type = event_type self.callback = event_callback diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 90acf22f..f409edb6 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -650,7 +650,7 @@ class Mobject(object): Otherwise, if about_point is given a value, scaling is done with respect to that point. """ - if isinstance(scale_factor, npt.ArrayLike): + if isinstance(scale_factor, Iterable): scale_factor = np.array(scale_factor).clip(min=min_scale_factor) else: scale_factor = max(scale_factor, min_scale_factor) From be5de32d70f37696660618ab002b83961b196ac0 Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Mon, 14 Feb 2022 21:22:18 +0800 Subject: [PATCH 11/27] chore: add type hints to manimlib.scene --- manimlib/camera/camera.py | 2 +- manimlib/event_handler/event_dispatcher.py | 4 +- manimlib/scene/scene.py | 211 ++++++++++++++------- manimlib/scene/scene_file_writer.py | 90 +++++---- 4 files changed, 198 insertions(+), 109 deletions(-) diff --git a/manimlib/camera/camera.py b/manimlib/camera/camera.py index 33e77d91..c433dbdc 100644 --- a/manimlib/camera/camera.py +++ b/manimlib/camera/camera.py @@ -297,7 +297,7 @@ class Camera(object): dtype=dtype, ) - def get_image(self) -> Image: + def get_image(self) -> Image.Image: return Image.frombytes( 'RGBA', self.get_pixel_shape(), diff --git a/manimlib/event_handler/event_dispatcher.py b/manimlib/event_handler/event_dispatcher.py index 69cd07cd..34eb55eb 100644 --- a/manimlib/event_handler/event_dispatcher.py +++ b/manimlib/event_handler/event_dispatcher.py @@ -16,7 +16,7 @@ class EventDispatcher(object): } self.mouse_point = np.array((0., 0., 0.)) self.mouse_drag_point = np.array((0., 0., 0.)) - self.pressed_keys: set[str] = set() + self.pressed_keys: set[int] = set() self.draggable_object_listners: list[EventListner] = [] def add_listner(self, event_listner: EventListner): @@ -86,7 +86,7 @@ class EventDispatcher(object): def get_mouse_drag_point(self) -> np.ndarray: return self.mouse_drag_point - def is_key_pressed(self, symbol) -> bool: + def is_key_pressed(self, symbol: int) -> bool: return (symbol in self.pressed_keys) __iadd__ = add_listner diff --git a/manimlib/scene/scene.py b/manimlib/scene/scene.py index 514e2b9f..25f88a84 100644 --- a/manimlib/scene/scene.py +++ b/manimlib/scene/scene.py @@ -1,12 +1,16 @@ -import inspect +from __future__ import annotations + +import time import random +import inspect import platform import itertools as it from functools import wraps +from typing import Iterable, Callable from tqdm import tqdm as ProgressDisplay import numpy as np -import time +import numpy.typing as npt from manimlib.animation.animation import prepare_animation from manimlib.animation.transform import MoveToTarget @@ -22,6 +26,11 @@ from manimlib.event_handler.event_type import EventType from manimlib.event_handler import EVENT_DISPATCHER from manimlib.logger import log +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from PIL.Image import Image + from manimlib.animation.animation import Animation + class Scene(object): CONFIG = { @@ -50,13 +59,13 @@ class Scene(object): else: self.window = None - self.camera = self.camera_class(**self.camera_config) + self.camera: Camera = self.camera_class(**self.camera_config) self.file_writer = SceneFileWriter(self, **self.file_writer_config) - self.mobjects = [self.camera.frame] - self.num_plays = 0 - self.time = 0 - self.skip_time = 0 - self.original_skipping_status = self.skip_animations + self.mobjects: list[Mobject] = [self.camera.frame] + self.num_plays: int = 0 + self.time: float = 0 + self.skip_time: float = 0 + self.original_skipping_status: bool = self.skip_animations if self.start_at_animation_number is not None: self.skip_animations = True @@ -70,9 +79,9 @@ class Scene(object): random.seed(self.random_seed) np.random.seed(self.random_seed) - def run(self): - self.virtual_animation_start_time = 0 - self.real_animation_start_time = time.time() + def run(self) -> None: + self.virtual_animation_start_time: float = 0 + self.real_animation_start_time: float = time.time() self.file_writer.begin() self.setup() @@ -82,7 +91,7 @@ class Scene(object): pass self.tear_down() - def setup(self): + def setup(self) -> None: """ This is meant to be implement by any scenes which are comonly subclassed, and have some common setup @@ -90,18 +99,18 @@ class Scene(object): """ pass - def construct(self): + def construct(self) -> None: # Where all the animation happens # To be implemented in subclasses pass - def tear_down(self): + def tear_down(self) -> None: self.stop_skipping() self.file_writer.finish() if self.window and self.linger_after_completion: self.interact() - def interact(self): + def interact(self) -> None: # If there is a window, enter a loop # which updates the frame while under # the hood calling the pyglet event loop @@ -116,7 +125,7 @@ class Scene(object): if self.quit_interaction: self.unlock_mobject_data() - def embed(self, close_scene_on_exit=True): + def embed(self, close_scene_on_exit: bool = True) -> None: if not self.preview: # If the scene is just being # written, ignore embed calls @@ -145,18 +154,18 @@ class Scene(object): if close_scene_on_exit: raise EndSceneEarlyException() - def __str__(self): + def __str__(self) -> str: return self.__class__.__name__ # Only these methods should touch the camera - def get_image(self): + def get_image(self) -> Image: return self.camera.get_image() - def show(self): + def show(self) -> None: self.update_frame(ignore_skipping=True) self.get_image().show() - def update_frame(self, dt=0, ignore_skipping=False): + def update_frame(self, dt: float = 0, ignore_skipping: bool = False) -> None: self.increment_time(dt) self.update_mobjects(dt) if self.skip_animations and not ignore_skipping: @@ -174,22 +183,22 @@ class Scene(object): if rt < vt: self.update_frame(0) - def emit_frame(self): + def emit_frame(self) -> None: if not self.skip_animations: self.file_writer.write_frame(self.camera) # Related to updating - def update_mobjects(self, dt): + def update_mobjects(self, dt: float) -> None: for mobject in self.mobjects: mobject.update(dt) - def should_update_mobjects(self): + def should_update_mobjects(self) -> bool: return self.always_update_mobjects or any([ len(mob.get_family_updaters()) > 0 for mob in self.mobjects ]) - def has_time_based_updaters(self): + def has_time_based_updaters(self) -> bool: return any([ sm.has_time_based_updater() for mob in self.mobjects() @@ -197,14 +206,14 @@ class Scene(object): ]) # Related to time - def get_time(self): + def get_time(self) -> float: return self.time - def increment_time(self, dt): + def increment_time(self, dt: float) -> None: self.time += dt # Related to internal mobject organization - def get_top_level_mobjects(self): + def get_top_level_mobjects(self) -> list[Mobject]: # Return only those which are not in the family # of another mobject from the scene mobjects = self.get_mobjects() @@ -218,10 +227,10 @@ class Scene(object): return num_families == 1 return list(filter(is_top_level, mobjects)) - def get_mobject_family_members(self): + def get_mobject_family_members(self) -> list[Mobject]: return extract_mobject_family_members(self.mobjects) - def add(self, *new_mobjects): + def add(self, *new_mobjects: Mobject): """ Mobjects will be displayed, from background to foreground in the order with which they are added. @@ -230,7 +239,7 @@ class Scene(object): self.mobjects += new_mobjects return self - def add_mobjects_among(self, values): + def add_mobjects_among(self, values: Iterable): """ This is meant mostly for quick prototyping, e.g. to add all mobjects defined up to a point, @@ -242,17 +251,17 @@ class Scene(object): )) return self - def remove(self, *mobjects_to_remove): + def remove(self, *mobjects_to_remove: Mobject): self.mobjects = restructure_list_to_exclude_certain_family_members( self.mobjects, mobjects_to_remove ) return self - def bring_to_front(self, *mobjects): + def bring_to_front(self, *mobjects: Mobject): self.add(*mobjects) return self - def bring_to_back(self, *mobjects): + def bring_to_back(self, *mobjects: Mobject): self.remove(*mobjects) self.mobjects = list(mobjects) + self.mobjects return self @@ -261,13 +270,18 @@ class Scene(object): self.mobjects = [] return self - def get_mobjects(self): + def get_mobjects(self) -> list[Mobject]: return list(self.mobjects) - def get_mobject_copies(self): + def get_mobject_copies(self) -> list[Mobject]: return [m.copy() for m in self.mobjects] - def point_to_mobject(self, point, search_set=None, buff=0): + def point_to_mobject( + self, + point: np.ndarray, + search_set: Iterable[Mobject] | None = None, + buff: float = 0 + ) -> Mobject | None: """ E.g. if clicking on the scene, this returns the top layer mobject under a given point @@ -280,7 +294,7 @@ class Scene(object): return None # Related to skipping - def update_skipping_status(self): + def update_skipping_status(self) -> None: if self.start_at_animation_number is not None: if self.num_plays == self.start_at_animation_number: self.skip_time = self.time @@ -290,12 +304,18 @@ class Scene(object): if self.num_plays >= self.end_at_animation_number: raise EndSceneEarlyException() - def stop_skipping(self): + def stop_skipping(self) -> None: self.virtual_animation_start_time = self.time self.skip_animations = False # Methods associated with running animations - def get_time_progression(self, run_time, n_iterations=None, desc="", override_skip_animations=False): + def get_time_progression( + self, + run_time: float, + n_iterations: int | None = None, + desc: str = "", + override_skip_animations: bool = False + ) -> list[float] | np.ndarray | ProgressDisplay: if self.skip_animations and not override_skip_animations: return [run_time] else: @@ -314,10 +334,13 @@ class Scene(object): desc=desc, ) - def get_run_time(self, animations): + def get_run_time(self, animations: Iterable[Animation]) -> float: return np.max([animation.run_time for animation in animations]) - def get_animation_time_progression(self, animations): + def get_animation_time_progression( + self, + animations: Iterable[Animation] + ) -> list[float] | np.ndarray | ProgressDisplay: run_time = self.get_run_time(animations) description = f"{self.num_plays} {animations[0]}" if len(animations) > 1: @@ -325,14 +348,18 @@ class Scene(object): time_progression = self.get_time_progression(run_time, desc=description) return time_progression - def get_wait_time_progression(self, duration, stop_condition=None): + def get_wait_time_progression( + self, + duration: float, + stop_condition: Callable[[], bool] | None = None + ) -> list[float] | np.ndarray | ProgressDisplay: kw = {"desc": f"{self.num_plays} Waiting"} if stop_condition is not None: kw["n_iterations"] = -1 # So it doesn't show % progress kw["override_skip_animations"] = True return self.get_time_progression(duration, **kw) - def anims_from_play_args(self, *args, **kwargs): + def anims_from_play_args(self, *args, **kwargs) -> list[Animation]: """ Each arg can either be an animation, or a mobject method followed by that methods arguments (and potentially follow @@ -422,7 +449,7 @@ class Scene(object): self.num_plays += 1 return wrapper - def lock_static_mobject_data(self, *animations): + def lock_static_mobject_data(self, *animations: Animation) -> None: movers = list(it.chain(*[ anim.mobject.get_family() for anim in animations @@ -432,7 +459,7 @@ class Scene(object): continue self.camera.set_mobjects_as_static(mobject) - def unlock_mobject_data(self): + def unlock_mobject_data(self) -> None: self.camera.release_static_mobjects() def refresh_locked_data(self): @@ -440,7 +467,7 @@ class Scene(object): self.lock_static_mobject_data() return self - def begin_animations(self, animations): + def begin_animations(self, animations: Iterable[Animation]) -> None: for animation in animations: animation.begin() # Anything animated that's not already in the @@ -451,7 +478,7 @@ class Scene(object): if animation.mobject not in self.mobjects: self.add(animation.mobject) - def progress_through_animations(self, animations): + def progress_through_animations(self, animations: Iterable[Animation]) -> None: last_t = 0 for t in self.get_animation_time_progression(animations): dt = t - last_t @@ -463,7 +490,7 @@ class Scene(object): self.update_frame(dt) self.emit_frame() - def finish_animations(self, animations): + def finish_animations(self, animations: Iterable[Animation]) -> None: for animation in animations: animation.finish() animation.clean_up_from_scene(self) @@ -473,7 +500,7 @@ class Scene(object): self.update_mobjects(0) @handle_play_like_call - def play(self, *args, **kwargs): + def play(self, *args, **kwargs) -> None: if len(args) == 0: log.warning("Called Scene.play with no animations") return @@ -485,11 +512,13 @@ class Scene(object): self.unlock_mobject_data() @handle_play_like_call - def wait(self, - duration=DEFAULT_WAIT_TIME, - stop_condition=None, - note=None, - ignore_presenter_mode=False): + def wait( + self, + duration: float = DEFAULT_WAIT_TIME, + stop_condition: Callable[[], bool] = None, + note: str = None, + ignore_presenter_mode: bool = False + ): if note: log.info(note) self.update_mobjects(dt=0) # Any problems with this? @@ -512,7 +541,11 @@ class Scene(object): self.unlock_mobject_data() return self - def wait_until(self, stop_condition, max_time=60): + def wait_until( + self, + stop_condition: Callable[[], bool], + max_time: float = 60 + ): self.wait(max_time, stop_condition=stop_condition) def force_skipping(self): @@ -525,14 +558,20 @@ class Scene(object): self.skip_animations = self.original_skipping_status return self - def add_sound(self, sound_file, time_offset=0, gain=None, **kwargs): + def add_sound( + self, + sound_file: str, + time_offset: float = 0, + gain: float | None = None, + gain_to_background: float | None = None + ): if self.skip_animations: return time = self.get_time() + time_offset - self.file_writer.add_sound(sound_file, time, gain, **kwargs) + self.file_writer.add_sound(sound_file, time, gain, gain_to_background) # Helpers for interactive development - def save_state(self): + def save_state(self) -> None: self.saved_state = { "mobjects": self.mobjects, "mobject_states": [ @@ -541,7 +580,7 @@ class Scene(object): ], } - def restore(self): + def restore(self) -> None: if not hasattr(self, "saved_state"): raise Exception("Trying to restore scene without having saved") mobjects = self.saved_state["mobjects"] @@ -552,7 +591,11 @@ class Scene(object): # Event handling - def on_mouse_motion(self, point, d_point): + def on_mouse_motion( + self, + point: np.ndarray, + d_point: np.ndarray + ) -> None: self.mouse_point.move_to(point) event_data = {"point": point, "d_point": d_point} @@ -572,7 +615,13 @@ class Scene(object): shift = np.dot(np.transpose(transform), shift) frame.shift(shift) - def on_mouse_drag(self, point, d_point, buttons, modifiers): + def on_mouse_drag( + self, + point: np.ndarray, + d_point: np.ndarray, + buttons: int, + modifiers: int + ) -> None: self.mouse_drag_point.move_to(point) event_data = {"point": point, "d_point": d_point, "buttons": buttons, "modifiers": modifiers} @@ -580,19 +629,33 @@ class Scene(object): if propagate_event is not None and propagate_event is False: return - def on_mouse_press(self, point, button, mods): + def on_mouse_press( + self, + point: np.ndarray, + button: int, + mods: int + ) -> None: event_data = {"point": point, "button": button, "mods": mods} propagate_event = EVENT_DISPATCHER.dispatch(EventType.MousePressEvent, **event_data) if propagate_event is not None and propagate_event is False: return - def on_mouse_release(self, point, button, mods): + def on_mouse_release( + self, + point: np.ndarray, + button: int, + mods: int + ) -> None: event_data = {"point": point, "button": button, "mods": mods} propagate_event = EVENT_DISPATCHER.dispatch(EventType.MouseReleaseEvent, **event_data) if propagate_event is not None and propagate_event is False: return - def on_mouse_scroll(self, point, offset): + def on_mouse_scroll( + self, + point: np.ndarray, + offset: np.ndarray + ) -> None: event_data = {"point": point, "offset": offset} propagate_event = EVENT_DISPATCHER.dispatch(EventType.MouseScrollEvent, **event_data) if propagate_event is not None and propagate_event is False: @@ -607,13 +670,21 @@ class Scene(object): shift = np.dot(np.transpose(transform), offset) frame.shift(-20.0 * shift) - def on_key_release(self, symbol, modifiers): + def on_key_release( + self, + symbol: int, + modifiers: int + ) -> None: event_data = {"symbol": symbol, "modifiers": modifiers} propagate_event = EVENT_DISPATCHER.dispatch(EventType.KeyReleaseEvent, **event_data) if propagate_event is not None and propagate_event is False: return - def on_key_press(self, symbol, modifiers): + def on_key_press( + self, + symbol: int, + modifiers: int + ) -> None: try: char = chr(symbol) except OverflowError: @@ -634,16 +705,16 @@ class Scene(object): elif char == "e": self.embed(close_scene_on_exit=False) - def on_resize(self, width: int, height: int): + def on_resize(self, width: int, height: int) -> None: self.camera.reset_pixel_shape(width, height) - def on_show(self): + def on_show(self) -> None: pass - def on_hide(self): + def on_hide(self) -> None: pass - def on_close(self): + def on_close(self) -> None: pass diff --git a/manimlib/scene/scene_file_writer.py b/manimlib/scene/scene_file_writer.py index 297e96d0..d3d4ee29 100644 --- a/manimlib/scene/scene_file_writer.py +++ b/manimlib/scene/scene_file_writer.py @@ -1,10 +1,13 @@ -import numpy as np -from pydub import AudioSegment -import shutil -import subprocess as sp +from __future__ import annotations + import os import sys +import shutil import platform +import subprocess as sp + +import numpy as np +from pydub import AudioSegment from tqdm import tqdm as ProgressDisplay from manimlib.constants import FFMPEG_BIN @@ -15,6 +18,12 @@ from manimlib.utils.file_ops import get_sorted_integer_files from manimlib.utils.sounds import get_full_sound_file_path from manimlib.logger import log +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from manimlib.scene.scene import Scene + from manimlib.camera.camera import Camera + from PIL.Image import Image + class SceneFileWriter(object): CONFIG = { @@ -42,14 +51,14 @@ class SceneFileWriter(object): def __init__(self, scene, **kwargs): digest_config(self, kwargs) - self.scene = scene - self.writing_process = None - self.has_progress_display = False + self.scene: Scene = scene + self.writing_process: sp.Popen | None = None + self.has_progress_display: bool = False self.init_output_directories() self.init_audio() # Output directories and files - def init_output_directories(self): + def init_output_directories(self) -> None: out_dir = self.output_directory if self.mirror_module_path: module_dir = self.get_default_module_directory() @@ -69,13 +78,13 @@ class SceneFileWriter(object): movie_dir, "partial_movie_files", scene_name, )) - def get_default_module_directory(self): + def get_default_module_directory(self) -> str: path, _ = os.path.splitext(self.input_file_path) if path.startswith("_"): path = path[1:] return path - def get_default_scene_name(self): + def get_default_scene_name(self) -> str: name = str(self.scene) saan = self.scene.start_at_animation_number eaan = self.scene.end_at_animation_number @@ -85,7 +94,7 @@ class SceneFileWriter(object): name += f"_{eaan}" return name - def get_resolution_directory(self): + def get_resolution_directory(self) -> str: pixel_height = self.scene.camera.pixel_height frame_rate = self.scene.camera.frame_rate return "{}p{}".format( @@ -93,10 +102,10 @@ class SceneFileWriter(object): ) # Directory getters - def get_image_file_path(self): + def get_image_file_path(self) -> str: return self.image_file_path - def get_next_partial_movie_path(self): + def get_next_partial_movie_path(self) -> str: result = os.path.join( self.partial_movie_directory, "{:05}{}".format( @@ -106,19 +115,22 @@ class SceneFileWriter(object): ) return result - def get_movie_file_path(self): + def get_movie_file_path(self) -> str: return self.movie_file_path # Sound - def init_audio(self): - self.includes_sound = False + def init_audio(self) -> None: + self.includes_sound: bool = False - def create_audio_segment(self): + def create_audio_segment(self) -> None: self.audio_segment = AudioSegment.silent() - def add_audio_segment(self, new_segment, - time=None, - gain_to_background=None): + def add_audio_segment( + self, + new_segment: AudioSegment, + time: float | None = None, + gain_to_background: float | None = None + ) -> None: if not self.includes_sound: self.includes_sound = True self.create_audio_segment() @@ -142,27 +154,33 @@ class SceneFileWriter(object): gain_during_overlay=gain_to_background, ) - def add_sound(self, sound_file, time=None, gain=None, **kwargs): + def add_sound( + self, + sound_file: str, + time: float | None = None, + gain: float | None = None, + gain_to_background: float | None = None + ) -> None: file_path = get_full_sound_file_path(sound_file) new_segment = AudioSegment.from_file(file_path) if gain: new_segment = new_segment.apply_gain(gain) - self.add_audio_segment(new_segment, time, **kwargs) + self.add_audio_segment(new_segment, time, gain_to_background) # Writers - def begin(self): + def begin(self) -> None: if not self.break_into_partial_movies and self.write_to_movie: self.open_movie_pipe(self.get_movie_file_path()) - def begin_animation(self): + def begin_animation(self) -> None: if self.break_into_partial_movies and self.write_to_movie: self.open_movie_pipe(self.get_next_partial_movie_path()) - def end_animation(self): + def end_animation(self) -> None: if self.break_into_partial_movies and self.write_to_movie: self.close_movie_pipe() - def finish(self): + def finish(self) -> None: if self.write_to_movie: if self.break_into_partial_movies: self.combine_movie_files() @@ -177,7 +195,7 @@ class SceneFileWriter(object): if self.should_open_file(): self.open_file() - def open_movie_pipe(self, file_path): + def open_movie_pipe(self, file_path: str) -> None: stem, ext = os.path.splitext(file_path) self.final_file_path = file_path self.temp_file_path = stem + "_temp" + ext @@ -223,7 +241,7 @@ class SceneFileWriter(object): ) self.has_progress_display = True - def set_progress_display_subdescription(self, sub_desc): + def set_progress_display_subdescription(self, sub_desc: str) -> None: desc_len = self.progress_description_len file = os.path.split(self.get_movie_file_path())[1] full_desc = f"Rendering {file} ({sub_desc})" @@ -233,14 +251,14 @@ class SceneFileWriter(object): full_desc += " " * (desc_len - len(full_desc)) self.progress_display.set_description(full_desc) - def write_frame(self, camera): + def write_frame(self, camera: Camera) -> None: if self.write_to_movie: raw_bytes = camera.get_raw_fbo_data() self.writing_process.stdin.write(raw_bytes) if self.has_progress_display: self.progress_display.update() - def close_movie_pipe(self): + def close_movie_pipe(self) -> None: self.writing_process.stdin.close() self.writing_process.wait() self.writing_process.terminate() @@ -248,7 +266,7 @@ class SceneFileWriter(object): self.progress_display.close() shutil.move(self.temp_file_path, self.final_file_path) - def combine_movie_files(self): + def combine_movie_files(self) -> None: kwargs = { "remove_non_integer_files": True, "extension": self.movie_file_extension, @@ -296,7 +314,7 @@ class SceneFileWriter(object): combine_process = sp.Popen(commands) combine_process.wait() - def add_sound_to_video(self): + def add_sound_to_video(self) -> None: movie_file_path = self.get_movie_file_path() stem, ext = os.path.splitext(movie_file_path) sound_file_path = stem + ".wav" @@ -327,22 +345,22 @@ class SceneFileWriter(object): shutil.move(temp_file_path, movie_file_path) os.remove(sound_file_path) - def save_final_image(self, image): + def save_final_image(self, image: Image) -> None: file_path = self.get_image_file_path() image.save(file_path) self.print_file_ready_message(file_path) - def print_file_ready_message(self, file_path): + def print_file_ready_message(self, file_path: str) -> None: if not self.quiet: log.info(f"File ready at {file_path}") - def should_open_file(self): + def should_open_file(self) -> bool: return any([ self.show_file_location_upon_completion, self.open_file_upon_completion, ]) - def open_file(self): + def open_file(self) -> None: if self.quiet: curr_stdout = sys.stdout sys.stdout = open(os.devnull, "w") From 62cab9feaf050444722ac551ae7ba6fd48e3bcb4 Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Mon, 14 Feb 2022 21:25:46 +0800 Subject: [PATCH 12/27] chore: re-add type hint for EventListener --- manimlib/event_handler/event_listner.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/manimlib/event_handler/event_listner.py b/manimlib/event_handler/event_listner.py index 2f8663f7..4552cf8c 100644 --- a/manimlib/event_handler/event_listner.py +++ b/manimlib/event_handler/event_listner.py @@ -1,5 +1,18 @@ +from __future__ import annotations + +from typing import Callable, TYPE_CHECKING + +if TYPE_CHECKING: + from manimlib.mobject.mobject import Mobject + from manimlib.event_handler.event_type import EventType + class EventListner(object): - def __init__(self, mobject, event_type, event_callback): + def __init__( + self, + mobject: Mobject, + event_type: EventType, + event_callback: Callable[[Mobject, dict[str]]] + ): self.mobject = mobject self.event_type = event_type self.callback = event_callback From 66caf0c1adf6bf1f46426d8f2e3c70e1adb9a57a Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Mon, 14 Feb 2022 21:34:56 +0800 Subject: [PATCH 13/27] chore: only import some classes when type checking --- manimlib/camera/camera.py | 5 ++++- manimlib/mobject/types/surface.py | 5 ++++- manimlib/utils/debug.py | 5 ++++- manimlib/utils/family_ops.py | 4 +++- 4 files changed, 15 insertions(+), 4 deletions(-) diff --git a/manimlib/camera/camera.py b/manimlib/camera/camera.py index c433dbdc..94b09b37 100644 --- a/manimlib/camera/camera.py +++ b/manimlib/camera/camera.py @@ -12,7 +12,6 @@ from colour import Color from manimlib.constants import * from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Point -from manimlib.shader_wrapper import ShaderWrapper from manimlib.utils.config_ops import digest_config from manimlib.utils.simple_functions import fdiv from manimlib.utils.simple_functions import clip @@ -22,6 +21,10 @@ from manimlib.utils.space_ops import rotation_matrix_transpose from manimlib.utils.space_ops import quaternion_from_angle_axis from manimlib.utils.space_ops import quaternion_mult +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from manimlib.shader_wrapper import ShaderWrapper + class CameraFrame(Mobject): CONFIG = { diff --git a/manimlib/mobject/types/surface.py b/manimlib/mobject/types/surface.py index a8b4fd5c..01eb912f 100644 --- a/manimlib/mobject/types/surface.py +++ b/manimlib/mobject/types/surface.py @@ -7,7 +7,6 @@ import numpy as np import numpy.typing as npt from manimlib.constants import * -from manimlib.camera.camera import Camera from manimlib.mobject.mobject import Mobject from manimlib.utils.bezier import integer_interpolate from manimlib.utils.bezier import interpolate @@ -15,6 +14,10 @@ from manimlib.utils.images import get_full_raster_image_path from manimlib.utils.iterables import listify from manimlib.utils.space_ops import normalize_along_axis +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from manimlib.camera.camera import Camera + class Surface(Mobject): CONFIG = { diff --git a/manimlib/utils/debug.py b/manimlib/utils/debug.py index be3e9527..308d4906 100644 --- a/manimlib/utils/debug.py +++ b/manimlib/utils/debug.py @@ -5,11 +5,14 @@ import numpy as np from typing import Callable from manimlib.constants import BLACK -from manimlib.mobject.mobject import Mobject from manimlib.mobject.numbers import Integer from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.logger import log +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from manimlib.mobject.mobject import Mobject + def print_family(mobject: Mobject, n_tabs: int = 0) -> None: """For debugging purposes""" diff --git a/manimlib/utils/family_ops.py b/manimlib/utils/family_ops.py index 1f18614b..3218fab6 100644 --- a/manimlib/utils/family_ops.py +++ b/manimlib/utils/family_ops.py @@ -3,7 +3,9 @@ from __future__ import annotations import itertools as it from typing import Iterable -from manimlib.mobject.mobject import Mobject +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from manimlib.mobject.mobject import Mobject def extract_mobject_family_members( From 9bdcc8b63567d8d30cc0394f0aa2e97cc2626812 Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Mon, 14 Feb 2022 21:41:45 +0800 Subject: [PATCH 14/27] style: remove quotes of annotations according to PEP 563 --- manimlib/mobject/mobject.py | 79 +++++++++---------- manimlib/mobject/types/point_cloud_mobject.py | 4 +- manimlib/mobject/types/vectorized_mobject.py | 16 ++-- 3 files changed, 49 insertions(+), 50 deletions(-) diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index f409edb6..32080d42 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -39,7 +39,6 @@ from manimlib.event_handler.event_listner import EventListner from manimlib.event_handler.event_type import EventType -Self = TypeVar("Self", bound="Mobject") TimeBasedUpdater = Callable[["Mobject", float], None] NonTimeUpdater = Callable[["Mobject"], None] Updater = Union[TimeBasedUpdater, NonTimeUpdater] @@ -77,9 +76,9 @@ class Mobject(object): def __init__(self, **kwargs): digest_config(self, kwargs) - self.submobjects: list["Mobject"] = [] - self.parents: list["Mobject"] = [] - self.family: list["Mobject"] = [self] + self.submobjects: list[Mobject] = [] + self.parents: list[Mobject] = [] + self.family: list[Mobject] = [self] self.locked_data_keys: set[str] = set() self.needs_new_bounding_box: bool = True @@ -97,11 +96,11 @@ class Mobject(object): def __str__(self): return self.__class__.__name__ - def __add__(self, other: "Mobject") -> "Mobject": + def __add__(self, other: Mobject) -> Mobject: assert(isinstance(other, Mobject)) return self.get_group_class()(self, other) - def __mul__(self, other: int) -> "Mobject": + def __mul__(self, other: int) -> Mobject: assert(isinstance(other, int)) return self.replicate(other) @@ -208,7 +207,7 @@ class Mobject(object): # Others related to points - def match_points(self, mobject: "Mobject"): + def match_points(self, mobject: Mobject): self.set_points(mobject.get_points()) return self @@ -311,7 +310,7 @@ class Mobject(object): def family_members_with_points(self): return [m for m in self.get_family() if m.has_points()] - def add(self, *mobjects: "Mobject"): + def add(self, *mobjects: Mobject): if self in mobjects: raise Exception("Mobject cannot contain self") for mobject in mobjects: @@ -322,7 +321,7 @@ class Mobject(object): self.assemble_family() return self - def remove(self, *mobjects: "Mobject"): + def remove(self, *mobjects: Mobject): for mobject in mobjects: if mobject in self.submobjects: self.submobjects.remove(mobject) @@ -331,11 +330,11 @@ class Mobject(object): self.assemble_family() return self - def add_to_back(self, *mobjects: "Mobject"): + def add_to_back(self, *mobjects: Mobject): self.set_submobjects(list_update(mobjects, self.submobjects)) return self - def replace_submobject(self, index: int, new_submob: "Mobject"): + def replace_submobject(self, index: int, new_submob: Mobject): old_submob = self.submobjects[index] if self in old_submob.parents: old_submob.parents.remove(self) @@ -343,12 +342,12 @@ class Mobject(object): self.assemble_family() return self - def insert_submobject(self, index: int, new_submob: "Mobject"): + def insert_submobject(self, index: int, new_submob: Mobject): self.submobjects.insert(index, new_submob) self.assemble_family() return self - def set_submobjects(self, submobject_list: list["Mobject"]): + def set_submobjects(self, submobject_list: list[Mobject]): self.remove(*self.submobjects) self.add(*submobject_list) return self @@ -441,7 +440,7 @@ class Mobject(object): def sort( self, point_to_num_func: Callable[[np.ndarray], float] = lambda p: p[0], - submob_func: Callable[["Mobject"]] | None = None + submob_func: Callable[[Mobject]] | None = None ): if submob_func is not None: self.submobjects.sort(key=submob_func) @@ -596,7 +595,7 @@ class Mobject(object): submob.clear_updaters() return self - def match_updaters(self, mobject: "Mobject"): + def match_updaters(self, mobject: Mobject): self.clear_updaters() for updater in mobject.get_updaters(): self.add_updater(updater) @@ -798,11 +797,11 @@ class Mobject(object): def next_to( self, - mobject_or_point: "Mobject" | np.ndarray, + mobject_or_point: Mobject | np.ndarray, direction: np.ndarray = RIGHT, buff: float = DEFAULT_MOBJECT_TO_MOBJECT_BUFFER, aligned_edge: np.ndarray = ORIGIN, - submobject_to_align: "Mobject" | None = None, + submobject_to_align: Mobject | None = None, index_of_submobject_to_align: int | slice | None = None, coor_mask: np.ndarray = np.array([1, 1, 1]), ): @@ -938,7 +937,7 @@ class Mobject(object): def move_to( self, - point_or_mobject: "Mobject" | np.ndarray, + point_or_mobject: Mobject | np.ndarray, aligned_edge: np.ndarray = ORIGIN, coor_mask: np.ndarray = np.array([1, 1, 1]) ): @@ -950,7 +949,7 @@ class Mobject(object): self.shift((target - point_to_align) * coor_mask) return self - def replace(self, mobject: "Mobject", dim_to_match: int = 0, stretch: bool = False): + def replace(self, mobject: Mobject, dim_to_match: int = 0, stretch: bool = False): if not mobject.get_num_points() and not mobject.submobjects: self.scale(0) return self @@ -968,7 +967,7 @@ class Mobject(object): def surround( self, - mobject: "Mobject", + mobject: Mobject, dim_to_match: int = 0, stretch: bool = False, buff: float = MED_SMALL_BUFF @@ -1292,27 +1291,27 @@ class Mobject(object): # Match other mobject properties - def match_color(self, mobject: "Mobject"): + def match_color(self, mobject: Mobject): return self.set_color(mobject.get_color()) - def match_dim_size(self, mobject: "Mobject", dim: int, **kwargs): + def match_dim_size(self, mobject: Mobject, dim: int, **kwargs): return self.rescale_to_fit( mobject.length_over_dim(dim), dim, **kwargs ) - def match_width(self, mobject: "Mobject", **kwargs): + def match_width(self, mobject: Mobject, **kwargs): return self.match_dim_size(mobject, 0, **kwargs) - def match_height(self, mobject: "Mobject", **kwargs): + def match_height(self, mobject: Mobject, **kwargs): return self.match_dim_size(mobject, 1, **kwargs) - def match_depth(self, mobject: "Mobject", **kwargs): + def match_depth(self, mobject: Mobject, **kwargs): return self.match_dim_size(mobject, 2, **kwargs) def match_coord( self, - mobject_or_point: "Mobject" | np.ndarray, + mobject_or_point: Mobject | np.ndarray, dim: int, direction: np.ndarray = ORIGIN ): @@ -1324,28 +1323,28 @@ class Mobject(object): def match_x( self, - mobject_or_point: "Mobject" | np.ndarray, + mobject_or_point: Mobject | np.ndarray, direction: np.ndarray = ORIGIN ): return self.match_coord(mobject_or_point, 0, direction) def match_y( self, - mobject_or_point: "Mobject" | np.ndarray, + mobject_or_point: Mobject | np.ndarray, direction: np.ndarray = ORIGIN ): return self.match_coord(mobject_or_point, 1, direction) def match_z( self, - mobject_or_point: "Mobject" | np.ndarray, + mobject_or_point: Mobject | np.ndarray, direction: np.ndarray = ORIGIN ): return self.match_coord(mobject_or_point, 2, direction) def align_to( self, - mobject_or_point: "Mobject" | np.ndarray, + mobject_or_point: Mobject | np.ndarray, direction: np.ndarray = ORIGIN ): """ @@ -1372,11 +1371,11 @@ class Mobject(object): # Alignment - def align_data_and_family(self, mobject: "Mobject") -> None: + def align_data_and_family(self, mobject: Mobject) -> None: self.align_family(mobject) self.align_data(mobject) - def align_data(self, mobject: "Mobject") -> None: + def align_data(self, mobject: Mobject) -> None: # In case any data arrays get resized when aligned to shader data self.refresh_shader_data() for mob1, mob2 in zip(self.get_family(), mobject.get_family()): @@ -1393,13 +1392,13 @@ class Mobject(object): elif len(arr1) > len(arr2): mob2.data[key] = resize_preserving_order(arr2, len(arr1)) - def align_points(self, mobject: "Mobject"): + def align_points(self, mobject: Mobject): max_len = max(self.get_num_points(), mobject.get_num_points()) for mob in (self, mobject): mob.resize_points(max_len, resize_func=resize_preserving_order) return self - def align_family(self, mobject: "Mobject"): + def align_family(self, mobject: Mobject): mob1 = self mob2 = mobject n1 = len(mob1) @@ -1456,8 +1455,8 @@ class Mobject(object): def interpolate( self, - mobject1: "Mobject", - mobject2: "Mobject", + mobject1: Mobject, + mobject2: Mobject, alpha: float, path_func: Callable[[np.ndarray, np.ndarray, float], np.ndarray] = straight_path ): @@ -1496,7 +1495,7 @@ class Mobject(object): """ pass # To implement in subclass - def become(self, mobject: "Mobject"): + def become(self, mobject: Mobject): """ Edit all data and submobjects to be idential to another mobject @@ -1524,7 +1523,7 @@ class Mobject(object): self.refresh_shader_data() self.locked_data_keys = set(keys) - def lock_matching_data(self, mobject1: "Mobject", mobject2: "Mobject"): + def lock_matching_data(self, mobject1: Mobject, mobject2: Mobject): for sm, sm1, sm2 in zip(self.get_family(), mobject1.get_family(), mobject2.get_family()): keys = sm.data.keys() & sm1.data.keys() & sm2.data.keys() sm.lock_data(list(filter( @@ -1811,13 +1810,13 @@ class Mobject(object): class Group(Mobject): - def __init__(self, *mobjects: "Mobject", **kwargs): + def __init__(self, *mobjects: Mobject, **kwargs): if not all([isinstance(m, Mobject) for m in mobjects]): raise Exception("All submobjects must be of type Mobject") Mobject.__init__(self, **kwargs) self.add(*mobjects) - def __add__(self, other: "Mobject" | "Group"): + def __add__(self, other: Mobject | Group): assert(isinstance(other, Mobject)) return self.add(other) diff --git a/manimlib/mobject/types/point_cloud_mobject.py b/manimlib/mobject/types/point_cloud_mobject.py index 5cab597a..5de41173 100644 --- a/manimlib/mobject/types/point_cloud_mobject.py +++ b/manimlib/mobject/types/point_cloud_mobject.py @@ -77,7 +77,7 @@ class PMobject(Mobject): ))) return self - def match_colors(self, pmobject: "PMobject"): + def match_colors(self, pmobject: PMobject): self.data["rgbas"][:] = resize_with_interpolation( pmobject.data["rgbas"], self.get_num_points() ) @@ -116,7 +116,7 @@ class PMobject(Mobject): index = alpha * (self.get_num_points() - 1) return self.get_points()[int(index)] - def pointwise_become_partial(self, pmobject: "PMobject", a: float, b: float): + def pointwise_become_partial(self, pmobject: PMobject, a: float, b: float): lower_index = int(a * pmobject.get_num_points()) upper_index = int(b * pmobject.get_num_points()) for key in self.data: diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index 8b6d9db1..3d27a07f 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -233,7 +233,7 @@ class VMobject(Mobject): "shadow": self.get_shadow(), } - def match_style(self, vmobject: "VMobject", recurse: bool = True): + def match_style(self, vmobject: VMobject, recurse: bool = True): self.set_style(**vmobject.get_style(), recurse=False) if recurse: # Does its best to match up submobject lists, and @@ -553,7 +553,7 @@ class VMobject(Mobject): self.append_points(points) return self - def append_vectorized_mobject(self, vectorized_mobject: "VMobject"): + def append_vectorized_mobject(self, vectorized_mobject: VMobject): new_points = list(vectorized_mobject.get_points()) if self.has_new_path_started(): @@ -739,7 +739,7 @@ class VMobject(Mobject): return self # Alignment - def align_points(self, vmobject: "VMobject"): + def align_points(self, vmobject: VMobject): if self.get_num_points() == len(vmobject.get_points()): return @@ -827,8 +827,8 @@ class VMobject(Mobject): def interpolate( self, - mobject1: "VMobject", - mobject2: "VMobject", + mobject1: VMobject, + mobject2: VMobject, alpha: float, *args, **kwargs ): @@ -840,7 +840,7 @@ class VMobject(Mobject): self.refresh_triangulation() return self - def pointwise_become_partial(self, vmobject: "VMobject", a: float, b: float): + def pointwise_become_partial(self, vmobject: VMobject, a: float, b: float): assert(isinstance(vmobject, VMobject)) if a <= 0 and b >= 1: self.become(vmobject) @@ -882,7 +882,7 @@ class VMobject(Mobject): self.set_points(new_points) return self - def get_subcurve(self, a: float, b: float) -> "VMobject": + def get_subcurve(self, a: float, b: float) -> VMobject: vmob = self.copy() vmob.pointwise_become_partial(self, a, b) return vmob @@ -1106,7 +1106,7 @@ class VGroup(VMobject): super().__init__(**kwargs) self.add(*vmobjects) - def __add__(self, other: VMobject | "VGroup"): + def __add__(self, other: VMobject | VGroup): assert(isinstance(other, VMobject)) return self.add(other) From 61c70b426c389b0907a54032bd6ab0949ce92ff0 Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Mon, 14 Feb 2022 21:43:22 +0800 Subject: [PATCH 15/27] remove unnecessary import --- manimlib/mobject/mobject.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 32080d42..6506dda4 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -5,7 +5,7 @@ import copy import random import itertools as it from functools import wraps -from typing import Iterable, TypeVar, Callable, Union, Sequence +from typing import Iterable, Callable, Union, Sequence import colour import moderngl From 773e013af90da9c85e6d8f5969036da0da1d4eff Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Mon, 14 Feb 2022 22:55:41 +0800 Subject: [PATCH 16/27] chore: add type hints to manimlib.mobject.svg --- manimlib/mobject/svg/brace.py | 58 +++++++++---- manimlib/mobject/svg/mtex_mobject.py | 117 ++++++++++++++++----------- manimlib/mobject/svg/svg_mobject.py | 54 +++++++------ manimlib/mobject/svg/tex_mobject.py | 65 ++++++++++----- manimlib/mobject/svg/text_mobject.py | 60 ++++++++------ 5 files changed, 226 insertions(+), 128 deletions(-) diff --git a/manimlib/mobject/svg/brace.py b/manimlib/mobject/svg/brace.py index 31217a28..f9d96cec 100644 --- a/manimlib/mobject/svg/brace.py +++ b/manimlib/mobject/svg/brace.py @@ -1,11 +1,15 @@ -import numpy as np +from __future__ import annotations + import math import copy +from typing import Iterable + +import numpy as np -from manimlib.animation.composition import AnimationGroup from manimlib.constants import * from manimlib.animation.fading import FadeIn from manimlib.animation.growing import GrowFromCenter +from manimlib.animation.composition import AnimationGroup from manimlib.mobject.svg.tex_mobject import Tex from manimlib.mobject.svg.tex_mobject import SingleStringTex from manimlib.mobject.svg.tex_mobject import TexText @@ -14,6 +18,10 @@ from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.config_ops import digest_config from manimlib.utils.space_ops import get_norm +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from manimlib.mobject.mobject import Mobject + from manimlib.animation.animation import Animation class Brace(SingleStringTex): CONFIG = { @@ -21,7 +29,12 @@ class Brace(SingleStringTex): "tex_string": r"\underbrace{\qquad}" } - def __init__(self, mobject, direction=DOWN, **kwargs): + def __init__( + self, + mobject: Mobject, + direction: np.ndarray = DOWN, + **kwargs + ): digest_config(self, kwargs, locals()) angle = -math.atan2(*direction[:2]) + PI mobject.rotate(-angle, about_point=ORIGIN) @@ -36,7 +49,7 @@ class Brace(SingleStringTex): for mob in mobject, self: mob.rotate(angle, about_point=ORIGIN) - def set_initial_width(self, width): + def set_initial_width(self, width: float): width_diff = width - self.get_width() if width_diff > 0: for tip, rect, vect in [(self[0], self[1], RIGHT), (self[5], self[4], LEFT)]: @@ -49,7 +62,12 @@ class Brace(SingleStringTex): self.set_width(width, stretch=True) return self - def put_at_tip(self, mob, use_next_to=True, **kwargs): + def put_at_tip( + self, + mob: Mobject, + use_next_to: bool = True, + **kwargs + ): if use_next_to: mob.next_to( self.get_tip(), @@ -63,24 +81,24 @@ class Brace(SingleStringTex): mob.shift(self.get_direction() * shift_distance) return self - def get_text(self, text, **kwargs): + def get_text(self, text: str, **kwargs) -> Text: buff = kwargs.pop("buff", SMALL_BUFF) text_mob = Text(text, **kwargs) self.put_at_tip(text_mob, buff=buff) return text_mob - def get_tex(self, *tex, **kwargs): + def get_tex(self, *tex: str, **kwargs) -> Tex: tex_mob = Tex(*tex) self.put_at_tip(tex_mob, **kwargs) return tex_mob - def get_tip(self): + def get_tip(self) -> np.ndarray: # Very specific to the LaTeX representation # of a brace, but it's the only way I can think # of to get the tip regardless of orientation. return self.get_all_points()[self.tip_point_index] - def get_direction(self): + def get_direction(self) -> np.ndarray: vect = self.get_tip() - self.get_center() return vect / get_norm(vect) @@ -92,14 +110,20 @@ class BraceLabel(VMobject): "label_buff": DEFAULT_MOBJECT_TO_MOBJECT_BUFFER } - def __init__(self, obj, text, brace_direction=DOWN, **kwargs): + def __init__( + self, + obj: list[VMobject] | Mobject, + text: Iterable[str] | str, + brace_direction: np.ndarray = DOWN, + **kwargs + ) -> None: VMobject.__init__(self, **kwargs) self.brace_direction = brace_direction if isinstance(obj, list): obj = VMobject(*obj) self.brace = Brace(obj, brace_direction, **kwargs) - if isinstance(text, tuple) or isinstance(text, list): + if isinstance(text, Iterable): self.label = self.label_constructor(*text, **kwargs) else: self.label = self.label_constructor(str(text)) @@ -109,10 +133,14 @@ class BraceLabel(VMobject): self.brace.put_at_tip(self.label, buff=self.label_buff) self.set_submobjects([self.brace, self.label]) - def creation_anim(self, label_anim=FadeIn, brace_anim=GrowFromCenter): + def creation_anim( + self, + label_anim: Animation = FadeIn, + brace_anim: Animation=GrowFromCenter + ) -> AnimationGroup: return AnimationGroup(brace_anim(self.brace), label_anim(self.label)) - def shift_brace(self, obj, **kwargs): + def shift_brace(self, obj: list[VMobject] | Mobject, **kwargs): if isinstance(obj, list): obj = VMobject(*obj) self.brace = Brace(obj, self.brace_direction, **kwargs) @@ -120,7 +148,7 @@ class BraceLabel(VMobject): self.submobjects[0] = self.brace return self - def change_label(self, *text, **kwargs): + def change_label(self, *text: str, **kwargs): self.label = self.label_constructor(*text, **kwargs) if self.label_scale != 1: self.label.scale(self.label_scale) @@ -129,7 +157,7 @@ class BraceLabel(VMobject): self.submobjects[1] = self.label return self - def change_brace_label(self, obj, *text): + def change_brace_label(self, obj: list[VMobject] | Mobject, *text: str): self.shift_brace(obj) self.change_label(*text) return self diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 84c0cbf5..c7b1438c 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -1,6 +1,10 @@ -import itertools as it +from __future__ import annotations + import re +import colour +import itertools as it from types import MethodType +from typing import Iterable, Union, Sequence from manimlib.constants import BLACK from manimlib.mobject.svg.svg_mobject import SVGMobject @@ -14,14 +18,16 @@ from manimlib.utils.tex_file_writing import get_tex_config from manimlib.utils.tex_file_writing import display_during_execution from manimlib.logger import log +ManimColor = Union[str, colour.Color, Sequence[float]] + SCALE_FACTOR_PER_FONT_POINT = 0.001 -TEX_HASH_TO_MOB_MAP = {} +TEX_HASH_TO_MOB_MAP: dict[int, VGroup] = {} -def _get_neighbouring_pairs(iterable): +def _get_neighbouring_pairs(iterable: Iterable) -> list: return list(adjacent_pairs(iterable))[:-1] @@ -38,17 +44,19 @@ class _TexSVG(SVGMobject): class _TexParser(object): - def __init__(self, tex_string, additional_substrings): + def __init__(self, tex_string: str, additional_substrings: str): self.tex_string = tex_string self.whitespace_indices = self.get_whitespace_indices() self.backslash_indices = self.get_backslash_indices() self.script_indices = self.get_script_indices() self.brace_indices_dict = self.get_brace_indices_dict() - self.tex_span_list = [] - self.script_span_to_char_dict = {} - self.script_span_to_tex_span_dict = {} - self.neighbouring_script_span_pairs = [] - self.specified_substrings = [] + self.tex_span_list: list[tuple[int, int]] = [] + self.script_span_to_char_dict: dict[tuple[int, int], str] = {} + self.script_span_to_tex_span_dict: dict[ + tuple[int, int], tuple[int, int] + ] = {} + self.neighbouring_script_span_pairs: list[tuple[int, int]] = [] + self.specified_substrings: list[str] = [] self.add_tex_span((0, len(tex_string))) self.break_up_by_scripts() self.break_up_by_double_braces() @@ -59,17 +67,17 @@ class _TexParser(object): ) self.containing_labels_dict = self.get_containing_labels_dict() - def add_tex_span(self, tex_span): + def add_tex_span(self, tex_span: tuple[int, int]) -> None: if tex_span not in self.tex_span_list: self.tex_span_list.append(tex_span) - def get_whitespace_indices(self): + def get_whitespace_indices(self) -> list[int]: return [ match_obj.start() for match_obj in re.finditer(r"\s", self.tex_string) ] - def get_backslash_indices(self): + def get_backslash_indices(self) -> list[int]: # Newlines (`\\`) don't count. return [ match_obj.end() - 1 @@ -77,19 +85,19 @@ class _TexParser(object): if len(match_obj.group()) % 2 == 1 ] - def filter_out_escaped_characters(self, indices): + def filter_out_escaped_characters(self, indices) -> list[int]: return list(filter( lambda index: index - 1 not in self.backslash_indices, indices )) - def get_script_indices(self): + def get_script_indices(self) -> list[int]: return self.filter_out_escaped_characters([ match_obj.start() for match_obj in re.finditer(r"[_^]", self.tex_string) ]) - def get_brace_indices_dict(self): + def get_brace_indices_dict(self) -> dict[int, int]: tex_string = self.tex_string indices = self.filter_out_escaped_characters([ match_obj.start() @@ -105,7 +113,7 @@ class _TexParser(object): result[left_brace_index] = index return result - def break_up_by_scripts(self): + def break_up_by_scripts(self) -> None: # Match subscripts & superscripts. tex_string = self.tex_string whitespace_indices = self.whitespace_indices @@ -154,7 +162,7 @@ class _TexParser(object): if span_0[1] == span_1[0]: self.neighbouring_script_span_pairs.append((span_0, span_1)) - def break_up_by_double_braces(self): + def break_up_by_double_braces(self) -> None: # Match paired double braces (`{{...}}`). tex_string = self.tex_string reversed_indices_dict = dict( @@ -178,7 +186,10 @@ class _TexParser(object): self.specified_substrings.append(tex_string[slice(*tex_span)]) skip = True - def break_up_by_additional_substrings(self, additional_substrings): + def break_up_by_additional_substrings( + self, + additional_substrings: Iterable[str] + ) -> None: stripped_substrings = sorted(remove_list_redundancies([ string.strip() for string in additional_substrings @@ -208,7 +219,7 @@ class _TexParser(object): continue self.add_tex_span((span_begin, span_end)) - def get_containing_labels_dict(self): + def get_containing_labels_dict(self) -> dict[tuple[int, int], list[int]]: tex_span_list = self.tex_span_list result = { tex_span: [] @@ -233,7 +244,7 @@ class _TexParser(object): raise ValueError return result - def get_labelled_tex_string(self): + def get_labelled_tex_string(self) -> str: indices, _, flags, labels = zip(*sorted([ (*tex_span[::(1, -1)[flag]], flag, label) for label, tex_span in enumerate(self.tex_span_list) @@ -251,7 +262,7 @@ class _TexParser(object): return "".join(it.chain(*zip(command_pieces, string_pieces))) @staticmethod - def get_color_command(label): + def get_color_command(label: int) -> str: rg, b = divmod(label, 256) r, g = divmod(rg, 256) return "".join([ @@ -261,7 +272,7 @@ class _TexParser(object): "}" ]) - def get_sorted_submob_indices(self, submob_labels): + def get_sorted_submob_indices(self, submob_labels: Iterable[int]) -> list[int]: def script_span_to_submob_range(script_span): tex_span = self.script_span_to_tex_span_dict[script_span] submob_indices = [ @@ -295,7 +306,7 @@ class _TexParser(object): ] return result - def get_submob_tex_strings(self, submob_labels): + def get_submob_tex_strings(self, submob_labels: Iterable[int]) -> list[str]: ordered_tex_spans = [ self.tex_span_list[label] for label in submob_labels ] @@ -356,7 +367,10 @@ class _TexParser(object): ])) return result - def find_span_components_of_custom_span(self, custom_span): + def find_span_components_of_custom_span( + self, + custom_span: tuple[int, int] + ) -> list[tuple[int, int]] | None: skipped_indices = sorted(it.chain( self.whitespace_indices, self.script_indices @@ -384,16 +398,19 @@ class _TexParser(object): span_begin = next_begin return result - def get_containing_labels_by_tex_spans(self, tex_spans): + def get_containing_labels_by_tex_spans( + self, + tex_spans: Iterable[tuple[int, int]] + ) -> list[int]: return remove_list_redundancies(list(it.chain(*[ self.containing_labels_dict[tex_span] for tex_span in tex_spans ]))) - def get_specified_substrings(self): + def get_specified_substrings(self) -> list[str]: return self.specified_substrings - def get_isolated_substrings(self): + def get_isolated_substrings(self) -> list[str]: return remove_list_redundancies([ self.tex_string[slice(*tex_span)] for tex_span in self.tex_span_list @@ -412,7 +429,7 @@ class MTex(VMobject): "use_plain_tex": False, } - def __init__(self, tex_string, **kwargs): + def __init__(self, tex_string: str, **kwargs): super().__init__(**kwargs) tex_string = tex_string.strip() # Prevent from passing an empty string. @@ -431,12 +448,12 @@ class MTex(VMobject): self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) @staticmethod - def color_to_label(color): + def color_to_label(color: ManimColor) -> list: r, g, b = color_to_int_rgb(color) rg = r * 256 + g return rg * 256 + b - def generate_mobject(self): + def generate_mobject(self) -> VGroup: labelled_tex_string = self.__parser.get_labelled_tex_string() labelled_tex_content = self.get_tex_file_content(labelled_tex_string) hash_val = hash((labelled_tex_content, self.use_plain_tex)) @@ -471,7 +488,7 @@ class MTex(VMobject): TEX_HASH_TO_MOB_MAP[hash_val] = mob return mob - def get_tex_file_content(self, tex_string): + def get_tex_file_content(self, tex_string: str) -> str: if self.tex_environment: tex_string = "\n".join([ f"\\begin{{{self.tex_environment}}}", @@ -483,7 +500,7 @@ class MTex(VMobject): return tex_string @staticmethod - def tex_content_to_glyphs(tex_content): + def tex_content_to_glyphs(tex_content: str) -> _TexSVG: tex_config = get_tex_config() full_tex = tex_config["tex_body"].replace( tex_config["text_to_replace"], @@ -492,7 +509,11 @@ class MTex(VMobject): filename = tex_to_svg_file(full_tex) return _TexSVG(filename) - def build_mobject(self, svg_glyphs, glyph_labels): + def build_mobject( + self, + svg_glyphs: _TexSVG | None, + glyph_labels: Iterable[int] + ) -> VGroup: if not svg_glyphs: return VGroup() @@ -530,14 +551,17 @@ class MTex(VMobject): submob.get_tex = MethodType(lambda inst: inst.tex_string, submob) return VGroup(*rearranged_submobjects) - def get_part_by_tex_spans(self, tex_spans): + def get_part_by_tex_spans( + self, + tex_spans: Iterable[tuple[int, int]] + ) -> VGroup: labels = self.__parser.get_containing_labels_by_tex_spans(tex_spans) return VGroup(*filter( lambda submob: submob.submob_label in labels, self.submobjects )) - def get_part_by_custom_span(self, custom_span): + def get_part_by_custom_span(self, custom_span: tuple[int, int]) -> VGroup: tex_spans = self.__parser.find_span_components_of_custom_span( custom_span ) @@ -546,7 +570,7 @@ class MTex(VMobject): raise ValueError(f"Failed to match mobjects from tex: \"{tex}\"") return self.get_part_by_tex_spans(tex_spans) - def get_parts_by_tex(self, tex): + def get_parts_by_tex(self, tex: str) -> VGroup: return VGroup(*[ self.get_part_by_custom_span(match_obj.span()) for match_obj in re.finditer( @@ -554,20 +578,23 @@ class MTex(VMobject): ) ]) - def get_part_by_tex(self, tex, index=0): + def get_part_by_tex(self, tex: str, index: int = 0) -> VGroup: all_parts = self.get_parts_by_tex(tex) return all_parts[index] - def set_color_by_tex(self, tex, color): + def set_color_by_tex(self, tex: str, color: ManimColor): self.get_parts_by_tex(tex).set_color(color) return self - def set_color_by_tex_to_color_map(self, tex_to_color_map): + def set_color_by_tex_to_color_map( + self, + tex_to_color_map: dict[str, ManimColor] + ): for tex, color in tex_to_color_map.items(): self.set_color_by_tex(tex, color) return self - def indices_of_part(self, part): + def indices_of_part(self, part: Iterable[VGroup]) -> list[int]: indices = [ index for index, submob in enumerate(self.submobjects) if submob in part @@ -576,23 +603,23 @@ class MTex(VMobject): raise ValueError("Failed to find part in tex") return indices - def indices_of_part_by_tex(self, tex, index=0): + def indices_of_part_by_tex(self, tex: str, index: int = 0) -> list[int]: part = self.get_part_by_tex(tex, index=index) return self.indices_of_part(part) - def get_tex(self): + def get_tex(self) -> str: return self.tex_string - def get_submob_tex(self): + def get_submob_tex(self) -> list[str]: return [ submob.get_tex() for submob in self.submobjects ] - def get_specified_substrings(self): + def get_specified_substrings(self) -> list[str]: return self.__parser.get_specified_substrings() - def get_isolated_substrings(self): + def get_isolated_substrings(self) -> list[str]: return self.__parser.get_isolated_substrings() diff --git a/manimlib/mobject/svg/svg_mobject.py b/manimlib/mobject/svg/svg_mobject.py index fd79dffa..ac3c66b5 100644 --- a/manimlib/mobject/svg/svg_mobject.py +++ b/manimlib/mobject/svg/svg_mobject.py @@ -1,7 +1,10 @@ +from __future__ import annotations + import os import re import hashlib import itertools as it +from typing import Callable import svgelements as se import numpy as np @@ -20,7 +23,7 @@ from manimlib.utils.images import get_full_vector_image_path from manimlib.logger import log -def _convert_point_to_3d(x, y): +def _convert_point_to_3d(x: float, y: float) -> np.ndarray: return np.array([x, y, 0.0]) @@ -41,7 +44,7 @@ class SVGMobject(VMobject): "path_string_config": {} } - def __init__(self, file_name=None, **kwargs): + def __init__(self, file_name: str | None = None, **kwargs): digest_config(self, kwargs) self.file_name = file_name or self.file_name if file_name is None: @@ -51,7 +54,7 @@ class SVGMobject(VMobject): super().__init__(**kwargs) self.move_into_position() - def move_into_position(self): + def move_into_position(self) -> None: if self.should_center: self.center() if self.height is not None: @@ -59,7 +62,7 @@ class SVGMobject(VMobject): if self.width is not None: self.set_width(self.width) - def init_colors(self): + def init_colors(self) -> None: # Remove fill_color, fill_opacity, # stroke_width, stroke_color, stroke_opacity # as each submobject may have those values specified in svg file @@ -68,7 +71,7 @@ class SVGMobject(VMobject): self.set_flat_stroke(self.flat_stroke) return self - def init_points(self): + def init_points(self) -> None: with open(self.file_path, "r") as svg_file: svg_string = svg_file.read() @@ -96,7 +99,7 @@ class SVGMobject(VMobject): self.flip(RIGHT) # Flip y self.scale(0.75) - def modify_svg_file(self, svg_string): + def modify_svg_file(self, svg_string: str) -> str: # svgelements cannot handle em, ex units # Convert them using 1em = 16px, 1ex = 0.5em = 8px def convert_unit(match_obj): @@ -127,7 +130,7 @@ class SVGMobject(VMobject): return result - def generate_context_values_from_config(self): + def generate_context_values_from_config(self) -> dict[str]: result = {} if self.stroke_width is not None: result["stroke-width"] = self.stroke_width @@ -145,7 +148,7 @@ class SVGMobject(VMobject): result["stroke-opacity"] = self.stroke_opacity return result - def get_mobjects_from(self, shape): + def get_mobjects_from(self, shape) -> list[VMobject]: if isinstance(shape, se.Group): return list(it.chain(*( self.get_mobjects_from(child) @@ -161,7 +164,7 @@ class SVGMobject(VMobject): return [mob] @staticmethod - def handle_transform(mob, matrix): + def handle_transform(mob: VMobject, matrix: se.Matrix) -> VMobject: mat = np.array([ [matrix.a, matrix.c], [matrix.b, matrix.d] @@ -171,8 +174,10 @@ class SVGMobject(VMobject): mob.shift(vec) return mob - def get_mobject_from(self, shape): - shape_class_to_func_map = { + def get_mobject_from(self, shape: se.Shape | se.Text) -> VMobject | None: + shape_class_to_func_map: dict[ + type, Callable[[se.Shape | se.Text], VMobject] + ] = { se.Path: self.path_to_mobject, se.SimpleLine: self.line_to_mobject, se.Rect: self.rect_to_mobject, @@ -194,7 +199,10 @@ class SVGMobject(VMobject): return None @staticmethod - def apply_style_to_mobject(mob, shape): + def apply_style_to_mobject( + mob: VMobject, + shape: se.Shape | se.Text + ) -> VMobject: mob.set_style( stroke_width=shape.stroke_width, stroke_color=shape.stroke.hex, @@ -204,16 +212,16 @@ class SVGMobject(VMobject): ) return mob - def path_to_mobject(self, path): + def path_to_mobject(self, path: se.Path) -> VMobjectFromSVGPath: return VMobjectFromSVGPath(path, **self.path_string_config) - def line_to_mobject(self, line): + def line_to_mobject(self, line: se.Line) -> Line: return Line( start=_convert_point_to_3d(line.x1, line.y1), end=_convert_point_to_3d(line.x2, line.y2) ) - def rect_to_mobject(self, rect): + def rect_to_mobject(self, rect: se.Rect) -> Rectangle | RoundedRectangle: if rect.rx == 0 or rect.ry == 0: mob = Rectangle( width=rect.width, @@ -232,7 +240,7 @@ class SVGMobject(VMobject): )) return mob - def circle_to_mobject(self, circle): + def circle_to_mobject(self, circle: se.Circle) -> Circle: # svgelements supports `rx` & `ry` but `r` mob = Circle(radius=circle.rx) mob.shift(_convert_point_to_3d( @@ -240,7 +248,7 @@ class SVGMobject(VMobject): )) return mob - def ellipse_to_mobject(self, ellipse): + def ellipse_to_mobject(self, ellipse: se.Ellipse) -> Circle: mob = Circle(radius=ellipse.rx) mob.stretch_to_fit_height(2 * ellipse.ry) mob.shift(_convert_point_to_3d( @@ -248,21 +256,21 @@ class SVGMobject(VMobject): )) return mob - def polygon_to_mobject(self, polygon): + def polygon_to_mobject(self, polygon: se.Polygon) -> Polygon: points = [ _convert_point_to_3d(*point) for point in polygon ] return Polygon(*points) - def polyline_to_mobject(self, polyline): + def polyline_to_mobject(self, polyline: se.Polyline) -> Polyline: points = [ _convert_point_to_3d(*point) for point in polyline ] return Polyline(*points) - def text_to_mobject(self, text): + def text_to_mobject(self, text: se.Text): pass @@ -273,13 +281,13 @@ class VMobjectFromSVGPath(VMobject): "should_remove_null_curves": False, } - def __init__(self, path_obj, **kwargs): + def __init__(self, path_obj: se.Path, **kwargs): # Get rid of arcs path_obj.approximate_arcs_with_quads() self.path_obj = path_obj super().__init__(**kwargs) - def init_points(self): + def init_points(self) -> None: # After a given svg_path has been converted into points, the result # will be saved to a file so that future calls for the same path # don't need to retrace the same computation. @@ -305,7 +313,7 @@ class VMobjectFromSVGPath(VMobject): np.save(points_filepath, self.get_points()) np.save(tris_filepath, self.get_triangulation()) - def handle_commands(self): + def handle_commands(self) -> None: segment_class_to_func_map = { se.Move: (self.start_new_path, ("end",)), se.Close: (self.close_path, ()), diff --git a/manimlib/mobject/svg/tex_mobject.py b/manimlib/mobject/svg/tex_mobject.py index c81a781b..0e765c48 100644 --- a/manimlib/mobject/svg/tex_mobject.py +++ b/manimlib/mobject/svg/tex_mobject.py @@ -1,5 +1,9 @@ +from __future__ import annotations + +from typing import Iterable, Sequence, Union from functools import reduce import operator as op +import colour import re from manimlib.constants import * @@ -13,10 +17,13 @@ from manimlib.utils.tex_file_writing import get_tex_config from manimlib.utils.tex_file_writing import display_during_execution +ManimColor = Union[str, colour.Color, Sequence[float]] + SCALE_FACTOR_PER_FONT_POINT = 0.001 - -tex_string_with_color_to_mob_map = {} +tex_string_with_color_to_mob_map: dict[ + tuple[ManimColor, str], SVGMobject +] = {} class SingleStringTex(VMobject): @@ -31,7 +38,7 @@ class SingleStringTex(VMobject): "math_mode": True, } - def __init__(self, tex_string, **kwargs): + def __init__(self, tex_string: str, **kwargs): super().__init__(**kwargs) assert(isinstance(tex_string, str)) self.tex_string = tex_string @@ -66,7 +73,7 @@ class SingleStringTex(VMobject): self.set_flat_stroke(self.flat_stroke) return self - def get_tex_file_body(self, tex_string): + def get_tex_file_body(self, tex_string: str) -> str: new_tex = self.get_modified_expression(tex_string) if self.math_mode: new_tex = "\\begin{align*}\n" + new_tex + "\n\\end{align*}" @@ -79,10 +86,10 @@ class SingleStringTex(VMobject): new_tex ) - def get_modified_expression(self, tex_string): + def get_modified_expression(self, tex_string: str) -> str: return self.modify_special_strings(tex_string.strip()) - def modify_special_strings(self, tex): + def modify_special_strings(self, tex: str) -> str: tex = tex.strip() should_add_filler = reduce(op.or_, [ # Fraction line needs something to be over @@ -134,7 +141,7 @@ class SingleStringTex(VMobject): tex = "" return tex - def balance_braces(self, tex): + def balance_braces(self, tex: str) -> str: """ Makes Tex resiliant to unmatched braces """ @@ -154,7 +161,7 @@ class SingleStringTex(VMobject): tex += num_unclosed_brackets * "}" return tex - def get_tex(self): + def get_tex(self) -> str: return self.tex_string def organize_submobjects_left_to_right(self): @@ -169,7 +176,7 @@ class Tex(SingleStringTex): "tex_to_color_map": {}, } - def __init__(self, *tex_strings, **kwargs): + def __init__(self, *tex_strings: str, **kwargs): digest_config(self, kwargs) self.tex_strings = self.break_up_tex_strings(tex_strings) full_string = self.arg_separator.join(self.tex_strings) @@ -180,7 +187,7 @@ class Tex(SingleStringTex): if self.organize_left_to_right: self.organize_submobjects_left_to_right() - def break_up_tex_strings(self, tex_strings): + def break_up_tex_strings(self, tex_strings: Iterable[str]) -> Iterable[str]: # Separate out any strings specified in the isolate # or tex_to_color_map lists. substrings_to_isolate = [*self.isolate, *self.tex_to_color_map.keys()] @@ -228,7 +235,12 @@ class Tex(SingleStringTex): self.set_submobjects(new_submobjects) return self - def get_parts_by_tex(self, tex, substring=True, case_sensitive=True): + def get_parts_by_tex( + self, + tex: str, + substring: bool = True, + case_sensitive: bool = True + ) -> VGroup: def test(tex1, tex2): if not case_sensitive: tex1 = tex1.lower() @@ -243,27 +255,36 @@ class Tex(SingleStringTex): self.submobjects )) - def get_part_by_tex(self, tex, **kwargs): + def get_part_by_tex(self, tex: str, **kwargs) -> SingleStringTex | None: all_parts = self.get_parts_by_tex(tex, **kwargs) return all_parts[0] if all_parts else None - def set_color_by_tex(self, tex, color, **kwargs): + def set_color_by_tex(self, tex: str, color: ManimColor, **kwargs): self.get_parts_by_tex(tex, **kwargs).set_color(color) return self - def set_color_by_tex_to_color_map(self, tex_to_color_map, **kwargs): + def set_color_by_tex_to_color_map( + self, + tex_to_color_map: dict[str, ManimColor], + **kwargs + ): for tex, color in list(tex_to_color_map.items()): self.set_color_by_tex(tex, color, **kwargs) return self - def index_of_part(self, part, start=0): + def index_of_part(self, part: SingleStringTex, start: int = 0) -> int: return self.submobjects.index(part, start) - def index_of_part_by_tex(self, tex, start=0, **kwargs): + def index_of_part_by_tex(self, tex: str, start: int = 0, **kwargs) -> int: part = self.get_part_by_tex(tex, **kwargs) return self.index_of_part(part, start) - def slice_by_tex(self, start_tex=None, stop_tex=None, **kwargs): + def slice_by_tex( + self, + start_tex: str | None = None, + stop_tex: str | None = None, + **kwargs + ) -> VGroup: if start_tex is None: start_index = 0 else: @@ -275,10 +296,10 @@ class Tex(SingleStringTex): stop_index = self.index_of_part_by_tex(stop_tex, start=start_index, **kwargs) return self[start_index:stop_index] - def sort_alphabetically(self): + def sort_alphabetically(self) -> None: self.submobjects.sort(key=lambda m: m.get_tex()) - def set_bstroke(self, color=BLACK, width=4): + def set_bstroke(self, color: ManimColor = BLACK, width: float = 4): self.set_stroke(color, width, background=True) return self @@ -297,7 +318,7 @@ class BulletedList(TexText): "alignment": "", } - def __init__(self, *items, **kwargs): + def __init__(self, *items: str, **kwargs): line_separated_items = [s + "\\\\" for s in items] TexText.__init__(self, *line_separated_items, **kwargs) for part in self: @@ -310,7 +331,7 @@ class BulletedList(TexText): buff=self.buff ) - def fade_all_but(self, index_or_string, opacity=0.5): + def fade_all_but(self, index_or_string: int | str, opacity: float = 0.5) -> None: arg = index_or_string if isinstance(arg, str): part = self.get_part_by_tex(arg) @@ -348,7 +369,7 @@ class Title(TexText): "underline_buff": MED_SMALL_BUFF, } - def __init__(self, *text_parts, **kwargs): + def __init__(self, *text_parts: str, **kwargs): TexText.__init__(self, *text_parts, **kwargs) self.scale(self.scale_factor) self.to_edge(UP) diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index e412f08d..668cddbf 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -1,17 +1,19 @@ -import hashlib +from __future__ import annotations + import os import re import io -import typing -import xml.etree.ElementTree as ET +import hashlib import functools +from pathlib import Path +import xml.etree.ElementTree as ET +from contextlib import contextmanager +from typing import Iterable, Sequence, Union + import pygments import pygments.lexers import pygments.styles -from contextlib import contextmanager -from pathlib import Path - import manimpango from manimlib.logger import log from manimlib.constants import * @@ -23,6 +25,12 @@ from manimlib.utils.customization import get_customization from manimlib.utils.directories import get_downloads_dir, get_text_dir from manimpango import PangoUtils, TextSetting, MarkupUtils +from typing import TYPE_CHECKING +if TYPE_CHECKING: + import colour + from manimlib.mobject.types.vectorized_mobject import VMobject + ManimColor = Union[str, colour.Color, Sequence[float]] + TEXT_MOB_SCALE_FACTOR = 0.0076 DEFAULT_LINE_SPACING_SCALE = 0.6 @@ -50,7 +58,7 @@ class Text(SVGMobject): "disable_ligatures": True, } - def __init__(self, text, **kwargs): + def __init__(self, text: str, **kwargs): self.full2short(kwargs) digest_config(self, kwargs) if self.size: @@ -60,9 +68,9 @@ class Text(SVGMobject): ) self.font_size = self.size if self.lsh == -1: - self.lsh = self.font_size + self.font_size * DEFAULT_LINE_SPACING_SCALE + self.lsh: float = self.font_size + self.font_size * DEFAULT_LINE_SPACING_SCALE else: - self.lsh = self.font_size + self.font_size * self.lsh + self.lsh: float = self.font_size + self.font_size * self.lsh text_without_tabs = text if text.find('\t') != -1: text_without_tabs = text.replace('\t', ' ' * self.tab_width) @@ -87,14 +95,14 @@ class Text(SVGMobject): if self.height is None: self.scale(TEXT_MOB_SCALE_FACTOR) - def remove_empty_path(self, file_name): + def remove_empty_path(self, file_name: str) -> None: with open(file_name, 'r') as fpr: content = fpr.read() content = re.sub(r'', '', content) with open(file_name, 'w') as fpw: fpw.write(content) - def apply_space_chars(self): + def apply_space_chars(self) -> None: submobs = self.submobjects.copy() for char_index in range(len(self.text)): if self.text[char_index] in [" ", "\t", "\n"]: @@ -103,7 +111,7 @@ class Text(SVGMobject): submobs.insert(char_index, space) self.set_submobjects(submobs) - def find_indexes(self, word): + def find_indexes(self, word: str) -> list[tuple[int, int]]: m = re.match(r'\[([0-9\-]{0,}):([0-9\-]{0,})\]', word) if m: start = int(m.group(1)) if m.group(1) != '' else 0 @@ -119,20 +127,20 @@ class Text(SVGMobject): index = self.text.find(word, index + len(word)) return indexes - def get_parts_by_text(self, word): + def get_parts_by_text(self, word: str) -> VGroup: return VGroup(*( self[i:j] for i, j in self.find_indexes(word) )) - def get_part_by_text(self, word): + def get_part_by_text(self, word: str) -> VMobject | None: parts = self.get_parts_by_text(word) if len(parts) > 0: return parts[0] else: return None - def full2short(self, config): + def full2short(self, config: dict[str]) -> None: for kwargs in [config, self.CONFIG]: if kwargs.__contains__('line_spacing_height'): kwargs['lsh'] = kwargs.pop('line_spacing_height') @@ -147,19 +155,25 @@ class Text(SVGMobject): if kwargs.__contains__('text2weight'): kwargs['t2w'] = kwargs.pop('text2weight') - def set_color_by_t2c(self, t2c=None): + def set_color_by_t2c( + self, + t2c: dict[str, ManimColor] | None = None + ) -> None: t2c = t2c if t2c else self.t2c for word, color in t2c.items(): for start, end in self.find_indexes(word): self[start:end].set_color(color) - def set_color_by_t2g(self, t2g=None): + def set_color_by_t2g( + self, + t2g: dict[str, Iterable[ManimColor]] | None = None + ) -> None: t2g = t2g if t2g else self.t2g for word, gradient in t2g.items(): for start, end in self.find_indexes(word): self[start:end].set_color_by_gradient(*gradient) - def text2hash(self): + def text2hash(self) -> str: settings = self.font + self.slant + self.weight settings += str(self.t2f) + str(self.t2s) + str(self.t2w) settings += str(self.lsh) + str(self.font_size) @@ -168,7 +182,7 @@ class Text(SVGMobject): hasher.update(id_str.encode()) return hasher.hexdigest()[:16] - def text2settings(self): + def text2settings(self) -> list[TextSetting]: """ Substrings specified in t2f, t2s, t2w can occupy each other. For each category of style, a stack following first-in-last-out is constructed, @@ -227,7 +241,7 @@ class Text(SVGMobject): del self.line_num return settings - def text2svg(self): + def text2svg(self) -> str: # anti-aliasing size = self.font_size lsh = self.lsh @@ -503,7 +517,7 @@ class Code(Text): "char_width": None } - def __init__(self, code, **kwargs): + def __init__(self, code: str, **kwargs): self.full2short(kwargs) digest_config(self, kwargs) code = code.lstrip("\n") # avoid mismatches of character indices @@ -536,7 +550,7 @@ class Code(Text): if self.char_width is not None: self.set_monospace(self.char_width) - def set_monospace(self, char_width): + def set_monospace(self, char_width: float) -> None: current_char_index = 0 for i, char in enumerate(self.text): if char == "\n": @@ -548,7 +562,7 @@ class Code(Text): @contextmanager -def register_font(font_file: typing.Union[str, Path]): +def register_font(font_file: str | Path): """Temporarily add a font file to Pango's search path. This searches for the font_file at various places. The order it searches it described below. 1. Absolute path. From 3744844efa64697b2204df48f99d7b8b16d9fd7f Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Tue, 15 Feb 2022 11:35:22 +0800 Subject: [PATCH 17/27] fix: fix type hint of set_array_by_interpolation --- manimlib/utils/bezier.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/manimlib/utils/bezier.py b/manimlib/utils/bezier.py index c21b7350..fa33cc55 100644 --- a/manimlib/utils/bezier.py +++ b/manimlib/utils/bezier.py @@ -92,12 +92,12 @@ def interpolate(start: T, end: T, alpha: float) -> T: def set_array_by_interpolation( - arr: list[T], - arr1: list[T], - arr2: list[T], + arr: np.ndarray, + arr1: np.ndarray, + arr2: np.ndarray, alpha: float, - interp_func: Callable[[T, T, float], T] = interpolate -) -> list[T]: + interp_func: Callable[[np.ndarray, np.ndarray, float], np.ndarray] = interpolate +) -> np.ndarray: arr[:] = interp_func(arr1, arr2, alpha) return arr From 4c16bfc2c02a621f19751df3d6057e5be32174c9 Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Tue, 15 Feb 2022 14:37:15 +0800 Subject: [PATCH 18/27] chore: add type hints to manimlib.mobject --- manimlib/mobject/boolean_ops.py | 17 +- manimlib/mobject/changing.py | 37 +++- manimlib/mobject/coordinate_systems.py | 264 ++++++++++++++++------- manimlib/mobject/functions.py | 30 ++- manimlib/mobject/geometry.py | 240 ++++++++++++++------- manimlib/mobject/interactive.py | 87 ++++---- manimlib/mobject/matrix.py | 51 +++-- manimlib/mobject/mobject_update_utils.py | 30 ++- manimlib/mobject/number_line.py | 49 +++-- manimlib/mobject/numbers.py | 33 +-- manimlib/mobject/probability.py | 89 +++++--- manimlib/mobject/shape_matchers.py | 35 +-- manimlib/mobject/three_dimensions.py | 42 ++-- manimlib/mobject/value_tracker.py | 17 +- manimlib/mobject/vector_field.py | 77 +++++-- 15 files changed, 737 insertions(+), 361 deletions(-) diff --git a/manimlib/mobject/boolean_ops.py b/manimlib/mobject/boolean_ops.py index 90d205ce..0d5d05b9 100644 --- a/manimlib/mobject/boolean_ops.py +++ b/manimlib/mobject/boolean_ops.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np import pathops @@ -7,7 +9,7 @@ from manimlib.mobject.types.vectorized_mobject import VMobject # Boolean operations between 2D mobjects # Borrowed from from https://github.com/ManimCommunity/manim/ -def _convert_vmobject_to_skia_path(vmobject): +def _convert_vmobject_to_skia_path(vmobject: VMobject) -> pathops.Path: path = pathops.Path() subpaths = vmobject.get_subpaths_from_points(vmobject.get_all_points()) for subpath in subpaths: @@ -21,7 +23,10 @@ def _convert_vmobject_to_skia_path(vmobject): return path -def _convert_skia_path_to_vmobject(path, vmobject): +def _convert_skia_path_to_vmobject( + path: pathops.Path, + vmobject: VMobject +) -> VMobject: PathVerb = pathops.PathVerb current_path_start = np.array([0.0, 0.0, 0.0]) for path_verb, points in path: @@ -45,7 +50,7 @@ def _convert_skia_path_to_vmobject(path, vmobject): class Union(VMobject): - def __init__(self, *vmobjects, **kwargs): + def __init__(self, *vmobjects: VMobject, **kwargs): if len(vmobjects) < 2: raise ValueError("At least 2 mobjects needed for Union.") super().__init__(**kwargs) @@ -59,7 +64,7 @@ class Union(VMobject): class Difference(VMobject): - def __init__(self, subject, clip, **kwargs): + def __init__(self, subject: VMobject, clip: VMobject, **kwargs): super().__init__(**kwargs) outpen = pathops.Path() pathops.difference( @@ -71,7 +76,7 @@ class Difference(VMobject): class Intersection(VMobject): - def __init__(self, *vmobjects, **kwargs): + def __init__(self, *vmobjects: VMobject, **kwargs): if len(vmobjects) < 2: raise ValueError("At least 2 mobjects needed for Intersection.") super().__init__(**kwargs) @@ -94,7 +99,7 @@ class Intersection(VMobject): class Exclusion(VMobject): - def __init__(self, *vmobjects, **kwargs): + def __init__(self, *vmobjects: VMobject, **kwargs): if len(vmobjects) < 2: raise ValueError("At least 2 mobjects needed for Exclusion.") super().__init__(**kwargs) diff --git a/manimlib/mobject/changing.py b/manimlib/mobject/changing.py index febe4acb..76d92bab 100644 --- a/manimlib/mobject/changing.py +++ b/manimlib/mobject/changing.py @@ -1,4 +1,9 @@ +from __future__ import annotations + +from typing import Callable + import numpy as np + from manimlib.constants import BLUE_D from manimlib.constants import BLUE_B from manimlib.constants import BLUE_E @@ -20,10 +25,10 @@ class AnimatedBoundary(VGroup): "fade_rate_func": smooth, } - def __init__(self, vmobject, **kwargs): + def __init__(self, vmobject: VMobject, **kwargs): super().__init__(**kwargs) - self.vmobject = vmobject - self.boundary_copies = [ + self.vmobject: VMobject = vmobject + self.boundary_copies: list[VMobject] = [ vmobject.copy().set_style( stroke_width=0, fill_opacity=0 @@ -31,12 +36,12 @@ class AnimatedBoundary(VGroup): for x in range(2) ] self.add(*self.boundary_copies) - self.total_time = 0 + self.total_time: float = 0 self.add_updater( lambda m, dt: self.update_boundary_copies(dt) ) - def update_boundary_copies(self, dt): + def update_boundary_copies(self, dt: float) -> None: # Not actual time, but something which passes at # an altered rate to make the implementation below # cleaner @@ -67,7 +72,13 @@ class AnimatedBoundary(VGroup): self.total_time += dt - def full_family_become_partial(self, mob1, mob2, a, b): + def full_family_become_partial( + self, + mob1: VMobject, + mob2: VMobject, + a: float, + b: float + ): family1 = mob1.family_members_with_points() family2 = mob2.family_members_with_points() for sm1, sm2 in zip(family1, family2): @@ -84,14 +95,14 @@ class TracedPath(VMobject): "time_per_anchor": 1 / 15, } - def __init__(self, traced_point_func, **kwargs): + def __init__(self, traced_point_func: Callable[[], np.ndarray], **kwargs): super().__init__(**kwargs) self.traced_point_func = traced_point_func - self.time = 0 - self.traced_points = [] + self.time: float = 0 + self.traced_points: list[np.ndarray] = [] self.add_updater(lambda m, dt: m.update_path(dt)) - def update_path(self, dt): + def update_path(self, dt: float): if dt == 0: return self point = self.traced_point_func().copy() @@ -133,7 +144,11 @@ class TracingTail(TracedPath): "time_traced": 1.0, } - def __init__(self, mobject_or_func, **kwargs): + def __init__( + self, + mobject_or_func: Mobject | Callable[[], np.ndarray], + **kwargs + ): if isinstance(mobject_or_func, Mobject): func = mobject_or_func.get_center else: diff --git a/manimlib/mobject/coordinate_systems.py b/manimlib/mobject/coordinate_systems.py index 3ad01086..80355e08 100644 --- a/manimlib/mobject/coordinate_systems.py +++ b/manimlib/mobject/coordinate_systems.py @@ -1,6 +1,10 @@ -from abc import abstractmethod -import numpy as np +from __future__ import annotations + import numbers +from abc import abstractmethod +from typing import Type, TypeVar, Union, Callable, Iterable, Sequence + +import numpy as np from manimlib.constants import * from manimlib.mobject.functions import ParametricCurve @@ -18,6 +22,14 @@ from manimlib.utils.space_ops import angle_of_vector from manimlib.utils.space_ops import get_norm from manimlib.utils.space_ops import rotate_vector +from typing import TYPE_CHECKING +if TYPE_CHECKING: + import colour + from manimlib.mobject.mobject import Mobject + T = TypeVar("T", bound=Mobject) + ManimColor = Union[str, colour.Color, Sequence[float]] + + EPSILON = 1e-8 @@ -39,56 +51,77 @@ class CoordinateSystem(): self.x_range = np.array(self.default_x_range) self.y_range = np.array(self.default_y_range) - def coords_to_point(self, *coords): + @abstractmethod + def coords_to_point(self, *coords: float) -> np.ndarray: raise Exception("Not implemented") - def point_to_coords(self, point): + @abstractmethod + def point_to_coords(self, point: np.ndarray) -> tuple[float, ...]: raise Exception("Not implemented") - def c2p(self, *coords): + def c2p(self, *coords: float): """Abbreviation for coords_to_point""" return self.coords_to_point(*coords) - def p2c(self, point): + def p2c(self, point: np.ndarray): """Abbreviation for point_to_coords""" return self.point_to_coords(point) - def get_origin(self): + def get_origin(self) -> np.ndarray: return self.c2p(*[0] * self.dimension) @abstractmethod - def get_axes(self): + def get_axes(self) -> VGroup: raise Exception("Not implemented") @abstractmethod - def get_all_ranges(self): + def get_all_ranges(self) -> list[np.ndarray]: raise Exception("Not implemented") - def get_axis(self, index): + def get_axis(self, index: int) -> NumberLine: return self.get_axes()[index] - def get_x_axis(self): + def get_x_axis(self) -> NumberLine: return self.get_axis(0) - def get_y_axis(self): + def get_y_axis(self) -> NumberLine: return self.get_axis(1) - def get_z_axis(self): + def get_z_axis(self) -> NumberLine: return self.get_axis(2) - def get_x_axis_label(self, label_tex, edge=RIGHT, direction=DL, **kwargs): + def get_x_axis_label( + self, + label_tex: str, + edge: np.ndarray = RIGHT, + direction: np.ndarray = DL, + **kwargs + ) -> Tex: return self.get_axis_label( label_tex, self.get_x_axis(), edge, direction, **kwargs ) - def get_y_axis_label(self, label_tex, edge=UP, direction=DR, **kwargs): + def get_y_axis_label( + self, + label_tex: str, + edge: np.ndarray = UP, + direction: np.ndarray = DR, + **kwargs + ) -> Tex: return self.get_axis_label( label_tex, self.get_y_axis(), edge, direction, **kwargs ) - def get_axis_label(self, label_tex, axis, edge, direction, buff=MED_SMALL_BUFF): + def get_axis_label( + self, + label_tex: str, + axis: np.ndarray, + edge: np.ndarray, + direction: np.ndarray, + buff: float = MED_SMALL_BUFF + ) -> Tex: label = Tex(label_tex) label.next_to( axis.get_edge_center(edge), direction, @@ -97,30 +130,43 @@ class CoordinateSystem(): label.shift_onto_screen(buff=MED_SMALL_BUFF) return label - def get_axis_labels(self, x_label_tex="x", y_label_tex="y"): + def get_axis_labels( + self, + x_label_tex: str = "x", + y_label_tex: str = "y" + ) -> VGroup: self.axis_labels = VGroup( self.get_x_axis_label(x_label_tex), self.get_y_axis_label(y_label_tex), ) return self.axis_labels - def get_line_from_axis_to_point(self, index, point, - line_func=DashedLine, - color=GREY_A, - stroke_width=2): + def get_line_from_axis_to_point( + self, + index: int, + point: np.ndarray, + line_func: Type[T] = DashedLine, + color: ManimColor = GREY_A, + stroke_width: float = 2 + ) -> T: axis = self.get_axis(index) line = line_func(axis.get_projection(point), point) line.set_stroke(color, stroke_width) return line - def get_v_line(self, point, **kwargs): + def get_v_line(self, point: np.ndarray, **kwargs): return self.get_line_from_axis_to_point(0, point, **kwargs) - def get_h_line(self, point, **kwargs): + def get_h_line(self, point: np.ndarray, **kwargs): return self.get_line_from_axis_to_point(1, point, **kwargs) # Useful for graphing - def get_graph(self, function, x_range=None, **kwargs): + def get_graph( + self, + function: Callable[[float], float], + x_range: Sequence[float] | None = None, + **kwargs + ) -> ParametricCurve: t_range = np.array(self.x_range, dtype=float) if x_range is not None: t_range[:len(x_range)] = x_range @@ -139,7 +185,11 @@ class CoordinateSystem(): graph.x_range = x_range return graph - def get_parametric_curve(self, function, **kwargs): + def get_parametric_curve( + self, + function: Callable[[float], np.ndarray], + **kwargs + ) -> ParametricCurve: dim = self.dimension graph = ParametricCurve( lambda t: self.coords_to_point(*function(t)[:dim]), @@ -148,7 +198,11 @@ class CoordinateSystem(): graph.underlying_function = function return graph - def input_to_graph_point(self, x, graph): + def input_to_graph_point( + self, + x: float, + graph: ParametricCurve + ) -> np.ndarray | None: if hasattr(graph, "underlying_function"): return self.coords_to_point(x, graph.underlying_function(x)) else: @@ -165,19 +219,21 @@ class CoordinateSystem(): else: return None - def i2gp(self, x, graph): + def i2gp(self, x: float, graph: ParametricCurve) -> np.ndarray | None: """ Alias for input_to_graph_point """ return self.input_to_graph_point(x, graph) - def get_graph_label(self, - graph, - label="f(x)", - x=None, - direction=RIGHT, - buff=MED_SMALL_BUFF, - color=None): + def get_graph_label( + self, + graph: ParametricCurve, + label: str | Mobject = "f(x)", + x: float | None = None, + direction: np.ndarray = RIGHT, + buff: float = MED_SMALL_BUFF, + color: ManimColor | None = None + ) -> Tex | Mobject: if isinstance(label, str): label = Tex(label) if color is None: @@ -204,38 +260,56 @@ class CoordinateSystem(): label.shift_onto_screen() return label - def get_v_line_to_graph(self, x, graph, **kwargs): + def get_v_line_to_graph(self, x: float, graph: ParametricCurve, **kwargs): return self.get_v_line(self.i2gp(x, graph), **kwargs) - def get_h_line_to_graph(self, x, graph, **kwargs): + def get_h_line_to_graph(self, x: float, graph: ParametricCurve, **kwargs): return self.get_h_line(self.i2gp(x, graph), **kwargs) # For calculus - def angle_of_tangent(self, x, graph, dx=EPSILON): + def angle_of_tangent( + self, + x: float, + graph: ParametricCurve, + dx: float = EPSILON + ) -> float: p0 = self.input_to_graph_point(x, graph) p1 = self.input_to_graph_point(x + dx, graph) return angle_of_vector(p1 - p0) - def slope_of_tangent(self, x, graph, **kwargs): + def slope_of_tangent( + self, + x: float, + graph: ParametricCurve, + **kwargs + ) -> float: return np.tan(self.angle_of_tangent(x, graph, **kwargs)) - def get_tangent_line(self, x, graph, length=5, line_func=Line): + def get_tangent_line( + self, + x: float, + graph: ParametricCurve, + length: float = 5, + line_func: Type[T] = Line + ) -> T: line = line_func(LEFT, RIGHT) line.set_width(length) line.rotate(self.angle_of_tangent(x, graph)) line.move_to(self.input_to_graph_point(x, graph)) return line - def get_riemann_rectangles(self, - graph, - x_range=None, - dx=None, - input_sample_type="left", - stroke_width=1, - stroke_color=BLACK, - fill_opacity=1, - colors=(BLUE, GREEN), - show_signed_area=True): + def get_riemann_rectangles( + self, + graph: ParametricCurve, + x_range: Sequence[float] = None, + dx: float | None = None, + input_sample_type: str = "left", + stroke_width: float = 1, + stroke_color: ManimColor = BLACK, + fill_opacity: float = 1, + colors: Iterable[ManimColor] = (BLUE, GREEN), + show_signed_area: bool = True + ) -> VGroup: if x_range is None: x_range = self.x_range[:2] if dx is None: @@ -288,10 +362,12 @@ class Axes(VGroup, CoordinateSystem): "width": FRAME_WIDTH - 2, } - def __init__(self, - x_range=None, - y_range=None, - **kwargs): + def __init__( + self, + x_range: Sequence[float] | None = None, + y_range: Sequence[float] | None = None, + **kwargs + ): CoordinateSystem.__init__(self, **kwargs) VGroup.__init__(self, **kwargs) @@ -314,36 +390,43 @@ class Axes(VGroup, CoordinateSystem): self.add(*self.axes) self.center() - def create_axis(self, range_terms, axis_config, length): + def create_axis( + self, + range_terms: Sequence[float], + axis_config: dict[str], + length: float + ) -> NumberLine: new_config = merge_dicts_recursively(self.axis_config, axis_config) new_config["width"] = length axis = NumberLine(range_terms, **new_config) axis.shift(-axis.n2p(0)) return axis - def coords_to_point(self, *coords): + def coords_to_point(self, *coords: float) -> np.ndarray: origin = self.x_axis.number_to_point(0) result = origin.copy() for axis, coord in zip(self.get_axes(), coords): result += (axis.number_to_point(coord) - origin) return result - def point_to_coords(self, point): + def point_to_coords(self, point: np.ndarray) -> tuple[float, ...]: return tuple([ axis.point_to_number(point) for axis in self.get_axes() ]) - def get_axes(self): + def get_axes(self) -> VGroup: return self.axes - def get_all_ranges(self): + def get_all_ranges(self) -> list[Sequence[float]]: return [self.x_range, self.y_range] - def add_coordinate_labels(self, - x_values=None, - y_values=None, - **kwargs): + def add_coordinate_labels( + self, + x_values: Iterable[float] | None = None, + y_values: Iterable[float] | None = None, + **kwargs + ) -> VGroup: axes = self.get_axes() self.coordinate_labels = VGroup() for axis, values in zip(axes, [x_values, y_values]): @@ -367,7 +450,13 @@ class ThreeDAxes(Axes): "gloss": 0.5, } - def __init__(self, x_range=None, y_range=None, z_range=None, **kwargs): + def __init__( + self, + x_range: Sequence[float] | None = None, + y_range: Sequence[float] | None = None, + z_range: Sequence[float] | None = None, + **kwargs + ): Axes.__init__(self, x_range, y_range, **kwargs) if z_range is not None: self.z_range[:len(z_range)] = z_range @@ -390,7 +479,7 @@ class ThreeDAxes(Axes): for axis in self.axes: axis.insert_n_curves(self.num_axis_pieces - 1) - def get_all_ranges(self): + def get_all_ranges(self) -> list[Sequence[float]]: return [self.x_range, self.y_range, self.z_range] @@ -420,11 +509,16 @@ class NumberPlane(Axes): "make_smooth_after_applying_functions": True, } - def __init__(self, x_range=None, y_range=None, **kwargs): + def __init__( + self, + x_range: Sequence[float] | None = None, + y_range: Sequence[float] | None = None, + **kwargs + ): super().__init__(x_range, y_range, **kwargs) self.init_background_lines() - def init_background_lines(self): + def init_background_lines(self) -> None: if self.faded_line_style is None: style = dict(self.background_line_style) # For anything numerical, like stroke_width @@ -442,7 +536,7 @@ class NumberPlane(Axes): self.background_lines, ) - def get_lines(self): + def get_lines(self) -> tuple[VGroup, VGroup]: x_axis = self.get_x_axis() y_axis = self.get_y_axis() @@ -452,7 +546,11 @@ class NumberPlane(Axes): lines2 = VGroup(*x_lines2, *y_lines2) return lines1, lines2 - def get_lines_parallel_to_axis(self, axis1, axis2): + def get_lines_parallel_to_axis( + self, + axis1: NumberLine, + axis2: NumberLine + ) -> tuple[VGroup, VGroup]: freq = axis2.x_step ratio = self.faded_line_ratio line = Line(axis1.get_start(), axis1.get_end()) @@ -471,20 +569,20 @@ class NumberPlane(Axes): lines2.add(new_line) return lines1, lines2 - def get_x_unit_size(self): + def get_x_unit_size(self) -> float: return self.get_x_axis().get_unit_size() - def get_y_unit_size(self): + def get_y_unit_size(self) -> list: return self.get_x_axis().get_unit_size() - def get_axes(self): + def get_axes(self) -> VGroup: return self.axes - def get_vector(self, coords, **kwargs): + def get_vector(self, coords: Iterable[float], **kwargs) -> Arrow: kwargs["buff"] = 0 return Arrow(self.c2p(0, 0), self.c2p(*coords), **kwargs) - def prepare_for_nonlinear_transform(self, num_inserted_curves=50): + def prepare_for_nonlinear_transform(self, num_inserted_curves: int = 50): for mob in self.family_members_with_points(): num_curves = mob.get_num_curves() if num_inserted_curves > num_curves: @@ -499,27 +597,35 @@ class ComplexPlane(NumberPlane): "line_frequency": 1, } - def number_to_point(self, number): + def number_to_point(self, number: complex | float) -> np.ndarray: number = complex(number) return self.coords_to_point(number.real, number.imag) - def n2p(self, number): + def n2p(self, number: complex | float) -> np.ndarray: return self.number_to_point(number) - def point_to_number(self, point): + def point_to_number(self, point: np.ndarray) -> complex: x, y = self.point_to_coords(point) return complex(x, y) - def p2n(self, point): + def p2n(self, point: np.ndarray) -> complex: return self.point_to_number(point) - def get_default_coordinate_values(self, skip_first=True): + def get_default_coordinate_values( + self, + skip_first: bool = True + ) -> list[complex]: x_numbers = self.get_x_axis().get_tick_range()[1:] y_numbers = self.get_y_axis().get_tick_range()[1:] y_numbers = [complex(0, y) for y in y_numbers if y != 0] return [*x_numbers, *y_numbers] - def add_coordinate_labels(self, numbers=None, skip_first=True, **kwargs): + def add_coordinate_labels( + self, + numbers: list[complex] | None = None, + skip_first: bool = True, + **kwargs + ): if numbers is None: numbers = self.get_default_coordinate_values(skip_first) diff --git a/manimlib/mobject/functions.py b/manimlib/mobject/functions.py index 3677a119..2ded8b08 100644 --- a/manimlib/mobject/functions.py +++ b/manimlib/mobject/functions.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import Callable, Sequence + from isosurfaces import plot_isoline from manimlib.constants import * @@ -14,7 +18,12 @@ class ParametricCurve(VMobject): "use_smoothing": True, } - def __init__(self, t_func, t_range=None, **kwargs): + def __init__( + self, + t_func: Callable[[float], np.ndarray], + t_range: Sequence[float] | None = None, + **kwargs + ): digest_config(self, kwargs) if t_range is not None: self.t_range[:len(t_range)] = t_range @@ -27,7 +36,7 @@ class ParametricCurve(VMobject): self.t_func = t_func VMobject.__init__(self, **kwargs) - def get_point_from_function(self, t): + def get_point_from_function(self, t: float) -> np.ndarray: return self.t_func(t) def init_points(self): @@ -55,7 +64,12 @@ class FunctionGraph(ParametricCurve): "x_range": [-8, 8, 0.25], } - def __init__(self, function, x_range=None, **kwargs): + def __init__( + self, + function: Callable[[float], float], + x_range: Sequence[float] | None = None, + **kwargs + ): digest_config(self, kwargs) self.function = function @@ -67,10 +81,10 @@ class FunctionGraph(ParametricCurve): super().__init__(parametric_function, self.x_range, **kwargs) - def get_function(self): + def get_function(self) -> Callable[[float], float]: return self.function - def get_point_from_function(self, x): + def get_point_from_function(self, x: float) -> np.ndarray: return self.t_func(x) @@ -83,7 +97,11 @@ class ImplicitFunction(VMobject): "use_smoothing": True } - def __init__(self, func, x_range=None, y_range=None, **kwargs): + def __init__( + self, + func: Callable[[float, float], float], + **kwargs + ): digest_config(self, kwargs) self.function = func super().__init__(**kwargs) diff --git a/manimlib/mobject/geometry.py b/manimlib/mobject/geometry.py index 14a15b27..b856238c 100644 --- a/manimlib/mobject/geometry.py +++ b/manimlib/mobject/geometry.py @@ -1,6 +1,11 @@ -import numpy as np +from __future__ import annotations + import math import numbers +from typing import Sequence, Union + +import colour +import numpy as np from manimlib.constants import * from manimlib.mobject.mobject import Mobject @@ -21,6 +26,8 @@ from manimlib.utils.space_ops import normalize from manimlib.utils.space_ops import rotate_vector from manimlib.utils.space_ops import rotation_matrix_transpose +ManimColor = Union[str, colour.Color, Sequence[float]] + DEFAULT_DOT_RADIUS = 0.08 DEFAULT_SMALL_DOT_RADIUS = 0.04 @@ -58,7 +65,7 @@ class TipableVMobject(VMobject): } # Adding, Creating, Modifying tips - def add_tip(self, at_start=False, **kwargs): + def add_tip(self, at_start: bool = False, **kwargs): """ Adds a tip to the TipableVMobject instance, recognising that the endpoints might need to be switched if it's @@ -71,7 +78,7 @@ class TipableVMobject(VMobject): self.add(tip) return self - def create_tip(self, at_start=False, **kwargs): + def create_tip(self, at_start: bool = False, **kwargs) -> ArrowTip: """ Stylises the tip, positions it spacially, and returns the newly instantiated tip to the caller. @@ -80,7 +87,7 @@ class TipableVMobject(VMobject): self.position_tip(tip, at_start) return tip - def get_unpositioned_tip(self, **kwargs): + def get_unpositioned_tip(self, **kwargs) -> ArrowTip: """ Returns a tip that has been stylistically configured, but has not yet been given a position in space. @@ -90,7 +97,7 @@ class TipableVMobject(VMobject): config.update(kwargs) return ArrowTip(**config) - def position_tip(self, tip, at_start=False): + def position_tip(self, tip: ArrowTip, at_start: bool = False) -> ArrowTip: # Last two control points, defining both # the end, and the tangency direction if at_start: @@ -103,7 +110,7 @@ class TipableVMobject(VMobject): tip.shift(anchor - tip.get_tip_point()) return tip - def reset_endpoints_based_on_tip(self, tip, at_start): + def reset_endpoints_based_on_tip(self, tip: ArrowTip, at_start: bool): if self.get_length() == 0: # Zero length, put_start_and_end_on wouldn't # work @@ -118,7 +125,7 @@ class TipableVMobject(VMobject): self.put_start_and_end_on(start, end) return self - def asign_tip_attr(self, tip, at_start): + def asign_tip_attr(self, tip: ArrowTip, at_start: bool): if at_start: self.start_tip = tip else: @@ -126,14 +133,14 @@ class TipableVMobject(VMobject): return self # Checking for tips - def has_tip(self): + def has_tip(self) -> bool: return hasattr(self, "tip") and self.tip in self - def has_start_tip(self): + def has_start_tip(self) -> bool: return hasattr(self, "start_tip") and self.start_tip in self # Getters - def pop_tips(self): + def pop_tips(self) -> VGroup: start, end = self.get_start_and_end() result = VGroup() if self.has_tip(): @@ -145,7 +152,7 @@ class TipableVMobject(VMobject): self.put_start_and_end_on(start, end) return result - def get_tips(self): + def get_tips(self) -> VGroup: """ Returns a VGroup (collection of VMobjects) containing the TipableVMObject instance's tips. @@ -157,7 +164,7 @@ class TipableVMobject(VMobject): result.add(self.start_tip) return result - def get_tip(self): + def get_tip(self) -> ArrowTip: """Returns the TipableVMobject instance's (first) tip, otherwise throws an exception.""" tips = self.get_tips() @@ -166,28 +173,28 @@ class TipableVMobject(VMobject): else: return tips[0] - def get_default_tip_length(self): + def get_default_tip_length(self) -> float: return self.tip_length - def get_first_handle(self): + def get_first_handle(self) -> np.ndarray: return self.get_points()[1] - def get_last_handle(self): + def get_last_handle(self) -> np.ndarray: return self.get_points()[-2] - def get_end(self): + def get_end(self) -> np.ndarray: if self.has_tip(): return self.tip.get_start() else: return VMobject.get_end(self) - def get_start(self): + def get_start(self) -> np.ndarray: if self.has_start_tip(): return self.start_tip.get_start() else: return VMobject.get_start(self) - def get_length(self): + def get_length(self) -> float: start, end = self.get_start_and_end() return get_norm(start - end) @@ -200,12 +207,17 @@ class Arc(TipableVMobject): "arc_center": ORIGIN, } - def __init__(self, start_angle=0, angle=TAU / 4, **kwargs): + def __init__( + self, + start_angle: float = 0, + angle: float = TAU / 4, + **kwargs + ): self.start_angle = start_angle self.angle = angle VMobject.__init__(self, **kwargs) - def init_points(self): + def init_points(self) -> None: self.set_points(Arc.create_quadratic_bezier_points( angle=self.angle, start_angle=self.start_angle, @@ -215,7 +227,11 @@ class Arc(TipableVMobject): self.shift(self.arc_center) @staticmethod - def create_quadratic_bezier_points(angle, start_angle=0, n_components=8): + def create_quadratic_bezier_points( + angle: float, + start_angle: float = 0, + n_components: int = 8 + ) -> np.ndarray: samples = np.array([ [np.cos(a), np.sin(a), 0] for a in np.linspace( @@ -233,7 +249,7 @@ class Arc(TipableVMobject): points[2::3] = samples[2::2] return points - def get_arc_center(self): + def get_arc_center(self) -> np.ndarray: """ Looks at the normals to the first two anchors, and finds their intersection points @@ -248,21 +264,27 @@ class Arc(TipableVMobject): n2 = rotate_vector(t2, TAU / 4) return find_intersection(a1, n1, a2, n2) - def get_start_angle(self): + def get_start_angle(self) -> float: angle = angle_of_vector(self.get_start() - self.get_arc_center()) return angle % TAU - def get_stop_angle(self): + def get_stop_angle(self) -> float: angle = angle_of_vector(self.get_end() - self.get_arc_center()) return angle % TAU - def move_arc_center_to(self, point): + def move_arc_center_to(self, point: np.ndarray): self.shift(point - self.get_arc_center()) return self class ArcBetweenPoints(Arc): - def __init__(self, start, end, angle=TAU / 4, **kwargs): + def __init__( + self, + start: np.ndarray, + end: np.ndarray, + angle: float = TAU / 4, + **kwargs + ): super().__init__(angle=angle, **kwargs) if angle == 0: self.set_points_as_corners([LEFT, RIGHT]) @@ -270,13 +292,23 @@ class ArcBetweenPoints(Arc): class CurvedArrow(ArcBetweenPoints): - def __init__(self, start_point, end_point, **kwargs): + def __init__( + self, + start_point: np.ndarray, + end_point: np.ndarray, + **kwargs + ): ArcBetweenPoints.__init__(self, start_point, end_point, **kwargs) self.add_tip() class CurvedDoubleArrow(CurvedArrow): - def __init__(self, start_point, end_point, **kwargs): + def __init__( + self, + start_point: np.ndarray, + end_point: np.ndarray, + **kwargs + ): CurvedArrow.__init__(self, start_point, end_point, **kwargs) self.add_tip(at_start=True) @@ -291,7 +323,13 @@ class Circle(Arc): def __init__(self, **kwargs): Arc.__init__(self, 0, TAU, **kwargs) - def surround(self, mobject, dim_to_match=0, stretch=False, buff=MED_SMALL_BUFF): + def surround( + self, + mobject: Mobject, + dim_to_match: int = 0, + stretch: bool = False, + buff: float = MED_SMALL_BUFF + ): # Ignores dim_to_match and stretch; result will always be a circle # TODO: Perhaps create an ellipse class to handle singele-dimension stretching @@ -299,13 +337,13 @@ class Circle(Arc): self.stretch((self.get_width() + 2 * buff) / self.get_width(), 0) self.stretch((self.get_height() + 2 * buff) / self.get_height(), 1) - def point_at_angle(self, angle): + def point_at_angle(self, angle: float) -> np.ndarray: start_angle = self.get_start_angle() return self.point_from_proportion( (angle - start_angle) / TAU ) - def get_radius(self): + def get_radius(self) -> float: return get_norm(self.get_start() - self.get_center()) @@ -317,7 +355,7 @@ class Dot(Circle): "color": WHITE } - def __init__(self, point=ORIGIN, **kwargs): + def __init__(self, point: np.ndarray = ORIGIN, **kwargs): super().__init__(arc_center=point, **kwargs) @@ -401,15 +439,26 @@ class Line(TipableVMobject): "path_arc": 0, } - def __init__(self, start=LEFT, end=RIGHT, **kwargs): + def __init__( + self, + start: np.ndarray = LEFT, + end: np.ndarray = RIGHT, + **kwargs + ): digest_config(self, kwargs) self.set_start_and_end_attrs(start, end) super().__init__(**kwargs) - def init_points(self): + def init_points(self) -> None: self.set_points_by_ends(self.start, self.end, self.buff, self.path_arc) - def set_points_by_ends(self, start, end, buff=0, path_arc=0): + def set_points_by_ends( + self, + start: np.ndarray, + end: np.ndarray, + buff: float = 0, + path_arc: float = 0 + ): vect = end - start dist = get_norm(vect) if np.isclose(dist, 0): @@ -438,11 +487,11 @@ class Line(TipableVMobject): self.set_points_as_corners([start, end]) return self - def set_path_arc(self, new_value): + def set_path_arc(self, new_value: float) -> None: self.path_arc = new_value self.init_points() - def set_start_and_end_attrs(self, start, end): + def set_start_and_end_attrs(self, start: np.ndarray, end: np.ndarray): # If either start or end are Mobjects, this # gives their centers rough_start = self.pointify(start) @@ -454,7 +503,11 @@ class Line(TipableVMobject): self.start = self.pointify(start, vect) self.end = self.pointify(end, -vect) - def pointify(self, mob_or_point, direction=None): + def pointify( + self, + mob_or_point: Mobject | np.ndarray, + direction: np.ndarray | None = None + ) -> np.ndarray: """ Take an argument passed into Line (or subclass) and turn it into a 3d point. @@ -471,7 +524,7 @@ class Line(TipableVMobject): result[:len(point)] = point return result - def put_start_and_end_on(self, start, end): + def put_start_and_end_on(self, start: np.ndarray, end: np.ndarray): curr_start, curr_end = self.get_start_and_end() if np.isclose(curr_start, curr_end).all(): # Handle null lines more gracefully @@ -479,16 +532,16 @@ class Line(TipableVMobject): return self return super().put_start_and_end_on(start, end) - def get_vector(self): + def get_vector(self) -> np.ndarray: return self.get_end() - self.get_start() - def get_unit_vector(self): + def get_unit_vector(self) -> np.ndarray: return normalize(self.get_vector()) - def get_angle(self): + def get_angle(self) -> float: return angle_of_vector(self.get_vector()) - def get_projection(self, point): + def get_projection(self, point: np.ndarray) -> np.ndarray: """ Return projection of a point onto the line """ @@ -496,10 +549,10 @@ class Line(TipableVMobject): start = self.get_start() return start + np.dot(point - start, unit_vect) * unit_vect - def get_slope(self): + def get_slope(self) -> float: return np.tan(self.get_angle()) - def set_angle(self, angle, about_point=None): + def set_angle(self, angle: float, about_point: np.ndarray | None = None): if about_point is None: about_point = self.get_start() self.rotate( @@ -508,7 +561,7 @@ class Line(TipableVMobject): ) return self - def set_length(self, length, **kwargs): + def set_length(self, length: float, **kwargs): self.scale(length / self.get_length(), **kwargs) return self @@ -532,35 +585,35 @@ class DashedLine(Line): self.clear_points() self.add(*dashes) - def calculate_num_dashes(self, positive_space_ratio): + def calculate_num_dashes(self, positive_space_ratio: float) -> int: try: full_length = self.dash_length / positive_space_ratio return int(np.ceil(self.get_length() / full_length)) except ZeroDivisionError: return 1 - def calculate_positive_space_ratio(self): + def calculate_positive_space_ratio(self) -> float: return fdiv( self.dash_length, self.dash_length + self.dash_spacing, ) - def get_start(self): + def get_start(self) -> np.ndarray: if len(self.submobjects) > 0: return self.submobjects[0].get_start() else: return Line.get_start(self) - def get_end(self): + def get_end(self) -> np.ndarray: if len(self.submobjects) > 0: return self.submobjects[-1].get_end() else: return Line.get_end(self) - def get_first_handle(self): + def get_first_handle(self) -> np.ndarray: return self.submobjects[0].get_points()[1] - def get_last_handle(self): + def get_last_handle(self) -> np.ndarray: return self.submobjects[-1].get_points()[-2] @@ -570,7 +623,7 @@ class TangentLine(Line): "d_alpha": 1e-6 } - def __init__(self, vmob, alpha, **kwargs): + def __init__(self, vmob: VMobject, alpha: float, **kwargs): digest_config(self, kwargs) da = self.d_alpha a1 = clip(alpha - da, 0, 1) @@ -603,16 +656,22 @@ class Arrow(Line): "buff": 0.25, } - def set_points_by_ends(self, start, end, buff=0, path_arc=0): + def set_points_by_ends( + self, + start: np.ndarray, + end: np.ndarray, + buff: float = 0, + path_arc: float = 0 + ): super().set_points_by_ends(start, end, buff, path_arc) self.insert_tip_anchor() return self - def init_colors(self): + def init_colors(self) -> None: super().init_colors() self.create_tip_with_stroke_width() - def get_arc_length(self): + def get_arc_length(self) -> float: # Push up into Line? arc_len = get_norm(self.get_vector()) if self.path_arc > 0: @@ -655,14 +714,19 @@ class Arrow(Line): self.create_tip_with_stroke_width() return self - def set_stroke(self, color=None, width=None, *args, **kwargs): + def set_stroke( + self, + color: ManimColor | None = None, + width: float | None = None, + *args, **kwargs + ): super().set_stroke(color=color, width=width, *args, **kwargs) if isinstance(width, numbers.Number): self.max_stroke_width = width self.reset_tip() return self - def _handle_scale_side_effects(self, scale_factor): + def _handle_scale_side_effects(self, scale_factor: float): return self.reset_tip() @@ -679,7 +743,13 @@ class FillArrow(Line): "max_width_to_length_ratio": 0.1, } - def set_points_by_ends(self, start, end, buff=0, path_arc=0): + def set_points_by_ends( + self, + start: np.ndarray, + end: np.ndarray, + buff: float = 0, + path_arc: float = 0 + ) -> None: # Find the right tip length and thickness vect = end - start length = max(get_norm(vect), 1e-8) @@ -748,15 +818,15 @@ class FillArrow(Line): ) return self - def get_start(self): + def get_start(self) -> np.ndarray: nppc = self.n_points_per_curve points = self.get_points() return (points[0] + points[-nppc]) / 2 - def get_end(self): + def get_end(self) -> np.ndarray: return self.get_points()[self.tip_index] - def put_start_and_end_on(self, start, end): + def put_start_and_end_on(self, start: np.ndarray, end: np.ndarray): self.set_points_by_ends(start, end, buff=0, path_arc=self.path_arc) return self @@ -765,12 +835,12 @@ class FillArrow(Line): self.reset_points_around_ends() return self - def set_thickness(self, thickness): + def set_thickness(self, thickness: float): self.thickness = thickness self.reset_points_around_ends() return self - def set_path_arc(self, path_arc): + def set_path_arc(self, path_arc: float): self.path_arc = path_arc self.reset_points_around_ends() return self @@ -781,7 +851,7 @@ class Vector(Arrow): "buff": 0, } - def __init__(self, direction=RIGHT, **kwargs): + def __init__(self, direction: np.ndarray = RIGHT, **kwargs): if len(direction) == 2: direction = np.hstack([direction, 0]) super().__init__(ORIGIN, direction, **kwargs) @@ -794,24 +864,31 @@ class DoubleArrow(Arrow): class CubicBezier(VMobject): - def __init__(self, a0, h0, h1, a1, **kwargs): + def __init__( + self, + a0: np.ndarray, + h0: np.ndarray, + h1: np.ndarray, + a1: np.ndarray, + **kwargs + ): VMobject.__init__(self, **kwargs) self.add_cubic_bezier_curve(a0, h0, h1, a1) class Polygon(VMobject): - def __init__(self, *vertices, **kwargs): + def __init__(self, *vertices: np.ndarray, **kwargs): self.vertices = vertices super().__init__(**kwargs) - def init_points(self): + def init_points(self) -> None: verts = self.vertices self.set_points_as_corners([*verts, verts[0]]) - def get_vertices(self): + def get_vertices(self) -> list[np.ndarray]: return self.get_start_anchors() - def round_corners(self, radius=0.5): + def round_corners(self, radius: float = 0.5): vertices = self.get_vertices() arcs = [] for v1, v2, v3 in adjacent_n_tuples(vertices, 3): @@ -850,7 +927,7 @@ class Polygon(VMobject): class Polyline(Polygon): - def init_points(self): + def init_points(self) -> None: self.set_points_as_corners(self.vertices) @@ -859,7 +936,7 @@ class RegularPolygon(Polygon): "start_angle": None, } - def __init__(self, n=6, **kwargs): + def __init__(self, n: int = 6, **kwargs): digest_config(self, kwargs, locals()) if self.start_angle is None: # 0 for odd, 90 for even @@ -898,19 +975,19 @@ class ArrowTip(Triangle): self.data["points"] = Dot().set_width(h).get_points() self.rotate(self.angle) - def get_base(self): + def get_base(self) -> np.ndarray: return self.point_from_proportion(0.5) - def get_tip_point(self): + def get_tip_point(self) -> np.ndarray: return self.get_points()[0] - def get_vector(self): + def get_vector(self) -> np.ndarray: return self.get_tip_point() - self.get_base() - def get_angle(self): + def get_angle(self) -> float: return angle_of_vector(self.get_vector()) - def get_length(self): + def get_length(self) -> float: return get_norm(self.get_vector()) @@ -923,7 +1000,12 @@ class Rectangle(Polygon): "close_new_points": True, } - def __init__(self, width=None, height=None, **kwargs): + def __init__( + self, + width: float | None = None, + height: float | None = None, + **kwargs + ): Polygon.__init__(self, UR, UL, DL, DR, **kwargs) if width is None: @@ -936,7 +1018,7 @@ class Rectangle(Polygon): class Square(Rectangle): - def __init__(self, side_length=2.0, **kwargs): + def __init__(self, side_length: float = 2.0, **kwargs): self.side_length = side_length super().__init__(side_length, side_length, **kwargs) diff --git a/manimlib/mobject/interactive.py b/manimlib/mobject/interactive.py index 617449be..b50425ef 100644 --- a/manimlib/mobject/interactive.py +++ b/manimlib/mobject/interactive.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import Callable + import numpy as np from pyglet.window import key as PygletWindowKeys @@ -21,8 +25,7 @@ class MotionMobject(Mobject): """ You could hold and drag this object to any position """ - - def __init__(self, mobject, **kwargs): + def __init__(self, mobject: Mobject, **kwargs): super().__init__(**kwargs) assert(isinstance(mobject, Mobject)) self.mobject = mobject @@ -31,7 +34,7 @@ class MotionMobject(Mobject): self.mobject.add_updater(lambda mob: None) self.add(mobject) - def mob_on_mouse_drag(self, mob, event_data): + def mob_on_mouse_drag(self, mob: Mobject, event_data: dict[str, np.ndarray]) -> bool: mob.move_to(event_data["point"]) return False @@ -43,7 +46,7 @@ class Button(Mobject): The on_click method takes mobject as argument like updater """ - def __init__(self, mobject, on_click, **kwargs): + def __init__(self, mobject: Mobject, on_click: Callable[[Mobject]], **kwargs): super().__init__(**kwargs) assert(isinstance(mobject, Mobject)) self.on_click = on_click @@ -51,7 +54,7 @@ class Button(Mobject): self.mobject.add_mouse_press_listner(self.mob_on_mouse_press) self.add(self.mobject) - def mob_on_mouse_press(self, mob, event_data): + def mob_on_mouse_press(self, mob: Mobject, event_data) -> bool: self.on_click(mob) return False @@ -59,7 +62,7 @@ class Button(Mobject): # Controls class ControlMobject(ValueTracker): - def __init__(self, value, *mobjects, **kwargs): + def __init__(self, value: float, *mobjects: Mobject, **kwargs): super().__init__(value=value, **kwargs) self.add(*mobjects) @@ -67,7 +70,7 @@ class ControlMobject(ValueTracker): self.add_updater(lambda mob: None) self.fix_in_frame() - def set_value(self, value): + def set_value(self, value: float): self.assert_value(value) self.set_value_anim(value) return ValueTracker.set_value(self, value) @@ -93,25 +96,25 @@ class EnableDisableButton(ControlMobject): "disable_color": RED } - def __init__(self, value=True, **kwargs): + def __init__(self, value: bool = True, **kwargs): digest_config(self, kwargs) self.box = Rectangle(**self.rect_kwargs) super().__init__(value, self.box, **kwargs) self.add_mouse_press_listner(self.on_mouse_press) - def assert_value(self, value): + def assert_value(self, value: bool) -> None: assert(isinstance(value, bool)) - def set_value_anim(self, value): + def set_value_anim(self, value: bool) -> None: if value: self.box.set_fill(self.enable_color) else: self.box.set_fill(self.disable_color) - def toggle_value(self): + def toggle_value(self) -> None: super().set_value(not self.get_value()) - def on_mouse_press(self, mob, event_data): + def on_mouse_press(self, mob: Mobject, event_data) -> bool: mob.toggle_value() return False @@ -136,32 +139,32 @@ class Checkbox(ControlMobject): "box_content_buff": SMALL_BUFF } - def __init__(self, value=True, **kwargs): + def __init__(self, value: bool = True, **kwargs): digest_config(self, kwargs) self.box = Rectangle(**self.rect_kwargs) self.box_content = self.get_checkmark() if value else self.get_cross() super().__init__(value, self.box, self.box_content, **kwargs) self.add_mouse_press_listner(self.on_mouse_press) - def assert_value(self, value): + def assert_value(self, value: bool) -> None: assert(isinstance(value, bool)) - def toggle_value(self): + def toggle_value(self) -> None: super().set_value(not self.get_value()) - def set_value_anim(self, value): + def set_value_anim(self, value: bool) -> None: if value: self.box_content.become(self.get_checkmark()) else: self.box_content.become(self.get_cross()) - def on_mouse_press(self, mob, event_data): + def on_mouse_press(self, mob: Mobject, event_data) -> None: mob.toggle_value() return False # Helper methods - def get_checkmark(self): + def get_checkmark(self) -> VGroup: checkmark = VGroup( Line(UP / 2 + 2 * LEFT, DOWN + LEFT, **self.checkmark_kwargs), Line(DOWN + LEFT, UP + RIGHT, **self.checkmark_kwargs) @@ -173,7 +176,7 @@ class Checkbox(ControlMobject): checkmark.move_to(self.box) return checkmark - def get_cross(self): + def get_cross(self) -> VGroup: cross = VGroup( Line(UP + LEFT, DOWN + RIGHT, **self.cross_kwargs), Line(UP + RIGHT, DOWN + LEFT, **self.cross_kwargs) @@ -206,7 +209,7 @@ class LinearNumberSlider(ControlMobject): } } - def __init__(self, value=0, **kwargs): + def __init__(self, value: float = 0, **kwargs): digest_config(self, kwargs) self.bar = RoundedRectangle(**self.rounded_rect_kwargs) self.slider = Circle(**self.circle_kwargs) @@ -219,22 +222,22 @@ class LinearNumberSlider(ControlMobject): self.slider.add_mouse_drag_listner(self.slider_on_mouse_drag) - super().__init__(value, self.bar, self.slider, self.slider_axis, ** kwargs) + super().__init__(value, self.bar, self.slider, self.slider_axis, **kwargs) - def assert_value(self, value): + def assert_value(self, value: float) -> None: assert(self.min_value <= value <= self.max_value) - def set_value_anim(self, value): + def set_value_anim(self, value: float) -> None: prop = (value - self.min_value) / (self.max_value - self.min_value) self.slider.move_to(self.slider_axis.point_from_proportion(prop)) - def slider_on_mouse_drag(self, mob, event_data): + def slider_on_mouse_drag(self, mob, event_data: dict[str, np.ndarray]) -> bool: self.set_value(self.get_value_from_point(event_data["point"])) return False # Helper Methods - def get_value_from_point(self, point): + def get_value_from_point(self, point: np.ndarray) -> float: start, end = self.slider_axis.get_start_and_end() point_on_line = get_closest_point_on_line(start, end, point) prop = get_norm(point_on_line - start) / get_norm(end - start) @@ -300,7 +303,7 @@ class ColorSliders(Group): self.arrange(DOWN) - def get_background(self): + def get_background(self) -> VGroup: single_square_len = self.background_grid_kwargs["single_square_len"] colors = self.background_grid_kwargs["colors"] width = self.rect_kwargs["width"] @@ -322,24 +325,24 @@ class ColorSliders(Group): return grid - def set_value(self, r, g, b, a): + def set_value(self, r: float, g: float, b: float, a: float): self.r_slider.set_value(r) self.g_slider.set_value(g) self.b_slider.set_value(b) self.a_slider.set_value(a) - def get_value(self): + def get_value(self) -> np.ndarary: r = self.r_slider.get_value() / 255 g = self.g_slider.get_value() / 255 b = self.b_slider.get_value() / 255 alpha = self.a_slider.get_value() return color_to_rgba(rgb_to_color((r, g, b)), alpha=alpha) - def get_picked_color(self): + def get_picked_color(self) -> str: rgba = self.get_value() return rgb_to_hex(rgba[:3]) - def get_picked_opacity(self): + def get_picked_opacity(self) -> float: rgba = self.get_value() return rgba[3] @@ -363,7 +366,7 @@ class Textbox(ControlMobject): "deactive_color": RED, } - def __init__(self, value="", **kwargs): + def __init__(self, value: str = "", **kwargs): digest_config(self, kwargs) self.isActive = self.isInitiallyActive self.box = Rectangle(**self.box_kwargs) @@ -374,10 +377,10 @@ class Textbox(ControlMobject): self.active_anim(self.isActive) self.add_key_press_listner(self.on_key_press) - def set_value_anim(self, value): + def set_value_anim(self, value: str) -> None: self.update_text(value) - def update_text(self, value): + def update_text(self, value: str) -> None: text = self.text self.remove(text) text.__init__(value, **self.text_kwargs) @@ -389,18 +392,18 @@ class Textbox(ControlMobject): text.fix_in_frame() self.add(text) - def active_anim(self, isActive): + def active_anim(self, isActive: bool) -> None: if isActive: self.box.set_stroke(self.active_color) else: self.box.set_stroke(self.deactive_color) - def box_on_mouse_press(self, mob, event_data): + def box_on_mouse_press(self, mob, event_data) -> bool: self.isActive = not self.isActive self.active_anim(self.isActive) return False - def on_key_press(self, mob, event_data): + def on_key_press(self, mob: Mobject, event_data: dict[str, int]) -> bool | None: symbol = event_data["symbol"] modifiers = event_data["modifiers"] char = chr(symbol) @@ -443,7 +446,7 @@ class ControlPanel(Group): } } - def __init__(self, *controls, **kwargs): + def __init__(self, *controls: ControlMobject, **kwargs): digest_config(self, kwargs) self.panel = Rectangle(**self.panel_kwargs) @@ -472,7 +475,7 @@ class ControlPanel(Group): self.move_panel_and_controls_to_panel_opener() self.fix_in_frame() - def move_panel_and_controls_to_panel_opener(self): + def move_panel_and_controls_to_panel_opener(self) -> None: self.panel.next_to( self.panel_opener_rect, direction=UP, @@ -488,11 +491,11 @@ class ControlPanel(Group): self.controls.set_x(controls_old_x) - def add_controls(self, *new_controls): + def add_controls(self, *new_controls: ControlMobject) -> None: self.controls.add(*new_controls) self.move_panel_and_controls_to_panel_opener() - def remove_controls(self, *controls_to_remove): + def remove_controls(self, *controls_to_remove: ControlMobject) -> None: self.controls.remove(*controls_to_remove) self.move_panel_and_controls_to_panel_opener() @@ -510,13 +513,13 @@ class ControlPanel(Group): self.move_panel_and_controls_to_panel_opener() return self - def panel_opener_on_mouse_drag(self, mob, event_data): + def panel_opener_on_mouse_drag(self, mob, event_data: dict[str, np.ndarray]) -> bool: point = event_data["point"] self.panel_opener.match_y(Dot(point)) self.move_panel_and_controls_to_panel_opener() return False - def panel_on_mouse_scroll(self, mob, event_data): + def panel_on_mouse_scroll(self, mob, event_data: dict[str, np.ndarray]) -> bool: offset = event_data["offset"] factor = 10 * offset[1] self.controls.set_y(self.controls.get_y() + factor) diff --git a/manimlib/mobject/matrix.py b/manimlib/mobject/matrix.py index 18a22b20..3133d992 100644 --- a/manimlib/mobject/matrix.py +++ b/manimlib/mobject/matrix.py @@ -1,5 +1,10 @@ -import numpy as np +from __future__ import annotations + import itertools as it +from typing import Union, Sequence + +import numpy as np +import numpy.typing as npt from manimlib.constants import * from manimlib.mobject.numbers import DecimalNumber @@ -10,10 +15,17 @@ from manimlib.mobject.svg.tex_mobject import TexText from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VMobject +from typing import TYPE_CHECKING +if TYPE_CHECKING: + import colour + from manimlib.mobject.mobject import Mobject + ManimColor = Union[str, colour.Color, Sequence[float]] + + VECTOR_LABEL_SCALE_FACTOR = 0.8 -def matrix_to_tex_string(matrix): +def matrix_to_tex_string(matrix: npt.ArrayLike) -> str: matrix = np.array(matrix).astype("str") if matrix.ndim == 1: matrix = matrix.reshape((matrix.size, 1)) @@ -27,12 +39,16 @@ def matrix_to_tex_string(matrix): return prefix + " \\\\ ".join(rows) + suffix -def matrix_to_mobject(matrix): +def matrix_to_mobject(matrix: npt.ArrayLike) -> Tex: return Tex(matrix_to_tex_string(matrix)) -def vector_coordinate_label(vector_mob, integer_labels=True, - n_dim=2, color=WHITE): +def vector_coordinate_label( + vector_mob: VMobject, + integer_labels: bool = True, + n_dim: int = 2, + color: ManimColor = WHITE +) -> Matrix: vect = np.array(vector_mob.get_end()) if integer_labels: vect = np.round(vect).astype(int) @@ -66,7 +82,7 @@ class Matrix(VMobject): "element_alignment_corner": DOWN, } - def __init__(self, matrix, **kwargs): + def __init__(self, matrix: npt.ArrayLike, **kwargs): """ Matrix can either include numbers, tex_strings, or mobjects @@ -87,7 +103,7 @@ class Matrix(VMobject): if self.include_background_rectangle: self.add_background_rectangle() - def matrix_to_mob_matrix(self, matrix): + def matrix_to_mob_matrix(self, matrix: npt.ArrayLike) -> list[list[Mobject]]: return [ [ self.element_to_mobject(item, **self.element_to_mobject_config) @@ -96,7 +112,7 @@ class Matrix(VMobject): for row in matrix ] - def organize_mob_matrix(self, matrix): + def organize_mob_matrix(self, matrix: npt.ArrayLike): for i, row in enumerate(matrix): for j, elem in enumerate(row): mob = matrix[i][j] @@ -126,19 +142,19 @@ class Matrix(VMobject): self.brackets = VGroup(l_bracket, r_bracket) return self - def get_columns(self): + def get_columns(self) -> VGroup: return VGroup(*[ VGroup(*[row[i] for row in self.mob_matrix]) for i in range(len(self.mob_matrix[0])) ]) - def get_rows(self): + def get_rows(self) -> VGroup: return VGroup(*[ VGroup(*row) for row in self.mob_matrix ]) - def set_column_colors(self, *colors): + def set_column_colors(self, *colors: ManimColor): columns = self.get_columns() for color, column in zip(colors, columns): column.set_color(color) @@ -149,13 +165,13 @@ class Matrix(VMobject): mob.add_background_rectangle() return self - def get_mob_matrix(self): + def get_mob_matrix(self) -> list[list[Mobject]]: return self.mob_matrix - def get_entries(self): + def get_entries(self) -> VGroup: return self.elements - def get_brackets(self): + def get_brackets(self) -> VGroup: return self.brackets @@ -179,7 +195,12 @@ class MobjectMatrix(Matrix): } -def get_det_text(matrix, determinant=None, background_rect=False, initial_scale_factor=2): +def get_det_text( + matrix: Matrix, + determinant: int | str | None = None, + background_rect: bool = False, + initial_scale_factor: int = 2 +) -> VGroup: parens = Tex("(", ")") parens.scale(initial_scale_factor) parens.stretch_to_fit_height(matrix.get_height()) diff --git a/manimlib/mobject/mobject_update_utils.py b/manimlib/mobject/mobject_update_utils.py index 152ffd99..1a698888 100644 --- a/manimlib/mobject/mobject_update_utils.py +++ b/manimlib/mobject/mobject_update_utils.py @@ -1,10 +1,18 @@ +from __future__ import annotations + import inspect +from typing import Callable from manimlib.constants import DEGREES from manimlib.constants import RIGHT from manimlib.mobject.mobject import Mobject from manimlib.utils.simple_functions import clip +from typing import TYPE_CHECKING +if TYPE_CHECKING: + import numpy as np + from manimlib.animation.animation import Animation + def assert_is_mobject_method(method): assert(inspect.ismethod(method)) @@ -41,27 +49,39 @@ def f_always(method, *arg_generators, **kwargs): return mobject -def always_redraw(func, *args, **kwargs): +def always_redraw(func: Callable[..., Mobject], *args, **kwargs) -> Mobject: mob = func(*args, **kwargs) mob.add_updater(lambda m: mob.become(func(*args, **kwargs))) return mob -def always_shift(mobject, direction=RIGHT, rate=0.1): +def always_shift( + mobject: Mobject, + direction: np.ndarray = RIGHT, + rate: float = 0.1 +) -> Mobject: mobject.add_updater( lambda m, dt: m.shift(dt * rate * direction) ) return mobject -def always_rotate(mobject, rate=20 * DEGREES, **kwargs): +def always_rotate( + mobject: Mobject, + rate: float = 20 * DEGREES, + **kwargs +) -> Mobject: mobject.add_updater( lambda m, dt: m.rotate(dt * rate, **kwargs) ) return mobject -def turn_animation_into_updater(animation, cycle=False, **kwargs): +def turn_animation_into_updater( + animation: Animation, + cycle: bool = False, + **kwargs +) -> Mobject: """ Add an updater to the animation's mobject which applies the interpolation and update functions of the animation @@ -94,7 +114,7 @@ def turn_animation_into_updater(animation, cycle=False, **kwargs): return mobject -def cycle_animation(animation, **kwargs): +def cycle_animation(animation: Animation, **kwargs) -> Mobject: return turn_animation_into_updater( animation, cycle=True, **kwargs ) diff --git a/manimlib/mobject/number_line.py b/manimlib/mobject/number_line.py index cb9a04fa..f4e2a496 100644 --- a/manimlib/mobject/number_line.py +++ b/manimlib/mobject/number_line.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import Iterable, Sequence + from manimlib.constants import * from manimlib.mobject.geometry import Line from manimlib.mobject.numbers import DecimalNumber @@ -38,7 +42,7 @@ class NumberLine(Line): "numbers_to_exclude": None } - def __init__(self, x_range=None, **kwargs): + def __init__(self, x_range: Sequence[float] | None = None, **kwargs): digest_config(self, kwargs) if x_range is None: x_range = self.x_range @@ -48,9 +52,9 @@ class NumberLine(Line): x_min, x_max, x_step = x_range # A lot of old scenes pass in x_min or x_max explicitly, # so this is just here to keep those workin - self.x_min = kwargs.get("x_min", x_min) - self.x_max = kwargs.get("x_max", x_max) - self.x_step = kwargs.get("x_step", x_step) + self.x_min: float = kwargs.get("x_min", x_min) + self.x_max: float = kwargs.get("x_max", x_max) + self.x_step: float = kwargs.get("x_step", x_step) super().__init__(self.x_min * RIGHT, self.x_max * RIGHT, **kwargs) if self.width: @@ -71,14 +75,14 @@ class NumberLine(Line): if self.include_numbers: self.add_numbers(excluding=self.numbers_to_exclude) - def get_tick_range(self): + def get_tick_range(self) -> np.ndarray: if self.include_tip: x_max = self.x_max else: x_max = self.x_max + self.x_step return np.arange(self.x_min, x_max, self.x_step) - def add_ticks(self): + def add_ticks(self) -> None: ticks = VGroup() for x in self.get_tick_range(): size = self.tick_size @@ -88,7 +92,7 @@ class NumberLine(Line): self.add(ticks) self.ticks = ticks - def get_tick(self, x, size=None): + def get_tick(self, x: float, size: float | None = None) -> Line: if size is None: size = self.tick_size result = Line(size * DOWN, size * UP) @@ -97,14 +101,14 @@ class NumberLine(Line): result.match_style(self) return result - def get_tick_marks(self): + def get_tick_marks(self) -> VGroup: return self.ticks - def number_to_point(self, number): + def number_to_point(self, number: float) -> np.ndarray: alpha = float(number - self.x_min) / (self.x_max - self.x_min) return interpolate(self.get_start(), self.get_end(), alpha) - def point_to_number(self, point): + def point_to_number(self, point: np.ndarray) -> float: points = self.get_points() start = points[0] end = points[-1] @@ -115,21 +119,24 @@ class NumberLine(Line): ) return interpolate(self.x_min, self.x_max, proportion) - def n2p(self, number): + def n2p(self, number: float) -> np.ndarray: """Abbreviation for number_to_point""" return self.number_to_point(number) - def p2n(self, point): + def p2n(self, point: np.ndarray) -> float: """Abbreviation for point_to_number""" return self.point_to_number(point) - def get_unit_size(self): + def get_unit_size(self) -> float: return self.get_length() / (self.x_max - self.x_min) - def get_number_mobject(self, x, - direction=None, - buff=None, - **number_config): + def get_number_mobject( + self, + x: float, + direction: np.ndarray | None = None, + buff: float | None = None, + **number_config + ) -> DecimalNumber: number_config = merge_dicts_recursively( self.decimal_number_config, number_config ) @@ -149,7 +156,13 @@ class NumberLine(Line): num_mob.shift(num_mob[0].get_width() * LEFT / 2) return num_mob - def add_numbers(self, x_values=None, excluding=None, font_size=24, **kwargs): + def add_numbers( + self, + x_values: Iterable[float] | None = None, + excluding: Iterable[float] | None =None, + font_size: int = 24, + **kwargs + ) -> VGroup: if x_values is None: x_values = self.get_tick_range() diff --git a/manimlib/mobject/numbers.py b/manimlib/mobject/numbers.py index 7a8d05ba..79167439 100644 --- a/manimlib/mobject/numbers.py +++ b/manimlib/mobject/numbers.py @@ -1,10 +1,15 @@ +from __future__ import annotations + +from typing import TypeVar, Type + from manimlib.constants import * from manimlib.mobject.svg.tex_mobject import SingleStringTex from manimlib.mobject.svg.text_mobject import Text from manimlib.mobject.types.vectorized_mobject import VMobject +T = TypeVar("T", bound=VMobject) -string_to_mob_map = {} +string_to_mob_map: dict[str, VMobject] = {} class DecimalNumber(VMobject): @@ -22,12 +27,12 @@ class DecimalNumber(VMobject): "font_size": 48, } - def __init__(self, number=0, **kwargs): + def __init__(self, number: float | complex = 0, **kwargs): super().__init__(**kwargs) self.set_submobjects_from_number(number) self.init_colors() - def set_submobjects_from_number(self, number): + def set_submobjects_from_number(self, number: float | complex) -> None: self.number = number self.set_submobjects([]) @@ -62,7 +67,7 @@ class DecimalNumber(VMobject): if self.include_background_rectangle: self.add_background_rectangle() - def get_num_string(self, number): + def get_num_string(self, number: float | complex) -> str: if isinstance(number, complex): formatter = self.get_complex_formatter() else: @@ -78,21 +83,21 @@ class DecimalNumber(VMobject): num_string = num_string.replace("-", "–") return num_string - def init_data(self): + def init_data(self) -> None: super().init_data() self.data["font_size"] = np.array([self.font_size], dtype=float) - def get_font_size(self): + def get_font_size(self) -> float: return self.data["font_size"][0] - def string_to_mob(self, string, mob_class=Text): + def string_to_mob(self, string: str, mob_class: Type[T] = Text) -> T: if string not in string_to_mob_map: string_to_mob_map[string] = mob_class(string, font_size=1) mob = string_to_mob_map[string].copy() mob.scale(self.get_font_size()) return mob - def get_formatter(self, **kwargs): + def get_formatter(self, **kwargs) -> str: """ Configuration is based first off instance attributes, but overwritten by any kew word argument. Relevant @@ -121,14 +126,14 @@ class DecimalNumber(VMobject): "}", ]) - def get_complex_formatter(self, **kwargs): + def get_complex_formatter(self, **kwargs) -> str: return "".join([ self.get_formatter(field_name="0.real"), self.get_formatter(field_name="0.imag", include_sign=True), "i" ]) - def set_value(self, number): + def set_value(self, number: float | complex): move_to_point = self.get_edge_center(self.edge_to_fix) old_submobjects = list(self.submobjects) self.set_submobjects_from_number(number) @@ -137,13 +142,13 @@ class DecimalNumber(VMobject): sm1.match_style(sm2) return self - def _handle_scale_side_effects(self, scale_factor): + def _handle_scale_side_effects(self, scale_factor: float) -> None: self.data["font_size"] *= scale_factor - def get_value(self): + def get_value(self) -> float | complex: return self.number - def increment_value(self, delta_t=1): + def increment_value(self, delta_t: float | complex = 1) -> None: self.set_value(self.get_value() + delta_t) @@ -152,5 +157,5 @@ class Integer(DecimalNumber): "num_decimal_places": 0, } - def get_value(self): + def get_value(self) -> int: return int(np.round(super().get_value())) diff --git a/manimlib/mobject/probability.py b/manimlib/mobject/probability.py index dd3dfd31..9f4bdeab 100644 --- a/manimlib/mobject/probability.py +++ b/manimlib/mobject/probability.py @@ -1,3 +1,8 @@ +from __future__ import annotations + +from typing import Iterable, Union, Sequence +import colour + from manimlib.constants import * from manimlib.mobject.geometry import Line from manimlib.mobject.geometry import Rectangle @@ -9,6 +14,8 @@ from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.utils.color import color_gradient from manimlib.utils.iterables import listify +ManimColor = Union[str, colour.Color, Sequence[float]] + EPSILON = 0.0001 @@ -24,7 +31,11 @@ class SampleSpace(Rectangle): "default_label_scale_val": 1, } - def add_title(self, title="Sample space", buff=MED_SMALL_BUFF): + def add_title( + self, + title: str = "Sample space", + buff: float = MED_SMALL_BUFF + ) -> None: # TODO, should this really exist in SampleSpaceScene title_mob = TexText(title) if title_mob.get_width() > self.get_width(): @@ -33,17 +44,23 @@ class SampleSpace(Rectangle): self.title = title_mob self.add(title_mob) - def add_label(self, label): + def add_label(self, label: str) -> None: self.label = label - def complete_p_list(self, p_list): + def complete_p_list(self, p_list: list[float]) -> list[float]: new_p_list = listify(p_list) remainder = 1.0 - sum(new_p_list) if abs(remainder) > EPSILON: new_p_list.append(remainder) return new_p_list - def get_division_along_dimension(self, p_list, dim, colors, vect): + def get_division_along_dimension( + self, + p_list: list[float], + dim: int, + colors: Iterable[ManimColor], + vect: np.ndarray + ) -> VGroup: p_list = self.complete_p_list(p_list) colors = color_gradient(colors, len(p_list)) @@ -60,38 +77,41 @@ class SampleSpace(Rectangle): return parts def get_horizontal_division( - self, p_list, - colors=[GREEN_E, BLUE_E], - vect=DOWN - ): + self, + p_list: list[float], + colors: Iterable[ManimColor] = [GREEN_E, BLUE_E], + vect: np.ndarray = DOWN + ) -> VGroup: return self.get_division_along_dimension(p_list, 1, colors, vect) def get_vertical_division( - self, p_list, - colors=[MAROON_B, YELLOW], - vect=RIGHT - ): + self, + p_list: list[float], + colors: Iterable[ManimColor] = [MAROON_B, YELLOW], + vect: np.ndarray = RIGHT + ) -> VGroup: return self.get_division_along_dimension(p_list, 0, colors, vect) - def divide_horizontally(self, *args, **kwargs): + def divide_horizontally(self, *args, **kwargs) -> None: self.horizontal_parts = self.get_horizontal_division(*args, **kwargs) self.add(self.horizontal_parts) - def divide_vertically(self, *args, **kwargs): + def divide_vertically(self, *args, **kwargs) -> None: self.vertical_parts = self.get_vertical_division(*args, **kwargs) self.add(self.vertical_parts) def get_subdivision_braces_and_labels( - self, parts, labels, direction, - buff=SMALL_BUFF, - min_num_quads=1 - ): + self, + parts: VGroup, + labels: str, + direction: np.ndarray, + buff: float = SMALL_BUFF, + ) -> VGroup: label_mobs = VGroup() braces = VGroup() for label, part in zip(labels, parts): brace = Brace( part, direction, - min_num_quads=min_num_quads, buff=buff ) if isinstance(label, Mobject): @@ -112,22 +132,35 @@ class SampleSpace(Rectangle): } return VGroup(parts.braces, parts.labels) - def get_side_braces_and_labels(self, labels, direction=LEFT, **kwargs): + def get_side_braces_and_labels( + self, + labels: str, + direction: np.ndarray = LEFT, + **kwargs + ) -> VGroup: assert(hasattr(self, "horizontal_parts")) parts = self.horizontal_parts return self.get_subdivision_braces_and_labels(parts, labels, direction, **kwargs) - def get_top_braces_and_labels(self, labels, **kwargs): + def get_top_braces_and_labels( + self, + labels: str, + **kwargs + ) -> VGroup: assert(hasattr(self, "vertical_parts")) parts = self.vertical_parts return self.get_subdivision_braces_and_labels(parts, labels, UP, **kwargs) - def get_bottom_braces_and_labels(self, labels, **kwargs): + def get_bottom_braces_and_labels( + self, + labels: str, + **kwargs + ) -> VGroup: assert(hasattr(self, "vertical_parts")) parts = self.vertical_parts return self.get_subdivision_braces_and_labels(parts, labels, DOWN, **kwargs) - def add_braces_and_labels(self): + def add_braces_and_labels(self) -> None: for attr in "horizontal_parts", "vertical_parts": if not hasattr(self, attr): continue @@ -136,7 +169,7 @@ class SampleSpace(Rectangle): if hasattr(parts, subattr): self.add(getattr(parts, subattr)) - def __getitem__(self, index): + def __getitem__(self, index: int | slice) -> VGroup: if hasattr(self, "horizontal_parts"): return self.horizontal_parts[index] elif hasattr(self, "vertical_parts"): @@ -162,7 +195,7 @@ class BarChart(VGroup): "bar_label_scale_val": 0.75, } - def __init__(self, values, **kwargs): + def __init__(self, values: Iterable[float], **kwargs): VGroup.__init__(self, **kwargs) if self.max_value is None: self.max_value = max(values) @@ -172,7 +205,7 @@ class BarChart(VGroup): self.add_bars(values) self.center() - def add_axes(self): + def add_axes(self) -> None: x_axis = Line(self.tick_width * LEFT / 2, self.width * RIGHT) y_axis = Line(MED_LARGE_BUFF * DOWN, self.height * UP) y_ticks = VGroup() @@ -209,7 +242,7 @@ class BarChart(VGroup): self.y_axis_labels = labels self.add(labels) - def add_bars(self, values): + def add_bars(self, values: Iterable[float]) -> None: buff = float(self.width) / (2 * len(values)) bars = VGroup() for i, value in enumerate(values): @@ -234,7 +267,7 @@ class BarChart(VGroup): self.bars = bars self.bar_labels = bar_labels - def change_bar_values(self, values): + def change_bar_values(self, values: Iterable[float]) -> None: for bar, value in zip(self.bars, values): bar_bottom = bar.get_bottom() bar.stretch_to_fit_height( diff --git a/manimlib/mobject/shape_matchers.py b/manimlib/mobject/shape_matchers.py index 8a279c6c..62c93e34 100644 --- a/manimlib/mobject/shape_matchers.py +++ b/manimlib/mobject/shape_matchers.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from manimlib.constants import * from manimlib.mobject.geometry import Line from manimlib.mobject.geometry import Rectangle @@ -7,6 +9,12 @@ from manimlib.utils.color import Color from manimlib.utils.customization import get_customization from manimlib.utils.config_ops import digest_config +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Union, Sequence + from manimlib.mobject.mobject import Mobject + ManimColor = Union[str, Color, Sequence[float]] + class SurroundingRectangle(Rectangle): CONFIG = { @@ -14,7 +22,7 @@ class SurroundingRectangle(Rectangle): "buff": SMALL_BUFF, } - def __init__(self, mobject, **kwargs): + def __init__(self, mobject: Mobject, **kwargs): digest_config(self, kwargs) kwargs["width"] = mobject.get_width() + 2 * self.buff kwargs["height"] = mobject.get_height() + 2 * self.buff @@ -30,23 +38,24 @@ class BackgroundRectangle(SurroundingRectangle): "buff": 0 } - def __init__(self, mobject, color=None, **kwargs): + def __init__(self, mobject: Mobject, color: ManimColor = None, **kwargs): if color is None: color = get_customization()['style']['background_color'] SurroundingRectangle.__init__(self, mobject, color=color, **kwargs) self.original_fill_opacity = self.fill_opacity - def pointwise_become_partial(self, mobject, a, b): + def pointwise_become_partial(self, mobject: Mobject, a: float, b: float): self.set_fill(opacity=b * self.original_fill_opacity) return self - def set_style_data(self, - stroke_color=None, - stroke_width=None, - fill_color=None, - fill_opacity=None, - family=True - ): + def set_style_data( + self, + stroke_color: ManimColor | None = None, + stroke_width: float | None = None, + fill_color: ManimColor | None = None, + fill_opacity: float | None = None, + family: bool = True + ): # Unchangeable style, except for fill_opacity VMobject.set_style_data( self, @@ -57,7 +66,7 @@ class BackgroundRectangle(SurroundingRectangle): ) return self - def get_fill_color(self): + def get_fill_color(self) -> Color: return Color(self.color) @@ -67,7 +76,7 @@ class Cross(VGroup): "stroke_width": [0, 6, 0], } - def __init__(self, mobject, **kwargs): + def __init__(self, mobject: Mobject, **kwargs): super().__init__( Line(UL, DR), Line(UR, DL), @@ -82,7 +91,7 @@ class Underline(Line): "buff": SMALL_BUFF, } - def __init__(self, mobject, **kwargs): + def __init__(self, mobject: Mobject, **kwargs): super().__init__(LEFT, RIGHT, **kwargs) self.match_width(mobject) self.next_to(mobject, DOWN, buff=self.buff) diff --git a/manimlib/mobject/three_dimensions.py b/manimlib/mobject/three_dimensions.py index 93083ba7..5d39fe7a 100644 --- a/manimlib/mobject/three_dimensions.py +++ b/manimlib/mobject/three_dimensions.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import math from manimlib.constants import * @@ -23,13 +25,13 @@ class SurfaceMesh(VGroup): "flat_stroke": False, } - def __init__(self, uv_surface, **kwargs): + def __init__(self, uv_surface: Surface, **kwargs): if not isinstance(uv_surface, Surface): raise Exception("uv_surface must be of type Surface") self.uv_surface = uv_surface super().__init__(**kwargs) - def init_points(self): + def init_points(self) -> None: uv_surface = self.uv_surface full_nu, full_nv = uv_surface.resolution @@ -75,7 +77,7 @@ class Sphere(Surface): "v_range": (0, PI), } - def uv_func(self, u, v): + def uv_func(self, u: float, v: float) -> np.ndarray: return self.radius * np.array([ np.cos(u) * np.sin(v), np.sin(u) * np.sin(v), @@ -91,7 +93,7 @@ class Torus(Surface): "r2": 1, } - def uv_func(self, u, v): + def uv_func(self, u: float, v: float) -> np.ndarray: P = np.array([math.cos(u), math.sin(u), 0]) return (self.r1 - self.r2 * math.cos(v)) * P - math.sin(v) * OUT @@ -113,8 +115,8 @@ class Cylinder(Surface): self.apply_matrix(z_to_vector(self.axis)) return self - def uv_func(self, u, v): - return [np.cos(u), np.sin(u), v] + def uv_func(self, u: float, v: float) -> np.ndarray: + return np.array([np.cos(u), np.sin(u), v]) class Line3D(Cylinder): @@ -123,7 +125,7 @@ class Line3D(Cylinder): "resolution": (21, 25) } - def __init__(self, start, end, **kwargs): + def __init__(self, start: np.ndarray, end: np.ndarray, **kwargs): digest_config(self, kwargs) axis = end - start super().__init__( @@ -142,16 +144,16 @@ class Disk3D(Surface): "resolution": (2, 25), } - def init_points(self): + def init_points(self) -> None: super().init_points() self.scale(self.radius) - def uv_func(self, u, v): - return [ + def uv_func(self, u: float, v: float) -> np.ndarray: + return np.array([ u * np.cos(v), u * np.sin(v), 0 - ] + ]) class Square3D(Surface): @@ -162,12 +164,12 @@ class Square3D(Surface): "resolution": (2, 2), } - def init_points(self): + def init_points(self) -> None: super().init_points() self.scale(self.side_length / 2) - def uv_func(self, u, v): - return [u, v, 0] + def uv_func(self, u: float, v: float) -> np.ndarray: + return np.array([u, v, 0]) class Cube(SGroup): @@ -180,7 +182,7 @@ class Cube(SGroup): "square_class": Square3D, } - def init_points(self): + def init_points(self) -> None: face = Square3D( resolution=self.square_resolution, side_length=self.side_length, @@ -188,7 +190,7 @@ class Cube(SGroup): self.add(*self.square_to_cube_faces(face)) @staticmethod - def square_to_cube_faces(square): + def square_to_cube_faces(square: Square3D) -> list[Square3D]: radius = square.get_height() / 2 square.move_to(radius * OUT) result = [square] @@ -199,7 +201,7 @@ class Cube(SGroup): result.append(square.copy().rotate(PI, RIGHT, about_point=ORIGIN)) return result - def _get_face(self): + def _get_face(self) -> Square3D: return Square3D(resolution=self.square_resolution) @@ -212,7 +214,7 @@ class VCube(VGroup): "shadow": 0.5, } - def __init__(self, side_length=2, **kwargs): + def __init__(self, side_length: int = 2, **kwargs): super().__init__(**kwargs) face = Square(side_length=side_length) face.get_triangulation() @@ -233,7 +235,7 @@ class Dodecahedron(VGroup): "depth_test": True, } - def init_points(self): + def init_points(self) -> None: # Star by creating two of the pentagons, meeting # back to back on the positive x-axis phi = (1 + math.sqrt(5)) / 2 @@ -274,7 +276,7 @@ class Prism(Cube): "dimensions": [3, 2, 1] } - def init_points(self): + def init_points(self) -> None: Cube.init_points(self) for dim, value in enumerate(self.dimensions): self.rescale_to_fit(value, dim, stretch=True) diff --git a/manimlib/mobject/value_tracker.py b/manimlib/mobject/value_tracker.py index 2232f6c7..40c61d2e 100644 --- a/manimlib/mobject/value_tracker.py +++ b/manimlib/mobject/value_tracker.py @@ -1,3 +1,6 @@ +from asyncio import futures +from __future__ import annotations + import numpy as np from manimlib.mobject.mobject import Mobject @@ -15,11 +18,11 @@ class ValueTracker(Mobject): "value_type": np.float64, } - def __init__(self, value=0, **kwargs): + def __init__(self, value: float | complex = 0, **kwargs): self.value = value super().__init__(**kwargs) - def init_data(self): + def init_data(self) -> None: super().init_data() self.data["value"] = np.array( listify(self.value), @@ -27,17 +30,17 @@ class ValueTracker(Mobject): dtype=self.value_type, ) - def get_value(self): + def get_value(self) -> float | complex: result = self.data["value"][0, :] if len(result) == 1: return result[0] return result - def set_value(self, value): + def set_value(self, value: float | complex): self.data["value"][0, :] = value return self - def increment_value(self, d_value): + def increment_value(self, d_value: float | complex) -> None: self.set_value(self.get_value() + d_value) @@ -48,10 +51,10 @@ class ExponentialValueTracker(ValueTracker): behaves """ - def get_value(self): + def get_value(self) -> float | complex: return np.exp(ValueTracker.get_value(self)) - def set_value(self, value): + def set_value(self, value: float | complex): return ValueTracker.set_value(self, np.log(value)) diff --git a/manimlib/mobject/vector_field.py b/manimlib/mobject/vector_field.py index b7f69096..1300c2c5 100644 --- a/manimlib/mobject/vector_field.py +++ b/manimlib/mobject/vector_field.py @@ -1,9 +1,13 @@ -import numpy as np +from __future__ import annotations + import itertools as it import random +from typing import Sequence, TypeVar, Callable, Iterable + +import numpy as np +import numpy.typing as npt from manimlib.constants import * - from manimlib.animation.composition import AnimationGroup from manimlib.animation.indication import VShowPassingFlash from manimlib.mobject.geometry import Arrow @@ -18,8 +22,18 @@ from manimlib.utils.rate_functions import linear from manimlib.utils.simple_functions import sigmoid from manimlib.utils.space_ops import get_norm +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from manimlib.mobject.mobject import Mobject + from manimlib.mobject.coordinate_systems import CoordinateSystem + T = TypeVar("T") -def get_vectorized_rgb_gradient_function(min_value, max_value, color_map): + +def get_vectorized_rgb_gradient_function( + min_value: T, + max_value: T, + color_map: str +) -> Callable[[npt.ArrayLike], np.ndarray]: rgbs = np.array(get_colormap_list(color_map)) def func(values): @@ -37,12 +51,19 @@ def get_vectorized_rgb_gradient_function(min_value, max_value, color_map): return func -def get_rgb_gradient_function(min_value, max_value, color_map): +def get_rgb_gradient_function( + min_value: T, + max_value: T, + color_map: str +) -> Callable[[T], np.ndarray]: vectorized_func = get_vectorized_rgb_gradient_function(min_value, max_value, color_map) return lambda value: vectorized_func([value])[0] -def move_along_vector_field(mobject, func): +def move_along_vector_field( + mobject: Mobject, + func: Callable[[np.ndarray], np.ndarray] +) -> Mobject: mobject.add_updater( lambda m, dt: m.shift( func(m.get_center()) * dt @@ -51,7 +72,10 @@ def move_along_vector_field(mobject, func): return mobject -def move_submobjects_along_vector_field(mobject, func): +def move_submobjects_along_vector_field( + mobject: Mobject, + func: Callable[[np.ndarray], np.ndarray] +) -> Mobject: def apply_nudge(mob, dt): for submob in mob: x, y = submob.get_center()[:2] @@ -62,7 +86,11 @@ def move_submobjects_along_vector_field(mobject, func): return mobject -def move_points_along_vector_field(mobject, func, coordinate_system): +def move_points_along_vector_field( + mobject: Mobject, + func: Callable[[float, float], Iterable[float]], + coordinate_system: CoordinateSystem +) -> Mobject: cs = coordinate_system origin = cs.get_origin() @@ -74,7 +102,10 @@ def move_points_along_vector_field(mobject, func, coordinate_system): return mobject -def get_sample_points_from_coordinate_system(coordinate_system, step_multiple): +def get_sample_points_from_coordinate_system( + coordinate_system: CoordinateSystem, + step_multiple: float +) -> it.product[tuple[np.ndarray, ...]]: ranges = [] for range_args in coordinate_system.get_all_ranges(): _min, _max, step = range_args @@ -96,7 +127,12 @@ class VectorField(VGroup): "vector_config": {}, } - def __init__(self, func, coordinate_system, **kwargs): + def __init__( + self, + func: Callable[[float, float], Sequence[float]], + coordinate_system: CoordinateSystem, + **kwargs + ): super().__init__(**kwargs) self.func = func self.coordinate_system = coordinate_system @@ -112,7 +148,7 @@ class VectorField(VGroup): for coords in samples )) - def get_vector(self, coords, **kwargs): + def get_vector(self, coords: Iterable[float], **kwargs) -> Arrow: vector_config = merge_dicts_recursively( self.vector_config, kwargs @@ -157,19 +193,24 @@ class StreamLines(VGroup): "color_map": "3b1b_colormap", } - def __init__(self, func, coordinate_system, **kwargs): + def __init__( + self, + func: Callable[[float, float], Sequence[float]], + coordinate_system: CoordinateSystem, + **kwargs + ): super().__init__(**kwargs) self.func = func self.coordinate_system = coordinate_system self.draw_lines() self.init_style() - def point_func(self, point): + def point_func(self, point: np.ndarray) -> np.ndarray: in_coords = self.coordinate_system.p2c(point) out_coords = self.func(*in_coords) return self.coordinate_system.c2p(*out_coords) - def draw_lines(self): + def draw_lines(self) -> None: lines = [] origin = self.coordinate_system.get_origin() for point in self.get_start_points(): @@ -194,7 +235,7 @@ class StreamLines(VGroup): lines.append(line) self.set_submobjects(lines) - def get_start_points(self): + def get_start_points(self) -> np.ndarray: cs = self.coordinate_system sample_coords = get_sample_points_from_coordinate_system( cs, self.step_multiple, @@ -210,7 +251,7 @@ class StreamLines(VGroup): for coords in sample_coords ]) - def init_style(self): + def init_style(self) -> None: if self.color_by_magnitude: values_to_rgbs = get_vectorized_rgb_gradient_function( *self.magnitude_range, self.color_map, @@ -247,7 +288,7 @@ class AnimatedStreamLines(VGroup): }, } - def __init__(self, stream_lines, **kwargs): + def __init__(self, stream_lines: StreamLines, **kwargs): super().__init__(**kwargs) self.stream_lines = stream_lines for line in stream_lines: @@ -262,7 +303,7 @@ class AnimatedStreamLines(VGroup): self.add_updater(lambda m, dt: m.update(dt)) - def update(self, dt): + def update(self, dt: float) -> None: stream_lines = self.stream_lines for line in stream_lines: line.time += dt @@ -278,7 +319,7 @@ class ShowPassingFlashWithThinningStrokeWidth(AnimationGroup): "remover": True } - def __init__(self, vmobject, **kwargs): + def __init__(self, vmobject: VMobject, **kwargs): digest_config(self, kwargs) max_stroke_width = vmobject.get_stroke_width() max_time_width = kwargs.pop("time_width", self.time_width) From db71ed1ae9dedd11e67a40ebaa153ed66414adc7 Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Tue, 15 Feb 2022 14:38:55 +0800 Subject: [PATCH 19/27] fix: fix type hint of remove_empty_value --- manimlib/utils/init_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/manimlib/utils/init_config.py b/manimlib/utils/init_config.py index d4e1ac5d..cb0a1787 100644 --- a/manimlib/utils/init_config.py +++ b/manimlib/utils/init_config.py @@ -17,7 +17,7 @@ def get_manim_dir() -> str: return os.path.abspath(os.path.join(manimlib_dir, "..")) -def remove_empty_value(dictionary: dict[str, Any]) -> dict[str, Any]: +def remove_empty_value(dictionary: dict[str, Any]) -> None: for key in list(dictionary.keys()): if dictionary[key] == "": dictionary.pop(key) From 91ffdeb2d40a4b13839b44b7c4d4869eab664899 Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Tue, 15 Feb 2022 14:49:02 +0800 Subject: [PATCH 20/27] chore: add type hints to manimlib.shader_wrapper --- manimlib/shader_wrapper.py | 55 +++++++++++++++++++++----------------- 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/manimlib/shader_wrapper.py b/manimlib/shader_wrapper.py index 87f1ccc6..bd32a7ae 100644 --- a/manimlib/shader_wrapper.py +++ b/manimlib/shader_wrapper.py @@ -1,8 +1,12 @@ +from __future__ import annotations + import os import re +import copy +from typing import Iterable + import moderngl import numpy as np -import copy from manimlib.utils.directories import get_shader_dir from manimlib.utils.file_ops import find_file @@ -15,15 +19,16 @@ from manimlib.utils.file_ops import find_file class ShaderWrapper(object): - def __init__(self, - vert_data=None, - vert_indices=None, - shader_folder=None, - uniforms=None, # A dictionary mapping names of uniform variables - texture_paths=None, # A dictionary mapping names to filepaths for textures. - depth_test=False, - render_primitive=moderngl.TRIANGLE_STRIP, - ): + def __init__( + self, + vert_data: np.ndarray | None = None, + vert_indices: np.ndarray | None = None, + shader_folder: str | None = None, + uniforms: dict[str, float] | None = None, # A dictionary mapping names of uniform variables + texture_paths: dict[str, str] | None = None, # A dictionary mapping names to filepaths for textures. + depth_test: bool = False, + render_primitive: int = moderngl.TRIANGLE_STRIP, + ): self.vert_data = vert_data self.vert_indices = vert_indices self.vert_attributes = vert_data.dtype.names @@ -46,20 +51,20 @@ class ShaderWrapper(object): result.texture_paths = dict(self.texture_paths) return result - def is_valid(self): + def is_valid(self) -> bool: return all([ self.vert_data is not None, self.program_code["vertex_shader"] is not None, self.program_code["fragment_shader"] is not None, ]) - def get_id(self): + def get_id(self) -> str: return self.id - def get_program_id(self): + def get_program_id(self) -> int: return self.program_id - def create_id(self): + def create_id(self) -> str: # A unique id for a shader return "|".join(map(str, [ self.program_id, @@ -69,32 +74,32 @@ class ShaderWrapper(object): self.render_primitive, ])) - def refresh_id(self): + def refresh_id(self) -> None: self.program_id = self.create_program_id() self.id = self.create_id() - def create_program_id(self): + def create_program_id(self) -> int: return hash("".join(( self.program_code[f"{name}_shader"] or "" for name in ("vertex", "geometry", "fragment") ))) - def init_program_code(self): - def get_code(name): + def init_program_code(self) -> None: + def get_code(name: str) -> str | None: return get_shader_code_from_file( os.path.join(self.shader_folder, f"{name}.glsl") ) - self.program_code = { + self.program_code: dict[str, str | None] = { "vertex_shader": get_code("vert"), "geometry_shader": get_code("geom"), "fragment_shader": get_code("frag"), } - def get_program_code(self): + def get_program_code(self) -> dict[str, str | None]: return self.program_code - def replace_code(self, old, new): + def replace_code(self, old: str, new: str) -> None: code_map = self.program_code for (name, code) in code_map.items(): if code_map[name] is None: @@ -102,7 +107,7 @@ class ShaderWrapper(object): code_map[name] = re.sub(old, new, code_map[name]) self.refresh_id() - def combine_with(self, *shader_wrappers): + def combine_with(self, *shader_wrappers: ShaderWrapper): # Assume they are of the same type if len(shader_wrappers) == 0: return @@ -122,10 +127,10 @@ class ShaderWrapper(object): # For caching -filename_to_code_map = {} +filename_to_code_map: dict[str, str] = {} -def get_shader_code_from_file(filename): +def get_shader_code_from_file(filename: str) -> str | None: if not filename: return None if filename in filename_to_code_map: @@ -157,7 +162,7 @@ def get_shader_code_from_file(filename): return result -def get_colormap_code(rgb_list): +def get_colormap_code(rgb_list: Iterable[float]) -> str: data = ",".join( "vec3({}, {}, {})".format(*rgb) for rgb in rgb_list From f085e6c2dd642c47ad994238576df6a7adc7d26c Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Tue, 15 Feb 2022 14:55:35 +0800 Subject: [PATCH 21/27] chore: add type hints to manimlib.window --- manimlib/window.py | 46 +++++++++++++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/manimlib/window.py b/manimlib/window.py index 63d0ffd1..5785283b 100644 --- a/manimlib/window.py +++ b/manimlib/window.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np import moderngl_window as mglw from moderngl_window.context.pyglet.window import Window as PygletWindow @@ -7,6 +9,10 @@ from screeninfo import get_monitors from manimlib.utils.config_ops import digest_config from manimlib.utils.customization import get_customization +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from manimlib.scene.scene import Scene + class Window(PygletWindow): fullscreen = False @@ -15,7 +21,12 @@ class Window(PygletWindow): vsync = True cursor = True - def __init__(self, scene, size=(1280, 720), **kwargs): + def __init__( + self, + scene: Scene, + size: tuple[int, int] = (1280, 720), + **kwargs + ): super().__init__(size=size) digest_config(self, kwargs) @@ -37,7 +48,7 @@ class Window(PygletWindow): self.position = initial_position self.position = initial_position - def find_initial_position(self, size): + def find_initial_position(self, size: tuple[int, int]) -> tuple[int, int]: custom_position = get_customization()["window_position"] monitors = get_monitors() mon_index = get_customization()["window_monitor"] @@ -59,7 +70,12 @@ class Window(PygletWindow): ) # Delegate event handling to scene - def pixel_coords_to_space_coords(self, px, py, relative=False): + def pixel_coords_to_space_coords( + self, + px: int, + py: int, + relative: bool = False + ) -> np.ndarray: pw, ph = self.size fw, fh = self.scene.camera.get_frame_shape() fc = self.scene.camera.get_frame_center() @@ -72,59 +88,59 @@ class Window(PygletWindow): 0 ]) - def on_mouse_motion(self, x, y, dx, dy): + def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> None: super().on_mouse_motion(x, y, dx, dy) point = self.pixel_coords_to_space_coords(x, y) d_point = self.pixel_coords_to_space_coords(dx, dy, relative=True) self.scene.on_mouse_motion(point, d_point) - def on_mouse_drag(self, x, y, dx, dy, buttons, modifiers): + def on_mouse_drag(self, x: int, y: int, dx: int, dy: int, buttons: int, modifiers: int) -> None: super().on_mouse_drag(x, y, dx, dy, buttons, modifiers) point = self.pixel_coords_to_space_coords(x, y) d_point = self.pixel_coords_to_space_coords(dx, dy, relative=True) self.scene.on_mouse_drag(point, d_point, buttons, modifiers) - def on_mouse_press(self, x: int, y: int, button, mods): + def on_mouse_press(self, x: int, y: int, button: int, mods: int) -> None: super().on_mouse_press(x, y, button, mods) point = self.pixel_coords_to_space_coords(x, y) self.scene.on_mouse_press(point, button, mods) - def on_mouse_release(self, x: int, y: int, button, mods): + def on_mouse_release(self, x: int, y: int, button: int, mods: int) -> None: super().on_mouse_release(x, y, button, mods) point = self.pixel_coords_to_space_coords(x, y) self.scene.on_mouse_release(point, button, mods) - def on_mouse_scroll(self, x, y, x_offset: float, y_offset: float): + def on_mouse_scroll(self, x: int, y: int, x_offset: float, y_offset: float) -> None: super().on_mouse_scroll(x, y, x_offset, y_offset) point = self.pixel_coords_to_space_coords(x, y) offset = self.pixel_coords_to_space_coords(x_offset, y_offset, relative=True) self.scene.on_mouse_scroll(point, offset) - def on_key_press(self, symbol, modifiers): + def on_key_press(self, symbol: int, modifiers: int) -> None: self.pressed_keys.add(symbol) # Modifiers? super().on_key_press(symbol, modifiers) self.scene.on_key_press(symbol, modifiers) - def on_key_release(self, symbol, modifiers): + def on_key_release(self, symbol: int, modifiers: int) -> None: self.pressed_keys.difference_update({symbol}) # Modifiers? super().on_key_release(symbol, modifiers) self.scene.on_key_release(symbol, modifiers) - def on_resize(self, width: int, height: int): + def on_resize(self, width: int, height: int) -> None: super().on_resize(width, height) self.scene.on_resize(width, height) - def on_show(self): + def on_show(self) -> None: super().on_show() self.scene.on_show() - def on_hide(self): + def on_hide(self) -> None: super().on_hide() self.scene.on_hide() - def on_close(self): + def on_close(self) -> None: super().on_close() self.scene.on_close() - def is_key_pressed(self, symbol): + def is_key_pressed(self, symbol: int) -> bool: return (symbol in self.pressed_keys) From d19e0cb9ab827e7d25d823dca2e2884f9385bfc0 Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Tue, 15 Feb 2022 14:56:00 +0800 Subject: [PATCH 22/27] fix: remove import before future --- manimlib/mobject/value_tracker.py | 1 - 1 file changed, 1 deletion(-) diff --git a/manimlib/mobject/value_tracker.py b/manimlib/mobject/value_tracker.py index 40c61d2e..0ff5f73c 100644 --- a/manimlib/mobject/value_tracker.py +++ b/manimlib/mobject/value_tracker.py @@ -1,4 +1,3 @@ -from asyncio import futures from __future__ import annotations import numpy as np From 41c4023986f26f6a2f2a7f27935ec879abc61fcf Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Tue, 15 Feb 2022 18:39:45 +0800 Subject: [PATCH 23/27] chore: add type hints to manimlib.animation --- manimlib/animation/animation.py | 61 ++++++--- manimlib/animation/composition.py | 42 ++++-- manimlib/animation/creation.py | 66 +++++---- manimlib/animation/fading.py | 56 +++++--- manimlib/animation/growing.py | 23 ++-- manimlib/animation/indication.py | 83 ++++++++---- manimlib/animation/movement.py | 48 +++++-- manimlib/animation/numbers.py | 27 +++- manimlib/animation/rotation.py | 25 +++- manimlib/animation/specialized.py | 6 +- manimlib/animation/transform.py | 128 +++++++++++++----- .../animation/transform_matching_parts.py | 33 +++-- manimlib/animation/update.py | 27 +++- 13 files changed, 443 insertions(+), 182 deletions(-) diff --git a/manimlib/animation/animation.py b/manimlib/animation/animation.py index ffefabce..d1225cd3 100644 --- a/manimlib/animation/animation.py +++ b/manimlib/animation/animation.py @@ -1,4 +1,7 @@ +from __future__ import annotations + from copy import deepcopy +from typing import Callable from manimlib.mobject.mobject import _AnimationBuilder from manimlib.mobject.mobject import Mobject @@ -6,6 +9,10 @@ from manimlib.utils.config_ops import digest_config from manimlib.utils.rate_functions import smooth from manimlib.utils.simple_functions import clip +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from manimlib.scene.scene import Scene + DEFAULT_ANIMATION_RUN_TIME = 1.0 DEFAULT_ANIMATION_LAG_RATIO = 0 @@ -29,17 +36,17 @@ class Animation(object): "suspend_mobject_updating": True, } - def __init__(self, mobject, **kwargs): + def __init__(self, mobject: Mobject, **kwargs): assert(isinstance(mobject, Mobject)) digest_config(self, kwargs) self.mobject = mobject - def __str__(self): + def __str__(self) -> str: if self.name: return self.name return self.__class__.__name__ + str(self.mobject) - def begin(self): + def begin(self) -> None: # This is called right as an animation is being # played. As much initialization as possible, # especially any mobject copying, should live in @@ -56,32 +63,32 @@ class Animation(object): self.families = list(self.get_all_families_zipped()) self.interpolate(0) - def finish(self): + def finish(self) -> None: self.interpolate(self.final_alpha_value) if self.suspend_mobject_updating: self.mobject.resume_updating() - def clean_up_from_scene(self, scene): + def clean_up_from_scene(self, scene: Scene) -> None: if self.is_remover(): scene.remove(self.mobject) - def create_starting_mobject(self): + def create_starting_mobject(self) -> Mobject: # Keep track of where the mobject starts return self.mobject.copy() - def get_all_mobjects(self): + def get_all_mobjects(self) -> tuple[Mobject, Mobject]: """ Ordering must match the ording of arguments to interpolate_submobject """ return self.mobject, self.starting_mobject - def get_all_families_zipped(self): + def get_all_families_zipped(self) -> zip[tuple[Mobject]]: return zip(*[ mob.get_family() for mob in self.get_all_mobjects() ]) - def update_mobjects(self, dt): + def update_mobjects(self, dt: float) -> None: """ Updates things like starting_mobject, and (for Transforms) target_mobject. Note, since typically @@ -92,7 +99,7 @@ class Animation(object): for mob in self.get_all_mobjects_to_update(): mob.update(dt) - def get_all_mobjects_to_update(self): + def get_all_mobjects_to_update(self) -> list[Mobject]: # The surrounding scene typically handles # updating of self.mobject. Besides, in # most cases its updating is suspended anyway @@ -109,27 +116,37 @@ class Animation(object): return self # Methods for interpolation, the mean of an Animation - def interpolate(self, alpha): + def interpolate(self, alpha: float) -> None: alpha = clip(alpha, 0, 1) self.interpolate_mobject(self.rate_func(alpha)) - def update(self, alpha): + def update(self, alpha: float) -> None: """ This method shouldn't exist, but it's here to keep many old scenes from breaking """ self.interpolate(alpha) - def interpolate_mobject(self, alpha): + def interpolate_mobject(self, alpha: float) -> None: for i, mobs in enumerate(self.families): sub_alpha = self.get_sub_alpha(alpha, i, len(self.families)) self.interpolate_submobject(*mobs, sub_alpha) - def interpolate_submobject(self, submobject, starting_sumobject, alpha): + def interpolate_submobject( + self, + submobject: Mobject, + starting_submobject: Mobject, + alpha: float + ): # Typically ipmlemented by subclass pass - def get_sub_alpha(self, alpha, index, num_submobjects): + def get_sub_alpha( + self, + alpha: float, + index: int, + num_submobjects: int + ) -> float: # TODO, make this more understanable, and/or combine # its functionality with AnimationGroup's method # build_animations_with_timings @@ -140,29 +157,29 @@ class Animation(object): return clip((value - lower), 0, 1) # Getters and setters - def set_run_time(self, run_time): + def set_run_time(self, run_time: float): self.run_time = run_time return self - def get_run_time(self): + def get_run_time(self) -> float: return self.run_time - def set_rate_func(self, rate_func): + def set_rate_func(self, rate_func: Callable[[float], float]): self.rate_func = rate_func return self - def get_rate_func(self): + def get_rate_func(self) -> Callable[[float], float]: return self.rate_func - def set_name(self, name): + def set_name(self, name: str): self.name = name return self - def is_remover(self): + def is_remover(self) -> bool: return self.remover -def prepare_animation(anim): +def prepare_animation(anim: Animation | _AnimationBuilder): if isinstance(anim, _AnimationBuilder): return anim.build() diff --git a/manimlib/animation/composition.py b/manimlib/animation/composition.py index ba175dce..78a7b024 100644 --- a/manimlib/animation/composition.py +++ b/manimlib/animation/composition.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import numpy as np +from typing import Callable from manimlib.animation.animation import Animation, prepare_animation from manimlib.mobject.mobject import Group @@ -9,6 +12,11 @@ from manimlib.utils.iterables import remove_list_redundancies from manimlib.utils.rate_functions import linear from manimlib.utils.simple_functions import clip +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from manimlib.scene.scene import Scene + from manimlib.mobject.mobject import Mobject + DEFAULT_LAGGED_START_LAG_RATIO = 0.05 @@ -27,7 +35,7 @@ class AnimationGroup(Animation): "group": None, } - def __init__(self, *animations, **kwargs): + def __init__(self, *animations: Animation, **kwargs): digest_config(self, kwargs) self.animations = [prepare_animation(anim) for anim in animations] if self.group is None: @@ -37,27 +45,27 @@ class AnimationGroup(Animation): self.init_run_time() Animation.__init__(self, self.group, **kwargs) - def get_all_mobjects(self): + def get_all_mobjects(self) -> Group: return self.group - def begin(self): + def begin(self) -> None: for anim in self.animations: anim.begin() # self.init_run_time() - def finish(self): + def finish(self) -> None: for anim in self.animations: anim.finish() - def clean_up_from_scene(self, scene): + def clean_up_from_scene(self, scene: Scene) -> None: for anim in self.animations: anim.clean_up_from_scene(scene) - def update_mobjects(self, dt): + def update_mobjects(self, dt: float) -> None: for anim in self.animations: anim.update_mobjects(dt) - def init_run_time(self): + def init_run_time(self) -> None: self.build_animations_with_timings() if self.anims_with_timings: self.max_end_time = np.max([ @@ -68,7 +76,7 @@ class AnimationGroup(Animation): if self.run_time is None: self.run_time = self.max_end_time - def build_animations_with_timings(self): + def build_animations_with_timings(self) -> None: """ Creates a list of triplets of the form (anim, start_time, end_time) @@ -87,7 +95,7 @@ class AnimationGroup(Animation): start_time, end_time, self.lag_ratio ) - def interpolate(self, alpha): + def interpolate(self, alpha: float) -> None: # Note, if the run_time of AnimationGroup has been # set to something other than its default, these # times might not correspond to actual times, @@ -111,19 +119,19 @@ class Succession(AnimationGroup): "lag_ratio": 1, } - def begin(self): + def begin(self) -> None: assert(len(self.animations) > 0) self.init_run_time() self.active_animation = self.animations[0] self.active_animation.begin() - def finish(self): + def finish(self) -> None: self.active_animation.finish() - def update_mobjects(self, dt): + def update_mobjects(self, dt: float) -> None: self.active_animation.update_mobjects(dt) - def interpolate(self, alpha): + def interpolate(self, alpha: float) -> None: index, subalpha = integer_interpolate( 0, len(self.animations), alpha ) @@ -146,7 +154,13 @@ class LaggedStartMap(LaggedStart): "run_time": 2, } - def __init__(self, AnimationClass, mobject, arg_creator=None, **kwargs): + def __init__( + self, + AnimationClass: type, + mobject: Mobject, + arg_creator: Callable[[Mobject], tuple] | None = None, + **kwargs + ): args_list = [] for submob in mobject: if arg_creator: diff --git a/manimlib/animation/creation.py b/manimlib/animation/creation.py index 5a6a04f9..c4f52149 100644 --- a/manimlib/animation/creation.py +++ b/manimlib/animation/creation.py @@ -1,3 +1,10 @@ +from __future__ import annotations + +import itertools as it +from abc import abstractmethod + +import numpy as np + from manimlib.animation.animation import Animation from manimlib.animation.composition import Succession from manimlib.mobject.types.vectorized_mobject import VMobject @@ -7,8 +14,9 @@ from manimlib.utils.rate_functions import linear from manimlib.utils.rate_functions import double_smooth from manimlib.utils.rate_functions import smooth -import numpy as np -import itertools as it +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from manimlib.mobject.mobject import Group class ShowPartial(Animation): @@ -19,21 +27,27 @@ class ShowPartial(Animation): "should_match_start": False, } - def begin(self): + def begin(self) -> None: super().begin() if not self.should_match_start: self.mobject.lock_matching_data(self.mobject, self.starting_mobject) - def finish(self): + def finish(self) -> None: super().finish() self.mobject.unlock_data() - def interpolate_submobject(self, submob, start_submob, alpha): + def interpolate_submobject( + self, + submob: VMobject, + start_submob: VMobject, + alpha: float + ) -> None: submob.pointwise_become_partial( start_submob, *self.get_bounds(alpha) ) - def get_bounds(self, alpha): + @abstractmethod + def get_bounds(self, alpha: float) -> tuple[float, float]: raise Exception("Not Implemented") @@ -42,7 +56,7 @@ class ShowCreation(ShowPartial): "lag_ratio": 1, } - def get_bounds(self, alpha): + def get_bounds(self, alpha: float) -> tuple[float, float]: return (0, alpha) @@ -64,7 +78,7 @@ class DrawBorderThenFill(Animation): "fill_animation_config": {}, } - def __init__(self, vmobject, **kwargs): + def __init__(self, vmobject: VMobject, **kwargs): assert(isinstance(vmobject, VMobject)) self.sm_to_index = dict([ (hash(sm), 0) @@ -72,7 +86,7 @@ class DrawBorderThenFill(Animation): ]) super().__init__(vmobject, **kwargs) - def begin(self): + def begin(self) -> None: # Trigger triangulation calculation for submob in self.mobject.get_family(): submob.get_triangulation() @@ -82,11 +96,11 @@ class DrawBorderThenFill(Animation): self.mobject.match_style(self.outline) self.mobject.lock_matching_data(self.mobject, self.outline) - def finish(self): + def finish(self) -> None: super().finish() self.mobject.unlock_data() - def get_outline(self): + def get_outline(self) -> VMobject: outline = self.mobject.copy() outline.set_fill(opacity=0) for sm in outline.get_family(): @@ -96,17 +110,23 @@ class DrawBorderThenFill(Animation): ) return outline - def get_stroke_color(self, vmobject): + def get_stroke_color(self, vmobject: VMobject) -> str: if self.stroke_color: return self.stroke_color elif vmobject.get_stroke_width() > 0: return vmobject.get_stroke_color() return vmobject.get_color() - def get_all_mobjects(self): + def get_all_mobjects(self) -> list[VMobject]: return [*super().get_all_mobjects(), self.outline] - def interpolate_submobject(self, submob, start, outline, alpha): + def interpolate_submobject( + self, + submob: VMobject, + start: VMobject, + outline: VMobject, + alpha: float + ) -> None: index, subalpha = integer_interpolate(0, 2, alpha) if index == 1 and self.sm_to_index[hash(submob)] == 0: @@ -133,13 +153,13 @@ class Write(DrawBorderThenFill): "rate_func": linear, } - def __init__(self, mobject, **kwargs): + def __init__(self, vmobject: VMobject, **kwargs): digest_config(self, kwargs) - self.set_default_config_from_length(mobject) - super().__init__(mobject, **kwargs) + self.set_default_config_from_length(vmobject) + super().__init__(vmobject, **kwargs) - def set_default_config_from_length(self, mobject): - length = len(mobject.family_members_with_points()) + def set_default_config_from_length(self, vmobject: VMobject) -> None: + length = len(vmobject.family_members_with_points()) if self.run_time is None: if length < 15: self.run_time = 1 @@ -155,16 +175,16 @@ class ShowIncreasingSubsets(Animation): "int_func": np.round, } - def __init__(self, group, **kwargs): + def __init__(self, group: Group, **kwargs): self.all_submobs = list(group.submobjects) super().__init__(group, **kwargs) - def interpolate_mobject(self, alpha): + def interpolate_mobject(self, alpha: float) -> None: n_submobs = len(self.all_submobs) index = int(self.int_func(alpha * n_submobs)) self.update_submobject_list(index) - def update_submobject_list(self, index): + def update_submobject_list(self, index: int) -> None: self.mobject.set_submobjects(self.all_submobs[:index]) @@ -173,7 +193,7 @@ class ShowSubmobjectsOneByOne(ShowIncreasingSubsets): "int_func": np.ceil, } - def update_submobject_list(self, index): + def update_submobject_list(self, index: int) -> None: # N = len(self.all_submobs) if index == 0: self.mobject.set_submobjects([]) diff --git a/manimlib/animation/fading.py b/manimlib/animation/fading.py index 2263b00d..2d246c55 100644 --- a/manimlib/animation/fading.py +++ b/manimlib/animation/fading.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from manimlib.animation.animation import Animation @@ -7,6 +9,12 @@ from manimlib.constants import ORIGIN from manimlib.utils.bezier import interpolate from manimlib.utils.rate_functions import there_and_back +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from manimlib.scene.scene import Scene + from manimlib.mobject.mobject import Mobject + from manimlib.mobject.types.vectorized_mobject import VMobject + DEFAULT_FADE_LAG_RATIO = 0 @@ -16,7 +24,13 @@ class Fade(Transform): "lag_ratio": DEFAULT_FADE_LAG_RATIO, } - def __init__(self, mobject, shift=ORIGIN, scale=1, **kwargs): + def __init__( + self, + mobject: Mobject, + shift: np.ndarray = ORIGIN, + scale: float = 1, + **kwargs + ): self.shift_vect = shift self.scale_factor = scale super().__init__(mobject, **kwargs) @@ -27,10 +41,10 @@ class FadeIn(Fade): "lag_ratio": DEFAULT_FADE_LAG_RATIO, } - def create_target(self): + def create_target(self) -> Mobject: return self.mobject - def create_starting_mobject(self): + def create_starting_mobject(self) -> Mobject: start = super().create_starting_mobject() start.set_opacity(0) start.scale(1.0 / self.scale_factor) @@ -45,7 +59,7 @@ class FadeOut(Fade): "final_alpha_value": 0, } - def create_target(self): + def create_target(self) -> Mobject: result = self.mobject.copy() result.set_opacity(0) result.shift(self.shift_vect) @@ -54,7 +68,7 @@ class FadeOut(Fade): class FadeInFromPoint(FadeIn): - def __init__(self, mobject, point, **kwargs): + def __init__(self, mobject: Mobject, point: np.ndarray, **kwargs): super().__init__( mobject, shift=mobject.get_center() - point, @@ -64,7 +78,7 @@ class FadeInFromPoint(FadeIn): class FadeOutToPoint(FadeOut): - def __init__(self, mobject, point, **kwargs): + def __init__(self, mobject: Mobject, point: np.ndarray, **kwargs): super().__init__( mobject, shift=point - mobject.get_center(), @@ -79,7 +93,7 @@ class FadeTransform(Transform): "dim_to_match": 1, } - def __init__(self, mobject, target_mobject, **kwargs): + def __init__(self, mobject: Mobject, target_mobject: Mobject, **kwargs): self.to_add_on_completion = target_mobject mobject.save_state() super().__init__( @@ -87,7 +101,7 @@ class FadeTransform(Transform): **kwargs ) - def begin(self): + def begin(self) -> None: self.ending_mobject = self.mobject.copy() Animation.begin(self) # Both 'start' and 'end' consists of the source and target mobjects. @@ -97,21 +111,21 @@ class FadeTransform(Transform): for m0, m1 in ((start[1], start[0]), (end[0], end[1])): self.ghost_to(m0, m1) - def ghost_to(self, source, target): + def ghost_to(self, source: Mobject, target: Mobject) -> None: source.replace(target, stretch=self.stretch, dim_to_match=self.dim_to_match) source.set_opacity(0) - def get_all_mobjects(self): + def get_all_mobjects(self) -> list[Mobject]: return [ self.mobject, self.starting_mobject, self.ending_mobject, ] - def get_all_families_zipped(self): + def get_all_families_zipped(self) -> zip[tuple[Mobject]]: return Animation.get_all_families_zipped(self) - def clean_up_from_scene(self, scene): + def clean_up_from_scene(self, scene: Scene) -> None: Animation.clean_up_from_scene(self, scene) scene.remove(self.mobject) self.mobject[0].restore() @@ -119,11 +133,11 @@ class FadeTransform(Transform): class FadeTransformPieces(FadeTransform): - def begin(self): + def begin(self) -> None: self.mobject[0].align_family(self.mobject[1]) super().begin() - def ghost_to(self, source, target): + def ghost_to(self, source: Mobject, target: Mobject) -> None: for sm0, sm1 in zip(source.get_family(), target.get_family()): super().ghost_to(sm0, sm1) @@ -136,7 +150,12 @@ class VFadeIn(Animation): "suspend_mobject_updating": False, } - def interpolate_submobject(self, submob, start, alpha): + def interpolate_submobject( + self, + submob: VMobject, + start: VMobject, + alpha: float + ) -> None: submob.set_stroke( opacity=interpolate(0, start.get_stroke_opacity(), alpha) ) @@ -152,7 +171,12 @@ class VFadeOut(VFadeIn): "final_alpha_value": 0, } - def interpolate_submobject(self, submob, start, alpha): + def interpolate_submobject( + self, + submob: VMobject, + start: VMobject, + alpha: float + ) -> None: super().interpolate_submobject(submob, start, 1 - alpha) diff --git a/manimlib/animation/growing.py b/manimlib/animation/growing.py index a4dcfeba..5e0d0eda 100644 --- a/manimlib/animation/growing.py +++ b/manimlib/animation/growing.py @@ -1,6 +1,13 @@ -from manimlib.animation.transform import Transform -# from manimlib.utils.paths import counterclockwise_path +from __future__ import annotations + from manimlib.constants import PI +from manimlib.animation.transform import Transform + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + import numpy as np + from manimlib.mobject.mobject import Mobject + from manimlib.mobject.geometry import Arrow class GrowFromPoint(Transform): @@ -8,14 +15,14 @@ class GrowFromPoint(Transform): "point_color": None, } - def __init__(self, mobject, point, **kwargs): + def __init__(self, mobject: Mobject, point: np.ndarray, **kwargs): self.point = point super().__init__(mobject, **kwargs) - def create_target(self): + def create_target(self) -> Mobject: return self.mobject - def create_starting_mobject(self): + def create_starting_mobject(self) -> Mobject: start = super().create_starting_mobject() start.scale(0) start.move_to(self.point) @@ -25,19 +32,19 @@ class GrowFromPoint(Transform): class GrowFromCenter(GrowFromPoint): - def __init__(self, mobject, **kwargs): + def __init__(self, mobject: Mobject, **kwargs): point = mobject.get_center() super().__init__(mobject, point, **kwargs) class GrowFromEdge(GrowFromPoint): - def __init__(self, mobject, edge, **kwargs): + def __init__(self, mobject: Mobject, edge: np.ndarray, **kwargs): point = mobject.get_bounding_box_point(edge) super().__init__(mobject, point, **kwargs) class GrowArrow(GrowFromPoint): - def __init__(self, arrow, **kwargs): + def __init__(self, arrow: Arrow, **kwargs): point = arrow.get_start() super().__init__(arrow, point, **kwargs) diff --git a/manimlib/animation/indication.py b/manimlib/animation/indication.py index f99de961..7a463aaa 100644 --- a/manimlib/animation/indication.py +++ b/manimlib/animation/indication.py @@ -1,5 +1,9 @@ -import numpy as np +from __future__ import annotations + import math +from typing import Union, Sequence + +import numpy as np from manimlib.constants import * from manimlib.animation.animation import Animation @@ -10,7 +14,7 @@ from manimlib.animation.creation import ShowCreation from manimlib.animation.creation import ShowPartial from manimlib.animation.fading import FadeOut from manimlib.animation.fading import FadeIn -from manimlib.animation.transform import Transform +from manimlib.animation.transform import ManimColor, Transform from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.mobject.geometry import Circle from manimlib.mobject.geometry import Dot @@ -25,6 +29,12 @@ from manimlib.utils.rate_functions import wiggle from manimlib.utils.rate_functions import smooth from manimlib.utils.rate_functions import squish_rate_func +from typing import TYPE_CHECKING +if TYPE_CHECKING: + import colour + from manimlib.mobject.mobject import Mobject + ManimColor = Union[str, colour.Color, Sequence[float]] + class FocusOn(Transform): CONFIG = { @@ -34,13 +44,13 @@ class FocusOn(Transform): "remover": True, } - def __init__(self, focus_point, **kwargs): + def __init__(self, focus_point: np.ndarray, **kwargs): self.focus_point = focus_point # Initialize with blank mobject, while create_target # and create_starting_mobject handle the meat super().__init__(VMobject(), **kwargs) - def create_target(self): + def create_target(self) -> Dot: little_dot = Dot(radius=0) little_dot.set_fill(self.color, opacity=self.opacity) little_dot.add_updater( @@ -48,7 +58,7 @@ class FocusOn(Transform): ) return little_dot - def create_starting_mobject(self): + def create_starting_mobject(self) -> Dot: return Dot( radius=FRAME_X_RADIUS + FRAME_Y_RADIUS, stroke_width=0, @@ -64,7 +74,7 @@ class Indicate(Transform): "color": YELLOW, } - def create_target(self): + def create_target(self) -> Mobject: target = self.mobject.copy() target.scale(self.scale_factor) target.set_color(self.color) @@ -80,7 +90,12 @@ class Flash(AnimationGroup): "run_time": 1, } - def __init__(self, point, color=YELLOW, **kwargs): + def __init__( + self, + point: np.ndarray, + color: ManimColor = YELLOW, + **kwargs + ): self.point = point self.color = color digest_config(self, kwargs) @@ -92,7 +107,7 @@ class Flash(AnimationGroup): **kwargs, ) - def create_lines(self): + def create_lines(self) -> VGroup: lines = VGroup() for angle in np.arange(0, TAU, TAU / self.num_lines): line = Line(ORIGIN, self.line_length * RIGHT) @@ -106,7 +121,7 @@ class Flash(AnimationGroup): lines.add_updater(lambda l: l.move_to(self.point)) return lines - def create_line_anims(self): + def create_line_anims(self) -> list[Animation]: return [ ShowCreationThenDestruction(line) for line in self.lines @@ -122,17 +137,17 @@ class CircleIndicate(Indicate): }, } - def __init__(self, mobject, **kwargs): + def __init__(self, mobject: Mobject, **kwargs): digest_config(self, kwargs) circle = self.get_circle(mobject) super().__init__(circle, **kwargs) - def get_circle(self, mobject): + def get_circle(self, mobject: Mobject) -> Circle: circle = Circle(**self.circle_config) circle.add_updater(lambda c: c.surround(mobject)) return circle - def interpolate_mobject(self, alpha): + def interpolate_mobject(self, alpha: float) -> None: super().interpolate_mobject(alpha) self.mobject.set_stroke(opacity=alpha) @@ -143,7 +158,7 @@ class ShowPassingFlash(ShowPartial): "remover": True, } - def get_bounds(self, alpha): + def get_bounds(self, alpha: float) -> tuple[float, float]: tw = self.time_width upper = interpolate(0, 1 + tw, alpha) lower = upper - tw @@ -151,7 +166,7 @@ class ShowPassingFlash(ShowPartial): lower = max(lower, 0) return (lower, upper) - def finish(self): + def finish(self) -> None: super().finish() for submob, start in self.get_all_families_zipped(): submob.pointwise_become_partial(start, 0, 1) @@ -164,7 +179,7 @@ class VShowPassingFlash(Animation): "remover": True, } - def begin(self): + def begin(self) -> None: self.mobject.align_stroke_width_data_to_points() # Compute an array of stroke widths for each submobject # which tapers out at either end @@ -184,7 +199,12 @@ class VShowPassingFlash(Animation): self.submob_to_anchor_widths[hash(sm)] = anchor_widths * taper_array super().begin() - def interpolate_submobject(self, submobject, starting_sumobject, alpha): + def interpolate_submobject( + self, + submobject: VMobject, + starting_sumobject: None, + alpha: float + ) -> None: anchor_widths = self.submob_to_anchor_widths[hash(submobject)] # Create a gaussian such that 3 sigmas out on either side # will equals time_width @@ -206,7 +226,7 @@ class VShowPassingFlash(Animation): new_widths[1::3] = (new_widths[0::3] + new_widths[2::3]) / 2 submobject.set_stroke(width=new_widths) - def finish(self): + def finish(self) -> None: super().finish() for submob, start in self.get_all_families_zipped(): submob.match_style(start) @@ -221,7 +241,7 @@ class FlashAround(VShowPassingFlash): "n_inserted_curves": 20, } - def __init__(self, mobject, **kwargs): + def __init__(self, mobject: Mobject, **kwargs): digest_config(self, kwargs) path = self.get_path(mobject) if mobject.is_fixed_in_frame: @@ -231,12 +251,12 @@ class FlashAround(VShowPassingFlash): path.set_stroke(self.color, self.stroke_width) super().__init__(path, **kwargs) - def get_path(self, mobject): + def get_path(self, mobject: Mobject) -> SurroundingRectangle: return SurroundingRectangle(mobject, buff=self.buff) class FlashUnder(FlashAround): - def get_path(self, mobject): + def get_path(self, mobject: Mobject) -> Underline: return Underline(mobject, buff=self.buff) @@ -252,7 +272,7 @@ class ShowCreationThenFadeOut(Succession): "remover": True, } - def __init__(self, mobject, **kwargs): + def __init__(self, mobject: Mobject, **kwargs): super().__init__( ShowCreation(mobject), FadeOut(mobject), @@ -269,7 +289,7 @@ class AnimationOnSurroundingRectangle(AnimationGroup): "rect_animation": Animation } - def __init__(self, mobject, **kwargs): + def __init__(self, mobject: Mobject, **kwargs): digest_config(self, kwargs) if "surrounding_rectangle_config" in kwargs: kwargs.pop("surrounding_rectangle_config") @@ -282,7 +302,7 @@ class AnimationOnSurroundingRectangle(AnimationGroup): self.rect_animation(rect, **kwargs), ) - def get_rect(self): + def get_rect(self) -> SurroundingRectangle: return SurroundingRectangle( self.mobject_to_surround, **self.surrounding_rectangle_config @@ -314,7 +334,7 @@ class ApplyWave(Homotopy): "run_time": 1, } - def __init__(self, mobject, **kwargs): + def __init__(self, mobject: Mobject, **kwargs): digest_config(self, kwargs, locals()) left_x = mobject.get_left()[0] right_x = mobject.get_right()[0] @@ -339,15 +359,20 @@ class WiggleOutThenIn(Animation): "rotate_about_point": None, } - def get_scale_about_point(self): + def get_scale_about_point(self) -> np.ndarray: if self.scale_about_point is None: return self.mobject.get_center() - def get_rotate_about_point(self): + def get_rotate_about_point(self) -> np.ndarray: if self.rotate_about_point is None: return self.mobject.get_center() - def interpolate_submobject(self, submobject, starting_sumobject, alpha): + def interpolate_submobject( + self, + submobject: Mobject, + starting_sumobject: Mobject, + alpha: float + ) -> None: submobject.match_points(starting_sumobject) submobject.scale( interpolate(1, self.scale_value, there_and_back(alpha)), @@ -364,7 +389,7 @@ class TurnInsideOut(Transform): "path_arc": TAU / 4, } - def create_target(self): + def create_target(self) -> Mobject: return self.mobject.copy().reverse_points() @@ -373,7 +398,7 @@ class FlashyFadeIn(AnimationGroup): "fade_lag": 0, } - def __init__(self, vmobject, stroke_width=2, **kwargs): + def __init__(self, vmobject: VMobject, stroke_width: float = 2, **kwargs): digest_config(self, kwargs) outline = vmobject.copy() outline.set_fill(opacity=0) diff --git a/manimlib/animation/movement.py b/manimlib/animation/movement.py index d1cea65e..2d157a05 100644 --- a/manimlib/animation/movement.py +++ b/manimlib/animation/movement.py @@ -1,6 +1,15 @@ +from __future__ import annotations + +from typing import Callable, Sequence + from manimlib.animation.animation import Animation from manimlib.utils.rate_functions import linear +from typing import TYPE_CHECKING +if TYPE_CHECKING: + import numpy as np + from manimlib.mobject.mobject import Mobject + class Homotopy(Animation): CONFIG = { @@ -8,7 +17,12 @@ class Homotopy(Animation): "apply_function_kwargs": {}, } - def __init__(self, homotopy, mobject, **kwargs): + def __init__( + self, + homotopy: Callable[[float, float, float, float], Sequence[float]], + mobject: Mobject, + **kwargs + ): """ Homotopy is a function from (x, y, z, t) to (x', y', z') @@ -16,10 +30,18 @@ class Homotopy(Animation): self.homotopy = homotopy super().__init__(mobject, **kwargs) - def function_at_time_t(self, t): + def function_at_time_t( + self, + t: float + ) -> Callable[[np.ndarray], Sequence[float]]: return lambda p: self.homotopy(*p, t) - def interpolate_submobject(self, submob, start, alpha): + def interpolate_submobject( + self, + submob: Mobject, + start: Mobject, + alpha: float + ) -> None: submob.match_points(start) submob.apply_function( self.function_at_time_t(alpha), @@ -34,7 +56,12 @@ class SmoothedVectorizedHomotopy(Homotopy): class ComplexHomotopy(Homotopy): - def __init__(self, complex_homotopy, mobject, **kwargs): + def __init__( + self, + complex_homotopy: Callable[[complex, float], Sequence[float]], + mobject: Mobject, + **kwargs + ): """ Given a function form (z, t) -> w, where z and w are complex numbers and t is time, this animates @@ -53,11 +80,16 @@ class PhaseFlow(Animation): "suspend_mobject_updating": False, } - def __init__(self, function, mobject, **kwargs): + def __init__( + self, + function: Callable[[np.ndarray], np.ndarray], + mobject: Mobject, + **kwargs + ): self.function = function super().__init__(mobject, **kwargs) - def interpolate_mobject(self, alpha): + def interpolate_mobject(self, alpha: float) -> None: if hasattr(self, "last_alpha"): dt = self.virtual_time * (alpha - self.last_alpha) self.mobject.apply_function( @@ -71,10 +103,10 @@ class MoveAlongPath(Animation): "suspend_mobject_updating": False, } - def __init__(self, mobject, path, **kwargs): + def __init__(self, mobject: Mobject, path: Mobject, **kwargs): self.path = path super().__init__(mobject, **kwargs) - def interpolate_mobject(self, alpha): + def interpolate_mobject(self, alpha: float) -> None: point = self.path.point_from_proportion(alpha) self.mobject.move_to(point) diff --git a/manimlib/animation/numbers.py b/manimlib/animation/numbers.py index 1cbd2489..0a992b39 100644 --- a/manimlib/animation/numbers.py +++ b/manimlib/animation/numbers.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import Callable + from manimlib.animation.animation import Animation from manimlib.mobject.numbers import DecimalNumber from manimlib.utils.bezier import interpolate @@ -8,19 +12,29 @@ class ChangingDecimal(Animation): "suspend_mobject_updating": False, } - def __init__(self, decimal_mob, number_update_func, **kwargs): + def __init__( + self, + decimal_mob: DecimalNumber, + number_update_func: Callable[[float], float], + **kwargs + ): assert(isinstance(decimal_mob, DecimalNumber)) self.number_update_func = number_update_func super().__init__(decimal_mob, **kwargs) - def interpolate_mobject(self, alpha): + def interpolate_mobject(self, alpha: float) -> None: self.mobject.set_value( self.number_update_func(alpha) ) class ChangeDecimalToValue(ChangingDecimal): - def __init__(self, decimal_mob, target_number, **kwargs): + def __init__( + self, + decimal_mob: DecimalNumber, + target_number: float | complex, + **kwargs + ): start_number = decimal_mob.number super().__init__( decimal_mob, @@ -30,7 +44,12 @@ class ChangeDecimalToValue(ChangingDecimal): class CountInFrom(ChangingDecimal): - def __init__(self, decimal_mob, source_number=0, **kwargs): + def __init__( + self, + decimal_mob: DecimalNumber, + source_number: float | complex = 0, + **kwargs + ): start_number = decimal_mob.number super().__init__( decimal_mob, diff --git a/manimlib/animation/rotation.py b/manimlib/animation/rotation.py index a37fb0b6..8c9b7ddf 100644 --- a/manimlib/animation/rotation.py +++ b/manimlib/animation/rotation.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from manimlib.animation.animation import Animation from manimlib.constants import OUT from manimlib.constants import PI @@ -6,6 +8,11 @@ from manimlib.constants import ORIGIN from manimlib.utils.rate_functions import linear from manimlib.utils.rate_functions import smooth +from typing import TYPE_CHECKING +if TYPE_CHECKING: + import numpy as np + from manimlib.mobject.mobject import Mobject + class Rotating(Animation): CONFIG = { @@ -18,12 +25,18 @@ class Rotating(Animation): "suspend_mobject_updating": False, } - def __init__(self, mobject, angle=TAU, axis=OUT, **kwargs): + def __init__( + self, + mobject: Mobject, + angle: float = TAU, + axis: np.ndarray = OUT, + **kwargs + ): self.angle = angle self.axis = axis super().__init__(mobject, **kwargs) - def interpolate_mobject(self, alpha): + def interpolate_mobject(self, alpha: float) -> None: for sm1, sm2 in self.get_all_families_zipped(): sm1.set_points(sm2.get_points()) self.mobject.rotate( @@ -41,5 +54,11 @@ class Rotate(Rotating): "about_edge": ORIGIN, } - def __init__(self, mobject, angle=PI, axis=OUT, **kwargs): + def __init__( + self, + mobject: Mobject, + angle: float = PI, + axis: np.ndarray = OUT, + **kwargs + ): super().__init__(mobject, angle, axis, **kwargs) diff --git a/manimlib/animation/specialized.py b/manimlib/animation/specialized.py index 389eb93c..376e37e2 100644 --- a/manimlib/animation/specialized.py +++ b/manimlib/animation/specialized.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +import numpy as np + from manimlib.animation.composition import LaggedStart from manimlib.animation.transform import Restore from manimlib.constants import WHITE @@ -19,7 +23,7 @@ class Broadcast(LaggedStart): "run_time": 3, } - def __init__(self, focal_point, **kwargs): + def __init__(self, focal_point: np.ndarray, **kwargs): digest_config(self, kwargs) circles = VGroup() for x in range(self.n_circles): diff --git a/manimlib/animation/transform.py b/manimlib/animation/transform.py index 75160af7..76c36978 100644 --- a/manimlib/animation/transform.py +++ b/manimlib/animation/transform.py @@ -1,6 +1,10 @@ +from __future__ import annotations + import inspect +from typing import Callable, Union, Sequence import numpy as np +import numpy.typing as npt from manimlib.animation.animation import Animation from manimlib.constants import DEFAULT_POINTWISE_FUNCTION_RUN_TIME @@ -14,6 +18,12 @@ from manimlib.utils.paths import straight_path from manimlib.utils.rate_functions import smooth from manimlib.utils.rate_functions import squish_rate_func +from typing import TYPE_CHECKING +if TYPE_CHECKING: + import colour + from manimlib.scene.scene import Scene + ManimColor = Union[str, colour.Color, Sequence[float]] + class Transform(Animation): CONFIG = { @@ -23,12 +33,17 @@ class Transform(Animation): "replace_mobject_with_target_in_scene": False, } - def __init__(self, mobject, target_mobject=None, **kwargs): + def __init__( + self, + mobject: Mobject, + target_mobject: Mobject | None = None, + **kwargs + ): super().__init__(mobject, **kwargs) self.target_mobject = target_mobject self.init_path_func() - def init_path_func(self): + def init_path_func(self) -> None: if self.path_func is not None: return elif self.path_arc == 0: @@ -39,7 +54,7 @@ class Transform(Animation): self.path_arc_axis, ) - def begin(self): + def begin(self) -> None: self.target_mobject = self.create_target() self.check_target_mobject_validity() # Use a copy of target_mobject for the align_data_and_family @@ -54,28 +69,28 @@ class Transform(Animation): self.target_copy, ) - def finish(self): + def finish(self) -> None: super().finish() self.mobject.unlock_data() - def create_target(self): + def create_target(self) -> Mobject: # Has no meaningful effect here, but may be useful # in subclasses return self.target_mobject - def check_target_mobject_validity(self): + def check_target_mobject_validity(self) -> None: if self.target_mobject is None: raise Exception( f"{self.__class__.__name__}.create_target not properly implemented" ) - def clean_up_from_scene(self, scene): + def clean_up_from_scene(self, scene: Scene) -> None: super().clean_up_from_scene(scene) if self.replace_mobject_with_target_in_scene: scene.remove(self.mobject) scene.add(self.target_mobject) - def update_config(self, **kwargs): + def update_config(self, **kwargs) -> None: Animation.update_config(self, **kwargs) if "path_arc" in kwargs: self.path_func = path_along_arc( @@ -83,7 +98,7 @@ class Transform(Animation): kwargs.get("path_arc_axis", OUT) ) - def get_all_mobjects(self): + def get_all_mobjects(self) -> list[Mobject]: return [ self.mobject, self.starting_mobject, @@ -91,7 +106,7 @@ class Transform(Animation): self.target_copy, ] - def get_all_families_zipped(self): + def get_all_families_zipped(self) -> zip[tuple[Mobject]]: return zip(*[ mob.get_family() for mob in [ @@ -101,7 +116,13 @@ class Transform(Animation): ] ]) - def interpolate_submobject(self, submob, start, target_copy, alpha): + def interpolate_submobject( + self, + submob: Mobject, + start: Mobject, + target_copy: Mobject, + alpha: float + ): submob.interpolate(start, target_copy, alpha, self.path_func) return self @@ -117,10 +138,10 @@ class TransformFromCopy(Transform): Performs a reversed Transform """ - def __init__(self, mobject, target_mobject, **kwargs): + def __init__(self, mobject: Mobject, target_mobject: Mobject, **kwargs): super().__init__(target_mobject, mobject, **kwargs) - def interpolate(self, alpha): + def interpolate(self, alpha: float) -> None: super().interpolate(1 - alpha) @@ -137,11 +158,11 @@ class CounterclockwiseTransform(Transform): class MoveToTarget(Transform): - def __init__(self, mobject, **kwargs): + def __init__(self, mobject: Mobject, **kwargs): self.check_validity_of_input(mobject) super().__init__(mobject, mobject.target, **kwargs) - def check_validity_of_input(self, mobject): + def check_validity_of_input(self, mobject: Mobject) -> None: if not hasattr(mobject, "target"): raise Exception( "MoveToTarget called on mobject" @@ -150,13 +171,13 @@ class MoveToTarget(Transform): class _MethodAnimation(MoveToTarget): - def __init__(self, mobject, methods): + def __init__(self, mobject: Mobject, methods: Callable): self.methods = methods super().__init__(mobject) class ApplyMethod(Transform): - def __init__(self, method, *args, **kwargs): + def __init__(self, method: Callable, *args, **kwargs): """ method is a method of Mobject, *args are arguments for that method. Key word arguments should be passed in @@ -170,7 +191,7 @@ class ApplyMethod(Transform): self.method_args = args super().__init__(method.__self__, **kwargs) - def check_validity_of_input(self, method): + def check_validity_of_input(self, method: Callable) -> None: if not inspect.ismethod(method): raise Exception( "Whoops, looks like you accidentally invoked " @@ -178,7 +199,7 @@ class ApplyMethod(Transform): ) assert(isinstance(method.__self__, Mobject)) - def create_target(self): + def create_target(self) -> Mobject: method = self.method # Make sure it's a list so that args.pop() works args = list(self.method_args) @@ -197,16 +218,26 @@ class ApplyPointwiseFunction(ApplyMethod): "run_time": DEFAULT_POINTWISE_FUNCTION_RUN_TIME } - def __init__(self, function, mobject, **kwargs): + def __init__( + self, + function: Callable[[np.ndarray], np.ndarray], + mobject: Mobject, + **kwargs + ): super().__init__(mobject.apply_function, function, **kwargs) class ApplyPointwiseFunctionToCenter(ApplyPointwiseFunction): - def __init__(self, function, mobject, **kwargs): + def __init__( + self, + function: Callable[[np.ndarray], np.ndarray], + mobject: Mobject, + **kwargs + ): self.function = function super().__init__(mobject.move_to, **kwargs) - def begin(self): + def begin(self) -> None: self.method_args = [ self.function(self.mobject.get_center()) ] @@ -214,31 +245,46 @@ class ApplyPointwiseFunctionToCenter(ApplyPointwiseFunction): class FadeToColor(ApplyMethod): - def __init__(self, mobject, color, **kwargs): + def __init__( + self, + mobject: Mobject, + color: ManimColor, + **kwargs + ): super().__init__(mobject.set_color, color, **kwargs) class ScaleInPlace(ApplyMethod): - def __init__(self, mobject, scale_factor, **kwargs): + def __init__( + self, + mobject: Mobject, + scale_factor: npt.ArrayLike, + **kwargs + ): super().__init__(mobject.scale, scale_factor, **kwargs) class ShrinkToCenter(ScaleInPlace): - def __init__(self, mobject, **kwargs): + def __init__(self, mobject: Mobject, **kwargs): super().__init__(mobject, 0, **kwargs) class Restore(ApplyMethod): - def __init__(self, mobject, **kwargs): + def __init__(self, mobject: Mobject, **kwargs): super().__init__(mobject.restore, **kwargs) class ApplyFunction(Transform): - def __init__(self, function, mobject, **kwargs): + def __init__( + self, + function: Callable[[Mobject], Mobject], + mobject: Mobject, + **kwargs + ): self.function = function super().__init__(mobject, **kwargs) - def create_target(self): + def create_target(self) -> Mobject: target = self.function(self.mobject.copy()) if not isinstance(target, Mobject): raise Exception("Functions passed to ApplyFunction must return object of type Mobject") @@ -246,7 +292,12 @@ class ApplyFunction(Transform): class ApplyMatrix(ApplyPointwiseFunction): - def __init__(self, matrix, mobject, **kwargs): + def __init__( + self, + matrix: npt.ArrayLike, + mobject: Mobject, + **kwargs + ): matrix = self.initialize_matrix(matrix) def func(p): @@ -254,7 +305,7 @@ class ApplyMatrix(ApplyPointwiseFunction): super().__init__(func, mobject, **kwargs) - def initialize_matrix(self, matrix): + def initialize_matrix(self, matrix: npt.ArrayLike) -> np.ndarray: matrix = np.array(matrix) if matrix.shape == (2, 2): new_matrix = np.identity(3) @@ -266,12 +317,17 @@ class ApplyMatrix(ApplyPointwiseFunction): class ApplyComplexFunction(ApplyMethod): - def __init__(self, function, mobject, **kwargs): + def __init__( + self, + function: Callable[[complex], complex], + mobject: Mobject, + **kwargs + ): self.function = function method = mobject.apply_complex_function super().__init__(method, function, **kwargs) - def init_path_func(self): + def init_path_func(self) -> None: func1 = self.function(complex(1)) self.path_arc = np.log(func1).imag super().init_path_func() @@ -284,11 +340,11 @@ class CyclicReplace(Transform): "path_arc": 90 * DEGREES, } - def __init__(self, *mobjects, **kwargs): + def __init__(self, *mobjects: Mobject, **kwargs): self.group = Group(*mobjects) super().__init__(self.group, **kwargs) - def create_target(self): + def create_target(self) -> Mobject: target = self.group.copy() cycled_targets = [target[-1], *target[:-1]] for m1, m2 in zip(cycled_targets, self.group): @@ -306,7 +362,7 @@ class TransformAnimations(Transform): "rate_func": squish_rate_func(smooth) } - def __init__(self, start_anim, end_anim, **kwargs): + def __init__(self, start_anim: Animation, end_anim: Animation, **kwargs): digest_config(self, kwargs, locals()) if "run_time" in kwargs: self.run_time = kwargs.pop("run_time") @@ -327,7 +383,7 @@ class TransformAnimations(Transform): start_anim.mobject = self.starting_mobject end_anim.mobject = self.target_mobject - def interpolate(self, alpha): + def interpolate(self, alpha: float) -> None: self.start_anim.interpolate(alpha) self.end_anim.interpolate(alpha) Transform.interpolate(self, alpha) diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index ce396404..c1dd2da2 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -1,13 +1,15 @@ -import numpy as np +from __future__ import annotations + import itertools as it +import numpy as np + from manimlib.animation.composition import AnimationGroup from manimlib.animation.fading import FadeTransformPieces from manimlib.animation.fading import FadeInFromPoint from manimlib.animation.fading import FadeOutToPoint from manimlib.animation.transform import ReplacementTransform from manimlib.animation.transform import Transform - from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Group from manimlib.mobject.svg.mtex_mobject import MTex @@ -16,6 +18,11 @@ from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.config_ops import digest_config from manimlib.utils.iterables import remove_list_redundancies +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from manimlib.scene.scene import Scene + from manimlib.mobject.svg.tex_mobject import Tex, SingleStringTex + class TransformMatchingParts(AnimationGroup): CONFIG = { @@ -26,7 +33,7 @@ class TransformMatchingParts(AnimationGroup): "key_map": dict(), } - def __init__(self, mobject, target_mobject, **kwargs): + def __init__(self, mobject: Mobject, target_mobject: Mobject, **kwargs): digest_config(self, kwargs) assert(isinstance(mobject, self.mobject_type)) assert(isinstance(target_mobject, self.mobject_type)) @@ -83,8 +90,8 @@ class TransformMatchingParts(AnimationGroup): self.to_remove = mobject self.to_add = target_mobject - def get_shape_map(self, mobject): - shape_map = {} + def get_shape_map(self, mobject: Mobject) -> dict[int, VGroup]: + shape_map: dict[int, VGroup] = {} for sm in self.get_mobject_parts(mobject): key = self.get_mobject_key(sm) if key not in shape_map: @@ -92,7 +99,7 @@ class TransformMatchingParts(AnimationGroup): shape_map[key].add(sm) return shape_map - def clean_up_from_scene(self, scene): + def clean_up_from_scene(self, scene: Scene) -> None: for anim in self.animations: anim.update(0) scene.remove(self.mobject) @@ -100,12 +107,12 @@ class TransformMatchingParts(AnimationGroup): scene.add(self.to_add) @staticmethod - def get_mobject_parts(mobject): + def get_mobject_parts(mobject: Mobject) -> Mobject: # To be implemented in subclass return mobject @staticmethod - def get_mobject_key(mobject): + def get_mobject_key(mobject: Mobject) -> int: # To be implemented in subclass return hash(mobject) @@ -117,11 +124,11 @@ class TransformMatchingShapes(TransformMatchingParts): } @staticmethod - def get_mobject_parts(mobject): + def get_mobject_parts(mobject: VMobject) -> list[VMobject]: return mobject.family_members_with_points() @staticmethod - def get_mobject_key(mobject): + def get_mobject_key(mobject: VMobject) -> int: mobject.save_state() mobject.center() mobject.set_height(1) @@ -137,11 +144,11 @@ class TransformMatchingTex(TransformMatchingParts): } @staticmethod - def get_mobject_parts(mobject): + def get_mobject_parts(mobject: Tex) -> list[SingleStringTex]: return mobject.submobjects @staticmethod - def get_mobject_key(mobject): + def get_mobject_key(mobject: Tex) -> str: return mobject.get_tex() @@ -150,7 +157,7 @@ class TransformMatchingMTex(AnimationGroup): "key_map": dict(), } - def __init__(self, source_mobject, target_mobject, **kwargs): + def __init__(self, source_mobject: MTex, target_mobject: MTex, **kwargs): digest_config(self, kwargs) assert isinstance(source_mobject, MTex) assert isinstance(target_mobject, MTex) diff --git a/manimlib/animation/update.py b/manimlib/animation/update.py index 57856a2f..564c56cf 100644 --- a/manimlib/animation/update.py +++ b/manimlib/animation/update.py @@ -1,7 +1,14 @@ +from __future__ import annotations + import operator as op +from typing import Callable from manimlib.animation.animation import Animation +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from manimlib.mobject.mobject import Mobject + class UpdateFromFunc(Animation): """ @@ -13,21 +20,31 @@ class UpdateFromFunc(Animation): "suspend_mobject_updating": False, } - def __init__(self, mobject, update_function, **kwargs): + def __init__( + self, + mobject: Mobject, + update_function: Callable[[Mobject]], + **kwargs + ): self.update_function = update_function super().__init__(mobject, **kwargs) - def interpolate_mobject(self, alpha): + def interpolate_mobject(self, alpha: float) -> None: self.update_function(self.mobject) class UpdateFromAlphaFunc(UpdateFromFunc): - def interpolate_mobject(self, alpha): + def interpolate_mobject(self, alpha: float) -> None: self.update_function(self.mobject, alpha) class MaintainPositionRelativeTo(Animation): - def __init__(self, mobject, tracked_mobject, **kwargs): + def __init__( + self, + mobject: Mobject, + tracked_mobject: Mobject, + **kwargs + ): self.tracked_mobject = tracked_mobject self.diff = op.sub( mobject.get_center(), @@ -35,7 +52,7 @@ class MaintainPositionRelativeTo(Animation): ) super().__init__(mobject, **kwargs) - def interpolate_mobject(self, alpha): + def interpolate_mobject(self, alpha: float) -> None: target = self.tracked_mobject.get_center() location = self.mobject.get_center() self.mobject.shift(target - location + self.diff) From 854f7cd2bf80b1484133b32a8816d200eb627a9d Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Tue, 15 Feb 2022 18:47:17 +0800 Subject: [PATCH 24/27] fix: remove type alias import in indication.py --- manimlib/animation/indication.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/manimlib/animation/indication.py b/manimlib/animation/indication.py index 7a463aaa..b38dc9e9 100644 --- a/manimlib/animation/indication.py +++ b/manimlib/animation/indication.py @@ -14,7 +14,7 @@ from manimlib.animation.creation import ShowCreation from manimlib.animation.creation import ShowPartial from manimlib.animation.fading import FadeOut from manimlib.animation.fading import FadeIn -from manimlib.animation.transform import ManimColor, Transform +from manimlib.animation.transform import Transform from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.mobject.geometry import Circle from manimlib.mobject.geometry import Dot From 0e4d4155a3df5c589ea462ca91fa7ce10ddd1e48 Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Tue, 15 Feb 2022 20:23:59 +0800 Subject: [PATCH 25/27] workflow: only build wheels for python 3.7+ --- .github/workflows/publish.yml | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 5e81dd94..04d651fc 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -8,6 +8,11 @@ jobs: deploy: runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python: ["py37", "py38", "py39", "py310"] + steps: - uses: actions/checkout@v2 @@ -20,11 +25,13 @@ jobs: run: | python -m pip install --upgrade pip pip install setuptools wheel twine build - - - name: Build and publish + + - name: Build wheels + run: python setup.py bdist_wheel --python-tag ${{ matrix.python }} + + - name: Upload wheels env: TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} run: | - python -m build twine upload dist/* \ No newline at end of file From 05bee011d27666cd47a46c169a78a100b78158d6 Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Wed, 16 Feb 2022 20:37:07 +0800 Subject: [PATCH 26/27] chore: update type hint of SVGMobject --- manimlib/mobject/svg/svg_mobject.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/manimlib/mobject/svg/svg_mobject.py b/manimlib/mobject/svg/svg_mobject.py index 1bda6eae..0a185802 100644 --- a/manimlib/mobject/svg/svg_mobject.py +++ b/manimlib/mobject/svg/svg_mobject.py @@ -169,9 +169,9 @@ class SVGMobject(VMobject): mob.shift(vec) return mob - def get_mobject_from(self, shape: se.Shape | se.Text) -> VMobject | None: + def get_mobject_from(self, shape: se.GraphicObject) -> VMobject | None: shape_class_to_func_map: dict[ - type, Callable[[se.Shape | se.Text], VMobject] + type, Callable[[se.GraphicObject], VMobject] ] = { se.Path: self.path_to_mobject, se.SimpleLine: self.line_to_mobject, @@ -196,7 +196,7 @@ class SVGMobject(VMobject): @staticmethod def apply_style_to_mobject( mob: VMobject, - shape: se.Shape | se.Text + shape: se.GraphicObject ) -> VMobject: mob.set_style( stroke_width=shape.stroke_width, @@ -216,7 +216,7 @@ class SVGMobject(VMobject): end=_convert_point_to_3d(line.x2, line.y2) ) - def rect_to_mobject(self, rect: se.Rect) -> Rectangle | RoundedRectangle: + def rect_to_mobject(self, rect: se.Rect) -> Rectangle: if rect.rx == 0 or rect.ry == 0: mob = Rectangle( width=rect.width, From 4fbe948b63f40ff8e03bcbda9d9f0f0a30590e9f Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Wed, 16 Feb 2022 21:08:25 +0800 Subject: [PATCH 27/27] style: insert an empty line after import --- manimlib/animation/animation.py | 1 + manimlib/animation/composition.py | 1 + manimlib/animation/creation.py | 1 + manimlib/animation/fading.py | 1 + manimlib/animation/growing.py | 1 + manimlib/animation/indication.py | 1 + manimlib/animation/movement.py | 1 + manimlib/animation/rotation.py | 1 + manimlib/animation/transform.py | 1 + manimlib/animation/transform_matching_parts.py | 1 + manimlib/animation/update.py | 1 + manimlib/camera/camera.py | 1 + manimlib/mobject/coordinate_systems.py | 1 + manimlib/mobject/matrix.py | 1 + manimlib/mobject/mobject_update_utils.py | 1 + manimlib/mobject/shape_matchers.py | 1 + manimlib/mobject/svg/brace.py | 1 + manimlib/mobject/svg/text_mobject.py | 1 + manimlib/mobject/types/surface.py | 1 + manimlib/mobject/vector_field.py | 1 + manimlib/scene/scene.py | 1 + manimlib/scene/scene_file_writer.py | 1 + manimlib/utils/debug.py | 1 + manimlib/utils/family_ops.py | 1 + manimlib/window.py | 1 + 25 files changed, 25 insertions(+) diff --git a/manimlib/animation/animation.py b/manimlib/animation/animation.py index d1225cd3..8ec26de2 100644 --- a/manimlib/animation/animation.py +++ b/manimlib/animation/animation.py @@ -10,6 +10,7 @@ from manimlib.utils.rate_functions import smooth from manimlib.utils.simple_functions import clip from typing import TYPE_CHECKING + if TYPE_CHECKING: from manimlib.scene.scene import Scene diff --git a/manimlib/animation/composition.py b/manimlib/animation/composition.py index 78a7b024..f282bc9c 100644 --- a/manimlib/animation/composition.py +++ b/manimlib/animation/composition.py @@ -13,6 +13,7 @@ from manimlib.utils.rate_functions import linear from manimlib.utils.simple_functions import clip from typing import TYPE_CHECKING + if TYPE_CHECKING: from manimlib.scene.scene import Scene from manimlib.mobject.mobject import Mobject diff --git a/manimlib/animation/creation.py b/manimlib/animation/creation.py index c4f52149..00588b46 100644 --- a/manimlib/animation/creation.py +++ b/manimlib/animation/creation.py @@ -15,6 +15,7 @@ from manimlib.utils.rate_functions import double_smooth from manimlib.utils.rate_functions import smooth from typing import TYPE_CHECKING + if TYPE_CHECKING: from manimlib.mobject.mobject import Group diff --git a/manimlib/animation/fading.py b/manimlib/animation/fading.py index 2d246c55..39c149f0 100644 --- a/manimlib/animation/fading.py +++ b/manimlib/animation/fading.py @@ -10,6 +10,7 @@ from manimlib.utils.bezier import interpolate from manimlib.utils.rate_functions import there_and_back from typing import TYPE_CHECKING + if TYPE_CHECKING: from manimlib.scene.scene import Scene from manimlib.mobject.mobject import Mobject diff --git a/manimlib/animation/growing.py b/manimlib/animation/growing.py index 5e0d0eda..1b3c3cd7 100644 --- a/manimlib/animation/growing.py +++ b/manimlib/animation/growing.py @@ -4,6 +4,7 @@ from manimlib.constants import PI from manimlib.animation.transform import Transform from typing import TYPE_CHECKING + if TYPE_CHECKING: import numpy as np from manimlib.mobject.mobject import Mobject diff --git a/manimlib/animation/indication.py b/manimlib/animation/indication.py index b38dc9e9..5f210773 100644 --- a/manimlib/animation/indication.py +++ b/manimlib/animation/indication.py @@ -30,6 +30,7 @@ from manimlib.utils.rate_functions import smooth from manimlib.utils.rate_functions import squish_rate_func from typing import TYPE_CHECKING + if TYPE_CHECKING: import colour from manimlib.mobject.mobject import Mobject diff --git a/manimlib/animation/movement.py b/manimlib/animation/movement.py index 2d157a05..78cbbee8 100644 --- a/manimlib/animation/movement.py +++ b/manimlib/animation/movement.py @@ -6,6 +6,7 @@ from manimlib.animation.animation import Animation from manimlib.utils.rate_functions import linear from typing import TYPE_CHECKING + if TYPE_CHECKING: import numpy as np from manimlib.mobject.mobject import Mobject diff --git a/manimlib/animation/rotation.py b/manimlib/animation/rotation.py index 8c9b7ddf..7993c3cf 100644 --- a/manimlib/animation/rotation.py +++ b/manimlib/animation/rotation.py @@ -9,6 +9,7 @@ from manimlib.utils.rate_functions import linear from manimlib.utils.rate_functions import smooth from typing import TYPE_CHECKING + if TYPE_CHECKING: import numpy as np from manimlib.mobject.mobject import Mobject diff --git a/manimlib/animation/transform.py b/manimlib/animation/transform.py index 76c36978..a426b21e 100644 --- a/manimlib/animation/transform.py +++ b/manimlib/animation/transform.py @@ -19,6 +19,7 @@ from manimlib.utils.rate_functions import smooth from manimlib.utils.rate_functions import squish_rate_func from typing import TYPE_CHECKING + if TYPE_CHECKING: import colour from manimlib.scene.scene import Scene diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index c1dd2da2..90ffa76f 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -19,6 +19,7 @@ from manimlib.utils.config_ops import digest_config from manimlib.utils.iterables import remove_list_redundancies from typing import TYPE_CHECKING + if TYPE_CHECKING: from manimlib.scene.scene import Scene from manimlib.mobject.svg.tex_mobject import Tex, SingleStringTex diff --git a/manimlib/animation/update.py b/manimlib/animation/update.py index 564c56cf..43fafa42 100644 --- a/manimlib/animation/update.py +++ b/manimlib/animation/update.py @@ -6,6 +6,7 @@ from typing import Callable from manimlib.animation.animation import Animation from typing import TYPE_CHECKING + if TYPE_CHECKING: from manimlib.mobject.mobject import Mobject diff --git a/manimlib/camera/camera.py b/manimlib/camera/camera.py index 94b09b37..59b43f7b 100644 --- a/manimlib/camera/camera.py +++ b/manimlib/camera/camera.py @@ -22,6 +22,7 @@ from manimlib.utils.space_ops import quaternion_from_angle_axis from manimlib.utils.space_ops import quaternion_mult from typing import TYPE_CHECKING + if TYPE_CHECKING: from manimlib.shader_wrapper import ShaderWrapper diff --git a/manimlib/mobject/coordinate_systems.py b/manimlib/mobject/coordinate_systems.py index 80355e08..709764be 100644 --- a/manimlib/mobject/coordinate_systems.py +++ b/manimlib/mobject/coordinate_systems.py @@ -23,6 +23,7 @@ from manimlib.utils.space_ops import get_norm from manimlib.utils.space_ops import rotate_vector from typing import TYPE_CHECKING + if TYPE_CHECKING: import colour from manimlib.mobject.mobject import Mobject diff --git a/manimlib/mobject/matrix.py b/manimlib/mobject/matrix.py index 3133d992..b53cba0a 100644 --- a/manimlib/mobject/matrix.py +++ b/manimlib/mobject/matrix.py @@ -16,6 +16,7 @@ from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VMobject from typing import TYPE_CHECKING + if TYPE_CHECKING: import colour from manimlib.mobject.mobject import Mobject diff --git a/manimlib/mobject/mobject_update_utils.py b/manimlib/mobject/mobject_update_utils.py index 1a698888..b32a0ff3 100644 --- a/manimlib/mobject/mobject_update_utils.py +++ b/manimlib/mobject/mobject_update_utils.py @@ -9,6 +9,7 @@ from manimlib.mobject.mobject import Mobject from manimlib.utils.simple_functions import clip from typing import TYPE_CHECKING + if TYPE_CHECKING: import numpy as np from manimlib.animation.animation import Animation diff --git a/manimlib/mobject/shape_matchers.py b/manimlib/mobject/shape_matchers.py index 62c93e34..a1ffe5fd 100644 --- a/manimlib/mobject/shape_matchers.py +++ b/manimlib/mobject/shape_matchers.py @@ -10,6 +10,7 @@ from manimlib.utils.customization import get_customization from manimlib.utils.config_ops import digest_config from typing import TYPE_CHECKING + if TYPE_CHECKING: from typing import Union, Sequence from manimlib.mobject.mobject import Mobject diff --git a/manimlib/mobject/svg/brace.py b/manimlib/mobject/svg/brace.py index f9d96cec..659b0604 100644 --- a/manimlib/mobject/svg/brace.py +++ b/manimlib/mobject/svg/brace.py @@ -19,6 +19,7 @@ from manimlib.utils.config_ops import digest_config from manimlib.utils.space_ops import get_norm from typing import TYPE_CHECKING + if TYPE_CHECKING: from manimlib.mobject.mobject import Mobject from manimlib.animation.animation import Animation diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index bd67696b..dcc37c59 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -26,6 +26,7 @@ from manimlib.utils.directories import get_downloads_dir, get_text_dir from manimpango import PangoUtils, TextSetting, MarkupUtils from typing import TYPE_CHECKING + if TYPE_CHECKING: import colour from manimlib.mobject.types.vectorized_mobject import VMobject diff --git a/manimlib/mobject/types/surface.py b/manimlib/mobject/types/surface.py index 01eb912f..cc3e32d5 100644 --- a/manimlib/mobject/types/surface.py +++ b/manimlib/mobject/types/surface.py @@ -15,6 +15,7 @@ from manimlib.utils.iterables import listify from manimlib.utils.space_ops import normalize_along_axis from typing import TYPE_CHECKING + if TYPE_CHECKING: from manimlib.camera.camera import Camera diff --git a/manimlib/mobject/vector_field.py b/manimlib/mobject/vector_field.py index 1300c2c5..b17b55de 100644 --- a/manimlib/mobject/vector_field.py +++ b/manimlib/mobject/vector_field.py @@ -23,6 +23,7 @@ from manimlib.utils.simple_functions import sigmoid from manimlib.utils.space_ops import get_norm from typing import TYPE_CHECKING + if TYPE_CHECKING: from manimlib.mobject.mobject import Mobject from manimlib.mobject.coordinate_systems import CoordinateSystem diff --git a/manimlib/scene/scene.py b/manimlib/scene/scene.py index 71cda435..ed9233d5 100644 --- a/manimlib/scene/scene.py +++ b/manimlib/scene/scene.py @@ -27,6 +27,7 @@ from manimlib.event_handler import EVENT_DISPATCHER from manimlib.logger import log from typing import TYPE_CHECKING + if TYPE_CHECKING: from PIL.Image import Image from manimlib.animation.animation import Animation diff --git a/manimlib/scene/scene_file_writer.py b/manimlib/scene/scene_file_writer.py index d3d4ee29..cb948ab5 100644 --- a/manimlib/scene/scene_file_writer.py +++ b/manimlib/scene/scene_file_writer.py @@ -19,6 +19,7 @@ from manimlib.utils.sounds import get_full_sound_file_path from manimlib.logger import log from typing import TYPE_CHECKING + if TYPE_CHECKING: from manimlib.scene.scene import Scene from manimlib.camera.camera import Camera diff --git a/manimlib/utils/debug.py b/manimlib/utils/debug.py index 308d4906..29aa6a3c 100644 --- a/manimlib/utils/debug.py +++ b/manimlib/utils/debug.py @@ -10,6 +10,7 @@ from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.logger import log from typing import TYPE_CHECKING + if TYPE_CHECKING: from manimlib.mobject.mobject import Mobject diff --git a/manimlib/utils/family_ops.py b/manimlib/utils/family_ops.py index 3218fab6..fc1a8b93 100644 --- a/manimlib/utils/family_ops.py +++ b/manimlib/utils/family_ops.py @@ -4,6 +4,7 @@ import itertools as it from typing import Iterable from typing import TYPE_CHECKING + if TYPE_CHECKING: from manimlib.mobject.mobject import Mobject diff --git a/manimlib/window.py b/manimlib/window.py index 5785283b..0d9d3a47 100644 --- a/manimlib/window.py +++ b/manimlib/window.py @@ -10,6 +10,7 @@ from manimlib.utils.config_ops import digest_config from manimlib.utils.customization import get_customization from typing import TYPE_CHECKING + if TYPE_CHECKING: from manimlib.scene.scene import Scene