Merge pull request #1736 from TonyCrane/master

Add type hints according to PEP 484 and PEP 604
This commit is contained in:
Grant Sanderson 2022-03-22 11:05:10 -07:00 committed by GitHub
commit 0c8b333a42
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
62 changed files with 2804 additions and 1381 deletions

View file

@ -11,7 +11,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python: ["py36", "py37", "py38", "py39", "py310"]
python: ["py37", "py38", "py39", "py310"]
steps:
- uses: actions/checkout@v2

View file

@ -19,7 +19,7 @@ Note, there are two versions of manim. This repository began as a personal proj
>
> **Note**: To install manim directly through pip, please pay attention to the name of the installed package. This repository is ManimGL of 3b1b. The package name is `manimgl` instead of `manim` or `manimlib`. Please use `pip install manimgl` to install the version in this repository.
Manim runs on Python 3.6 or higher (Python 3.8 is recommended).
Manim runs on Python 3.7 or higher.
System requirements are [FFmpeg](https://ffmpeg.org/), [OpenGL](https://www.opengl.org/) and [LaTeX](https://www.latex-project.org) (optional, if you want to use LaTeX).
For Linux, [Pango](https://pango.gnome.org) along with its development headers are required. See instruction [here](https://github.com/ManimCommunity/ManimPango#building).

View file

@ -1,7 +1,7 @@
Installation
============
Manim runs on Python 3.6 or higher (Python 3.8 is recommended).
Manim runs on Python 3.7 or higher.
System requirements are

View file

@ -1,4 +1,7 @@
from __future__ import annotations
from copy import deepcopy
from typing import Callable
from manimlib.mobject.mobject import _AnimationBuilder
from manimlib.mobject.mobject import Mobject
@ -6,6 +9,11 @@ from manimlib.utils.config_ops import digest_config
from manimlib.utils.rate_functions import smooth
from manimlib.utils.simple_functions import clip
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.scene.scene import Scene
DEFAULT_ANIMATION_RUN_TIME = 1.0
DEFAULT_ANIMATION_LAG_RATIO = 0
@ -29,17 +37,17 @@ class Animation(object):
"suspend_mobject_updating": True,
}
def __init__(self, mobject, **kwargs):
def __init__(self, mobject: Mobject, **kwargs):
assert(isinstance(mobject, Mobject))
digest_config(self, kwargs)
self.mobject = mobject
def __str__(self):
def __str__(self) -> str:
if self.name:
return self.name
return self.__class__.__name__ + str(self.mobject)
def begin(self):
def begin(self) -> None:
# This is called right as an animation is being
# played. As much initialization as possible,
# especially any mobject copying, should live in
@ -56,32 +64,32 @@ class Animation(object):
self.families = list(self.get_all_families_zipped())
self.interpolate(0)
def finish(self):
def finish(self) -> None:
self.interpolate(self.final_alpha_value)
if self.suspend_mobject_updating:
self.mobject.resume_updating()
def clean_up_from_scene(self, scene):
def clean_up_from_scene(self, scene: Scene) -> None:
if self.is_remover():
scene.remove(self.mobject)
def create_starting_mobject(self):
def create_starting_mobject(self) -> Mobject:
# Keep track of where the mobject starts
return self.mobject.copy()
def get_all_mobjects(self):
def get_all_mobjects(self) -> tuple[Mobject, Mobject]:
"""
Ordering must match the ording of arguments to interpolate_submobject
"""
return self.mobject, self.starting_mobject
def get_all_families_zipped(self):
def get_all_families_zipped(self) -> zip[tuple[Mobject]]:
return zip(*[
mob.get_family()
for mob in self.get_all_mobjects()
])
def update_mobjects(self, dt):
def update_mobjects(self, dt: float) -> None:
"""
Updates things like starting_mobject, and (for
Transforms) target_mobject. Note, since typically
@ -92,7 +100,7 @@ class Animation(object):
for mob in self.get_all_mobjects_to_update():
mob.update(dt)
def get_all_mobjects_to_update(self):
def get_all_mobjects_to_update(self) -> list[Mobject]:
# The surrounding scene typically handles
# updating of self.mobject. Besides, in
# most cases its updating is suspended anyway
@ -109,27 +117,37 @@ class Animation(object):
return self
# Methods for interpolation, the mean of an Animation
def interpolate(self, alpha):
def interpolate(self, alpha: float) -> None:
alpha = clip(alpha, 0, 1)
self.interpolate_mobject(self.rate_func(alpha))
def update(self, alpha):
def update(self, alpha: float) -> None:
"""
This method shouldn't exist, but it's here to
keep many old scenes from breaking
"""
self.interpolate(alpha)
def interpolate_mobject(self, alpha):
def interpolate_mobject(self, alpha: float) -> None:
for i, mobs in enumerate(self.families):
sub_alpha = self.get_sub_alpha(alpha, i, len(self.families))
self.interpolate_submobject(*mobs, sub_alpha)
def interpolate_submobject(self, submobject, starting_sumobject, alpha):
def interpolate_submobject(
self,
submobject: Mobject,
starting_submobject: Mobject,
alpha: float
):
# Typically ipmlemented by subclass
pass
def get_sub_alpha(self, alpha, index, num_submobjects):
def get_sub_alpha(
self,
alpha: float,
index: int,
num_submobjects: int
) -> float:
# TODO, make this more understanable, and/or combine
# its functionality with AnimationGroup's method
# build_animations_with_timings
@ -140,29 +158,29 @@ class Animation(object):
return clip((value - lower), 0, 1)
# Getters and setters
def set_run_time(self, run_time):
def set_run_time(self, run_time: float):
self.run_time = run_time
return self
def get_run_time(self):
def get_run_time(self) -> float:
return self.run_time
def set_rate_func(self, rate_func):
def set_rate_func(self, rate_func: Callable[[float], float]):
self.rate_func = rate_func
return self
def get_rate_func(self):
def get_rate_func(self) -> Callable[[float], float]:
return self.rate_func
def set_name(self, name):
def set_name(self, name: str):
self.name = name
return self
def is_remover(self):
def is_remover(self) -> bool:
return self.remover
def prepare_animation(anim):
def prepare_animation(anim: Animation | _AnimationBuilder):
if isinstance(anim, _AnimationBuilder):
return anim.build()

View file

@ -1,4 +1,7 @@
from __future__ import annotations
import numpy as np
from typing import Callable
from manimlib.animation.animation import Animation, prepare_animation
from manimlib.mobject.mobject import Group
@ -9,6 +12,12 @@ from manimlib.utils.iterables import remove_list_redundancies
from manimlib.utils.rate_functions import linear
from manimlib.utils.simple_functions import clip
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.scene.scene import Scene
from manimlib.mobject.mobject import Mobject
DEFAULT_LAGGED_START_LAG_RATIO = 0.05
@ -27,7 +36,7 @@ class AnimationGroup(Animation):
"group": None,
}
def __init__(self, *animations, **kwargs):
def __init__(self, *animations: Animation, **kwargs):
digest_config(self, kwargs)
self.animations = [prepare_animation(anim) for anim in animations]
if self.group is None:
@ -37,27 +46,27 @@ class AnimationGroup(Animation):
self.init_run_time()
Animation.__init__(self, self.group, **kwargs)
def get_all_mobjects(self):
def get_all_mobjects(self) -> Group:
return self.group
def begin(self):
def begin(self) -> None:
for anim in self.animations:
anim.begin()
# self.init_run_time()
def finish(self):
def finish(self) -> None:
for anim in self.animations:
anim.finish()
def clean_up_from_scene(self, scene):
def clean_up_from_scene(self, scene: Scene) -> None:
for anim in self.animations:
anim.clean_up_from_scene(scene)
def update_mobjects(self, dt):
def update_mobjects(self, dt: float) -> None:
for anim in self.animations:
anim.update_mobjects(dt)
def init_run_time(self):
def init_run_time(self) -> None:
self.build_animations_with_timings()
if self.anims_with_timings:
self.max_end_time = np.max([
@ -68,7 +77,7 @@ class AnimationGroup(Animation):
if self.run_time is None:
self.run_time = self.max_end_time
def build_animations_with_timings(self):
def build_animations_with_timings(self) -> None:
"""
Creates a list of triplets of the form
(anim, start_time, end_time)
@ -87,7 +96,7 @@ class AnimationGroup(Animation):
start_time, end_time, self.lag_ratio
)
def interpolate(self, alpha):
def interpolate(self, alpha: float) -> None:
# Note, if the run_time of AnimationGroup has been
# set to something other than its default, these
# times might not correspond to actual times,
@ -111,19 +120,19 @@ class Succession(AnimationGroup):
"lag_ratio": 1,
}
def begin(self):
def begin(self) -> None:
assert(len(self.animations) > 0)
self.init_run_time()
self.active_animation = self.animations[0]
self.active_animation.begin()
def finish(self):
def finish(self) -> None:
self.active_animation.finish()
def update_mobjects(self, dt):
def update_mobjects(self, dt: float) -> None:
self.active_animation.update_mobjects(dt)
def interpolate(self, alpha):
def interpolate(self, alpha: float) -> None:
index, subalpha = integer_interpolate(
0, len(self.animations), alpha
)
@ -146,7 +155,13 @@ class LaggedStartMap(LaggedStart):
"run_time": 2,
}
def __init__(self, AnimationClass, mobject, arg_creator=None, **kwargs):
def __init__(
self,
AnimationClass: type,
mobject: Mobject,
arg_creator: Callable[[Mobject], tuple] | None = None,
**kwargs
):
args_list = []
for submob in mobject:
if arg_creator:

View file

@ -1,3 +1,10 @@
from __future__ import annotations
import itertools as it
from abc import abstractmethod
import numpy as np
from manimlib.animation.animation import Animation
from manimlib.animation.composition import Succession
from manimlib.mobject.types.vectorized_mobject import VMobject
@ -7,8 +14,10 @@ from manimlib.utils.rate_functions import linear
from manimlib.utils.rate_functions import double_smooth
from manimlib.utils.rate_functions import smooth
import numpy as np
import itertools as it
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.mobject.mobject import Group
class ShowPartial(Animation):
@ -19,21 +28,27 @@ class ShowPartial(Animation):
"should_match_start": False,
}
def begin(self):
def begin(self) -> None:
super().begin()
if not self.should_match_start:
self.mobject.lock_matching_data(self.mobject, self.starting_mobject)
def finish(self):
def finish(self) -> None:
super().finish()
self.mobject.unlock_data()
def interpolate_submobject(self, submob, start_submob, alpha):
def interpolate_submobject(
self,
submob: VMobject,
start_submob: VMobject,
alpha: float
) -> None:
submob.pointwise_become_partial(
start_submob, *self.get_bounds(alpha)
)
def get_bounds(self, alpha):
@abstractmethod
def get_bounds(self, alpha: float) -> tuple[float, float]:
raise Exception("Not Implemented")
@ -42,7 +57,7 @@ class ShowCreation(ShowPartial):
"lag_ratio": 1,
}
def get_bounds(self, alpha):
def get_bounds(self, alpha: float) -> tuple[float, float]:
return (0, alpha)
@ -64,7 +79,7 @@ class DrawBorderThenFill(Animation):
"fill_animation_config": {},
}
def __init__(self, vmobject, **kwargs):
def __init__(self, vmobject: VMobject, **kwargs):
assert(isinstance(vmobject, VMobject))
self.sm_to_index = dict([
(hash(sm), 0)
@ -72,7 +87,7 @@ class DrawBorderThenFill(Animation):
])
super().__init__(vmobject, **kwargs)
def begin(self):
def begin(self) -> None:
# Trigger triangulation calculation
for submob in self.mobject.get_family():
submob.get_triangulation()
@ -82,11 +97,11 @@ class DrawBorderThenFill(Animation):
self.mobject.match_style(self.outline)
self.mobject.lock_matching_data(self.mobject, self.outline)
def finish(self):
def finish(self) -> None:
super().finish()
self.mobject.unlock_data()
def get_outline(self):
def get_outline(self) -> VMobject:
outline = self.mobject.copy()
outline.set_fill(opacity=0)
for sm in outline.get_family():
@ -96,17 +111,23 @@ class DrawBorderThenFill(Animation):
)
return outline
def get_stroke_color(self, vmobject):
def get_stroke_color(self, vmobject: VMobject) -> str:
if self.stroke_color:
return self.stroke_color
elif vmobject.get_stroke_width() > 0:
return vmobject.get_stroke_color()
return vmobject.get_color()
def get_all_mobjects(self):
def get_all_mobjects(self) -> list[VMobject]:
return [*super().get_all_mobjects(), self.outline]
def interpolate_submobject(self, submob, start, outline, alpha):
def interpolate_submobject(
self,
submob: VMobject,
start: VMobject,
outline: VMobject,
alpha: float
) -> None:
index, subalpha = integer_interpolate(0, 2, alpha)
if index == 1 and self.sm_to_index[hash(submob)] == 0:
@ -133,13 +154,13 @@ class Write(DrawBorderThenFill):
"rate_func": linear,
}
def __init__(self, mobject, **kwargs):
def __init__(self, vmobject: VMobject, **kwargs):
digest_config(self, kwargs)
self.set_default_config_from_length(mobject)
super().__init__(mobject, **kwargs)
self.set_default_config_from_length(vmobject)
super().__init__(vmobject, **kwargs)
def set_default_config_from_length(self, mobject):
length = len(mobject.family_members_with_points())
def set_default_config_from_length(self, vmobject: VMobject) -> None:
length = len(vmobject.family_members_with_points())
if self.run_time is None:
if length < 15:
self.run_time = 1
@ -155,16 +176,16 @@ class ShowIncreasingSubsets(Animation):
"int_func": np.round,
}
def __init__(self, group, **kwargs):
def __init__(self, group: Group, **kwargs):
self.all_submobs = list(group.submobjects)
super().__init__(group, **kwargs)
def interpolate_mobject(self, alpha):
def interpolate_mobject(self, alpha: float) -> None:
n_submobs = len(self.all_submobs)
index = int(self.int_func(alpha * n_submobs))
self.update_submobject_list(index)
def update_submobject_list(self, index):
def update_submobject_list(self, index: int) -> None:
self.mobject.set_submobjects(self.all_submobs[:index])
@ -173,7 +194,7 @@ class ShowSubmobjectsOneByOne(ShowIncreasingSubsets):
"int_func": np.ceil,
}
def update_submobject_list(self, index):
def update_submobject_list(self, index: int) -> None:
# N = len(self.all_submobs)
if index == 0:
self.mobject.set_submobjects([])

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import numpy as np
from manimlib.animation.animation import Animation
@ -7,6 +9,13 @@ from manimlib.constants import ORIGIN
from manimlib.utils.bezier import interpolate
from manimlib.utils.rate_functions import there_and_back
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.scene.scene import Scene
from manimlib.mobject.mobject import Mobject
from manimlib.mobject.types.vectorized_mobject import VMobject
DEFAULT_FADE_LAG_RATIO = 0
@ -16,7 +25,13 @@ class Fade(Transform):
"lag_ratio": DEFAULT_FADE_LAG_RATIO,
}
def __init__(self, mobject, shift=ORIGIN, scale=1, **kwargs):
def __init__(
self,
mobject: Mobject,
shift: np.ndarray = ORIGIN,
scale: float = 1,
**kwargs
):
self.shift_vect = shift
self.scale_factor = scale
super().__init__(mobject, **kwargs)
@ -27,10 +42,10 @@ class FadeIn(Fade):
"lag_ratio": DEFAULT_FADE_LAG_RATIO,
}
def create_target(self):
def create_target(self) -> Mobject:
return self.mobject
def create_starting_mobject(self):
def create_starting_mobject(self) -> Mobject:
start = super().create_starting_mobject()
start.set_opacity(0)
start.scale(1.0 / self.scale_factor)
@ -45,7 +60,7 @@ class FadeOut(Fade):
"final_alpha_value": 0,
}
def create_target(self):
def create_target(self) -> Mobject:
result = self.mobject.copy()
result.set_opacity(0)
result.shift(self.shift_vect)
@ -54,7 +69,7 @@ class FadeOut(Fade):
class FadeInFromPoint(FadeIn):
def __init__(self, mobject, point, **kwargs):
def __init__(self, mobject: Mobject, point: np.ndarray, **kwargs):
super().__init__(
mobject,
shift=mobject.get_center() - point,
@ -64,7 +79,7 @@ class FadeInFromPoint(FadeIn):
class FadeOutToPoint(FadeOut):
def __init__(self, mobject, point, **kwargs):
def __init__(self, mobject: Mobject, point: np.ndarray, **kwargs):
super().__init__(
mobject,
shift=point - mobject.get_center(),
@ -79,7 +94,7 @@ class FadeTransform(Transform):
"dim_to_match": 1,
}
def __init__(self, mobject, target_mobject, **kwargs):
def __init__(self, mobject: Mobject, target_mobject: Mobject, **kwargs):
self.to_add_on_completion = target_mobject
mobject.save_state()
super().__init__(
@ -87,7 +102,7 @@ class FadeTransform(Transform):
**kwargs
)
def begin(self):
def begin(self) -> None:
self.ending_mobject = self.mobject.copy()
Animation.begin(self)
# Both 'start' and 'end' consists of the source and target mobjects.
@ -97,21 +112,21 @@ class FadeTransform(Transform):
for m0, m1 in ((start[1], start[0]), (end[0], end[1])):
self.ghost_to(m0, m1)
def ghost_to(self, source, target):
def ghost_to(self, source: Mobject, target: Mobject) -> None:
source.replace(target, stretch=self.stretch, dim_to_match=self.dim_to_match)
source.set_opacity(0)
def get_all_mobjects(self):
def get_all_mobjects(self) -> list[Mobject]:
return [
self.mobject,
self.starting_mobject,
self.ending_mobject,
]
def get_all_families_zipped(self):
def get_all_families_zipped(self) -> zip[tuple[Mobject]]:
return Animation.get_all_families_zipped(self)
def clean_up_from_scene(self, scene):
def clean_up_from_scene(self, scene: Scene) -> None:
Animation.clean_up_from_scene(self, scene)
scene.remove(self.mobject)
self.mobject[0].restore()
@ -119,11 +134,11 @@ class FadeTransform(Transform):
class FadeTransformPieces(FadeTransform):
def begin(self):
def begin(self) -> None:
self.mobject[0].align_family(self.mobject[1])
super().begin()
def ghost_to(self, source, target):
def ghost_to(self, source: Mobject, target: Mobject) -> None:
for sm0, sm1 in zip(source.get_family(), target.get_family()):
super().ghost_to(sm0, sm1)
@ -136,7 +151,12 @@ class VFadeIn(Animation):
"suspend_mobject_updating": False,
}
def interpolate_submobject(self, submob, start, alpha):
def interpolate_submobject(
self,
submob: VMobject,
start: VMobject,
alpha: float
) -> None:
submob.set_stroke(
opacity=interpolate(0, start.get_stroke_opacity(), alpha)
)
@ -152,7 +172,12 @@ class VFadeOut(VFadeIn):
"final_alpha_value": 0,
}
def interpolate_submobject(self, submob, start, alpha):
def interpolate_submobject(
self,
submob: VMobject,
start: VMobject,
alpha: float
) -> None:
super().interpolate_submobject(submob, start, 1 - alpha)

View file

@ -1,6 +1,14 @@
from manimlib.animation.transform import Transform
# from manimlib.utils.paths import counterclockwise_path
from __future__ import annotations
from manimlib.constants import PI
from manimlib.animation.transform import Transform
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import numpy as np
from manimlib.mobject.mobject import Mobject
from manimlib.mobject.geometry import Arrow
class GrowFromPoint(Transform):
@ -8,14 +16,14 @@ class GrowFromPoint(Transform):
"point_color": None,
}
def __init__(self, mobject, point, **kwargs):
def __init__(self, mobject: Mobject, point: np.ndarray, **kwargs):
self.point = point
super().__init__(mobject, **kwargs)
def create_target(self):
def create_target(self) -> Mobject:
return self.mobject
def create_starting_mobject(self):
def create_starting_mobject(self) -> Mobject:
start = super().create_starting_mobject()
start.scale(0)
start.move_to(self.point)
@ -25,19 +33,19 @@ class GrowFromPoint(Transform):
class GrowFromCenter(GrowFromPoint):
def __init__(self, mobject, **kwargs):
def __init__(self, mobject: Mobject, **kwargs):
point = mobject.get_center()
super().__init__(mobject, point, **kwargs)
class GrowFromEdge(GrowFromPoint):
def __init__(self, mobject, edge, **kwargs):
def __init__(self, mobject: Mobject, edge: np.ndarray, **kwargs):
point = mobject.get_bounding_box_point(edge)
super().__init__(mobject, point, **kwargs)
class GrowArrow(GrowFromPoint):
def __init__(self, arrow, **kwargs):
def __init__(self, arrow: Arrow, **kwargs):
point = arrow.get_start()
super().__init__(arrow, point, **kwargs)

View file

@ -1,5 +1,9 @@
import numpy as np
from __future__ import annotations
import math
from typing import Union, Sequence
import numpy as np
from manimlib.constants import *
from manimlib.animation.animation import Animation
@ -25,6 +29,13 @@ from manimlib.utils.rate_functions import wiggle
from manimlib.utils.rate_functions import smooth
from manimlib.utils.rate_functions import squish_rate_func
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import colour
from manimlib.mobject.mobject import Mobject
ManimColor = Union[str, colour.Color, Sequence[float]]
class FocusOn(Transform):
CONFIG = {
@ -34,13 +45,13 @@ class FocusOn(Transform):
"remover": True,
}
def __init__(self, focus_point, **kwargs):
def __init__(self, focus_point: np.ndarray, **kwargs):
self.focus_point = focus_point
# Initialize with blank mobject, while create_target
# and create_starting_mobject handle the meat
super().__init__(VMobject(), **kwargs)
def create_target(self):
def create_target(self) -> Dot:
little_dot = Dot(radius=0)
little_dot.set_fill(self.color, opacity=self.opacity)
little_dot.add_updater(
@ -48,7 +59,7 @@ class FocusOn(Transform):
)
return little_dot
def create_starting_mobject(self):
def create_starting_mobject(self) -> Dot:
return Dot(
radius=FRAME_X_RADIUS + FRAME_Y_RADIUS,
stroke_width=0,
@ -64,7 +75,7 @@ class Indicate(Transform):
"color": YELLOW,
}
def create_target(self):
def create_target(self) -> Mobject:
target = self.mobject.copy()
target.scale(self.scale_factor)
target.set_color(self.color)
@ -80,7 +91,12 @@ class Flash(AnimationGroup):
"run_time": 1,
}
def __init__(self, point, color=YELLOW, **kwargs):
def __init__(
self,
point: np.ndarray,
color: ManimColor = YELLOW,
**kwargs
):
self.point = point
self.color = color
digest_config(self, kwargs)
@ -92,7 +108,7 @@ class Flash(AnimationGroup):
**kwargs,
)
def create_lines(self):
def create_lines(self) -> VGroup:
lines = VGroup()
for angle in np.arange(0, TAU, TAU / self.num_lines):
line = Line(ORIGIN, self.line_length * RIGHT)
@ -106,7 +122,7 @@ class Flash(AnimationGroup):
lines.add_updater(lambda l: l.move_to(self.point))
return lines
def create_line_anims(self):
def create_line_anims(self) -> list[Animation]:
return [
ShowCreationThenDestruction(line)
for line in self.lines
@ -122,17 +138,17 @@ class CircleIndicate(Indicate):
},
}
def __init__(self, mobject, **kwargs):
def __init__(self, mobject: Mobject, **kwargs):
digest_config(self, kwargs)
circle = self.get_circle(mobject)
super().__init__(circle, **kwargs)
def get_circle(self, mobject):
def get_circle(self, mobject: Mobject) -> Circle:
circle = Circle(**self.circle_config)
circle.add_updater(lambda c: c.surround(mobject))
return circle
def interpolate_mobject(self, alpha):
def interpolate_mobject(self, alpha: float) -> None:
super().interpolate_mobject(alpha)
self.mobject.set_stroke(opacity=alpha)
@ -143,7 +159,7 @@ class ShowPassingFlash(ShowPartial):
"remover": True,
}
def get_bounds(self, alpha):
def get_bounds(self, alpha: float) -> tuple[float, float]:
tw = self.time_width
upper = interpolate(0, 1 + tw, alpha)
lower = upper - tw
@ -151,7 +167,7 @@ class ShowPassingFlash(ShowPartial):
lower = max(lower, 0)
return (lower, upper)
def finish(self):
def finish(self) -> None:
super().finish()
for submob, start in self.get_all_families_zipped():
submob.pointwise_become_partial(start, 0, 1)
@ -164,7 +180,7 @@ class VShowPassingFlash(Animation):
"remover": True,
}
def begin(self):
def begin(self) -> None:
self.mobject.align_stroke_width_data_to_points()
# Compute an array of stroke widths for each submobject
# which tapers out at either end
@ -184,7 +200,12 @@ class VShowPassingFlash(Animation):
self.submob_to_anchor_widths[hash(sm)] = anchor_widths * taper_array
super().begin()
def interpolate_submobject(self, submobject, starting_sumobject, alpha):
def interpolate_submobject(
self,
submobject: VMobject,
starting_sumobject: None,
alpha: float
) -> None:
anchor_widths = self.submob_to_anchor_widths[hash(submobject)]
# Create a gaussian such that 3 sigmas out on either side
# will equals time_width
@ -206,7 +227,7 @@ class VShowPassingFlash(Animation):
new_widths[1::3] = (new_widths[0::3] + new_widths[2::3]) / 2
submobject.set_stroke(width=new_widths)
def finish(self):
def finish(self) -> None:
super().finish()
for submob, start in self.get_all_families_zipped():
submob.match_style(start)
@ -221,7 +242,7 @@ class FlashAround(VShowPassingFlash):
"n_inserted_curves": 20,
}
def __init__(self, mobject, **kwargs):
def __init__(self, mobject: Mobject, **kwargs):
digest_config(self, kwargs)
path = self.get_path(mobject)
if mobject.is_fixed_in_frame:
@ -231,12 +252,12 @@ class FlashAround(VShowPassingFlash):
path.set_stroke(self.color, self.stroke_width)
super().__init__(path, **kwargs)
def get_path(self, mobject):
def get_path(self, mobject: Mobject) -> SurroundingRectangle:
return SurroundingRectangle(mobject, buff=self.buff)
class FlashUnder(FlashAround):
def get_path(self, mobject):
def get_path(self, mobject: Mobject) -> Underline:
return Underline(mobject, buff=self.buff)
@ -252,7 +273,7 @@ class ShowCreationThenFadeOut(Succession):
"remover": True,
}
def __init__(self, mobject, **kwargs):
def __init__(self, mobject: Mobject, **kwargs):
super().__init__(
ShowCreation(mobject),
FadeOut(mobject),
@ -269,7 +290,7 @@ class AnimationOnSurroundingRectangle(AnimationGroup):
"rect_animation": Animation
}
def __init__(self, mobject, **kwargs):
def __init__(self, mobject: Mobject, **kwargs):
digest_config(self, kwargs)
if "surrounding_rectangle_config" in kwargs:
kwargs.pop("surrounding_rectangle_config")
@ -282,7 +303,7 @@ class AnimationOnSurroundingRectangle(AnimationGroup):
self.rect_animation(rect, **kwargs),
)
def get_rect(self):
def get_rect(self) -> SurroundingRectangle:
return SurroundingRectangle(
self.mobject_to_surround,
**self.surrounding_rectangle_config
@ -314,7 +335,7 @@ class ApplyWave(Homotopy):
"run_time": 1,
}
def __init__(self, mobject, **kwargs):
def __init__(self, mobject: Mobject, **kwargs):
digest_config(self, kwargs, locals())
left_x = mobject.get_left()[0]
right_x = mobject.get_right()[0]
@ -339,15 +360,20 @@ class WiggleOutThenIn(Animation):
"rotate_about_point": None,
}
def get_scale_about_point(self):
def get_scale_about_point(self) -> np.ndarray:
if self.scale_about_point is None:
return self.mobject.get_center()
def get_rotate_about_point(self):
def get_rotate_about_point(self) -> np.ndarray:
if self.rotate_about_point is None:
return self.mobject.get_center()
def interpolate_submobject(self, submobject, starting_sumobject, alpha):
def interpolate_submobject(
self,
submobject: Mobject,
starting_sumobject: Mobject,
alpha: float
) -> None:
submobject.match_points(starting_sumobject)
submobject.scale(
interpolate(1, self.scale_value, there_and_back(alpha)),
@ -364,7 +390,7 @@ class TurnInsideOut(Transform):
"path_arc": TAU / 4,
}
def create_target(self):
def create_target(self) -> Mobject:
return self.mobject.copy().reverse_points()
@ -373,7 +399,7 @@ class FlashyFadeIn(AnimationGroup):
"fade_lag": 0,
}
def __init__(self, vmobject, stroke_width=2, **kwargs):
def __init__(self, vmobject: VMobject, stroke_width: float = 2, **kwargs):
digest_config(self, kwargs)
outline = vmobject.copy()
outline.set_fill(opacity=0)

View file

@ -1,6 +1,16 @@
from __future__ import annotations
from typing import Callable, Sequence
from manimlib.animation.animation import Animation
from manimlib.utils.rate_functions import linear
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import numpy as np
from manimlib.mobject.mobject import Mobject
class Homotopy(Animation):
CONFIG = {
@ -8,7 +18,12 @@ class Homotopy(Animation):
"apply_function_kwargs": {},
}
def __init__(self, homotopy, mobject, **kwargs):
def __init__(
self,
homotopy: Callable[[float, float, float, float], Sequence[float]],
mobject: Mobject,
**kwargs
):
"""
Homotopy is a function from
(x, y, z, t) to (x', y', z')
@ -16,10 +31,18 @@ class Homotopy(Animation):
self.homotopy = homotopy
super().__init__(mobject, **kwargs)
def function_at_time_t(self, t):
def function_at_time_t(
self,
t: float
) -> Callable[[np.ndarray], Sequence[float]]:
return lambda p: self.homotopy(*p, t)
def interpolate_submobject(self, submob, start, alpha):
def interpolate_submobject(
self,
submob: Mobject,
start: Mobject,
alpha: float
) -> None:
submob.match_points(start)
submob.apply_function(
self.function_at_time_t(alpha),
@ -34,7 +57,12 @@ class SmoothedVectorizedHomotopy(Homotopy):
class ComplexHomotopy(Homotopy):
def __init__(self, complex_homotopy, mobject, **kwargs):
def __init__(
self,
complex_homotopy: Callable[[complex, float], Sequence[float]],
mobject: Mobject,
**kwargs
):
"""
Given a function form (z, t) -> w, where z and w
are complex numbers and t is time, this animates
@ -53,11 +81,16 @@ class PhaseFlow(Animation):
"suspend_mobject_updating": False,
}
def __init__(self, function, mobject, **kwargs):
def __init__(
self,
function: Callable[[np.ndarray], np.ndarray],
mobject: Mobject,
**kwargs
):
self.function = function
super().__init__(mobject, **kwargs)
def interpolate_mobject(self, alpha):
def interpolate_mobject(self, alpha: float) -> None:
if hasattr(self, "last_alpha"):
dt = self.virtual_time * (alpha - self.last_alpha)
self.mobject.apply_function(
@ -71,10 +104,10 @@ class MoveAlongPath(Animation):
"suspend_mobject_updating": False,
}
def __init__(self, mobject, path, **kwargs):
def __init__(self, mobject: Mobject, path: Mobject, **kwargs):
self.path = path
super().__init__(mobject, **kwargs)
def interpolate_mobject(self, alpha):
def interpolate_mobject(self, alpha: float) -> None:
point = self.path.point_from_proportion(alpha)
self.mobject.move_to(point)

View file

