chore: add type hints to manimlib.scene

This commit is contained in:
TonyCrane 2022-02-14 21:22:18 +08:00
parent 09ce4717aa
commit be5de32d70
No known key found for this signature in database
GPG key ID: 2313A5058A9C637C
4 changed files with 198 additions and 109 deletions

View file

@ -297,7 +297,7 @@ class Camera(object):
dtype=dtype, dtype=dtype,
) )
def get_image(self) -> Image: def get_image(self) -> Image.Image:
return Image.frombytes( return Image.frombytes(
'RGBA', 'RGBA',
self.get_pixel_shape(), self.get_pixel_shape(),

View file

@ -16,7 +16,7 @@ class EventDispatcher(object):
} }
self.mouse_point = np.array((0., 0., 0.)) self.mouse_point = np.array((0., 0., 0.))
self.mouse_drag_point = np.array((0., 0., 0.)) self.mouse_drag_point = np.array((0., 0., 0.))
self.pressed_keys: set[str] = set() self.pressed_keys: set[int] = set()
self.draggable_object_listners: list[EventListner] = [] self.draggable_object_listners: list[EventListner] = []
def add_listner(self, event_listner: EventListner): def add_listner(self, event_listner: EventListner):
@ -86,7 +86,7 @@ class EventDispatcher(object):
def get_mouse_drag_point(self) -> np.ndarray: def get_mouse_drag_point(self) -> np.ndarray:
return self.mouse_drag_point return self.mouse_drag_point
def is_key_pressed(self, symbol) -> bool: def is_key_pressed(self, symbol: int) -> bool:
return (symbol in self.pressed_keys) return (symbol in self.pressed_keys)
__iadd__ = add_listner __iadd__ = add_listner

View file

