chore: add type hints to manimlib.scene

This commit is contained in:
TonyCrane 2022-02-14 21:22:18 +08:00
parent 09ce4717aa
commit be5de32d70
No known key found for this signature in database
GPG key ID: 2313A5058A9C637C
4 changed files with 198 additions and 109 deletions

View file

@ -297,7 +297,7 @@ class Camera(object):
dtype=dtype,
)
def get_image(self) -> Image:
def get_image(self) -> Image.Image:
return Image.frombytes(
'RGBA',
self.get_pixel_shape(),

View file

@ -16,7 +16,7 @@ class EventDispatcher(object):
}
self.mouse_point = np.array((0., 0., 0.))
self.mouse_drag_point = np.array((0., 0., 0.))
self.pressed_keys: set[str] = set()
self.pressed_keys: set[int] = set()
self.draggable_object_listners: list[EventListner] = []
def add_listner(self, event_listner: EventListner):
@ -86,7 +86,7 @@ class EventDispatcher(object):
def get_mouse_drag_point(self) -> np.ndarray:
return self.mouse_drag_point
def is_key_pressed(self, symbol) -> bool:
def is_key_pressed(self, symbol: int) -> bool:
return (symbol in self.pressed_keys)
__iadd__ = add_listner

View file

@ -1,12 +1,16 @@
import inspect
from __future__ import annotations
import time
import random
import inspect
import platform
import itertools as it
from functools import wraps
from typing import Iterable, Callable
from tqdm import tqdm as ProgressDisplay
import numpy as np
import time
import numpy.typing as npt
from manimlib.animation.animation import prepare_animation
from manimlib.animation.transform import MoveToTarget
@ -22,6 +26,11 @@ from manimlib.event_handler.event_type import EventType
from manimlib.event_handler import EVENT_DISPATCHER
from manimlib.logger import log
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from PIL.Image import Image
from manimlib.animation.animation import Animation
class Scene(object):
CONFIG = {
@ -50,13 +59,13 @@ class Scene(object):
else:
self.window = None
self.camera = self.camera_class(**self.camera_config)
self.camera: Camera = self.camera_class(**self.camera_config)
self.file_writer = SceneFileWriter(self, **self.file_writer_config)
self.mobjects = [self.camera.frame]
self.num_plays = 0
self.time = 0
self.skip_time = 0
self.original_skipping_status = self.skip_animations
self.mobjects: list[Mobject] = [self.camera.frame]
self.num_plays: int = 0
self.time: float = 0
self.skip_time: float = 0
self.original_skipping_status: bool = self.skip_animations
if self.start_at_animation_number is not None:
self.skip_animations = True
@ -70,9 +79,9 @@ class Scene(object):
random.seed(self.random_seed)
np.random.seed(self.random_seed)
def run(self):
self.virtual_animation_start_time = 0
self.real_animation_start_time = time.time()
def run(self) -> None:
self.virtual_animation_start_time: float = 0
self.real_animation_start_time: float = time.time()
self.file_writer.begin()
self.setup()
@ -82,7 +91,7 @@ class Scene(object):
pass
self.tear_down()
def setup(self):
def setup(self) -> None:
"""
This is meant to be implement by any scenes which
are comonly subclassed, and have some common setup
@ -90,18 +99,18 @@ class Scene(object):
"""
pass
def construct(self):
def construct(self) -> None:
# Where all the animation happens
# To be implemented in subclasses
pass
def tear_down(self):
def tear_down(self) -> None:
self.stop_skipping()
self.file_writer.finish()
if self.window and self.linger_after_completion:
self.interact()
def interact(self):
def interact(self) -> None:
# If there is a window, enter a loop
# which updates the frame while under
# the hood calling the pyglet event loop
@ -116,7 +125,7 @@ class Scene(object):
if self.quit_interaction:
self.unlock_mobject_data()
def embed(self, close_scene_on_exit=True):
def embed(self, close_scene_on_exit: bool = True) -> None:
if not self.preview:
# If the scene is just being
# written, ignore embed calls
@ -145,18 +154,18 @@ class Scene(object):
if close_scene_on_exit:
raise EndSceneEarlyException()
def __str__(self):
def __str__(self) -> str:
return self.__class__.__name__
# Only these methods should touch the camera
def get_image(self):
def get_image(self) -> Image:
return self.camera.get_image()
def show(self):
def show(self) -> None:
self.update_frame(ignore_skipping=True)
self.get_image().show()
def update_frame(self, dt=0, ignore_skipping=False):
def update_frame(self, dt: float = 0, ignore_skipping: bool = False) -> None:
self.increment_time(dt)
self.update_mobjects(dt)
if self.skip_animations and not ignore_skipping:
@ -174,22 +183,22 @@ class Scene(object):
if rt < vt:
self.update_frame(0)
def emit_frame(self):
def emit_frame(self) -> None:
if not self.skip_animations:
self.file_writer.write_frame(self.camera)
# Related to updating
def update_mobjects(self, dt):
def update_mobjects(self, dt: float) -> None:
for mobject in self.mobjects:
mobject.update(dt)
def should_update_mobjects(self):
def should_update_mobjects(self) -> bool:
return self.always_update_mobjects or any([
len(mob.get_family_updaters()) > 0
for mob in self.mobjects
])
def has_time_based_updaters(self):
def has_time_based_updaters(self) -> bool:
return any([
sm.has_time_based_updater()
for mob in self.mobjects()
@ -197,14 +206,14 @@ class Scene(object):
])
# Related to time
def get_time(self):
def get_time(self) -> float:
return self.time
def increment_time(self, dt):
def increment_time(self, dt: float) -> None:
self.time += dt
# Related to internal mobject organization
def get_top_level_mobjects(self):
def get_top_level_mobjects(self) -> list[Mobject]:
# Return only those which are not in the family
# of another mobject from the scene
mobjects = self.get_mobjects()
@ -218,10 +227,10 @@ class Scene(object):
return num_families == 1
return list(filter(is_top_level, mobjects))
def get_mobject_family_members(self):
def get_mobject_family_members(self) -> list[Mobject]:
return extract_mobject_family_members(self.mobjects)
def add(self, *new_mobjects):
def add(self, *new_mobjects: Mobject):
"""
Mobjects will be displayed, from background to
foreground in the order with which they are added.
@ -230,7 +239,7 @@ class Scene(object):
self.mobjects += new_mobjects
return self
def add_mobjects_among(self, values):
def add_mobjects_among(self, values: Iterable):
"""
This is meant mostly for quick prototyping,
e.g. to add all mobjects defined up to a point,
@ -242,17 +251,17 @@ class Scene(object):
))
return self
def remove(self, *mobjects_to_remove):
def remove(self, *mobjects_to_remove: Mobject):
self.mobjects = restructure_list_to_exclude_certain_family_members(
self.mobjects, mobjects_to_remove
)
return self
def bring_to_front(self, *mobjects):
def bring_to_front(self, *mobjects: Mobject):
self.add(*mobjects)
return self
def bring_to_back(self, *mobjects):
def bring_to_back(self, *mobjects: Mobject):
self.remove(*mobjects)
self.mobjects = list(mobjects) + self.mobjects
return self
@ -261,13 +270,18 @@ class Scene(object):
self.mobjects = []
return self
def get_mobjects(self):
def get_mobjects(self) -> list[Mobject]:
return list(self.mobjects)
def get_mobject_copies(self):
def get_mobject_copies(self) -> list[Mobject]:
return [m.copy() for m in self.mobjects]
def point_to_mobject(self, point, search_set=None, buff=0):
def point_to_mobject(
self,
point: np.ndarray,
search_set: Iterable[Mobject] | None = None,
buff: float = 0
) -> Mobject | None:
"""
E.g. if clicking on the scene, this returns the top layer mobject
under a given point
@ -280,7 +294,7 @@ class Scene(object):
return None
# Related to skipping
def update_skipping_status(self):
def update_skipping_status(self) -> None:
if self.start_at_animation_number is not None:
if self.num_plays == self.start_at_animation_number:
self.skip_time = self.time
@ -290,12 +304,18 @@ class Scene(object):
if self.num_plays >= self.end_at_animation_number:
raise EndSceneEarlyException()
def stop_skipping(self):
def stop_skipping(self) -> None:
self.virtual_animation_start_time = self.time
self.skip_animations = False
# Methods associated with running animations
def get_time_progression(self, run_time, n_iterations=None, desc="", override_skip_animations=False):
def get_time_progression(
self,
run_time: float,
n_iterations: int | None = None,
desc: str = "",
override_skip_animations: bool = False
) -> list[float] | np.ndarray | ProgressDisplay:
if self.skip_animations and not override_skip_animations:
return [run_time]
else:
@ -314,10 +334,13 @@ class Scene(object):
desc=desc,
)
def get_run_time(self, animations):
def get_run_time(self, animations: Iterable[Animation]) -> float:
return np.max([animation.run_time for animation in animations])
def get_animation_time_progression(self, animations):
def get_animation_time_progression(
self,
animations: Iterable[Animation]
) -> list[float] | np.ndarray | ProgressDisplay:
run_time = self.get_run_time(animations)
description = f"{self.num_plays} {animations[0]}"
if len(animations) > 1:
@ -325,14 +348,18 @@ class Scene(object):
time_progression = self.get_time_progression(run_time, desc=description)
return time_progression
def get_wait_time_progression(self, duration, stop_condition=None):
def get_wait_time_progression(
self,
duration: float,
stop_condition: Callable[[], bool] | None = None
) -> list[float] | np.ndarray | ProgressDisplay:
kw = {"desc": f"{self.num_plays} Waiting"}
if stop_condition is not None:
kw["n_iterations"] = -1 # So it doesn't show % progress
kw["override_skip_animations"] = True
return self.get_time_progression(duration, **kw)
def anims_from_play_args(self, *args, **kwargs):
def anims_from_play_args(self, *args, **kwargs) -> list[Animation]:
"""
Each arg can either be an animation, or a mobject method
followed by that methods arguments (and potentially follow
@ -422,7 +449,7 @@ class Scene(object):
self.num_plays += 1
return wrapper
def lock_static_mobject_data(self, *animations):
def lock_static_mobject_data(self, *animations: Animation) -> None:
movers = list(it.chain(*[
anim.mobject.get_family()
for anim in animations
@ -432,7 +459,7 @@ class Scene(object):
continue
self.camera.set_mobjects_as_static(mobject)
def unlock_mobject_data(self):
def unlock_mobject_data(self) -> None:
self.camera.release_static_mobjects()
def refresh_locked_data(self):
@ -440,7 +467,7 @@ class Scene(object):
self.lock_static_mobject_data()
return self
def begin_animations(self, animations):
def begin_animations(self, animations: Iterable[Animation]) -> None:
for animation in animations:
animation.begin()
# Anything animated that's not already in the
@ -451,7 +478,7 @@ class Scene(object):
if animation.mobject not in self.mobjects:
self.add(animation.mobject)
def progress_through_animations(self, animations):
def progress_through_animations(self, animations: Iterable[Animation]) -> None:
last_t = 0
for t in self.get_animation_time_progression(animations):
dt = t - last_t
@ -463,7 +490,7 @@ class Scene(object):
self.update_frame(dt)
self.emit_frame()
def finish_animations(self, animations):
def finish_animations(self, animations: Iterable[Animation]) -> None:
for animation in animations:
animation.finish()
animation.clean_up_from_scene(self)
@ -473,7 +500,7 @@ class Scene(object):
self.update_mobjects(0)
@handle_play_like_call
def play(self, *args, **kwargs):
def play(self, *args, **kwargs) -> None:
if len(args) == 0:
log.warning("Called Scene.play with no animations")
return
@ -485,11 +512,13 @@ class Scene(object):
self.unlock_mobject_data()
@handle_play_like_call
def wait(self,
duration=DEFAULT_WAIT_TIME,
stop_condition=None,
note=None,
ignore_presenter_mode=False):
def wait(
self,
duration: float = DEFAULT_WAIT_TIME,
stop_condition: Callable[[], bool] = None,
note: str = None,
ignore_presenter_mode: bool = False
):
if note:
log.info(note)
self.update_mobjects(dt=0) # Any problems with this?
@ -512,7 +541,11 @@ class Scene(object):
self.unlock_mobject_data()
return self
def wait_until(self, stop_condition, max_time=60):
def wait_until(
self,
stop_condition: Callable[[], bool],
max_time: float = 60
):
self.wait(max_time, stop_condition=stop_condition)
def force_skipping(self):
@ -525,14 +558,20 @@ class Scene(object):
self.skip_animations = self.original_skipping_status
return self
def add_sound(self, sound_file, time_offset=0, gain=None, **kwargs):
def add_sound(
self,
sound_file: str,
time_offset: float = 0,
gain: float | None = None,
gain_to_background: float | None = None
):
if self.skip_animations:
return
time = self.get_time() + time_offset
self.file_writer.add_sound(sound_file, time, gain, **kwargs)
self.file_writer.add_sound(sound_file, time, gain, gain_to_background)
# Helpers for interactive development
def save_state(self):
def save_state(self) -> None:
self.saved_state = {
"mobjects": self.mobjects,
"mobject_states": [
@ -541,7 +580,7 @@ class Scene(object):
],
}
def restore(self):
def restore(self) -> None:
if not hasattr(self, "saved_state"):
raise Exception("Trying to restore scene without having saved")
mobjects = self.saved_state["mobjects"]
@ -552,7 +591,11 @@ class Scene(object):
# Event handling
def on_mouse_motion(self, point, d_point):
def on_mouse_motion(
self,
point: np.ndarray,
d_point: np.ndarray
) -> None:
self.mouse_point.move_to(point)
event_data = {"point": point, "d_point": d_point}
@ -572,7 +615,13 @@ class Scene(object):
shift = np.dot(np.transpose(transform), shift)
frame.shift(shift)
def on_mouse_drag(self, point, d_point, buttons, modifiers):
def on_mouse_drag(
self,
point: np.ndarray,
d_point: np.ndarray,
buttons: int,
modifiers: int
) -> None:
self.mouse_drag_point.move_to(point)
event_data = {"point": point, "d_point": d_point, "buttons": buttons, "modifiers": modifiers}
@ -580,19 +629,33 @@ class Scene(object):
if propagate_event is not None and propagate_event is False:
return
def on_mouse_press(self, point, button, mods):
def on_mouse_press(
self,
point: np.ndarray,
button: int,
mods: int
) -> None:
event_data = {"point": point, "button": button, "mods": mods}
propagate_event = EVENT_DISPATCHER.dispatch(EventType.MousePressEvent, **event_data)
if propagate_event is not None and propagate_event is False:
return
def on_mouse_release(self, point, button, mods):
def on_mouse_release(
self,
point: np.ndarray,
button: int,
mods: int
) -> None:
event_data = {"point": point, "button": button, "mods": mods}
propagate_event = EVENT_DISPATCHER.dispatch(EventType.MouseReleaseEvent, **event_data)
if propagate_event is not None and propagate_event is False:
return
def on_mouse_scroll(self, point, offset):
def on_mouse_scroll(
self,
point: np.ndarray,
offset: np.ndarray
) -> None:
event_data = {"point": point, "offset": offset}
propagate_event = EVENT_DISPATCHER.dispatch(EventType.MouseScrollEvent, **event_data)
if propagate_event is not None and propagate_event is False:
@ -607,13 +670,21 @@ class Scene(object):
shift = np.dot(np.transpose(transform), offset)
frame.shift(-20.0 * shift)
def on_key_release(self, symbol, modifiers):
def on_key_release(
self,
symbol: int,
modifiers: int
) -> None:
event_data = {"symbol": symbol, "modifiers": modifiers}
propagate_event = EVENT_DISPATCHER.dispatch(EventType.KeyReleaseEvent, **event_data)
if propagate_event is not None and propagate_event is False:
return
def on_key_press(self, symbol, modifiers):
def on_key_press(
self,
symbol: int,
modifiers: int
) -> None:
try:
char = chr(symbol)
except OverflowError:
@ -634,16 +705,16 @@ class Scene(object):
elif char == "e":
self.embed(close_scene_on_exit=False)
def on_resize(self, width: int, height: int):
def on_resize(self, width: int, height: int) -> None:
self.camera.reset_pixel_shape(width, height)
def on_show(self):
def on_show(self) -> None:
pass
def on_hide(self):
def on_hide(self) -> None:
pass
def on_close(self):
def on_close(self) -> None:
pass

View file

@ -1,10 +1,13 @@
import numpy as np
from pydub import AudioSegment
import shutil
import subprocess as sp
from __future__ import annotations
import os
import sys
import shutil
import platform
import subprocess as sp
import numpy as np
from pydub import AudioSegment
from tqdm import tqdm as ProgressDisplay
from manimlib.constants import FFMPEG_BIN
@ -15,6 +18,12 @@ from manimlib.utils.file_ops import get_sorted_integer_files
from manimlib.utils.sounds import get_full_sound_file_path
from manimlib.logger import log
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.scene.scene import Scene
from manimlib.camera.camera import Camera
from PIL.Image import Image
class SceneFileWriter(object):
CONFIG = {
@ -42,14 +51,14 @@ class SceneFileWriter(object):
def __init__(self, scene, **kwargs):
digest_config(self, kwargs)
self.scene = scene
self.writing_process = None
self.has_progress_display = False
self.scene: Scene = scene
self.writing_process: sp.Popen | None = None
self.has_progress_display: bool = False
self.init_output_directories()
self.init_audio()
# Output directories and files
def init_output_directories(self):
def init_output_directories(self) -> None:
out_dir = self.output_directory
if self.mirror_module_path:
module_dir = self.get_default_module_directory()
@ -69,13 +78,13 @@ class SceneFileWriter(object):
movie_dir, "partial_movie_files", scene_name,
))
def get_default_module_directory(self):
def get_default_module_directory(self) -> str:
path, _ = os.path.splitext(self.input_file_path)
if path.startswith("_"):
path = path[1:]
return path
def get_default_scene_name(self):
def get_default_scene_name(self) -> str:
name = str(self.scene)
saan = self.scene.start_at_animation_number
eaan = self.scene.end_at_animation_number
@ -85,7 +94,7 @@ class SceneFileWriter(object):
name += f"_{eaan}"
return name
def get_resolution_directory(self):
def get_resolution_directory(self) -> str:
pixel_height = self.scene.camera.pixel_height
frame_rate = self.scene.camera.frame_rate
return "{}p{}".format(
@ -93,10 +102,10 @@ class SceneFileWriter(object):
)
# Directory getters
def get_image_file_path(self):
def get_image_file_path(self) -> str:
return self.image_file_path
def get_next_partial_movie_path(self):
def get_next_partial_movie_path(self) -> str:
result = os.path.join(
self.partial_movie_directory,
"{:05}{}".format(
@ -106,19 +115,22 @@ class SceneFileWriter(object):
)
return result
def get_movie_file_path(self):
def get_movie_file_path(self) -> str:
return self.movie_file_path
# Sound
def init_audio(self):
self.includes_sound = False
def init_audio(self) -> None:
self.includes_sound: bool = False
def create_audio_segment(self):
def create_audio_segment(self) -> None:
self.audio_segment = AudioSegment.silent()
def add_audio_segment(self, new_segment,
time=None,
gain_to_background=None):
def add_audio_segment(
self,
new_segment: AudioSegment,
time: float | None = None,
gain_to_background: float | None = None
) -> None:
if not self.includes_sound:
self.includes_sound = True
self.create_audio_segment()
@ -142,27 +154,33 @@ class SceneFileWriter(object):
gain_during_overlay=gain_to_background,
)
def add_sound(self, sound_file, time=None, gain=None, **kwargs):
def add_sound(
self,
sound_file: str,
time: float | None = None,
gain: float | None = None,
gain_to_background: float | None = None
) -> None:
file_path = get_full_sound_file_path(sound_file)
new_segment = AudioSegment.from_file(file_path)
if gain:
new_segment = new_segment.apply_gain(gain)
self.add_audio_segment(new_segment, time, **kwargs)
self.add_audio_segment(new_segment, time, gain_to_background)
# Writers
def begin(self):
def begin(self) -> None:
if not self.break_into_partial_movies and self.write_to_movie:
self.open_movie_pipe(self.get_movie_file_path())
def begin_animation(self):
def begin_animation(self) -> None:
if self.break_into_partial_movies and self.write_to_movie:
self.open_movie_pipe(self.get_next_partial_movie_path())
def end_animation(self):
def end_animation(self) -> None:
if self.break_into_partial_movies and self.write_to_movie:
self.close_movie_pipe()
def finish(self):
def finish(self) -> None:
if self.write_to_movie:
if self.break_into_partial_movies:
self.combine_movie_files()
@ -177,7 +195,7 @@ class SceneFileWriter(object):
if self.should_open_file():
self.open_file()
def open_movie_pipe(self, file_path):
def open_movie_pipe(self, file_path: str) -> None:
stem, ext = os.path.splitext(file_path)
self.final_file_path = file_path
self.temp_file_path = stem + "_temp" + ext
@ -223,7 +241,7 @@ class SceneFileWriter(object):
)
self.has_progress_display = True
def set_progress_display_subdescription(self, sub_desc):
def set_progress_display_subdescription(self, sub_desc: str) -> None:
desc_len = self.progress_description_len
file = os.path.split(self.get_movie_file_path())[1]
full_desc = f"Rendering {file} ({sub_desc})"
@ -233,14 +251,14 @@ class SceneFileWriter(object):
full_desc += " " * (desc_len - len(full_desc))
self.progress_display.set_description(full_desc)
def write_frame(self, camera):
def write_frame(self, camera: Camera) -> None:
if self.write_to_movie:
raw_bytes = camera.get_raw_fbo_data()
self.writing_process.stdin.write(raw_bytes)
if self.has_progress_display:
self.progress_display.update()
def close_movie_pipe(self):
def close_movie_pipe(self) -> None:
self.writing_process.stdin.close()
self.writing_process.wait()
self.writing_process.terminate()
@ -248,7 +266,7 @@ class SceneFileWriter(object):
self.progress_display.close()
shutil.move(self.temp_file_path, self.final_file_path)
def combine_movie_files(self):
def combine_movie_files(self) -> None:
kwargs = {
"remove_non_integer_files": True,
"extension": self.movie_file_extension,
@ -296,7 +314,7 @@ class SceneFileWriter(object):
combine_process = sp.Popen(commands)
combine_process.wait()
def add_sound_to_video(self):
def add_sound_to_video(self) -> None:
movie_file_path = self.get_movie_file_path()
stem, ext = os.path.splitext(movie_file_path)
sound_file_path = stem + ".wav"
@ -327,22 +345,22 @@ class SceneFileWriter(object):
shutil.move(temp_file_path, movie_file_path)
os.remove(sound_file_path)
def save_final_image(self, image):
def save_final_image(self, image: Image) -> None:
file_path = self.get_image_file_path()
image.save(file_path)
self.print_file_ready_message(file_path)
def print_file_ready_message(self, file_path):
def print_file_ready_message(self, file_path: str) -> None:
if not self.quiet:
log.info(f"File ready at {file_path}")
def should_open_file(self):
def should_open_file(self) -> bool:
return any([
self.show_file_location_upon_completion,
self.open_file_upon_completion,
])
def open_file(self):
def open_file(self) -> None:
if self.quiet:
curr_stdout = sys.stdout
sys.stdout = open(os.devnull, "w")