@ -1,3 +1,7 @@
from __future__ import annotations
from typing import Callable
from manimlib.animation.animation import Animation
from manimlib.mobject.numbers import DecimalNumber
from manimlib.utils.bezier import interpolate
@ -8,19 +12,29 @@ class ChangingDecimal(Animation):
"suspend_mobject_updating": False,
}
def __init__(self, decimal_mob, number_update_func, **kwargs):
def __init__(
self,
decimal_mob: DecimalNumber,
number_update_func: Callable[[float], float],
**kwargs
):
assert(isinstance(decimal_mob, DecimalNumber))
self.number_update_func = number_update_func
super().__init__(decimal_mob, **kwargs)
def interpolate_mobject(self, alpha):
def interpolate_mobject(self, alpha: float) -> None:
self.mobject.set_value(
self.number_update_func(alpha)
)
class ChangeDecimalToValue(ChangingDecimal):
def __init__(self, decimal_mob, target_number, **kwargs):
def __init__(
self,
decimal_mob: DecimalNumber,
target_number: float | complex,
**kwargs
):
start_number = decimal_mob.number
super().__init__(
decimal_mob,
@ -30,7 +44,12 @@ class ChangeDecimalToValue(ChangingDecimal):
class CountInFrom(ChangingDecimal):
def __init__(self, decimal_mob, source_number=0, **kwargs):
def __init__(
self,
decimal_mob: DecimalNumber,
source_number: float | complex = 0,
**kwargs
):
start_number = decimal_mob.number
super().__init__(
decimal_mob,

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from manimlib.animation.animation import Animation
from manimlib.constants import OUT
from manimlib.constants import PI
@ -6,6 +8,12 @@ from manimlib.constants import ORIGIN
from manimlib.utils.rate_functions import linear
from manimlib.utils.rate_functions import smooth
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import numpy as np
from manimlib.mobject.mobject import Mobject
class Rotating(Animation):
CONFIG = {
@ -18,12 +26,18 @@ class Rotating(Animation):
"suspend_mobject_updating": False,
}
def __init__(self, mobject, angle=TAU, axis=OUT, **kwargs):
def __init__(
self,
mobject: Mobject,
angle: float = TAU,
axis: np.ndarray = OUT,
**kwargs
):
self.angle = angle
self.axis = axis
super().__init__(mobject, **kwargs)
def interpolate_mobject(self, alpha):
def interpolate_mobject(self, alpha: float) -> None:
for sm1, sm2 in self.get_all_families_zipped():
sm1.set_points(sm2.get_points())
self.mobject.rotate(
@ -41,5 +55,11 @@ class Rotate(Rotating):
"about_edge": ORIGIN,
}
def __init__(self, mobject, angle=PI, axis=OUT, **kwargs):
def __init__(
self,
mobject: Mobject,
angle: float = PI,
axis: np.ndarray = OUT,
**kwargs
):
super().__init__(mobject, angle, axis, **kwargs)

View file

@ -1,3 +1,7 @@
from __future__ import annotations
import numpy as np
from manimlib.animation.composition import LaggedStart
from manimlib.animation.transform import Restore
from manimlib.constants import WHITE
@ -19,7 +23,7 @@ class Broadcast(LaggedStart):
"run_time": 3,
}
def __init__(self, focal_point, **kwargs):
def __init__(self, focal_point: np.ndarray, **kwargs):
digest_config(self, kwargs)
circles = VGroup()
for x in range(self.n_circles):

View file

@ -1,6 +1,10 @@
from __future__ import annotations
import inspect
from typing import Callable, Union, Sequence
import numpy as np
import numpy.typing as npt
from manimlib.animation.animation import Animation
from manimlib.constants import DEFAULT_POINTWISE_FUNCTION_RUN_TIME
@ -14,6 +18,13 @@ from manimlib.utils.paths import straight_path
from manimlib.utils.rate_functions import smooth
from manimlib.utils.rate_functions import squish_rate_func
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import colour
from manimlib.scene.scene import Scene
ManimColor = Union[str, colour.Color, Sequence[float]]
class Transform(Animation):
CONFIG = {
@ -23,12 +34,17 @@ class Transform(Animation):
"replace_mobject_with_target_in_scene": False,
}
def __init__(self, mobject, target_mobject=None, **kwargs):
def __init__(
self,
mobject: Mobject,
target_mobject: Mobject | None = None,
**kwargs
):
super().__init__(mobject, **kwargs)
self.target_mobject = target_mobject
self.init_path_func()
def init_path_func(self):
def init_path_func(self) -> None:
if self.path_func is not None:
return
elif self.path_arc == 0:
@ -39,7 +55,7 @@ class Transform(Animation):
self.path_arc_axis,
)
def begin(self):
def begin(self) -> None:
self.target_mobject = self.create_target()
self.check_target_mobject_validity()
# Use a copy of target_mobject for the align_data_and_family
@ -54,28 +70,28 @@ class Transform(Animation):
self.target_copy,
)
def finish(self):
def finish(self) -> None:
super().finish()
self.mobject.unlock_data()
def create_target(self):
def create_target(self) -> Mobject:
# Has no meaningful effect here, but may be useful
# in subclasses
return self.target_mobject
def check_target_mobject_validity(self):
def check_target_mobject_validity(self) -> None:
if self.target_mobject is None:
raise Exception(
f"{self.__class__.__name__}.create_target not properly implemented"
)
def clean_up_from_scene(self, scene):
def clean_up_from_scene(self, scene: Scene) -> None:
super().clean_up_from_scene(scene)
if self.replace_mobject_with_target_in_scene:
scene.remove(self.mobject)
scene.add(self.target_mobject)
def update_config(self, **kwargs):
def update_config(self, **kwargs) -> None:
Animation.update_config(self, **kwargs)
if "path_arc" in kwargs:
self.path_func = path_along_arc(
@ -83,7 +99,7 @@ class Transform(Animation):
kwargs.get("path_arc_axis", OUT)
)
def get_all_mobjects(self):
def get_all_mobjects(self) -> list[Mobject]:
return [
self.mobject,
self.starting_mobject,
@ -91,7 +107,7 @@ class Transform(Animation):
self.target_copy,
]
def get_all_families_zipped(self):
def get_all_families_zipped(self) -> zip[tuple[Mobject]]:
return zip(*[
mob.get_family()
for mob in [
@ -101,7 +117,13 @@ class Transform(Animation):
]
])
def interpolate_submobject(self, submob, start, target_copy, alpha):
def interpolate_submobject(
self,
submob: Mobject,
start: Mobject,
target_copy: Mobject,
alpha: float
):
submob.interpolate(start, target_copy, alpha, self.path_func)
return self
@ -117,10 +139,10 @@ class TransformFromCopy(Transform):
Performs a reversed Transform
"""
def __init__(self, mobject, target_mobject, **kwargs):
def __init__(self, mobject: Mobject, target_mobject: Mobject, **kwargs):
super().__init__(target_mobject, mobject, **kwargs)
def interpolate(self, alpha):
def interpolate(self, alpha: float) -> None:
super().interpolate(1 - alpha)
@ -137,11 +159,11 @@ class CounterclockwiseTransform(Transform):
class MoveToTarget(Transform):
def __init__(self, mobject, **kwargs):
def __init__(self, mobject: Mobject, **kwargs):
self.check_validity_of_input(mobject)
super().__init__(mobject, mobject.target, **kwargs)
def check_validity_of_input(self, mobject):
def check_validity_of_input(self, mobject: Mobject) -> None:
if not hasattr(mobject, "target"):
raise Exception(
"MoveToTarget called on mobject"
@ -150,13 +172,13 @@ class MoveToTarget(Transform):
class _MethodAnimation(MoveToTarget):
def __init__(self, mobject, methods):
def __init__(self, mobject: Mobject, methods: Callable):
self.methods = methods
super().__init__(mobject)
class ApplyMethod(Transform):
def __init__(self, method, *args, **kwargs):
def __init__(self, method: Callable, *args, **kwargs):
"""
method is a method of Mobject, *args are arguments for
that method. Key word arguments should be passed in
@ -170,7 +192,7 @@ class ApplyMethod(Transform):
self.method_args = args
super().__init__(method.__self__, **kwargs)
def check_validity_of_input(self, method):
def check_validity_of_input(self, method: Callable) -> None:
if not inspect.ismethod(method):
raise Exception(
"Whoops, looks like you accidentally invoked "
@ -178,7 +200,7 @@ class ApplyMethod(Transform):
)
assert(isinstance(method.__self__, Mobject))
def create_target(self):
def create_target(self) -> Mobject:
method = self.method
# Make sure it's a list so that args.pop() works
args = list(self.method_args)
@ -197,16 +219,26 @@ class ApplyPointwiseFunction(ApplyMethod):
"run_time": DEFAULT_POINTWISE_FUNCTION_RUN_TIME
}
def __init__(self, function, mobject, **kwargs):
def __init__(
self,
function: Callable[[np.ndarray], np.ndarray],
mobject: Mobject,
**kwargs
):
super().__init__(mobject.apply_function, function, **kwargs)
class ApplyPointwiseFunctionToCenter(ApplyPointwiseFunction):
def __init__(self, function, mobject, **kwargs):
def __init__(
self,
function: Callable[[np.ndarray], np.ndarray],
mobject: Mobject,
**kwargs
):
self.function = function
super().__init__(mobject.move_to, **kwargs)
def begin(self):
def begin(self) -> None:
self.method_args = [
self.function(self.mobject.get_center())
]
@ -214,31 +246,46 @@ class ApplyPointwiseFunctionToCenter(ApplyPointwiseFunction):
class FadeToColor(ApplyMethod):
def __init__(self, mobject, color, **kwargs):
def __init__(
self,
mobject: Mobject,
color: ManimColor,
**kwargs
):
super().__init__(mobject.set_color, color, **kwargs)
class ScaleInPlace(ApplyMethod):
def __init__(self, mobject, scale_factor, **kwargs):
def __init__(
self,
mobject: Mobject,
scale_factor: npt.ArrayLike,
**kwargs
):
super().__init__(mobject.scale, scale_factor, **kwargs)
class ShrinkToCenter(ScaleInPlace):
def __init__(self, mobject, **kwargs):
def __init__(self, mobject: Mobject, **kwargs):
super().__init__(mobject, 0, **kwargs)
class Restore(ApplyMethod):
def __init__(self, mobject, **kwargs):
def __init__(self, mobject: Mobject, **kwargs):
super().__init__(mobject.restore, **kwargs)
class ApplyFunction(Transform):
def __init__(self, function, mobject, **kwargs):
def __init__(
self,
function: Callable[[Mobject], Mobject],
mobject: Mobject,
**kwargs
):
self.function = function
super().__init__(mobject, **kwargs)
def create_target(self):
def create_target(self) -> Mobject:
target = self.function(self.mobject.copy())
if not isinstance(target, Mobject):
raise Exception("Functions passed to ApplyFunction must return object of type Mobject")
@ -246,7 +293,12 @@ class ApplyFunction(Transform):
class ApplyMatrix(ApplyPointwiseFunction):
def __init__(self, matrix, mobject, **kwargs):
def __init__(
self,
matrix: npt.ArrayLike,
mobject: Mobject,
**kwargs
):
matrix = self.initialize_matrix(matrix)
def func(p):
@ -254,7 +306,7 @@ class ApplyMatrix(ApplyPointwiseFunction):
super().__init__(func, mobject, **kwargs)
def initialize_matrix(self, matrix):
def initialize_matrix(self, matrix: npt.ArrayLike) -> np.ndarray:
matrix = np.array(matrix)
if matrix.shape == (2, 2):
new_matrix = np.identity(3)
@ -266,12 +318,17 @@ class ApplyMatrix(ApplyPointwiseFunction):
class ApplyComplexFunction(ApplyMethod):
def __init__(self, function, mobject, **kwargs):
def __init__(
self,
function: Callable[[complex], complex],
mobject: Mobject,
**kwargs
):
self.function = function
method = mobject.apply_complex_function
super().__init__(method, function, **kwargs)
def init_path_func(self):
def init_path_func(self) -> None:
func1 = self.function(complex(1))
self.path_arc = np.log(func1).imag
super().init_path_func()
@ -284,11 +341,11 @@ class CyclicReplace(Transform):
"path_arc": 90 * DEGREES,
}
def __init__(self, *mobjects, **kwargs):
def __init__(self, *mobjects: Mobject, **kwargs):
self.group = Group(*mobjects)
super().__init__(self.group, **kwargs)
def create_target(self):
def create_target(self) -> Mobject:
target = self.group.copy()
cycled_targets = [target[-1], *target[:-1]]
for m1, m2 in zip(cycled_targets, self.group):
@ -306,7 +363,7 @@ class TransformAnimations(Transform):
"rate_func": squish_rate_func(smooth)
}
def __init__(self, start_anim, end_anim, **kwargs):
def __init__(self, start_anim: Animation, end_anim: Animation, **kwargs):
digest_config(self, kwargs, locals())
if "run_time" in kwargs:
self.run_time = kwargs.pop("run_time")
@ -327,7 +384,7 @@ class TransformAnimations(Transform):
start_anim.mobject = self.starting_mobject
end_anim.mobject = self.target_mobject
def interpolate(self, alpha):
def interpolate(self, alpha: float) -> None:
self.start_anim.interpolate(alpha)
self.end_anim.interpolate(alpha)
Transform.interpolate(self, alpha)

View file

@ -1,13 +1,15 @@
import numpy as np
from __future__ import annotations
import itertools as it
import numpy as np
from manimlib.animation.composition import AnimationGroup
from manimlib.animation.fading import FadeTransformPieces
from manimlib.animation.fading import FadeInFromPoint
from manimlib.animation.fading import FadeOutToPoint
from manimlib.animation.transform import ReplacementTransform
from manimlib.animation.transform import Transform
from manimlib.mobject.mobject import Mobject
from manimlib.mobject.mobject import Group
from manimlib.mobject.svg.mtex_mobject import MTex
@ -16,6 +18,12 @@ from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.config_ops import digest_config
from manimlib.utils.iterables import remove_list_redundancies
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.scene.scene import Scene
from manimlib.mobject.svg.tex_mobject import Tex, SingleStringTex
class TransformMatchingParts(AnimationGroup):
CONFIG = {
@ -26,7 +34,7 @@ class TransformMatchingParts(AnimationGroup):
"key_map": dict(),
}
def __init__(self, mobject, target_mobject, **kwargs):
def __init__(self, mobject: Mobject, target_mobject: Mobject, **kwargs):
digest_config(self, kwargs)
assert(isinstance(mobject, self.mobject_type))
assert(isinstance(target_mobject, self.mobject_type))
@ -83,8 +91,8 @@ class TransformMatchingParts(AnimationGroup):
self.to_remove = mobject
self.to_add = target_mobject
def get_shape_map(self, mobject):
shape_map = {}
def get_shape_map(self, mobject: Mobject) -> dict[int, VGroup]:
shape_map: dict[int, VGroup] = {}
for sm in self.get_mobject_parts(mobject):
key = self.get_mobject_key(sm)
if key not in shape_map:
@ -92,7 +100,7 @@ class TransformMatchingParts(AnimationGroup):
shape_map[key].add(sm)
return shape_map
def clean_up_from_scene(self, scene):
def clean_up_from_scene(self, scene: Scene) -> None:
for anim in self.animations:
anim.update(0)
scene.remove(self.mobject)
@ -100,12 +108,12 @@ class TransformMatchingParts(AnimationGroup):
scene.add(self.to_add)
@staticmethod
def get_mobject_parts(mobject):
def get_mobject_parts(mobject: Mobject) -> Mobject:
# To be implemented in subclass
return mobject
@staticmethod
def get_mobject_key(mobject):
def get_mobject_key(mobject: Mobject) -> int:
# To be implemented in subclass
return hash(mobject)
@ -117,11 +125,11 @@ class TransformMatchingShapes(TransformMatchingParts):
}
@staticmethod
def get_mobject_parts(mobject):
def get_mobject_parts(mobject: VMobject) -> list[VMobject]:
return mobject.family_members_with_points()
@staticmethod
def get_mobject_key(mobject):
def get_mobject_key(mobject: VMobject) -> int:
mobject.save_state()
mobject.center()
mobject.set_height(1)
@ -137,11 +145,11 @@ class TransformMatchingTex(TransformMatchingParts):
}
@staticmethod
def get_mobject_parts(mobject):
def get_mobject_parts(mobject: Tex) -> list[SingleStringTex]:
return mobject.submobjects
@staticmethod
def get_mobject_key(mobject):
def get_mobject_key(mobject: Tex) -> str:
return mobject.get_tex()
@ -150,7 +158,7 @@ class TransformMatchingMTex(AnimationGroup):
"key_map": dict(),
}
def __init__(self, source_mobject, target_mobject, **kwargs):
def __init__(self, source_mobject: MTex, target_mobject: MTex, **kwargs):
digest_config(self, kwargs)
assert isinstance(source_mobject, MTex)
assert isinstance(target_mobject, MTex)

View file