@ -1,12 +1,16 @@
import inspect from __future__ import annotations
import time
import random import random
import inspect
import platform import platform
import itertools as it import itertools as it
from functools import wraps from functools import wraps
from typing import Iterable, Callable
from tqdm import tqdm as ProgressDisplay from tqdm import tqdm as ProgressDisplay
import numpy as np import numpy as np
import time import numpy.typing as npt
from manimlib.animation.animation import prepare_animation from manimlib.animation.animation import prepare_animation
from manimlib.animation.transform import MoveToTarget from manimlib.animation.transform import MoveToTarget
@ -22,6 +26,11 @@ from manimlib.event_handler.event_type import EventType
from manimlib.event_handler import EVENT_DISPATCHER from manimlib.event_handler import EVENT_DISPATCHER
from manimlib.logger import log from manimlib.logger import log
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from PIL.Image import Image
from manimlib.animation.animation import Animation
class Scene(object): class Scene(object):
CONFIG = { CONFIG = {
@ -50,13 +59,13 @@ class Scene(object):
else: else:
self.window = None self.window = None
self.camera = self.camera_class(**self.camera_config) self.camera: Camera = self.camera_class(**self.camera_config)
self.file_writer = SceneFileWriter(self, **self.file_writer_config) self.file_writer = SceneFileWriter(self, **self.file_writer_config)
self.mobjects = [self.camera.frame] self.mobjects: list[Mobject] = [self.camera.frame]
self.num_plays = 0 self.num_plays: int = 0
self.time = 0 self.time: float = 0
self.skip_time = 0 self.skip_time: float = 0
self.original_skipping_status = self.skip_animations self.original_skipping_status: bool = self.skip_animations
if self.start_at_animation_number is not None: if self.start_at_animation_number is not None:
self.skip_animations = True self.skip_animations = True
@ -70,9 +79,9 @@ class Scene(object):
random.seed(self.random_seed) random.seed(self.random_seed)
np.random.seed(self.random_seed) np.random.seed(self.random_seed)
def run(self): def run(self) -> None:
self.virtual_animation_start_time = 0 self.virtual_animation_start_time: float = 0
self.real_animation_start_time = time.time() self.real_animation_start_time: float = time.time()
self.file_writer.begin() self.file_writer.begin()
self.setup() self.setup()
@ -82,7 +91,7 @@ class Scene(object):
pass pass
self.tear_down() self.tear_down()
def setup(self): def setup(self) -> None:
""" """
This is meant to be implement by any scenes which This is meant to be implement by any scenes which
are comonly subclassed, and have some common setup are comonly subclassed, and have some common setup
@ -90,18 +99,18 @@ class Scene(object):
""" """
pass pass
def construct(self): def construct(self) -> None:
# Where all the animation happens # Where all the animation happens
# To be implemented in subclasses # To be implemented in subclasses
pass pass
def tear_down(self): def tear_down(self) -> None:
self.stop_skipping() self.stop_skipping()
self.file_writer.finish() self.file_writer.finish()
if self.window and self.linger_after_completion: if self.window and self.linger_after_completion:
self.interact() self.interact()
def interact(self): def interact(self) -> None:
# If there is a window, enter a loop # If there is a window, enter a loop
# which updates the frame while under # which updates the frame while under
# the hood calling the pyglet event loop # the hood calling the pyglet event loop
@ -116,7 +125,7 @@ class Scene(object):
if self.quit_interaction: if self.quit_interaction:
self.unlock_mobject_data() self.unlock_mobject_data()
def embed(self, close_scene_on_exit=True): def embed(self, close_scene_on_exit: bool = True) -> None:
if not self.preview: if not self.preview:
# If the scene is just being # If the scene is just being
# written, ignore embed calls # written, ignore embed calls
@ -145,18 +154,18 @@ class Scene(object):
if close_scene_on_exit: if close_scene_on_exit:
raise EndSceneEarlyException() raise EndSceneEarlyException()
def __str__(self): def __str__(self) -> str:
return self.__class__.__name__ return self.__class__.__name__
# Only these methods should touch the camera # Only these methods should touch the camera
def get_image(self): def get_image(self) -> Image:
return self.camera.get_image() return self.camera.get_image()
def show(self): def show(self) -> None:
self.update_frame(ignore_skipping=True) self.update_frame(ignore_skipping=True)
self.get_image().show() self.get_image().show()
def update_frame(self, dt=0, ignore_skipping=False): def update_frame(self, dt: float = 0, ignore_skipping: bool = False) -> None:
self.increment_time(dt) self.increment_time(dt)
self.update_mobjects(dt) self.update_mobjects(dt)
if self.skip_animations and not ignore_skipping: if self.skip_animations and not ignore_skipping:
@ -174,22 +183,22 @@ class Scene(object):
if rt < vt: if rt < vt:
self.update_frame(0) self.update_frame(0)
def emit_frame(self): def emit_frame(self) -> None:
if not self.skip_animations: if not self.skip_animations:
self.file_writer.write_frame(self.camera) self.file_writer.write_frame(self.camera)
# Related to updating # Related to updating
def update_mobjects(self, dt): def update_mobjects(self, dt: float) -> None:
for mobject in self.mobjects: for mobject in self.mobjects:
mobject.update(dt) mobject.update(dt)
def should_update_mobjects(self): def should_update_mobjects(self) -> bool:
return self.always_update_mobjects or any([ return self.always_update_mobjects or any([
len(mob.get_family_updaters()) > 0 len(mob.get_family_updaters()) > 0
for mob in self.mobjects for mob in self.mobjects
]) ])
def has_time_based_updaters(self): def has_time_based_updaters(self) -> bool:
return any([ return any([
sm.has_time_based_updater() sm.has_time_based_updater()
for mob in self.mobjects() for mob in self.mobjects()
@ -197,14 +206,14 @@ class Scene(object):
]) ])
# Related to time # Related to time
def get_time(self): def get_time(self) -> float:
return self.time return self.time
def increment_time(self, dt): def increment_time(self, dt: float) -> None:
self.time += dt self.time += dt
# Related to internal mobject organization # Related to internal mobject organization
def get_top_level_mobjects(self): def get_top_level_mobjects(self) -> list[Mobject]:
# Return only those which are not in the family # Return only those which are not in the family
# of another mobject from the scene # of another mobject from the scene
mobjects = self.get_mobjects() mobjects = self.get_mobjects()
@ -218,10 +227,10 @@ class Scene(object):
return num_families == 1 return num_families == 1
return list(filter(is_top_level, mobjects)) return list(filter(is_top_level, mobjects))
def get_mobject_family_members(self): def get_mobject_family_members(self) -> list[Mobject]:
return extract_mobject_family_members(self.mobjects) return extract_mobject_family_members(self.mobjects)
def add(self, *new_mobjects): def add(self, *new_mobjects: Mobject):
""" """
Mobjects will be displayed, from background to Mobjects will be displayed, from background to
foreground in the order with which they are added. foreground in the order with which they are added.
@ -230,7 +239,7 @@ class Scene(object):
self.mobjects += new_mobjects self.mobjects += new_mobjects
return self return self
def add_mobjects_among(self, values): def add_mobjects_among(self, values: Iterable):
""" """
This is meant mostly for quick prototyping, This is meant mostly for quick prototyping,
e.g. to add all mobjects defined up to a point, e.g. to add all mobjects defined up to a point,
@ -242,17 +251,17 @@ class Scene(object):
)) ))
return self return self
def remove(self, *mobjects_to_remove): def remove(self, *mobjects_to_remove: Mobject):
self.mobjects = restructure_list_to_exclude_certain_family_members( self.mobjects = restructure_list_to_exclude_certain_family_members(
self.mobjects, mobjects_to_remove self.mobjects, mobjects_to_remove
) )
return self return self
def bring_to_front(self, *mobjects): def bring_to_front(self, *mobjects: Mobject):
self.add(*mobjects) self.add(*mobjects)
return self return self
def bring_to_back(self, *mobjects): def bring_to_back(self, *mobjects: Mobject):
self.remove(*mobjects) self.remove(*mobjects)
self.mobjects = list(mobjects) + self.mobjects self.mobjects = list(mobjects) + self.mobjects
return self return self
@ -261,13 +270,18 @@ class Scene(object):
self.mobjects = [] self.mobjects = []
return self return self
def get_mobjects(self): def get_mobjects(self) -> list[Mobject]:
return list(self.mobjects) return list(self.mobjects)
def get_mobject_copies(self): def get_mobject_copies(self) -> list[Mobject]:
return [m.copy() for m in self.mobjects] return [m.copy() for m in self.mobjects]
def point_to_mobject(self, point, search_set=None, buff=0): def point_to_mobject(
self,
point: np.ndarray,
search_set: Iterable[Mobject] | None = None,
buff: float = 0
) -> Mobject | None:
""" """
E.g. if clicking on the scene, this returns the top layer mobject E.g. if clicking on the scene, this returns the top layer mobject
under a given point under a given point
@ -280,7 +294,7 @@ class Scene(object):
return None return None
# Related to skipping # Related to skipping
def update_skipping_status(self): def update_skipping_status(self) -> None:
if self.start_at_animation_number is not None: if self.start_at_animation_number is not None:
if self.num_plays == self.start_at_animation_number: if self.num_plays == self.start_at_animation_number:
self.skip_time = self.time self.skip_time = self.time
@ -290,12 +304,18 @@ class Scene(object):
if self.num_plays >= self.end_at_animation_number: if self.num_plays >= self.end_at_animation_number:
raise EndSceneEarlyException() raise EndSceneEarlyException()
def stop_skipping(self): def stop_skipping(self) -> None:
self.virtual_animation_start_time = self.time self.virtual_animation_start_time = self.time
self.skip_animations = False self.skip_animations = False
# Methods associated with running animations # Methods associated with running animations
def get_time_progression(self, run_time, n_iterations=None, desc="", override_skip_animations=False): def get_time_progression(
self,
run_time: float,
n_iterations: int | None = None,
desc: str = "",
override_skip_animations: bool = False
) -> list[float] | np.ndarray | ProgressDisplay:
if self.skip_animations and not override_skip_animations: if self.skip_animations and not override_skip_animations:
return [run_time] return [run_time]
else: else:
@ -314,10 +334,13 @@ class Scene(object):
desc=desc, desc=desc,
) )
def get_run_time(self, animations): def get_run_time(self, animations: Iterable[Animation]) -> float:
return np.max([animation.run_time for animation in animations]) return np.max([animation.run_time for animation in animations])
def get_animation_time_progression(self, animations): def get_animation_time_progression(
self,
animations: Iterable[Animation]
) -> list[float] | np.ndarray | ProgressDisplay:
run_time = self.get_run_time(animations) run_time = self.get_run_time(animations)
description = f"{self.num_plays} {animations[0]}" description = f"{self.num_plays} {animations[0]}"
if len(animations) > 1: if len(animations) > 1:
@ -325,14 +348,18 @@ class Scene(object):
time_progression = self.get_time_progression(run_time, desc=description) time_progression = self.get_time_progression(run_time, desc=description)
return time_progression return time_progression
def get_wait_time_progression(self, duration, stop_condition=None): def get_wait_time_progression(
self,
duration: float,
stop_condition: Callable[[], bool] | None = None
) -> list[float] | np.ndarray | ProgressDisplay:
kw = {"desc": f"{self.num_plays} Waiting"} kw = {"desc": f"{self.num_plays} Waiting"}
if stop_condition is not None: if stop_condition is not None:
kw["n_iterations"] = -1 # So it doesn't show % progress kw["n_iterations"] = -1 # So it doesn't show % progress
kw["override_skip_animations"] = True kw["override_skip_animations"] = True
return self.get_time_progression(duration, **kw) return self.get_time_progression(duration, **kw)
def anims_from_play_args(self, *args, **kwargs): def anims_from_play_args(self, *args, **kwargs) -> list[Animation]:
""" """
Each arg can either be an animation, or a mobject method Each arg can either be an animation, or a mobject method
followed by that methods arguments (and potentially follow followed by that methods arguments (and potentially follow
@ -422,7 +449,7 @@ class Scene(object):
self.num_plays += 1 self.num_plays += 1
return wrapper return wrapper
def lock_static_mobject_data(self, *animations): def lock_static_mobject_data(self, *animations: Animation) -> None:
movers = list(it.chain(*[ movers = list(it.chain(*[
anim.mobject.get_family() anim.mobject.get_family()
for anim in animations for anim in animations
@ -432,7 +459,7 @@ class Scene(object):
continue continue
self.camera.set_mobjects_as_static(mobject) self.camera.set_mobjects_as_static(mobject)
def unlock_mobject_data(self): def unlock_mobject_data(self) -> None:
self.camera.release_static_mobjects() self.camera.release_static_mobjects()
def refresh_locked_data(self): def refresh_locked_data(self):
@ -440,7 +467,7 @@ class Scene(object):
self.lock_static_mobject_data() self.lock_static_mobject_data()
return self return self
def begin_animations(self, animations): def begin_animations(self, animations: Iterable[Animation]) -> None:
for animation in animations: for animation in animations:
animation.begin() animation.begin()
# Anything animated that's not already in the # Anything animated that's not already in the
@ -451,7 +478,7 @@ class Scene(object):
if animation.mobject not in self.mobjects: if animation.mobject not in self.mobjects:
self.add(animation.mobject) self.add(animation.mobject)
def progress_through_animations(self, animations): def progress_through_animations(self, animations: Iterable[Animation]) -> None:
last_t = 0 last_t = 0
for t in self.get_animation_time_progression(animations): for t in self.get_animation_time_progression(animations):
dt = t - last_t dt = t - last_t
@ -463,7 +490,7 @@ class Scene(object):
self.update_frame(dt) self.update_frame(dt)
self.emit_frame() self.emit_frame()
def finish_animations(self, animations): def finish_animations(self, animations: Iterable[Animation]) -> None:
for animation in animations: for animation in animations:
animation.finish() animation.finish()
animation.clean_up_from_scene(self) animation.clean_up_from_scene(self)
@ -473,7 +500,7 @@ class Scene(object):
self.update_mobjects(0) self.update_mobjects(0)
@handle_play_like_call @handle_play_like_call
def play(self, *args, **kwargs): def play(self, *args, **kwargs) -> None:
if len(args) == 0: if len(args) == 0:
log.warning("Called Scene.play with no animations") log.warning("Called Scene.play with no animations")
return return
@ -485,11 +512,13 @@ class Scene(object):
self.unlock_mobject_data() self.unlock_mobject_data()
@handle_play_like_call @handle_play_like_call
def wait(self, def wait(
duration=DEFAULT_WAIT_TIME, self,
stop_condition=None, duration: float = DEFAULT_WAIT_TIME,
note=None, stop_condition: Callable[[], bool] = None,
ignore_presenter_mode=False): note: str = None,
ignore_presenter_mode: bool = False
):
if note: if note:
log.info(note) log.info(note)
self.update_mobjects(dt=0) # Any problems with this? self.update_mobjects(dt=0) # Any problems with this?
@ -512,7 +541,11 @@ class Scene(object):
self.unlock_mobject_data() self.unlock_mobject_data()
return self return self
def wait_until(self, stop_condition, max_time=60): def wait_until(
self,
stop_condition: Callable[[], bool],
max_time: float = 60
):
self.wait(max_time, stop_condition=stop_condition) self.wait(max_time, stop_condition=stop_condition)
def force_skipping(self): def force_skipping(self):
@ -525,14 +558,20 @@ class Scene(object):
self.skip_animations = self.original_skipping_status self.skip_animations = self.original_skipping_status
return self return self
def add_sound(self, sound_file, time_offset=0, gain=None, **kwargs): def add_sound(
self,
sound_file: str,
time_offset: float = 0,
gain: float | None = None,
gain_to_background: float | None = None
):
if self.skip_animations: if self.skip_animations:
return return
time = self.get_time() + time_offset time = self.get_time() + time_offset
self.file_writer.add_sound(sound_file, time, gain, **kwargs) self.file_writer.add_sound(sound_file, time, gain, gain_to_background)
# Helpers for interactive development # Helpers for interactive development
def save_state(self): def save_state(self) -> None:
self.saved_state = { self.saved_state = {
"mobjects": self.mobjects, "mobjects": self.mobjects,
"mobject_states": [ "mobject_states": [
@ -541,7 +580,7 @@ class Scene(object):
], ],
} }
def restore(self): def restore(self) -> None:
if not hasattr(self, "saved_state"): if not hasattr(self, "saved_state"):
raise Exception("Trying to restore scene without having saved") raise Exception("Trying to restore scene without having saved")
mobjects = self.saved_state["mobjects"] mobjects = self.saved_state["mobjects"]
@ -552,7 +591,11 @@ class Scene(object):
# Event handling # Event handling
def on_mouse_motion(self, point, d_point): def on_mouse_motion(
self,
point: np.ndarray,
d_point: np.ndarray
) -> None:
self.mouse_point.move_to(point) self.mouse_point.move_to(point)
event_data = {"point": point, "d_point": d_point} event_data = {"point": point, "d_point": d_point}
@ -572,7 +615,13 @@ class Scene(object):
shift = np.dot(np.transpose(transform), shift) shift = np.dot(np.transpose(transform), shift)
frame.shift(shift) frame.shift(shift)
def on_mouse_drag(self, point, d_point, buttons, modifiers): def on_mouse_drag(
self,
point: np.ndarray,
d_point: np.ndarray,
buttons: int,
modifiers: int
) -> None:
self.mouse_drag_point.move_to(point) self.mouse_drag_point.move_to(point)
event_data = {"point": point, "d_point": d_point, "buttons": buttons, "modifiers": modifiers} event_data = {"point": point, "d_point": d_point, "buttons": buttons, "modifiers": modifiers}
@ -580,19 +629,33 @@ class Scene(object):
if propagate_event is not None and propagate_event is False: if propagate_event is not None and propagate_event is False:
return return
def on_mouse_press(self, point, button, mods): def on_mouse_press(
self,
point: np.ndarray,
button: int,
mods: int
) -> None:
event_data = {"point": point, "button": button, "mods": mods} event_data = {"point": point, "button": button, "mods": mods}
propagate_event = EVENT_DISPATCHER.dispatch(EventType.MousePressEvent, **event_data) propagate_event = EVENT_DISPATCHER.dispatch(EventType.MousePressEvent, **event_data)
if propagate_event is not None and propagate_event is False: if propagate_event is not None and propagate_event is False:
return return
def on_mouse_release(self, point, button, mods): def on_mouse_release(
self,
point: np.ndarray,
button: int,
mods: int
) -> None:
event_data = {"point": point, "button": button, "mods": mods} event_data = {"point": point, "button": button, "mods": mods}
propagate_event = EVENT_DISPATCHER.dispatch(EventType.MouseReleaseEvent, **event_data) propagate_event = EVENT_DISPATCHER.dispatch(EventType.MouseReleaseEvent, **event_data)
if propagate_event is not None and propagate_event is False: if propagate_event is not None and propagate_event is False:
return return
def on_mouse_scroll(self, point, offset): def on_mouse_scroll(
self,
point: np.ndarray,
offset: np.ndarray
) -> None:
event_data = {"point": point, "offset": offset} event_data = {"point": point, "offset": offset}
propagate_event = EVENT_DISPATCHER.dispatch(EventType.MouseScrollEvent, **event_data) propagate_event = EVENT_DISPATCHER.dispatch(EventType.MouseScrollEvent, **event_data)
if propagate_event is not None and propagate_event is False: if propagate_event is not None and propagate_event is False:
@ -607,13 +670,21 @@ class Scene(object):
shift = np.dot(np.transpose(transform), offset) shift = np.dot(np.transpose(transform), offset)
frame.shift(-20.0 * shift) frame.shift(-20.0 * shift)
def on_key_release(self, symbol, modifiers): def on_key_release(
self,
symbol: int,
modifiers: int
) -> None:
event_data = {"symbol": symbol, "modifiers": modifiers} event_data = {"symbol": symbol, "modifiers": modifiers}
propagate_event = EVENT_DISPATCHER.dispatch(EventType.KeyReleaseEvent, **event_data) propagate_event = EVENT_DISPATCHER.dispatch(EventType.KeyReleaseEvent, **event_data)
if propagate_event is not None and propagate_event is False: if propagate_event is not None and propagate_event is False:
return return
def on_key_press(self, symbol, modifiers): def on_key_press(
self,
symbol: int,
modifiers: int
) -> None:
try: try:
char = chr(symbol) char = chr(symbol)
except OverflowError: except OverflowError:
@ -634,16 +705,16 @@ class Scene(object):
elif char == "e": elif char == "e":
self.embed(close_scene_on_exit=False) self.embed(close_scene_on_exit=False)
def on_resize(self, width: int, height: int): def on_resize(self, width: int, height: int) -> None:
self.camera.reset_pixel_shape(width, height) self.camera.reset_pixel_shape(width, height)
def on_show(self): def on_show(self) -> None:
pass pass
def on_hide(self): def on_hide(self) -> None:
pass pass
def on_close(self): def on_close(self) -> None:
pass pass

View file

@ -1,10 +1,13 @@
import numpy as np from __future__ import annotations
from pydub import AudioSegment
import shutil
import subprocess as sp
import os import os
import sys import sys
import shutil
import platform import platform
import subprocess as sp
import numpy as np
from pydub import AudioSegment
from tqdm import tqdm as ProgressDisplay from tqdm import tqdm as ProgressDisplay
from manimlib.constants import FFMPEG_BIN from manimlib.constants import FFMPEG_BIN
@ -15,6 +18,12 @@ from manimlib.utils.file_ops import get_sorted_integer_files
from manimlib.utils.sounds import get_full_sound_file_path from manimlib.utils.sounds import get_full_sound_file_path
from manimlib.logger import log from manimlib.logger import log
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.scene.scene import Scene
from manimlib.camera.camera import Camera
from PIL.Image import Image
class SceneFileWriter(object): class SceneFileWriter(object):
CONFIG = { CONFIG = {
@ -42,14 +51,14 @@ class SceneFileWriter(object):
def __init__(self, scene, **kwargs): def __init__(self, scene, **kwargs):
digest_config(self, kwargs) digest_config(self, kwargs)
self.scene = scene self.scene: Scene = scene
self.writing_process = None self.writing_process: sp.Popen | None = None
self.has_progress_display = False self.has_progress_display: bool = False
self.init_output_directories() self.init_output_directories()
self.init_audio() self.init_audio()
# Output directories and files # Output directories and files
def init_output_directories(self): def init_output_directories(self) -> None:
out_dir = self.output_directory out_dir = self.output_directory
if self.mirror_module_path: if self.mirror_module_path:
module_dir = self.get_default_module_directory() module_dir = self.get_default_module_directory()
@ -69,13 +78,13 @@ class SceneFileWriter(object):
movie_dir, "partial_movie_files", scene_name, movie_dir, "partial_movie_files", scene_name,
)) ))
def get_default_module_directory(self): def get_default_module_directory(self) -> str:
path, _ = os.path.splitext(self.input_file_path) path, _ = os.path.splitext(self.input_file_path)
if path.startswith("_"): if path.startswith("_"):
path = path[1:] path = path[1:]
return path return path
def get_default_scene_name(self): def get_default_scene_name(self) -> str:
name = str(self.scene) name = str(self.scene)
saan = self.scene.start_at_animation_number saan = self.scene.start_at_animation_number
eaan = self.scene.end_at_animation_number eaan = self.scene.end_at_animation_number
@ -85,7 +94,7 @@ class SceneFileWriter(object):
name += f"_{eaan}" name += f"_{eaan}"
return name return name
def get_resolution_directory(self): def get_resolution_directory(self) -> str:
pixel_height = self.scene.camera.pixel_height pixel_height = self.scene.camera.pixel_height
frame_rate = self.scene.camera.frame_rate frame_rate = self.scene.camera.frame_rate
return "{}p{}".format( return "{}p{}".format(
@ -93,10 +102,10 @@ class SceneFileWriter(object):
) )
# Directory getters # Directory getters
def get_image_file_path(self): def get_image_file_path(self) -> str:
return self.image_file_path return self.image_file_path
def get_next_partial_movie_path(self): def get_next_partial_movie_path(self) -> str:
result = os.path.join( result = os.path.join(
self.partial_movie_directory, self.partial_movie_directory,
"{:05}{}".format( "{:05}{}".format(
@ -106,19 +115,22 @@ class SceneFileWriter(object):
) )
return result return result
def get_movie_file_path(self): def get_movie_file_path(self) -> str:
return self.movie_file_path return self.movie_file_path
# Sound # Sound
def init_audio(self): def init_audio(self) -> None:
self.includes_sound = False self.includes_sound: bool = False
def create_audio_segment(self): def create_audio_segment(self) -> None:
self.audio_segment = AudioSegment.silent() self.audio_segment = AudioSegment.silent()
def add_audio_segment(self, new_segment, def add_audio_segment(
time=None, self,
gain_to_background=None): new_segment: AudioSegment,
time: float | None = None,
gain_to_background: float | None = None
) -> None:
if not self.includes_sound: if not self.includes_sound:
self.includes_sound = True self.includes_sound = True
self.create_audio_segment() self.create_audio_segment()
@ -142,27 +154,33 @@ class SceneFileWriter(object):
gain_during_overlay=gain_to_background, gain_during_overlay=gain_to_background,
) )
def add_sound(self, sound_file, time=None, gain=None, **kwargs): def add_sound(
self,
sound_file: str,
time: float | None = None,
gain: float | None = None,
gain_to_background: float | None = None
) -> None:
file_path = get_full_sound_file_path(sound_file) file_path = get_full_sound_file_path(sound_file)
new_segment = AudioSegment.from_file(file_path) new_segment = AudioSegment.from_file(file_path)
if gain: if gain:
new_segment = new_segment.apply_gain(gain) new_segment = new_segment.apply_gain(gain)
self.add_audio_segment(new_segment, time, **kwargs) self.add_audio_segment(new_segment, time, gain_to_background)
# Writers # Writers
def begin(self): def begin(self) -> None:
if not self.break_into_partial_movies and self.write_to_movie: if not self.break_into_partial_movies and self.write_to_movie:
self.open_movie_pipe(self.get_movie_file_path()) self.open_movie_pipe(self.get_movie_file_path())
def begin_animation(self): def begin_animation(self) -> None:
if self.break_into_partial_movies and self.write_to_movie: if self.break_into_partial_movies and self.write_to_movie:
self.open_movie_pipe(self.get_next_partial_movie_path()) self.open_movie_pipe(self.get_next_partial_movie_path())
def end_animation(self): def end_animation(self) -> None:
if self.break_into_partial_movies and self.write_to_movie: if self.break_into_partial_movies and self.write_to_movie:
self.close_movie_pipe() self.close_movie_pipe()
def finish(self): def finish(self) -> None:
if self.write_to_movie: if self.write_to_movie:
if self.break_into_partial_movies: if self.break_into_partial_movies:
self.combine_movie_files() self.combine_movie_files()
@ -177,7 +195,7 @@ class SceneFileWriter(object):
if self.should_open_file(): if self.should_open_file():
self.open_file() self.open_file()
def open_movie_pipe(self, file_path): def open_movie_pipe(self, file_path: str) -> None:
stem, ext = os.path.splitext(file_path) stem, ext = os.path.splitext(file_path)
self.final_file_path = file_path self.final_file_path = file_path
self.temp_file_path = stem + "_temp" + ext self.temp_file_path = stem + "_temp" + ext
@ -223,7 +241,7 @@ class SceneFileWriter(object):
) )
self.has_progress_display = True self.has_progress_display = True
def set_progress_display_subdescription(self, sub_desc): def set_progress_display_subdescription(self, sub_desc: str) -> None:
desc_len = self.progress_description_len desc_len = self.progress_description_len
file = os.path.split(self.get_movie_file_path())[1] file = os.path.split(self.get_movie_file_path())[1]
full_desc = f"Rendering {file} ({sub_desc})" full_desc = f"Rendering {file} ({sub_desc})"
@ -233,14 +251,14 @@ class SceneFileWriter(object):
full_desc += " " * (desc_len - len(full_desc)) full_desc += " " * (desc_len - len(full_desc))
self.progress_display.set_description(full_desc) self.progress_display.set_description(full_desc)
def write_frame(self, camera): def write_frame(self, camera: Camera) -> None:
if self.write_to_movie: if self.write_to_movie:
raw_bytes = camera.get_raw_fbo_data() raw_bytes = camera.get_raw_fbo_data()
self.writing_process.stdin.write(raw_bytes) self.writing_process.stdin.write(raw_bytes)
if self.has_progress_display: if self.has_progress_display:
self.progress_display.update() self.progress_display.update()
def close_movie_pipe(self): def close_movie_pipe(self) -> None:
self.writing_process.stdin.close() self.writing_process.stdin.close()
self.writing_process.wait() self.writing_process.wait()
self.writing_process.terminate() self.writing_process.terminate()
@ -248,7 +266,7 @@ class SceneFileWriter(object):
self.progress_display.close() self.progress_display.close()
shutil.move(self.temp_file_path, self.final_file_path) shutil.move(self.temp_file_path, self.final_file_path)
def combine_movie_files(self): def combine_movie_files(self) -> None:
kwargs = { kwargs = {
"remove_non_integer_files": True, "remove_non_integer_files": True,
"extension": self.movie_file_extension, "extension": self.movie_file_extension,
@ -296,7 +314,7 @@ class SceneFileWriter(object):
combine_process = sp.Popen(commands) combine_process = sp.Popen(commands)
combine_process.wait() combine_process.wait()
def add_sound_to_video(self): def add_sound_to_video(self) -> None:
movie_file_path = self.get_movie_file_path() movie_file_path = self.get_movie_file_path()
stem, ext = os.path.splitext(movie_file_path) stem, ext = os.path.splitext(movie_file_path)
sound_file_path = stem + ".wav" sound_file_path = stem + ".wav"
@ -327,22 +345,22 @@ class SceneFileWriter(object):
shutil.move(temp_file_path, movie_file_path) shutil.move(temp_file_path, movie_file_path)
os.remove(sound_file_path) os.remove(sound_file_path)
def save_final_image(self, image): def save_final_image(self, image: Image) -> None:
file_path = self.get_image_file_path() file_path = self.get_image_file_path()
image.save(file_path) image.save(file_path)
self.print_file_ready_message(file_path) self.print_file_ready_message(file_path)
def print_file_ready_message(self, file_path): def print_file_ready_message(self, file_path: str) -> None:
if not self.quiet: if not self.quiet:
log.info(f"File ready at {file_path}") log.info(f"File ready at {file_path}")
def should_open_file(self): def should_open_file(self) -> bool:
return any([ return any([
self.show_file_location_upon_completion, self.show_file_location_upon_completion,
self.open_file_upon_completion, self.open_file_upon_completion,
]) ])
def open_file(self): def open_file(self) -> None:
if self.quiet: if self.quiet:
curr_stdout = sys.stdout curr_stdout = sys.stdout
sys.stdout = open(os.devnull, "w") sys.stdout = open(os.devnull, "w")