@ -1,7 +1,15 @@
from __future__ import annotations
import operator as op
from typing import Callable
from manimlib.animation.animation import Animation
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.mobject.mobject import Mobject
class UpdateFromFunc(Animation):
"""
@ -13,21 +21,31 @@ class UpdateFromFunc(Animation):
"suspend_mobject_updating": False,
}
def __init__(self, mobject, update_function, **kwargs):
def __init__(
self,
mobject: Mobject,
update_function: Callable[[Mobject]],
**kwargs
):
self.update_function = update_function
super().__init__(mobject, **kwargs)
def interpolate_mobject(self, alpha):
def interpolate_mobject(self, alpha: float) -> None:
self.update_function(self.mobject)
class UpdateFromAlphaFunc(UpdateFromFunc):
def interpolate_mobject(self, alpha):
def interpolate_mobject(self, alpha: float) -> None:
self.update_function(self.mobject, alpha)
class MaintainPositionRelativeTo(Animation):
def __init__(self, mobject, tracked_mobject, **kwargs):
def __init__(
self,
mobject: Mobject,
tracked_mobject: Mobject,
**kwargs
):
self.tracked_mobject = tracked_mobject
self.diff = op.sub(
mobject.get_center(),
@ -35,7 +53,7 @@ class MaintainPositionRelativeTo(Animation):
)
super().__init__(mobject, **kwargs)
def interpolate_mobject(self, alpha):
def interpolate_mobject(self, alpha: float) -> None:
target = self.tracked_mobject.get_center()
location = self.mobject.get_center()
self.mobject.shift(target - location + self.diff)

View file

@ -1,12 +1,14 @@
import moderngl
import math
from colour import Color
import OpenGL.GL as gl
from __future__ import annotations
from PIL import Image
import numpy as np
import math
import itertools as it
import moderngl
import numpy as np
from PIL import Image
import OpenGL.GL as gl
from colour import Color
from manimlib.constants import *
from manimlib.mobject.mobject import Mobject
from manimlib.mobject.mobject import Point
@ -19,6 +21,11 @@ from manimlib.utils.space_ops import rotation_matrix_transpose
from manimlib.utils.space_ops import quaternion_from_angle_axis
from manimlib.utils.space_ops import quaternion_mult
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.shader_wrapper import ShaderWrapper
class CameraFrame(Mobject):
CONFIG = {
@ -29,12 +36,12 @@ class CameraFrame(Mobject):
"focal_distance": 2,
}
def init_data(self):
def init_data(self) -> None:
super().init_data()
self.data["euler_angles"] = np.array(self.euler_angles, dtype=float)
self.refresh_rotation_matrix()
def init_points(self):
def init_points(self) -> None:
self.set_points([ORIGIN, LEFT, RIGHT, DOWN, UP])
self.set_width(self.frame_shape[0], stretch=True)
self.set_height(self.frame_shape[1], stretch=True)
@ -47,13 +54,13 @@ class CameraFrame(Mobject):
self.set_euler_angles(0, 0, 0)
return self
def get_euler_angles(self):
def get_euler_angles(self) -> np.ndarray:
return self.data["euler_angles"]
def get_inverse_camera_rotation_matrix(self):
def get_inverse_camera_rotation_matrix(self) -> list[list[float]]:
return self.inverse_camera_rotation_matrix
def refresh_rotation_matrix(self):
def refresh_rotation_matrix(self) -> None:
# Rotate based on camera orientation
theta, phi, gamma = self.get_euler_angles()
quat = quaternion_mult(
@ -63,7 +70,7 @@ class CameraFrame(Mobject):
)
self.inverse_camera_rotation_matrix = rotation_matrix_transpose_from_quaternion(quat)
def rotate(self, angle, axis=OUT, **kwargs):
def rotate(self, angle: float, axis: np.ndarray = OUT, **kwargs):
curr_rot_T = self.get_inverse_camera_rotation_matrix()
added_rot_T = rotation_matrix_transpose(angle, axis)
new_rot_T = np.dot(curr_rot_T, added_rot_T)
@ -78,7 +85,13 @@ class CameraFrame(Mobject):
self.set_euler_angles(theta, phi, gamma)
return self
def set_euler_angles(self, theta=None, phi=None, gamma=None, units=RADIANS):
def set_euler_angles(
self,
theta: float | None = None,
phi: float | None = None,
gamma: float | None = None,
units: float = RADIANS
):
if theta is not None:
self.data["euler_angles"][0] = theta * units
if phi is not None:
@ -88,7 +101,12 @@ class CameraFrame(Mobject):
self.refresh_rotation_matrix()
return self
def reorient(self, theta_degrees=None, phi_degrees=None, gamma_degrees=None):
def reorient(
self,
theta_degrees: float | None = None,
phi_degrees: float | None = None,
gamma_degrees: float | None = None,
):
"""
Shortcut for set_euler_angles, defaulting to taking
in angles in degrees
@ -96,60 +114,60 @@ class CameraFrame(Mobject):
self.set_euler_angles(theta_degrees, phi_degrees, gamma_degrees, units=DEGREES)
return self
def set_theta(self, theta):
def set_theta(self, theta: float):
return self.set_euler_angles(theta=theta)
def set_phi(self, phi):
def set_phi(self, phi: float):
return self.set_euler_angles(phi=phi)
def set_gamma(self, gamma):
def set_gamma(self, gamma: float):
return self.set_euler_angles(gamma=gamma)
def increment_theta(self, dtheta):
def increment_theta(self, dtheta: float):
self.data["euler_angles"][0] += dtheta
self.refresh_rotation_matrix()
return self
def increment_phi(self, dphi):
def increment_phi(self, dphi: float):
phi = self.data["euler_angles"][1]
new_phi = clip(phi + dphi, 0, PI)
self.data["euler_angles"][1] = new_phi
self.refresh_rotation_matrix()
return self
def increment_gamma(self, dgamma):
def increment_gamma(self, dgamma: float):
self.data["euler_angles"][2] += dgamma
self.refresh_rotation_matrix()
return self
def get_theta(self):
def get_theta(self) -> float:
return self.data["euler_angles"][0]
def get_phi(self):
def get_phi(self) -> float:
return self.data["euler_angles"][1]
def get_gamma(self):
def get_gamma(self) -> float:
return self.data["euler_angles"][2]
def get_shape(self):
def get_shape(self) -> tuple[float, float]:
return (self.get_width(), self.get_height())
def get_center(self):
def get_center(self) -> np.ndarray:
# Assumes first point is at the center
return self.get_points()[0]
def get_width(self):
def get_width(self) -> float:
points = self.get_points()
return points[2, 0] - points[1, 0]
def get_height(self):
def get_height(self) -> float:
points = self.get_points()
return points[4, 1] - points[3, 1]
def get_focal_distance(self):
def get_focal_distance(self) -> float:
return self.focal_distance * self.get_height()
def get_implied_camera_location(self):
def get_implied_camera_location(self) -> tuple[float, float, float]:
theta, phi, gamma = self.get_euler_angles()
dist = self.get_focal_distance()
x, y, z = self.get_center()
@ -190,10 +208,10 @@ class Camera(object):
"samples": 0,
}
def __init__(self, ctx=None, **kwargs):
def __init__(self, ctx: moderngl.Context | None = None, **kwargs):
digest_config(self, kwargs, locals())
self.rgb_max_val = np.iinfo(self.pixel_array_dtype).max
self.background_rgba = [
self.rgb_max_val: float = np.iinfo(self.pixel_array_dtype).max
self.background_rgba: list[float] = [
*Color(self.background_color).get_rgb(),
self.background_opacity
]
@ -205,10 +223,10 @@ class Camera(object):
self.refresh_perspective_uniforms()
self.static_mobject_to_render_group_list = {}
def init_frame(self):
def init_frame(self) -> None:
self.frame = CameraFrame(**self.frame_config)
def init_context(self, ctx=None):
def init_context(self, ctx: moderngl.Context | None = None) -> None:
if ctx is None:
ctx = moderngl.create_standalone_context()
fbo = self.get_fbo(ctx, 0)
@ -223,7 +241,7 @@ class Camera(object):
fbo_msaa.use()
self.fbo_msaa = fbo_msaa
def set_ctx_blending(self, enable=True):
def set_ctx_blending(self, enable: bool = True) -> None:
if enable:
self.ctx.enable(moderngl.BLEND)
else:
@ -233,17 +251,21 @@ class Camera(object):
# moderngl.ONE, moderngl.ONE
)
def set_ctx_depth_test(self, enable=True):
def set_ctx_depth_test(self, enable: bool = True) -> None:
if enable:
self.ctx.enable(moderngl.DEPTH_TEST)
else:
self.ctx.disable(moderngl.DEPTH_TEST)
def init_light_source(self):
def init_light_source(self) -> None:
self.light_source = Point(self.light_source_position)
# Methods associated with the frame buffer
def get_fbo(self, ctx, samples=0):
def get_fbo(
self,
ctx: moderngl.Context,
samples: int = 0
) -> moderngl.Framebuffer:
pw = self.pixel_width
ph = self.pixel_height
return ctx.framebuffer(
@ -258,16 +280,16 @@ class Camera(object):
)
)
def clear(self):
def clear(self) -> None:
self.fbo.clear(*self.background_rgba)
self.fbo_msaa.clear(*self.background_rgba)
def reset_pixel_shape(self, new_width, new_height):
def reset_pixel_shape(self, new_width: int, new_height: int) -> None:
self.pixel_width = new_width
self.pixel_height = new_height
self.refresh_perspective_uniforms()
def get_raw_fbo_data(self, dtype='f1'):
def get_raw_fbo_data(self, dtype: str = 'f1') -> bytes:
# Copy blocks from the fbo_msaa to the drawn fbo using Blit
pw, ph = (self.pixel_width, self.pixel_height)
gl.glBindFramebuffer(gl.GL_READ_FRAMEBUFFER, self.fbo_msaa.glo)
@ -279,7 +301,7 @@ class Camera(object):
dtype=dtype,
)
def get_image(self, pixel_array=None):
def get_image(self) -> Image.Image:
return Image.frombytes(
'RGBA',
self.get_pixel_shape(),
@ -287,7 +309,7 @@ class Camera(object):
'raw', 'RGBA', 0, -1
)
def get_pixel_array(self):
def get_pixel_array(self) -> np.ndarray:
raw = self.get_raw_fbo_data(dtype='f4')
flat_arr = np.frombuffer(raw, dtype='f4')
arr = flat_arr.reshape([*self.fbo.size, self.n_channels])
@ -295,7 +317,7 @@ class Camera(object):
return (self.rgb_max_val * arr).astype(self.pixel_array_dtype)
# Needed?
def get_texture(self):
def get_texture(self) -> moderngl.Texture:
texture = self.ctx.texture(
size=self.fbo.size,
components=4,
@ -305,32 +327,32 @@ class Camera(object):
return texture
# Getting camera attributes
def get_pixel_shape(self):
def get_pixel_shape(self) -> tuple[int, int]:
return self.fbo.viewport[2:4]
# return (self.pixel_width, self.pixel_height)
def get_pixel_width(self):
def get_pixel_width(self) -> int:
return self.get_pixel_shape()[0]
def get_pixel_height(self):
def get_pixel_height(self) -> int:
return self.get_pixel_shape()[1]
def get_frame_height(self):
def get_frame_height(self) -> float:
return self.frame.get_height()
def get_frame_width(self):
def get_frame_width(self) -> float:
return self.frame.get_width()
def get_frame_shape(self):
def get_frame_shape(self) -> tuple[float, float]:
return (self.get_frame_width(), self.get_frame_height())
def get_frame_center(self):
def get_frame_center(self) -> np.ndarray:
return self.frame.get_center()
def get_location(self):
def get_location(self) -> tuple[float, float, float]:
return self.frame.get_implied_camera_location()
def resize_frame_shape(self, fixed_dimension=0):
def resize_frame_shape(self, fixed_dimension: bool = False) -> None:
"""
Changes frame_shape to match the aspect ratio
of the pixels, where fixed_dimension determines
@ -342,7 +364,7 @@ class Camera(object):
frame_height = self.get_frame_height()
frame_width = self.get_frame_width()
aspect_ratio = fdiv(pixel_width, pixel_height)
if fixed_dimension == 0:
if not fixed_dimension:
frame_height = frame_width / aspect_ratio
else:
frame_width = aspect_ratio * frame_height
@ -350,13 +372,13 @@ class Camera(object):
self.frame.set_width(frame_width)
# Rendering
def capture(self, *mobjects, **kwargs):
def capture(self, *mobjects: Mobject, **kwargs) -> None:
self.refresh_perspective_uniforms()
for mobject in mobjects:
for render_group in self.get_render_group_list(mobject):
self.render(render_group)
def render(self, render_group):
def render(self, render_group: dict[str]) -> None:
shader_wrapper = render_group["shader_wrapper"]
shader_program = render_group["prog"]
self.set_shader_uniforms(shader_program, shader_wrapper)
@ -365,13 +387,17 @@ class Camera(object):
if render_group["single_use"]:
self.release_render_group(render_group)
def get_render_group_list(self, mobject):
def get_render_group_list(self, mobject: Mobject) -> list[dict[str]] | map[dict[str]]:
try:
return self.static_mobject_to_render_group_list[id(mobject)]
except KeyError:
return map(self.get_render_group, mobject.get_shader_wrapper_list())
def get_render_group(self, shader_wrapper, single_use=True):
def get_render_group(
self,
shader_wrapper: ShaderWrapper,
single_use: bool = True
) -> dict[str]:
# Data buffers
vbo = self.ctx.buffer(shader_wrapper.vert_data.tobytes())
if shader_wrapper.vert_indices is None:
@ -399,12 +425,12 @@ class Camera(object):
"single_use": single_use,
}
def release_render_group(self, render_group):
def release_render_group(self, render_group: dict[str]) -> None:
for key in ["vbo", "ibo", "vao"]:
if render_group[key] is not None:
render_group[key].release()
def set_mobjects_as_static(self, *mobjects):
def set_mobjects_as_static(self, *mobjects: Mobject) -> None:
# Creates buffer and array objects holding each mobjects shader data
for mob in mobjects:
self.static_mobject_to_render_group_list[id(mob)] = [
@ -412,18 +438,23 @@ class Camera(object):
for sw in mob.get_shader_wrapper_list()
]
def release_static_mobjects(self):
def release_static_mobjects(self) -> None:
for rg_list in self.static_mobject_to_render_group_list.values():
for render_group in rg_list:
self.release_render_group(render_group)
self.static_mobject_to_render_group_list = {}
# Shaders
def init_shaders(self):
def init_shaders(self) -> None:
# Initialize with the null id going to None
self.id_to_shader_program = {"": None}
self.id_to_shader_program: dict[
int | str, tuple[moderngl.Program, str] | None
] = {"": None}
def get_shader_program(self, shader_wrapper):
def get_shader_program(
self,
shader_wrapper: ShaderWrapper
) -> tuple[moderngl.Program, str]:
sid = shader_wrapper.get_program_id()
if sid not in self.id_to_shader_program:
# Create shader program for the first time, then cache
@ -433,7 +464,11 @@ class Camera(object):
self.id_to_shader_program[sid] = (program, vert_format)
return self.id_to_shader_program[sid]
def set_shader_uniforms(self, shader, shader_wrapper):
def set_shader_uniforms(
self,
shader: moderngl.Program,
shader_wrapper: ShaderWrapper
) -> None:
for name, path in shader_wrapper.texture_paths.items():
tid = self.get_texture_id(path)
shader[name].value = tid
@ -445,7 +480,7 @@ class Camera(object):
except KeyError:
pass
def refresh_perspective_uniforms(self):
def refresh_perspective_uniforms(self) -> None:
frame = self.frame
pw, ph = self.get_pixel_shape()
fw, fh = frame.get_shape()
@ -470,11 +505,13 @@ class Camera(object):
"focal_distance": frame.get_focal_distance(),
}
def init_textures(self):
self.n_textures = 0
self.path_to_texture = {}
def init_textures(self) -> None:
self.n_textures: int = 0
self.path_to_texture: dict[
str, tuple[int, moderngl.Texture]
] = {}
def get_texture_id(self, path):
def get_texture_id(self, path: str) -> int:
if path not in self.path_to_texture:
if self.n_textures == 15: # I have no clue why this is needed
self.n_textures += 1
@ -490,7 +527,7 @@ class Camera(object):
self.path_to_texture[path] = (tid, texture)
return self.path_to_texture[path][0]
def release_texture(self, path):
def release_texture(self, path: str):
tid_and_texture = self.path_to_texture.pop(path, None)
if tid_and_texture:
tid_and_texture[1].release()

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import numpy as np
from manimlib.event_handler.event_type import EventType
@ -6,21 +8,23 @@ from manimlib.event_handler.event_listner import EventListner
class EventDispatcher(object):
def __init__(self):
self.event_listners = {
self.event_listners: dict[
EventType, list[EventListner]
] = {
event_type: []
for event_type in EventType
}
self.mouse_point = np.array((0., 0., 0.))
self.mouse_drag_point = np.array((0., 0., 0.))
self.pressed_keys = set()
self.draggable_object_listners = []
self.pressed_keys: set[int] = set()
self.draggable_object_listners: list[EventListner] = []
def add_listner(self, event_listner):
def add_listner(self, event_listner: EventListner):
assert(isinstance(event_listner, EventListner))
self.event_listners[event_listner.event_type].append(event_listner)
return self
def remove_listner(self, event_listner):
def remove_listner(self, event_listner: EventListner):
assert(isinstance(event_listner, EventListner))
try:
while event_listner in self.event_listners[event_listner.event_type]:
@ -30,8 +34,7 @@ class EventDispatcher(object):
pass
return self
def dispatch(self, event_type, **event_data):
def dispatch(self, event_type: EventType, **event_data):
if event_type == EventType.MouseMotionEvent:
self.mouse_point = event_data["point"]
elif event_type == EventType.MouseDragEvent:
@ -74,16 +77,16 @@ class EventDispatcher(object):
return propagate_event
def get_listners_count(self):
def get_listners_count(self) -> int:
return sum([len(value) for key, value in self.event_listners.items()])
def get_mouse_point(self):
def get_mouse_point(self) -> np.ndarray:
return self.mouse_point
def get_mouse_drag_point(self):
def get_mouse_drag_point(self) -> np.ndarray:
return self.mouse_drag_point
def is_key_pressed(self, symbol):
def is_key_pressed(self, symbol: int) -> bool:
return (symbol in self.pressed_keys)
__iadd__ = add_listner

View file

@ -1,5 +1,18 @@
from __future__ import annotations
from typing import Callable, TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.mobject.mobject import Mobject
from manimlib.event_handler.event_type import EventType
class EventListner(object):
def __init__(self, mobject, event_type, event_callback):
def __init__(
self,
mobject: Mobject,
event_type: EventType,
event_callback: Callable[[Mobject, dict[str]]]
):
self.mobject = mobject
self.event_type = event_type
self.callback = event_callback

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import numpy as np
import pathops
@ -7,7 +9,7 @@ from manimlib.mobject.types.vectorized_mobject import VMobject
# Boolean operations between 2D mobjects
# Borrowed from from https://github.com/ManimCommunity/manim/
def _convert_vmobject_to_skia_path(vmobject):
def _convert_vmobject_to_skia_path(vmobject: VMobject) -> pathops.Path:
path = pathops.Path()
subpaths = vmobject.get_subpaths_from_points(vmobject.get_all_points())
for subpath in subpaths:
@ -21,7 +23,10 @@ def _convert_vmobject_to_skia_path(vmobject):
return path
def _convert_skia_path_to_vmobject(path, vmobject):
def _convert_skia_path_to_vmobject(
path: pathops.Path,
vmobject: VMobject
) -> VMobject:
PathVerb = pathops.PathVerb
current_path_start = np.array([0.0, 0.0, 0.0])
for path_verb, points in path:
@ -45,7 +50,7 @@ def _convert_skia_path_to_vmobject(path, vmobject):
class Union(VMobject):
def __init__(self, *vmobjects, **kwargs):
def __init__(self, *vmobjects: VMobject, **kwargs):
if len(vmobjects) < 2:
raise ValueError("At least 2 mobjects needed for Union.")
super().__init__(**kwargs)
@ -59,7 +64,7 @@ class Union(VMobject):
class Difference(VMobject):
def __init__(self, subject, clip, **kwargs):
def __init__(self, subject: VMobject, clip: VMobject, **kwargs):
super().__init__(**kwargs)
outpen = pathops.Path()
pathops.difference(
@ -71,7 +76,7 @@ class Difference(VMobject):
class Intersection(VMobject):
def __init__(self, *vmobjects, **kwargs):
def __init__(self, *vmobjects: VMobject, **kwargs):
if len(vmobjects) < 2:
raise ValueError("At least 2 mobjects needed for Intersection.")
super().__init__(**kwargs)
@ -94,7 +99,7 @@ class Intersection(VMobject):
class Exclusion(VMobject):
def __init__(self, *vmobjects, **kwargs):
def __init__(self, *vmobjects: VMobject, **kwargs):
if len(vmobjects) < 2:
raise ValueError("At least 2 mobjects needed for Exclusion.")
super().__init__(**kwargs)

View file

@ -1,4 +1,9 @@
from __future__ import annotations
from typing import Callable
import numpy as np
from manimlib.constants import BLUE_D
from manimlib.constants import BLUE_B
from manimlib.constants import BLUE_E
@ -20,10 +25,10 @@ class AnimatedBoundary(VGroup):
"fade_rate_func": smooth,
}
def __init__(self, vmobject, **kwargs):
def __init__(self, vmobject: VMobject, **kwargs):
super().__init__(**kwargs)
self.vmobject = vmobject
self.boundary_copies = [
self.vmobject: VMobject = vmobject
self.boundary_copies: list[VMobject] = [
vmobject.copy().set_style(
stroke_width=0,
fill_opacity=0
@ -31,12 +36,12 @@ class AnimatedBoundary(VGroup):
for x in range(2)
]
self.add(*self.boundary_copies)
self.total_time = 0
self.total_time: float = 0
self.add_updater(
lambda m, dt: self.update_boundary_copies(dt)
)
def update_boundary_copies(self, dt):
def update_boundary_copies(self, dt: float) -> None:
# Not actual time, but something which passes at
# an altered rate to make the implementation below
# cleaner
@ -67,7 +72,13 @@ class AnimatedBoundary(VGroup):
self.total_time += dt
def full_family_become_partial(self, mob1, mob2, a, b):
def full_family_become_partial(
self,
mob1: VMobject,
mob2: VMobject,
a: float,
b: float
):
family1 = mob1.family_members_with_points()
family2 = mob2.family_members_with_points()
for sm1, sm2 in zip(family1, family2):
@ -84,14 +95,14 @@ class TracedPath(VMobject):
"time_per_anchor": 1 / 15,
}
def __init__(self, traced_point_func, **kwargs):
def __init__(self, traced_point_func: Callable[[], np.ndarray], **kwargs):
super().__init__(**kwargs)
self.traced_point_func = traced_point_func
self.time = 0
self.traced_points = []
self.time: float = 0
self.traced_points: list[np.ndarray] = []
self.add_updater(lambda m, dt: m.update_path(dt))
def update_path(self, dt):
def update_path(self, dt: float):
if dt == 0:
return self
point = self.traced_point_func().copy()
@ -133,7 +144,11 @@ class TracingTail(TracedPath):
"time_traced": 1.0,
}
def __init__(self, mobject_or_func, **kwargs):
def __init__(
self,
mobject_or_func: Mobject | Callable[[], np.ndarray],
**kwargs
):
if isinstance(mobject_or_func, Mobject):
func = mobject_or_func.get_center
else:

View file

@ -1,6 +1,10 @@
from abc import abstractmethod
import numpy as np
from __future__ import annotations
import numbers
from abc import abstractmethod
from typing import Type, TypeVar, Union, Callable, Iterable, Sequence
import numpy as np
from manimlib.constants import *
from manimlib.mobject.functions import ParametricCurve
@ -18,6 +22,15 @@ from manimlib.utils.space_ops import angle_of_vector
from manimlib.utils.space_ops import get_norm
from manimlib.utils.space_ops import rotate_vector
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import colour
from manimlib.mobject.mobject import Mobject
T = TypeVar("T", bound=Mobject)
ManimColor = Union[str, colour.Color, Sequence[float]]
EPSILON = 1e-8
@ -39,56 +52,77 @@ class CoordinateSystem():
self.x_range = np.array(self.default_x_range)
self.y_range = np.array(self.default_y_range)
def coords_to_point(self, *coords):
@abstractmethod
def coords_to_point(self, *coords: float) -> np.ndarray:
raise Exception("Not implemented")
def point_to_coords(self, point):
@abstractmethod
def point_to_coords(self, point: np.ndarray) -> tuple[float, ...]:
raise Exception("Not implemented")
def c2p(self, *coords):
def c2p(self, *coords: float):
"""Abbreviation for coords_to_point"""
return self.coords_to_point(*coords)
def p2c(self, point):
def p2c(self, point: np.ndarray):
"""Abbreviation for point_to_coords"""
return self.point_to_coords(point)
def get_origin(self):
def get_origin(self) -> np.ndarray:
return self.c2p(*[0] * self.dimension)
@abstractmethod
def get_axes(self):
def get_axes(self) -> VGroup:
raise Exception("Not implemented")
@abstractmethod
def get_all_ranges(self):
def get_all_ranges(self) -> list[np.ndarray]:
raise Exception("Not implemented")
def get_axis(self, index):
def get_axis(self, index: int) -> NumberLine:
return self.get_axes()[index]
def get_x_axis(self):
def get_x_axis(self) -> NumberLine:
return self.get_axis(0)
def get_y_axis(self):
def get_y_axis(self) -> NumberLine:
return self.get_axis(1)
def get_z_axis(self):
def get_z_axis(self) -> NumberLine:
return self.get_axis(2)
def get_x_axis_label(self, label_tex, edge=RIGHT, direction=DL, **kwargs):
def get_x_axis_label(
self,
label_tex: str,
edge: np.ndarray = RIGHT,
direction: np.ndarray = DL,
**kwargs
) -> Tex:
return self.get_axis_label(
label_tex, self.get_x_axis(),
edge, direction, **kwargs
)
def get_y_axis_label(self, label_tex, edge=UP, direction=DR, **kwargs):
def get_y_axis_label(
self,
label_tex: str,
edge: np.ndarray = UP,
direction: np.ndarray = DR,
**kwargs
) -> Tex:
return self.get_axis_label(
label_tex, self.get_y_axis(),
edge, direction, **kwargs
)
def get_axis_label(self, label_tex, axis, edge, direction, buff=MED_SMALL_BUFF):
def get_axis_label(
self,
label_tex: str,
axis: np.ndarray,
edge: np.ndarray,
direction: np.ndarray,
buff: float = MED_SMALL_BUFF
) -> Tex:
label = Tex(label_tex)
label.next_to(
axis.get_edge_center(edge), direction,
@ -97,30 +131,43 @@ class CoordinateSystem():
label.shift_onto_screen(buff=MED_SMALL_BUFF)
return label
def get_axis_labels(self, x_label_tex="x", y_label_tex="y"):
def get_axis_labels(
self,
x_label_tex: str = "x",
y_label_tex: str = "y"
) -> VGroup:
self.axis_labels = VGroup(
self.get_x_axis_label(x_label_tex),
self.get_y_axis_label(y_label_tex),
)
return self.axis_labels
def get_line_from_axis_to_point(self, index, point,
line_func=DashedLine,
color=GREY_A,
stroke_width=2):
def get_line_from_axis_to_point(
self,
index: int,
point: np.ndarray,
line_func: Type[T] = DashedLine,
color: ManimColor = GREY_A,
stroke_width: float = 2
) -> T:
axis = self.get_axis(index)
line = line_func(axis.get_projection(point), point)
line.set_stroke(color, stroke_width)
return line
def get_v_line(self, point, **kwargs):
def get_v_line(self, point: np.ndarray, **kwargs):
return self.get_line_from_axis_to_point(0, point, **kwargs)
def get_h_line(self, point, **kwargs):
def get_h_line(self, point: np.ndarray, **kwargs):
return self.get_line_from_axis_to_point(1, point, **kwargs)
# Useful for graphing
def get_graph(self, function, x_range=None, **kwargs):
def get_graph(
self,
function: Callable[[float], float],
x_range: Sequence[float] | None = None,
**kwargs
) -> ParametricCurve:
t_range = np.array(self.x_range, dtype=float)
if x_range is not None:
t_range[:len(x_range)] = x_range
@ -139,7 +186,11 @@ class CoordinateSystem():
graph.x_range = x_range
return graph
def get_parametric_curve(self, function, **kwargs):
def get_parametric_curve(
self,
function: Callable[[float], np.ndarray],
**kwargs
) -> ParametricCurve:
dim = self.dimension
graph = ParametricCurve(
lambda t: self.coords_to_point(*function(t)[:dim]),
@ -148,7 +199,11 @@ class CoordinateSystem():
graph.underlying_function = function
return graph
def input_to_graph_point(self, x, graph):
def input_to_graph_point(
self,
x: float,
graph: ParametricCurve
) -> np.ndarray | None:
if hasattr(graph, "underlying_function"):
return self.coords_to_point(x, graph.underlying_function(x))
else:
@ -165,19 +220,21 @@ class CoordinateSystem():
else:
return None
def i2gp(self, x, graph):
def i2gp(self, x: float, graph: ParametricCurve) -> np.ndarray | None:
"""
Alias for input_to_graph_point
"""
return self.input_to_graph_point(x, graph)
def get_graph_label(self,
graph,
label="f(x)",
x=None,
direction=RIGHT,
buff=MED_SMALL_BUFF,
color=None):
def get_graph_label(
self,
graph: ParametricCurve,
label: str | Mobject = "f(x)",
x: float | None = None,
direction: np.ndarray = RIGHT,
buff: float = MED_SMALL_BUFF,
color: ManimColor | None = None
) -> Tex | Mobject:
if isinstance(label, str):
label = Tex(label)
if color is None:
@ -204,39 +261,57 @@ class CoordinateSystem():
label.shift_onto_screen()
return label
def get_v_line_to_graph(self, x, graph, **kwargs):
def get_v_line_to_graph(self, x: float, graph: ParametricCurve, **kwargs):
return self.get_v_line(self.i2gp(x, graph), **kwargs)
def get_h_line_to_graph(self, x, graph, **kwargs):
def get_h_line_to_graph(self, x: float, graph: ParametricCurve, **kwargs):
return self.get_h_line(self.i2gp(x, graph), **kwargs)
# For calculus
def angle_of_tangent(self, x, graph, dx=EPSILON):
def angle_of_tangent(
self,
x: float,
graph: ParametricCurve,
dx: float = EPSILON
) -> float:
p0 = self.input_to_graph_point(x, graph)
p1 = self.input_to_graph_point(x + dx, graph)
return angle_of_vector(p1 - p0)
def slope_of_tangent(self, x, graph, **kwargs):
def slope_of_tangent(
self,
x: float,
graph: ParametricCurve,
**kwargs
) -> float:
return np.tan(self.angle_of_tangent(x, graph, **kwargs))
def get_tangent_line(self, x, graph, length=5, line_func=Line):
def get_tangent_line(
self,
x: float,
graph: ParametricCurve,
length: float = 5,
line_func: Type[T] = Line
) -> T:
line = line_func(LEFT, RIGHT)
line.set_width(length)
line.rotate(self.angle_of_tangent(x, graph))
line.move_to(self.input_to_graph_point(x, graph))
return line
def get_riemann_rectangles(self,
graph,
x_range=None,
dx=None,
input_sample_type="left",
stroke_width=1,
stroke_color=BLACK,
fill_opacity=1,
colors=(BLUE, GREEN),
stroke_background=True,
show_signed_area=True):
def get_riemann_rectangles(
self,
graph: ParametricCurve,
x_range: Sequence[float] = None,
dx: float | None = None,
input_sample_type: str = "left",
stroke_width: float = 1,
stroke_color: ManimColor = BLACK,
fill_opacity: float = 1,
colors: Iterable[ManimColor] = (BLUE, GREEN),
stroke_background: bool = True,
show_signed_area: bool = True
) -> VGroup:
if x_range is None:
x_range = self.x_range[:2]
if dx is None:
@ -291,10 +366,12 @@ class Axes(VGroup, CoordinateSystem):
"width": FRAME_WIDTH - 2,
}
def __init__(self,
x_range=None,
y_range=None,
**kwargs):
def __init__(
self,
x_range: Sequence[float] | None = None,
y_range: Sequence[float] | None = None,
**kwargs
):
CoordinateSystem.__init__(self, **kwargs)
VGroup.__init__(self, **kwargs)
@ -317,36 +394,43 @@ class Axes(VGroup, CoordinateSystem):
self.add(*self.axes)
self.center()
def create_axis(self, range_terms, axis_config, length):
def create_axis(
self,
range_terms: Sequence[float],
axis_config: dict[str],
length: float
) -> NumberLine:
new_config = merge_dicts_recursively(self.axis_config, axis_config)
new_config["width"] = length
axis = NumberLine(range_terms, **new_config)
axis.shift(-axis.n2p(0))
return axis
def coords_to_point(self, *coords):
def coords_to_point(self, *coords: float) -> np.ndarray:
origin = self.x_axis.number_to_point(0)
result = origin.copy()
for axis, coord in zip(self.get_axes(), coords):
result += (axis.number_to_point(coord) - origin)
return result
def point_to_coords(self, point):
def point_to_coords(self, point: np.ndarray) -> tuple[float, ...]:
return tuple([
axis.point_to_number(point)
for axis in self.get_axes()
])
def get_axes(self):
def get_axes(self) -> VGroup:
return self.axes
def get_all_ranges(self):
def get_all_ranges(self) -> list[Sequence[float]]:
return [self.x_range, self.y_range]
def add_coordinate_labels(self,
x_values=None,
y_values=None,
**kwargs):
def add_coordinate_labels(
self,
x_values: Iterable[float] | None = None,
y_values: Iterable[float] | None = None,
**kwargs
) -> VGroup:
axes = self.get_axes()
self.coordinate_labels = VGroup()
for axis, values in zip(axes, [x_values, y_values]):
@ -370,7 +454,13 @@ class ThreeDAxes(Axes):
"gloss": 0.5,
}
def __init__(self, x_range=None, y_range=None, z_range=None, **kwargs):
def __init__(
self,
x_range: Sequence[float] | None = None,
y_range: Sequence[float] | None = None,
z_range: Sequence[float] | None = None,
**kwargs
):
Axes.__init__(self, x_range, y_range, **kwargs)
if z_range is not None:
self.z_range[:len(z_range)] = z_range
@ -393,7 +483,7 @@ class ThreeDAxes(Axes):
for axis in self.axes:
axis.insert_n_curves(self.num_axis_pieces - 1)
def get_all_ranges(self):
def get_all_ranges(self) -> list[Sequence[float]]:
return [self.x_range, self.y_range, self.z_range]
@ -423,11 +513,16 @@ class NumberPlane(Axes):
"make_smooth_after_applying_functions": True,
}
def __init__(self, x_range=None, y_range=None, **kwargs):
def __init__(
self,
x_range: Sequence[float] | None = None,
y_range: Sequence[float] | None = None,
**kwargs
):
super().__init__(x_range, y_range, **kwargs)
self.init_background_lines()
def init_background_lines(self):
def init_background_lines(self) -> None:
if self.faded_line_style is None:
style = dict(self.background_line_style)
# For anything numerical, like stroke_width
@ -445,7 +540,7 @@ class NumberPlane(Axes):
self.background_lines,
)
def get_lines(self):
def get_lines(self) -> tuple[VGroup, VGroup]:
x_axis = self.get_x_axis()
y_axis = self.get_y_axis()
@ -455,7 +550,11 @@ class NumberPlane(Axes):
lines2 = VGroup(*x_lines2, *y_lines2)
return lines1, lines2
def get_lines_parallel_to_axis(self, axis1, axis2):
def get_lines_parallel_to_axis(
self,
axis1: NumberLine,
axis2: NumberLine
) -> tuple[VGroup, VGroup]:
freq = axis2.x_step
ratio = self.faded_line_ratio
line = Line(axis1.get_start(), axis1.get_end())
@ -474,20 +573,20 @@ class NumberPlane(Axes):
lines2.add(new_line)
return lines1, lines2
def get_x_unit_size(self):
def get_x_unit_size(self) -> float:
return self.get_x_axis().get_unit_size()
def get_y_unit_size(self):
def get_y_unit_size(self) -> list:
return self.get_x_axis().get_unit_size()
def get_axes(self):
def get_axes(self) -> VGroup:
return self.axes
def get_vector(self, coords, **kwargs):
def get_vector(self, coords: Iterable[float], **kwargs) -> Arrow:
kwargs["buff"] = 0
return Arrow(self.c2p(0, 0), self.c2p(*coords), **kwargs)
def prepare_for_nonlinear_transform(self, num_inserted_curves=50):
def prepare_for_nonlinear_transform(self, num_inserted_curves: int = 50):
for mob in self.family_members_with_points():
num_curves = mob.get_num_curves()
if num_inserted_curves > num_curves:
@ -502,27 +601,35 @@ class ComplexPlane(NumberPlane):
"line_frequency": 1,
}
def number_to_point(self, number):
def number_to_point(self, number: complex | float) -> np.ndarray:
number = complex(number)
return self.coords_to_point(number.real, number.imag)
def n2p(self, number):
def n2p(self, number: complex | float) -> np.ndarray:
return self.number_to_point(number)
def point_to_number(self, point):
def point_to_number(self, point: np.ndarray) -> complex:
x, y = self.point_to_coords(point)
return complex(x, y)
def p2n(self, point):
def p2n(self, point: np.ndarray) -> complex:
return self.point_to_number(point)
def get_default_coordinate_values(self, skip_first=True):
def get_default_coordinate_values(
self,
skip_first: bool = True
) -> list[complex]:
x_numbers = self.get_x_axis().get_tick_range()[1:]
y_numbers = self.get_y_axis().get_tick_range()[1:]
y_numbers = [complex(0, y) for y in y_numbers if y != 0]
return [*x_numbers, *y_numbers]
def add_coordinate_labels(self, numbers=None, skip_first=True, **kwargs):
def add_coordinate_labels(
self,
numbers: list[complex] | None = None,
skip_first: bool = True,
**kwargs
):
if numbers is None:
numbers = self.get_default_coordinate_values(skip_first)

View file

@ -1,3 +1,7 @@
from __future__ import annotations
from typing import Callable, Sequence
from isosurfaces import plot_isoline
from manimlib.constants import *
@ -14,7 +18,12 @@ class ParametricCurve(VMobject):
"use_smoothing": True,
}
def __init__(self, t_func, t_range=None, **kwargs):
def __init__(
self,
t_func: Callable[[float], np.ndarray],
t_range: Sequence[float] | None = None,
**kwargs
):
digest_config(self, kwargs)
if t_range is not None:
self.t_range[:len(t_range)] = t_range
@ -27,7 +36,7 @@ class ParametricCurve(VMobject):
self.t_func = t_func
VMobject.__init__(self, **kwargs)
def get_point_from_function(self, t):
def get_point_from_function(self, t: float) -> np.ndarray:
return self.t_func(t)
def init_points(self):
@ -67,7 +76,12 @@ class FunctionGraph(ParametricCurve):
"x_range": [-8, 8, 0.25],
}
def __init__(self, function, x_range=None, **kwargs):
def __init__(
self,
function: Callable[[float], float],
x_range: Sequence[float] | None = None,
**kwargs
):
digest_config(self, kwargs)
self.function = function
@ -89,7 +103,11 @@ class ImplicitFunction(VMobject):
"use_smoothing": True
}
def __init__(self, func, x_range=None, y_range=None, **kwargs):
def __init__(
self,
func: Callable[[float, float], float],
**kwargs
):
digest_config(self, kwargs)
self.function = func
super().__init__(**kwargs)

View file

@ -1,6 +1,11 @@
import numpy as np
from __future__ import annotations
import math
import numbers
from typing import Sequence, Union
import colour
import numpy as np
from manimlib.constants import *
from manimlib.mobject.mobject import Mobject
@ -21,6 +26,8 @@ from manimlib.utils.space_ops import normalize
from manimlib.utils.space_ops import rotate_vector
from manimlib.utils.space_ops import rotation_matrix_transpose
ManimColor = Union[str, colour.Color, Sequence[float]]
DEFAULT_DOT_RADIUS = 0.08
DEFAULT_SMALL_DOT_RADIUS = 0.04
@ -58,7 +65,7 @@ class TipableVMobject(VMobject):
}
# Adding, Creating, Modifying tips
def add_tip(self, at_start=False, **kwargs):
def add_tip(self, at_start: bool = False, **kwargs):
"""
Adds a tip to the TipableVMobject instance, recognising
that the endpoints might need to be switched if it's
@ -71,7 +78,7 @@ class TipableVMobject(VMobject):
self.add(tip)
return self
def create_tip(self, at_start=False, **kwargs):
def create_tip(self, at_start: bool = False, **kwargs) -> ArrowTip:
"""
Stylises the tip, positions it spacially, and returns
the newly instantiated tip to the caller.
@ -80,7 +87,7 @@ class TipableVMobject(VMobject):
self.position_tip(tip, at_start)
return tip
def get_unpositioned_tip(self, **kwargs):
def get_unpositioned_tip(self, **kwargs) -> ArrowTip:
"""
Returns a tip that has been stylistically configured,
but has not yet been given a position in space.
@ -90,7 +97,7 @@ class TipableVMobject(VMobject):
config.update(kwargs)
return ArrowTip(**config)
def position_tip(self, tip, at_start=False):
def position_tip(self, tip: ArrowTip, at_start: bool = False) -> ArrowTip:
# Last two control points, defining both
# the end, and the tangency direction
if at_start:
@ -103,7 +110,7 @@ class TipableVMobject(VMobject):
tip.shift(anchor - tip.get_tip_point())
return tip
def reset_endpoints_based_on_tip(self, tip, at_start):
def reset_endpoints_based_on_tip(self, tip: ArrowTip, at_start: bool):
if self.get_length() == 0:
# Zero length, put_start_and_end_on wouldn't
# work
@ -118,7 +125,7 @@ class TipableVMobject(VMobject):
self.put_start_and_end_on(start, end)
return self
def asign_tip_attr(self, tip, at_start):
def asign_tip_attr(self, tip: ArrowTip, at_start: bool):
if at_start:
self.start_tip = tip
else:
@ -126,14 +133,14 @@ class TipableVMobject(VMobject):
return self
# Checking for tips
def has_tip(self):
def has_tip(self) -> bool:
return hasattr(self, "tip") and self.tip in self
def has_start_tip(self):
def has_start_tip(self) -> bool:
return hasattr(self, "start_tip") and self.start_tip in self
# Getters
def pop_tips(self):
def pop_tips(self) -> VGroup:
start, end = self.get_start_and_end()
result = VGroup()
if self.has_tip():
@ -145,7 +152,7 @@ class TipableVMobject(VMobject):
self.put_start_and_end_on(start, end)
return result
def get_tips(self):
def get_tips(self) -> VGroup:
"""
Returns a VGroup (collection of VMobjects) containing
the TipableVMObject instance's tips.
@ -157,7 +164,7 @@ class TipableVMobject(VMobject):
result.add(self.start_tip)
return result
def get_tip(self):
def get_tip(self) -> ArrowTip:
"""Returns the TipableVMobject instance's (first) tip,
otherwise throws an exception."""
tips = self.get_tips()
@ -166,28 +173,28 @@ class TipableVMobject(VMobject):
else:
return tips[0]
def get_default_tip_length(self):
def get_default_tip_length(self) -> float:
return self.tip_length
def get_first_handle(self):
def get_first_handle(self) -> np.ndarray:
return self.get_points()[1]
def get_last_handle(self):
def get_last_handle(self) -> np.ndarray:
return self.get_points()[-2]
def get_end(self):
def get_end(self) -> np.ndarray:
if self.has_tip():
return self.tip.get_start()
else:
return VMobject.get_end(self)
def get_start(self):
def get_start(self) -> np.ndarray:
if self.has_start_tip():
return self.start_tip.get_start()
else:
return VMobject.get_start(self)
def get_length(self):
def get_length(self) -> float:
start, end = self.get_start_and_end()
return get_norm(start - end)
@ -200,12 +207,17 @@ class Arc(TipableVMobject):
"arc_center": ORIGIN,
}
def __init__(self, start_angle=0, angle=TAU / 4, **kwargs):
def __init__(
self,
start_angle: float = 0,
angle: float = TAU / 4,
**kwargs
):
self.start_angle = start_angle
self.angle = angle
VMobject.__init__(self, **kwargs)
def init_points(self):
def init_points(self) -> None:
self.set_points(Arc.create_quadratic_bezier_points(
angle=self.angle,
start_angle=self.start_angle,
@ -215,7 +227,11 @@ class Arc(TipableVMobject):
self.shift(self.arc_center)
@staticmethod
def create_quadratic_bezier_points(angle, start_angle=0, n_components=8):
def create_quadratic_bezier_points(
angle: float,
start_angle: float = 0,
n_components: int = 8
) -> np.ndarray:
samples = np.array([
[np.cos(a), np.sin(a), 0]
for a in np.linspace(
@ -233,7 +249,7 @@ class Arc(TipableVMobject):
points[2::3] = samples[2::2]
return points
def get_arc_center(self):
def get_arc_center(self) -> np.ndarray:
"""
Looks at the normals to the first two
anchors, and finds their intersection points
@ -248,21 +264,27 @@ class Arc(TipableVMobject):
n2 = rotate_vector(t2, TAU / 4)
return find_intersection(a1, n1, a2, n2)
def get_start_angle(self):
def get_start_angle(self) -> float:
angle = angle_of_vector(self.get_start() - self.get_arc_center())
return angle % TAU
def get_stop_angle(self):
def get_stop_angle(self) -> float:
angle = angle_of_vector(self.get_end() - self.get_arc_center())
return angle % TAU
def move_arc_center_to(self, point):
def move_arc_center_to(self, point: np.ndarray):
self.shift(point - self.get_arc_center())
return self
class ArcBetweenPoints(Arc):
def __init__(self, start, end, angle=TAU / 4, **kwargs):
def __init__(
self,
start: np.ndarray,
end: np.ndarray,
angle: float = TAU / 4,
**kwargs
):
super().__init__(angle=angle, **kwargs)
if angle == 0:
self.set_points_as_corners([LEFT, RIGHT])
@ -270,13 +292,23 @@ class ArcBetweenPoints(Arc):
class CurvedArrow(ArcBetweenPoints):
def __init__(self, start_point, end_point, **kwargs):
def __init__(
self,
start_point: np.ndarray,
end_point: np.ndarray,
**kwargs
):
ArcBetweenPoints.__init__(self, start_point, end_point, **kwargs)
self.add_tip()
class CurvedDoubleArrow(CurvedArrow):
def __init__(self, start_point, end_point, **kwargs):
def __init__(
self,
start_point: np.ndarray,
end_point: np.ndarray,
**kwargs
):
CurvedArrow.__init__(self, start_point, end_point, **kwargs)
self.add_tip(at_start=True)
@ -291,7 +323,13 @@ class Circle(Arc):
def __init__(self, **kwargs):
Arc.__init__(self, 0, TAU, **kwargs)
def surround(self, mobject, dim_to_match=0, stretch=False, buff=MED_SMALL_BUFF):
def surround(
self,
mobject: Mobject,
dim_to_match: int = 0,
stretch: bool = False,
buff: float = MED_SMALL_BUFF
):
# Ignores dim_to_match and stretch; result will always be a circle
# TODO: Perhaps create an ellipse class to handle singele-dimension stretching
@ -299,13 +337,13 @@ class Circle(Arc):
self.stretch((self.get_width() + 2 * buff) / self.get_width(), 0)
self.stretch((self.get_height() + 2 * buff) / self.get_height(), 1)
def point_at_angle(self, angle):
def point_at_angle(self, angle: float) -> np.ndarray:
start_angle = self.get_start_angle()
return self.point_from_proportion(
(angle - start_angle) / TAU
)
def get_radius(self):
def get_radius(self) -> float:
return get_norm(self.get_start() - self.get_center())
@ -317,7 +355,7 @@ class Dot(Circle):
"color": WHITE
}
def __init__(self, point=ORIGIN, **kwargs):
def __init__(self, point: np.ndarray = ORIGIN, **kwargs):
super().__init__(arc_center=point, **kwargs)
@ -401,15 +439,26 @@ class Line(TipableVMobject):
"path_arc": 0,
}
def __init__(self, start=LEFT, end=RIGHT, **kwargs):
def __init__(
self,
start: np.ndarray = LEFT,
end: np.ndarray = RIGHT,
**kwargs
):
digest_config(self, kwargs)
self.set_start_and_end_attrs(start, end)
super().__init__(**kwargs)
def init_points(self):
def init_points(self) -> None:
self.set_points_by_ends(self.start, self.end, self.buff, self.path_arc)
def set_points_by_ends(self, start, end, buff=0, path_arc=0):
def set_points_by_ends(
self,
start: np.ndarray,
end: np.ndarray,
buff: float = 0,
path_arc: float = 0
):
vect = end - start
dist = get_norm(vect)
if np.isclose(dist, 0):
@ -438,11 +487,11 @@ class Line(TipableVMobject):
self.set_points_as_corners([start, end])
return self
def set_path_arc(self, new_value):
def set_path_arc(self, new_value: float) -> None:
self.path_arc = new_value
self.init_points()
def set_start_and_end_attrs(self, start, end):
def set_start_and_end_attrs(self, start: np.ndarray, end: np.ndarray):
# If either start or end are Mobjects, this
# gives their centers
rough_start = self.pointify(start)
@ -454,7 +503,11 @@ class Line(TipableVMobject):
self.start = self.pointify(start, vect)
self.end = self.pointify(end, -vect)
def pointify(self, mob_or_point, direction=None):
def pointify(
self,
mob_or_point: Mobject | np.ndarray,
direction: np.ndarray | None = None
) -> np.ndarray:
"""
Take an argument passed into Line (or subclass) and turn
it into a 3d point.
@ -471,7 +524,7 @@ class Line(TipableVMobject):
result[:len(point)] = point
return result
def put_start_and_end_on(self, start, end):
def put_start_and_end_on(self, start: np.ndarray, end: np.ndarray):
curr_start, curr_end = self.get_start_and_end()
if np.isclose(curr_start, curr_end).all():
# Handle null lines more gracefully
@ -479,16 +532,16 @@ class Line(TipableVMobject):
return self
return super().put_start_and_end_on(start, end)
def get_vector(self):
def get_vector(self) -> np.ndarray:
return self.get_end() - self.get_start()
def get_unit_vector(self):
def get_unit_vector(self) -> np.ndarray:
return normalize(self.get_vector())
def get_angle(self):
def get_angle(self) -> float:
return angle_of_vector(self.get_vector())
def get_projection(self, point):
def get_projection(self, point: np.ndarray) -> np.ndarray:
"""
Return projection of a point onto the line
"""
@ -496,10 +549,10 @@ class Line(TipableVMobject):
start = self.get_start()
return start + np.dot(point - start, unit_vect) * unit_vect
def get_slope(self):
def get_slope(self) -> float:
return np.tan(self.get_angle())
def set_angle(self, angle, about_point=None):
def set_angle(self, angle: float, about_point: np.ndarray | None = None):
if about_point is None:
about_point = self.get_start()
self.rotate(
@ -508,7 +561,7 @@ class Line(TipableVMobject):
)
return self
def set_length(self, length, **kwargs):
def set_length(self, length: float, **kwargs):
self.scale(length / self.get_length(), **kwargs)
return self
@ -532,35 +585,35 @@ class DashedLine(Line):
self.clear_points()
self.add(*dashes)
def calculate_num_dashes(self, positive_space_ratio):
def calculate_num_dashes(self, positive_space_ratio: float) -> int:
try:
full_length = self.dash_length / positive_space_ratio
return int(np.ceil(self.get_length() / full_length))
except ZeroDivisionError:
return 1
def calculate_positive_space_ratio(self):
def calculate_positive_space_ratio(self) -> float:
return fdiv(
self.dash_length,
self.dash_length + self.dash_spacing,
)
def get_start(self):
def get_start(self) -> np.ndarray:
if len(self.submobjects) > 0:
return self.submobjects[0].get_start()
else:
return Line.get_start(self)
def get_end(self):
def get_end(self) -> np.ndarray:
if len(self.submobjects) > 0:
return self.submobjects[-1].get_end()
else:
return Line.get_end(self)
def get_first_handle(self):
def get_first_handle(self) -> np.ndarray:
return self.submobjects[0].get_points()[1]
def get_last_handle(self):
def get_last_handle(self) -> np.ndarray:
return self.submobjects[-1].get_points()[-2]
@ -570,7 +623,7 @@ class TangentLine(Line):
"d_alpha": 1e-6
}
def __init__(self, vmob, alpha, **kwargs):
def __init__(self, vmob: VMobject, alpha: float, **kwargs):
digest_config(self, kwargs)
da = self.d_alpha
a1 = clip(alpha - da, 0, 1)
@ -603,16 +656,22 @@ class Arrow(Line):
"buff": 0.25,
}
def set_points_by_ends(self, start, end, buff=0, path_arc=0):
def set_points_by_ends(
self,
start: np.ndarray,
end: np.ndarray,
buff: float = 0,
path_arc: float = 0
):
super().set_points_by_ends(start, end, buff, path_arc)
self.insert_tip_anchor()
return self
def init_colors(self):
def init_colors(self) -> None:
super().init_colors()
self.create_tip_with_stroke_width()
def get_arc_length(self):
def get_arc_length(self) -> float:
# Push up into Line?
arc_len = get_norm(self.get_vector())
if self.path_arc > 0:
@ -655,14 +714,19 @@ class Arrow(Line):
self.create_tip_with_stroke_width()
return self
def set_stroke(self, color=None, width=None, *args, **kwargs):
def set_stroke(
self,
color: ManimColor | None = None,
width: float | None = None,
*args, **kwargs
):
super().set_stroke(color=color, width=width, *args, **kwargs)
if isinstance(width, numbers.Number):
self.max_stroke_width = width
self.reset_tip()
return self
def _handle_scale_side_effects(self, scale_factor):
def _handle_scale_side_effects(self, scale_factor: float):
return self.reset_tip()
@ -679,7 +743,13 @@ class FillArrow(Line):
"max_width_to_length_ratio": 0.1,
}
def set_points_by_ends(self, start, end, buff=0, path_arc=0):
def set_points_by_ends(
self,
start: np.ndarray,
end: np.ndarray,
buff: float = 0,
path_arc: float = 0
) -> None:
# Find the right tip length and thickness
vect = end - start
length = max(get_norm(vect), 1e-8)
@ -748,15 +818,15 @@ class FillArrow(Line):
)
return self
def get_start(self):
def get_start(self) -> np.ndarray:
nppc = self.n_points_per_curve
points = self.get_points()
return (points[0] + points[-nppc]) / 2
def get_end(self):
def get_end(self) -> np.ndarray:
return self.get_points()[self.tip_index]
def put_start_and_end_on(self, start, end):
def put_start_and_end_on(self, start: np.ndarray, end: np.ndarray):
self.set_points_by_ends(start, end, buff=0, path_arc=self.path_arc)
return self
@ -765,12 +835,12 @@ class FillArrow(Line):
self.reset_points_around_ends()
return self
def set_thickness(self, thickness):
def set_thickness(self, thickness: float):
self.thickness = thickness
self.reset_points_around_ends()
return self
def set_path_arc(self, path_arc):
def set_path_arc(self, path_arc: float):
self.path_arc = path_arc
self.reset_points_around_ends()
return self
@ -781,7 +851,7 @@ class Vector(Arrow):
"buff": 0,
}
def __init__(self, direction=RIGHT, **kwargs):
def __init__(self, direction: np.ndarray = RIGHT, **kwargs):
if len(direction) == 2:
direction = np.hstack([direction, 0])
super().__init__(ORIGIN, direction, **kwargs)
@ -794,24 +864,31 @@ class DoubleArrow(Arrow):
class CubicBezier(VMobject):
def __init__(self, a0, h0, h1, a1, **kwargs):
def __init__(
self,
a0: np.ndarray,
h0: np.ndarray,
h1: np.ndarray,
a1: np.ndarray,
**kwargs
):
VMobject.__init__(self, **kwargs)
self.add_cubic_bezier_curve(a0, h0, h1, a1)
class Polygon(VMobject):
def __init__(self, *vertices, **kwargs):
def __init__(self, *vertices: np.ndarray, **kwargs):
self.vertices = vertices
super().__init__(**kwargs)
def init_points(self):
def init_points(self) -> None:
verts = self.vertices
self.set_points_as_corners([*verts, verts[0]])
def get_vertices(self):
def get_vertices(self) -> list[np.ndarray]:
return self.get_start_anchors()
def round_corners(self, radius=0.5):
def round_corners(self, radius: float = 0.5):
vertices = self.get_vertices()
arcs = []
for v1, v2, v3 in adjacent_n_tuples(vertices, 3):
@ -850,7 +927,7 @@ class Polygon(VMobject):
class Polyline(Polygon):
def init_points(self):
def init_points(self) -> None:
self.set_points_as_corners(self.vertices)
@ -859,7 +936,7 @@ class RegularPolygon(Polygon):
"start_angle": None,
}
def __init__(self, n=6, **kwargs):
def __init__(self, n: int = 6, **kwargs):
digest_config(self, kwargs, locals())
if self.start_angle is None:
# 0 for odd, 90 for even
@ -898,19 +975,19 @@ class ArrowTip(Triangle):
self.data["points"] = Dot().set_width(h).get_points()
self.rotate(self.angle)
def get_base(self):
def get_base(self) -> np.ndarray:
return self.point_from_proportion(0.5)
def get_tip_point(self):
def get_tip_point(self) -> np.ndarray:
return self.get_points()[0]
def get_vector(self):
def get_vector(self) -> np.ndarray:
return self.get_tip_point() - self.get_base()
def get_angle(self):
def get_angle(self) -> float:
return angle_of_vector(self.get_vector())
def get_length(self):
def get_length(self) -> float:
return get_norm(self.get_vector())
@ -923,7 +1000,12 @@ class Rectangle(Polygon):
"close_new_points": True,
}
def __init__(self, width=None, height=None, **kwargs):
def __init__(
self,
width: float | None = None,
height: float | None = None,
**kwargs
):
Polygon.__init__(self, UR, UL, DL, DR, **kwargs)
if width is None:
@ -936,7 +1018,7 @@ class Rectangle(Polygon):
class Square(Rectangle):
def __init__(self, side_length=2.0, **kwargs):
def __init__(self, side_length: float = 2.0, **kwargs):
self.side_length = side_length
super().__init__(side_length, side_length, **kwargs)

View file

@ -1,3 +1,7 @@
from __future__ import annotations
from typing import Callable
import numpy as np
from pyglet.window import key as PygletWindowKeys
@ -21,8 +25,7 @@ class MotionMobject(Mobject):
"""
You could hold and drag this object to any position
"""
def __init__(self, mobject, **kwargs):
def __init__(self, mobject: Mobject, **kwargs):
super().__init__(**kwargs)
assert(isinstance(mobject, Mobject))
self.mobject = mobject
@ -31,7 +34,7 @@ class MotionMobject(Mobject):
self.mobject.add_updater(lambda mob: None)
self.add(mobject)
def mob_on_mouse_drag(self, mob, event_data):
def mob_on_mouse_drag(self, mob: Mobject, event_data: dict[str, np.ndarray]) -> bool:
mob.move_to(event_data["point"])
return False
@ -43,7 +46,7 @@ class Button(Mobject):
The on_click method takes mobject as argument like updater
"""
def __init__(self, mobject, on_click, **kwargs):
def __init__(self, mobject: Mobject, on_click: Callable[[Mobject]], **kwargs):
super().__init__(**kwargs)
assert(isinstance(mobject, Mobject))
self.on_click = on_click
@ -51,7 +54,7 @@ class Button(Mobject):
self.mobject.add_mouse_press_listner(self.mob_on_mouse_press)
self.add(self.mobject)
def mob_on_mouse_press(self, mob, event_data):
def mob_on_mouse_press(self, mob: Mobject, event_data) -> bool:
self.on_click(mob)
return False
@ -59,7 +62,7 @@ class Button(Mobject):
# Controls
class ControlMobject(ValueTracker):
def __init__(self, value, *mobjects, **kwargs):
def __init__(self, value: float, *mobjects: Mobject, **kwargs):
super().__init__(value=value, **kwargs)
self.add(*mobjects)
@ -67,7 +70,7 @@ class ControlMobject(ValueTracker):
self.add_updater(lambda mob: None)
self.fix_in_frame()
def set_value(self, value):
def set_value(self, value: float):
self.assert_value(value)
self.set_value_anim(value)
return ValueTracker.set_value(self, value)
@ -93,25 +96,25 @@ class EnableDisableButton(ControlMobject):
"disable_color": RED
}
def __init__(self, value=True, **kwargs):
def __init__(self, value: bool = True, **kwargs):
digest_config(self, kwargs)
self.box = Rectangle(**self.rect_kwargs)
super().__init__(value, self.box, **kwargs)
self.add_mouse_press_listner(self.on_mouse_press)
def assert_value(self, value):
def assert_value(self, value: bool) -> None:
assert(isinstance(value, bool))
def set_value_anim(self, value):
def set_value_anim(self, value: bool) -> None:
if value:
self.box.set_fill(self.enable_color)
else:
self.box.set_fill(self.disable_color)
def toggle_value(self):
def toggle_value(self) -> None:
super().set_value(not self.get_value())
def on_mouse_press(self, mob, event_data):
def on_mouse_press(self, mob: Mobject, event_data) -> bool:
mob.toggle_value()
return False
@ -136,32 +139,32 @@ class Checkbox(ControlMobject):
"box_content_buff": SMALL_BUFF
}
def __init__(self, value=True, **kwargs):
def __init__(self, value: bool = True, **kwargs):
digest_config(self, kwargs)
self.box = Rectangle(**self.rect_kwargs)
self.box_content = self.get_checkmark() if value else self.get_cross()
super().__init__(value, self.box, self.box_content, **kwargs)
self.add_mouse_press_listner(self.on_mouse_press)
def assert_value(self, value):
def assert_value(self, value: bool) -> None:
assert(isinstance(value, bool))
def toggle_value(self):
def toggle_value(self) -> None:
super().set_value(not self.get_value())
def set_value_anim(self, value):
def set_value_anim(self, value: bool) -> None:
if value:
self.box_content.become(self.get_checkmark())
else:
self.box_content.become(self.get_cross())
def on_mouse_press(self, mob, event_data):
def on_mouse_press(self, mob: Mobject, event_data) -> None:
mob.toggle_value()
return False
# Helper methods
def get_checkmark(self):
def get_checkmark(self) -> VGroup:
checkmark = VGroup(
Line(UP / 2 + 2 * LEFT, DOWN + LEFT, **self.checkmark_kwargs),
Line(DOWN + LEFT, UP + RIGHT, **self.checkmark_kwargs)
@ -173,7 +176,7 @@ class Checkbox(ControlMobject):
checkmark.move_to(self.box)
return checkmark
def get_cross(self):
def get_cross(self) -> VGroup:
cross = VGroup(
Line(UP + LEFT, DOWN + RIGHT, **self.cross_kwargs),
Line(UP + RIGHT, DOWN + LEFT, **self.cross_kwargs)
@ -206,7 +209,7 @@ class LinearNumberSlider(ControlMobject):
}
}
def __init__(self, value=0, **kwargs):
def __init__(self, value: float = 0, **kwargs):
digest_config(self, kwargs)
self.bar = RoundedRectangle(**self.rounded_rect_kwargs)
self.slider = Circle(**self.circle_kwargs)
@ -219,22 +222,22 @@ class LinearNumberSlider(ControlMobject):
self.slider.add_mouse_drag_listner(self.slider_on_mouse_drag)
super().__init__(value, self.bar, self.slider, self.slider_axis, ** kwargs)
super().__init__(value, self.bar, self.slider, self.slider_axis, **kwargs)
def assert_value(self, value):
def assert_value(self, value: float) -> None:
assert(self.min_value <= value <= self.max_value)
def set_value_anim(self, value):
def set_value_anim(self, value: float) -> None:
prop = (value - self.min_value) / (self.max_value - self.min_value)
self.slider.move_to(self.slider_axis.point_from_proportion(prop))
def slider_on_mouse_drag(self, mob, event_data):
def slider_on_mouse_drag(self, mob, event_data: dict[str, np.ndarray]) -> bool:
self.set_value(self.get_value_from_point(event_data["point"]))
return False
# Helper Methods
def get_value_from_point(self, point):
def get_value_from_point(self, point: np.ndarray) -> float:
start, end = self.slider_axis.get_start_and_end()
point_on_line = get_closest_point_on_line(start, end, point)
prop = get_norm(point_on_line - start) / get_norm(end - start)
@ -300,7 +303,7 @@ class ColorSliders(Group):
self.arrange(DOWN)
def get_background(self):
def get_background(self) -> VGroup:
single_square_len = self.background_grid_kwargs["single_square_len"]
colors = self.background_grid_kwargs["colors"]
width = self.rect_kwargs["width"]
@ -322,24 +325,24 @@ class ColorSliders(Group):
return grid
def set_value(self, r, g, b, a):
def set_value(self, r: float, g: float, b: float, a: float):
self.r_slider.set_value(r)
self.g_slider.set_value(g)
self.b_slider.set_value(b)
self.a_slider.set_value(a)
def get_value(self):
def get_value(self) -> np.ndarary:
r = self.r_slider.get_value() / 255
g = self.g_slider.get_value() / 255
b = self.b_slider.get_value() / 255
alpha = self.a_slider.get_value()
return color_to_rgba(rgb_to_color((r, g, b)), alpha=alpha)
def get_picked_color(self):
def get_picked_color(self) -> str:
rgba = self.get_value()
return rgb_to_hex(rgba[:3])
def get_picked_opacity(self):
def get_picked_opacity(self) -> float:
rgba = self.get_value()
return rgba[3]
@ -363,7 +366,7 @@ class Textbox(ControlMobject):
"deactive_color": RED,
}
def __init__(self, value="", **kwargs):
def __init__(self, value: str = "", **kwargs):
digest_config(self, kwargs)
self.isActive = self.isInitiallyActive
self.box = Rectangle(**self.box_kwargs)
@ -374,10 +377,10 @@ class Textbox(ControlMobject):
self.active_anim(self.isActive)
self.add_key_press_listner(self.on_key_press)
def set_value_anim(self, value):
def set_value_anim(self, value: str) -> None:
self.update_text(value)
def update_text(self, value):
def update_text(self, value: str) -> None:
text = self.text
self.remove(text)
text.__init__(value, **self.text_kwargs)
@ -389,18 +392,18 @@ class Textbox(ControlMobject):
text.fix_in_frame()
self.add(text)
def active_anim(self, isActive):
def active_anim(self, isActive: bool) -> None:
if isActive:
self.box.set_stroke(self.active_color)
else:
self.box.set_stroke(self.deactive_color)
def box_on_mouse_press(self, mob, event_data):
def box_on_mouse_press(self, mob, event_data) -> bool:
self.isActive = not self.isActive
self.active_anim(self.isActive)
return False
def on_key_press(self, mob, event_data):
def on_key_press(self, mob: Mobject, event_data: dict[str, int]) -> bool | None:
symbol = event_data["symbol"]
modifiers = event_data["modifiers"]
char = chr(symbol)
@ -443,7 +446,7 @@ class ControlPanel(Group):
}
}
def __init__(self, *controls, **kwargs):
def __init__(self, *controls: ControlMobject, **kwargs):
digest_config(self, kwargs)
self.panel = Rectangle(**self.panel_kwargs)
@ -472,7 +475,7 @@ class ControlPanel(Group):
self.move_panel_and_controls_to_panel_opener()
self.fix_in_frame()
def move_panel_and_controls_to_panel_opener(self):
def move_panel_and_controls_to_panel_opener(self) -> None:
self.panel.next_to(
self.panel_opener_rect,
direction=UP,
@ -488,11 +491,11 @@ class ControlPanel(Group):
self.controls.set_x(controls_old_x)
def add_controls(self, *new_controls):
def add_controls(self, *new_controls: ControlMobject) -> None:
self.controls.add(*new_controls)
self.move_panel_and_controls_to_panel_opener()
def remove_controls(self, *controls_to_remove):
def remove_controls(self, *controls_to_remove: ControlMobject) -> None:
self.controls.remove(*controls_to_remove)
self.move_panel_and_controls_to_panel_opener()
@ -510,13 +513,13 @@ class ControlPanel(Group):
self.move_panel_and_controls_to_panel_opener()
return self
def panel_opener_on_mouse_drag(self, mob, event_data):
def panel_opener_on_mouse_drag(self, mob, event_data: dict[str, np.ndarray]) -> bool:
point = event_data["point"]
self.panel_opener.match_y(Dot(point))
self.move_panel_and_controls_to_panel_opener()
return False
def panel_on_mouse_scroll(self, mob, event_data):
def panel_on_mouse_scroll(self, mob, event_data: dict[str, np.ndarray]) -> bool:
offset = event_data["offset"]
factor = 10 * offset[1]
self.controls.set_y(self.controls.get_y() + factor)

View file

@ -1,5 +1,10 @@
import numpy as np
from __future__ import annotations
import itertools as it
from typing import Union, Sequence
import numpy as np
import numpy.typing as npt
from manimlib.constants import *
from manimlib.mobject.numbers import DecimalNumber
@ -10,10 +15,18 @@ from manimlib.mobject.svg.tex_mobject import TexText
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import colour
from manimlib.mobject.mobject import Mobject
ManimColor = Union[str, colour.Color, Sequence[float]]
VECTOR_LABEL_SCALE_FACTOR = 0.8
def matrix_to_tex_string(matrix):
def matrix_to_tex_string(matrix: npt.ArrayLike) -> str:
matrix = np.array(matrix).astype("str")
if matrix.ndim == 1:
matrix = matrix.reshape((matrix.size, 1))
@ -27,12 +40,16 @@ def matrix_to_tex_string(matrix):
return prefix + " \\\\ ".join(rows) + suffix
def matrix_to_mobject(matrix):
def matrix_to_mobject(matrix: npt.ArrayLike) -> Tex:
return Tex(matrix_to_tex_string(matrix))
def vector_coordinate_label(vector_mob, integer_labels=True,
n_dim=2, color=WHITE):
def vector_coordinate_label(
vector_mob: VMobject,
integer_labels: bool = True,
n_dim: int = 2,
color: ManimColor = WHITE
) -> Matrix:
vect = np.array(vector_mob.get_end())
if integer_labels:
vect = np.round(vect).astype(int)
@ -66,7 +83,7 @@ class Matrix(VMobject):
"element_alignment_corner": DOWN,
}
def __init__(self, matrix, **kwargs):
def __init__(self, matrix: npt.ArrayLike, **kwargs):
"""
Matrix can either include numbers, tex_strings,
or mobjects
@ -87,7 +104,7 @@ class Matrix(VMobject):
if self.include_background_rectangle:
self.add_background_rectangle()
def matrix_to_mob_matrix(self, matrix):
def matrix_to_mob_matrix(self, matrix: npt.ArrayLike) -> list[list[Mobject]]:
return [
[
self.element_to_mobject(item, **self.element_to_mobject_config)
@ -96,7 +113,7 @@ class Matrix(VMobject):
for row in matrix
]
def organize_mob_matrix(self, matrix):
def organize_mob_matrix(self, matrix: npt.ArrayLike):
for i, row in enumerate(matrix):
for j, elem in enumerate(row):
mob = matrix[i][j]
@ -126,19 +143,19 @@ class Matrix(VMobject):
self.brackets = VGroup(l_bracket, r_bracket)
return self
def get_columns(self):
def get_columns(self) -> VGroup:
return VGroup(*[
VGroup(*[row[i] for row in self.mob_matrix])
for i in range(len(self.mob_matrix[0]))
])
def get_rows(self):
def get_rows(self) -> VGroup:
return VGroup(*[
VGroup(*row)
for row in self.mob_matrix
])
def set_column_colors(self, *colors):
def set_column_colors(self, *colors: ManimColor):
columns = self.get_columns()
for color, column in zip(colors, columns):
column.set_color(color)
@ -149,13 +166,13 @@ class Matrix(VMobject):
mob.add_background_rectangle()
return self
def get_mob_matrix(self):
def get_mob_matrix(self) -> list[list[Mobject]]:
return self.mob_matrix
def get_entries(self):
def get_entries(self) -> VGroup:
return self.elements
def get_brackets(self):
def get_brackets(self) -> VGroup:
return self.brackets
@ -179,7 +196,12 @@ class MobjectMatrix(Matrix):
}
def get_det_text(matrix, determinant=None, background_rect=False, initial_scale_factor=2):
def get_det_text(
matrix: Matrix,
determinant: int | str | None = None,
background_rect: bool = False,
initial_scale_factor: int = 2
) -> VGroup:
parens = Tex("(", ")")
parens.scale(initial_scale_factor)
parens.stretch_to_fit_height(matrix.get_height())

File diff suppressed because it is too large Load diff

View file

@ -1,10 +1,19 @@
from __future__ import annotations
import inspect
from typing import Callable
from manimlib.constants import DEGREES
from manimlib.constants import RIGHT
from manimlib.mobject.mobject import Mobject
from manimlib.utils.simple_functions import clip
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import numpy as np
from manimlib.animation.animation import Animation
def assert_is_mobject_method(method):
assert(inspect.ismethod(method))
@ -41,27 +50,39 @@ def f_always(method, *arg_generators, **kwargs):
return mobject
def always_redraw(func, *args, **kwargs):
def always_redraw(func: Callable[..., Mobject], *args, **kwargs) -> Mobject:
mob = func(*args, **kwargs)
mob.add_updater(lambda m: mob.become(func(*args, **kwargs)))
return mob
def always_shift(mobject, direction=RIGHT, rate=0.1):
def always_shift(
mobject: Mobject,
direction: np.ndarray = RIGHT,
rate: float = 0.1
) -> Mobject:
mobject.add_updater(
lambda m, dt: m.shift(dt * rate * direction)
)
return mobject
def always_rotate(mobject, rate=20 * DEGREES, **kwargs):
def always_rotate(
mobject: Mobject,
rate: float = 20 * DEGREES,
**kwargs
) -> Mobject:
mobject.add_updater(
lambda m, dt: m.rotate(dt * rate, **kwargs)
)
return mobject
def turn_animation_into_updater(animation, cycle=False, **kwargs):
def turn_animation_into_updater(
animation: Animation,
cycle: bool = False,
**kwargs
) -> Mobject:
"""
Add an updater to the animation's mobject which applies
the interpolation and update functions of the animation
@ -94,7 +115,7 @@ def turn_animation_into_updater(animation, cycle=False, **kwargs):
return mobject
def cycle_animation(animation, **kwargs):
def cycle_animation(animation: Animation, **kwargs) -> Mobject:
return turn_animation_into_updater(
animation, cycle=True, **kwargs
)

View file

@ -1,3 +1,7 @@
from __future__ import annotations
from typing import Iterable, Sequence
from manimlib.constants import *
from manimlib.mobject.geometry import Line
from manimlib.mobject.numbers import DecimalNumber
@ -38,7 +42,7 @@ class NumberLine(Line):
"numbers_to_exclude": None
}
def __init__(self, x_range=None, **kwargs):
def __init__(self, x_range: Sequence[float] | None = None, **kwargs):
digest_config(self, kwargs)
if x_range is None:
x_range = self.x_range
@ -48,9 +52,9 @@ class NumberLine(Line):
x_min, x_max, x_step = x_range
# A lot of old scenes pass in x_min or x_max explicitly,
# so this is just here to keep those workin
self.x_min = kwargs.get("x_min", x_min)
self.x_max = kwargs.get("x_max", x_max)
self.x_step = kwargs.get("x_step", x_step)
self.x_min: float = kwargs.get("x_min", x_min)
self.x_max: float = kwargs.get("x_max", x_max)
self.x_step: float = kwargs.get("x_step", x_step)
super().__init__(self.x_min * RIGHT, self.x_max * RIGHT, **kwargs)
if self.width:
@ -71,14 +75,14 @@ class NumberLine(Line):
if self.include_numbers:
self.add_numbers(excluding=self.numbers_to_exclude)
def get_tick_range(self):
def get_tick_range(self) -> np.ndarray:
if self.include_tip:
x_max = self.x_max
else:
x_max = self.x_max + self.x_step
return np.arange(self.x_min, x_max, self.x_step)
def add_ticks(self):
def add_ticks(self) -> None:
ticks = VGroup()
for x in self.get_tick_range():
size = self.tick_size
@ -88,7 +92,7 @@ class NumberLine(Line):
self.add(ticks)
self.ticks = ticks
def get_tick(self, x, size=None):
def get_tick(self, x: float, size: float | None = None) -> Line:
if size is None:
size = self.tick_size
result = Line(size * DOWN, size * UP)
@ -97,14 +101,14 @@ class NumberLine(Line):
result.match_style(self)
return result
def get_tick_marks(self):
def get_tick_marks(self) -> VGroup:
return self.ticks
def number_to_point(self, number):
def number_to_point(self, number: float) -> np.ndarray:
alpha = float(number - self.x_min) / (self.x_max - self.x_min)
return interpolate(self.get_start(), self.get_end(), alpha)
def point_to_number(self, point):
def point_to_number(self, point: np.ndarray) -> float:
points = self.get_points()
start = points[0]
end = points[-1]
@ -115,21 +119,24 @@ class NumberLine(Line):
)
return interpolate(self.x_min, self.x_max, proportion)
def n2p(self, number):
def n2p(self, number: float) -> np.ndarray:
"""Abbreviation for number_to_point"""
return self.number_to_point(number)
def p2n(self, point):
def p2n(self, point: np.ndarray) -> float:
"""Abbreviation for point_to_number"""
return self.point_to_number(point)
def get_unit_size(self):
def get_unit_size(self) -> float:
return self.get_length() / (self.x_max - self.x_min)
def get_number_mobject(self, x,
direction=None,
buff=None,
**number_config):
def get_number_mobject(
self,
x: float,
direction: np.ndarray | None = None,
buff: float | None = None,
**number_config
) -> DecimalNumber:
number_config = merge_dicts_recursively(
self.decimal_number_config, number_config
)
@ -149,7 +156,13 @@ class NumberLine(Line):
num_mob.shift(num_mob[0].get_width() * LEFT / 2)
return num_mob
def add_numbers(self, x_values=None, excluding=None, font_size=24, **kwargs):
def add_numbers(
self,
x_values: Iterable[float] | None = None,
excluding: Iterable[float] | None =None,
font_size: int = 24,
**kwargs
) -> VGroup:
if x_values is None:
x_values = self.get_tick_range()

View file

@ -1,11 +1,16 @@
from __future__ import annotations
from typing import TypeVar, Type
from manimlib.constants import *
from manimlib.mobject.svg.tex_mobject import SingleStringTex
from manimlib.mobject.svg.text_mobject import Text
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.iterables import hash_obj
T = TypeVar("T", bound=VMobject)
string_to_mob_map = {}
string_to_mob_map: dict[str, VMobject] = {}
class DecimalNumber(VMobject):
@ -24,11 +29,11 @@ class DecimalNumber(VMobject):
"text_config": {} # Do not pass in font_size here
}
def __init__(self, number=0, **kwargs):
def __init__(self, number: float | complex = 0, **kwargs):
super().__init__(**kwargs)
self.set_submobjects_from_number(number)
def set_submobjects_from_number(self, number):
def set_submobjects_from_number(self, number: float | complex) -> None:
self.number = number
self.set_submobjects([])
string_to_mob_ = lambda s: self.string_to_mob(s, **self.text_config)
@ -63,7 +68,7 @@ class DecimalNumber(VMobject):
if self.include_background_rectangle:
self.add_background_rectangle()
def get_num_string(self, number):
def get_num_string(self, number: float | complex) -> str:
if isinstance(number, complex):
formatter = self.get_complex_formatter()
else:
@ -79,21 +84,21 @@ class DecimalNumber(VMobject):
num_string = num_string.replace("-", "")
return num_string
def init_data(self):
def init_data(self) -> None:
super().init_data()
self.data["font_size"] = np.array([self.font_size], dtype=float)
def get_font_size(self):
def get_font_size(self) -> float:
return self.data["font_size"][0]
def string_to_mob(self, string, mob_class=Text, **kwargs):
def string_to_mob(self, string: str, mob_class: Type[T] = Text, **kwargs) -> T:
if (string, hash_obj(kwargs)) not in string_to_mob_map:
string_to_mob_map[(string, hash_obj(kwargs))] = mob_class(string, font_size=1, **kwargs)
mob = string_to_mob_map[(string, hash_obj(kwargs))].copy()
mob.scale(self.get_font_size())
return mob
def get_formatter(self, **kwargs):
def get_formatter(self, **kwargs) -> str:
"""
Configuration is based first off instance attributes,
but overwritten by any kew word argument. Relevant
@ -122,14 +127,14 @@ class DecimalNumber(VMobject):
"}",
])
def get_complex_formatter(self, **kwargs):
def get_complex_formatter(self, **kwargs) -> str:
return "".join([
self.get_formatter(field_name="0.real"),
self.get_formatter(field_name="0.imag", include_sign=True),
"i"
])
def set_value(self, number):
def set_value(self, number: float | complex):
move_to_point = self.get_edge_center(self.edge_to_fix)
old_submobjects = list(self.submobjects)
self.set_submobjects_from_number(number)
@ -138,13 +143,13 @@ class DecimalNumber(VMobject):
sm1.match_style(sm2)
return self
def _handle_scale_side_effects(self, scale_factor):
def _handle_scale_side_effects(self, scale_factor: float) -> None:
self.data["font_size"] *= scale_factor
def get_value(self):
def get_value(self) -> float | complex:
return self.number
def increment_value(self, delta_t=1):
def increment_value(self, delta_t: float | complex = 1) -> None:
self.set_value(self.get_value() + delta_t)
@ -153,5 +158,5 @@ class Integer(DecimalNumber):
"num_decimal_places": 0,
}
def get_value(self):
def get_value(self) -> int:
return int(np.round(super().get_value()))

View file

@ -1,3 +1,8 @@
from __future__ import annotations
from typing import Iterable, Union, Sequence
import colour
from manimlib.constants import *
from manimlib.mobject.geometry import Line
from manimlib.mobject.geometry import Rectangle
@ -9,6 +14,8 @@ from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.color import color_gradient
from manimlib.utils.iterables import listify
ManimColor = Union[str, colour.Color, Sequence[float]]
EPSILON = 0.0001
@ -24,7 +31,11 @@ class SampleSpace(Rectangle):
"default_label_scale_val": 1,
}
def add_title(self, title="Sample space", buff=MED_SMALL_BUFF):
def add_title(
self,
title: str = "Sample space",
buff: float = MED_SMALL_BUFF
) -> None:
# TODO, should this really exist in SampleSpaceScene
title_mob = TexText(title)
if title_mob.get_width() > self.get_width():
@ -33,17 +44,23 @@ class SampleSpace(Rectangle):
self.title = title_mob
self.add(title_mob)
def add_label(self, label):
def add_label(self, label: str) -> None:
self.label = label
def complete_p_list(self, p_list):
def complete_p_list(self, p_list: list[float]) -> list[float]:
new_p_list = listify(p_list)
remainder = 1.0 - sum(new_p_list)
if abs(remainder) > EPSILON:
new_p_list.append(remainder)
return new_p_list
def get_division_along_dimension(self, p_list, dim, colors, vect):
def get_division_along_dimension(
self,
p_list: list[float],
dim: int,
colors: Iterable[ManimColor],
vect: np.ndarray
) -> VGroup:
p_list = self.complete_p_list(p_list)
colors = color_gradient(colors, len(p_list))
@ -60,38 +77,41 @@ class SampleSpace(Rectangle):
return parts
def get_horizontal_division(
self, p_list,
colors=[GREEN_E, BLUE_E],
vect=DOWN
):
self,
p_list: list[float],
colors: Iterable[ManimColor] = [GREEN_E, BLUE_E],
vect: np.ndarray = DOWN
) -> VGroup:
return self.get_division_along_dimension(p_list, 1, colors, vect)
def get_vertical_division(
self, p_list,
colors=[MAROON_B, YELLOW],
vect=RIGHT
):
self,
p_list: list[float],
colors: Iterable[ManimColor] = [MAROON_B, YELLOW],
vect: np.ndarray = RIGHT
) -> VGroup:
return self.get_division_along_dimension(p_list, 0, colors, vect)
def divide_horizontally(self, *args, **kwargs):
def divide_horizontally(self, *args, **kwargs) -> None:
self.horizontal_parts = self.get_horizontal_division(*args, **kwargs)
self.add(self.horizontal_parts)
def divide_vertically(self, *args, **kwargs):
def divide_vertically(self, *args, **kwargs) -> None:
self.vertical_parts = self.get_vertical_division(*args, **kwargs)
self.add(self.vertical_parts)
def get_subdivision_braces_and_labels(
self, parts, labels, direction,
buff=SMALL_BUFF,
min_num_quads=1
):
self,
parts: VGroup,
labels: str,
direction: np.ndarray,
buff: float = SMALL_BUFF,
) -> VGroup:
label_mobs = VGroup()
braces = VGroup()
for label, part in zip(labels, parts):
brace = Brace(
part, direction,
min_num_quads=min_num_quads,
buff=buff
)
if isinstance(label, Mobject):
@ -112,22 +132,35 @@ class SampleSpace(Rectangle):
}
return VGroup(parts.braces, parts.labels)
def get_side_braces_and_labels(self, labels, direction=LEFT, **kwargs):
def get_side_braces_and_labels(
self,
labels: str,
direction: np.ndarray = LEFT,
**kwargs
) -> VGroup:
assert(hasattr(self, "horizontal_parts"))
parts = self.horizontal_parts
return self.get_subdivision_braces_and_labels(parts, labels, direction, **kwargs)
def get_top_braces_and_labels(self, labels, **kwargs):
def get_top_braces_and_labels(
self,
labels: str,
**kwargs
) -> VGroup:
assert(hasattr(self, "vertical_parts"))
parts = self.vertical_parts
return self.get_subdivision_braces_and_labels(parts, labels, UP, **kwargs)
def get_bottom_braces_and_labels(self, labels, **kwargs):
def get_bottom_braces_and_labels(
self,
labels: str,
**kwargs
) -> VGroup:
assert(hasattr(self, "vertical_parts"))
parts = self.vertical_parts
return self.get_subdivision_braces_and_labels(parts, labels, DOWN, **kwargs)
def add_braces_and_labels(self):
def add_braces_and_labels(self) -> None:
for attr in "horizontal_parts", "vertical_parts":
if not hasattr(self, attr):
continue
@ -136,7 +169,7 @@ class SampleSpace(Rectangle):
if hasattr(parts, subattr):
self.add(getattr(parts, subattr))
def __getitem__(self, index):
def __getitem__(self, index: int | slice) -> VGroup:
if hasattr(self, "horizontal_parts"):
return self.horizontal_parts[index]
elif hasattr(self, "vertical_parts"):
@ -162,7 +195,7 @@ class BarChart(VGroup):
"bar_label_scale_val": 0.75,
}
def __init__(self, values, **kwargs):
def __init__(self, values: Iterable[float], **kwargs):
VGroup.__init__(self, **kwargs)
if self.max_value is None:
self.max_value = max(values)
@ -172,7 +205,7 @@ class BarChart(VGroup):
self.add_bars(values)
self.center()
def add_axes(self):
def add_axes(self) -> None:
x_axis = Line(self.tick_width * LEFT / 2, self.width * RIGHT)
y_axis = Line(MED_LARGE_BUFF * DOWN, self.height * UP)
y_ticks = VGroup()
@ -209,7 +242,7 @@ class BarChart(VGroup):
self.y_axis_labels = labels
self.add(labels)
def add_bars(self, values):
def add_bars(self, values: Iterable[float]) -> None:
buff = float(self.width) / (2 * len(values))
bars = VGroup()
for i, value in enumerate(values):
@ -234,7 +267,7 @@ class BarChart(VGroup):
self.bars = bars
self.bar_labels = bar_labels
def change_bar_values(self, values):
def change_bar_values(self, values: Iterable[float]) -> None:
for bar, value in zip(self.bars, values):
bar_bottom = bar.get_bottom()
bar.stretch_to_fit_height(

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from manimlib.constants import *
from manimlib.mobject.geometry import Line
from manimlib.mobject.geometry import Rectangle
@ -7,6 +9,13 @@ from manimlib.utils.color import Color
from manimlib.utils.customization import get_customization
from manimlib.utils.config_ops import digest_config
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Union, Sequence
from manimlib.mobject.mobject import Mobject
ManimColor = Union[str, Color, Sequence[float]]
class SurroundingRectangle(Rectangle):
CONFIG = {
@ -14,7 +23,7 @@ class SurroundingRectangle(Rectangle):
"buff": SMALL_BUFF,
}
def __init__(self, mobject, **kwargs):
def __init__(self, mobject: Mobject, **kwargs):
digest_config(self, kwargs)
kwargs["width"] = mobject.get_width() + 2 * self.buff
kwargs["height"] = mobject.get_height() + 2 * self.buff
@ -30,23 +39,24 @@ class BackgroundRectangle(SurroundingRectangle):
"buff": 0
}
def __init__(self, mobject, color=None, **kwargs):
def __init__(self, mobject: Mobject, color: ManimColor = None, **kwargs):
if color is None:
color = get_customization()['style']['background_color']
SurroundingRectangle.__init__(self, mobject, color=color, **kwargs)
self.original_fill_opacity = self.fill_opacity
def pointwise_become_partial(self, mobject, a, b):
def pointwise_become_partial(self, mobject: Mobject, a: float, b: float):
self.set_fill(opacity=b * self.original_fill_opacity)
return self
def set_style_data(self,
stroke_color=None,
stroke_width=None,
fill_color=None,
fill_opacity=None,
family=True
):
def set_style_data(
self,
stroke_color: ManimColor | None = None,
stroke_width: float | None = None,
fill_color: ManimColor | None = None,
fill_opacity: float | None = None,
family: bool = True
):
# Unchangeable style, except for fill_opacity
VMobject.set_style_data(
self,
@ -57,7 +67,7 @@ class BackgroundRectangle(SurroundingRectangle):
)
return self
def get_fill_color(self):
def get_fill_color(self) -> Color:
return Color(self.color)
@ -67,7 +77,7 @@ class Cross(VGroup):
"stroke_width": [0, 6, 0],
}
def __init__(self, mobject, **kwargs):
def __init__(self, mobject: Mobject, **kwargs):
super().__init__(
Line(UL, DR),
Line(UR, DL),
@ -82,7 +92,7 @@ class Underline(Line):
"buff": SMALL_BUFF,
}
def __init__(self, mobject, **kwargs):
def __init__(self, mobject: Mobject, **kwargs):
super().__init__(LEFT, RIGHT, **kwargs)
self.match_width(mobject)
self.next_to(mobject, DOWN, buff=self.buff)

View file

@ -1,11 +1,15 @@
import numpy as np
from __future__ import annotations
import math
import copy
from typing import Iterable
import numpy as np
from manimlib.animation.composition import AnimationGroup
from manimlib.constants import *
from manimlib.animation.fading import FadeIn
from manimlib.animation.growing import GrowFromCenter
from manimlib.animation.composition import AnimationGroup
from manimlib.mobject.svg.tex_mobject import Tex
from manimlib.mobject.svg.tex_mobject import SingleStringTex
from manimlib.mobject.svg.tex_mobject import TexText
@ -14,6 +18,11 @@ from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.config_ops import digest_config
from manimlib.utils.space_ops import get_norm
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.mobject.mobject import Mobject
from manimlib.animation.animation import Animation
class Brace(SingleStringTex):
CONFIG = {
@ -21,7 +30,12 @@ class Brace(SingleStringTex):
"tex_string": r"\underbrace{\qquad}"
}
def __init__(self, mobject, direction=DOWN, **kwargs):
def __init__(
self,
mobject: Mobject,
direction: np.ndarray = DOWN,
**kwargs
):
digest_config(self, kwargs, locals())
angle = -math.atan2(*direction[:2]) + PI
mobject.rotate(-angle, about_point=ORIGIN)
@ -36,7 +50,7 @@ class Brace(SingleStringTex):
for mob in mobject, self:
mob.rotate(angle, about_point=ORIGIN)
def set_initial_width(self, width):
def set_initial_width(self, width: float):
width_diff = width - self.get_width()
if width_diff > 0:
for tip, rect, vect in [(self[0], self[1], RIGHT), (self[5], self[4], LEFT)]:
@ -49,7 +63,12 @@ class Brace(SingleStringTex):
self.set_width(width, stretch=True)
return self
def put_at_tip(self, mob, use_next_to=True, **kwargs):
def put_at_tip(
self,
mob: Mobject,
use_next_to: bool = True,
**kwargs
):
if use_next_to:
mob.next_to(
self.get_tip(),
@ -63,24 +82,24 @@ class Brace(SingleStringTex):
mob.shift(self.get_direction() * shift_distance)
return self
def get_text(self, text, **kwargs):
def get_text(self, text: str, **kwargs) -> Text:
buff = kwargs.pop("buff", SMALL_BUFF)
text_mob = Text(text, **kwargs)
self.put_at_tip(text_mob, buff=buff)
return text_mob
def get_tex(self, *tex, **kwargs):
def get_tex(self, *tex: str, **kwargs) -> Tex:
tex_mob = Tex(*tex)
self.put_at_tip(tex_mob, **kwargs)
return tex_mob
def get_tip(self):
def get_tip(self) -> np.ndarray:
# Very specific to the LaTeX representation
# of a brace, but it's the only way I can think
# of to get the tip regardless of orientation.
return self.get_all_points()[self.tip_point_index]
def get_direction(self):
def get_direction(self) -> np.ndarray:
vect = self.get_tip() - self.get_center()
return vect / get_norm(vect)
@ -92,14 +111,20 @@ class BraceLabel(VMobject):
"label_buff": DEFAULT_MOBJECT_TO_MOBJECT_BUFFER
}
def __init__(self, obj, text, brace_direction=DOWN, **kwargs):
def __init__(
self,
obj: list[VMobject] | Mobject,
text: Iterable[str] | str,
brace_direction: np.ndarray = DOWN,
**kwargs
) -> None:
VMobject.__init__(self, **kwargs)
self.brace_direction = brace_direction
if isinstance(obj, list):
obj = VMobject(*obj)
self.brace = Brace(obj, brace_direction, **kwargs)
if isinstance(text, tuple) or isinstance(text, list):
if isinstance(text, Iterable):
self.label = self.label_constructor(*text, **kwargs)
else:
self.label = self.label_constructor(str(text))
@ -109,10 +134,14 @@ class BraceLabel(VMobject):
self.brace.put_at_tip(self.label, buff=self.label_buff)
self.set_submobjects([self.brace, self.label])
def creation_anim(self, label_anim=FadeIn, brace_anim=GrowFromCenter):
def creation_anim(
self,
label_anim: Animation = FadeIn,
brace_anim: Animation=GrowFromCenter
) -> AnimationGroup:
return AnimationGroup(brace_anim(self.brace), label_anim(self.label))
def shift_brace(self, obj, **kwargs):
def shift_brace(self, obj: list[VMobject] | Mobject, **kwargs):
if isinstance(obj, list):
obj = VMobject(*obj)
self.brace = Brace(obj, self.brace_direction, **kwargs)
@ -120,7 +149,7 @@ class BraceLabel(VMobject):
self.submobjects[0] = self.brace
return self
def change_label(self, *text, **kwargs):
def change_label(self, *text: str, **kwargs):
self.label = self.label_constructor(*text, **kwargs)
if self.label_scale != 1:
self.label.scale(self.label_scale)
@ -129,7 +158,7 @@ class BraceLabel(VMobject):
self.submobjects[1] = self.label
return self
def change_brace_label(self, obj, *text):
def change_brace_label(self, obj: list[VMobject] | Mobject, *text: str):
self.shift_brace(obj)
self.change_label(*text)
return self

View file

@ -1,6 +1,10 @@
import itertools as it
from __future__ import annotations
import re
import colour
import itertools as it
from types import MethodType
from typing import Iterable, Union, Sequence
from manimlib.constants import WHITE
from manimlib.mobject.svg.svg_mobject import SVGMobject
@ -14,26 +18,30 @@ from manimlib.utils.tex_file_writing import get_tex_config
from manimlib.utils.tex_file_writing import display_during_execution
from manimlib.logger import log
ManimColor = Union[str, colour.Color, Sequence[float]]
SCALE_FACTOR_PER_FONT_POINT = 0.001
def _get_neighbouring_pairs(iterable):
def _get_neighbouring_pairs(iterable: Iterable) -> list:
return list(adjacent_pairs(iterable))[:-1]
class _TexParser(object):
def __init__(self, tex_string, additional_substrings):
def __init__(self, tex_string: str, additional_substrings: str):
self.tex_string = tex_string
self.whitespace_indices = self.get_whitespace_indices()
self.backslash_indices = self.get_backslash_indices()
self.script_indices = self.get_script_indices()
self.brace_indices_dict = self.get_brace_indices_dict()
self.tex_span_list = []
self.script_span_to_char_dict = {}
self.script_span_to_tex_span_dict = {}
self.neighbouring_script_span_pairs = []
self.specified_substrings = []
self.tex_span_list: list[tuple[int, int]] = []
self.script_span_to_char_dict: dict[tuple[int, int], str] = {}
self.script_span_to_tex_span_dict: dict[
tuple[int, int], tuple[int, int]
] = {}
self.neighbouring_script_span_pairs: list[tuple[int, int]] = []
self.specified_substrings: list[str] = []
self.add_tex_span((0, len(tex_string)))
self.break_up_by_scripts()
self.break_up_by_double_braces()
@ -44,17 +52,17 @@ class _TexParser(object):
)
self.containing_labels_dict = self.get_containing_labels_dict()
def add_tex_span(self, tex_span):
def add_tex_span(self, tex_span: tuple[int, int]) -> None:
if tex_span not in self.tex_span_list:
self.tex_span_list.append(tex_span)
def get_whitespace_indices(self):
def get_whitespace_indices(self) -> list[int]:
return [
match_obj.start()
for match_obj in re.finditer(r"\s", self.tex_string)
]
def get_backslash_indices(self):
def get_backslash_indices(self) -> list[int]:
# Newlines (`\\`) don't count.
return [
match_obj.end() - 1
@ -62,19 +70,19 @@ class _TexParser(object):
if len(match_obj.group()) % 2 == 1
]
def filter_out_escaped_characters(self, indices):
def filter_out_escaped_characters(self, indices) -> list[int]:
return list(filter(
lambda index: index - 1 not in self.backslash_indices,
indices
))
def get_script_indices(self):
def get_script_indices(self) -> list[int]:
return self.filter_out_escaped_characters([
match_obj.start()
for match_obj in re.finditer(r"[_^]", self.tex_string)
])
def get_brace_indices_dict(self):
def get_brace_indices_dict(self) -> dict[int, int]:
tex_string = self.tex_string
indices = self.filter_out_escaped_characters([
match_obj.start()
@ -90,7 +98,7 @@ class _TexParser(object):
result[left_brace_index] = index
return result
def break_up_by_scripts(self):
def break_up_by_scripts(self) -> None:
# Match subscripts & superscripts.
tex_string = self.tex_string
whitespace_indices = self.whitespace_indices
@ -139,7 +147,7 @@ class _TexParser(object):
if span_0[1] == span_1[0]:
self.neighbouring_script_span_pairs.append((span_0, span_1))
def break_up_by_double_braces(self):
def break_up_by_double_braces(self) -> None:
# Match paired double braces (`{{...}}`).
tex_string = self.tex_string
reversed_indices_dict = dict(
@ -163,7 +171,10 @@ class _TexParser(object):
self.specified_substrings.append(tex_string[slice(*tex_span)])
skip = True
def break_up_by_additional_substrings(self, additional_substrings):
def break_up_by_additional_substrings(
self,
additional_substrings: Iterable[str]
) -> None:
stripped_substrings = sorted(remove_list_redundancies([
string.strip()
for string in additional_substrings
@ -193,7 +204,7 @@ class _TexParser(object):
continue
self.add_tex_span((span_begin, span_end))
def get_containing_labels_dict(self):
def get_containing_labels_dict(self) -> dict[tuple[int, int], list[int]]:
tex_span_list = self.tex_span_list
result = {
tex_span: []
@ -218,7 +229,7 @@ class _TexParser(object):
raise ValueError
return result
def get_labelled_tex_string(self):
def get_labelled_tex_string(self) -> str:
indices, _, flags, labels = zip(*sorted([
(*tex_span[::(1, -1)[flag]], flag, label)
for label, tex_span in enumerate(self.tex_span_list)
@ -236,7 +247,7 @@ class _TexParser(object):
return "".join(it.chain(*zip(command_pieces, string_pieces)))
@staticmethod
def get_color_command(label):
def get_color_command(label: int) -> str:
rg, b = divmod(label, 256)
r, g = divmod(rg, 256)
return "".join([
@ -246,7 +257,7 @@ class _TexParser(object):
"}"
])
def get_sorted_submob_indices(self, submob_labels):
def get_sorted_submob_indices(self, submob_labels: Iterable[int]) -> list[int]:
def script_span_to_submob_range(script_span):
tex_span = self.script_span_to_tex_span_dict[script_span]
submob_indices = [
@ -280,7 +291,7 @@ class _TexParser(object):
]
return result
def get_submob_tex_strings(self, submob_labels):
def get_submob_tex_strings(self, submob_labels: Iterable[int]) -> list[str]:
ordered_tex_spans = [
self.tex_span_list[label] for label in submob_labels
]
@ -341,7 +352,10 @@ class _TexParser(object):
]))
return result
def find_span_components_of_custom_span(self, custom_span):
def find_span_components_of_custom_span(
self,
custom_span: tuple[int, int]
) -> list[tuple[int, int]] | None:
skipped_indices = sorted(it.chain(
self.whitespace_indices,
self.script_indices
@ -369,16 +383,19 @@ class _TexParser(object):
span_begin = next_begin
return result
def get_containing_labels_by_tex_spans(self, tex_spans):
def get_containing_labels_by_tex_spans(
self,
tex_spans: Iterable[tuple[int, int]]
) -> list[int]:
return remove_list_redundancies(list(it.chain(*[
self.containing_labels_dict[tex_span]
for tex_span in tex_spans
])))
def get_specified_substrings(self):
def get_specified_substrings(self) -> list[str]:
return self.specified_substrings
def get_isolated_substrings(self):
def get_isolated_substrings(self) -> list[str]:
return remove_list_redundancies([
self.tex_string[slice(*tex_span)]
for tex_span in self.tex_span_list
@ -408,7 +425,7 @@ class MTex(_TexSVG):
"use_plain_tex": False,
}
def __init__(self, tex_string, **kwargs):
def __init__(self, tex_string: str, **kwargs):
digest_config(self, kwargs)
tex_string = tex_string.strip()
# Prevent from passing an empty string.
@ -425,7 +442,9 @@ class MTex(_TexSVG):
self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size)
@property
def hash_seed(self):
def hash_seed(
self
) -> tuple[str, dict[str], dict[str, bool], str, list[str], str, str, bool]:
return (
self.__class__.__name__,
self.svg_default,
@ -437,10 +456,10 @@ class MTex(_TexSVG):
self.use_plain_tex
)
def get_file_path(self):
def get_file_path(self) -> str:
return self._get_file_path(self.use_plain_tex)
def _get_file_path(self, use_plain_tex):
def _get_file_path(self, use_plain_tex: bool) -> str:
if use_plain_tex:
tex_string = self.tex_string
else:
@ -451,7 +470,7 @@ class MTex(_TexSVG):
file_path = self.tex_to_svg_file_path(full_tex)
return file_path
def get_tex_file_body(self, tex_string):
def get_tex_file_body(self, tex_string: str) -> str:
if self.tex_environment:
tex_string = "\n".join([
f"\\begin{{{self.tex_environment}}}",
@ -468,10 +487,10 @@ class MTex(_TexSVG):
)
@staticmethod
def tex_to_svg_file_path(tex_file_content):
def tex_to_svg_file_path(tex_file_content: str) -> str:
return tex_to_svg_file(tex_file_content)
def generate_mobject(self):
def generate_mobject(self) -> None:
super().generate_mobject()
if not self.use_plain_tex:
@ -488,12 +507,16 @@ class MTex(_TexSVG):
self.set_submobjects(mob.submobjects)
@staticmethod
def color_to_label(color):
def color_to_label(color: ManimColor) -> int:
r, g, b = color_to_int_rgb(color)
rg = r * 256 + g
return rg * 256 + b
def build_mobject(self, svg_glyphs, glyph_labels):
def build_mobject(
self,
svg_glyphs: _TexSVG | None,
glyph_labels: Iterable[int]
) -> VGroup:
if not svg_glyphs:
return VGroup()
@ -531,14 +554,17 @@ class MTex(_TexSVG):
submob.get_tex = MethodType(lambda inst: inst.tex_string, submob)
return VGroup(*rearranged_submobjects)
def get_part_by_tex_spans(self, tex_spans):
def get_part_by_tex_spans(
self,
tex_spans: Iterable[tuple[int, int]]
) -> VGroup:
labels = self.parser.get_containing_labels_by_tex_spans(tex_spans)
return VGroup(*filter(
lambda submob: submob.submob_label in labels,
self.submobjects
))
def get_part_by_custom_span(self, custom_span):
def get_part_by_custom_span(self, custom_span: tuple[int, int]) -> VGroup:
tex_spans = self.parser.find_span_components_of_custom_span(
custom_span
)
@ -547,7 +573,7 @@ class MTex(_TexSVG):
raise ValueError(f"Failed to match mobjects from tex: \"{tex}\"")
return self.get_part_by_tex_spans(tex_spans)
def get_parts_by_tex(self, tex):
def get_parts_by_tex(self, tex: str) -> VGroup:
return VGroup(*[
self.get_part_by_custom_span(match_obj.span())
for match_obj in re.finditer(
@ -555,20 +581,23 @@ class MTex(_TexSVG):
)
])
def get_part_by_tex(self, tex, index=0):
def get_part_by_tex(self, tex: str, index: int = 0) -> VGroup:
all_parts = self.get_parts_by_tex(tex)
return all_parts[index]
def set_color_by_tex(self, tex, color):
def set_color_by_tex(self, tex: str, color: ManimColor):
self.get_parts_by_tex(tex).set_color(color)
return self
def set_color_by_tex_to_color_map(self, tex_to_color_map):
def set_color_by_tex_to_color_map(
self,
tex_to_color_map: dict[str, ManimColor]
):
for tex, color in tex_to_color_map.items():
self.set_color_by_tex(tex, color)
return self
def indices_of_part(self, part):
def indices_of_part(self, part: Iterable[VGroup]) -> list[int]:
indices = [
index for index, submob in enumerate(self.submobjects)
if submob in part
@ -577,23 +606,23 @@ class MTex(_TexSVG):
raise ValueError("Failed to find part in tex")
return indices
def indices_of_part_by_tex(self, tex, index=0):
def indices_of_part_by_tex(self, tex: str, index: int = 0) -> list[int]:
part = self.get_part_by_tex(tex, index=index)
return self.indices_of_part(part)
def get_tex(self):
def get_tex(self) -> str:
return self.tex_string
def get_submob_tex(self):
def get_submob_tex(self) -> list[str]:
return [
submob.get_tex()
for submob in self.submobjects
]
def get_specified_substrings(self):
def get_specified_substrings(self) -> list[str]:
return self.parser.get_specified_substrings()
def get_isolated_substrings(self):
def get_isolated_substrings(self) -> list[str]:
return self.parser.get_isolated_substrings()

View file

@ -1,6 +1,9 @@
from __future__ import annotations
import os
import hashlib
import itertools as it
from typing import Callable
from xml.etree import ElementTree as ET
import svgelements as se
@ -21,10 +24,10 @@ from manimlib.utils.iterables import hash_obj
from manimlib.logger import log
SVG_HASH_TO_MOB_MAP = {}
SVG_HASH_TO_MOB_MAP: dict[int, VMobject] = {}
def _convert_point_to_3d(x, y):
def _convert_point_to_3d(x: float, y: float) -> np.ndarray:
return np.array([x, y, 0.0])
@ -55,15 +58,14 @@ class SVGMobject(VMobject):
},
"path_string_config": {},
}
def __init__(self, file_name=None, **kwargs):
def __init__(self, file_name: str | None = None, **kwargs):
super().__init__(**kwargs)
self.file_name = file_name or self.file_name
self.init_svg_mobject()
self.init_colors()
self.move_into_position()
def init_svg_mobject(self):
def init_svg_mobject(self) -> None:
hash_val = hash_obj(self.hash_seed)
if hash_val in SVG_HASH_TO_MOB_MAP:
mob = SVG_HASH_TO_MOB_MAP[hash_val].copy()
@ -74,7 +76,7 @@ class SVGMobject(VMobject):
SVG_HASH_TO_MOB_MAP[hash_val] = self.copy()
@property
def hash_seed(self):
def hash_seed(self) -> tuple[str, dict[str], dict[str, bool], str]:
# Returns data which can uniquely represent the result of `init_points`.
# The hashed value of it is stored as a key in `SVG_HASH_TO_MOB_MAP`.
return (
@ -84,7 +86,7 @@ class SVGMobject(VMobject):
self.file_name
)
def generate_mobject(self):
def generate_mobject(self) -> None:
file_path = self.get_file_path()
element_tree = ET.parse(file_path)
new_tree = self.modify_xml_tree(element_tree)
@ -100,12 +102,12 @@ class SVGMobject(VMobject):
self.add(*mobjects)
self.flip(RIGHT) # Flip y
def get_file_path(self):
def get_file_path(self) -> str:
if self.file_name is None:
raise Exception("Must specify file for SVGMobject")
return get_full_vector_image_path(self.file_name)
def modify_xml_tree(self, element_tree):
def modify_xml_tree(self, element_tree: ET.ElementTree) -> ET.ElementTree:
config_style_dict = self.generate_config_style_dict()
style_keys = (
"fill",
@ -127,7 +129,7 @@ class SVGMobject(VMobject):
root_style_node.extend(root)
return ET.ElementTree(new_root)
def generate_config_style_dict(self):
def generate_config_style_dict(self) -> dict[str, str]:
keys_converting_dict = {
"fill": ("color", "fill_color"),
"fill-opacity": ("opacity", "fill_opacity"),
@ -144,7 +146,7 @@ class SVGMobject(VMobject):
result[svg_key] = str(svg_default_dict[style_key])
return result
def get_mobjects_from(self, svg):
def get_mobjects_from(self, svg: se.SVG) -> list[VMobject]:
result = []
for shape in svg.elements():
if isinstance(shape, se.Group):
@ -177,7 +179,45 @@ class SVGMobject(VMobject):
return result
@staticmethod
def apply_style_to_mobject(mob, shape):
def handle_transform(mob: VMobject, matrix: se.Matrix) -> VMobject:
mat = np.array([
[matrix.a, matrix.c],
[matrix.b, matrix.d]
])
vec = np.array([matrix.e, matrix.f, 0.0])
mob.apply_matrix(mat)
mob.shift(vec)
return mob
def get_mobject_from(self, shape: se.GraphicObject) -> VMobject | None:
shape_class_to_func_map: dict[
type, Callable[[se.GraphicObject], VMobject]
] = {
se.Path: self.path_to_mobject,
se.SimpleLine: self.line_to_mobject,
se.Rect: self.rect_to_mobject,
se.Circle: self.circle_to_mobject,
se.Ellipse: self.ellipse_to_mobject,
se.Polygon: self.polygon_to_mobject,
se.Polyline: self.polyline_to_mobject,
# se.Text: self.text_to_mobject, # TODO
}
for shape_class, func in shape_class_to_func_map.items():
if isinstance(shape, shape_class):
mob = func(shape)
self.apply_style_to_mobject(mob, shape)
return mob
shape_class_name = shape.__class__.__name__
if shape_class_name != "SVGElement":
log.warning(f"Unsupported element type: {shape_class_name}")
return None
@staticmethod
def apply_style_to_mobject(
mob: VMobject,
shape: se.GraphicObject
) -> VMobject:
mob.set_style(
stroke_width=shape.stroke_width,
stroke_color=shape.stroke.hex,
@ -198,16 +238,16 @@ class SVGMobject(VMobject):
mob.shift(vec)
return mob
def path_to_mobject(self, path):
def path_to_mobject(self, path: se.Path) -> VMobjectFromSVGPath:
return VMobjectFromSVGPath(path, **self.path_string_config)
def line_to_mobject(self, line):
def line_to_mobject(self, line: se.Line) -> Line:
return Line(
start=_convert_point_to_3d(line.x1, line.y1),
end=_convert_point_to_3d(line.x2, line.y2)
)
def rect_to_mobject(self, rect):
def rect_to_mobject(self, rect: se.Rect) -> Rectangle:
if rect.rx == 0 or rect.ry == 0:
mob = Rectangle(
width=rect.width,
@ -226,7 +266,7 @@ class SVGMobject(VMobject):
))
return mob
def circle_to_mobject(self, circle):
def circle_to_mobject(self, circle: se.Circle) -> Circle:
# svgelements supports `rx` & `ry` but `r`
mob = Circle(radius=circle.rx)
mob.shift(_convert_point_to_3d(
@ -234,7 +274,7 @@ class SVGMobject(VMobject):
))
return mob
def ellipse_to_mobject(self, ellipse):
def ellipse_to_mobject(self, ellipse: se.Ellipse) -> Circle:
mob = Circle(radius=ellipse.rx)
mob.stretch_to_fit_height(2 * ellipse.ry)
mob.shift(_convert_point_to_3d(
@ -242,24 +282,24 @@ class SVGMobject(VMobject):
))
return mob
def polygon_to_mobject(self, polygon):
def polygon_to_mobject(self, polygon: se.Polygon) -> Polygon:
points = [
_convert_point_to_3d(*point)
for point in polygon
]
return Polygon(*points)
def polyline_to_mobject(self, polyline):
def polyline_to_mobject(self, polyline: se.Polyline) -> Polyline:
points = [
_convert_point_to_3d(*point)
for point in polyline
]
return Polyline(*points)
def text_to_mobject(self, text):
def text_to_mobject(self, text: se.Text):
pass
def move_into_position(self):
def move_into_position(self) -> None:
if self.should_center:
self.center()
if self.height is not None:
@ -275,13 +315,13 @@ class VMobjectFromSVGPath(VMobject):
"should_remove_null_curves": False,
}
def __init__(self, path_obj, **kwargs):
def __init__(self, path_obj: se.Path, **kwargs):
# Get rid of arcs
path_obj.approximate_arcs_with_quads()
self.path_obj = path_obj
super().__init__(**kwargs)
def init_points(self):
def init_points(self) -> None:
# After a given svg_path has been converted into points, the result
# will be saved to a file so that future calls for the same path
# don't need to retrace the same computation.
@ -307,7 +347,7 @@ class VMobjectFromSVGPath(VMobject):
np.save(points_filepath, self.get_points())
np.save(tris_filepath, self.get_triangulation())
def handle_commands(self):
def handle_commands(self) -> None:
segment_class_to_func_map = {
se.Move: (self.start_new_path, ("end",)),
se.Close: (self.close_path, ()),

View file

@ -1,5 +1,9 @@
from __future__ import annotations
from typing import Iterable, Sequence, Union
from functools import reduce
import operator as op
import colour
import re
from manimlib.constants import *
@ -11,6 +15,8 @@ from manimlib.utils.tex_file_writing import tex_to_svg_file
from manimlib.utils.tex_file_writing import get_tex_config
from manimlib.utils.tex_file_writing import display_during_execution
ManimColor = Union[str, colour.Color, Sequence[float]]
SCALE_FACTOR_PER_FONT_POINT = 0.001
@ -33,7 +39,7 @@ class SingleStringTex(SVGMobject):
"organize_left_to_right": False,
}
def __init__(self, tex_string, **kwargs):
def __init__(self, tex_string: str, **kwargs):
assert isinstance(tex_string, str)
self.tex_string = tex_string
super().__init__(**kwargs)
@ -44,7 +50,7 @@ class SingleStringTex(SVGMobject):
self.organize_submobjects_left_to_right()
@property
def hash_seed(self):
def hash_seed(self) -> tuple[str, dict[str], dict[str, bool], str, str, bool]:
return (
self.__class__.__name__,
self.svg_default,
@ -54,13 +60,13 @@ class SingleStringTex(SVGMobject):
self.math_mode
)
def get_file_path(self):
def get_file_path(self) -> str:
full_tex = self.get_tex_file_body(self.tex_string)
with display_during_execution(f"Writing \"{self.tex_string}\""):
file_path = tex_to_svg_file(full_tex)
return file_path
def get_tex_file_body(self, tex_string):
def get_tex_file_body(self, tex_string: str) -> str:
new_tex = self.get_modified_expression(tex_string)
if self.math_mode:
new_tex = "\\begin{align*}\n" + new_tex + "\n\\end{align*}"
@ -73,10 +79,10 @@ class SingleStringTex(SVGMobject):
new_tex
)
def get_modified_expression(self, tex_string):
def get_modified_expression(self, tex_string: str) -> str:
return self.modify_special_strings(tex_string.strip())
def modify_special_strings(self, tex):
def modify_special_strings(self, tex: str) -> str:
tex = tex.strip()
should_add_filler = reduce(op.or_, [
# Fraction line needs something to be over
@ -128,7 +134,7 @@ class SingleStringTex(SVGMobject):
tex = ""
return tex
def balance_braces(self, tex):
def balance_braces(self, tex: str) -> str:
"""
Makes Tex resiliant to unmatched braces
"""
@ -148,7 +154,7 @@ class SingleStringTex(SVGMobject):
tex += num_unclosed_brackets * "}"
return tex
def get_tex(self):
def get_tex(self) -> str:
return self.tex_string
def organize_submobjects_left_to_right(self):
@ -163,7 +169,7 @@ class Tex(SingleStringTex):
"tex_to_color_map": {},
}
def __init__(self, *tex_strings, **kwargs):
def __init__(self, *tex_strings: str, **kwargs):
digest_config(self, kwargs)
self.tex_strings = self.break_up_tex_strings(tex_strings)
full_string = self.arg_separator.join(self.tex_strings)
@ -174,7 +180,7 @@ class Tex(SingleStringTex):
if self.organize_left_to_right:
self.organize_submobjects_left_to_right()
def break_up_tex_strings(self, tex_strings):
def break_up_tex_strings(self, tex_strings: Iterable[str]) -> Iterable[str]:
# Separate out any strings specified in the isolate
# or tex_to_color_map lists.
substrings_to_isolate = [*self.isolate, *self.tex_to_color_map.keys()]
@ -222,7 +228,12 @@ class Tex(SingleStringTex):
self.set_submobjects(new_submobjects)
return self
def get_parts_by_tex(self, tex, substring=True, case_sensitive=True):
def get_parts_by_tex(
self,
tex: str,
substring: bool = True,
case_sensitive: bool = True
) -> VGroup:
def test(tex1, tex2):
if not case_sensitive:
tex1 = tex1.lower()
@ -237,27 +248,36 @@ class Tex(SingleStringTex):
self.submobjects
))
def get_part_by_tex(self, tex, **kwargs):
def get_part_by_tex(self, tex: str, **kwargs) -> SingleStringTex | None:
all_parts = self.get_parts_by_tex(tex, **kwargs)
return all_parts[0] if all_parts else None
def set_color_by_tex(self, tex, color, **kwargs):
def set_color_by_tex(self, tex: str, color: ManimColor, **kwargs):
self.get_parts_by_tex(tex, **kwargs).set_color(color)
return self
def set_color_by_tex_to_color_map(self, tex_to_color_map, **kwargs):
def set_color_by_tex_to_color_map(
self,
tex_to_color_map: dict[str, ManimColor],
**kwargs
):
for tex, color in list(tex_to_color_map.items()):
self.set_color_by_tex(tex, color, **kwargs)
return self
def index_of_part(self, part, start=0):
def index_of_part(self, part: SingleStringTex, start: int = 0) -> int:
return self.submobjects.index(part, start)
def index_of_part_by_tex(self, tex, start=0, **kwargs):
def index_of_part_by_tex(self, tex: str, start: int = 0, **kwargs) -> int:
part = self.get_part_by_tex(tex, **kwargs)
return self.index_of_part(part, start)
def slice_by_tex(self, start_tex=None, stop_tex=None, **kwargs):
def slice_by_tex(
self,
start_tex: str | None = None,
stop_tex: str | None = None,
**kwargs
) -> VGroup:
if start_tex is None:
start_index = 0
else:
@ -269,10 +289,10 @@ class Tex(SingleStringTex):
stop_index = self.index_of_part_by_tex(stop_tex, start=start_index, **kwargs)
return self[start_index:stop_index]
def sort_alphabetically(self):
def sort_alphabetically(self) -> None:
self.submobjects.sort(key=lambda m: m.get_tex())
def set_bstroke(self, color=BLACK, width=4):
def set_bstroke(self, color: ManimColor = BLACK, width: float = 4):
self.set_stroke(color, width, background=True)
return self
@ -291,7 +311,7 @@ class BulletedList(TexText):
"alignment": "",
}
def __init__(self, *items, **kwargs):
def __init__(self, *items: str, **kwargs):
line_separated_items = [s + "\\\\" for s in items]
TexText.__init__(self, *line_separated_items, **kwargs)
for part in self:
@ -304,7 +324,7 @@ class BulletedList(TexText):
buff=self.buff
)
def fade_all_but(self, index_or_string, opacity=0.5):
def fade_all_but(self, index_or_string: int | str, opacity: float = 0.5) -> None:
arg = index_or_string
if isinstance(arg, str):
part = self.get_part_by_tex(arg)
@ -342,7 +362,7 @@ class Title(TexText):
"underline_buff": MED_SMALL_BUFF,
}
def __init__(self, *text_parts, **kwargs):
def __init__(self, *text_parts: str, **kwargs):
TexText.__init__(self, *text_parts, **kwargs)
self.scale(self.scale_factor)
self.to_edge(UP)

View file

@ -1,14 +1,21 @@
from __future__ import annotations
import os
import re
import typing
import xml.sax.saxutils as saxutils
from contextlib import contextmanager
import io
import hashlib
import functools
from pathlib import Path
import xml.etree.ElementTree as ET
from contextlib import contextmanager
from typing import Iterable, Sequence, Union
import pygments
import pygments.formatters
import pygments.lexers
import manimpango
import pygments.styles
from manimpango import MarkupUtils
from manimlib.logger import log
@ -22,6 +29,13 @@ from manimlib.utils.directories import get_downloads_dir
from manimlib.utils.directories import get_text_dir
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import colour
from manimlib.mobject.types.vectorized_mobject import VMobject
ManimColor = Union[str, colour.Color, Sequence[float]]
TEXT_MOB_SCALE_FACTOR = 0.0076
DEFAULT_LINE_SPACING_SCALE = 0.6
@ -433,6 +447,7 @@ class MarkupText(Text):
}
class Code(MarkupText):
CONFIG = {
"font": "Consolas",
@ -454,7 +469,7 @@ class Code(MarkupText):
@contextmanager
def register_font(font_file: typing.Union[str, Path]):
def register_font(font_file: str | Path):
"""Temporarily add a font file to Pango's search path.
This searches for the font_file at various places. The order it searches it described below.
1. Absolute path.

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import math
from manimlib.constants import *
@ -23,13 +25,13 @@ class SurfaceMesh(VGroup):
"flat_stroke": False,
}
def __init__(self, uv_surface, **kwargs):
def __init__(self, uv_surface: Surface, **kwargs):
if not isinstance(uv_surface, Surface):
raise Exception("uv_surface must be of type Surface")
self.uv_surface = uv_surface
super().__init__(**kwargs)
def init_points(self):
def init_points(self) -> None:
uv_surface = self.uv_surface
full_nu, full_nv = uv_surface.resolution
@ -75,7 +77,7 @@ class Sphere(Surface):
"v_range": (0, PI),
}
def uv_func(self, u, v):
def uv_func(self, u: float, v: float) -> np.ndarray:
return self.radius * np.array([
np.cos(u) * np.sin(v),
np.sin(u) * np.sin(v),
@ -91,7 +93,7 @@ class Torus(Surface):
"r2": 1,
}
def uv_func(self, u, v):
def uv_func(self, u: float, v: float) -> np.ndarray:
P = np.array([math.cos(u), math.sin(u), 0])
return (self.r1 - self.r2 * math.cos(v)) * P - math.sin(v) * OUT
@ -113,8 +115,8 @@ class Cylinder(Surface):
self.apply_matrix(z_to_vector(self.axis))
return self
def uv_func(self, u, v):
return [np.cos(u), np.sin(u), v]
def uv_func(self, u: float, v: float) -> np.ndarray:
return np.array([np.cos(u), np.sin(u), v])
class Line3D(Cylinder):
@ -123,7 +125,7 @@ class Line3D(Cylinder):
"resolution": (21, 25)
}
def __init__(self, start, end, **kwargs):
def __init__(self, start: np.ndarray, end: np.ndarray, **kwargs):
digest_config(self, kwargs)
axis = end - start
super().__init__(
@ -142,16 +144,16 @@ class Disk3D(Surface):
"resolution": (2, 25),
}
def init_points(self):
def init_points(self) -> None:
super().init_points()
self.scale(self.radius)
def uv_func(self, u, v):
return [
def uv_func(self, u: float, v: float) -> np.ndarray:
return np.array([
u * np.cos(v),
u * np.sin(v),
0
]
])
class Square3D(Surface):
@ -162,12 +164,12 @@ class Square3D(Surface):
"resolution": (2, 2),
}
def init_points(self):
def init_points(self) -> None:
super().init_points()
self.scale(self.side_length / 2)
def uv_func(self, u, v):
return [u, v, 0]
def uv_func(self, u: float, v: float) -> np.ndarray:
return np.array([u, v, 0])
class Cube(SGroup):
@ -180,7 +182,7 @@ class Cube(SGroup):
"square_class": Square3D,
}
def init_points(self):
def init_points(self) -> None:
face = Square3D(
resolution=self.square_resolution,
side_length=self.side_length,
@ -188,7 +190,7 @@ class Cube(SGroup):
self.add(*self.square_to_cube_faces(face))
@staticmethod
def square_to_cube_faces(square):
def square_to_cube_faces(square: Square3D) -> list[Square3D]:
radius = square.get_height() / 2
square.move_to(radius * OUT)
result = [square]
@ -199,7 +201,7 @@ class Cube(SGroup):
result.append(square.copy().rotate(PI, RIGHT, about_point=ORIGIN))
return result
def _get_face(self):
def _get_face(self) -> Square3D:
return Square3D(resolution=self.square_resolution)
@ -212,7 +214,7 @@ class VCube(VGroup):
"shadow": 0.5,
}
def __init__(self, side_length=2, **kwargs):
def __init__(self, side_length: int = 2, **kwargs):
super().__init__(**kwargs)
face = Square(side_length=side_length)
face.get_triangulation()
@ -233,7 +235,7 @@ class Dodecahedron(VGroup):
"depth_test": True,
}
def init_points(self):
def init_points(self) -> None:
# Star by creating two of the pentagons, meeting
# back to back on the positive x-axis
phi = (1 + math.sqrt(5)) / 2
@ -274,7 +276,7 @@ class Prism(Cube):
"dimensions": [3, 2, 1]
}
def init_points(self):
def init_points(self) -> None:
Cube.init_points(self)
for dim, value in enumerate(self.dimensions):
self.rescale_to_fit(value, dim, stretch=True)

View file

@ -1,4 +1,7 @@
from __future__ import annotations
import numpy as np
import numpy.typing as npt
import moderngl
from manimlib.constants import GREY_C
@ -29,27 +32,31 @@ class DotCloud(PMobject):
],
}
def __init__(self, points=None, **kwargs):
def __init__(self, points: npt.ArrayLike = None, **kwargs):
super().__init__(**kwargs)
if points is not None:
self.set_points(points)
def init_data(self):
def init_data(self) -> None:
super().init_data()
self.data["radii"] = np.zeros((1, 1))
self.set_radius(self.radius)
def init_uniforms(self):
def init_uniforms(self) -> None:
super().init_uniforms()
self.uniforms["glow_factor"] = self.glow_factor
def to_grid(self, n_rows, n_cols, n_layers=1,
buff_ratio=None,
h_buff_ratio=1.0,
v_buff_ratio=1.0,
d_buff_ratio=1.0,
height=DEFAULT_GRID_HEIGHT,
):
def to_grid(
self,
n_rows: int,
n_cols: int,
n_layers: int = 1,
buff_ratio: float | None = None,
h_buff_ratio: float = 1.0,
v_buff_ratio: float = 1.0,
d_buff_ratio: float = 1.0,
height: float = DEFAULT_GRID_HEIGHT,
):
n_points = n_rows * n_cols * n_layers
points = np.repeat(range(n_points), 3, axis=0).reshape((n_points, 3))
points[:, 0] = points[:, 0] % n_cols
@ -74,50 +81,55 @@ class DotCloud(PMobject):
self.center()
return self
def set_radii(self, radii):
def set_radii(self, radii: npt.ArrayLike):
n_points = len(self.get_points())
radii = np.array(radii).reshape((len(radii), 1))
self.data["radii"] = resize_preserving_order(radii, n_points)
self.refresh_bounding_box()
return self
def get_radii(self):
def get_radii(self) -> np.ndarray:
return self.data["radii"]
def set_radius(self, radius):
def set_radius(self, radius: float):
self.data["radii"][:] = radius
self.refresh_bounding_box()
return self
def get_radius(self):
def get_radius(self) -> float:
return self.get_radii().max()
def set_glow_factor(self, glow_factor):
def set_glow_factor(self, glow_factor: float) -> None:
self.uniforms["glow_factor"] = glow_factor
def get_glow_factor(self):
def get_glow_factor(self) -> float:
return self.uniforms["glow_factor"]
def compute_bounding_box(self):
def compute_bounding_box(self) -> np.ndarray:
bb = super().compute_bounding_box()
radius = self.get_radius()
bb[0] += np.full((3,), -radius)
bb[2] += np.full((3,), radius)
return bb
def scale(self, scale_factor, scale_radii=True, **kwargs):
def scale(
self,
scale_factor: float | npt.ArrayLike,
scale_radii: bool = True,
**kwargs
):
super().scale(scale_factor, **kwargs)
if scale_radii:
self.set_radii(scale_factor * self.get_radii())
return self
def make_3d(self, reflectiveness=0.5, shadow=0.2):
def make_3d(self, reflectiveness: float = 0.5, shadow: float = 0.2):
self.set_reflectiveness(reflectiveness)
self.set_shadow(shadow)
self.apply_depth_test()
return self
def get_shader_data(self):
def get_shader_data(self) -> np.ndarray:
shader_data = super().get_shader_data()
self.read_data_to_shader(shader_data, "radius", "radii")
self.read_data_to_shader(shader_data, "color", "rgbas")
@ -125,7 +137,7 @@ class DotCloud(PMobject):
class TrueDot(DotCloud):
def __init__(self, center=ORIGIN, **kwargs):
def __init__(self, center: np.ndarray = ORIGIN, **kwargs):
super().__init__(points=[center], **kwargs)

View file

@ -1,5 +1,6 @@
import numpy as np
from __future__ import annotations
import numpy as np
from PIL import Image
from manimlib.constants import *
@ -21,33 +22,33 @@ class ImageMobject(Mobject):
]
}
def __init__(self, filename, **kwargs):
def __init__(self, filename: str, **kwargs):
self.set_image_path(get_full_raster_image_path(filename))
super().__init__(**kwargs)
def set_image_path(self, path):
def set_image_path(self, path: str) -> None:
self.path = path
self.image = Image.open(path)
self.texture_paths = {"Texture": path}
def init_data(self):
def init_data(self) -> None:
self.data = {
"points": np.array([UL, DL, UR, DR]),
"im_coords": np.array([(0, 0), (0, 1), (1, 0), (1, 1)]),
"opacity": np.array([[self.opacity]], dtype=np.float32),
}
def init_points(self):
def init_points(self) -> None:
size = self.image.size
self.set_width(2 * size[0] / size[1], stretch=True)
self.set_height(self.height)
def set_opacity(self, opacity, recurse=True):
def set_opacity(self, opacity: float, recurse: bool = True):
for mob in self.get_family(recurse):
mob.data["opacity"] = np.array([[o] for o in listify(opacity)])
return self
def point_to_rgb(self, point):
def point_to_rgb(self, point: np.ndarray) -> np.ndarray:
x0, y0 = self.get_corner(UL)[:2]
x1, y1 = self.get_corner(DR)[:2]
x_alpha = inverse_interpolate(x0, x1, point[0])
@ -63,7 +64,7 @@ class ImageMobject(Mobject):
))
return np.array(rgb) / 255
def get_shader_data(self):
def get_shader_data(self) -> np.ndarray:
shader_data = super().get_shader_data()
self.read_data_to_shader(shader_data, "im_coords", "im_coords")
self.read_data_to_shader(shader_data, "opacity", "opacity")

View file

@ -1,3 +1,10 @@
from __future__ import annotations
from typing import Callable, Sequence, Union
import colour
import numpy.typing as npt
from manimlib.constants import *
from manimlib.mobject.mobject import Mobject
from manimlib.utils.color import color_gradient
@ -6,28 +13,41 @@ from manimlib.utils.iterables import resize_with_interpolation
from manimlib.utils.iterables import resize_array
ManimColor = Union[str, colour.Color, Sequence[float]]
class PMobject(Mobject):
CONFIG = {
"opacity": 1.0,
}
def resize_points(self, size, resize_func=resize_array):
def resize_points(
self,
size: int,
resize_func: Callable[[np.ndarray, int], np.ndarray] = resize_array
):
# TODO
for key in self.data:
if key == "bounding_box":
continue
if len(self.data[key]) != size:
self.data[key] = resize_array(self.data[key], size)
self.data[key] = resize_func(self.data[key], size)
return self
def set_points(self, points):
def set_points(self, points: npt.ArrayLike):
if len(points) == 0:
points = np.zeros((0, 3))
super().set_points(points)
self.resize_points(len(points))
return self
def add_points(self, points, rgbas=None, color=None, opacity=None):
def add_points(
self,
points: npt.ArrayLike,
rgbas: np.ndarray | None = None,
color: ManimColor | None = None,
opacity: float | None = None
):
"""
points must be a Nx3 numpy array, as must rgbas if it is not None
"""
@ -50,20 +70,20 @@ class PMobject(Mobject):
self.add_points([point], rgbas, color, opacity)
return self
def set_color_by_gradient(self, *colors):
def set_color_by_gradient(self, *colors: ManimColor):
self.data["rgbas"] = np.array(list(map(
color_to_rgba,
color_gradient(colors, self.get_num_points())
)))
return self
def match_colors(self, pmobject):
def match_colors(self, pmobject: PMobject):
self.data["rgbas"][:] = resize_with_interpolation(
pmobject.data["rgbas"], self.get_num_points()
)
return self
def filter_out(self, condition):
def filter_out(self, condition: Callable[[np.ndarray], bool]):
for mob in self.family_members_with_points():
to_keep = ~np.apply_along_axis(condition, 1, mob.get_points())
for key in mob.data:
@ -72,7 +92,7 @@ class PMobject(Mobject):
mob.data[key] = mob.data[key][to_keep]
return self
def sort_points(self, function=lambda p: p[0]):
def sort_points(self, function: Callable[[np.ndarray]] = lambda p: p[0]):
"""
function is any map from R^3 to R
"""
@ -92,11 +112,11 @@ class PMobject(Mobject):
])
return self
def point_from_proportion(self, alpha):
def point_from_proportion(self, alpha: float) -> np.ndarray:
index = alpha * (self.get_num_points() - 1)
return self.get_points()[int(index)]
def pointwise_become_partial(self, pmobject, a, b):
def pointwise_become_partial(self, pmobject: PMobject, a: float, b: float):
lower_index = int(a * pmobject.get_num_points())
upper_index = int(b * pmobject.get_num_points())
for key in self.data:
@ -107,7 +127,7 @@ class PMobject(Mobject):
class PGroup(PMobject):
def __init__(self, *pmobs, **kwargs):
def __init__(self, *pmobs: PMobject, **kwargs):
if not all([isinstance(m, PMobject) for m in pmobs]):
raise Exception("All submobjects must be of type PMobject")
super().__init__(*pmobs, **kwargs)
@ -118,6 +138,6 @@ class Point(PMobject):
"color": BLACK,
}
def __init__(self, location=ORIGIN, **kwargs):
def __init__(self, location: np.ndarray = ORIGIN, **kwargs):
super().__init__(**kwargs)
self.add_points([location])

View file

@ -1,5 +1,10 @@
import numpy as np
from __future__ import annotations
from typing import Iterable, Callable
import moderngl
import numpy as np
import numpy.typing as npt
from manimlib.constants import *
from manimlib.mobject.mobject import Mobject
@ -9,6 +14,11 @@ from manimlib.utils.images import get_full_raster_image_path
from manimlib.utils.iterables import listify
from manimlib.utils.space_ops import normalize_along_axis
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.camera.camera import Camera
class Surface(Mobject):
CONFIG = {
@ -42,7 +52,7 @@ class Surface(Mobject):
super().__init__(**kwargs)
self.compute_triangle_indices()
def uv_func(self, u, v):
def uv_func(self, u: float, v: float) -> tuple[float, float, float]:
# To be implemented in subclasses
return (u, v, 0.0)
@ -85,15 +95,17 @@ class Surface(Mobject):
indices[5::6] = index_grid[+1:, +1:].flatten() # Bottom right
self.triangle_indices = indices
def get_triangle_indices(self):
def get_triangle_indices(self) -> np.ndarray:
return self.triangle_indices
def get_surface_points_and_nudged_points(self):
def get_surface_points_and_nudged_points(
self
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
points = self.get_points()
k = len(points) // 3
return points[:k], points[k:2 * k], points[2 * k:]
def get_unit_normals(self):
def get_unit_normals(self) -> np.ndarray:
s_points, du_points, dv_points = self.get_surface_points_and_nudged_points()
normals = np.cross(
(du_points - s_points) / self.epsilon,
@ -101,7 +113,13 @@ class Surface(Mobject):
)
return normalize_along_axis(normals, 1)
def pointwise_become_partial(self, smobject, a, b, axis=None):
def pointwise_become_partial(
self,
smobject: "Surface",
a: float,
b: float,
axis: np.ndarray | None = None
):
assert(isinstance(smobject, Surface))
if axis is None:
axis = self.prefered_creation_axis
@ -116,7 +134,14 @@ class Surface(Mobject):
]))
return self
def get_partial_points_array(self, points, a, b, resolution, axis):
def get_partial_points_array(
self,
points: np.ndarray,
a: float,
b: float,
resolution: npt.ArrayLike,
axis: int
) -> np.ndarray:
if len(points) == 0:
return points
nu, nv = resolution[:2]
@ -149,7 +174,7 @@ class Surface(Mobject):
).reshape(shape)
return points.reshape((nu * nv, *resolution[2:]))
def sort_faces_back_to_front(self, vect=OUT):
def sort_faces_back_to_front(self, vect: np.ndarray = OUT):
tri_is = self.triangle_indices
indices = list(range(len(tri_is) // 3))
points = self.get_points()
@ -162,13 +187,13 @@ class Surface(Mobject):
tri_is[k::3] = tri_is[k::3][indices]
return self
def always_sort_to_camera(self, camera):
def always_sort_to_camera(self, camera: Camera):
self.add_updater(lambda m: m.sort_faces_back_to_front(
camera.get_location() - self.get_center()
))
# For shaders
def get_shader_data(self):
def get_shader_data(self) -> np.ndarray:
s_points, du_points, dv_points = self.get_surface_points_and_nudged_points()
shader_data = self.get_resized_shader_data_array(len(s_points))
if "points" not in self.locked_data_keys:
@ -178,16 +203,22 @@ class Surface(Mobject):
self.fill_in_shader_color_info(shader_data)
return shader_data
def fill_in_shader_color_info(self, shader_data):
def fill_in_shader_color_info(self, shader_data: np.ndarray) -> np.ndarray:
self.read_data_to_shader(shader_data, "color", "rgbas")
return shader_data
def get_shader_vert_indices(self):
def get_shader_vert_indices(self) -> np.ndarray:
return self.get_triangle_indices()
class ParametricSurface(Surface):
def __init__(self, uv_func, u_range=(0, 1), v_range=(0, 1), **kwargs):
def __init__(
self,
uv_func: Callable[[float, float], Iterable[float]],
u_range: tuple[float, float] = (0, 1),
v_range: tuple[float, float] = (0, 1),
**kwargs
):
self.passed_uv_func = uv_func
super().__init__(u_range=u_range, v_range=v_range, **kwargs)
@ -200,7 +231,7 @@ class SGroup(Surface):
"resolution": (0, 0),
}
def __init__(self, *parametric_surfaces, **kwargs):
def __init__(self, *parametric_surfaces: Surface, **kwargs):
super().__init__(uv_func=None, **kwargs)
self.add(*parametric_surfaces)
@ -220,7 +251,13 @@ class TexturedSurface(Surface):
]
}
def __init__(self, uv_surface, image_file, dark_image_file=None, **kwargs):
def __init__(
self,
uv_surface: Surface,
image_file: str,
dark_image_file: str | None = None,
**kwargs
):
if not isinstance(uv_surface, Surface):
raise Exception("uv_surface must be of type Surface")
# Set texture information
@ -236,10 +273,10 @@ class TexturedSurface(Surface):
self.uv_surface = uv_surface
self.uv_func = uv_surface.uv_func
self.u_range = uv_surface.u_range
self.v_range = uv_surface.v_range
self.resolution = uv_surface.resolution
self.gloss = self.uv_surface.gloss
self.u_range: tuple[float, float] = uv_surface.u_range
self.v_range: tuple[float, float] = uv_surface.v_range
self.resolution: tuple[float, float] = uv_surface.resolution
self.gloss: float = self.uv_surface.gloss
super().__init__(**kwargs)
def init_data(self):
@ -263,12 +300,18 @@ class TexturedSurface(Surface):
def init_colors(self):
self.data["opacity"] = np.array([self.uv_surface.data["rgbas"][:, 3]])
def set_opacity(self, opacity, recurse=True):
def set_opacity(self, opacity: float, recurse: bool = True):
for mob in self.get_family(recurse):
mob.data["opacity"] = np.array([[o] for o in listify(opacity)])
return self
def pointwise_become_partial(self, tsmobject, a, b, axis=1):
def pointwise_become_partial(
self,
tsmobject: "TexturedSurface",
a: float,
b: float,
axis: int = 1
):
super().pointwise_become_partial(tsmobject, a, b, axis)
im_coords = self.data["im_coords"]
im_coords[:] = tsmobject.data["im_coords"]
@ -280,7 +323,7 @@ class TexturedSurface(Surface):
)
return self
def fill_in_shader_color_info(self, shader_data):
def fill_in_shader_color_info(self, shader_data: np.ndarray) -> np.ndarray:
self.read_data_to_shader(shader_data, "opacity", "opacity")
self.read_data_to_shader(shader_data, "im_coords", "im_coords")
return shader_data

View file

@ -1,8 +1,13 @@
import itertools as it
import operator as op
import moderngl
from __future__ import annotations
import operator as op
import itertools as it
from functools import reduce, wraps
from typing import Iterable, Sequence, Callable, Union
import colour
import moderngl
import numpy.typing as npt
from manimlib.constants import *
from manimlib.mobject.mobject import Mobject
@ -29,6 +34,9 @@ from manimlib.utils.space_ops import z_to_vector
from manimlib.shader_wrapper import ShaderWrapper
ManimColor = Union[str, colour.Color, Sequence[float]]
class VMobject(Mobject):
CONFIG = {
"fill_color": None,
@ -105,7 +113,12 @@ class VMobject(Mobject):
self.set_flat_stroke(self.flat_stroke)
return self
def set_rgba_array(self, rgba_array, name=None, recurse=False):
def set_rgba_array(
self,
rgba_array: npt.ArrayLike,
name: str = None,
recurse: bool = False
):
if name is None:
names = ["fill_rgba", "stroke_rgba"]
else:
@ -115,11 +128,23 @@ class VMobject(Mobject):
super().set_rgba_array(rgba_array, name, recurse)
return self
def set_fill(self, color=None, opacity=None, recurse=True):
def set_fill(
self,
color: ManimColor | None = None,
opacity: float | None = None,
recurse: bool = True
):
self.set_rgba_array_by_color(color, opacity, 'fill_rgba', recurse)
return self
def set_stroke(self, color=None, width=None, opacity=None, background=None, recurse=True):
def set_stroke(
self,
color: ManimColor | None = None,
width: float | npt.ArrayLike | None = None,
opacity: float | None = None,
background: bool | None = None,
recurse: bool = True
):
self.set_rgba_array_by_color(color, opacity, 'stroke_rgba', recurse)
if width is not None:
@ -135,29 +160,36 @@ class VMobject(Mobject):
mob.draw_stroke_behind_fill = background
return self
def set_backstroke(self, color=BLACK, width=3, background=True):
def set_backstroke(
self,
color: ManimColor = BLACK,
width: float | npt.ArrayLike = 3,
background: bool = True
):
self.set_stroke(color, width, background=background)
return self
def align_stroke_width_data_to_points(self, recurse=True):
def align_stroke_width_data_to_points(self, recurse: bool = True) -> None:
for mob in self.get_family(recurse):
mob.data["stroke_width"] = resize_with_interpolation(
mob.data["stroke_width"], len(mob.get_points())
)
def set_style(self,
fill_color=None,
fill_opacity=None,
fill_rgba=None,
stroke_color=None,
stroke_opacity=None,
stroke_rgba=None,
stroke_width=None,
stroke_background=True,
reflectiveness=None,
gloss=None,
shadow=None,
recurse=True):
def set_style(
self,
fill_color: ManimColor | None = None,
fill_opacity: float | None = None,
fill_rgba: npt.ArrayLike | None = None,
stroke_color: ManimColor | None = None,
stroke_opacity: float | None = None,
stroke_rgba: npt.ArrayLike | None = None,
stroke_width: float | npt.ArrayLike | None = None,
stroke_background: bool = True,
reflectiveness: float | None = None,
gloss: float | None = None,
shadow: float | None = None,
recurse: bool = True
):
if fill_rgba is not None:
self.data['fill_rgba'] = resize_with_interpolation(fill_rgba, len(fill_rgba))
else:
@ -201,7 +233,7 @@ class VMobject(Mobject):
"shadow": self.get_shadow(),
}
def match_style(self, vmobject, recurse=True):
def match_style(self, vmobject: VMobject, recurse: bool = True):
self.set_style(**vmobject.get_style(), recurse=False)
if recurse:
# Does its best to match up submobject lists, and
@ -215,17 +247,17 @@ class VMobject(Mobject):
sm1.match_style(sm2)
return self
def set_color(self, color, recurse=True):
def set_color(self, color: ManimColor, recurse: bool = True):
self.set_fill(color, recurse=recurse)
self.set_stroke(color, recurse=recurse)
return self
def set_opacity(self, opacity, recurse=True):
def set_opacity(self, opacity: float, recurse: bool = True):
self.set_fill(opacity=opacity, recurse=recurse)
self.set_stroke(opacity=opacity, recurse=recurse)
return self
def fade(self, darkness=0.5, recurse=True):
def fade(self, darkness: float = 0.5, recurse: bool = True):
mobs = self.get_family() if recurse else [self]
for mob in mobs:
factor = 1.0 - darkness
@ -239,78 +271,83 @@ class VMobject(Mobject):
)
return self
def get_fill_colors(self):
def get_fill_colors(self) -> list[str]:
return [
rgb_to_hex(rgba[:3])
for rgba in self.data['fill_rgba']
]
def get_fill_opacities(self):
def get_fill_opacities(self) -> np.ndarray:
return self.data['fill_rgba'][:, 3]
def get_stroke_colors(self):
def get_stroke_colors(self) -> list[str]:
return [
rgb_to_hex(rgba[:3])
for rgba in self.data['stroke_rgba']
]
def get_stroke_opacities(self):
def get_stroke_opacities(self) -> np.ndarray:
return self.data['stroke_rgba'][:, 3]
def get_stroke_widths(self):
def get_stroke_widths(self) -> np.ndarray:
return self.data['stroke_width'][:, 0]
# TODO, it's weird for these to return the first of various lists
# rather than the full information
def get_fill_color(self):
def get_fill_color(self) -> str:
"""
If there are multiple colors (for gradient)
this returns the first one
"""
return self.get_fill_colors()[0]
def get_fill_opacity(self):
def get_fill_opacity(self) -> float:
"""
If there are multiple opacities, this returns the
first
"""
return self.get_fill_opacities()[0]
def get_stroke_color(self):
def get_stroke_color(self) -> str:
return self.get_stroke_colors()[0]
def get_stroke_width(self):
def get_stroke_width(self) -> float | np.ndarray:
return self.get_stroke_widths()[0]
def get_stroke_opacity(self):
def get_stroke_opacity(self) -> float:
return self.get_stroke_opacities()[0]
def get_color(self):
def get_color(self) -> str:
if self.has_fill():
return self.get_fill_color()
return self.get_stroke_color()
def has_stroke(self):
def has_stroke(self) -> bool:
return self.get_stroke_widths().any() and self.get_stroke_opacities().any()
def has_fill(self):
def has_fill(self) -> bool:
return any(self.get_fill_opacities())
def get_opacity(self):
def get_opacity(self) -> float:
if self.has_fill():
return self.get_fill_opacity()
return self.get_stroke_opacity()
def set_flat_stroke(self, flat_stroke=True, recurse=True):
def set_flat_stroke(self, flat_stroke: bool = True, recurse: bool = True):
for mob in self.get_family(recurse):
mob.flat_stroke = flat_stroke
return self
def get_flat_stroke(self):
def get_flat_stroke(self) -> bool:
return self.flat_stroke
# Points
def set_anchors_and_handles(self, anchors1, handles, anchors2):
def set_anchors_and_handles(
self,
anchors1: np.ndarray,
handles: np.ndarray,
anchors2: np.ndarray
):
assert(len(anchors1) == len(handles) == len(anchors2))
nppc = self.n_points_per_curve
new_points = np.zeros((nppc * len(anchors1), self.dim))
@ -320,16 +357,27 @@ class VMobject(Mobject):
self.set_points(new_points)
return self
def start_new_path(self, point):
def start_new_path(self, point: np.ndarray):
assert(self.get_num_points() % self.n_points_per_curve == 0)
self.append_points([point])
return self
def add_cubic_bezier_curve(self, anchor1, handle1, handle2, anchor2):
def add_cubic_bezier_curve(
self,
anchor1: npt.ArrayLike,
handle1: npt.ArrayLike,
handle2: npt.ArrayLike,
anchor2: npt.ArrayLike
):
new_points = get_quadratic_approximation_of_cubic(anchor1, handle1, handle2, anchor2)
self.append_points(new_points)
def add_cubic_bezier_curve_to(self, handle1, handle2, anchor):
def add_cubic_bezier_curve_to(
self,
handle1: npt.ArrayLike,
handle2: npt.ArrayLike,
anchor: npt.ArrayLike
):
"""
Add cubic bezier curve to the path.
"""
@ -342,14 +390,14 @@ class VMobject(Mobject):
else:
self.append_points(quadratic_approx)
def add_quadratic_bezier_curve_to(self, handle, anchor):
def add_quadratic_bezier_curve_to(self, handle: np.ndarray, anchor: np.ndarray):
self.throw_error_if_no_points()
if self.has_new_path_started():
self.append_points([handle, anchor])
else:
self.append_points([self.get_last_point(), handle, anchor])
def add_line_to(self, point):
def add_line_to(self, point: np.ndarray):
end = self.get_points()[-1]
alphas = np.linspace(0, 1, self.n_points_per_curve)
if self.long_lines:
@ -371,7 +419,7 @@ class VMobject(Mobject):
self.append_points(points)
return self
def add_smooth_curve_to(self, point):
def add_smooth_curve_to(self, point: np.ndarray):
if self.has_new_path_started():
self.add_line_to(point)
else:
@ -380,7 +428,7 @@ class VMobject(Mobject):
self.add_quadratic_bezier_curve_to(new_handle, point)
return self
def add_smooth_cubic_curve_to(self, handle, point):
def add_smooth_cubic_curve_to(self, handle: np.ndarray, point: np.ndarray):
self.throw_error_if_no_points()
if self.get_num_points() == 1:
new_handle = self.get_points()[-1]
@ -388,13 +436,13 @@ class VMobject(Mobject):
new_handle = self.get_reflection_of_last_handle()
self.add_cubic_bezier_curve_to(new_handle, handle, point)
def has_new_path_started(self):
def has_new_path_started(self) -> bool:
return self.get_num_points() % self.n_points_per_curve == 1
def get_last_point(self):
def get_last_point(self) -> np.ndarray:
return self.get_points()[-1]
def get_reflection_of_last_handle(self):
def get_reflection_of_last_handle(self) -> np.ndarray:
points = self.get_points()
return 2 * points[-1] - points[-2]
@ -402,12 +450,16 @@ class VMobject(Mobject):
if not self.is_closed():
self.add_line_to(self.get_subpaths()[-1][0])
def is_closed(self):
def is_closed(self) -> bool:
return self.consider_points_equals(
self.get_points()[0], self.get_points()[-1]
)
def subdivide_sharp_curves(self, angle_threshold=30 * DEGREES, recurse=True):
def subdivide_sharp_curves(
self,
angle_threshold: float = 30 * DEGREES,
recurse: bool = True
):
vmobs = [vm for vm in self.get_family(recurse) if vm.has_points()]
for vmob in vmobs:
new_points = []
@ -425,12 +477,12 @@ class VMobject(Mobject):
vmob.set_points(np.vstack(new_points))
return self
def add_points_as_corners(self, points):
def add_points_as_corners(self, points: Iterable[np.ndarray]):
for point in points:
self.add_line_to(point)
return points
def set_points_as_corners(self, points):
def set_points_as_corners(self, points: Iterable[np.ndarray]):
nppc = self.n_points_per_curve
points = np.array(points)
self.set_anchors_and_handles(*[
@ -439,7 +491,11 @@ class VMobject(Mobject):
])
return self
def set_points_smoothly(self, points, true_smooth=False):
def set_points_smoothly(
self,
points: Iterable[np.ndarray],
true_smooth: bool = False
):
self.set_points_as_corners(points)
if true_smooth:
self.make_smooth()
@ -447,7 +503,7 @@ class VMobject(Mobject):
self.make_approximately_smooth()
return self
def change_anchor_mode(self, mode):
def change_anchor_mode(self, mode: str):
assert(mode in ("jagged", "approx_smooth", "true_smooth"))
nppc = self.n_points_per_curve
for submob in self.family_members_with_points():
@ -492,12 +548,12 @@ class VMobject(Mobject):
self.change_anchor_mode("jagged")
return self
def add_subpath(self, points):
def add_subpath(self, points: Iterable[np.ndarray]):
assert(len(points) % self.n_points_per_curve == 0)
self.append_points(points)
return self
def append_vectorized_mobject(self, vectorized_mobject):
def append_vectorized_mobject(self, vectorized_mobject: VMobject):
new_points = list(vectorized_mobject.get_points())
if self.has_new_path_started():
@ -508,11 +564,11 @@ class VMobject(Mobject):
return self
#
def consider_points_equals(self, p0, p1):
def consider_points_equals(self, p0: np.ndarray, p1: np.ndarray) -> bool:
return get_norm(p1 - p0) < self.tolerance_for_point_equality
# Information about the curve
def get_bezier_tuples_from_points(self, points):
def get_bezier_tuples_from_points(self, points: Sequence[np.ndarray]):
nppc = self.n_points_per_curve
remainder = len(points) % nppc
points = points[:len(points) - remainder]
@ -524,7 +580,10 @@ class VMobject(Mobject):
def get_bezier_tuples(self):
return self.get_bezier_tuples_from_points(self.get_points())
def get_subpaths_from_points(self, points):
def get_subpaths_from_points(
self,
points: Sequence[np.ndarray]
) -> list[Sequence[np.ndarray]]:
nppc = self.n_points_per_curve
diffs = points[nppc - 1:-1:nppc] - points[nppc::nppc]
splits = (diffs * diffs).sum(1) > self.tolerance_for_point_equality
@ -541,28 +600,28 @@ class VMobject(Mobject):
if (i2 - i1) >= nppc
]
def get_subpaths(self):
def get_subpaths(self) -> list[Sequence[np.ndarray]]:
return self.get_subpaths_from_points(self.get_points())
def get_nth_curve_points(self, n):
def get_nth_curve_points(self, n: int) -> np.ndarray:
assert(n < self.get_num_curves())
nppc = self.n_points_per_curve
return self.get_points()[nppc * n:nppc * (n + 1)]
def get_nth_curve_function(self, n):
def get_nth_curve_function(self, n: int) -> Callable[[float], np.ndarray]:
return bezier(self.get_nth_curve_points(n))
def get_num_curves(self):
def get_num_curves(self) -> int:
return self.get_num_points() // self.n_points_per_curve
def quick_point_from_proportion(self, alpha):
def quick_point_from_proportion(self, alpha: float) -> np.ndarray:
# Assumes all curves have the same length, so is inaccurate
num_curves = self.get_num_curves()
n, residue = integer_interpolate(0, num_curves, alpha)
curve_func = self.get_nth_curve_function(n)
return curve_func(residue)
def point_from_proportion(self, alpha):
def point_from_proportion(self, alpha: float) -> np.ndarray:
if alpha <= 0:
return self.get_start()
elif alpha >= 1:
@ -584,7 +643,7 @@ class VMobject(Mobject):
residue = inverse_interpolate(partials[i - 1] / full, partials[i] / full, alpha)
return self.get_nth_curve_function(i - 1)(residue)
def get_anchors_and_handles(self):
def get_anchors_and_handles(self) -> list[np.ndarray]:
"""
returns anchors1, handles, anchors2,
where (anchors1[i], handles[i], anchors2[i])
@ -598,14 +657,14 @@ class VMobject(Mobject):
for i in range(nppc)
]
def get_start_anchors(self):
def get_start_anchors(self) -> np.ndarray:
return self.get_points()[0::self.n_points_per_curve]
def get_end_anchors(self):
def get_end_anchors(self) -> np.ndarray:
nppc = self.n_points_per_curve
return self.get_points()[nppc - 1::nppc]
def get_anchors(self):
def get_anchors(self) -> np.ndarray:
points = self.get_points()
if len(points) == 1:
return points
@ -614,7 +673,7 @@ class VMobject(Mobject):
self.get_end_anchors(),
))))
def get_points_without_null_curves(self, atol=1e-9):
def get_points_without_null_curves(self, atol: float=1e-9) -> np.ndarray:
nppc = self.n_points_per_curve
points = self.get_points()
distinct_curves = reduce(op.or_, [
@ -623,7 +682,7 @@ class VMobject(Mobject):
])
return points[distinct_curves.repeat(nppc)]
def get_arc_length(self, n_sample_points=None):
def get_arc_length(self, n_sample_points: int | None = None) -> float:
if n_sample_points is None:
n_sample_points = 4 * self.get_num_curves() + 1
points = np.array([
@ -634,7 +693,7 @@ class VMobject(Mobject):
norms = np.array([get_norm(d) for d in diffs])
return norms.sum()
def get_area_vector(self):
def get_area_vector(self) -> np.ndarray:
# Returns a vector whose length is the area bound by
# the polygon formed by the anchor points, pointing
# in a direction perpendicular to the polygon according
@ -654,7 +713,7 @@ class VMobject(Mobject):
sum((p0[:, 0] + p1[:, 0]) * (p1[:, 1] - p0[:, 1])), # Add up (x1 + x2)*(y2 - y1)
])
def get_unit_normal(self, recompute=False):
def get_unit_normal(self, recompute: bool = False) -> np.ndarray:
if not recompute:
return self.data["unit_normal"][0]
@ -680,7 +739,7 @@ class VMobject(Mobject):
return self
# Alignment
def align_points(self, vmobject):
def align_points(self, vmobject: VMobject):
if self.get_num_points() == len(vmobject.get_points()):
return
@ -723,7 +782,7 @@ class VMobject(Mobject):
vmobject.set_points(np.vstack(new_subpaths2))
return self
def insert_n_curves(self, n, recurse=True):
def insert_n_curves(self, n: int, recurse: bool = True):
for mob in self.get_family(recurse):
if mob.get_num_curves() > 0:
new_points = mob.insert_n_curves_to_point_list(n, mob.get_points())
@ -733,7 +792,7 @@ class VMobject(Mobject):
mob.set_points(new_points)
return self
def insert_n_curves_to_point_list(self, n, points):
def insert_n_curves_to_point_list(self, n: int, points: np.ndarray):
nppc = self.n_points_per_curve
if len(points) == 1:
return np.repeat(points, nppc * n, 0)
@ -766,7 +825,13 @@ class VMobject(Mobject):
new_points += partial_quadratic_bezier_points(group, a1, a2)
return np.vstack(new_points)
def interpolate(self, mobject1, mobject2, alpha, *args, **kwargs):
def interpolate(
self,
mobject1: VMobject,
mobject2: VMobject,
alpha: float,
*args, **kwargs
):
super().interpolate(mobject1, mobject2, alpha, *args, **kwargs)
if self.has_fill():
tri1 = mobject1.get_triangulation()
@ -775,7 +840,7 @@ class VMobject(Mobject):
self.refresh_triangulation()
return self
def pointwise_become_partial(self, vmobject, a, b):
def pointwise_become_partial(self, vmobject: VMobject, a: float, b: float):
assert(isinstance(vmobject, VMobject))
if a <= 0 and b >= 1:
self.become(vmobject)
@ -817,7 +882,7 @@ class VMobject(Mobject):
self.set_points(new_points)
return self
def get_subcurve(self, a, b):
def get_subcurve(self, a: float, b: float) -> VMobject:
vmob = self.copy()
vmob.pointwise_become_partial(self, a, b)
return vmob
@ -829,7 +894,7 @@ class VMobject(Mobject):
mob.needs_new_triangulation = True
return self
def get_triangulation(self, normal_vector=None):
def get_triangulation(self, normal_vector: np.ndarray | None = None):
# Figure out how to triangulate the interior to know
# how to send the points as to the vertex shader.
# First triangles come directly from the points
@ -898,25 +963,30 @@ class VMobject(Mobject):
return wrapper
@triggers_refreshed_triangulation
def set_points(self, points):
def set_points(self, points: npt.ArrayLike):
super().set_points(points)
return self
@triggers_refreshed_triangulation
def set_data(self, data):
def set_data(self, data: dict):
super().set_data(data)
return self
# TODO, how to be smart about tangents here?
@triggers_refreshed_triangulation
def apply_function(self, function, make_smooth=False, **kwargs):
def apply_function(
self,
function: Callable[[np.ndarray], np.ndarray],
make_smooth: bool = False,
**kwargs
):
super().apply_function(function, **kwargs)
if self.make_smooth_after_applying_functions or make_smooth:
self.make_approximately_smooth()
return self
def flip(self, *args, **kwargs):
super().flip(*args, **kwargs)
def flip(self, axis: np.ndarray = UP, **kwargs):
super().flip(axis, **kwargs)
self.refresh_unit_normal()
self.refresh_triangulation()
return self
@ -942,20 +1012,20 @@ class VMobject(Mobject):
wrapper.refresh_id()
return self
def get_fill_shader_wrapper(self):
def get_fill_shader_wrapper(self) -> ShaderWrapper:
self.fill_shader_wrapper.vert_data = self.get_fill_shader_data()
self.fill_shader_wrapper.vert_indices = self.get_fill_shader_vert_indices()
self.fill_shader_wrapper.uniforms = self.get_shader_uniforms()
self.fill_shader_wrapper.depth_test = self.depth_test
return self.fill_shader_wrapper
def get_stroke_shader_wrapper(self):
def get_stroke_shader_wrapper(self) -> ShaderWrapper:
self.stroke_shader_wrapper.vert_data = self.get_stroke_shader_data()
self.stroke_shader_wrapper.uniforms = self.get_stroke_uniforms()
self.stroke_shader_wrapper.depth_test = self.depth_test
return self.stroke_shader_wrapper
def get_shader_wrapper_list(self):
def get_shader_wrapper_list(self) -> list[ShaderWrapper]:
# Build up data lists
fill_shader_wrappers = []
stroke_shader_wrappers = []
@ -984,13 +1054,13 @@ class VMobject(Mobject):
result.append(wrapper)
return result
def get_stroke_uniforms(self):
def get_stroke_uniforms(self) -> dict[str, float]:
result = dict(super().get_shader_uniforms())
result["joint_type"] = JOINT_TYPE_MAP[self.joint_type]
result["flat_stroke"] = float(self.flat_stroke)
return result
def get_stroke_shader_data(self):
def get_stroke_shader_data(self) -> np.ndarray:
points = self.get_points()
if len(self.stroke_data) != len(points):
self.stroke_data = resize_array(self.stroke_data, len(points))
@ -1009,7 +1079,7 @@ class VMobject(Mobject):
return self.stroke_data
def get_fill_shader_data(self):
def get_fill_shader_data(self) -> np.ndarray:
points = self.get_points()
if len(self.fill_data) != len(points):
self.fill_data = resize_array(self.fill_data, len(points))
@ -1025,18 +1095,18 @@ class VMobject(Mobject):
self.get_fill_shader_data()
self.get_stroke_shader_data()
def get_fill_shader_vert_indices(self):
def get_fill_shader_vert_indices(self) -> np.ndarray:
return self.get_triangulation()
class VGroup(VMobject):
def __init__(self, *vmobjects, **kwargs):
def __init__(self, *vmobjects: VMobject, **kwargs):
if not all([isinstance(m, VMobject) for m in vmobjects]):
raise Exception("All submobjects must be of type VMobject")
super().__init__(**kwargs)
self.add(*vmobjects)
def __add__(self: 'VGroup', other: 'VMobject' or 'VGroup'):
def __add__(self, other: VMobject | VGroup):
assert(isinstance(other, VMobject))
return self.add(other)
@ -1050,14 +1120,14 @@ class VectorizedPoint(Point, VMobject):
"artificial_height": 0.01,
}
def __init__(self, location=ORIGIN, **kwargs):
def __init__(self, location: np.ndarray = ORIGIN, **kwargs):
Point.__init__(self, **kwargs)
VMobject.__init__(self, **kwargs)
self.set_points(np.array([location]))
class CurvesAsSubmobjects(VGroup):
def __init__(self, vmobject, **kwargs):
def __init__(self, vmobject: VMobject, **kwargs):
super().__init__(**kwargs)
for tup in vmobject.get_bezier_tuples():
part = VMobject()
@ -1073,7 +1143,7 @@ class DashedVMobject(VMobject):
"color": WHITE
}
def __init__(self, vmobject, **kwargs):
def __init__(self, vmobject: VMobject, **kwargs):
super().__init__(**kwargs)
num_dashes = self.num_dashes
ps_ratio = self.positive_space_ratio

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import numpy as np
from manimlib.mobject.mobject import Mobject
@ -15,11 +17,11 @@ class ValueTracker(Mobject):
"value_type": np.float64,
}
def __init__(self, value=0, **kwargs):
def __init__(self, value: float | complex = 0, **kwargs):
self.value = value
super().__init__(**kwargs)
def init_data(self):
def init_data(self) -> None:
super().init_data()
self.data["value"] = np.array(
listify(self.value),
@ -27,17 +29,17 @@ class ValueTracker(Mobject):
dtype=self.value_type,
)
def get_value(self):
def get_value(self) -> float | complex:
result = self.data["value"][0, :]
if len(result) == 1:
return result[0]
return result
def set_value(self, value):
def set_value(self, value: float | complex):
self.data["value"][0, :] = value
return self
def increment_value(self, d_value):
def increment_value(self, d_value: float | complex) -> None:
self.set_value(self.get_value() + d_value)
@ -48,10 +50,10 @@ class ExponentialValueTracker(ValueTracker):
behaves
"""
def get_value(self):
def get_value(self) -> float | complex:
return np.exp(ValueTracker.get_value(self))
def set_value(self, value):
def set_value(self, value: float | complex):
return ValueTracker.set_value(self, np.log(value))

View file

@ -1,9 +1,13 @@
import numpy as np
from __future__ import annotations
import itertools as it
import random
from typing import Sequence, TypeVar, Callable, Iterable
import numpy as np
import numpy.typing as npt
from manimlib.constants import *
from manimlib.animation.composition import AnimationGroup
from manimlib.animation.indication import VShowPassingFlash
from manimlib.mobject.geometry import Arrow
@ -18,8 +22,19 @@ from manimlib.utils.rate_functions import linear
from manimlib.utils.simple_functions import sigmoid
from manimlib.utils.space_ops import get_norm
from typing import TYPE_CHECKING
def get_vectorized_rgb_gradient_function(min_value, max_value, color_map):
if TYPE_CHECKING:
from manimlib.mobject.mobject import Mobject
from manimlib.mobject.coordinate_systems import CoordinateSystem
T = TypeVar("T")
def get_vectorized_rgb_gradient_function(
min_value: T,
max_value: T,
color_map: str
) -> Callable[[npt.ArrayLike], np.ndarray]:
rgbs = np.array(get_colormap_list(color_map))
def func(values):
@ -37,12 +52,19 @@ def get_vectorized_rgb_gradient_function(min_value, max_value, color_map):
return func
def get_rgb_gradient_function(min_value, max_value, color_map):
def get_rgb_gradient_function(
min_value: T,
max_value: T,
color_map: str
) -> Callable[[T], np.ndarray]:
vectorized_func = get_vectorized_rgb_gradient_function(min_value, max_value, color_map)
return lambda value: vectorized_func([value])[0]
def move_along_vector_field(mobject, func):
def move_along_vector_field(
mobject: Mobject,
func: Callable[[np.ndarray], np.ndarray]
) -> Mobject:
mobject.add_updater(
lambda m, dt: m.shift(
func(m.get_center()) * dt
@ -51,7 +73,10 @@ def move_along_vector_field(mobject, func):
return mobject
def move_submobjects_along_vector_field(mobject, func):
def move_submobjects_along_vector_field(
mobject: Mobject,
func: Callable[[np.ndarray], np.ndarray]
) -> Mobject:
def apply_nudge(mob, dt):
for submob in mob:
x, y = submob.get_center()[:2]
@ -62,7 +87,11 @@ def move_submobjects_along_vector_field(mobject, func):
return mobject
def move_points_along_vector_field(mobject, func, coordinate_system):
def move_points_along_vector_field(
mobject: Mobject,
func: Callable[[float, float], Iterable[float]],
coordinate_system: CoordinateSystem
) -> Mobject:
cs = coordinate_system
origin = cs.get_origin()
@ -74,7 +103,10 @@ def move_points_along_vector_field(mobject, func, coordinate_system):
return mobject
def get_sample_points_from_coordinate_system(coordinate_system, step_multiple):
def get_sample_points_from_coordinate_system(
coordinate_system: CoordinateSystem,
step_multiple: float
) -> it.product[tuple[np.ndarray, ...]]:
ranges = []
for range_args in coordinate_system.get_all_ranges():
_min, _max, step = range_args
@ -96,7 +128,12 @@ class VectorField(VGroup):
"vector_config": {},
}
def __init__(self, func, coordinate_system, **kwargs):
def __init__(
self,
func: Callable[[float, float], Sequence[float]],
coordinate_system: CoordinateSystem,
**kwargs
):
super().__init__(**kwargs)
self.func = func
self.coordinate_system = coordinate_system
@ -112,7 +149,7 @@ class VectorField(VGroup):
for coords in samples
))
def get_vector(self, coords, **kwargs):
def get_vector(self, coords: Iterable[float], **kwargs) -> Arrow:
vector_config = merge_dicts_recursively(
self.vector_config,
kwargs
@ -157,19 +194,24 @@ class StreamLines(VGroup):
"color_map": "3b1b_colormap",
}
def __init__(self, func, coordinate_system, **kwargs):
def __init__(
self,
func: Callable[[float, float], Sequence[float]],
coordinate_system: CoordinateSystem,
**kwargs
):
super().__init__(**kwargs)
self.func = func
self.coordinate_system = coordinate_system
self.draw_lines()
self.init_style()
def point_func(self, point):
def point_func(self, point: np.ndarray) -> np.ndarray:
in_coords = self.coordinate_system.p2c(point)
out_coords = self.func(*in_coords)
return self.coordinate_system.c2p(*out_coords)
def draw_lines(self):
def draw_lines(self) -> None:
lines = []
origin = self.coordinate_system.get_origin()
for point in self.get_start_points():
@ -194,7 +236,7 @@ class StreamLines(VGroup):
lines.append(line)
self.set_submobjects(lines)
def get_start_points(self):
def get_start_points(self) -> np.ndarray:
cs = self.coordinate_system
sample_coords = get_sample_points_from_coordinate_system(
cs, self.step_multiple,
@ -210,7 +252,7 @@ class StreamLines(VGroup):
for coords in sample_coords
])
def init_style(self):
def init_style(self) -> None:
if self.color_by_magnitude:
values_to_rgbs = get_vectorized_rgb_gradient_function(
*self.magnitude_range, self.color_map,
@ -247,7 +289,7 @@ class AnimatedStreamLines(VGroup):
},
}
def __init__(self, stream_lines, **kwargs):
def __init__(self, stream_lines: StreamLines, **kwargs):
super().__init__(**kwargs)
self.stream_lines = stream_lines
for line in stream_lines:
@ -262,7 +304,7 @@ class AnimatedStreamLines(VGroup):
self.add_updater(lambda m, dt: m.update(dt))
def update(self, dt):
def update(self, dt: float) -> None:
stream_lines = self.stream_lines
for line in stream_lines:
line.time += dt
@ -278,7 +320,7 @@ class ShowPassingFlashWithThinningStrokeWidth(AnimationGroup):
"remover": True
}
def __init__(self, vmobject, **kwargs):
def __init__(self, vmobject: VMobject, **kwargs):
digest_config(self, kwargs)
max_stroke_width = vmobject.get_stroke_width()
max_time_width = kwargs.pop("time_width", self.time_width)

View file

@ -1,12 +1,16 @@
import inspect
from __future__ import annotations
import time
import random
import inspect
import platform
import itertools as it
from functools import wraps
from typing import Iterable, Callable
from tqdm import tqdm as ProgressDisplay
import numpy as np
import time
import numpy.typing as npt
from manimlib.animation.animation import prepare_animation
from manimlib.animation.transform import MoveToTarget
@ -22,6 +26,12 @@ from manimlib.event_handler.event_type import EventType
from manimlib.event_handler import EVENT_DISPATCHER
from manimlib.logger import log
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from PIL.Image import Image
from manimlib.animation.animation import Animation
class Scene(object):
CONFIG = {
@ -50,13 +60,13 @@ class Scene(object):
else:
self.window = None
self.camera = self.camera_class(**self.camera_config)
self.camera: Camera = self.camera_class(**self.camera_config)
self.file_writer = SceneFileWriter(self, **self.file_writer_config)
self.mobjects = [self.camera.frame]
self.num_plays = 0
self.time = 0
self.skip_time = 0
self.original_skipping_status = self.skip_animations
self.mobjects: list[Mobject] = [self.camera.frame]
self.num_plays: int = 0
self.time: float = 0
self.skip_time: float = 0
self.original_skipping_status: bool = self.skip_animations
if self.start_at_animation_number is not None:
self.skip_animations = True
@ -70,9 +80,9 @@ class Scene(object):
random.seed(self.random_seed)
np.random.seed(self.random_seed)
def run(self):
self.virtual_animation_start_time = 0
self.real_animation_start_time = time.time()
def run(self) -> None:
self.virtual_animation_start_time: float = 0
self.real_animation_start_time: float = time.time()
self.file_writer.begin()
self.setup()
@ -82,7 +92,7 @@ class Scene(object):
pass
self.tear_down()
def setup(self):
def setup(self) -> None:
"""
This is meant to be implement by any scenes which
are comonly subclassed, and have some common setup
@ -90,18 +100,18 @@ class Scene(object):
"""
pass
def construct(self):
def construct(self) -> None:
# Where all the animation happens
# To be implemented in subclasses
pass
def tear_down(self):
def tear_down(self) -> None:
self.stop_skipping()
self.file_writer.finish()
if self.window and self.linger_after_completion:
self.interact()
def interact(self):
def interact(self) -> None:
# If there is a window, enter a loop
# which updates the frame while under
# the hood calling the pyglet event loop
@ -116,7 +126,7 @@ class Scene(object):
if self.quit_interaction:
self.unlock_mobject_data()
def embed(self, close_scene_on_exit=True):
def embed(self, close_scene_on_exit: bool = True) -> None:
if not self.preview:
# If the scene is just being
# written, ignore embed calls
@ -145,18 +155,18 @@ class Scene(object):
if close_scene_on_exit:
raise EndSceneEarlyException()
def __str__(self):
def __str__(self) -> str:
return self.__class__.__name__
# Only these methods should touch the camera
def get_image(self):
def get_image(self) -> Image:
return self.camera.get_image()
def show(self):
def show(self) -> None:
self.update_frame(ignore_skipping=True)
self.get_image().show()
def update_frame(self, dt=0, ignore_skipping=False):
def update_frame(self, dt: float = 0, ignore_skipping: bool = False) -> None:
self.increment_time(dt)
self.update_mobjects(dt)
if self.skip_animations and not ignore_skipping:
@ -174,22 +184,22 @@ class Scene(object):
if rt < vt:
self.update_frame(0)
def emit_frame(self):
def emit_frame(self) -> None:
if not self.skip_animations:
self.file_writer.write_frame(self.camera)
# Related to updating
def update_mobjects(self, dt):
def update_mobjects(self, dt: float) -> None:
for mobject in self.mobjects:
mobject.update(dt)
def should_update_mobjects(self):
def should_update_mobjects(self) -> bool:
return self.always_update_mobjects or any([
len(mob.get_family_updaters()) > 0
for mob in self.mobjects
])
def has_time_based_updaters(self):
def has_time_based_updaters(self) -> bool:
return any([
sm.has_time_based_updater()
for mob in self.mobjects()
@ -197,14 +207,14 @@ class Scene(object):
])
# Related to time
def get_time(self):
def get_time(self) -> float:
return self.time
def increment_time(self, dt):
def increment_time(self, dt: float) -> None:
self.time += dt
# Related to internal mobject organization
def get_top_level_mobjects(self):
def get_top_level_mobjects(self) -> list[Mobject]:
# Return only those which are not in the family
# of another mobject from the scene
mobjects = self.get_mobjects()
@ -218,10 +228,10 @@ class Scene(object):
return num_families == 1
return list(filter(is_top_level, mobjects))
def get_mobject_family_members(self):
def get_mobject_family_members(self) -> list[Mobject]:
return extract_mobject_family_members(self.mobjects)
def add(self, *new_mobjects):
def add(self, *new_mobjects: Mobject):
"""
Mobjects will be displayed, from background to
foreground in the order with which they are added.
@ -230,7 +240,7 @@ class Scene(object):
self.mobjects += new_mobjects
return self
def add_mobjects_among(self, values):
def add_mobjects_among(self, values: Iterable):
"""
This is meant mostly for quick prototyping,
e.g. to add all mobjects defined up to a point,
@ -242,17 +252,17 @@ class Scene(object):
))
return self
def remove(self, *mobjects_to_remove):
def remove(self, *mobjects_to_remove: Mobject):
self.mobjects = restructure_list_to_exclude_certain_family_members(
self.mobjects, mobjects_to_remove
)
return self
def bring_to_front(self, *mobjects):
def bring_to_front(self, *mobjects: Mobject):
self.add(*mobjects)
return self
def bring_to_back(self, *mobjects):
def bring_to_back(self, *mobjects: Mobject):
self.remove(*mobjects)
self.mobjects = list(mobjects) + self.mobjects
return self
@ -261,13 +271,18 @@ class Scene(object):
self.mobjects = []
return self
def get_mobjects(self):
def get_mobjects(self) -> list[Mobject]:
return list(self.mobjects)
def get_mobject_copies(self):
def get_mobject_copies(self) -> list[Mobject]:
return [m.copy() for m in self.mobjects]
def point_to_mobject(self, point, search_set=None, buff=0):
def point_to_mobject(
self,
point: np.ndarray,
search_set: Iterable[Mobject] | None = None,
buff: float = 0
) -> Mobject | None:
"""
E.g. if clicking on the scene, this returns the top layer mobject
under a given point
@ -280,7 +295,7 @@ class Scene(object):
return None
# Related to skipping
def update_skipping_status(self):
def update_skipping_status(self) -> None:
if self.start_at_animation_number is not None:
if self.num_plays == self.start_at_animation_number:
self.skip_time = self.time
@ -290,12 +305,18 @@ class Scene(object):
if self.num_plays >= self.end_at_animation_number:
raise EndSceneEarlyException()
def stop_skipping(self):
def stop_skipping(self) -> None:
self.virtual_animation_start_time = self.time
self.skip_animations = False
# Methods associated with running animations
def get_time_progression(self, run_time, n_iterations=None, desc="", override_skip_animations=False):
def get_time_progression(
self,
run_time: float,
n_iterations: int | None = None,
desc: str = "",
override_skip_animations: bool = False
) -> list[float] | np.ndarray | ProgressDisplay:
if self.skip_animations and not override_skip_animations:
return [run_time]
else:
@ -314,10 +335,13 @@ class Scene(object):
desc=desc,
)
def get_run_time(self, animations):
def get_run_time(self, animations: Iterable[Animation]) -> float:
return np.max([animation.run_time for animation in animations])
def get_animation_time_progression(self, animations):
def get_animation_time_progression(
self,
animations: Iterable[Animation]
) -> list[float] | np.ndarray | ProgressDisplay:
run_time = self.get_run_time(animations)
description = f"{self.num_plays} {animations[0]}"
if len(animations) > 1:
@ -325,14 +349,18 @@ class Scene(object):
time_progression = self.get_time_progression(run_time, desc=description)
return time_progression
def get_wait_time_progression(self, duration, stop_condition=None):
def get_wait_time_progression(
self,
duration: float,
stop_condition: Callable[[], bool] | None = None
) -> list[float] | np.ndarray | ProgressDisplay:
kw = {"desc": f"{self.num_plays} Waiting"}
if stop_condition is not None:
kw["n_iterations"] = -1 # So it doesn't show % progress
kw["override_skip_animations"] = True
return self.get_time_progression(duration, **kw)
def anims_from_play_args(self, *args, **kwargs):
def anims_from_play_args(self, *args, **kwargs) -> list[Animation]:
"""
Each arg can either be an animation, or a mobject method
followed by that methods arguments (and potentially follow
@ -422,7 +450,7 @@ class Scene(object):
self.num_plays += 1
return wrapper
def lock_static_mobject_data(self, *animations):
def lock_static_mobject_data(self, *animations: Animation) -> None:
movers = list(it.chain(*[
anim.mobject.get_family()
for anim in animations
@ -432,7 +460,7 @@ class Scene(object):
continue
self.camera.set_mobjects_as_static(mobject)
def unlock_mobject_data(self):
def unlock_mobject_data(self) -> None:
self.camera.release_static_mobjects()
def refresh_locked_data(self):
@ -440,7 +468,7 @@ class Scene(object):
self.lock_static_mobject_data()
return self
def begin_animations(self, animations):
def begin_animations(self, animations: Iterable[Animation]) -> None:
for animation in animations:
animation.begin()
# Anything animated that's not already in the
@ -451,7 +479,7 @@ class Scene(object):
if animation.mobject not in self.mobjects:
self.add(animation.mobject)
def progress_through_animations(self, animations):
def progress_through_animations(self, animations: Iterable[Animation]) -> None:
last_t = 0
for t in self.get_animation_time_progression(animations):
dt = t - last_t
@ -463,7 +491,7 @@ class Scene(object):
self.update_frame(dt)
self.emit_frame()
def finish_animations(self, animations):
def finish_animations(self, animations: Iterable[Animation]) -> None:
for animation in animations:
animation.finish()
animation.clean_up_from_scene(self)
@ -473,7 +501,7 @@ class Scene(object):
self.update_mobjects(0)
@handle_play_like_call
def play(self, *args, **kwargs):
def play(self, *args, **kwargs) -> None:
if len(args) == 0:
log.warning("Called Scene.play with no animations")
return
@ -485,11 +513,13 @@ class Scene(object):
self.unlock_mobject_data()
@handle_play_like_call
def wait(self,
duration=DEFAULT_WAIT_TIME,
stop_condition=None,
note=None,
ignore_presenter_mode=False):
def wait(
self,
duration: float = DEFAULT_WAIT_TIME,
stop_condition: Callable[[], bool] = None,
note: str = None,
ignore_presenter_mode: bool = False
):
if note:
log.info(note)
self.update_mobjects(dt=0) # Any problems with this?
@ -512,7 +542,11 @@ class Scene(object):
self.unlock_mobject_data()
return self
def wait_until(self, stop_condition, max_time=60):
def wait_until(
self,
stop_condition: Callable[[], bool],
max_time: float = 60
):
self.wait(max_time, stop_condition=stop_condition)
def force_skipping(self):
@ -525,14 +559,20 @@ class Scene(object):
self.skip_animations = self.original_skipping_status
return self
def add_sound(self, sound_file, time_offset=0, gain=None, **kwargs):
def add_sound(
self,
sound_file: str,
time_offset: float = 0,
gain: float | None = None,
gain_to_background: float | None = None
):
if self.skip_animations:
return
time = self.get_time() + time_offset
self.file_writer.add_sound(sound_file, time, gain, **kwargs)
self.file_writer.add_sound(sound_file, time, gain, gain_to_background)
# Helpers for interactive development
def save_state(self):
def save_state(self) -> None:
self.saved_state = {
"mobjects": self.mobjects,
"mobject_states": [
@ -541,7 +581,7 @@ class Scene(object):
],
}
def restore(self):
def restore(self) -> None:
if not hasattr(self, "saved_state"):
raise Exception("Trying to restore scene without having saved")
mobjects = self.saved_state["mobjects"]
@ -552,7 +592,11 @@ class Scene(object):
# Event handling
def on_mouse_motion(self, point, d_point):
def on_mouse_motion(
self,
point: np.ndarray,
d_point: np.ndarray
) -> None:
self.mouse_point.move_to(point)
event_data = {"point": point, "d_point": d_point}
@ -572,7 +616,13 @@ class Scene(object):
shift = np.dot(np.transpose(transform), shift)
frame.shift(shift)
def on_mouse_drag(self, point, d_point, buttons, modifiers):
def on_mouse_drag(
self,
point: np.ndarray,
d_point: np.ndarray,
buttons: int,
modifiers: int
) -> None:
self.mouse_drag_point.move_to(point)
event_data = {"point": point, "d_point": d_point, "buttons": buttons, "modifiers": modifiers}
@ -580,19 +630,33 @@ class Scene(object):
if propagate_event is not None and propagate_event is False:
return
def on_mouse_press(self, point, button, mods):
def on_mouse_press(
self,
point: np.ndarray,
button: int,
mods: int
) -> None:
event_data = {"point": point, "button": button, "mods": mods}
propagate_event = EVENT_DISPATCHER.dispatch(EventType.MousePressEvent, **event_data)
if propagate_event is not None and propagate_event is False:
return
def on_mouse_release(self, point, button, mods):
def on_mouse_release(
self,
point: np.ndarray,
button: int,
mods: int
) -> None:
event_data = {"point": point, "button": button, "mods": mods}
propagate_event = EVENT_DISPATCHER.dispatch(EventType.MouseReleaseEvent, **event_data)
if propagate_event is not None and propagate_event is False:
return
def on_mouse_scroll(self, point, offset):
def on_mouse_scroll(
self,
point: np.ndarray,
offset: np.ndarray
) -> None:
event_data = {"point": point, "offset": offset}
propagate_event = EVENT_DISPATCHER.dispatch(EventType.MouseScrollEvent, **event_data)
if propagate_event is not None and propagate_event is False:
@ -607,13 +671,21 @@ class Scene(object):
shift = np.dot(np.transpose(transform), offset)
frame.shift(-20.0 * shift)
def on_key_release(self, symbol, modifiers):
def on_key_release(
self,
symbol: int,
modifiers: int
) -> None:
event_data = {"symbol": symbol, "modifiers": modifiers}
propagate_event = EVENT_DISPATCHER.dispatch(EventType.KeyReleaseEvent, **event_data)
if propagate_event is not None and propagate_event is False:
return
def on_key_press(self, symbol, modifiers):
def on_key_press(
self,
symbol: int,
modifiers: int
) -> None:
try:
char = chr(symbol)
except OverflowError:
@ -634,16 +706,16 @@ class Scene(object):
elif char == "e" and modifiers == 3: # ctrl + shift + e
self.embed(close_scene_on_exit=False)
def on_resize(self, width: int, height: int):
def on_resize(self, width: int, height: int) -> None:
self.camera.reset_pixel_shape(width, height)
def on_show(self):
def on_show(self) -> None:
pass
def on_hide(self):
def on_hide(self) -> None:
pass
def on_close(self):
def on_close(self) -> None:
pass

View file

@ -1,10 +1,13 @@
import numpy as np
from pydub import AudioSegment
import shutil
import subprocess as sp
from __future__ import annotations
import os
import sys
import shutil
import platform
import subprocess as sp
import numpy as np
from pydub import AudioSegment
from tqdm import tqdm as ProgressDisplay
from manimlib.constants import FFMPEG_BIN
@ -15,6 +18,13 @@ from manimlib.utils.file_ops import get_sorted_integer_files
from manimlib.utils.sounds import get_full_sound_file_path
from manimlib.logger import log
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.scene.scene import Scene
from manimlib.camera.camera import Camera
from PIL.Image import Image
class SceneFileWriter(object):
CONFIG = {
@ -42,14 +52,14 @@ class SceneFileWriter(object):
def __init__(self, scene, **kwargs):
digest_config(self, kwargs)
self.scene = scene
self.writing_process = None
self.has_progress_display = False
self.scene: Scene = scene
self.writing_process: sp.Popen | None = None
self.has_progress_display: bool = False
self.init_output_directories()
self.init_audio()
# Output directories and files
def init_output_directories(self):
def init_output_directories(self) -> None:
out_dir = self.output_directory
if self.mirror_module_path:
module_dir = self.get_default_module_directory()
@ -69,13 +79,13 @@ class SceneFileWriter(object):
movie_dir, "partial_movie_files", scene_name,
))
def get_default_module_directory(self):
def get_default_module_directory(self) -> str:
path, _ = os.path.splitext(self.input_file_path)
if path.startswith("_"):
path = path[1:]
return path
def get_default_scene_name(self):
def get_default_scene_name(self) -> str:
name = str(self.scene)
saan = self.scene.start_at_animation_number
eaan = self.scene.end_at_animation_number
@ -85,7 +95,7 @@ class SceneFileWriter(object):
name += f"_{eaan}"
return name
def get_resolution_directory(self):
def get_resolution_directory(self) -> str:
pixel_height = self.scene.camera.pixel_height
frame_rate = self.scene.camera.frame_rate
return "{}p{}".format(
@ -93,10 +103,10 @@ class SceneFileWriter(object):
)
# Directory getters
def get_image_file_path(self):
def get_image_file_path(self) -> str:
return self.image_file_path
def get_next_partial_movie_path(self):
def get_next_partial_movie_path(self) -> str:
result = os.path.join(
self.partial_movie_directory,
"{:05}{}".format(
@ -106,19 +116,22 @@ class SceneFileWriter(object):
)
return result
def get_movie_file_path(self):
def get_movie_file_path(self) -> str:
return self.movie_file_path
# Sound
def init_audio(self):
self.includes_sound = False
def init_audio(self) -> None:
self.includes_sound: bool = False
def create_audio_segment(self):
def create_audio_segment(self) -> None:
self.audio_segment = AudioSegment.silent()
def add_audio_segment(self, new_segment,
time=None,
gain_to_background=None):
def add_audio_segment(
self,
new_segment: AudioSegment,
time: float | None = None,
gain_to_background: float | None = None
) -> None:
if not self.includes_sound:
self.includes_sound = True
self.create_audio_segment()
@ -142,27 +155,33 @@ class SceneFileWriter(object):
gain_during_overlay=gain_to_background,
)
def add_sound(self, sound_file, time=None, gain=None, **kwargs):
def add_sound(
self,
sound_file: str,
time: float | None = None,
gain: float | None = None,
gain_to_background: float | None = None
) -> None:
file_path = get_full_sound_file_path(sound_file)
new_segment = AudioSegment.from_file(file_path)
if gain:
new_segment = new_segment.apply_gain(gain)
self.add_audio_segment(new_segment, time, **kwargs)
self.add_audio_segment(new_segment, time, gain_to_background)
# Writers
def begin(self):
def begin(self) -> None:
if not self.break_into_partial_movies and self.write_to_movie:
self.open_movie_pipe(self.get_movie_file_path())
def begin_animation(self):
def begin_animation(self) -> None:
if self.break_into_partial_movies and self.write_to_movie:
self.open_movie_pipe(self.get_next_partial_movie_path())
def end_animation(self):
def end_animation(self) -> None:
if self.break_into_partial_movies and self.write_to_movie:
self.close_movie_pipe()
def finish(self):
def finish(self) -> None:
if self.write_to_movie:
if self.break_into_partial_movies:
self.combine_movie_files()
@ -177,7 +196,7 @@ class SceneFileWriter(object):
if self.should_open_file():
self.open_file()
def open_movie_pipe(self, file_path):
def open_movie_pipe(self, file_path: str) -> None:
stem, ext = os.path.splitext(file_path)
self.final_file_path = file_path
self.temp_file_path = stem + "_temp" + ext
@ -223,7 +242,7 @@ class SceneFileWriter(object):
)
self.has_progress_display = True
def set_progress_display_subdescription(self, sub_desc):
def set_progress_display_subdescription(self, sub_desc: str) -> None:
desc_len = self.progress_description_len
file = os.path.split(self.get_movie_file_path())[1]
full_desc = f"Rendering {file} ({sub_desc})"
@ -233,14 +252,14 @@ class SceneFileWriter(object):
full_desc += " " * (desc_len - len(full_desc))
self.progress_display.set_description(full_desc)
def write_frame(self, camera):
def write_frame(self, camera: Camera) -> None:
if self.write_to_movie:
raw_bytes = camera.get_raw_fbo_data()
self.writing_process.stdin.write(raw_bytes)
if self.has_progress_display:
self.progress_display.update()
def close_movie_pipe(self):
def close_movie_pipe(self) -> None:
self.writing_process.stdin.close()
self.writing_process.wait()
self.writing_process.terminate()
@ -248,7 +267,7 @@ class SceneFileWriter(object):
self.progress_display.close()
shutil.move(self.temp_file_path, self.final_file_path)
def combine_movie_files(self):
def combine_movie_files(self) -> None:
kwargs = {
"remove_non_integer_files": True,
"extension": self.movie_file_extension,
@ -296,7 +315,7 @@ class SceneFileWriter(object):
combine_process = sp.Popen(commands)
combine_process.wait()
def add_sound_to_video(self):
def add_sound_to_video(self) -> None:
movie_file_path = self.get_movie_file_path()
stem, ext = os.path.splitext(movie_file_path)
sound_file_path = stem + ".wav"
@ -327,22 +346,22 @@ class SceneFileWriter(object):
shutil.move(temp_file_path, movie_file_path)
os.remove(sound_file_path)
def save_final_image(self, image):
def save_final_image(self, image: Image) -> None:
file_path = self.get_image_file_path()
image.save(file_path)
self.print_file_ready_message(file_path)
def print_file_ready_message(self, file_path):
def print_file_ready_message(self, file_path: str) -> None:
if not self.quiet:
log.info(f"File ready at {file_path}")
def should_open_file(self):
def should_open_file(self) -> bool:
return any([
self.show_file_location_upon_completion,
self.open_file_upon_completion,
])
def open_file(self):
def open_file(self) -> None:
if self.quiet:
curr_stdout = sys.stdout
sys.stdout = open(os.devnull, "w")

View file

@ -1,8 +1,12 @@
from __future__ import annotations
import os
import re
import copy
from typing import Iterable
import moderngl
import numpy as np
import copy
from manimlib.utils.directories import get_shader_dir
from manimlib.utils.file_ops import find_file
@ -15,15 +19,16 @@ from manimlib.utils.file_ops import find_file
class ShaderWrapper(object):
def __init__(self,
vert_data=None,
vert_indices=None,
shader_folder=None,
uniforms=None, # A dictionary mapping names of uniform variables
texture_paths=None, # A dictionary mapping names to filepaths for textures.
depth_test=False,
render_primitive=moderngl.TRIANGLE_STRIP,
):
def __init__(
self,
vert_data: np.ndarray | None = None,
vert_indices: np.ndarray | None = None,
shader_folder: str | None = None,
uniforms: dict[str, float] | None = None, # A dictionary mapping names of uniform variables
texture_paths: dict[str, str] | None = None, # A dictionary mapping names to filepaths for textures.
depth_test: bool = False,
render_primitive: int = moderngl.TRIANGLE_STRIP,
):
self.vert_data = vert_data
self.vert_indices = vert_indices
self.vert_attributes = vert_data.dtype.names
@ -46,20 +51,20 @@ class ShaderWrapper(object):
result.texture_paths = dict(self.texture_paths)
return result
def is_valid(self):
def is_valid(self) -> bool:
return all([
self.vert_data is not None,
self.program_code["vertex_shader"] is not None,
self.program_code["fragment_shader"] is not None,
])
def get_id(self):
def get_id(self) -> str:
return self.id
def get_program_id(self):
def get_program_id(self) -> int:
return self.program_id
def create_id(self):
def create_id(self) -> str:
# A unique id for a shader
return "|".join(map(str, [
self.program_id,
@ -69,32 +74,32 @@ class ShaderWrapper(object):
self.render_primitive,
]))
def refresh_id(self):
def refresh_id(self) -> None:
self.program_id = self.create_program_id()
self.id = self.create_id()
def create_program_id(self):
def create_program_id(self) -> int:
return hash("".join((
self.program_code[f"{name}_shader"] or ""
for name in ("vertex", "geometry", "fragment")
)))
def init_program_code(self):
def get_code(name):
def init_program_code(self) -> None:
def get_code(name: str) -> str | None:
return get_shader_code_from_file(
os.path.join(self.shader_folder, f"{name}.glsl")
)
self.program_code = {
self.program_code: dict[str, str | None] = {
"vertex_shader": get_code("vert"),
"geometry_shader": get_code("geom"),
"fragment_shader": get_code("frag"),
}
def get_program_code(self):
def get_program_code(self) -> dict[str, str | None]:
return self.program_code
def replace_code(self, old, new):
def replace_code(self, old: str, new: str) -> None:
code_map = self.program_code
for (name, code) in code_map.items():
if code_map[name] is None:
@ -102,7 +107,7 @@ class ShaderWrapper(object):
code_map[name] = re.sub(old, new, code_map[name])
self.refresh_id()
def combine_with(self, *shader_wrappers):
def combine_with(self, *shader_wrappers: ShaderWrapper):
# Assume they are of the same type
if len(shader_wrappers) == 0:
return
@ -122,10 +127,10 @@ class ShaderWrapper(object):
# For caching
filename_to_code_map = {}
filename_to_code_map: dict[str, str] = {}
def get_shader_code_from_file(filename):
def get_shader_code_from_file(filename: str) -> str | None:
if not filename:
return None
if filename in filename_to_code_map:
@ -157,7 +162,7 @@ def get_shader_code_from_file(filename):
return result
def get_colormap_code(rgb_list):
def get_colormap_code(rgb_list: Iterable[float]) -> str:
data = ",".join(
"vec3({}, {}, {})".format(*rgb)
for rgb in rgb_list

View file

@ -1,5 +1,10 @@
from __future__ import annotations
from typing import Iterable, Callable, TypeVar, Sequence
from scipy import linalg
import numpy as np
import numpy.typing as npt
from manimlib.utils.simple_functions import choose
from manimlib.utils.space_ops import find_intersection
@ -8,9 +13,11 @@ from manimlib.utils.space_ops import midpoint
from manimlib.logger import log
CLOSED_THRESHOLD = 0.001
T = TypeVar("T")
def bezier(points):
def bezier(
points: Iterable[float | np.ndarray]
) -> Callable[[float], float | np.ndarray]:
n = len(points) - 1
def result(t):
@ -22,7 +29,11 @@ def bezier(points):
return result
def partial_bezier_points(points, a, b):
def partial_bezier_points(
points: Sequence[np.ndarray],
a: float,
b: float
) -> list[float]:
"""
Given an list of points which define
a bezier curve, and two numbers 0<=a<b<=1,
@ -48,7 +59,11 @@ def partial_bezier_points(points, a, b):
# Shortened version of partial_bezier_points just for quadratics,
# since this is called a fair amount
def partial_quadratic_bezier_points(points, a, b):
def partial_quadratic_bezier_points(
points: Sequence[np.ndarray],
a: float,
b: float
) -> list[float]:
if a == 1:
return 3 * [points[-1]]
@ -65,7 +80,7 @@ def partial_quadratic_bezier_points(points, a, b):
# Linear interpolation variants
def interpolate(start, end, alpha):
def interpolate(start: T, end: T, alpha: float) -> T:
try:
return (1 - alpha) * start + alpha * end
except TypeError:
@ -76,12 +91,22 @@ def interpolate(start, end, alpha):
sys.exit(2)
def set_array_by_interpolation(arr, arr1, arr2, alpha, interp_func=interpolate):
def set_array_by_interpolation(
arr: np.ndarray,
arr1: np.ndarray,
arr2: np.ndarray,
alpha: float,
interp_func: Callable[[np.ndarray, np.ndarray, float], np.ndarray] = interpolate
) -> np.ndarray:
arr[:] = interp_func(arr1, arr2, alpha)
return arr
def integer_interpolate(start, end, alpha):
def integer_interpolate(
start: T,
end: T,
alpha: float
) -> tuple[int, float]:
"""
alpha is a float between 0 and 1. This returns
an integer between start and end (inclusive) representing
@ -102,22 +127,30 @@ def integer_interpolate(start, end, alpha):
return (value, residue)
def mid(start, end):
def mid(start: T, end: T) -> T:
return (start + end) / 2.0
def inverse_interpolate(start, end, value):
def inverse_interpolate(start: T, end: T, value: T) -> float:
return np.true_divide(value - start, end - start)
def match_interpolate(new_start, new_end, old_start, old_end, old_value):
def match_interpolate(
new_start: T,
new_end: T,
old_start: T,
old_end: T,
old_value: T
) -> T:
return interpolate(
new_start, new_end,
inverse_interpolate(old_start, old_end, old_value)
)
def get_smooth_quadratic_bezier_handle_points(points):
def get_smooth_quadratic_bezier_handle_points(
points: Sequence[np.ndarray]
) -> np.ndarray | list[np.ndarray]:
"""
Figuring out which bezier curves most smoothly connect a sequence of points.
@ -149,7 +182,9 @@ def get_smooth_quadratic_bezier_handle_points(points):
return handles
def get_smooth_cubic_bezier_handle_points(points):
def get_smooth_cubic_bezier_handle_points(
points: npt.ArrayLike
) -> tuple[np.ndarray, np.ndarray]:
points = np.array(points)
num_handles = len(points) - 1
dim = points.shape[1]
@ -207,7 +242,10 @@ def get_smooth_cubic_bezier_handle_points(points):
return handle_pairs[0::2], handle_pairs[1::2]
def diag_to_matrix(l_and_u, diag):
def diag_to_matrix(
l_and_u: tuple[int, int],
diag: np.ndarray
) -> np.ndarray:
"""
Converts array whose rows represent diagonal
entries of a matrix into the matrix itself.
@ -224,13 +262,18 @@ def diag_to_matrix(l_and_u, diag):
return matrix
def is_closed(points):
def is_closed(points: Sequence[np.ndarray]) -> bool:
return np.allclose(points[0], points[-1])
# Given 4 control points for a cubic bezier curve (or arrays of such)
# return control points for 2 quadratics (or 2n quadratics) approximating them.
def get_quadratic_approximation_of_cubic(a0, h0, h1, a1):
def get_quadratic_approximation_of_cubic(
a0: npt.ArrayLike,
h0: npt.ArrayLike,
h1: npt.ArrayLike,
a1: npt.ArrayLike
) -> np.ndarray:
a0 = np.array(a0, ndmin=2)
h0 = np.array(h0, ndmin=2)
h1 = np.array(h1, ndmin=2)
@ -298,7 +341,9 @@ def get_quadratic_approximation_of_cubic(a0, h0, h1, a1):
return result
def get_smooth_quadratic_bezier_path_through(points):
def get_smooth_quadratic_bezier_path_through(
points: list[np.ndarray]
) -> np.ndarray:
# TODO
h0, h1 = get_smooth_cubic_bezier_handle_points(points)
a0 = points[:-1]

View file

@ -1,19 +1,31 @@
from __future__ import annotations
import time
import numpy as np
from typing import Callable
from manimlib.constants import BLACK
from manimlib.mobject.numbers import Integer
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.logger import log
from typing import TYPE_CHECKING
def print_family(mobject, n_tabs=0):
if TYPE_CHECKING:
from manimlib.mobject.mobject import Mobject
def print_family(mobject: Mobject, n_tabs: int = 0) -> None:
"""For debugging purposes"""
log.debug("\t" * n_tabs + str(mobject) + " " + str(id(mobject)))
for submob in mobject.submobjects:
print_family(submob, n_tabs + 1)
def index_labels(mobject, label_height=0.15):
def index_labels(
mobject: Mobject | np.ndarray,
label_height: float = 0.15
) -> VGroup:
labels = VGroup()
for n, submob in enumerate(mobject):
label = Integer(n)
@ -24,7 +36,7 @@ def index_labels(mobject, label_height=0.15):
return labels
def get_runtime(func):
def get_runtime(func: Callable) -> float:
now = time.time()
func()
return time.time() - now

View file

@ -1,48 +1,50 @@
from __future__ import annotations
import os
from manimlib.utils.file_ops import guarantee_existence
from manimlib.utils.customization import get_customization
def get_directories():
def get_directories() -> dict[str, str]:
return get_customization()["directories"]
def get_temp_dir():
def get_temp_dir() -> str:
return get_directories()["temporary_storage"]
def get_tex_dir():
def get_tex_dir() -> str:
return guarantee_existence(os.path.join(get_temp_dir(), "Tex"))
def get_text_dir():
def get_text_dir() -> str:
return guarantee_existence(os.path.join(get_temp_dir(), "Text"))
def get_mobject_data_dir():
def get_mobject_data_dir() -> str:
return guarantee_existence(os.path.join(get_temp_dir(), "mobject_data"))
def get_downloads_dir():
def get_downloads_dir() -> str:
return guarantee_existence(os.path.join(get_temp_dir(), "manim_downloads"))
def get_output_dir():
def get_output_dir() -> str:
return guarantee_existence(get_directories()["output"])
def get_raster_image_dir():
def get_raster_image_dir() -> str:
return get_directories()["raster_images"]
def get_vector_image_dir():
def get_vector_image_dir() -> str:
return get_directories()["vector_images"]
def get_sound_dir():
def get_sound_dir() -> str:
return get_directories()["sounds"]
def get_shader_dir():
def get_shader_dir() -> str:
return get_directories()["shaders"]

View file

@ -1,7 +1,18 @@
from __future__ import annotations
import itertools as it
from typing import Iterable
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.mobject.mobject import Mobject
def extract_mobject_family_members(mobject_list, only_those_with_points=False):
def extract_mobject_family_members(
mobject_list: Iterable[Mobject],
only_those_with_points: bool = False
) -> list[Mobject]:
result = list(it.chain(*[
mob.get_family()
for mob in mobject_list
@ -11,7 +22,10 @@ def extract_mobject_family_members(mobject_list, only_those_with_points=False):
return result
def restructure_list_to_exclude_certain_family_members(mobject_list, to_remove):
def restructure_list_to_exclude_certain_family_members(
mobject_list: list[Mobject],
to_remove: list[Mobject]
) -> list[Mobject]:
"""
Removes anything in to_remove from mobject_list, but in the event that one of
the items to be removed is a member of the family of an item in mobject_list,

View file

@ -1,9 +1,13 @@
from __future__ import annotations
import os
from typing import Iterable
import numpy as np
import validators
def add_extension_if_not_present(file_name, extension):
def add_extension_if_not_present(file_name: str, extension: str) -> str:
# This could conceivably be smarter about handling existing differing extensions
if(file_name[-len(extension):] != extension):
return file_name + extension
@ -11,13 +15,17 @@ def add_extension_if_not_present(file_name, extension):
return file_name
def guarantee_existence(path):
def guarantee_existence(path: str) -> str:
if not os.path.exists(path):
os.makedirs(path)
return os.path.abspath(path)
def find_file(file_name, directories=None, extensions=None):
def find_file(
file_name: str,
directories: Iterable[str] | None = None,
extensions: Iterable[str] | None = None
) -> str:
# Check if this is a file online first, and if so, download
# it to a temporary directory
if validators.url(file_name):
@ -47,13 +55,14 @@ def find_file(file_name, directories=None, extensions=None):
raise IOError(f"{file_name} not Found")
def get_sorted_integer_files(directory,
min_index=0,
max_index=np.inf,
remove_non_integer_files=False,
remove_indices_greater_than=None,
extension=None,
):
def get_sorted_integer_files(
directory: str,
min_index: float = 0,
max_index: float = np.inf,
remove_non_integer_files: bool = False,
remove_indices_greater_than: float | None = None,
extension: str | None = None,
) -> list[str]:
indexed_files = []
for file in os.listdir(directory):
if '.' in file:

View file

@ -1,12 +1,13 @@
import numpy as np
from PIL import Image
from typing import Iterable
from manimlib.utils.file_ops import find_file
from manimlib.utils.directories import get_raster_image_dir
from manimlib.utils.directories import get_vector_image_dir
def get_full_raster_image_path(image_file_name):
def get_full_raster_image_path(image_file_name: str) -> str:
return find_file(
image_file_name,
directories=[get_raster_image_dir()],
@ -14,7 +15,7 @@ def get_full_raster_image_path(image_file_name):
)
def get_full_vector_image_path(image_file_name):
def get_full_vector_image_path(image_file_name: str) -> str:
return find_file(
image_file_name,
directories=[get_vector_image_dir()],
@ -22,7 +23,7 @@ def get_full_vector_image_path(image_file_name):
)
def drag_pixels(frames):
def drag_pixels(frames: Iterable) -> list:
curr = frames[0]
new_frames = []
for frame in frames:
@ -31,7 +32,7 @@ def drag_pixels(frames):
return new_frames
def invert_image(image):
def invert_image(image: Iterable) -> Image:
arr = np.array(image)
arr = (255 * np.ones(arr.shape)).astype(arr.dtype) - arr
return Image.fromarray(arr)

View file

@ -1,7 +1,8 @@
import os
import yaml
import inspect
import importlib
import importlib
from typing import Any
from rich import box
from rich.rule import Rule
@ -10,13 +11,13 @@ from rich.console import Console
from rich.prompt import Prompt, Confirm
def get_manim_dir():
def get_manim_dir() -> str:
manimlib_module = importlib.import_module("manimlib")
manimlib_dir = os.path.dirname(inspect.getabsfile(manimlib_module))
return os.path.abspath(os.path.join(manimlib_dir, ".."))
def remove_empty_value(dictionary):
def remove_empty_value(dictionary: dict[str, Any]) -> None:
for key in list(dictionary.keys()):
if dictionary[key] == "":
dictionary.pop(key)
@ -24,7 +25,7 @@ def remove_empty_value(dictionary):
remove_empty_value(dictionary[key])
def init_customization():
def init_customization() -> None:
configuration = {
"directories": {
"mirror_module_path": False,

View file

@ -1,8 +1,15 @@
from __future__ import annotations
import itertools as it
from typing import Callable, Iterable, Sequence, TypeVar
import numpy as np
T = TypeVar("T")
S = TypeVar("S")
def remove_list_redundancies(l):
def remove_list_redundancies(l: Iterable[T]) -> list[T]:
"""
Used instead of list(set(l)) to maintain order
Keeps the last occurrence of each element
@ -17,7 +24,7 @@ def remove_list_redundancies(l):
return reversed_result
def list_update(l1, l2):
def list_update(l1: Iterable[T], l2: Iterable[T]) -> list[T]:
"""
Used instead of list(set(l1).update(l2)) to maintain order,
making sure duplicates are removed from l1, not l2.
@ -25,26 +32,29 @@ def list_update(l1, l2):
return [e for e in l1 if e not in l2] + list(l2)
def list_difference_update(l1, l2):
def list_difference_update(l1: Iterable[T], l2: Iterable[T]) -> list[T]:
return [e for e in l1 if e not in l2]
def all_elements_are_instances(iterable, Class):
def all_elements_are_instances(iterable: Iterable, Class: type) -> bool:
return all([isinstance(e, Class) for e in iterable])
def adjacent_n_tuples(objects, n):
def adjacent_n_tuples(objects: Iterable[T], n: int) -> zip[tuple[T, T]]:
return zip(*[
[*objects[k:], *objects[:k]]
for k in range(n)
])
def adjacent_pairs(objects):
def adjacent_pairs(objects: Iterable[T]) -> zip[tuple[T, T]]:
return adjacent_n_tuples(objects, 2)
def batch_by_property(items, property_func):
def batch_by_property(
items: Iterable[T],
property_func: Callable[[T], S]
) -> list[tuple[T, S]]:
"""
Takes in a list, and returns a list of tuples, (batch, prop)
such that all items in a batch have the same output when
@ -71,7 +81,7 @@ def batch_by_property(items, property_func):
return batch_prop_pairs
def listify(obj):
def listify(obj) -> list:
if isinstance(obj, str):
return [obj]
try:
@ -80,13 +90,13 @@ def listify(obj):
return [obj]
def resize_array(nparray, length):
def resize_array(nparray: np.ndarray, length: int) -> np.ndarray:
if len(nparray) == length:
return nparray
return np.resize(nparray, (length, *nparray.shape[1:]))
def resize_preserving_order(nparray, length):
def resize_preserving_order(nparray: np.ndarray, length: int) -> np.ndarray:
if len(nparray) == 0:
return np.zeros((length, *nparray.shape[1:]))
if len(nparray) == length:
@ -95,7 +105,7 @@ def resize_preserving_order(nparray, length):
return nparray[indices]
def resize_with_interpolation(nparray, length):
def resize_with_interpolation(nparray: np.ndarray, length: int) -> np.ndarray:
if len(nparray) == length:
return nparray
if length == 0:
@ -108,7 +118,10 @@ def resize_with_interpolation(nparray, length):
])
def make_even(iterable_1, iterable_2):
def make_even(
iterable_1: Sequence[T],
iterable_2: Sequence[S]
) -> tuple[list[T], list[S]]:
len1 = len(iterable_1)
len2 = len(iterable_2)
if len1 == len2:
@ -120,7 +133,10 @@ def make_even(iterable_1, iterable_2):
)
def make_even_by_cycling(iterable_1, iterable_2):
def make_even_by_cycling(
iterable_1: Iterable[T],
iterable_2: Iterable[S]
) -> tuple[list[T], list[S]]:
length = max(len(iterable_1), len(iterable_2))
cycle1 = it.cycle(iterable_1)
cycle2 = it.cycle(iterable_2)
@ -130,7 +146,7 @@ def make_even_by_cycling(iterable_1, iterable_2):
)
def remove_nones(sequence):
def remove_nones(sequence: Iterable) -> list:
return [x for x in sequence if x]
@ -141,7 +157,7 @@ def concatenate_lists(*list_of_lists):
return [item for l in list_of_lists for item in l]
def hash_obj(obj):
def hash_obj(obj: object) -> int:
if isinstance(obj, dict):
new_obj = {k: hash_obj(v) for k, v in obj.items()}
return hash(tuple(frozenset(sorted(new_obj.items()))))

View file

@ -1,5 +1,7 @@
import numpy as np
import math
from typing import Callable
import numpy as np
from manimlib.constants import OUT
from manimlib.utils.bezier import interpolate
@ -9,7 +11,11 @@ from manimlib.utils.space_ops import rotation_matrix_transpose
STRAIGHT_PATH_THRESHOLD = 0.01
def straight_path(start_points, end_points, alpha):
def straight_path(
start_points: np.ndarray,
end_points: np.ndarray,
alpha: float
) -> np.ndarray:
"""
Same function as interpolate, but renamed to reflect
intent of being used to determine how a set of points move
@ -19,7 +25,10 @@ def straight_path(start_points, end_points, alpha):
return interpolate(start_points, end_points, alpha)
def path_along_arc(arc_angle, axis=OUT):
def path_along_arc(
arc_angle: float,
axis: np.ndarray = OUT
) -> Callable[[np.ndarray, np.ndarray, float], np.ndarray]:
"""
If vect is vector from start to end, [vect[:,1], -vect[:,0]] is
perpendicular to vect in the left direction.
@ -41,9 +50,9 @@ def path_along_arc(arc_angle, axis=OUT):
return path
def clockwise_path():
def clockwise_path() -> Callable[[np.ndarray, np.ndarray, float], np.ndarray]:
return path_along_arc(-np.pi)
def counterclockwise_path():
def counterclockwise_path() -> Callable[[np.ndarray, np.ndarray, float], np.ndarray]:
return path_along_arc(np.pi)

View file

@ -1,44 +1,46 @@
from typing import Callable
import numpy as np
from manimlib.utils.bezier import bezier
def linear(t):
def linear(t: float) -> float:
return t
def smooth(t):
def smooth(t: float) -> float:
# Zero first and second derivatives at t=0 and t=1.
# Equivalent to bezier([0, 0, 0, 1, 1, 1])
s = 1 - t
return (t**3) * (10 * s * s + 5 * s * t + t * t)
def rush_into(t):
def rush_into(t: float) -> float:
return 2 * smooth(0.5 * t)
def rush_from(t):
def rush_from(t: float) -> float:
return 2 * smooth(0.5 * (t + 1)) - 1
def slow_into(t):
def slow_into(t: float) -> float:
return np.sqrt(1 - (1 - t) * (1 - t))
def double_smooth(t):
def double_smooth(t: float) -> float:
if t < 0.5:
return 0.5 * smooth(2 * t)
else:
return 0.5 * (1 + smooth(2 * t - 1))
def there_and_back(t):
def there_and_back(t: float) -> float:
new_t = 2 * t if t < 0.5 else 2 * (1 - t)
return smooth(new_t)
def there_and_back_with_pause(t, pause_ratio=1. / 3):
def there_and_back_with_pause(t: float, pause_ratio: float = 1. / 3) -> float:
a = 1. / pause_ratio
if t < 0.5 - pause_ratio / 2:
return smooth(a * t)
@ -48,21 +50,28 @@ def there_and_back_with_pause(t, pause_ratio=1. / 3):
return smooth(a - a * t)
def running_start(t, pull_factor=-0.5):
def running_start(t: float, pull_factor: float = -0.5) -> float:
return bezier([0, 0, pull_factor, pull_factor, 1, 1, 1])(t)
def not_quite_there(func=smooth, proportion=0.7):
def not_quite_there(
func: Callable[[float], float] = smooth,
proportion: float = 0.7
) -> Callable[[float], float]:
def result(t):
return proportion * func(t)
return result
def wiggle(t, wiggles=2):
def wiggle(t: float, wiggles: float = 2) -> float:
return there_and_back(t) * np.sin(wiggles * np.pi * t)
def squish_rate_func(func, a=0.4, b=0.6):
def squish_rate_func(
func: Callable[[float], float],
a: float = 0.4,
b: float = 0.6
) -> Callable[[float], float]:
def result(t):
if a == b:
return a
@ -81,11 +90,11 @@ def squish_rate_func(func, a=0.4, b=0.6):
# "lingering", different from squish_rate_func's default params
def lingering(t):
def lingering(t: float) -> float:
return squish_rate_func(lambda t: t, 0, 0.8)(t)
def exponential_decay(t, half_life=0.1):
def exponential_decay(t: float, half_life: float = 0.1) -> float:
# The half-life should be rather small to minimize
# the cut-off error at the end
return 1 - np.exp(-t / half_life)

View file

@ -2,7 +2,7 @@ from manimlib.utils.file_ops import find_file
from manimlib.utils.directories import get_sound_dir
def get_full_sound_file_path(sound_file_name):
def get_full_sound_file_path(sound_file_name) -> str:
return find_file(
sound_file_name,
directories=[get_sound_dir()],

View file

@ -1,7 +1,12 @@
import numpy as np
from __future__ import annotations
import math
import operator as op
from functools import reduce
import math
from typing import Callable, Iterable, Sequence
import numpy as np
import numpy.typing as npt
from mapbox_earcut import triangulate_float32 as earcut
from manimlib.constants import RIGHT
@ -13,7 +18,7 @@ from manimlib.utils.iterables import adjacent_pairs
from manimlib.utils.simple_functions import clip
def cross(v1, v2):
def cross(v1: np.ndarray, v2: np.ndarray) -> list[np.ndarray]:
return [
v1[1] * v2[2] - v1[2] * v2[1],
v1[2] * v2[0] - v1[0] * v2[2],
@ -21,7 +26,7 @@ def cross(v1, v2):
]
def get_norm(vect):
def get_norm(vect: np.ndarray) -> np.flaoting:
return sum((x**2 for x in vect))**0.5
@ -29,7 +34,7 @@ def get_norm(vect):
# TODO, implement quaternion type
def quaternion_mult(*quats):
def quaternion_mult(*quats: Sequence[float]) -> list[float]:
if len(quats) == 0:
return [1, 0, 0, 0]
result = quats[0]
@ -45,13 +50,19 @@ def quaternion_mult(*quats):
return result
def quaternion_from_angle_axis(angle, axis, axis_normalized=False):
def quaternion_from_angle_axis(
angle: float,
axis: np.ndarray,
axis_normalized: bool = False
) -> list[float]:
if not axis_normalized:
axis = normalize(axis)
return [math.cos(angle / 2), *(math.sin(angle / 2) * axis)]
def angle_axis_from_quaternion(quaternion):
def angle_axis_from_quaternion(
quaternion: Sequence[float]
) -> tuple[float, np.ndarray]:
axis = normalize(
quaternion[1:],
fall_back=[1, 0, 0]
@ -62,14 +73,18 @@ def angle_axis_from_quaternion(quaternion):
return angle, axis
def quaternion_conjugate(quaternion):
def quaternion_conjugate(quaternion: Iterable) -> list:
result = list(quaternion)
for i in range(1, len(result)):
result[i] *= -1
return result
def rotate_vector(vector, angle, axis=OUT):
def rotate_vector(
vector: Iterable,
angle: float,
axis: np.ndarray = OUT
) -> np.ndarray | list[float]:
if len(vector) == 2:
# Use complex numbers...because why not
z = complex(*vector) * np.exp(complex(0, angle))
@ -88,13 +103,13 @@ def rotate_vector(vector, angle, axis=OUT):
return result
def thick_diagonal(dim, thickness=2):
def thick_diagonal(dim: int, thickness: int = 2) -> np.ndarray:
row_indices = np.arange(dim).repeat(dim).reshape((dim, dim))
col_indices = np.transpose(row_indices)
return (np.abs(row_indices - col_indices) < thickness).astype('uint8')
def rotation_matrix_transpose_from_quaternion(quat):
def rotation_matrix_transpose_from_quaternion(quat: Iterable) -> list[list[float]]:
quat_inv = quaternion_conjugate(quat)
return [
quaternion_mult(quat, [0, *basis], quat_inv)[1:]
@ -106,11 +121,11 @@ def rotation_matrix_transpose_from_quaternion(quat):
]
def rotation_matrix_from_quaternion(quat):
def rotation_matrix_from_quaternion(quat: Iterable) -> np.ndarray:
return np.transpose(rotation_matrix_transpose_from_quaternion(quat))
def rotation_matrix_transpose(angle, axis):
def rotation_matrix_transpose(angle: float, axis: np.ndarray) -> list[list[float]]:
if axis[0] == 0 and axis[1] == 0:
# axis = [0, 0, z] case is common enough it's worth
# having a shortcut
@ -126,14 +141,14 @@ def rotation_matrix_transpose(angle, axis):
return rotation_matrix_transpose_from_quaternion(quat)
def rotation_matrix(angle, axis):
def rotation_matrix(angle: float, axis: np.ndarray) -> np.ndarray:
"""
Rotation in R^3 about a specified axis of rotation.
"""
return np.transpose(rotation_matrix_transpose(angle, axis))
def rotation_about_z(angle):
def rotation_about_z(angle: float) -> list[list[float]]:
return [
[math.cos(angle), -math.sin(angle), 0],
[math.sin(angle), math.cos(angle), 0],
@ -141,7 +156,7 @@ def rotation_about_z(angle):
]
def z_to_vector(vector):
def z_to_vector(vector: np.ndarray) -> np.ndarray:
"""
Returns some matrix in SO(3) which takes the z-axis to the
(normalized) vector provided as an argument
@ -156,7 +171,7 @@ def z_to_vector(vector):
return rotation_matrix(angle, axis=axis)
def rotation_between_vectors(v1, v2):
def rotation_between_vectors(v1, v2) -> np.ndarray:
if np.all(np.isclose(v1, v2)):
return np.identity(3)
return rotation_matrix(
@ -165,14 +180,14 @@ def rotation_between_vectors(v1, v2):
)
def angle_of_vector(vector):
def angle_of_vector(vector: Sequence[float]) -> float:
"""
Returns polar coordinate theta when vector is project on xy plane
"""
return np.angle(complex(*vector[:2]))
def angle_between_vectors(v1, v2):
def angle_between_vectors(v1: np.ndarray, v2: np.ndarray) -> float:
"""
Returns the angle between two 3D vectors.
This angle will always be btw 0 and pi
@ -180,12 +195,15 @@ def angle_between_vectors(v1, v2):
return math.acos(clip(np.dot(normalize(v1), normalize(v2)), -1, 1))
def project_along_vector(point, vector):
def project_along_vector(point: np.ndarray, vector: np.ndarray) -> np.ndarray:
matrix = np.identity(3) - np.outer(vector, vector)
return np.dot(point, matrix.T)
def normalize(vect, fall_back=None):
def normalize(
vect: np.ndarray,
fall_back: np.ndarray | None = None
) -> np.ndarray:
norm = get_norm(vect)
if norm > 0:
return np.array(vect) / norm
@ -195,7 +213,10 @@ def normalize(vect, fall_back=None):
return np.zeros(len(vect))
def normalize_along_axis(array, axis, fall_back=None):
def normalize_along_axis(
array: np.ndarray,
axis: np.ndarray,
) -> np.ndarray:
norms = np.sqrt((array * array).sum(axis))
norms[norms == 0] = 1
buffed_norms = np.repeat(norms, array.shape[axis]).reshape(array.shape)
@ -203,7 +224,11 @@ def normalize_along_axis(array, axis, fall_back=None):
return array
def get_unit_normal(v1, v2, tol=1e-6):
def get_unit_normal(
v1: np.ndarray,
v2: np.ndarray,
tol: float=1e-6
) -> np.ndarray:
v1 = normalize(v1)
v2 = normalize(v2)
cp = cross(v1, v2)
@ -221,7 +246,7 @@ def get_unit_normal(v1, v2, tol=1e-6):
###
def compass_directions(n=4, start_vect=RIGHT):
def compass_directions(n: int = 4, start_vect: np.ndarray = RIGHT) -> np.ndarray:
angle = TAU / n
return np.array([
rotate_vector(start_vect, k * angle)
@ -229,28 +254,36 @@ def compass_directions(n=4, start_vect=RIGHT):
])
def complex_to_R3(complex_num):
def complex_to_R3(complex_num: complex) -> np.ndarray:
return np.array((complex_num.real, complex_num.imag, 0))
def R3_to_complex(point):
def R3_to_complex(point: Sequence[float]) -> complex:
return complex(*point[:2])
def complex_func_to_R3_func(complex_func):
def complex_func_to_R3_func(
complex_func: Callable[[complex], complex]
) -> Callable[[np.ndarray], np.ndarray]:
return lambda p: complex_to_R3(complex_func(R3_to_complex(p)))
def center_of_mass(points):
def center_of_mass(points: Iterable[npt.ArrayLike]) -> np.ndarray:
points = [np.array(point).astype("float") for point in points]
return sum(points) / len(points)
def midpoint(point1, point2):
def midpoint(
point1: Sequence[float],
point2: Sequence[float]
) -> np.ndarray:
return center_of_mass([point1, point2])
def line_intersection(line1, line2):
def line_intersection(
line1: Sequence[Sequence[float]],
line2: Sequence[Sequence[float]]
) -> np.ndarray:
"""
return intersection point of two lines,
each defined with a pair of vectors determining
@ -271,7 +304,13 @@ def line_intersection(line1, line2):
return np.array([x, y, 0])
def find_intersection(p0, v0, p1, v1, threshold=1e-5):
def find_intersection(
p0: npt.ArrayLike,
v0: npt.ArrayLike,
p1: npt.ArrayLike,
v1: npt.ArrayLike,
threshold: float = 1e-5
) -> np.ndarray:
"""
Return the intersection of a line passing through p0 in direction v0
with one passing through p1 in direction v1. (Or array of intersections
@ -300,7 +339,11 @@ def find_intersection(p0, v0, p1, v1, threshold=1e-5):
return p0 + ratio * v0
def get_closest_point_on_line(a, b, p):
def get_closest_point_on_line(
a: np.ndarray,
b: np.ndarray,
p: np.ndarray
) -> np.ndarray:
"""
It returns point x such that
x is on line ab and xp is perpendicular to ab.
@ -315,7 +358,7 @@ def get_closest_point_on_line(a, b, p):
return ((t * a) + ((1 - t) * b))
def get_winding_number(points):
def get_winding_number(points: Iterable[float]) -> float:
total_angle = 0
for p1, p2 in adjacent_pairs(points):
d_angle = angle_of_vector(p2) - angle_of_vector(p1)
@ -326,14 +369,18 @@ def get_winding_number(points):
##
def cross2d(a, b):
def cross2d(a: np.ndarray, b: np.ndarray) -> np.ndarray:
if len(a.shape) == 2:
return a[:, 0] * b[:, 1] - a[:, 1] * b[:, 0]
else:
return a[0] * b[1] - b[0] * a[1]
def tri_area(a, b, c):
def tri_area(
a: Sequence[float],
b: Sequence[float],
c: Sequence[float]
) -> float:
return 0.5 * abs(
a[0] * (b[1] - c[1]) +
b[0] * (c[1] - a[1]) +
@ -341,7 +388,12 @@ def tri_area(a, b, c):
)
def is_inside_triangle(p, a, b, c):
def is_inside_triangle(
p: np.ndarray,
a: np.ndarray,
b: np.ndarray,
c: np.ndarray
) -> bool:
"""
Test if point p is inside triangle abc
"""
@ -353,12 +405,12 @@ def is_inside_triangle(p, a, b, c):
return np.all(crosses > 0) or np.all(crosses < 0)
def norm_squared(v):
def norm_squared(v: Sequence[float]) -> float:
return v[0] * v[0] + v[1] * v[1] + v[2] * v[2]
# TODO, fails for polygons drawn over themselves
def earclip_triangulation(verts, ring_ends):
def earclip_triangulation(verts: np.ndarray, ring_ends: list[int]) -> list:
"""
Returns a list of indices giving a triangulation
of a polygon, potentially with holes

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import numpy as np
import moderngl_window as mglw
from moderngl_window.context.pyglet.window import Window as PygletWindow
@ -7,6 +9,11 @@ from screeninfo import get_monitors
from manimlib.utils.config_ops import digest_config
from manimlib.utils.customization import get_customization
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.scene.scene import Scene
class Window(PygletWindow):
fullscreen = False
@ -15,7 +22,12 @@ class Window(PygletWindow):
vsync = True
cursor = True
def __init__(self, scene, size=(1280, 720), **kwargs):
def __init__(
self,
scene: Scene,
size: tuple[int, int] = (1280, 720),
**kwargs
):
super().__init__(size=size)
digest_config(self, kwargs)
@ -37,7 +49,7 @@ class Window(PygletWindow):
self.position = initial_position
self.position = initial_position
def find_initial_position(self, size):
def find_initial_position(self, size: tuple[int, int]) -> tuple[int, int]:
custom_position = get_customization()["window_position"]
monitors = get_monitors()
mon_index = get_customization()["window_monitor"]
@ -59,7 +71,12 @@ class Window(PygletWindow):
)
# Delegate event handling to scene
def pixel_coords_to_space_coords(self, px, py, relative=False):
def pixel_coords_to_space_coords(
self,
px: int,
py: int,
relative: bool = False
) -> np.ndarray:
pw, ph = self.size
fw, fh = self.scene.camera.get_frame_shape()
fc = self.scene.camera.get_frame_center()
@ -72,59 +89,59 @@ class Window(PygletWindow):
0
])
def on_mouse_motion(self, x, y, dx, dy):
def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> None:
super().on_mouse_motion(x, y, dx, dy)
point = self.pixel_coords_to_space_coords(x, y)
d_point = self.pixel_coords_to_space_coords(dx, dy, relative=True)
self.scene.on_mouse_motion(point, d_point)
def on_mouse_drag(self, x, y, dx, dy, buttons, modifiers):
def on_mouse_drag(self, x: int, y: int, dx: int, dy: int, buttons: int, modifiers: int) -> None:
super().on_mouse_drag(x, y, dx, dy, buttons, modifiers)
point = self.pixel_coords_to_space_coords(x, y)
d_point = self.pixel_coords_to_space_coords(dx, dy, relative=True)
self.scene.on_mouse_drag(point, d_point, buttons, modifiers)
def on_mouse_press(self, x: int, y: int, button, mods):
def on_mouse_press(self, x: int, y: int, button: int, mods: int) -> None:
super().on_mouse_press(x, y, button, mods)
point = self.pixel_coords_to_space_coords(x, y)
self.scene.on_mouse_press(point, button, mods)
def on_mouse_release(self, x: int, y: int, button, mods):
def on_mouse_release(self, x: int, y: int, button: int, mods: int) -> None:
super().on_mouse_release(x, y, button, mods)
point = self.pixel_coords_to_space_coords(x, y)
self.scene.on_mouse_release(point, button, mods)
def on_mouse_scroll(self, x, y, x_offset: float, y_offset: float):
def on_mouse_scroll(self, x: int, y: int, x_offset: float, y_offset: float) -> None:
super().on_mouse_scroll(x, y, x_offset, y_offset)
point = self.pixel_coords_to_space_coords(x, y)
offset = self.pixel_coords_to_space_coords(x_offset, y_offset, relative=True)
self.scene.on_mouse_scroll(point, offset)
def on_key_press(self, symbol, modifiers):
def on_key_press(self, symbol: int, modifiers: int) -> None:
self.pressed_keys.add(symbol) # Modifiers?
super().on_key_press(symbol, modifiers)
self.scene.on_key_press(symbol, modifiers)
def on_key_release(self, symbol, modifiers):
def on_key_release(self, symbol: int, modifiers: int) -> None:
self.pressed_keys.difference_update({symbol}) # Modifiers?
super().on_key_release(symbol, modifiers)
self.scene.on_key_release(symbol, modifiers)
def on_resize(self, width: int, height: int):
def on_resize(self, width: int, height: int) -> None:
super().on_resize(width, height)
self.scene.on_resize(width, height)
def on_show(self):
def on_show(self) -> None:
super().on_show()
self.scene.on_show()
def on_hide(self):
def on_hide(self) -> None:
super().on_hide()
self.scene.on_hide()
def on_close(self):
def on_close(self) -> None:
super().on_close()
self.scene.on_close()
def is_key_pressed(self, symbol):
def is_key_pressed(self, symbol: int) -> bool:
return (symbol in self.pressed_keys)

View file

@ -18,7 +18,6 @@ classifiers =
Topic :: Scientific/Engineering
Topic :: Multimedia :: Video
Topic :: Multimedia :: Graphics
Programming Language :: Python :: 3.6
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9