From f307c2a2988226c954f660f3b5325f8c742f550d Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Tue, 12 Apr 2022 11:13:05 +0800 Subject: [PATCH 01/64] Add type annotations for color.py --- manimlib/utils/color.py | 80 +++++++++++++++++++++-------------------- 1 file changed, 41 insertions(+), 39 deletions(-) diff --git a/manimlib/utils/color.py b/manimlib/utils/color.py index 93cc9577..5ae235fd 100644 --- a/manimlib/utils/color.py +++ b/manimlib/utils/color.py @@ -1,6 +1,10 @@ -import random +from __future__ import annotations + +from typing import Iterable, Union from colour import Color +from colour import hex2rgb +from colour import rgb2hex import numpy as np from manimlib.constants import WHITE @@ -8,8 +12,10 @@ from manimlib.constants import COLORMAP_3B1B from manimlib.utils.bezier import interpolate from manimlib.utils.iterables import resize_with_interpolation +ManimColor = Union[str, Color] -def color_to_rgb(color): + +def color_to_rgb(color: ManimColor) -> np.ndarray: if isinstance(color, str): return hex_to_rgb(color) elif isinstance(color, Color): @@ -18,55 +24,48 @@ def color_to_rgb(color): raise Exception("Invalid color type") -def color_to_rgba(color, alpha=1): +def color_to_rgba(color: ManimColor, alpha: float = 1.0) -> np.ndarray: return np.array([*color_to_rgb(color), alpha]) -def rgb_to_color(rgb): +def rgb_to_color(rgb: Iterable[float]) -> Color: try: - return Color(rgb=rgb) + return Color(rgb=tuple(rgb)) except ValueError: return Color(WHITE) -def rgba_to_color(rgba): - return rgb_to_color(rgba[:3]) +def rgba_to_color(rgba: Iterable[float]) -> Color: + return rgb_to_color(tuple(rgba)[:3]) -def rgb_to_hex(rgb): - return "#" + "".join( - hex(int_x // 16)[2] + hex(int_x % 16)[2] - for x in rgb - for int_x in [int(255 * x)] - ) +def rgb_to_hex(rgb: Iterable[float]) -> str: + return rgb2hex(rgb, force_long=True).upper() -def hex_to_rgb(hex_code): - hex_part = hex_code[1:] - if len(hex_part) == 3: - hex_part = "".join([2 * c for c in hex_part]) - return np.array([ - int(hex_part[i:i + 2], 16) / 255 - for i in range(0, 6, 2) - ]) +def hex_to_rgb(hex_code: str) -> tuple[float]: + return hex2rgb(hex_code) -def invert_color(color): +def invert_color(color: ManimColor) -> Color: return rgb_to_color(1.0 - color_to_rgb(color)) -def color_to_int_rgb(color): +def color_to_int_rgb(color: ManimColor) -> np.ndarray: return (255 * color_to_rgb(color)).astype('uint8') -def color_to_int_rgba(color, opacity=1.0): +def color_to_int_rgba(color: ManimColor, opacity: float = 1.0) -> np.ndarray: alpha = int(255 * opacity) return np.array([*color_to_int_rgb(color), alpha]) -def color_gradient(reference_colors, length_of_output): +def color_gradient( + reference_colors: Iterable[ManimColor], + length_of_output: int +) -> list[Color]: if length_of_output == 0: - return reference_colors[0] + return [] rgbs = list(map(color_to_rgb, reference_colors)) alphas = np.linspace(0, (len(rgbs) - 1), length_of_output) floors = alphas.astype('int') @@ -80,30 +79,33 @@ def color_gradient(reference_colors, length_of_output): ] -def interpolate_color(color1, color2, alpha): +def interpolate_color( + color1: ManimColor, + color2: ManimColor, + alpha: float +) -> Color: rgb = interpolate(color_to_rgb(color1), color_to_rgb(color2), alpha) return rgb_to_color(rgb) -def average_color(*colors): +def average_color(*colors: ManimColor) -> Color: rgbs = np.array(list(map(color_to_rgb, colors))) return rgb_to_color(rgbs.mean(0)) -def random_bright_color(): +def random_color() -> Color: + return Color(rgb=tuple(np.random.random(3))) + + +def random_bright_color() -> Color: color = random_color() - curr_rgb = color_to_rgb(color) - new_rgb = interpolate( - curr_rgb, np.ones(len(curr_rgb)), 0.5 - ) - return Color(rgb=new_rgb) + return average_color(color, Color(WHITE)) -def random_color(): - return Color(rgb=(random.random() for i in range(3))) - - -def get_colormap_list(map_name="viridis", n_colors=9): +def get_colormap_list( + map_name: str = "viridis", + n_colors: int = 9 +) -> np.ndarray: """ Options for map_name: 3b1b_colormap From 0cf31995789e8bfe3722e0183e2cf55ab92457e5 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Tue, 12 Apr 2022 11:26:19 +0800 Subject: [PATCH 02/64] Adjust return type --- manimlib/utils/color.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/manimlib/utils/color.py b/manimlib/utils/color.py index 5ae235fd..1687d5b3 100644 --- a/manimlib/utils/color.py +++ b/manimlib/utils/color.py @@ -43,8 +43,8 @@ def rgb_to_hex(rgb: Iterable[float]) -> str: return rgb2hex(rgb, force_long=True).upper() -def hex_to_rgb(hex_code: str) -> tuple[float]: - return hex2rgb(hex_code) +def hex_to_rgb(hex_code: str) -> np.ndarray: + return np.array(hex2rgb(hex_code)) def invert_color(color: ManimColor) -> Color: From 9ef9961d0e56137d3a8738e7dfbd61a3ae8b6b93 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Tue, 12 Apr 2022 19:19:59 +0800 Subject: [PATCH 03/64] Sort imports --- manimlib/__init__.py | 12 +-- manimlib/__main__.py | 8 +- manimlib/animation/animation.py | 3 +- manimlib/animation/composition.py | 8 +- manimlib/animation/creation.py | 10 +-- manimlib/animation/fading.py | 4 +- manimlib/animation/growing.py | 5 +- manimlib/animation/indication.py | 27 ++++--- manimlib/animation/movement.py | 5 +- manimlib/animation/numbers.py | 7 +- manimlib/animation/rotation.py | 7 +- manimlib/animation/specialized.py | 10 ++- manimlib/animation/transform.py | 13 ++-- .../animation/transform_matching_parts.py | 6 +- manimlib/animation/update.py | 10 +-- manimlib/camera/camera.py | 25 ++++--- manimlib/config.py | 8 +- manimlib/constants.py | 1 + manimlib/event_handler/event_dispatcher.py | 2 +- manimlib/event_handler/event_listner.py | 7 +- manimlib/extract_scene.py | 4 +- manimlib/logger.py | 1 + manimlib/mobject/changing.py | 15 ++-- manimlib/mobject/coordinate_systems.py | 21 ++++-- manimlib/mobject/frame.py | 3 +- manimlib/mobject/functions.py | 10 ++- manimlib/mobject/geometry.py | 25 ++++--- manimlib/mobject/interactive.py | 32 +++++--- manimlib/mobject/matrix.py | 15 ++-- manimlib/mobject/mobject.py | 68 ++++++++++------- manimlib/mobject/mobject_update_utils.py | 4 +- manimlib/mobject/number_line.py | 11 ++- manimlib/mobject/numbers.py | 11 ++- manimlib/mobject/probability.py | 16 ++-- manimlib/mobject/shape_matchers.py | 15 ++-- manimlib/mobject/svg/brace.py | 33 +++++---- manimlib/mobject/svg/drawings.py | 2 +- manimlib/mobject/svg/labelled_string.py | 13 ++-- manimlib/mobject/svg/mtex_mobject.py | 15 ++-- manimlib/mobject/svg/svg_mobject.py | 9 +-- manimlib/mobject/svg/tex_mobject.py | 19 +++-- manimlib/mobject/svg/text_mobject.py | 34 +++++---- manimlib/mobject/three_dimensions.py | 10 ++- manimlib/mobject/types/dot_cloud.py | 11 ++- manimlib/mobject/types/image_mobject.py | 2 +- manimlib/mobject/types/point_cloud_mobject.py | 19 +++-- manimlib/mobject/types/surface.py | 10 ++- manimlib/mobject/types/vectorized_mobject.py | 73 ++++++++++++------- manimlib/mobject/vector_field.py | 19 +++-- manimlib/once_useful_constructs/fractals.py | 1 + manimlib/scene/sample_space_scene.py | 5 +- manimlib/scene/scene.py | 23 +++--- manimlib/scene/scene_file_writer.py | 13 ++-- manimlib/scene/vector_space_scene.py | 6 +- manimlib/shader_wrapper.py | 9 ++- manimlib/utils/bezier.py | 25 ++++--- manimlib/utils/color.py | 11 ++- manimlib/utils/customization.py | 1 + manimlib/utils/debug.py | 7 +- manimlib/utils/directories.py | 2 +- manimlib/utils/family_ops.py | 3 +- manimlib/utils/file_ops.py | 6 +- manimlib/utils/images.py | 2 +- manimlib/utils/init_config.py | 15 ++-- manimlib/utils/iterables.py | 11 ++- manimlib/utils/paths.py | 1 + manimlib/utils/rate_functions.py | 3 +- manimlib/utils/simple_functions.py | 7 +- manimlib/utils/sounds.py | 2 +- manimlib/utils/space_ops.py | 20 ++--- manimlib/utils/tex_file_writing.py | 10 +-- manimlib/window.py | 1 + 72 files changed, 527 insertions(+), 355 deletions(-) diff --git a/manimlib/__init__.py b/manimlib/__init__.py index a0147cf7..40d396ed 100644 --- a/manimlib/__init__.py +++ b/manimlib/__init__.py @@ -20,17 +20,16 @@ from manimlib.animation.update import * from manimlib.camera.camera import * -from manimlib.window import * - from manimlib.mobject.boolean_ops import * -from manimlib.mobject.coordinate_systems import * from manimlib.mobject.changing import * +from manimlib.mobject.coordinate_systems import * from manimlib.mobject.frame import * from manimlib.mobject.functions import * from manimlib.mobject.geometry import * from manimlib.mobject.interactive import * from manimlib.mobject.matrix import * from manimlib.mobject.mobject import * +from manimlib.mobject.mobject_update_utils import * from manimlib.mobject.number_line import * from manimlib.mobject.numbers import * from manimlib.mobject.probability import * @@ -43,12 +42,11 @@ from manimlib.mobject.svg.svg_mobject import * from manimlib.mobject.svg.tex_mobject import * from manimlib.mobject.svg.text_mobject import * from manimlib.mobject.three_dimensions import * +from manimlib.mobject.types.dot_cloud import * from manimlib.mobject.types.image_mobject import * from manimlib.mobject.types.point_cloud_mobject import * from manimlib.mobject.types.surface import * from manimlib.mobject.types.vectorized_mobject import * -from manimlib.mobject.types.dot_cloud import * -from manimlib.mobject.mobject_update_utils import * from manimlib.mobject.value_tracker import * from manimlib.mobject.vector_field import * @@ -61,11 +59,13 @@ from manimlib.utils.config_ops import * from manimlib.utils.customization import * from manimlib.utils.debug import * from manimlib.utils.directories import * +from manimlib.utils.file_ops import * from manimlib.utils.images import * from manimlib.utils.iterables import * -from manimlib.utils.file_ops import * from manimlib.utils.paths import * from manimlib.utils.rate_functions import * from manimlib.utils.simple_functions import * from manimlib.utils.sounds import * from manimlib.utils.space_ops import * + +from manimlib.window import * diff --git a/manimlib/__main__.py b/manimlib/__main__.py index d6af540d..d5dddc7e 100644 --- a/manimlib/__main__.py +++ b/manimlib/__main__.py @@ -1,9 +1,9 @@ #!/usr/bin/env python -import manimlib.config -import manimlib.logger -import manimlib.extract_scene -import manimlib.utils.init_config from manimlib import __version__ +import manimlib.config +import manimlib.extract_scene +import manimlib.logger +import manimlib.utils.init_config def main(): diff --git a/manimlib/animation/animation.py b/manimlib/animation/animation.py index 8ec26de2..bce11513 100644 --- a/manimlib/animation/animation.py +++ b/manimlib/animation/animation.py @@ -1,7 +1,6 @@ from __future__ import annotations from copy import deepcopy -from typing import Callable from manimlib.mobject.mobject import _AnimationBuilder from manimlib.mobject.mobject import Mobject @@ -12,6 +11,8 @@ from manimlib.utils.simple_functions import clip from typing import TYPE_CHECKING if TYPE_CHECKING: + from typing import Callable + from manimlib.scene.scene import Scene diff --git a/manimlib/animation/composition.py b/manimlib/animation/composition.py index f282bc9c..c9758af1 100644 --- a/manimlib/animation/composition.py +++ b/manimlib/animation/composition.py @@ -1,9 +1,9 @@ from __future__ import annotations import numpy as np -from typing import Callable -from manimlib.animation.animation import Animation, prepare_animation +from manimlib.animation.animation import Animation +from manimlib.animation.animation import prepare_animation from manimlib.mobject.mobject import Group from manimlib.utils.bezier import integer_interpolate from manimlib.utils.bezier import interpolate @@ -15,8 +15,10 @@ from manimlib.utils.simple_functions import clip from typing import TYPE_CHECKING if TYPE_CHECKING: - from manimlib.scene.scene import Scene + from typing import Callable + from manimlib.mobject.mobject import Mobject + from manimlib.scene.scene import Scene DEFAULT_LAGGED_START_LAG_RATIO = 0.05 diff --git a/manimlib/animation/creation.py b/manimlib/animation/creation.py index 6499d0af..5882df8f 100644 --- a/manimlib/animation/creation.py +++ b/manimlib/animation/creation.py @@ -1,12 +1,10 @@ from __future__ import annotations -import itertools as it -from abc import abstractmethod +from abc import ABC, abstractmethod import numpy as np from manimlib.animation.animation import Animation -from manimlib.animation.composition import Succession from manimlib.mobject.svg.labelled_string import LabelledString from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.bezier import integer_interpolate @@ -18,10 +16,10 @@ from manimlib.utils.rate_functions import smooth from typing import TYPE_CHECKING if TYPE_CHECKING: - from manimlib.mobject.mobject import Group + from manimlib.mobject.mobject import Mobject -class ShowPartial(Animation): +class ShowPartial(Animation, ABC): """ Abstract class for ShowCreation and ShowPassingFlash """ @@ -177,7 +175,7 @@ class ShowIncreasingSubsets(Animation): "int_func": np.round, } - def __init__(self, group: Group, **kwargs): + def __init__(self, group: Mobject, **kwargs): self.all_submobs = list(group.submobjects) super().__init__(group, **kwargs) diff --git a/manimlib/animation/fading.py b/manimlib/animation/fading.py index 39c149f0..2fc63d43 100644 --- a/manimlib/animation/fading.py +++ b/manimlib/animation/fading.py @@ -4,17 +4,17 @@ import numpy as np from manimlib.animation.animation import Animation from manimlib.animation.transform import Transform -from manimlib.mobject.mobject import Group from manimlib.constants import ORIGIN +from manimlib.mobject.mobject import Group from manimlib.utils.bezier import interpolate from manimlib.utils.rate_functions import there_and_back from typing import TYPE_CHECKING if TYPE_CHECKING: - from manimlib.scene.scene import Scene from manimlib.mobject.mobject import Mobject from manimlib.mobject.types.vectorized_mobject import VMobject + from manimlib.scene.scene import Scene DEFAULT_FADE_LAG_RATIO = 0 diff --git a/manimlib/animation/growing.py b/manimlib/animation/growing.py index 1b3c3cd7..b29982f2 100644 --- a/manimlib/animation/growing.py +++ b/manimlib/animation/growing.py @@ -1,14 +1,15 @@ from __future__ import annotations -from manimlib.constants import PI from manimlib.animation.transform import Transform +from manimlib.constants import PI from typing import TYPE_CHECKING if TYPE_CHECKING: import numpy as np - from manimlib.mobject.mobject import Mobject + from manimlib.mobject.geometry import Arrow + from manimlib.mobject.mobject import Mobject class GrowFromPoint(Transform): diff --git a/manimlib/animation/indication.py b/manimlib/animation/indication.py index 5f210773..92c7d013 100644 --- a/manimlib/animation/indication.py +++ b/manimlib/animation/indication.py @@ -1,40 +1,43 @@ from __future__ import annotations -import math -from typing import Union, Sequence - import numpy as np -from manimlib.constants import * from manimlib.animation.animation import Animation -from manimlib.animation.movement import Homotopy from manimlib.animation.composition import AnimationGroup from manimlib.animation.composition import Succession from manimlib.animation.creation import ShowCreation from manimlib.animation.creation import ShowPartial from manimlib.animation.fading import FadeOut from manimlib.animation.fading import FadeIn +from manimlib.animation.movement import Homotopy from manimlib.animation.transform import Transform -from manimlib.mobject.types.vectorized_mobject import VMobject +from manimlib.constants import ORIGIN, RIGHT, UP +from manimlib.constants import SMALL_BUFF +from manimlib.constants import TAU +from manimlib.constants import GREY, YELLOW from manimlib.mobject.geometry import Circle from manimlib.mobject.geometry import Dot +from manimlib.mobject.geometry import Line from manimlib.mobject.shape_matchers import SurroundingRectangle from manimlib.mobject.shape_matchers import Underline +from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.mobject.types.vectorized_mobject import VGroup -from manimlib.mobject.geometry import Line from manimlib.utils.bezier import interpolate from manimlib.utils.config_ops import digest_config -from manimlib.utils.rate_functions import there_and_back -from manimlib.utils.rate_functions import wiggle from manimlib.utils.rate_functions import smooth from manimlib.utils.rate_functions import squish_rate_func +from manimlib.utils.rate_functions import there_and_back +from manimlib.utils.rate_functions import wiggle from typing import TYPE_CHECKING if TYPE_CHECKING: - import colour + from colour import Color + from typing import Union + from manimlib.mobject.mobject import Mobject - ManimColor = Union[str, colour.Color, Sequence[float]] + + ManimColor = Union[str, Color] class FocusOn(Transform): @@ -217,7 +220,7 @@ class VShowPassingFlash(Animation): if abs(x - mu) > 3 * sigma: return 0 z = (x - mu) / sigma - return math.exp(-0.5 * z * z) + return np.exp(-0.5 * z * z) kernel_array = list(map(gauss_kernel, np.linspace(0, 1, len(anchor_widths)))) scaled_widths = anchor_widths * kernel_array diff --git a/manimlib/animation/movement.py b/manimlib/animation/movement.py index 78cbbee8..0edf38b1 100644 --- a/manimlib/animation/movement.py +++ b/manimlib/animation/movement.py @@ -1,14 +1,15 @@ from __future__ import annotations -from typing import Callable, Sequence - from manimlib.animation.animation import Animation from manimlib.utils.rate_functions import linear from typing import TYPE_CHECKING if TYPE_CHECKING: + from typing import Callable, Sequence + import numpy as np + from manimlib.mobject.mobject import Mobject diff --git a/manimlib/animation/numbers.py b/manimlib/animation/numbers.py index 0a992b39..5b6e9223 100644 --- a/manimlib/animation/numbers.py +++ b/manimlib/animation/numbers.py @@ -1,11 +1,14 @@ from __future__ import annotations -from typing import Callable - from manimlib.animation.animation import Animation from manimlib.mobject.numbers import DecimalNumber from manimlib.utils.bezier import interpolate +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Callable + class ChangingDecimal(Animation): CONFIG = { diff --git a/manimlib/animation/rotation.py b/manimlib/animation/rotation.py index 7993c3cf..058d9066 100644 --- a/manimlib/animation/rotation.py +++ b/manimlib/animation/rotation.py @@ -1,10 +1,8 @@ from __future__ import annotations from manimlib.animation.animation import Animation -from manimlib.constants import OUT -from manimlib.constants import PI -from manimlib.constants import TAU -from manimlib.constants import ORIGIN +from manimlib.constants import ORIGIN, OUT +from manimlib.constants import PI, TAU from manimlib.utils.rate_functions import linear from manimlib.utils.rate_functions import smooth @@ -12,6 +10,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: import numpy as np + from manimlib.mobject.mobject import Mobject diff --git a/manimlib/animation/specialized.py b/manimlib/animation/specialized.py index 376e37e2..26f26868 100644 --- a/manimlib/animation/specialized.py +++ b/manimlib/animation/specialized.py @@ -1,15 +1,17 @@ from __future__ import annotations -import numpy as np - from manimlib.animation.composition import LaggedStart from manimlib.animation.transform import Restore -from manimlib.constants import WHITE -from manimlib.constants import BLACK +from manimlib.constants import BLACK, WHITE from manimlib.mobject.geometry import Circle from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.utils.config_ops import digest_config +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import numpy as np + class Broadcast(LaggedStart): CONFIG = { diff --git a/manimlib/animation/transform.py b/manimlib/animation/transform.py index a426b21e..37307a0c 100644 --- a/manimlib/animation/transform.py +++ b/manimlib/animation/transform.py @@ -1,15 +1,13 @@ from __future__ import annotations import inspect -from typing import Callable, Union, Sequence import numpy as np -import numpy.typing as npt from manimlib.animation.animation import Animation from manimlib.constants import DEFAULT_POINTWISE_FUNCTION_RUN_TIME -from manimlib.constants import OUT from manimlib.constants import DEGREES +from manimlib.constants import OUT from manimlib.mobject.mobject import Group from manimlib.mobject.mobject import Mobject from manimlib.utils.config_ops import digest_config @@ -21,9 +19,14 @@ from manimlib.utils.rate_functions import squish_rate_func from typing import TYPE_CHECKING if TYPE_CHECKING: - import colour + from colour import Color + from typing import Callable, Union + + import numpy.typing as npt + from manimlib.scene.scene import Scene - ManimColor = Union[str, colour.Color, Sequence[float]] + + ManimColor = Union[str, Color] class Transform(Animation): diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index dab88005..e3c65a49 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -5,9 +5,9 @@ import itertools as it import numpy as np from manimlib.animation.composition import AnimationGroup -from manimlib.animation.fading import FadeTransformPieces from manimlib.animation.fading import FadeInFromPoint from manimlib.animation.fading import FadeOutToPoint +from manimlib.animation.fading import FadeTransformPieces from manimlib.animation.transform import ReplacementTransform from manimlib.animation.transform import Transform from manimlib.mobject.mobject import Mobject @@ -16,13 +16,13 @@ from manimlib.mobject.svg.labelled_string import LabelledString from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.config_ops import digest_config -from manimlib.utils.iterables import remove_list_redundancies from typing import TYPE_CHECKING if TYPE_CHECKING: + from manimlib.mobject.svg.tex_mobject import SingleStringTex + from manimlib.mobject.svg.tex_mobject import Tex from manimlib.scene.scene import Scene - from manimlib.mobject.svg.tex_mobject import Tex, SingleStringTex class TransformMatchingParts(AnimationGroup): diff --git a/manimlib/animation/update.py b/manimlib/animation/update.py index 43fafa42..2a929584 100644 --- a/manimlib/animation/update.py +++ b/manimlib/animation/update.py @@ -1,13 +1,12 @@ from __future__ import annotations -import operator as op -from typing import Callable - from manimlib.animation.animation import Animation from typing import TYPE_CHECKING if TYPE_CHECKING: + from typing import Callable + from manimlib.mobject.mobject import Mobject @@ -47,10 +46,7 @@ class MaintainPositionRelativeTo(Animation): **kwargs ): self.tracked_mobject = tracked_mobject - self.diff = op.sub( - mobject.get_center(), - tracked_mobject.get_center(), - ) + self.diff = mobject.get_center() - tracked_mobject.get_center() super().__init__(mobject, **kwargs) def interpolate_mobject(self, alpha: float) -> None: diff --git a/manimlib/camera/camera.py b/manimlib/camera/camera.py index 40037a3d..79ce1366 100644 --- a/manimlib/camera/camera.py +++ b/manimlib/camera/camera.py @@ -1,19 +1,23 @@ from __future__ import annotations -import moderngl -from colour import Color -import OpenGL.GL as gl +import itertools as it import math -import itertools as it - +import moderngl import numpy as np -from scipy.spatial.transform import Rotation +import OpenGL.GL as gl from PIL import Image +from scipy.spatial.transform import Rotation -from manimlib.constants import * +from manimlib.constants import BLACK +from manimlib.constants import DEGREES, RADIANS +from manimlib.constants import DEFAULT_FRAME_RATE +from manimlib.constants import DEFAULT_PIXEL_HEIGHT, DEFAULT_PIXEL_WIDTH +from manimlib.constants import FRAME_HEIGHT, FRAME_WIDTH +from manimlib.constants import DOWN, LEFT, ORIGIN, OUT, RIGHT, UP from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Point +from manimlib.utils.color import color_to_rgba from manimlib.utils.config_ops import digest_config from manimlib.utils.simple_functions import fdiv from manimlib.utils.space_ops import normalize @@ -180,10 +184,9 @@ class Camera(object): def __init__(self, ctx: moderngl.Context | None = None, **kwargs): digest_config(self, kwargs, locals()) self.rgb_max_val: float = np.iinfo(self.pixel_array_dtype).max - self.background_rgba: list[float] = [ - *Color(self.background_color).get_rgb(), - self.background_opacity - ] + self.background_rgba: list[float] = list(color_to_rgba( + self.background_color, self.background_opacity + )) self.init_frame() self.init_context(ctx) self.init_shaders() diff --git a/manimlib/config.py b/manimlib/config.py index a2e68c51..a57c3535 100644 --- a/manimlib/config.py +++ b/manimlib/config.py @@ -1,16 +1,16 @@ import argparse import colour -import inspect +from contextlib import contextmanager import importlib +import inspect import os +from screeninfo import get_monitors import sys import yaml -from contextlib import contextmanager -from screeninfo import get_monitors +from manimlib.logger import log from manimlib.utils.config_ops import merge_dicts_recursively from manimlib.utils.init_config import init_customization -from manimlib.logger import log __config_file__ = "custom_config.yml" diff --git a/manimlib/constants.py b/manimlib/constants.py index 590a9cda..dcc03766 100644 --- a/manimlib/constants.py +++ b/manimlib/constants.py @@ -1,5 +1,6 @@ import numpy as np + # Sizes relevant to default camera frame ASPECT_RATIO = 16.0 / 9.0 FRAME_HEIGHT = 8.0 diff --git a/manimlib/event_handler/event_dispatcher.py b/manimlib/event_handler/event_dispatcher.py index 34eb55eb..2ec7c49d 100644 --- a/manimlib/event_handler/event_dispatcher.py +++ b/manimlib/event_handler/event_dispatcher.py @@ -2,8 +2,8 @@ from __future__ import annotations import numpy as np -from manimlib.event_handler.event_type import EventType from manimlib.event_handler.event_listner import EventListner +from manimlib.event_handler.event_type import EventType class EventDispatcher(object): diff --git a/manimlib/event_handler/event_listner.py b/manimlib/event_handler/event_listner.py index 4552cf8c..6f32121c 100644 --- a/manimlib/event_handler/event_listner.py +++ b/manimlib/event_handler/event_listner.py @@ -1,10 +1,13 @@ from __future__ import annotations -from typing import Callable, TYPE_CHECKING +from typing import TYPE_CHECKING if TYPE_CHECKING: - from manimlib.mobject.mobject import Mobject + from typing import Callable + from manimlib.event_handler.event_type import EventType + from manimlib.mobject.mobject import Mobject + class EventListner(object): def __init__( diff --git a/manimlib/extract_scene.py b/manimlib/extract_scene.py index abec96ec..55a73f63 100644 --- a/manimlib/extract_scene.py +++ b/manimlib/extract_scene.py @@ -1,10 +1,10 @@ +import copy import inspect import sys -import copy -from manimlib.scene.scene import Scene from manimlib.config import get_custom_config from manimlib.logger import log +from manimlib.scene.scene import Scene class BlankScene(Scene): diff --git a/manimlib/logger.py b/manimlib/logger.py index b04ae7ae..71567a1d 100644 --- a/manimlib/logger.py +++ b/manimlib/logger.py @@ -1,4 +1,5 @@ import logging + from rich.logging import RichHandler __all__ = ["log"] diff --git a/manimlib/mobject/changing.py b/manimlib/mobject/changing.py index 76d92bab..7a954708 100644 --- a/manimlib/mobject/changing.py +++ b/manimlib/mobject/changing.py @@ -1,19 +1,18 @@ from __future__ import annotations -from typing import Callable - import numpy as np -from manimlib.constants import BLUE_D -from manimlib.constants import BLUE_B -from manimlib.constants import BLUE_E -from manimlib.constants import GREY_BROWN -from manimlib.constants import WHITE +from manimlib.constants import BLUE_B, BLUE_D, BLUE_E, GREY_BROWN, WHITE from manimlib.mobject.mobject import Mobject -from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.mobject.types.vectorized_mobject import VGroup +from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.rate_functions import smooth +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Callable + class AnimatedBoundary(VGroup): CONFIG = { diff --git a/manimlib/mobject/coordinate_systems.py b/manimlib/mobject/coordinate_systems.py index 2d21d9d4..09a84a97 100644 --- a/manimlib/mobject/coordinate_systems.py +++ b/manimlib/mobject/coordinate_systems.py @@ -1,16 +1,20 @@ from __future__ import annotations +from abc import ABC, abstractmethod import numbers -from abc import abstractmethod -from typing import Type, TypeVar, Union, Callable, Iterable, Sequence import numpy as np -from manimlib.constants import * +from manimlib.constants import BLACK, BLUE, BLUE_D, GREEN, GREY_A, WHITE +from manimlib.constants import DEGREES, PI +from manimlib.constants import DL, DOWN, DR, LEFT, ORIGIN, OUT, RIGHT, UP +from manimlib.constants import FRAME_HEIGHT, FRAME_WIDTH +from manimlib.constants import FRAME_X_RADIUS, FRAME_Y_RADIUS +from manimlib.constants import MED_SMALL_BUFF, SMALL_BUFF from manimlib.mobject.functions import ParametricCurve from manimlib.mobject.geometry import Arrow -from manimlib.mobject.geometry import Line from manimlib.mobject.geometry import DashedLine +from manimlib.mobject.geometry import Line from manimlib.mobject.geometry import Rectangle from manimlib.mobject.number_line import NumberLine from manimlib.mobject.svg.tex_mobject import Tex @@ -25,16 +29,19 @@ from manimlib.utils.space_ops import rotate_vector from typing import TYPE_CHECKING if TYPE_CHECKING: - import colour + from colour import Color + from typing import Callable, Iterable, Sequence, Type, TypeVar, Union + from manimlib.mobject.mobject import Mobject + T = TypeVar("T", bound=Mobject) - ManimColor = Union[str, colour.Color, Sequence[float]] + ManimColor = Union[str, Color] EPSILON = 1e-8 -class CoordinateSystem(): +class CoordinateSystem(ABC): """ Abstract class for Axes and NumberPlane """ diff --git a/manimlib/mobject/frame.py b/manimlib/mobject/frame.py index 7523ab51..fc496695 100644 --- a/manimlib/mobject/frame.py +++ b/manimlib/mobject/frame.py @@ -1,4 +1,5 @@ -from manimlib.constants import * +from manimlib.constants import BLACK, GREY_E +from manimlib.constants import FRAME_HEIGHT from manimlib.mobject.geometry import Rectangle from manimlib.utils.config_ops import digest_config diff --git a/manimlib/mobject/functions.py b/manimlib/mobject/functions.py index ec374ed0..b6033fb8 100644 --- a/manimlib/mobject/functions.py +++ b/manimlib/mobject/functions.py @@ -1,13 +1,17 @@ from __future__ import annotations -from typing import Callable, Sequence - from isosurfaces import plot_isoline -from manimlib.constants import * +from manimlib.constants import FRAME_X_RADIUS, FRAME_Y_RADIUS +from manimlib.constants import YELLOW from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.config_ops import digest_config +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Callable, Sequence + class ParametricCurve(VMobject): CONFIG = { diff --git a/manimlib/mobject/geometry.py b/manimlib/mobject/geometry.py index f555ee52..8be3624c 100644 --- a/manimlib/mobject/geometry.py +++ b/manimlib/mobject/geometry.py @@ -2,23 +2,24 @@ from __future__ import annotations import math import numbers -from typing import Sequence, Union -import colour import numpy as np -from manimlib.constants import * +from manimlib.constants import DOWN, LEFT, ORIGIN, OUT, RIGHT, UP +from manimlib.constants import GREY_A, RED, WHITE +from manimlib.constants import MED_SMALL_BUFF +from manimlib.constants import PI, TAU from manimlib.mobject.mobject import Mobject +from manimlib.mobject.types.vectorized_mobject import DashedVMobject from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VMobject -from manimlib.mobject.types.vectorized_mobject import DashedVMobject from manimlib.utils.config_ops import digest_config from manimlib.utils.iterables import adjacent_n_tuples from manimlib.utils.iterables import adjacent_pairs -from manimlib.utils.simple_functions import fdiv from manimlib.utils.simple_functions import clip -from manimlib.utils.space_ops import angle_of_vector +from manimlib.utils.simple_functions import fdiv from manimlib.utils.space_ops import angle_between_vectors +from manimlib.utils.space_ops import angle_of_vector from manimlib.utils.space_ops import compass_directions from manimlib.utils.space_ops import find_intersection from manimlib.utils.space_ops import get_norm @@ -26,7 +27,13 @@ from manimlib.utils.space_ops import normalize from manimlib.utils.space_ops import rotate_vector from manimlib.utils.space_ops import rotation_matrix_transpose -ManimColor = Union[str, colour.Color, Sequence[float]] +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from colour import Color + from typing import Union + + ManimColor = Union[str, Color] DEFAULT_DOT_RADIUS = 0.08 @@ -716,8 +723,8 @@ class Arrow(Line): def set_stroke( self, - color: ManimColor | None = None, - width: float | None = None, + color: ManimColor | Iterable[ManimColor] | None = None, + width: float | Iterable[float] | None = None, *args, **kwargs ): super().set_stroke(color=color, width=width, *args, **kwargs) diff --git a/manimlib/mobject/interactive.py b/manimlib/mobject/interactive.py index b50425ef..fd3a3b3b 100644 --- a/manimlib/mobject/interactive.py +++ b/manimlib/mobject/interactive.py @@ -1,22 +1,32 @@ from __future__ import annotations -from typing import Callable - import numpy as np from pyglet.window import key as PygletWindowKeys from manimlib.constants import FRAME_HEIGHT, FRAME_WIDTH -from manimlib.constants import LEFT, RIGHT, UP, DOWN, ORIGIN -from manimlib.constants import SMALL_BUFF, MED_SMALL_BUFF, MED_LARGE_BUFF -from manimlib.constants import BLACK, GREY_A, GREY_C, RED, GREEN, BLUE, WHITE -from manimlib.mobject.mobject import Mobject, Group -from manimlib.mobject.types.vectorized_mobject import VGroup -from manimlib.mobject.geometry import Dot, Line, Square, Rectangle, RoundedRectangle, Circle +from manimlib.constants import DOWN, LEFT, ORIGIN, RIGHT, UP +from manimlib.constants import MED_LARGE_BUFF, MED_SMALL_BUFF, SMALL_BUFF +from manimlib.constants import BLACK, BLUE, GREEN, GREY_A, GREY_C, RED, WHITE +from manimlib.mobject.mobject import Group +from manimlib.mobject.mobject import Mobject +from manimlib.mobject.geometry import Circle +from manimlib.mobject.geometry import Dot +from manimlib.mobject.geometry import Line +from manimlib.mobject.geometry import Rectangle +from manimlib.mobject.geometry import RoundedRectangle +from manimlib.mobject.geometry import Square from manimlib.mobject.svg.text_mobject import Text +from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.value_tracker import ValueTracker +from manimlib.utils.color import rgb_to_hex from manimlib.utils.config_ops import digest_config -from manimlib.utils.space_ops import get_norm, get_closest_point_on_line -from manimlib.utils.color import rgb_to_color, color_to_rgba, rgb_to_hex +from manimlib.utils.space_ops import get_closest_point_on_line +from manimlib.utils.space_ops import get_norm + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Callable # Interactive Mobjects @@ -336,7 +346,7 @@ class ColorSliders(Group): g = self.g_slider.get_value() / 255 b = self.b_slider.get_value() / 255 alpha = self.a_slider.get_value() - return color_to_rgba(rgb_to_color((r, g, b)), alpha=alpha) + return np.array((r, g, b, alpha)) def get_picked_color(self) -> str: rgba = self.get_value() diff --git a/manimlib/mobject/matrix.py b/manimlib/mobject/matrix.py index b53cba0a..76688a4d 100644 --- a/manimlib/mobject/matrix.py +++ b/manimlib/mobject/matrix.py @@ -1,12 +1,12 @@ from __future__ import annotations import itertools as it -from typing import Union, Sequence import numpy as np -import numpy.typing as npt -from manimlib.constants import * +from manimlib.constants import DEFAULT_MOBJECT_TO_MOBJECT_BUFFER +from manimlib.constants import DOWN, LEFT, RIGHT, UP +from manimlib.constants import WHITE from manimlib.mobject.numbers import DecimalNumber from manimlib.mobject.numbers import Integer from manimlib.mobject.shape_matchers import BackgroundRectangle @@ -18,9 +18,14 @@ from manimlib.mobject.types.vectorized_mobject import VMobject from typing import TYPE_CHECKING if TYPE_CHECKING: - import colour + from colour import Color + from typing import Union + + import numpy.typing as npt + from manimlib.mobject.mobject import Mobject - ManimColor = Union[str, colour.Color, Sequence[float]] + + ManimColor = Union[str, Color] VECTOR_LABEL_SCALE_FACTOR = 0.8 diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 17fe9ad0..d1e5e6eb 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -5,44 +5,54 @@ import copy import random import itertools as it from functools import wraps -from typing import Iterable, Callable, Union, Sequence -import colour import moderngl import numpy as np -import numpy.typing as npt -from manimlib.constants import * +from manimlib.constants import DEFAULT_MOBJECT_TO_EDGE_BUFFER +from manimlib.constants import DEFAULT_MOBJECT_TO_MOBJECT_BUFFER +from manimlib.constants import DOWN, IN, LEFT, ORIGIN, OUT, RIGHT, UP +from manimlib.constants import FRAME_X_RADIUS, FRAME_Y_RADIUS +from manimlib.constants import MED_SMALL_BUFF +from manimlib.constants import TAU +from manimlib.constants import WHITE +from manimlib.event_handler import EVENT_DISPATCHER +from manimlib.event_handler.event_listner import EventListner +from manimlib.event_handler.event_type import EventType +from manimlib.shader_wrapper import get_colormap_code +from manimlib.shader_wrapper import ShaderWrapper from manimlib.utils.color import color_gradient +from manimlib.utils.color import color_to_rgb from manimlib.utils.color import get_colormap_list from manimlib.utils.color import rgb_to_hex -from manimlib.utils.color import color_to_rgb from manimlib.utils.config_ops import digest_config from manimlib.utils.iterables import batch_by_property from manimlib.utils.iterables import list_update +from manimlib.utils.iterables import listify +from manimlib.utils.iterables import make_even from manimlib.utils.iterables import resize_array from manimlib.utils.iterables import resize_preserving_order from manimlib.utils.iterables import resize_with_interpolation -from manimlib.utils.iterables import make_even -from manimlib.utils.iterables import listify -from manimlib.utils.bezier import interpolate from manimlib.utils.bezier import integer_interpolate +from manimlib.utils.bezier import interpolate from manimlib.utils.paths import straight_path from manimlib.utils.simple_functions import get_parameters from manimlib.utils.space_ops import angle_of_vector from manimlib.utils.space_ops import get_norm from manimlib.utils.space_ops import rotation_matrix_transpose -from manimlib.shader_wrapper import ShaderWrapper -from manimlib.shader_wrapper import get_colormap_code -from manimlib.event_handler import EVENT_DISPATCHER -from manimlib.event_handler.event_listner import EventListner -from manimlib.event_handler.event_type import EventType +from typing import TYPE_CHECKING -TimeBasedUpdater = Callable[["Mobject", float], None] -NonTimeUpdater = Callable[["Mobject"], None] -Updater = Union[TimeBasedUpdater, NonTimeUpdater] -ManimColor = Union[str, colour.Color, Sequence[float]] +if TYPE_CHECKING: + from colour import Color + from typing import Callable, Iterable, Sequence, Union + + import numpy.typing as npt + + TimeBasedUpdater = Callable[["Mobject", float], None] + NonTimeUpdater = Callable[["Mobject"], None] + Updater = Union[TimeBasedUpdater, NonTimeUpdater] + ManimColor = Union[str, Color] class Mobject(object): @@ -635,7 +645,7 @@ class Mobject(object): def scale( self, - scale_factor: float | npt.ArrayLike, + scale_factor: float | Iterable[float], min_scale_factor: float = 1e-8, about_point: np.ndarray | None = None, about_edge: np.ndarray = ORIGIN @@ -649,10 +659,7 @@ class Mobject(object): Otherwise, if about_point is given a value, scaling is done with respect to that point. """ - if isinstance(scale_factor, Iterable): - scale_factor = np.array(scale_factor).clip(min=min_scale_factor) - else: - scale_factor = max(scale_factor, min_scale_factor) + scale_factor = np.resize(scale_factor, self.dim).clip(min=min_scale_factor) self.apply_points_function( lambda points: scale_factor * points, about_point=about_point, @@ -1038,8 +1045,8 @@ class Mobject(object): def set_rgba_array_by_color( self, - color: ManimColor | None = None, - opacity: float | None = None, + color: ManimColor | Iterable[ManimColor] | None = None, + opacity: float | Iterable[float] | None = None, name: str = "rgbas", recurse: bool = True ): @@ -1061,7 +1068,12 @@ class Mobject(object): mob.data[name][:, 3] = resize_array(opacities, size) return self - def set_color(self, color: ManimColor, opacity: float | None = None, recurse: bool = True): + def set_color( + self, + color: ManimColor | Iterable[ManimColor] | None, + opacity: float | Iterable[float] | None = None, + recurse: bool = True + ): self.set_rgba_array_by_color(color, opacity, recurse=False) # Recurse to submobjects differently from how set_rgba_array_by_color # in case they implement set_color differently @@ -1070,7 +1082,11 @@ class Mobject(object): submob.set_color(color, recurse=True) return self - def set_opacity(self, opacity: float, recurse: bool = True): + def set_opacity( + self, + opacity: float | Iterable[float] | None, + recurse: bool = True + ): self.set_rgba_array_by_color(color=None, opacity=opacity, recurse=False) if recurse: for submob in self.submobjects: diff --git a/manimlib/mobject/mobject_update_utils.py b/manimlib/mobject/mobject_update_utils.py index b32a0ff3..4fdcfecd 100644 --- a/manimlib/mobject/mobject_update_utils.py +++ b/manimlib/mobject/mobject_update_utils.py @@ -1,7 +1,6 @@ from __future__ import annotations import inspect -from typing import Callable from manimlib.constants import DEGREES from manimlib.constants import RIGHT @@ -11,7 +10,10 @@ from manimlib.utils.simple_functions import clip from typing import TYPE_CHECKING if TYPE_CHECKING: + from typing import Callable + import numpy as np + from manimlib.animation.animation import Animation diff --git a/manimlib/mobject/number_line.py b/manimlib/mobject/number_line.py index bc96b55a..2553ac3c 100644 --- a/manimlib/mobject/number_line.py +++ b/manimlib/mobject/number_line.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import Iterable, Sequence - -from manimlib.constants import * +from manimlib.constants import DOWN, LEFT, RIGHT, UP +from manimlib.constants import GREY_B +from manimlib.constants import MED_SMALL_BUFF from manimlib.mobject.geometry import Line from manimlib.mobject.numbers import DecimalNumber from manimlib.mobject.types.vectorized_mobject import VGroup @@ -12,6 +12,11 @@ from manimlib.utils.config_ops import digest_config from manimlib.utils.config_ops import merge_dicts_recursively from manimlib.utils.simple_functions import fdiv +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Iterable, Sequence + class NumberLine(Line): CONFIG = { diff --git a/manimlib/mobject/numbers.py b/manimlib/mobject/numbers.py index beac837c..a22df66d 100644 --- a/manimlib/mobject/numbers.py +++ b/manimlib/mobject/numbers.py @@ -1,13 +1,16 @@ from __future__ import annotations -from typing import TypeVar, Type - -from manimlib.constants import * +from manimlib.constants import DOWN, LEFT, RIGHT, UP from manimlib.mobject.svg.tex_mobject import SingleStringTex from manimlib.mobject.svg.text_mobject import Text from manimlib.mobject.types.vectorized_mobject import VMobject -T = TypeVar("T", bound=VMobject) +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Type, TypeVar + + T = TypeVar("T", bound=VMobject) class DecimalNumber(VMobject): diff --git a/manimlib/mobject/probability.py b/manimlib/mobject/probability.py index 9f4bdeab..6c6fe69e 100644 --- a/manimlib/mobject/probability.py +++ b/manimlib/mobject/probability.py @@ -1,9 +1,8 @@ from __future__ import annotations -from typing import Iterable, Union, Sequence -import colour - -from manimlib.constants import * +from manimlib.constants import BLUE, BLUE_E, GREEN_E, GREY_B, GREY_D, MAROON_B, YELLOW +from manimlib.constants import DOWN, LEFT, RIGHT, UP +from manimlib.constants import MED_LARGE_BUFF, MED_SMALL_BUFF, SMALL_BUFF from manimlib.mobject.geometry import Line from manimlib.mobject.geometry import Rectangle from manimlib.mobject.mobject import Mobject @@ -14,7 +13,14 @@ from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.utils.color import color_gradient from manimlib.utils.iterables import listify -ManimColor = Union[str, colour.Color, Sequence[float]] +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from colour import Color + from typing import Iterable, Union + + ManimColor = Union[str, Color] + EPSILON = 0.0001 diff --git a/manimlib/mobject/shape_matchers.py b/manimlib/mobject/shape_matchers.py index a1ffe5fd..78cc8100 100644 --- a/manimlib/mobject/shape_matchers.py +++ b/manimlib/mobject/shape_matchers.py @@ -1,20 +1,25 @@ from __future__ import annotations -from manimlib.constants import * +from colour import Color + +from manimlib.constants import BLACK, RED, YELLOW +from manimlib.constants import DL, DOWN, DR, LEFT, RIGHT, UL, UR +from manimlib.constants import SMALL_BUFF from manimlib.mobject.geometry import Line from manimlib.mobject.geometry import Rectangle from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VMobject -from manimlib.utils.color import Color -from manimlib.utils.customization import get_customization from manimlib.utils.config_ops import digest_config +from manimlib.utils.customization import get_customization from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Union, Sequence + from typing import Union + from manimlib.mobject.mobject import Mobject - ManimColor = Union[str, Color, Sequence[float]] + + ManimColor = Union[str, Color] class SurroundingRectangle(Rectangle): diff --git a/manimlib/mobject/svg/brace.py b/manimlib/mobject/svg/brace.py index 659b0604..e49cebaf 100644 --- a/manimlib/mobject/svg/brace.py +++ b/manimlib/mobject/svg/brace.py @@ -2,27 +2,32 @@ from __future__ import annotations import math import copy -from typing import Iterable import numpy as np -from manimlib.constants import * +from manimlib.constants import DEFAULT_MOBJECT_TO_MOBJECT_BUFFER, SMALL_BUFF +from manimlib.constants import DOWN, LEFT, ORIGIN, RIGHT, UP +from manimlib.constants import PI +from manimlib.animation.composition import AnimationGroup from manimlib.animation.fading import FadeIn from manimlib.animation.growing import GrowFromCenter -from manimlib.animation.composition import AnimationGroup -from manimlib.mobject.svg.tex_mobject import Tex from manimlib.mobject.svg.tex_mobject import SingleStringTex +from manimlib.mobject.svg.tex_mobject import Tex from manimlib.mobject.svg.tex_mobject import TexText from manimlib.mobject.svg.text_mobject import Text from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.config_ops import digest_config +from manimlib.utils.iterables import listify from manimlib.utils.space_ops import get_norm from typing import TYPE_CHECKING if TYPE_CHECKING: - from manimlib.mobject.mobject import Mobject + from typing import Iterable + from manimlib.animation.animation import Animation + from manimlib.mobject.mobject import Mobject + class Brace(SingleStringTex): CONFIG = { @@ -113,8 +118,8 @@ class BraceLabel(VMobject): def __init__( self, - obj: list[VMobject] | Mobject, - text: Iterable[str] | str, + obj: VMobject | list[VMobject], + text: str | Iterable[str], brace_direction: np.ndarray = DOWN, **kwargs ) -> None: @@ -124,12 +129,8 @@ class BraceLabel(VMobject): obj = VMobject(*obj) self.brace = Brace(obj, brace_direction, **kwargs) - if isinstance(text, Iterable): - self.label = self.label_constructor(*text, **kwargs) - else: - self.label = self.label_constructor(str(text)) - if self.label_scale != 1: - self.label.scale(self.label_scale) + self.label = self.label_constructor(*listify(text), **kwargs) + self.label.scale(self.label_scale) self.brace.put_at_tip(self.label, buff=self.label_buff) self.set_submobjects([self.brace, self.label]) @@ -137,11 +138,11 @@ class BraceLabel(VMobject): def creation_anim( self, label_anim: Animation = FadeIn, - brace_anim: Animation=GrowFromCenter + brace_anim: Animation = GrowFromCenter ) -> AnimationGroup: return AnimationGroup(brace_anim(self.brace), label_anim(self.label)) - def shift_brace(self, obj: list[VMobject] | Mobject, **kwargs): + def shift_brace(self, obj: VMobject | list[VMobject], **kwargs): if isinstance(obj, list): obj = VMobject(*obj) self.brace = Brace(obj, self.brace_direction, **kwargs) @@ -158,7 +159,7 @@ class BraceLabel(VMobject): self.submobjects[1] = self.label return self - def change_brace_label(self, obj: list[VMobject] | Mobject, *text: str): + def change_brace_label(self, obj: VMobject | list[VMobject], *text: str): self.shift_brace(obj) self.change_label(*text) return self diff --git a/manimlib/mobject/svg/drawings.py b/manimlib/mobject/svg/drawings.py index 6eee3428..ed2102c3 100644 --- a/manimlib/mobject/svg/drawings.py +++ b/manimlib/mobject/svg/drawings.py @@ -20,8 +20,8 @@ from manimlib.utils.config_ops import digest_config from manimlib.utils.rate_functions import linear from manimlib.utils.space_ops import angle_of_vector from manimlib.utils.space_ops import complex_to_R3 -from manimlib.utils.space_ops import rotate_vector from manimlib.utils.space_ops import midpoint +from manimlib.utils.space_ops import rotate_vector class Checkmark(TexText): diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 58c47094..1cd2b2a1 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -1,10 +1,8 @@ from __future__ import annotations -import re -import colour -import itertools as it -from typing import Iterable, Union, Sequence from abc import ABC, abstractmethod +import itertools as it +import re from manimlib.constants import BLACK, WHITE from manimlib.mobject.svg.svg_mobject import SVGMobject @@ -15,12 +13,15 @@ from manimlib.utils.color import rgb_to_hex from manimlib.utils.config_ops import digest_config from manimlib.utils.iterables import remove_list_redundancies - from typing import TYPE_CHECKING if TYPE_CHECKING: + from colour import Color + from typing import Iterable, Union + from manimlib.mobject.types.vectorized_mobject import VMobject - ManimColor = Union[str, colour.Color, Sequence[float]] + + ManimColor = Union[str, Color] Span = tuple[int, int] diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index fb7922e1..9f35f620 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -1,21 +1,22 @@ from __future__ import annotations import itertools as it -import colour -from typing import Union, Sequence from manimlib.mobject.svg.labelled_string import LabelledString -from manimlib.utils.tex_file_writing import tex_to_svg_file -from manimlib.utils.tex_file_writing import get_tex_config from manimlib.utils.tex_file_writing import display_during_execution - +from manimlib.utils.tex_file_writing import get_tex_config +from manimlib.utils.tex_file_writing import tex_to_svg_file from typing import TYPE_CHECKING if TYPE_CHECKING: - from manimlib.mobject.types.vectorized_mobject import VMobject + from colour import Color + from typing import Union + from manimlib.mobject.types.vectorized_mobject import VGroup - ManimColor = Union[str, colour.Color, Sequence[float]] + from manimlib.mobject.types.vectorized_mobject import VMobject + + ManimColor = Union[str, Color] Span = tuple[int, int] diff --git a/manimlib/mobject/svg/svg_mobject.py b/manimlib/mobject/svg/svg_mobject.py index b44c107f..09a11ff8 100644 --- a/manimlib/mobject/svg/svg_mobject.py +++ b/manimlib/mobject/svg/svg_mobject.py @@ -1,17 +1,17 @@ from __future__ import annotations -import os import hashlib import itertools as it -from typing import Callable +import os from xml.etree import ElementTree as ET -import svgelements as se import numpy as np +import svgelements as se from manimlib.constants import RIGHT -from manimlib.mobject.geometry import Line +from manimlib.logger import log from manimlib.mobject.geometry import Circle +from manimlib.mobject.geometry import Line from manimlib.mobject.geometry import Polygon from manimlib.mobject.geometry import Polyline from manimlib.mobject.geometry import Rectangle @@ -21,7 +21,6 @@ from manimlib.utils.config_ops import digest_config from manimlib.utils.directories import get_mobject_data_dir from manimlib.utils.images import get_full_vector_image_path from manimlib.utils.iterables import hash_obj -from manimlib.logger import log SVG_HASH_TO_MOB_MAP: dict[int, VMobject] = {} diff --git a/manimlib/mobject/svg/tex_mobject.py b/manimlib/mobject/svg/tex_mobject.py index 619f5bc9..fb444608 100644 --- a/manimlib/mobject/svg/tex_mobject.py +++ b/manimlib/mobject/svg/tex_mobject.py @@ -1,21 +1,28 @@ from __future__ import annotations -from typing import Iterable, Sequence, Union from functools import reduce import operator as op -import colour import re -from manimlib.constants import * +from manimlib.constants import BLACK, WHITE +from manimlib.constants import DOWN, LEFT, RIGHT, UP +from manimlib.constants import FRAME_WIDTH +from manimlib.constants import MED_LARGE_BUFF, MED_SMALL_BUFF, SMALL_BUFF from manimlib.mobject.geometry import Line from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.utils.config_ops import digest_config -from manimlib.utils.tex_file_writing import tex_to_svg_file -from manimlib.utils.tex_file_writing import get_tex_config from manimlib.utils.tex_file_writing import display_during_execution +from manimlib.utils.tex_file_writing import get_tex_config +from manimlib.utils.tex_file_writing import tex_to_svg_file -ManimColor = Union[str, colour.Color, Sequence[float]] +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from colour import Color + from typing import Iterable, Union + + ManimColor = Union[str, Color] SCALE_FACTOR_PER_FONT_POINT = 0.001 diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index c3c3be19..d2bf0b6f 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -1,36 +1,38 @@ from __future__ import annotations -import os -import re -import itertools as it -from pathlib import Path from contextlib import contextmanager -import typing -from typing import Iterable, Sequence, Union +import itertools as it +import os +from pathlib import Path +import re +from manimpango import MarkupUtils import pygments import pygments.formatters import pygments.lexers -from manimpango import MarkupUtils - +from manimlib.constants import BLACK +from manimlib.constants import DEFAULT_PIXEL_HEIGHT, DEFAULT_PIXEL_WIDTH +from manimlib.constants import NORMAL from manimlib.logger import log -from manimlib.constants import * from manimlib.mobject.svg.labelled_string import LabelledString -from manimlib.utils.customization import get_customization -from manimlib.utils.tex_file_writing import tex_hash from manimlib.utils.config_ops import digest_config +from manimlib.utils.customization import get_customization from manimlib.utils.directories import get_downloads_dir from manimlib.utils.directories import get_text_dir from manimlib.utils.iterables import remove_list_redundancies - +from manimlib.utils.tex_file_writing import tex_hash from typing import TYPE_CHECKING if TYPE_CHECKING: + from colour import Color + from typing import Any, Union + from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.mobject.types.vectorized_mobject import VGroup - ManimColor = Union[str, colour.Color, Sequence[float]] + + ManimColor = Union[str, Color] Span = tuple[int, int] @@ -256,7 +258,7 @@ class MarkupText(LabelledString): @staticmethod def merge_attr_dicts( - attr_dict_items: list[Span, str, typing.Any] + attr_dict_items: list[Span, str, Any] ) -> list[tuple[Span, dict[str, str]]]: index_seq = [0] attr_dict_list = [{}] @@ -356,7 +358,7 @@ class MarkupText(LabelledString): ) return result - def get_global_dict_from_config(self) -> dict[str, typing.Any]: + def get_global_dict_from_config(self) -> dict[str, Any]: result = { "line_height": ( (self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1 @@ -380,7 +382,7 @@ class MarkupText(LabelledString): def get_local_dicts_from_config( self - ) -> list[Span, dict[str, typing.Any]]: + ) -> list[Span, dict[str, Any]]: return [ (span, {key: val}) for t2x_dict, key in ( diff --git a/manimlib/mobject/three_dimensions.py b/manimlib/mobject/three_dimensions.py index c6c2d946..a42aa7ba 100644 --- a/manimlib/mobject/three_dimensions.py +++ b/manimlib/mobject/three_dimensions.py @@ -2,19 +2,21 @@ from __future__ import annotations import math -from manimlib.constants import * -from manimlib.mobject.types.surface import Surface +from manimlib.constants import BLUE, BLUE_D, BLUE_E +from manimlib.constants import IN, ORIGIN, OUT, RIGHT +from manimlib.constants import PI, TAU from manimlib.mobject.types.surface import SGroup +from manimlib.mobject.types.surface import Surface from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VMobject -from manimlib.mobject.geometry import Square from manimlib.mobject.geometry import Polygon +from manimlib.mobject.geometry import Square from manimlib.utils.bezier import interpolate from manimlib.utils.config_ops import digest_config from manimlib.utils.iterables import adjacent_pairs +from manimlib.utils.space_ops import compass_directions from manimlib.utils.space_ops import get_norm from manimlib.utils.space_ops import z_to_vector -from manimlib.utils.space_ops import compass_directions class SurfaceMesh(VGroup): diff --git a/manimlib/mobject/types/dot_cloud.py b/manimlib/mobject/types/dot_cloud.py index 3e48aa8f..5975b3ee 100644 --- a/manimlib/mobject/types/dot_cloud.py +++ b/manimlib/mobject/types/dot_cloud.py @@ -1,15 +1,18 @@ from __future__ import annotations -import numpy as np -import numpy.typing as npt import moderngl +import numpy as np -from manimlib.constants import GREY_C -from manimlib.constants import YELLOW +from manimlib.constants import GREY_C, YELLOW from manimlib.constants import ORIGIN from manimlib.mobject.types.point_cloud_mobject import PMobject from manimlib.utils.iterables import resize_preserving_order +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import numpy.typing as npt + DEFAULT_DOT_RADIUS = 0.05 DEFAULT_GLOW_DOT_RADIUS = 0.2 diff --git a/manimlib/mobject/types/image_mobject.py b/manimlib/mobject/types/image_mobject.py index d3f11f2b..0f9c4d0d 100644 --- a/manimlib/mobject/types/image_mobject.py +++ b/manimlib/mobject/types/image_mobject.py @@ -3,7 +3,7 @@ from __future__ import annotations import numpy as np from PIL import Image -from manimlib.constants import * +from manimlib.constants import DL, DR, UL, UR from manimlib.mobject.mobject import Mobject from manimlib.utils.bezier import inverse_interpolate from manimlib.utils.images import get_full_raster_image_path diff --git a/manimlib/mobject/types/point_cloud_mobject.py b/manimlib/mobject/types/point_cloud_mobject.py index 5de41173..f05ca8ea 100644 --- a/manimlib/mobject/types/point_cloud_mobject.py +++ b/manimlib/mobject/types/point_cloud_mobject.py @@ -1,19 +1,22 @@ from __future__ import annotations -from typing import Callable, Sequence, Union - -import colour -import numpy.typing as npt - -from manimlib.constants import * +from manimlib.constants import BLACK +from manimlib.constants import ORIGIN from manimlib.mobject.mobject import Mobject from manimlib.utils.color import color_gradient from manimlib.utils.color import color_to_rgba -from manimlib.utils.iterables import resize_with_interpolation from manimlib.utils.iterables import resize_array +from manimlib.utils.iterables import resize_with_interpolation +from typing import TYPE_CHECKING -ManimColor = Union[str, colour.Color, Sequence[float]] +if TYPE_CHECKING: + from colour import Color + from typing import Callable, Union + + import numpy.typing as npt + + ManimColor = Union[str, Color] class PMobject(Mobject): diff --git a/manimlib/mobject/types/surface.py b/manimlib/mobject/types/surface.py index cc3e32d5..339f62f1 100644 --- a/manimlib/mobject/types/surface.py +++ b/manimlib/mobject/types/surface.py @@ -1,12 +1,10 @@ from __future__ import annotations -from typing import Iterable, Callable - import moderngl import numpy as np -import numpy.typing as npt -from manimlib.constants import * +from manimlib.constants import GREY +from manimlib.constants import OUT from manimlib.mobject.mobject import Mobject from manimlib.utils.bezier import integer_interpolate from manimlib.utils.bezier import interpolate @@ -17,6 +15,10 @@ from manimlib.utils.space_ops import normalize_along_axis from typing import TYPE_CHECKING if TYPE_CHECKING: + from typing import Callable, Iterable + + import numpy.typing as npt + from manimlib.camera.camera import Camera diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index f5b47859..44a6bccb 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -1,30 +1,33 @@ from __future__ import annotations -import operator as op +from functools import reduce +from functools import wraps import itertools as it -from functools import reduce, wraps -from typing import Iterable, Sequence, Callable, Union +import operator as op -import colour import moderngl -import numpy.typing as npt +import numpy as np -from manimlib.constants import * +from manimlib.constants import BLACK, WHITE +from manimlib.constants import DEFAULT_STROKE_WIDTH +from manimlib.constants import DEGREES +from manimlib.constants import JOINT_TYPE_MAP +from manimlib.constants import ORIGIN, OUT, UP from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Point from manimlib.utils.bezier import bezier -from manimlib.utils.bezier import get_smooth_quadratic_bezier_handle_points -from manimlib.utils.bezier import get_smooth_cubic_bezier_handle_points from manimlib.utils.bezier import get_quadratic_approximation_of_cubic +from manimlib.utils.bezier import get_smooth_cubic_bezier_handle_points +from manimlib.utils.bezier import get_smooth_quadratic_bezier_handle_points +from manimlib.utils.bezier import integer_interpolate from manimlib.utils.bezier import interpolate from manimlib.utils.bezier import inverse_interpolate -from manimlib.utils.bezier import integer_interpolate from manimlib.utils.bezier import partial_quadratic_bezier_points from manimlib.utils.color import rgb_to_hex +from manimlib.utils.iterables import listify from manimlib.utils.iterables import make_even from manimlib.utils.iterables import resize_array from manimlib.utils.iterables import resize_with_interpolation -from manimlib.utils.iterables import listify from manimlib.utils.space_ops import angle_between_vectors from manimlib.utils.space_ops import cross2d from manimlib.utils.space_ops import earclip_triangulation @@ -33,8 +36,15 @@ from manimlib.utils.space_ops import get_unit_normal from manimlib.utils.space_ops import z_to_vector from manimlib.shader_wrapper import ShaderWrapper +from typing import TYPE_CHECKING -ManimColor = Union[str, colour.Color, Sequence[float]] +if TYPE_CHECKING: + from colour import Color + from typing import Callable, Iterable, Sequence, Union + + import numpy.typing as npt + + ManimColor = Union[str, Color] class VMobject(Mobject): @@ -130,8 +140,8 @@ class VMobject(Mobject): def set_fill( self, - color: ManimColor | None = None, - opacity: float | None = None, + color: ManimColor | Iterable[ManimColor] | None = None, + opacity: float | Iterable[float] | None = None, recurse: bool = True ): self.set_rgba_array_by_color(color, opacity, 'fill_rgba', recurse) @@ -139,9 +149,9 @@ class VMobject(Mobject): def set_stroke( self, - color: ManimColor | None = None, - width: float | npt.ArrayLike | None = None, - opacity: float | None = None, + color: ManimColor | Iterable[ManimColor] | None = None, + width: float | Iterable[float] | None = None, + opacity: float | Iterable[float] | None = None, background: bool | None = None, recurse: bool = True ): @@ -162,8 +172,8 @@ class VMobject(Mobject): def set_backstroke( self, - color: ManimColor = BLACK, - width: float | npt.ArrayLike = 3, + color: ManimColor | Iterable[ManimColor] = BLACK, + width: float | Iterable[float] = 3, background: bool = True ): self.set_stroke(color, width, background=background) @@ -177,13 +187,13 @@ class VMobject(Mobject): def set_style( self, - fill_color: ManimColor | None = None, - fill_opacity: float | None = None, + fill_color: ManimColor | Iterable[ManimColor] | None = None, + fill_opacity: float | Iterable[float] | None = None, fill_rgba: npt.ArrayLike | None = None, - stroke_color: ManimColor | None = None, - stroke_opacity: float | None = None, + stroke_color: ManimColor | Iterable[ManimColor] | None = None, + stroke_opacity: float | Iterable[float] | None = None, stroke_rgba: npt.ArrayLike | None = None, - stroke_width: float | npt.ArrayLike | None = None, + stroke_width: float | Iterable[float] | None = None, stroke_background: bool = True, reflectiveness: float | None = None, gloss: float | None = None, @@ -247,12 +257,21 @@ class VMobject(Mobject): sm1.match_style(sm2) return self - def set_color(self, color: ManimColor, recurse: bool = True): - self.set_fill(color, recurse=recurse) - self.set_stroke(color, recurse=recurse) + def set_color( + self, + color: ManimColor | Iterable[ManimColor] | None, + opacity: float | Iterable[float] | None = None, + recurse: bool = True + ): + self.set_fill(color, opacity=opacity, recurse=recurse) + self.set_stroke(color, opacity=opacity, recurse=recurse) return self - def set_opacity(self, opacity: float, recurse: bool = True): + def set_opacity( + self, + opacity: float | Iterable[float] | None, + recurse: bool = True + ): self.set_fill(opacity=opacity, recurse=recurse) self.set_stroke(opacity=opacity, recurse=recurse) return self diff --git a/manimlib/mobject/vector_field.py b/manimlib/mobject/vector_field.py index b17b55de..016e6f24 100644 --- a/manimlib/mobject/vector_field.py +++ b/manimlib/mobject/vector_field.py @@ -1,23 +1,21 @@ from __future__ import annotations import itertools as it -import random -from typing import Sequence, TypeVar, Callable, Iterable import numpy as np -import numpy.typing as npt -from manimlib.constants import * +from manimlib.constants import FRAME_HEIGHT, FRAME_WIDTH +from manimlib.constants import WHITE from manimlib.animation.composition import AnimationGroup from manimlib.animation.indication import VShowPassingFlash from manimlib.mobject.geometry import Arrow from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VMobject -from manimlib.utils.bezier import inverse_interpolate from manimlib.utils.bezier import interpolate +from manimlib.utils.bezier import inverse_interpolate from manimlib.utils.color import get_colormap_list -from manimlib.utils.config_ops import merge_dicts_recursively from manimlib.utils.config_ops import digest_config +from manimlib.utils.config_ops import merge_dicts_recursively from manimlib.utils.rate_functions import linear from manimlib.utils.simple_functions import sigmoid from manimlib.utils.space_ops import get_norm @@ -25,8 +23,13 @@ from manimlib.utils.space_ops import get_norm from typing import TYPE_CHECKING if TYPE_CHECKING: - from manimlib.mobject.mobject import Mobject + from typing import Callable, Iterable, Sequence, TypeVar + + import numpy.typing as npt + from manimlib.mobject.coordinate_systems import CoordinateSystem + from manimlib.mobject.mobject import Mobject + T = TypeVar("T") @@ -299,7 +302,7 @@ class AnimatedStreamLines(VGroup): **self.line_anim_config, ) line.anim.begin() - line.time = -self.lag_range * random.random() + line.time = -self.lag_range * np.random.random() self.add(line.anim.mobject) self.add_updater(lambda m, dt: m.update(dt)) diff --git a/manimlib/once_useful_constructs/fractals.py b/manimlib/once_useful_constructs/fractals.py index 6285f554..57512262 100644 --- a/manimlib/once_useful_constructs/fractals.py +++ b/manimlib/once_useful_constructs/fractals.py @@ -1,4 +1,5 @@ from functools import reduce +import random from manimlib.constants import * # from manimlib.for_3b1b_videos.pi_creature import PiCreature diff --git a/manimlib/scene/sample_space_scene.py b/manimlib/scene/sample_space_scene.py index cc1c73f5..b68c23d9 100644 --- a/manimlib/scene/sample_space_scene.py +++ b/manimlib/scene/sample_space_scene.py @@ -2,10 +2,11 @@ from manimlib.animation.animation import Animation from manimlib.animation.transform import MoveToTarget from manimlib.animation.transform import Transform from manimlib.animation.update import UpdateFromFunc -from manimlib.constants import * -from manimlib.scene.scene import Scene +from manimlib.constants import DOWN, RIGHT +from manimlib.constants import MED_LARGE_BUFF, SMALL_BUFF from manimlib.mobject.probability import SampleSpace from manimlib.mobject.types.vectorized_mobject import VGroup +from manimlib.scene.scene import Scene class SampleSpaceScene(Scene): diff --git a/manimlib/scene/scene.py b/manimlib/scene/scene.py index 3b649c96..ace012f8 100644 --- a/manimlib/scene/scene.py +++ b/manimlib/scene/scene.py @@ -1,35 +1,36 @@ from __future__ import annotations -import time -import random -import inspect -import platform -import itertools as it from functools import wraps -from typing import Iterable, Callable +import inspect +import itertools as it +import platform +import random +import time -from tqdm import tqdm as ProgressDisplay import numpy as np -import numpy.typing as npt +from tqdm import tqdm as ProgressDisplay from manimlib.animation.animation import prepare_animation from manimlib.animation.transform import MoveToTarget from manimlib.camera.camera import Camera from manimlib.constants import DEFAULT_WAIT_TIME +from manimlib.event_handler import EVENT_DISPATCHER +from manimlib.event_handler.event_type import EventType +from manimlib.logger import log from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Point from manimlib.scene.scene_file_writer import SceneFileWriter from manimlib.utils.config_ops import digest_config from manimlib.utils.family_ops import extract_mobject_family_members from manimlib.utils.family_ops import restructure_list_to_exclude_certain_family_members -from manimlib.event_handler.event_type import EventType -from manimlib.event_handler import EVENT_DISPATCHER -from manimlib.logger import log from typing import TYPE_CHECKING if TYPE_CHECKING: + from typing import Callable, Iterable + from PIL.Image import Image + from manimlib.animation.animation import Animation diff --git a/manimlib/scene/scene_file_writer.py b/manimlib/scene/scene_file_writer.py index cb948ab5..becdb44c 100644 --- a/manimlib/scene/scene_file_writer.py +++ b/manimlib/scene/scene_file_writer.py @@ -1,30 +1,31 @@ from __future__ import annotations import os -import sys -import shutil import platform +import shutil import subprocess as sp +import sys import numpy as np from pydub import AudioSegment from tqdm import tqdm as ProgressDisplay from manimlib.constants import FFMPEG_BIN +from manimlib.logger import log from manimlib.utils.config_ops import digest_config -from manimlib.utils.file_ops import guarantee_existence from manimlib.utils.file_ops import add_extension_if_not_present from manimlib.utils.file_ops import get_sorted_integer_files +from manimlib.utils.file_ops import guarantee_existence from manimlib.utils.sounds import get_full_sound_file_path -from manimlib.logger import log from typing import TYPE_CHECKING if TYPE_CHECKING: - from manimlib.scene.scene import Scene - from manimlib.camera.camera import Camera from PIL.Image import Image + from manimlib.camera.camera import Camera + from manimlib.scene.scene import Scene + class SceneFileWriter(object): CONFIG = { diff --git a/manimlib/scene/vector_space_scene.py b/manimlib/scene/vector_space_scene.py index e85c96d2..4b5a2efc 100644 --- a/manimlib/scene/vector_space_scene.py +++ b/manimlib/scene/vector_space_scene.py @@ -8,7 +8,10 @@ from manimlib.animation.growing import GrowArrow from manimlib.animation.transform import ApplyFunction from manimlib.animation.transform import ApplyPointwiseFunction from manimlib.animation.transform import Transform -from manimlib.constants import * +from manimlib.constants import BLACK, BLUE_D, GREEN_C, RED_C, GREY, WHITE, YELLOW +from manimlib.constants import DL, DOWN, ORIGIN, RIGHT, UP +from manimlib.constants import FRAME_WIDTH, FRAME_X_RADIUS, FRAME_Y_RADIUS +from manimlib.constants import SMALL_BUFF from manimlib.mobject.coordinate_systems import Axes from manimlib.mobject.coordinate_systems import NumberPlane from manimlib.mobject.geometry import Arrow @@ -30,6 +33,7 @@ from manimlib.utils.rate_functions import rush_into from manimlib.utils.space_ops import angle_of_vector from manimlib.utils.space_ops import get_norm + X_COLOR = GREEN_C Y_COLOR = RED_C Z_COLOR = BLUE_D diff --git a/manimlib/shader_wrapper.py b/manimlib/shader_wrapper.py index bd32a7ae..5ed2364f 100644 --- a/manimlib/shader_wrapper.py +++ b/manimlib/shader_wrapper.py @@ -1,9 +1,8 @@ from __future__ import annotations +import copy import os import re -import copy -from typing import Iterable import moderngl import numpy as np @@ -11,6 +10,12 @@ import numpy as np from manimlib.utils.directories import get_shader_dir from manimlib.utils.file_ops import find_file +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Iterable + + # Mobjects that should be rendered with # the same shader will be organized and # clumped together based on keeping track diff --git a/manimlib/utils/bezier.py b/manimlib/utils/bezier.py index 71d3d2b9..293b2228 100644 --- a/manimlib/utils/bezier.py +++ b/manimlib/utils/bezier.py @@ -1,19 +1,26 @@ from __future__ import annotations -from typing import Iterable, Callable, TypeVar, Sequence - -from scipy import linalg import numpy as np -import numpy.typing as npt +from scipy import linalg -from manimlib.utils.simple_functions import choose -from manimlib.utils.space_ops import find_intersection -from manimlib.utils.space_ops import cross2d -from manimlib.utils.space_ops import midpoint from manimlib.logger import log +from manimlib.utils.simple_functions import choose +from manimlib.utils.space_ops import cross2d +from manimlib.utils.space_ops import find_intersection +from manimlib.utils.space_ops import midpoint + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Callable, Iterable, Sequence, TypeVar + + import numpy.typing as npt + + T = TypeVar("T") + CLOSED_THRESHOLD = 0.001 -T = TypeVar("T") + def bezier( points: Iterable[float | np.ndarray] diff --git a/manimlib/utils/color.py b/manimlib/utils/color.py index 1687d5b3..ae6517e0 100644 --- a/manimlib/utils/color.py +++ b/manimlib/utils/color.py @@ -1,18 +1,21 @@ from __future__ import annotations -from typing import Iterable, Union - from colour import Color from colour import hex2rgb from colour import rgb2hex import numpy as np -from manimlib.constants import WHITE from manimlib.constants import COLORMAP_3B1B +from manimlib.constants import WHITE from manimlib.utils.bezier import interpolate from manimlib.utils.iterables import resize_with_interpolation -ManimColor = Union[str, Color] +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Iterable, Union + + ManimColor = Union[str, Color] def color_to_rgb(color: ManimColor) -> np.ndarray: diff --git a/manimlib/utils/customization.py b/manimlib/utils/customization.py index bf79b1b8..94923b43 100644 --- a/manimlib/utils/customization.py +++ b/manimlib/utils/customization.py @@ -4,6 +4,7 @@ import tempfile from manimlib.config import get_custom_config from manimlib.config import get_manim_dir + CUSTOMIZATION = {} diff --git a/manimlib/utils/debug.py b/manimlib/utils/debug.py index 29aa6a3c..f877af8b 100644 --- a/manimlib/utils/debug.py +++ b/manimlib/utils/debug.py @@ -1,17 +1,18 @@ from __future__ import annotations -import time import numpy as np -from typing import Callable +import time from manimlib.constants import BLACK +from manimlib.logger import log from manimlib.mobject.numbers import Integer from manimlib.mobject.types.vectorized_mobject import VGroup -from manimlib.logger import log from typing import TYPE_CHECKING if TYPE_CHECKING: + from typing import Callable + from manimlib.mobject.mobject import Mobject diff --git a/manimlib/utils/directories.py b/manimlib/utils/directories.py index 87970523..daf1714c 100644 --- a/manimlib/utils/directories.py +++ b/manimlib/utils/directories.py @@ -2,8 +2,8 @@ from __future__ import annotations import os -from manimlib.utils.file_ops import guarantee_existence from manimlib.utils.customization import get_customization +from manimlib.utils.file_ops import guarantee_existence def get_directories() -> dict[str, str]: diff --git a/manimlib/utils/family_ops.py b/manimlib/utils/family_ops.py index fc1a8b93..5db6186c 100644 --- a/manimlib/utils/family_ops.py +++ b/manimlib/utils/family_ops.py @@ -1,11 +1,12 @@ from __future__ import annotations import itertools as it -from typing import Iterable from typing import TYPE_CHECKING if TYPE_CHECKING: + from typing import Iterable + from manimlib.mobject.mobject import Mobject diff --git a/manimlib/utils/file_ops.py b/manimlib/utils/file_ops.py index a50366bc..e1419e91 100644 --- a/manimlib/utils/file_ops.py +++ b/manimlib/utils/file_ops.py @@ -1,11 +1,15 @@ from __future__ import annotations import os -from typing import Iterable import numpy as np import validators +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Iterable + def add_extension_if_not_present(file_name: str, extension: str) -> str: # This could conceivably be smarter about handling existing differing extensions diff --git a/manimlib/utils/images.py b/manimlib/utils/images.py index cab0a45f..35a573a0 100644 --- a/manimlib/utils/images.py +++ b/manimlib/utils/images.py @@ -2,9 +2,9 @@ import numpy as np from PIL import Image from typing import Iterable -from manimlib.utils.file_ops import find_file from manimlib.utils.directories import get_raster_image_dir from manimlib.utils.directories import get_vector_image_dir +from manimlib.utils.file_ops import find_file def get_full_raster_image_path(image_file_name: str) -> str: diff --git a/manimlib/utils/init_config.py b/manimlib/utils/init_config.py index 36ae9d4b..2e5d4b32 100644 --- a/manimlib/utils/init_config.py +++ b/manimlib/utils/init_config.py @@ -1,16 +1,21 @@ from __future__ import annotations +import importlib +import inspect import os import yaml -import inspect -import importlib -from typing import Any from rich import box +from rich.console import Console +from rich.prompt import Confirm +from rich.prompt import Prompt from rich.rule import Rule from rich.table import Table -from rich.console import Console -from rich.prompt import Prompt, Confirm + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any def get_manim_dir() -> str: diff --git a/manimlib/utils/iterables.py b/manimlib/utils/iterables.py index 99788a42..05d92597 100644 --- a/manimlib/utils/iterables.py +++ b/manimlib/utils/iterables.py @@ -1,11 +1,14 @@ from __future__ import annotations -from typing import Callable, Iterable, Sequence, TypeVar - import numpy as np -T = TypeVar("T") -S = TypeVar("S") +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Callable, Iterable, Sequence, TypeVar + + T = TypeVar("T") + S = TypeVar("S") def remove_list_redundancies(l: Iterable[T]) -> list[T]: diff --git a/manimlib/utils/paths.py b/manimlib/utils/paths.py index 3bbf092d..6effed15 100644 --- a/manimlib/utils/paths.py +++ b/manimlib/utils/paths.py @@ -8,6 +8,7 @@ from manimlib.utils.bezier import interpolate from manimlib.utils.space_ops import get_norm from manimlib.utils.space_ops import rotation_matrix_transpose + STRAIGHT_PATH_THRESHOLD = 0.01 diff --git a/manimlib/utils/rate_functions.py b/manimlib/utils/rate_functions.py index 79057734..cb33734b 100644 --- a/manimlib/utils/rate_functions.py +++ b/manimlib/utils/rate_functions.py @@ -1,6 +1,5 @@ -from typing import Callable - import numpy as np +from typing import Callable from manimlib.utils.bezier import bezier diff --git a/manimlib/utils/simple_functions.py b/manimlib/utils/simple_functions.py index c6a7c5d1..1371a744 100644 --- a/manimlib/utils/simple_functions.py +++ b/manimlib/utils/simple_functions.py @@ -1,7 +1,8 @@ -import inspect -import numpy as np -import math from functools import lru_cache +import inspect +import math + +import numpy as np def sigmoid(x): diff --git a/manimlib/utils/sounds.py b/manimlib/utils/sounds.py index 79501284..dc37ff00 100644 --- a/manimlib/utils/sounds.py +++ b/manimlib/utils/sounds.py @@ -1,5 +1,5 @@ -from manimlib.utils.file_ops import find_file from manimlib.utils.directories import get_sound_dir +from manimlib.utils.file_ops import find_file def get_full_sound_file_path(sound_file_name) -> str: diff --git a/manimlib/utils/space_ops.py b/manimlib/utils/space_ops.py index 29c67ab4..e5ca47f0 100644 --- a/manimlib/utils/space_ops.py +++ b/manimlib/utils/space_ops.py @@ -1,25 +1,27 @@ from __future__ import annotations +from functools import reduce import math import operator as op -from functools import reduce -from typing import Callable, Iterable, Sequence import platform -import numpy as np -import numpy.typing as npt from mapbox_earcut import triangulate_float32 as earcut +import numpy as np from scipy.spatial.transform import Rotation from tqdm import tqdm as ProgressDisplay -from manimlib.constants import RIGHT -from manimlib.constants import DOWN -from manimlib.constants import OUT -from manimlib.constants import PI -from manimlib.constants import TAU +from manimlib.constants import DOWN, OUT, RIGHT +from manimlib.constants import PI, TAU from manimlib.utils.iterables import adjacent_pairs from manimlib.utils.simple_functions import clip +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Callable, Iterable, Sequence + + import numpy.typing as npt + def cross(v1: np.ndarray, v2: np.ndarray) -> list[np.ndarray]: return [ diff --git a/manimlib/utils/tex_file_writing.py b/manimlib/utils/tex_file_writing.py index 1ca7ed39..5b57be64 100644 --- a/manimlib/utils/tex_file_writing.py +++ b/manimlib/utils/tex_file_writing.py @@ -1,12 +1,12 @@ -import sys -import os -import hashlib from contextlib import contextmanager +import hashlib +import os +import sys -from manimlib.utils.directories import get_tex_dir -from manimlib.config import get_manim_dir from manimlib.config import get_custom_config +from manimlib.config import get_manim_dir from manimlib.logger import log +from manimlib.utils.directories import get_tex_dir SAVED_TEX_CONFIG = {} diff --git a/manimlib/window.py b/manimlib/window.py index 0d9d3a47..f124c117 100644 --- a/manimlib/window.py +++ b/manimlib/window.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np + import moderngl_window as mglw from moderngl_window.context.pyglet.window import Window as PygletWindow from moderngl_window.timers.clock import Timer From fbebaf0c75b032df6d3e7203c8c1b7038541b484 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Tue, 12 Apr 2022 19:39:19 +0800 Subject: [PATCH 04/64] Sort imports --- manimlib/animation/indication.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/manimlib/animation/indication.py b/manimlib/animation/indication.py index 92c7d013..160464b3 100644 --- a/manimlib/animation/indication.py +++ b/manimlib/animation/indication.py @@ -1,5 +1,6 @@ from __future__ import annotations +import math import numpy as np from manimlib.animation.animation import Animation @@ -220,7 +221,7 @@ class VShowPassingFlash(Animation): if abs(x - mu) > 3 * sigma: return 0 z = (x - mu) / sigma - return np.exp(-0.5 * z * z) + return math.exp(-0.5 * z * z) kernel_array = list(map(gauss_kernel, np.linspace(0, 1, len(anchor_widths)))) scaled_widths = anchor_widths * kernel_array From 93790cde641ff440feedc6f1b348320e05518bc1 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Tue, 12 Apr 2022 20:03:48 +0800 Subject: [PATCH 05/64] Add import annotations statement --- manimlib/utils/images.py | 8 +++++++- manimlib/utils/paths.py | 8 +++++++- manimlib/utils/rate_functions.py | 8 +++++++- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/manimlib/utils/images.py b/manimlib/utils/images.py index 35a573a0..17f9628c 100644 --- a/manimlib/utils/images.py +++ b/manimlib/utils/images.py @@ -1,11 +1,17 @@ +from __future__ import annotations + import numpy as np from PIL import Image -from typing import Iterable from manimlib.utils.directories import get_raster_image_dir from manimlib.utils.directories import get_vector_image_dir from manimlib.utils.file_ops import find_file +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Iterable + def get_full_raster_image_path(image_file_name: str) -> str: return find_file( diff --git a/manimlib/utils/paths.py b/manimlib/utils/paths.py index 6effed15..67192e45 100644 --- a/manimlib/utils/paths.py +++ b/manimlib/utils/paths.py @@ -1,5 +1,6 @@ +from __future__ import annotations + import math -from typing import Callable import numpy as np @@ -8,6 +9,11 @@ from manimlib.utils.bezier import interpolate from manimlib.utils.space_ops import get_norm from manimlib.utils.space_ops import rotation_matrix_transpose +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Callable + STRAIGHT_PATH_THRESHOLD = 0.01 diff --git a/manimlib/utils/rate_functions.py b/manimlib/utils/rate_functions.py index cb33734b..45646760 100644 --- a/manimlib/utils/rate_functions.py +++ b/manimlib/utils/rate_functions.py @@ -1,8 +1,14 @@ +from __future__ import annotations + import numpy as np -from typing import Callable from manimlib.utils.bezier import bezier +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Callable + def linear(t: float) -> float: return t From 296ab84b46853735aeeb5ae770ecd1c7f541a648 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Tue, 12 Apr 2022 20:21:25 +0800 Subject: [PATCH 06/64] Adjust annotation --- manimlib/mobject/mobject.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index d1e5e6eb..959e4ddb 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -645,7 +645,7 @@ class Mobject(object): def scale( self, - scale_factor: float | Iterable[float], + scale_factor: float | Sequence[float], min_scale_factor: float = 1e-8, about_point: np.ndarray | None = None, about_edge: np.ndarray = ORIGIN From b11ce7ff7c722c2092ffe8532e986ffd61faf782 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Tue, 12 Apr 2022 20:22:13 +0800 Subject: [PATCH 07/64] Adjust annotation --- manimlib/mobject/mobject.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 959e4ddb..43a20c9c 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -645,7 +645,7 @@ class Mobject(object): def scale( self, - scale_factor: float | Sequence[float], + scale_factor: float | npt.ArrayLike, min_scale_factor: float = 1e-8, about_point: np.ndarray | None = None, about_edge: np.ndarray = ORIGIN From 42444d090e2577b043e3901394f387b736c65ac2 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Tue, 12 Apr 2022 21:09:25 +0800 Subject: [PATCH 08/64] Add missing import --- manimlib/mobject/geometry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/manimlib/mobject/geometry.py b/manimlib/mobject/geometry.py index 8be3624c..5a9958a9 100644 --- a/manimlib/mobject/geometry.py +++ b/manimlib/mobject/geometry.py @@ -5,7 +5,7 @@ import numbers import numpy as np -from manimlib.constants import DOWN, LEFT, ORIGIN, OUT, RIGHT, UP +from manimlib.constants import DL, DOWN, DR, LEFT, ORIGIN, OUT, RIGHT, UL, UP, UR from manimlib.constants import GREY_A, RED, WHITE from manimlib.constants import MED_SMALL_BUFF from manimlib.constants import PI, TAU From bff9f74b04ecde39fde17934fe8b0af8229de2c3 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Tue, 12 Apr 2022 23:19:10 +0800 Subject: [PATCH 09/64] Prevent from passing an empty string --- manimlib/mobject/svg/mtex_mobject.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 9f35f620..a219cc19 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -33,7 +33,7 @@ class MTex(LabelledString): def __init__(self, tex_string: str, **kwargs): # Prevent from passing an empty string. - if not tex_string: + if not tex_string.strip(): tex_string = "\\\\" self.tex_string = tex_string super().__init__(tex_string, **kwargs) From 0c1e5b337b36768f5be2ae1a0ecda80684d7a9ff Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Wed, 13 Apr 2022 22:51:55 +0800 Subject: [PATCH 10/64] Support passing in complete environment tags --- manimlib/mobject/svg/labelled_string.py | 2 ++ manimlib/mobject/svg/mtex_mobject.py | 11 ++++++----- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 1cd2b2a1..3f45c38a 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -373,6 +373,8 @@ class LabelledString(_StringSVG, ABC): return [] def check_overlapping(self) -> None: + if len(self.label_span_list) >= 16777216: + raise ValueError("Cannot label that many substrings") for span_0, span_1 in it.product(self.label_span_list, repeat=2): if not span_0[0] < span_1[0] < span_0[1] < span_1[1]: continue diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index a219cc19..70128b1f 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -266,11 +266,12 @@ class MTex(LabelledString): result = self.get_replaced_substr(self.full_span, span_repl_dict) if self.tex_environment: - result = "\n".join([ - f"\\begin{{{self.tex_environment}}}", - result, - f"\\end{{{self.tex_environment}}}" - ]) + if isinstance(self.tex_environment, str): + prefix = f"\\begin{{{self.tex_environment}}}" + suffix = f"\\end{{{self.tex_environment}}}" + else: + prefix, suffix = self.tex_environment + result = "\n".join([prefix, result, suffix]) if self.alignment: result = "\n".join([self.alignment, result]) if use_plain_file: From eec6b01a72f0a9ec4e4960c259b59e3b191b7640 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Thu, 14 Apr 2022 21:07:31 +0800 Subject: [PATCH 11/64] Refactor labelled_string.py --- manimlib/mobject/svg/labelled_string.py | 100 ++++++++++-------------- manimlib/mobject/svg/mtex_mobject.py | 18 +---- manimlib/mobject/svg/text_mobject.py | 62 +++++++-------- 3 files changed, 71 insertions(+), 109 deletions(-) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 3f45c38a..765d96cb 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -4,10 +4,9 @@ from abc import ABC, abstractmethod import itertools as it import re -from manimlib.constants import BLACK, WHITE +from manimlib.constants import WHITE from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.types.vectorized_mobject import VGroup -from manimlib.utils.color import color_to_int_rgb from manimlib.utils.color import color_to_rgb from manimlib.utils.color import rgb_to_hex from manimlib.utils.config_ops import digest_config @@ -25,7 +24,10 @@ if TYPE_CHECKING: Span = tuple[int, int] -class _StringSVG(SVGMobject): +class LabelledString(SVGMobject, ABC): + """ + An abstract base class for `MTex` and `MarkupText` + """ CONFIG = { "height": None, "stroke_width": 0, @@ -34,16 +36,6 @@ class _StringSVG(SVGMobject): "should_subdivide_sharp_curves": True, "should_remove_null_curves": True, }, - } - - -class LabelledString(_StringSVG, ABC): - """ - An abstract base class for `MTex` and `MarkupText` - """ - CONFIG = { - "base_color": WHITE, - "use_plain_file": False, "isolate": [], } @@ -51,14 +43,11 @@ class LabelledString(_StringSVG, ABC): self.string = string digest_config(self, kwargs) - # Convert `base_color` to hex code. - self.base_color = rgb_to_hex(color_to_rgb( - self.base_color \ - or self.svg_default.get("color", None) \ - or self.svg_default.get("fill_color", None) \ + self.base_color_int = self.color_to_int( + self.svg_default.get("fill_color") \ + or self.svg_default.get("color") \ or WHITE - )) - self.svg_default["fill_color"] = BLACK + ) self.pre_parse() self.parse() @@ -66,7 +55,7 @@ class LabelledString(_StringSVG, ABC): self.post_parse() def get_file_path(self) -> str: - return self.get_file_path_(use_plain_file=False) + return self.get_file_path_(use_plain_file=True) def get_file_path_(self, use_plain_file: bool) -> str: content = self.get_content(use_plain_file) @@ -79,22 +68,34 @@ class LabelledString(_StringSVG, ABC): def generate_mobject(self) -> None: super().generate_mobject() - submob_labels = [ - self.color_to_label(submob.get_fill_color()) - for submob in self.submobjects - ] - if self.use_plain_file or self.has_predefined_local_colors: - file_path = self.get_file_path_(use_plain_file=True) - plain_svg = _StringSVG( - file_path, - svg_default=self.svg_default, - path_string_config=self.path_string_config - ) - self.set_submobjects(plain_svg.submobjects) + if self.label_span_list: + file_path = self.get_file_path_(use_plain_file=False) + labelled_svg = SVGMobject(file_path) + submob_color_ints = [ + self.color_to_int(submob.get_fill_color()) + for submob in labelled_svg.submobjects + ] else: - self.set_fill(self.base_color) - for submob, label in zip(self.submobjects, submob_labels): - submob.label = label + submob_color_ints = [0] * len(self.submobjects) + + if len(self.submobjects) != len(submob_color_ints): + raise ValueError( + "Cannot align submobjects of the labelled svg " + "to the original svg" + ) + + unrecognized_color_ints = remove_list_redundancies(sorted(filter( + lambda color_int: color_int > len(self.label_span_list), + submob_color_ints + ))) + if unrecognized_color_ints: + raise ValueError( + "Unrecognized color label(s) detected: " + f"{','.join(map(self.int_to_hex, unrecognized_color_ints))}" + ) + + for submob, color_int in zip(self.submobjects, submob_color_ints): + submob.label = color_int - 1 def pre_parse(self) -> None: self.string_len = len(self.string) @@ -283,31 +284,14 @@ class LabelledString(_StringSVG, ABC): return index @staticmethod - def rgb_to_int(rgb_tuple: tuple[int, int, int]) -> int: - r, g, b = rgb_tuple - rg = r * 256 + g - return rg * 256 + b - - @staticmethod - def int_to_rgb(rgb_int: int) -> tuple[int, int, int]: - rg, b = divmod(rgb_int, 256) - r, g = divmod(rg, 256) - return r, g, b + def color_to_int(color: ManimColor) -> int: + hex_code = rgb_to_hex(color_to_rgb(color)) + return int(hex_code[1:], 16) @staticmethod def int_to_hex(rgb_int: int) -> str: return "#{:06x}".format(rgb_int).upper() - @staticmethod - def hex_to_int(rgb_hex: str) -> int: - return int(rgb_hex[1:], 16) - - @staticmethod - def color_to_label(color: ManimColor) -> int: - rgb_tuple = color_to_int_rgb(color) - rgb = LabelledString.rgb_to_int(rgb_tuple) - return rgb - 1 - # Parsing @abstractmethod @@ -387,10 +371,6 @@ class LabelledString(_StringSVG, ABC): def get_content(self, use_plain_file: bool) -> str: return "" - @abstractmethod - def has_predefined_local_colors(self) -> bool: - return False - # Post-parsing def get_labelled_submobjects(self) -> list[VMobject]: diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 70128b1f..8c7d4843 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -47,8 +47,6 @@ class MTex(LabelledString): self.__class__.__name__, self.svg_default, self.path_string_config, - self.base_color, - self.use_plain_file, self.isolate, self.tex_string, self.alignment, @@ -78,13 +76,9 @@ class MTex(LabelledString): @staticmethod def get_color_command_str(rgb_int: int) -> str: - rgb_tuple = MTex.int_to_rgb(rgb_int) - return "".join([ - "\\color[RGB]", - "{", - ",".join(map(str, rgb_tuple)), - "}" - ]) + rg, b = divmod(rgb_int, 256) + r, g = divmod(rg, 256) + return f"\\color[RGB]{{{r}, {g}, {b}}}" # Pre-parsing @@ -276,15 +270,11 @@ class MTex(LabelledString): result = "\n".join([self.alignment, result]) if use_plain_file: result = "\n".join([ - self.get_color_command_str(self.hex_to_int(self.base_color)), + self.get_color_command_str(self.base_color_int), result ]) return result - @property - def has_predefined_local_colors(self) -> bool: - return bool(self.command_repl_items) - # Post-parsing def get_cleaned_substr(self, span: Span) -> str: diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index d2bf0b6f..fdcdb5fe 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -27,7 +27,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from colour import Color - from typing import Any, Union + from typing import Union from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.mobject.types.vectorized_mobject import VGroup @@ -43,7 +43,7 @@ DEFAULT_LINE_SPACING_SCALE = 0.6 # See https://docs.gtk.org/Pango/pango_markup.html # A tag containing two aliases will cause warning, # so only use the first key of each group of aliases. -SPAN_ATTR_KEY_ALIAS_LIST = ( +MARKUP_KEY_ALIAS_LIST = ( ("font", "font_desc"), ("font_family", "face"), ("font_size", "size"), @@ -77,19 +77,14 @@ SPAN_ATTR_KEY_ALIAS_LIST = ( ("text_transform",), ("segment",), ) -COLOR_RELATED_KEYS = ( +MARKUP_COLOR_KEYS = ( "foreground", - "background", - "underline_color", - "overline_color", - "strikethrough_color" + "background", + "underline_color", + "overline_color", + "strikethrough_color" ) -SPAN_ATTR_KEY_CONVERSION = { - key: key_alias_list[0] - for key_alias_list in SPAN_ATTR_KEY_ALIAS_LIST - for key in key_alias_list -} -TAG_TO_ATTR_DICT = { +MARKUP_TAG_CONVERSION_DICT = { "b": {"font_weight": "bold"}, "big": {"font_size": "larger"}, "i": {"font_style": "italic"}, @@ -166,8 +161,6 @@ class MarkupText(LabelledString): self.__class__.__name__, self.svg_default, self.path_string_config, - self.base_color, - self.use_plain_file, self.isolate, self.text, self.is_markup, @@ -258,7 +251,7 @@ class MarkupText(LabelledString): @staticmethod def merge_attr_dicts( - attr_dict_items: list[Span, str, Any] + attr_dict_items: list[tuple[Span, dict[str, str]]] ) -> list[tuple[Span, dict[str, str]]]: index_seq = [0] attr_dict_list = [{}] @@ -344,12 +337,12 @@ class MarkupText(LabelledString): attr_pattern, begin_match_obj.group(3) ) } - elif tag_name in TAG_TO_ATTR_DICT.keys(): + elif tag_name in MARKUP_TAG_CONVERSION_DICT.keys(): if begin_match_obj.group(3): raise ValueError( f"Attributes shan't exist in tag '{tag_name}'" ) - attr_dict = TAG_TO_ATTR_DICT[tag_name].copy() + attr_dict = MARKUP_TAG_CONVERSION_DICT[tag_name].copy() else: raise ValueError(f"Unknown tag: '{tag_name}'") @@ -358,13 +351,13 @@ class MarkupText(LabelledString): ) return result - def get_global_dict_from_config(self) -> dict[str, Any]: + def get_global_dict_from_config(self) -> dict[str, str]: result = { - "line_height": ( + "line_height": str(( (self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1 - ) * 0.6, + ) * 0.6), "font_family": self.font, - "font_size": self.font_size * 1024, + "font_size": str(self.font_size * 1024), "font_style": self.slant, "font_weight": self.weight } @@ -382,7 +375,7 @@ class MarkupText(LabelledString): def get_local_dicts_from_config( self - ) -> list[Span, dict[str, Any]]: + ) -> list[Span, dict[str, str]]: return [ (span, {key: val}) for t2x_dict, key in ( @@ -405,9 +398,14 @@ class MarkupText(LabelledString): *self.local_dicts_from_markup, *self.local_dicts_from_config ] + key_conversion_dict = { + key: key_alias_list[0] + for key_alias_list in MARKUP_KEY_ALIAS_LIST + for key in key_alias_list + } return [ (span, { - SPAN_ATTR_KEY_CONVERSION[key.lower()]: str(val) + key_conversion_dict[key.lower()]: val for key, val in attr_dict.items() }) for span, attr_dict in attr_dict_items @@ -442,7 +440,7 @@ class MarkupText(LabelledString): return [] def get_internal_specified_spans(self) -> list[Span]: - return [span for span, _ in self.local_dicts_from_markup] + return [] def get_external_specified_spans(self) -> list[Span]: return [span for span, _ in self.local_dicts_from_config] @@ -468,7 +466,9 @@ class MarkupText(LabelledString): def get_content(self, use_plain_file: bool) -> str: if use_plain_file: attr_dict_items = [ - (self.full_span, {"foreground": self.base_color}), + (self.full_span, { + "foreground": self.int_to_hex(self.base_color_int) + }), *self.predefined_attr_dicts, *[ (span, {}) @@ -480,7 +480,7 @@ class MarkupText(LabelledString): (self.full_span, {"foreground": BLACK}), *[ (span, { - key: BLACK if key in COLOR_RELATED_KEYS else val + key: BLACK if key in MARKUP_COLOR_KEYS else val for key, val in attr_dict.items() }) for span, attr_dict in self.predefined_attr_dicts @@ -502,14 +502,6 @@ class MarkupText(LabelledString): ) return self.get_replaced_substr(self.full_span, span_repl_dict) - @property - def has_predefined_local_colors(self) -> bool: - return any([ - key in COLOR_RELATED_KEYS - for _, attr_dict in self.predefined_attr_dicts - for key in attr_dict.keys() - ]) - # Method alias def get_parts_by_text(self, text: str, **kwargs) -> VGroup: From 4c324767bdaa5120d17958f1711b725bbe07f0f1 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Fri, 15 Apr 2022 00:55:02 +0800 Subject: [PATCH 12/64] Recover Mobject.scale method --- manimlib/mobject/mobject.py | 6 +++++- manimlib/mobject/numbers.py | 2 ++ manimlib/mobject/svg/labelled_string.py | 11 ++++++++--- manimlib/mobject/svg/text_mobject.py | 17 ++++++++++------- 4 files changed, 25 insertions(+), 11 deletions(-) diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 43a20c9c..a9116d03 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -7,6 +7,7 @@ import itertools as it from functools import wraps import moderngl +import numbers import numpy as np from manimlib.constants import DEFAULT_MOBJECT_TO_EDGE_BUFFER @@ -659,7 +660,10 @@ class Mobject(object): Otherwise, if about_point is given a value, scaling is done with respect to that point. """ - scale_factor = np.resize(scale_factor, self.dim).clip(min=min_scale_factor) + if isinstance(scale_factor, numbers.Number): + scale_factor = max(scale_factor, min_scale_factor) + else: + scale_factor = np.array(scale_factor).clip(min=min_scale_factor) self.apply_points_function( lambda points: scale_factor * points, about_point=about_point, diff --git a/manimlib/mobject/numbers.py b/manimlib/mobject/numbers.py index a22df66d..6d88c647 100644 --- a/manimlib/mobject/numbers.py +++ b/manimlib/mobject/numbers.py @@ -1,5 +1,7 @@ from __future__ import annotations +import numpy as np + from manimlib.constants import DOWN, LEFT, RIGHT, UP from manimlib.mobject.svg.tex_mobject import SingleStringTex from manimlib.mobject.svg.text_mobject import Text diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 765d96cb..3d0efb4f 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -311,6 +311,12 @@ class LabelledString(SVGMobject, ABC): self.extra_entity_spans )) + def index_not_in_entity_spans(self, index: int) -> bool: + return not any([ + entity_span[0] < index < entity_span[1] + for entity_span in self.entity_spans + ]) + @abstractmethod def get_extra_ignored_spans(self) -> list[int]: return [] @@ -343,10 +349,9 @@ class LabelledString(SVGMobject, ABC): self.find_substrs(self.isolate) )) shrinked_spans = list(filter( - lambda span: span[0] < span[1] and not any([ - entity_span[0] < index < entity_span[1] + lambda span: span[0] < span[1] and all([ + self.index_not_in_entity_spans(index) for index in span - for entity_span in self.entity_spans ]), [self.shrink_span(span) for span in spans] )) diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index fdcdb5fe..1e3f073a 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -452,11 +452,7 @@ class MarkupText(LabelledString): self.specified_spans )))) breakup_indices = sorted(filter( - lambda index: not any([ - span[0] < index < span[1] - for span in self.entity_spans - ]), - breakup_indices + self.index_not_in_entity_spans, breakup_indices )) return list(filter( lambda span: self.get_substr(span).strip(), @@ -464,12 +460,19 @@ class MarkupText(LabelledString): )) def get_content(self, use_plain_file: bool) -> str: + filtered_attr_dicts = list(filter( + lambda item: all([ + self.index_not_in_entity_spans(index) + for index in item[0] + ]), + self.predefined_attr_dicts + )) if use_plain_file: attr_dict_items = [ (self.full_span, { "foreground": self.int_to_hex(self.base_color_int) }), - *self.predefined_attr_dicts, + *filtered_attr_dicts, *[ (span, {}) for span in self.label_span_list @@ -483,7 +486,7 @@ class MarkupText(LabelledString): key: BLACK if key in MARKUP_COLOR_KEYS else val for key, val in attr_dict.items() }) - for span, attr_dict in self.predefined_attr_dicts + for span, attr_dict in filtered_attr_dicts ], *[ (span, {"foreground": self.int_to_hex(label + 1)}) From 020bd8727113f7c11580c3ae3381671924cea8bb Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Fri, 15 Apr 2022 13:27:50 +0800 Subject: [PATCH 13/64] Add back base_color attribute --- manimlib/mobject/svg/labelled_string.py | 29 ++++++++++++++----------- manimlib/mobject/svg/mtex_mobject.py | 1 + manimlib/mobject/svg/text_mobject.py | 5 +++-- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 3d0efb4f..d20b3c11 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -36,18 +36,15 @@ class LabelledString(SVGMobject, ABC): "should_subdivide_sharp_curves": True, "should_remove_null_curves": True, }, + "base_color": WHITE, "isolate": [], } def __init__(self, string: str, **kwargs): self.string = string digest_config(self, kwargs) - - self.base_color_int = self.color_to_int( - self.svg_default.get("fill_color") \ - or self.svg_default.get("color") \ - or WHITE - ) + if self.base_color is None: + self.base_color = WHITE self.pre_parse() self.parse() @@ -68,7 +65,8 @@ class LabelledString(SVGMobject, ABC): def generate_mobject(self) -> None: super().generate_mobject() - if self.label_span_list: + num_labels = len(self.label_span_list) + if num_labels: file_path = self.get_file_path_(use_plain_file=False) labelled_svg = SVGMobject(file_path) submob_color_ints = [ @@ -85,7 +83,7 @@ class LabelledString(SVGMobject, ABC): ) unrecognized_color_ints = remove_list_redundancies(sorted(filter( - lambda color_int: color_int > len(self.label_span_list), + lambda color_int: color_int > num_labels, submob_color_ints ))) if unrecognized_color_ints: @@ -100,6 +98,7 @@ class LabelledString(SVGMobject, ABC): def pre_parse(self) -> None: self.string_len = len(self.string) self.full_span = (0, self.string_len) + self.base_color_int = self.color_to_int(self.base_color) def parse(self) -> None: self.command_repl_items = self.get_command_repl_items() @@ -311,7 +310,7 @@ class LabelledString(SVGMobject, ABC): self.extra_entity_spans )) - def index_not_in_entity_spans(self, index: int) -> bool: + def is_splittable_index(self, index: int) -> bool: return not any([ entity_span[0] < index < entity_span[1] for entity_span in self.entity_spans @@ -348,12 +347,16 @@ class LabelledString(SVGMobject, ABC): self.external_specified_spans, self.find_substrs(self.isolate) )) - shrinked_spans = list(filter( - lambda span: span[0] < span[1] and all([ - self.index_not_in_entity_spans(index) + filtered_spans = list(filter( + lambda span: all([ + self.is_splittable_index(index) for index in span ]), - [self.shrink_span(span) for span in spans] + spans + )) + shrinked_spans = list(filter( + lambda span: span[0] < span[1], + [self.shrink_span(span) for span in filtered_spans] )) return remove_list_redundancies(shrinked_spans) diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 8c7d4843..91d7675b 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -47,6 +47,7 @@ class MTex(LabelledString): self.__class__.__name__, self.svg_default, self.path_string_config, + self.base_color, self.isolate, self.tex_string, self.alignment, diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 1e3f073a..2c076551 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -161,6 +161,7 @@ class MarkupText(LabelledString): self.__class__.__name__, self.svg_default, self.path_string_config, + self.base_color, self.isolate, self.text, self.is_markup, @@ -452,7 +453,7 @@ class MarkupText(LabelledString): self.specified_spans )))) breakup_indices = sorted(filter( - self.index_not_in_entity_spans, breakup_indices + self.is_splittable_index, breakup_indices )) return list(filter( lambda span: self.get_substr(span).strip(), @@ -462,7 +463,7 @@ class MarkupText(LabelledString): def get_content(self, use_plain_file: bool) -> str: filtered_attr_dicts = list(filter( lambda item: all([ - self.index_not_in_entity_spans(index) + self.is_splittable_index(index) for index in item[0] ]), self.predefined_attr_dicts From 09952756ce817067f15b950addc3cec75c85d93c Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Fri, 15 Apr 2022 13:48:24 +0800 Subject: [PATCH 14/64] Support hashing Color type in hash_obj --- manimlib/utils/iterables.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/manimlib/utils/iterables.py b/manimlib/utils/iterables.py index 05d92597..ecbdbc1a 100644 --- a/manimlib/utils/iterables.py +++ b/manimlib/utils/iterables.py @@ -1,5 +1,7 @@ from __future__ import annotations +from colour import Color + import numpy as np from typing import TYPE_CHECKING @@ -79,7 +81,7 @@ def batch_by_property( return batch_prop_pairs -def listify(obj) -> list: +def listify(obj: object) -> list: if isinstance(obj, str): return [obj] try: @@ -139,4 +141,7 @@ def hash_obj(obj: object) -> int: if isinstance(obj, (set, tuple, list)): return hash(tuple(hash_obj(e) for e in obj)) + if isinstance(obj, Color): + return hash(obj.get_rgb()) + return hash(obj) From 0a810bb4f1e92ce77d20555207d597ba82ba877c Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Fri, 15 Apr 2022 22:54:06 +0800 Subject: [PATCH 15/64] Refactor LabelledString --- .../animation/transform_matching_parts.py | 16 ++- manimlib/mobject/svg/labelled_string.py | 106 +++++++++--------- manimlib/mobject/svg/mtex_mobject.py | 60 ++++++---- manimlib/mobject/svg/text_mobject.py | 82 ++++++-------- 4 files changed, 128 insertions(+), 136 deletions(-) diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index e3c65a49..325c2fb0 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -213,14 +213,12 @@ class TransformMatchingStrings(AnimationGroup): ], key=len, reverse=True) def get_parts_from_keys(mobject, keys): - if isinstance(keys, str): + if not isinstance(keys, list): keys = [keys] - result = VGroup() - for key in keys: - if not isinstance(key, str): - raise TypeError(key) - result.add(*mobject.get_parts_by_string(key)) - return result + return VGroup(*it.chain(*[ + mobject.select_parts(key) + for key in keys + ])) add_anims_from( ReplacementTransform, get_parts_from_keys, @@ -228,7 +226,7 @@ class TransformMatchingStrings(AnimationGroup): ) add_anims_from( FadeTransformPieces, - LabelledString.get_parts_by_string, + LabelledString.select_parts, get_common_substrs( source.specified_substrs, target.specified_substrs @@ -236,7 +234,7 @@ class TransformMatchingStrings(AnimationGroup): ) add_anims_from( FadeTransformPieces, - LabelledString.get_parts_by_group_substr, + LabelledString.select_parts_by_group_substr, get_common_substrs( source.group_substrs, target.group_substrs diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index d20b3c11..9c927e0e 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -22,6 +22,11 @@ if TYPE_CHECKING: ManimColor = Union[str, Color] Span = tuple[int, int] + Selector = Union[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]] + ] class LabelledString(SVGMobject, ABC): @@ -52,10 +57,10 @@ class LabelledString(SVGMobject, ABC): self.post_parse() def get_file_path(self) -> str: - return self.get_file_path_(use_plain_file=True) + return self.get_file_path_(is_labelled=False) - def get_file_path_(self, use_plain_file: bool) -> str: - content = self.get_content(use_plain_file) + def get_file_path_(self, is_labelled: bool) -> str: + content = self.get_content(is_labelled) return self.get_file_path_by_content(content) @abstractmethod @@ -67,7 +72,7 @@ class LabelledString(SVGMobject, ABC): num_labels = len(self.label_span_list) if num_labels: - file_path = self.get_file_path_(use_plain_file=False) + file_path = self.get_file_path_(is_labelled=True) labelled_svg = SVGMobject(file_path) submob_color_ints = [ self.color_to_int(submob.get_fill_color()) @@ -132,37 +137,31 @@ class LabelledString(SVGMobject, ABC): def get_substr(self, span: Span) -> str: return self.string[slice(*span)] - def finditer( - self, pattern: str, flags: int = 0, **kwargs - ) -> Iterable[re.Match]: - return re.compile(pattern, flags).finditer(self.string, **kwargs) - - def search( - self, pattern: str, flags: int = 0, **kwargs - ) -> re.Match | None: - return re.compile(pattern, flags).search(self.string, **kwargs) - - def match( - self, pattern: str, flags: int = 0, **kwargs - ) -> re.Match | None: - return re.compile(pattern, flags).match(self.string, **kwargs) - - def find_spans(self, pattern: str, **kwargs) -> list[Span]: + def find_spans(self, pattern: str | re.Pattern, **kwargs) -> list[Span]: + if isinstance(pattern, str): + pattern = re.compile(pattern) return [ match_obj.span() - for match_obj in self.finditer(pattern, **kwargs) + for match_obj in pattern.finditer(self.string, **kwargs) ] - def find_substr(self, substr: str, **kwargs) -> list[Span]: - if not substr: - return [] - return self.find_spans(re.escape(substr), **kwargs) - - def find_substrs(self, substrs: list[str], **kwargs) -> list[Span]: - return list(it.chain(*[ - self.find_substr(substr, **kwargs) - for substr in remove_list_redundancies(substrs) - ])) + def find_spans_by_selector(self, selector: Selector) -> list[Span]: + if isinstance(selector, str): + result = self.find_spans(re.escape(selector)) + elif isinstance(selector, re.Pattern): + result = self.find_spans(selector) + else: + span = tuple([ + ( + min(index, self.string_len) + if index >= 0 + else max(index + self.string_len, 0) + ) + if index is not None else default_index + for index, default_index in zip(selector, self.full_span) + ]) + result = [span] + return list(filter(lambda span: span[0] < span[1], result)) @staticmethod def get_neighbouring_pairs(iterable: list) -> list[tuple]: @@ -345,7 +344,10 @@ class LabelledString(SVGMobject, ABC): spans = list(it.chain( self.internal_specified_spans, self.external_specified_spans, - self.find_substrs(self.isolate) + *[ + self.find_spans_by_selector(selector) + for selector in self.isolate + ] )) filtered_spans = list(filter( lambda span: all([ @@ -376,7 +378,7 @@ class LabelledString(SVGMobject, ABC): ) @abstractmethod - def get_content(self, use_plain_file: bool) -> str: + def get_content(self, is_labelled: bool) -> str: return "" # Post-parsing @@ -441,7 +443,7 @@ class LabelledString(SVGMobject, ABC): def get_submob_groups(self) -> list[VGroup]: return [submob_group for _, submob_group in self.group_items] - def get_parts_by_group_substr(self, substr: str) -> VGroup: + def select_parts_by_group_substr(self, substr: str) -> VGroup: return VGroup(*[ group for group_substr, group in self.group_items @@ -488,7 +490,7 @@ class LabelledString(SVGMobject, ABC): span_begin = next_begin return result - def get_part_by_custom_span(self, custom_span: Span, **kwargs) -> VGroup: + def select_part_by_span(self, custom_span: Span, **kwargs) -> VGroup: labels = [ label for label, span in enumerate(self.label_span_list) if any([ @@ -503,34 +505,28 @@ class LabelledString(SVGMobject, ABC): if label in labels ]) - def get_parts_by_string( - self, substr: str, - case_sensitive: bool = True, regex: bool = False, **kwargs - ) -> VGroup: - flags = 0 - if not case_sensitive: - flags |= re.I - pattern = substr if regex else re.escape(substr) + def select_parts(self, selector: Selector, **kwargs) -> VGroup: return VGroup(*[ - self.get_part_by_custom_span(span, **kwargs) - for span in self.find_spans(pattern, flags=flags) - if span[0] < span[1] + self.select_part_by_span(span, **kwargs) + for span in self.find_spans_by_selector(selector) ]) - def get_part_by_string( - self, substr: str, index: int = 0, **kwargs + def select_part( + self, selector: Selector, index: int = 0, **kwargs ) -> VMobject: - return self.get_parts_by_string(substr, **kwargs)[index] + return self.select_parts(selector, **kwargs)[index] - def set_color_by_string(self, substr: str, color: ManimColor, **kwargs): - self.get_parts_by_string(substr, **kwargs).set_color(color) + def set_parts_color( + self, selector: Selector, color: ManimColor, **kwargs + ): + self.select_parts(selector, **kwargs).set_color(color) return self - def set_color_by_string_to_color_map( - self, string_to_color_map: dict[str, ManimColor], **kwargs + def set_parts_color_by_dict( + self, color_map: dict[Selector, ManimColor], **kwargs ): - for substr, color in string_to_color_map.items(): - self.set_color_by_string(substr, color, **kwargs) + for selector, color in color_map.items(): + self.set_parts_color(selector, color, **kwargs) return self def get_string(self) -> str: diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 91d7675b..dad69df5 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -1,6 +1,7 @@ from __future__ import annotations import itertools as it +import re from manimlib.mobject.svg.labelled_string import LabelledString from manimlib.utils.tex_file_writing import display_during_execution @@ -18,6 +19,11 @@ if TYPE_CHECKING: ManimColor = Union[str, Color] Span = tuple[int, int] + Selector = Union[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]] + ] SCALE_FACTOR_PER_FONT_POINT = 0.001 @@ -61,7 +67,7 @@ class MTex(LabelledString): tex_config["text_to_replace"], content ) - with display_during_execution(f"Writing \"{self.tex_string}\""): + with display_during_execution(f"Writing \"{self.string}\""): file_path = tex_to_svg_file(full_tex) return file_path @@ -93,7 +99,10 @@ class MTex(LabelledString): def get_unescaped_char_spans(self, chars: str): return sorted(filter( lambda span: span[0] - 1 not in self.backslash_indices, - self.find_substrs(list(chars)) + list(it.chain(*[ + self.find_spans(re.escape(char)) + for char in chars + ])) )) def get_brace_index_pairs(self) -> list[Span]: @@ -121,8 +130,8 @@ class MTex(LabelledString): result = [] brace_indices_dict = dict(self.brace_index_pairs) script_pattern = r"[a-zA-Z0-9]|\\[a-zA-Z]+" - for script_char_span in self.script_char_spans: - span_begin = self.match(r"\s*", pos=script_char_span[1]).end() + for char_span in self.script_char_spans: + span_begin = self.find_spans(r"\s*", pos=char_span[1])[0][1] if span_begin in brace_indices_dict.keys(): span_end = brace_indices_dict[span_begin] + 1 else: @@ -143,10 +152,10 @@ class MTex(LabelledString): def get_script_spans(self) -> list[Span]: return [ ( - self.search(r"\s*$", endpos=script_char_span[0]).start(), + self.find_spans(r"\s*", endpos=char_span[0])[-1][0], script_content_span[1] ) - for script_char_span, script_content_span in zip( + for char_span, script_content_span in zip( self.script_char_spans, self.script_content_spans ) ] @@ -174,7 +183,7 @@ class MTex(LabelledString): ")", r"(?![a-zA-Z])" ]) - for match_obj in self.finditer(pattern): + for match_obj in re.finditer(pattern, self.string): span_begin, cmd_end = match_obj.span() if span_begin not in backslash_indices: continue @@ -192,7 +201,7 @@ class MTex(LabelledString): def get_extra_entity_spans(self) -> list[Span]: return [ - self.match(r"\\([a-zA-Z]+|.)", pos=index).span() + self.find_spans(r"\\([a-zA-Z]+|.?)", pos=index)[0] for index in self.backslash_indices ] @@ -223,7 +232,10 @@ class MTex(LabelledString): return result def get_external_specified_spans(self) -> list[Span]: - return self.find_substrs(list(self.tex_to_color_map.keys())) + return list(it.chain(*[ + self.find_spans_by_selector(selector) + for selector in self.tex_to_color_map.keys() + ])) def get_label_span_list(self) -> list[Span]: result = self.script_content_spans.copy() @@ -237,10 +249,8 @@ class MTex(LabelledString): result.append(shrinked_span) return result - def get_content(self, use_plain_file: bool) -> str: - if use_plain_file: - span_repl_dict = {} - else: + def get_content(self, is_labelled: bool) -> str: + if is_labelled: extended_label_span_list = [ span if span in self.script_content_spans @@ -258,6 +268,8 @@ class MTex(LabelledString): inserted_string_pairs, self.command_repl_items ) + else: + span_repl_dict = {} result = self.get_replaced_substr(self.full_span, span_repl_dict) if self.tex_environment: @@ -269,7 +281,7 @@ class MTex(LabelledString): result = "\n".join([prefix, result, suffix]) if self.alignment: result = "\n".join([self.alignment, result]) - if use_plain_file: + if not is_labelled: result = "\n".join([ self.get_color_command_str(self.base_color_int), result @@ -303,21 +315,21 @@ class MTex(LabelledString): # Method alias - def get_parts_by_tex(self, tex: str, **kwargs) -> VGroup: - return self.get_parts_by_string(tex, **kwargs) + def get_parts_by_tex(self, selector: Selector, **kwargs) -> VGroup: + return self.select_parts(selector, **kwargs) - def get_part_by_tex(self, tex: str, **kwargs) -> VMobject: - return self.get_part_by_string(tex, **kwargs) + def get_part_by_tex(self, selector: Selector, **kwargs) -> VMobject: + return self.select_part(selector, **kwargs) - def set_color_by_tex(self, tex: str, color: ManimColor, **kwargs): - return self.set_color_by_string(tex, color, **kwargs) + def set_color_by_tex( + self, selector: Selector, color: ManimColor, **kwargs + ): + return self.set_parts_color(selector, color, **kwargs) def set_color_by_tex_to_color_map( - self, tex_to_color_map: dict[str, ManimColor], **kwargs + self, color_map: dict[Selector, ManimColor], **kwargs ): - return self.set_color_by_string_to_color_map( - tex_to_color_map, **kwargs - ) + return self.set_parts_color_by_dict(color_map, **kwargs) def get_tex(self) -> str: return self.get_string() diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 2c076551..f79bb79a 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -34,6 +34,11 @@ if TYPE_CHECKING: ManimColor = Union[str, Color] Span = tuple[int, int] + Selector = Union[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]] + ] TEXT_MOB_SCALE_FACTOR = 0.0076 @@ -283,25 +288,6 @@ class MarkupText(LabelledString): MarkupText.get_neighbouring_pairs(index_seq), attr_dict_list[:-1] )) - def find_substr_or_span( - self, substr_or_span: str | tuple[int | None, int | None] - ) -> list[Span]: - if isinstance(substr_or_span, str): - return self.find_substr(substr_or_span) - - span = tuple([ - ( - min(index, self.string_len) - if index >= 0 - else max(index + self.string_len, 0) - ) - if index is not None else default_index - for index, default_index in zip(substr_or_span, self.full_span) - ]) - if span[0] >= span[1]: - return [] - return [span] - # Pre-parsing def get_tag_items_from_markup( @@ -314,7 +300,7 @@ class MarkupText(LabelledString): attr_pattern = r"""(\w+)\s*\=\s*(['"])(.*?)\2""" begin_match_obj_stack = [] match_obj_pairs = [] - for match_obj in self.finditer(tag_pattern): + for match_obj in re.finditer(tag_pattern, self.string): if not match_obj.group(1): begin_match_obj_stack.append(match_obj) else: @@ -385,12 +371,12 @@ class MarkupText(LabelledString): (self.t2s, "font_style"), (self.t2w, "font_weight") ) - for substr_or_span, val in t2x_dict.items() - for span in self.find_substr_or_span(substr_or_span) + for selector, val in t2x_dict.items() + for span in self.find_spans_by_selector(selector) ] + [ (span, local_config) - for substr_or_span, local_config in self.local_configs.items() - for span in self.find_substr_or_span(substr_or_span) + for selector, local_config in self.local_configs.items() + for span in self.find_spans_by_selector(selector) ] def get_predefined_attr_dicts(self) -> list[Span, dict[str, str]]: @@ -428,7 +414,7 @@ class MarkupText(LabelledString): (">", ">"), ("<", "<") ) - for span in self.find_substr(char) + for span in self.find_spans(re.escape(char)) ] return result @@ -460,7 +446,7 @@ class MarkupText(LabelledString): self.get_neighbouring_pairs(breakup_indices) )) - def get_content(self, use_plain_file: bool) -> str: + def get_content(self, is_labelled: bool) -> str: filtered_attr_dicts = list(filter( lambda item: all([ self.is_splittable_index(index) @@ -468,18 +454,7 @@ class MarkupText(LabelledString): ]), self.predefined_attr_dicts )) - if use_plain_file: - attr_dict_items = [ - (self.full_span, { - "foreground": self.int_to_hex(self.base_color_int) - }), - *filtered_attr_dicts, - *[ - (span, {}) - for span in self.label_span_list - ] - ] - else: + if is_labelled: attr_dict_items = [ (self.full_span, {"foreground": BLACK}), *[ @@ -494,6 +469,17 @@ class MarkupText(LabelledString): for label, span in enumerate(self.label_span_list) ] ] + else: + attr_dict_items = [ + (self.full_span, { + "foreground": self.int_to_hex(self.base_color_int) + }), + *filtered_attr_dicts, + *[ + (span, {}) + for span in self.label_span_list + ] + ] inserted_string_pairs = [ (span, ( f"", @@ -508,21 +494,21 @@ class MarkupText(LabelledString): # Method alias - def get_parts_by_text(self, text: str, **kwargs) -> VGroup: - return self.get_parts_by_string(text, **kwargs) + def get_parts_by_text(self, selector: Selector, **kwargs) -> VGroup: + return self.select_parts(selector, **kwargs) - def get_part_by_text(self, text: str, **kwargs) -> VMobject: - return self.get_part_by_string(text, **kwargs) + def get_part_by_text(self, selector: Selector, **kwargs) -> VMobject: + return self.select_part(selector, **kwargs) - def set_color_by_text(self, text: str, color: ManimColor, **kwargs): - return self.set_color_by_string(text, color, **kwargs) + def set_color_by_text( + self, selector: Selector, color: ManimColor, **kwargs + ): + return self.set_parts_color(selector, color, **kwargs) def set_color_by_text_to_color_map( - self, text_to_color_map: dict[str, ManimColor], **kwargs + self, color_map: dict[Selector, ManimColor], **kwargs ): - return self.set_color_by_string_to_color_map( - text_to_color_map, **kwargs - ) + return self.set_parts_color_by_dict(color_map, **kwargs) def get_text(self) -> str: return self.get_string() From 14dfd776dcab16ffb0fd77331c6571f91ae73748 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Fri, 15 Apr 2022 23:26:41 +0800 Subject: [PATCH 16/64] Refactor LabelledString --- manimlib/mobject/svg/mtex_mobject.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index dad69df5..10d3989e 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -135,8 +135,8 @@ class MTex(LabelledString): if span_begin in brace_indices_dict.keys(): span_end = brace_indices_dict[span_begin] + 1 else: - match_obj = self.match(script_pattern, pos=span_begin) - if not match_obj: + span = self.find_spans(script_pattern, pos=span_begin)[0] + if span[0] != span_begin: script_name = { "_": "subscript", "^": "superscript" @@ -145,7 +145,7 @@ class MTex(LabelledString): f"Unclear {script_name} detected while parsing. " "Please use braces to clarify" ) - span_end = match_obj.end() + span_end = span[1] result.append((span_begin, span_end)) return result From dbefc3b25631c7296b41c145e93d27b2e6c4b445 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Fri, 15 Apr 2022 23:30:42 +0800 Subject: [PATCH 17/64] Refactor LabelledString --- manimlib/mobject/svg/mtex_mobject.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 10d3989e..f0586904 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -135,17 +135,17 @@ class MTex(LabelledString): if span_begin in brace_indices_dict.keys(): span_end = brace_indices_dict[span_begin] + 1 else: - span = self.find_spans(script_pattern, pos=span_begin)[0] - if span[0] != span_begin: + spans = self.find_spans(script_pattern, pos=span_begin) + if not spans or spans[0][0] != span_begin: script_name = { "_": "subscript", "^": "superscript" - }[script_char] + }[self.get_string(char_span)] raise ValueError( f"Unclear {script_name} detected while parsing. " "Please use braces to clarify" ) - span_end = span[1] + span_end = spans[0][1] result.append((span_begin, span_end)) return result From a1e77b0ce2ae949ab84f2c4aee278a32b25c5f20 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Fri, 15 Apr 2022 23:58:06 +0800 Subject: [PATCH 18/64] Refactor LabelledString --- manimlib/mobject/svg/mtex_mobject.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index f0586904..46ea13e8 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -142,7 +142,8 @@ class MTex(LabelledString): "^": "superscript" }[self.get_string(char_span)] raise ValueError( - f"Unclear {script_name} detected while parsing. " + f"Unclear {script_name} detected while parsing " + f"(position {char_span[0]}). " "Please use braces to clarify" ) span_end = spans[0][1] From 4690edec3e57e70af4c9c0f0108bc8741e64fd95 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Sat, 16 Apr 2022 00:24:55 +0800 Subject: [PATCH 19/64] Refactor LabelledString --- manimlib/mobject/svg/labelled_string.py | 8 ++++++-- manimlib/mobject/svg/mtex_mobject.py | 12 ++++++------ manimlib/mobject/svg/text_mobject.py | 2 +- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 9c927e0e..a0d1343e 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -103,6 +103,7 @@ class LabelledString(SVGMobject, ABC): def pre_parse(self) -> None: self.string_len = len(self.string) self.full_span = (0, self.string_len) + self.space_spans = self.find_spans(r"\s+") self.base_color_int = self.color_to_int(self.base_color) def parse(self) -> None: @@ -137,14 +138,17 @@ class LabelledString(SVGMobject, ABC): def get_substr(self, span: Span) -> str: return self.string[slice(*span)] - def find_spans(self, pattern: str | re.Pattern, **kwargs) -> list[Span]: + def find_spans(self, pattern: str | re.Pattern) -> list[Span]: if isinstance(pattern, str): pattern = re.compile(pattern) return [ match_obj.span() - for match_obj in pattern.finditer(self.string, **kwargs) + for match_obj in pattern.finditer(self.string) ] + def match_at(self, pattern: str, pos: int) -> re.Pattern | None: + return re.compile(pattern).match(self.string, pos=pos) + def find_spans_by_selector(self, selector: Selector) -> list[Span]: if isinstance(selector, str): result = self.find_spans(re.escape(selector)) diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 46ea13e8..878ba94e 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -131,12 +131,12 @@ class MTex(LabelledString): brace_indices_dict = dict(self.brace_index_pairs) script_pattern = r"[a-zA-Z0-9]|\\[a-zA-Z]+" for char_span in self.script_char_spans: - span_begin = self.find_spans(r"\s*", pos=char_span[1])[0][1] + span_begin = self.rslide(char_span[1], self.space_spans) if span_begin in brace_indices_dict.keys(): span_end = brace_indices_dict[span_begin] + 1 else: - spans = self.find_spans(script_pattern, pos=span_begin) - if not spans or spans[0][0] != span_begin: + match_obj = self.match_at(script_pattern, span_begin) + if match_obj is None: script_name = { "_": "subscript", "^": "superscript" @@ -146,14 +146,14 @@ class MTex(LabelledString): f"(position {char_span[0]}). " "Please use braces to clarify" ) - span_end = spans[0][1] + span_end = match_obj.end() result.append((span_begin, span_end)) return result def get_script_spans(self) -> list[Span]: return [ ( - self.find_spans(r"\s*", endpos=char_span[0])[-1][0], + self.lslide(char_span[0], self.space_spans), script_content_span[1] ) for char_span, script_content_span in zip( @@ -202,7 +202,7 @@ class MTex(LabelledString): def get_extra_entity_spans(self) -> list[Span]: return [ - self.find_spans(r"\\([a-zA-Z]+|.?)", pos=index)[0] + self.match_at(r"\\([a-zA-Z]+|.?)", index).span() for index in self.backslash_indices ] diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index f79bb79a..d02e7bc2 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -434,8 +434,8 @@ class MarkupText(LabelledString): def get_label_span_list(self) -> list[Span]: breakup_indices = remove_list_redundancies(list(it.chain(*it.chain( - self.find_spans(r"\s+"), self.find_spans(r"\b"), + self.space_spans, self.specified_spans )))) breakup_indices = sorted(filter( From ac4620483c18e9c395c475e941cd9fbd1ba64f0a Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Sat, 16 Apr 2022 12:53:43 +0800 Subject: [PATCH 20/64] Support flexible selector types --- .../animation/transform_matching_parts.py | 16 +---- manimlib/mobject/svg/labelled_string.py | 59 ++++++++++++------- manimlib/mobject/svg/mtex_mobject.py | 3 +- manimlib/mobject/svg/text_mobject.py | 3 +- 4 files changed, 46 insertions(+), 35 deletions(-) diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index 325c2fb0..be8af624 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -212,29 +212,19 @@ class TransformMatchingStrings(AnimationGroup): if substr and substr in substrs_from_target ], key=len, reverse=True) - def get_parts_from_keys(mobject, keys): - if not isinstance(keys, list): - keys = [keys] - return VGroup(*it.chain(*[ - mobject.select_parts(key) - for key in keys - ])) - add_anims_from( - ReplacementTransform, get_parts_from_keys, + ReplacementTransform, LabelledString.select_parts, self.key_map.keys(), self.key_map.values() ) add_anims_from( - FadeTransformPieces, - LabelledString.select_parts, + FadeTransformPieces, LabelledString.select_parts, get_common_substrs( source.specified_substrs, target.specified_substrs ) ) add_anims_from( - FadeTransformPieces, - LabelledString.select_parts_by_group_substr, + FadeTransformPieces, LabelledString.select_parts_by_group_substr, get_common_substrs( source.group_substrs, target.group_substrs diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index a0d1343e..3624a0c1 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -22,11 +22,12 @@ if TYPE_CHECKING: ManimColor = Union[str, Color] Span = tuple[int, int] - Selector = Union[ + SingleSelector = Union[ str, re.Pattern, tuple[Union[int, None], Union[int, None]] ] + Selector = Union[SingleSelector, Iterable[SingleSelector]] class LabelledString(SVGMobject, ABC): @@ -149,22 +150,43 @@ class LabelledString(SVGMobject, ABC): def match_at(self, pattern: str, pos: int) -> re.Pattern | None: return re.compile(pattern).match(self.string, pos=pos) - def find_spans_by_selector(self, selector: Selector) -> list[Span]: + @staticmethod + def is_single_selector(selector: Selector) -> bool: if isinstance(selector, str): - result = self.find_spans(re.escape(selector)) - elif isinstance(selector, re.Pattern): - result = self.find_spans(selector) - else: - span = tuple([ - ( - min(index, self.string_len) - if index >= 0 - else max(index + self.string_len, 0) - ) - if index is not None else default_index - for index, default_index in zip(selector, self.full_span) - ]) - result = [span] + return True + if isinstance(selector, re.Pattern): + return True + if isinstance(selector, tuple): + if len(selector) == 2 and all([ + isinstance(index, int) or index is None + for index in selector + ]): + return True + return False + + def find_spans_by_selector(self, selector: Selector) -> list[Span]: + if self.is_single_selector(selector): + selector = (selector,) + result = [] + for sel in selector: + if not self.is_single_selector(sel): + raise TypeError(f"Invalid selector: '{sel}'") + if isinstance(sel, str): + spans = self.find_spans(re.escape(sel)) + elif isinstance(sel, re.Pattern): + spans = self.find_spans(sel) + else: + span = tuple([ + ( + min(index, self.string_len) + if index >= 0 + else max(index + self.string_len, 0) + ) + if index is not None else default_index + for index, default_index in zip(sel, self.full_span) + ]) + spans = [span] + result.extend(spans) return list(filter(lambda span: span[0] < span[1], result)) @staticmethod @@ -348,10 +370,7 @@ class LabelledString(SVGMobject, ABC): spans = list(it.chain( self.internal_specified_spans, self.external_specified_spans, - *[ - self.find_spans_by_selector(selector) - for selector in self.isolate - ] + self.find_spans_by_selector(self.isolate) )) filtered_spans = list(filter( lambda span: all([ diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 878ba94e..eba37668 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -19,11 +19,12 @@ if TYPE_CHECKING: ManimColor = Union[str, Color] Span = tuple[int, int] - Selector = Union[ + SingleSelector = Union[ str, re.Pattern, tuple[Union[int, None], Union[int, None]] ] + Selector = Union[SingleSelector, Iterable[SingleSelector]] SCALE_FACTOR_PER_FONT_POINT = 0.001 diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index d02e7bc2..a5e0c321 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -34,11 +34,12 @@ if TYPE_CHECKING: ManimColor = Union[str, Color] Span = tuple[int, int] - Selector = Union[ + SingleSelector = Union[ str, re.Pattern, tuple[Union[int, None], Union[int, None]] ] + Selector = Union[SingleSelector, Iterable[SingleSelector]] TEXT_MOB_SCALE_FACTOR = 0.0076 From bc18894040a2952894327c3f5245c44f6fc41bb6 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Sat, 16 Apr 2022 13:59:42 +0800 Subject: [PATCH 21/64] Remove empty results in LabelledString.select_parts --- manimlib/mobject/svg/labelled_string.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 3624a0c1..dcc175a9 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -187,7 +187,13 @@ class LabelledString(SVGMobject, ABC): ]) spans = [span] result.extend(spans) - return list(filter(lambda span: span[0] < span[1], result)) + return sorted( + filter( + lambda span: span[0] < span[1], + remove_list_redundancies(result) + ), + key=lambda span: (span[0], -span[1]) + ) @staticmethod def get_neighbouring_pairs(iterable: list) -> list[tuple]: @@ -529,10 +535,13 @@ class LabelledString(SVGMobject, ABC): ]) def select_parts(self, selector: Selector, **kwargs) -> VGroup: - return VGroup(*[ - self.select_part_by_span(span, **kwargs) - for span in self.find_spans_by_selector(selector) - ]) + return VGroup(*filter( + lambda part: part.submobjects, + [ + self.select_part_by_span(span, **kwargs) + for span in self.find_spans_by_selector(selector) + ] + )) def select_part( self, selector: Selector, index: int = 0, **kwargs From 654da85cf622a442279e26941b2785199e46eea4 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Sat, 16 Apr 2022 14:09:59 +0800 Subject: [PATCH 22/64] Adjust typings --- manimlib/animation/transform_matching_parts.py | 10 +++++++--- manimlib/mobject/svg/labelled_string.py | 8 +------- manimlib/mobject/svg/mtex_mobject.py | 3 +-- manimlib/mobject/svg/text_mobject.py | 3 +-- 4 files changed, 10 insertions(+), 14 deletions(-) diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index be8af624..ec41fdcb 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -168,13 +168,17 @@ class TransformMatchingStrings(AnimationGroup): assert isinstance(source, LabelledString) assert isinstance(target, LabelledString) anims = [] - source_indices = list(range(len(source.labelled_submobjects))) - target_indices = list(range(len(target.labelled_submobjects))) + source_indices = list(range(len(source.labelled_submobject_items))) + target_indices = list(range(len(target.labelled_submobject_items))) def get_indices_lists(mobject, parts): + labelled_submobjects = [ + submob + for _, submob in mobject.labelled_submobject_items + ] return [ [ - mobject.labelled_submobjects.index(submob) + labelled_submobjects.index(submob) for submob in part ] for part in parts diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index dcc175a9..5ecc9d7c 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -18,8 +18,6 @@ if TYPE_CHECKING: from colour import Color from typing import Iterable, Union - from manimlib.mobject.types.vectorized_mobject import VMobject - ManimColor = Union[str, Color] Span = tuple[int, int] SingleSelector = Union[ @@ -125,7 +123,6 @@ class LabelledString(SVGMobject, ABC): (submob.label, submob) for submob in self.submobjects ] - self.labelled_submobjects = self.get_labelled_submobjects() self.specified_substrs = self.get_specified_substrs() self.group_items = self.get_group_items() self.group_substrs = self.get_group_substrs() @@ -412,9 +409,6 @@ class LabelledString(SVGMobject, ABC): # Post-parsing - def get_labelled_submobjects(self) -> list[VMobject]: - return [submob for _, submob in self.labelled_submobject_items] - def get_cleaned_substr(self, span: Span) -> str: span_repl_dict = dict.fromkeys(self.command_spans, "") return self.get_replaced_substr(span, span_repl_dict) @@ -545,7 +539,7 @@ class LabelledString(SVGMobject, ABC): def select_part( self, selector: Selector, index: int = 0, **kwargs - ) -> VMobject: + ) -> VGroup: return self.select_parts(selector, **kwargs)[index] def set_parts_color( diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index eba37668..ce2f4cd1 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -15,7 +15,6 @@ if TYPE_CHECKING: from typing import Union from manimlib.mobject.types.vectorized_mobject import VGroup - from manimlib.mobject.types.vectorized_mobject import VMobject ManimColor = Union[str, Color] Span = tuple[int, int] @@ -320,7 +319,7 @@ class MTex(LabelledString): def get_parts_by_tex(self, selector: Selector, **kwargs) -> VGroup: return self.select_parts(selector, **kwargs) - def get_part_by_tex(self, selector: Selector, **kwargs) -> VMobject: + def get_part_by_tex(self, selector: Selector, **kwargs) -> VGroup: return self.select_part(selector, **kwargs) def set_color_by_tex( diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index a5e0c321..5b88e241 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -29,7 +29,6 @@ if TYPE_CHECKING: from colour import Color from typing import Union - from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.mobject.types.vectorized_mobject import VGroup ManimColor = Union[str, Color] @@ -498,7 +497,7 @@ class MarkupText(LabelledString): def get_parts_by_text(self, selector: Selector, **kwargs) -> VGroup: return self.select_parts(selector, **kwargs) - def get_part_by_text(self, selector: Selector, **kwargs) -> VMobject: + def get_part_by_text(self, selector: Selector, **kwargs) -> VGroup: return self.select_part(selector, **kwargs) def set_color_by_text( From 0406ef70bbdd01bf686492ad8e1d86782871241d Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Sat, 16 Apr 2022 14:37:28 +0800 Subject: [PATCH 23/64] Adjust typings for sounds.py and tex_file_writing.py --- manimlib/utils/sounds.py | 4 +++- manimlib/utils/tex_file_writing.py | 16 +++++++++------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/manimlib/utils/sounds.py b/manimlib/utils/sounds.py index dc37ff00..f34ba435 100644 --- a/manimlib/utils/sounds.py +++ b/manimlib/utils/sounds.py @@ -1,8 +1,10 @@ +from __future__ import annotations + from manimlib.utils.directories import get_sound_dir from manimlib.utils.file_ops import find_file -def get_full_sound_file_path(sound_file_name) -> str: +def get_full_sound_file_path(sound_file_name: str) -> str: return find_file( sound_file_name, directories=[get_sound_dir()], diff --git a/manimlib/utils/tex_file_writing.py b/manimlib/utils/tex_file_writing.py index 5b57be64..3e4591a9 100644 --- a/manimlib/utils/tex_file_writing.py +++ b/manimlib/utils/tex_file_writing.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from contextlib import contextmanager import hashlib import os @@ -12,7 +14,7 @@ from manimlib.utils.directories import get_tex_dir SAVED_TEX_CONFIG = {} -def get_tex_config(): +def get_tex_config() -> dict[str, str]: """ Returns a dict which should look something like this: { @@ -37,13 +39,13 @@ def get_tex_config(): return SAVED_TEX_CONFIG -def tex_hash(tex_file_content): +def tex_hash(tex_file_content: str) -> int: # Truncating at 16 bytes for cleanliness hasher = hashlib.sha256(tex_file_content.encode()) return hasher.hexdigest()[:16] -def tex_to_svg_file(tex_file_content): +def tex_to_svg_file(tex_file_content: str) -> str: svg_file = os.path.join( get_tex_dir(), tex_hash(tex_file_content) + ".svg" ) @@ -53,7 +55,7 @@ def tex_to_svg_file(tex_file_content): return svg_file -def tex_to_svg(tex_file_content, svg_file): +def tex_to_svg(tex_file_content: str, svg_file: str) -> str: tex_file = svg_file.replace(".svg", ".tex") with open(tex_file, "w", encoding="utf-8") as outfile: outfile.write(tex_file_content) @@ -69,7 +71,7 @@ def tex_to_svg(tex_file_content, svg_file): return svg_file -def tex_to_dvi(tex_file): +def tex_to_dvi(tex_file: str) -> str: tex_config = get_tex_config() program = tex_config["executable"] file_type = tex_config["intermediate_filetype"] @@ -96,7 +98,7 @@ def tex_to_dvi(tex_file): return result -def dvi_to_svg(dvi_file, regen_if_exists=False): +def dvi_to_svg(dvi_file: str) -> str: """ Converts a dvi, which potentially has multiple slides, into a directory full of enumerated pngs corresponding with these slides. @@ -123,7 +125,7 @@ def dvi_to_svg(dvi_file, regen_if_exists=False): # TODO, perhaps this should live elsewhere @contextmanager -def display_during_execution(message): +def display_during_execution(message: str) -> None: # Only show top line to_print = message.split("\n")[0] max_characters = os.get_terminal_size().columns - 1 From b387bc0c95e15a882c875c6a72d6669f5af26116 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Sat, 16 Apr 2022 15:29:23 +0800 Subject: [PATCH 24/64] Adjust typings --- manimlib/mobject/svg/labelled_string.py | 10 +++++++--- manimlib/mobject/svg/mtex_mobject.py | 10 +++++++--- manimlib/mobject/svg/text_mobject.py | 10 +++++++--- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 5ecc9d7c..2e81221f 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -20,12 +20,16 @@ if TYPE_CHECKING: ManimColor = Union[str, Color] Span = tuple[int, int] - SingleSelector = Union[ + Selector = Union[ str, re.Pattern, - tuple[Union[int, None], Union[int, None]] + tuple[Union[int, None], Union[int, None]], + Iterable[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]] + ] ] - Selector = Union[SingleSelector, Iterable[SingleSelector]] class LabelledString(SVGMobject, ABC): diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index ce2f4cd1..17f96b20 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -18,12 +18,16 @@ if TYPE_CHECKING: ManimColor = Union[str, Color] Span = tuple[int, int] - SingleSelector = Union[ + Selector = Union[ str, re.Pattern, - tuple[Union[int, None], Union[int, None]] + tuple[Union[int, None], Union[int, None]], + Iterable[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]] + ] ] - Selector = Union[SingleSelector, Iterable[SingleSelector]] SCALE_FACTOR_PER_FONT_POINT = 0.001 diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 5b88e241..22978c95 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -33,12 +33,16 @@ if TYPE_CHECKING: ManimColor = Union[str, Color] Span = tuple[int, int] - SingleSelector = Union[ + Selector = Union[ str, re.Pattern, - tuple[Union[int, None], Union[int, None]] + tuple[Union[int, None], Union[int, None]], + Iterable[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]] + ] ] - Selector = Union[SingleSelector, Iterable[SingleSelector]] TEXT_MOB_SCALE_FACTOR = 0.0076 From 58127e7511ae7d1e2a8e0e4dcffb14de42d2dc05 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Sat, 16 Apr 2022 15:34:32 +0800 Subject: [PATCH 25/64] import Iterables --- manimlib/mobject/svg/mtex_mobject.py | 2 +- manimlib/mobject/svg/text_mobject.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 17f96b20..4e115ad9 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -12,7 +12,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from colour import Color - from typing import Union + from typing import Iterable, Union from manimlib.mobject.types.vectorized_mobject import VGroup diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 22978c95..5e540724 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -27,7 +27,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from colour import Color - from typing import Union + from typing import Iterable, Union from manimlib.mobject.types.vectorized_mobject import VGroup From 4f5173b633c2ef925e21623ad6466f255437ff37 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Sat, 16 Apr 2022 15:45:55 +0800 Subject: [PATCH 26/64] Adjust typing --- manimlib/mobject/svg/labelled_string.py | 4 ++-- manimlib/mobject/svg/mtex_mobject.py | 4 ++-- manimlib/mobject/svg/text_mobject.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 2e81221f..1bf5cad5 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -24,11 +24,11 @@ if TYPE_CHECKING: str, re.Pattern, tuple[Union[int, None], Union[int, None]], - Iterable[ + Iterable[Union[ str, re.Pattern, tuple[Union[int, None], Union[int, None]] - ] + ]] ] diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 4e115ad9..8d7f8a90 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -22,11 +22,11 @@ if TYPE_CHECKING: str, re.Pattern, tuple[Union[int, None], Union[int, None]], - Iterable[ + Iterable[Union[ str, re.Pattern, tuple[Union[int, None], Union[int, None]] - ] + ]] ] diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 5e540724..f4000b3d 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -37,11 +37,11 @@ if TYPE_CHECKING: str, re.Pattern, tuple[Union[int, None], Union[int, None]], - Iterable[ + Iterable[Union[ str, re.Pattern, tuple[Union[int, None], Union[int, None]] - ] + ]] ] From e9298c5faf52ce5072277088018250605d29c743 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Sat, 16 Apr 2022 16:31:55 +0800 Subject: [PATCH 27/64] Remove sorting key --- manimlib/mobject/svg/labelled_string.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 1bf5cad5..6be55ef4 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -188,13 +188,10 @@ class LabelledString(SVGMobject, ABC): ]) spans = [span] result.extend(spans) - return sorted( - filter( - lambda span: span[0] < span[1], - remove_list_redundancies(result) - ), - key=lambda span: (span[0], -span[1]) - ) + return sorted(filter( + lambda span: span[0] < span[1], + remove_list_redundancies(result) + )) @staticmethod def get_neighbouring_pairs(iterable: list) -> list[tuple]: From 0e0244128cc04e33838ccde3778ddb472b060acc Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Sun, 17 Apr 2022 13:57:03 +0800 Subject: [PATCH 28/64] Refactor LabelledString and relevant classes --- manimlib/animation/creation.py | 5 +- .../animation/transform_matching_parts.py | 108 ++++--- manimlib/mobject/svg/labelled_string.py | 252 ++++++---------- manimlib/mobject/svg/mtex_mobject.py | 248 ++++++++------- manimlib/mobject/svg/text_mobject.py | 285 ++++++++---------- 5 files changed, 412 insertions(+), 486 deletions(-) diff --git a/manimlib/animation/creation.py b/manimlib/animation/creation.py index 5882df8f..6ad6a9bd 100644 --- a/manimlib/animation/creation.py +++ b/manimlib/animation/creation.py @@ -6,6 +6,7 @@ import numpy as np from manimlib.animation.animation import Animation from manimlib.mobject.svg.labelled_string import LabelledString +from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.bezier import integer_interpolate from manimlib.utils.config_ops import digest_config @@ -212,7 +213,9 @@ class AddTextWordByWord(ShowIncreasingSubsets): def __init__(self, string_mobject, **kwargs): assert isinstance(string_mobject, LabelledString) - grouped_mobject = string_mobject.submob_groups + grouped_mobject = VGroup(*[ + part for _, part in string_mobject.get_group_part_items() + ]) digest_config(self, kwargs) if self.run_time is None: self.run_time = self.time_per_word * len(grouped_mobject) diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index ec41fdcb..e84f1d9d 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -168,40 +168,36 @@ class TransformMatchingStrings(AnimationGroup): assert isinstance(source, LabelledString) assert isinstance(target, LabelledString) anims = [] - source_indices = list(range(len(source.labelled_submobject_items))) - target_indices = list(range(len(target.labelled_submobject_items))) - def get_indices_lists(mobject, parts): - labelled_submobjects = [ - submob - for _, submob in mobject.labelled_submobject_items - ] - return [ + source_submobs = [ + submob for _, submob in source.labelled_submobject_items + ] + target_submobs = [ + submob for _, submob in target.labelled_submobject_items + ] + source_indices = list(range(len(source_submobs))) + target_indices = list(range(len(target_submobs))) + + def get_filtered_indices_lists(parts, submobs, rest_indices): + return list(filter( + lambda indices_list: all([ + index in rest_indices + for index in indices_list + ]), [ - labelled_submobjects.index(submob) - for submob in part + [submobs.index(submob) for submob in part] + for part in parts ] - for part in parts - ] + )) - def add_anims_from(anim_class, func, source_args, target_args=None): - if target_args is None: - target_args = source_args.copy() - for source_arg, target_arg in zip(source_args, target_args): - source_parts = func(source, source_arg) - target_parts = func(target, target_arg) - source_indices_lists = list(filter( - lambda indices_list: all([ - index in source_indices - for index in indices_list - ]), get_indices_lists(source, source_parts) - )) - target_indices_lists = list(filter( - lambda indices_list: all([ - index in target_indices - for index in indices_list - ]), get_indices_lists(target, target_parts) - )) + def add_anims(anim_class, parts_pairs): + for source_parts, target_parts in parts_pairs: + source_indices_lists = get_filtered_indices_lists( + source_parts, source_submobs, source_indices + ) + target_indices_lists = get_filtered_indices_lists( + target_parts, target_submobs, target_indices + ) if not source_indices_lists or not target_indices_lists: continue anims.append(anim_class(source_parts, target_parts, **kwargs)) @@ -210,29 +206,45 @@ class TransformMatchingStrings(AnimationGroup): for index in it.chain(*target_indices_lists): target_indices.remove(index) - def get_common_substrs(substrs_from_source, substrs_from_target): - return sorted([ - substr for substr in substrs_from_source - if substr and substr in substrs_from_target - ], key=len, reverse=True) + def get_substr_to_parts_map(part_items): + result = {} + for substr, part in part_items: + if substr not in result: + result[substr] = [] + result[substr].append(part) + return result - add_anims_from( - ReplacementTransform, LabelledString.select_parts, - self.key_map.keys(), self.key_map.values() + def add_anims_from(anim_class, func): + source_substr_to_parts_map = get_substr_to_parts_map(func(source)) + target_substr_to_parts_map = get_substr_to_parts_map(func(target)) + add_anims( + anim_class, + [ + ( + VGroup(*source_substr_to_parts_map[substr]), + VGroup(*target_substr_to_parts_map[substr]) + ) + for substr in sorted([ + s for s in source_substr_to_parts_map.keys() + if s and s in target_substr_to_parts_map.keys() + ], key=len, reverse=True) + ] + ) + + add_anims( + ReplacementTransform, + [ + (source.select_parts(k), target.select_parts(v)) + for k, v in self.key_map.items() + ] ) add_anims_from( - FadeTransformPieces, LabelledString.select_parts, - get_common_substrs( - source.specified_substrs, - target.specified_substrs - ) + FadeTransformPieces, + LabelledString.get_specified_part_items ) add_anims_from( - FadeTransformPieces, LabelledString.select_parts_by_group_substr, - get_common_substrs( - source.group_substrs, - target.group_substrs - ) + FadeTransformPieces, + LabelledString.get_group_part_items ) rest_source = VGroup(*[source[index] for index in source_indices]) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 6be55ef4..7b3f457f 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -85,6 +85,7 @@ class LabelledString(SVGMobject, ABC): submob_color_ints = [0] * len(self.submobjects) if len(self.submobjects) != len(submob_color_ints): + print(len(self.submobjects), len(submob_color_ints)) raise ValueError( "Cannot align submobjects of the labelled svg " "to the original svg" @@ -106,31 +107,25 @@ class LabelledString(SVGMobject, ABC): def pre_parse(self) -> None: self.string_len = len(self.string) self.full_span = (0, self.string_len) - self.space_spans = self.find_spans(r"\s+") self.base_color_int = self.color_to_int(self.base_color) def parse(self) -> None: - self.command_repl_items = self.get_command_repl_items() - self.command_spans = self.get_command_spans() - self.extra_entity_spans = self.get_extra_entity_spans() + self.skippable_indices = self.get_skippable_indices() self.entity_spans = self.get_entity_spans() - self.extra_ignored_spans = self.get_extra_ignored_spans() - self.skipped_spans = self.get_skipped_spans() - self.internal_specified_spans = self.get_internal_specified_spans() - self.external_specified_spans = self.get_external_specified_spans() + self.bracket_spans = self.get_bracket_spans() + self.extra_isolated_items = self.get_extra_isolated_items() + self.specified_items = self.get_specified_items() self.specified_spans = self.get_specified_spans() - self.label_span_list = self.get_label_span_list() self.check_overlapping() + self.label_span_list = self.get_label_span_list() + if len(self.label_span_list) >= 16777216: + raise ValueError("Cannot handle that many substrings") def post_parse(self) -> None: self.labelled_submobject_items = [ (submob.label, submob) for submob in self.submobjects ] - self.specified_substrs = self.get_specified_substrs() - self.group_items = self.get_group_items() - self.group_substrs = self.get_group_substrs() - self.submob_groups = self.get_submob_groups() def copy(self): return self.deepcopy() @@ -140,16 +135,21 @@ class LabelledString(SVGMobject, ABC): def get_substr(self, span: Span) -> str: return self.string[slice(*span)] - def find_spans(self, pattern: str | re.Pattern) -> list[Span]: + def match(self, pattern: str | re.Pattern, **kwargs) -> re.Pattern | None: + if isinstance(pattern, str): + pattern = re.compile(pattern) + return re.compile(pattern).match(self.string, **kwargs) + + def find_spans(self, pattern: str | re.Pattern, **kwargs) -> list[Span]: if isinstance(pattern, str): pattern = re.compile(pattern) return [ match_obj.span() - for match_obj in pattern.finditer(self.string) + for match_obj in pattern.finditer(self.string, **kwargs) ] - def match_at(self, pattern: str, pos: int) -> re.Pattern | None: - return re.compile(pattern).match(self.string, pos=pos) + def find_indices(self, pattern: str | re.Pattern, **kwargs) -> list[int]: + return [index for index, _ in self.find_spans(pattern, **kwargs)] @staticmethod def is_single_selector(selector: Selector) -> bool: @@ -230,41 +230,24 @@ class LabelledString(SVGMobject, ABC): spans = LabelledString.get_neighbouring_pairs(indices) return list(zip(unique_vals, spans)) - @staticmethod - def find_region_index(seq: list[int], val: int) -> int: - # Returns an integer in `range(-1, len(seq))` satisfying - # `seq[result] <= val < seq[result + 1]`. - # `seq` should be sorted in ascending order. - if not seq or val < seq[0]: - return -1 - result = len(seq) - 1 - while val < seq[result]: - result -= 1 - return result - - @staticmethod - def take_nearest_value(seq: list[int], val: int, index_shift: int) -> int: - sorted_seq = sorted(seq) - index = LabelledString.find_region_index(sorted_seq, val) - return sorted_seq[index + index_shift] - @staticmethod def generate_span_repl_dict( inserted_string_pairs: list[tuple[Span, tuple[str, str]]], - other_repl_items: list[tuple[Span, str]] + repl_items: list[tuple[Span, str]] ) -> dict[Span, str]: - result = dict(other_repl_items) + result = dict(repl_items) if not inserted_string_pairs: return result - indices, _, _, inserted_strings = zip(*sorted([ + indices, _, _, _, inserted_strings = zip(*sorted([ ( - span[flag], + item[0][flag], -flag, - -span[1 - flag], - str_pair[flag] + -item[0][1 - flag], + (1, -1)[flag] * item_index, + item[1][flag] ) - for span, str_pair in inserted_string_pairs + for item_index, item in enumerate(inserted_string_pairs) for flag in range(2) ])) result.update({ @@ -295,22 +278,6 @@ class LabelledString(SVGMobject, ABC): repl_strs.append("") return "".join(it.chain(*zip(pieces, repl_strs))) - @staticmethod - def rslide(index: int, skipped: list[Span]) -> int: - transfer_dict = dict(sorted(skipped)) - while index in transfer_dict.keys(): - index = transfer_dict[index] - return index - - @staticmethod - def lslide(index: int, skipped: list[Span]) -> int: - transfer_dict = dict(sorted([ - skipped_span[::-1] for skipped_span in skipped - ], reverse=True)) - while index in transfer_dict.keys(): - index = transfer_dict[index] - return index - @staticmethod def color_to_int(color: ManimColor) -> int: hex_code = rgb_to_hex(color_to_rgb(color)) @@ -323,80 +290,63 @@ class LabelledString(SVGMobject, ABC): # Parsing @abstractmethod - def get_command_repl_items(self) -> list[tuple[Span, str]]: + def get_skippable_indices(self) -> list[int]: return [] - def get_command_spans(self) -> list[Span]: - return [cmd_span for cmd_span, _ in self.command_repl_items] + @staticmethod + def shrink_span(span: Span, skippable_indices: list[int]) -> Span: + span_begin, span_end = span + while span_begin in skippable_indices: + span_begin += 1 + while span_end - 1 in skippable_indices: + span_end -= 1 + return (span_begin, span_end) @abstractmethod - def get_extra_entity_spans(self) -> list[Span]: - return [] - def get_entity_spans(self) -> list[Span]: - return list(it.chain( - self.command_spans, - self.extra_entity_spans + return [] + + @abstractmethod + def get_bracket_spans(self) -> list[Span]: + return [] + + @abstractmethod + def get_extra_isolated_items(self) -> list[tuple[Span, dict[str, str]]]: + return [] + + def get_specified_items(self) -> list[tuple[Span, dict[str, str]]]: + span_items = list(it.chain( + self.extra_isolated_items, + [ + (span, {}) + for span in self.find_spans_by_selector(self.isolate) + ] )) - - def is_splittable_index(self, index: int) -> bool: - return not any([ - entity_span[0] < index < entity_span[1] - for entity_span in self.entity_spans - ]) - - @abstractmethod - def get_extra_ignored_spans(self) -> list[int]: - return [] - - def get_skipped_spans(self) -> list[Span]: - return list(it.chain( - self.find_spans(r"\s"), - self.command_spans, - self.extra_ignored_spans - )) - - def shrink_span(self, span: Span) -> Span: - return ( - self.rslide(span[0], self.skipped_spans), - self.lslide(span[1], self.skipped_spans) - ) - - @abstractmethod - def get_internal_specified_spans(self) -> list[Span]: - return [] - - @abstractmethod - def get_external_specified_spans(self) -> list[Span]: - return [] + result = [] + for span, attr_dict in span_items: + shrinked_span = self.shrink_span(span, self.skippable_indices) + if shrinked_span[0] >= shrinked_span[1]: + continue + if any([ + entity_span[0] < index < entity_span[1] + for index in shrinked_span + for entity_span in self.entity_spans + ]): + continue + result.append((shrinked_span, attr_dict)) + return result def get_specified_spans(self) -> list[Span]: - spans = list(it.chain( - self.internal_specified_spans, - self.external_specified_spans, - self.find_spans_by_selector(self.isolate) - )) - filtered_spans = list(filter( - lambda span: all([ - self.is_splittable_index(index) - for index in span - ]), - spans - )) - shrinked_spans = list(filter( - lambda span: span[0] < span[1], - [self.shrink_span(span) for span in filtered_spans] - )) - return remove_list_redundancies(shrinked_spans) - - @abstractmethod - def get_label_span_list(self) -> list[Span]: - return [] + return remove_list_redundancies([ + span for span, _ in self.specified_items + ]) def check_overlapping(self) -> None: - if len(self.label_span_list) >= 16777216: - raise ValueError("Cannot label that many substrings") - for span_0, span_1 in it.product(self.label_span_list, repeat=2): + spans = remove_list_redundancies(list(it.chain( + self.specified_spans, + self.bracket_spans + ))) + for span_0, span_1 in it.product(spans, repeat=2): if not span_0[0] < span_1[0] < span_0[1] < span_1[1]: continue raise ValueError( @@ -404,23 +354,21 @@ class LabelledString(SVGMobject, ABC): f"'{self.get_substr(span_0)}' and '{self.get_substr(span_1)}'" ) + @abstractmethod + def get_label_span_list(self) -> list[Span]: + return [] + @abstractmethod def get_content(self, is_labelled: bool) -> str: return "" # Post-parsing + @abstractmethod def get_cleaned_substr(self, span: Span) -> str: - span_repl_dict = dict.fromkeys(self.command_spans, "") - return self.get_replaced_substr(span, span_repl_dict) + return "" - def get_specified_substrs(self) -> list[str]: - return remove_list_redundancies([ - self.get_cleaned_substr(span) - for span in self.specified_spans - ]) - - def get_group_items(self) -> list[tuple[str, VGroup]]: + def get_group_part_items(self) -> list[tuple[str, VGroup]]: if not self.labelled_submobject_items: return [] @@ -445,41 +393,33 @@ class LabelledString(SVGMobject, ABC): ordered_spans ) ] - shrinked_spans = [ - self.shrink_span(span) + group_substrs = [ + self.get_cleaned_substr(span) if span[0] < span[1] else "" for span in self.get_complement_spans( interval_spans, (ordered_spans[0][0], ordered_spans[-1][1]) ) ] - group_substrs = [ - self.get_cleaned_substr(span) if span[0] < span[1] else "" - for span in shrinked_spans - ] submob_groups = VGroup(*[ VGroup(*labelled_submobjects[slice(*submob_span)]) for submob_span in labelled_submob_spans ]) return list(zip(group_substrs, submob_groups)) - def get_group_substrs(self) -> list[str]: - return [group_substr for group_substr, _ in self.group_items] - - def get_submob_groups(self) -> list[VGroup]: - return [submob_group for _, submob_group in self.group_items] - - def select_parts_by_group_substr(self, substr: str) -> VGroup: - return VGroup(*[ - group - for group_substr, group in self.group_items - if group_substr == substr - ]) + def get_specified_part_items(self) -> list[tuple[str, VGroup]]: + return [ + ( + self.get_substr(span), + self.select_part_by_span(span, substring=False) + ) + for span in self.specified_spans + ] # Selector def find_span_components( self, custom_span: Span, substring: bool = True ) -> list[Span]: - shrinked_span = self.shrink_span(custom_span) + shrinked_span = self.shrink_span(custom_span, self.skippable_indices) if shrinked_span[0] >= shrinked_span[1]: return [] @@ -488,12 +428,12 @@ class LabelledString(SVGMobject, ABC): self.full_span, *self.label_span_list ))) - span_begin = self.take_nearest_value( - indices, shrinked_span[0], 0 - ) - span_end = self.take_nearest_value( - indices, shrinked_span[1] - 1, 1 - ) + span_begin = max(filter( + lambda index: index <= shrinked_span[0], indices + )) + span_end = min(filter( + lambda index: index >= shrinked_span[1], indices + )) else: span_begin, span_end = shrinked_span diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 8d7f8a90..1c3a5f20 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -33,6 +33,15 @@ if TYPE_CHECKING: SCALE_FACTOR_PER_FONT_POINT = 0.001 +TEX_COLOR_COMMANDS_DICT = { + "\\color": (1, False), + "\\textcolor": (1, False), + "\\pagecolor": (1, True), + "\\colorbox": (1, True), + "\\fcolorbox": (2, True), +} + + class MTex(LabelledString): CONFIG = { "font_size": 48, @@ -78,10 +87,12 @@ class MTex(LabelledString): def pre_parse(self) -> None: super().pre_parse() self.backslash_indices = self.get_backslash_indices() - self.brace_index_pairs = self.get_brace_index_pairs() - self.script_char_spans = self.get_script_char_spans() + self.command_spans = self.get_command_spans() + self.brace_spans = self.get_brace_spans() + self.script_char_indices = self.get_script_char_indices() self.script_content_spans = self.get_script_content_spans() self.script_spans = self.get_script_spans() + self.command_repl_items = self.get_command_repl_items() # Toolkits @@ -95,61 +106,61 @@ class MTex(LabelledString): def get_backslash_indices(self) -> list[int]: # The latter of `\\` doesn't count. - return list(it.chain(*[ - range(span[0], span[1], 2) - for span in self.find_spans(r"\\+") - ])) + return self.find_indices(r"\\.") - def get_unescaped_char_spans(self, chars: str): - return sorted(filter( - lambda span: span[0] - 1 not in self.backslash_indices, - list(it.chain(*[ - self.find_spans(re.escape(char)) - for char in chars - ])) + def get_command_spans(self) -> list[Span]: + return [ + self.match(r"\\(?:[a-zA-Z]+|.)", pos=index).span() + for index in self.backslash_indices + ] + + def get_unescaped_char_indices(self, char: str) -> list[int]: + return list(filter( + lambda index: index - 1 not in self.backslash_indices, + self.find_indices(re.escape(char)) )) - def get_brace_index_pairs(self) -> list[Span]: - left_brace_indices = [] - right_brace_indices = [] - left_brace_indices_stack = [] - for span in self.get_unescaped_char_spans("{}"): - index = span[0] - if self.get_substr(span) == "{": - left_brace_indices_stack.append(index) + def get_brace_spans(self) -> list[Span]: + span_begins = [] + span_ends = [] + span_begins_stack = [] + char_items = sorted([ + (index, char) + for char in "{}" + for index in self.get_unescaped_char_indices(char) + ]) + for index, char in char_items: + if char == "{": + span_begins_stack.append(index) else: - if not left_brace_indices_stack: + if not span_begins_stack: raise ValueError("Missing '{' inserted") - left_brace_index = left_brace_indices_stack.pop() - left_brace_indices.append(left_brace_index) - right_brace_indices.append(index) - if left_brace_indices_stack: + span_begins.append(span_begins_stack.pop()) + span_ends.append(index + 1) + if span_begins_stack: raise ValueError("Missing '}' inserted") - return list(zip(left_brace_indices, right_brace_indices)) + return list(zip(span_begins, span_ends)) - def get_script_char_spans(self) -> list[int]: - return self.get_unescaped_char_spans("_^") + def get_script_char_indices(self) -> list[int]: + return list(it.chain(*[ + self.get_unescaped_char_indices(char) + for char in "_^" + ])) def get_script_content_spans(self) -> list[Span]: result = [] - brace_indices_dict = dict(self.brace_index_pairs) - script_pattern = r"[a-zA-Z0-9]|\\[a-zA-Z]+" - for char_span in self.script_char_spans: - span_begin = self.rslide(char_span[1], self.space_spans) - if span_begin in brace_indices_dict.keys(): - span_end = brace_indices_dict[span_begin] + 1 + script_entity_dict = dict(it.chain( + self.brace_spans, + self.command_spans + )) + for index in self.script_char_indices: + span_begin = self.match(r"\s*", pos=index + 1).end() + if span_begin in script_entity_dict.keys(): + span_end = script_entity_dict[span_begin] else: - match_obj = self.match_at(script_pattern, span_begin) + match_obj = self.match(r".", pos=span_begin) if match_obj is None: - script_name = { - "_": "subscript", - "^": "superscript" - }[self.get_string(char_span)] - raise ValueError( - f"Unclear {script_name} detected while parsing " - f"(position {char_span[0]}). " - "Please use braces to clarify" - ) + continue span_end = match_obj.end() result.append((span_begin, span_end)) return result @@ -157,46 +168,29 @@ class MTex(LabelledString): def get_script_spans(self) -> list[Span]: return [ ( - self.lslide(char_span[0], self.space_spans), + self.match(r"[\s\S]*?(\s*)$", endpos=index).start(1), script_content_span[1] ) - for char_span, script_content_span in zip( - self.script_char_spans, self.script_content_spans + for index, script_content_span in zip( + self.script_char_indices, self.script_content_spans ) ] - # Parsing - def get_command_repl_items(self) -> list[tuple[Span, str]]: - color_related_command_dict = { - "color": (1, False), - "textcolor": (1, False), - "pagecolor": (1, True), - "colorbox": (1, True), - "fcolorbox": (2, True), - } result = [] - backslash_indices = self.backslash_indices - right_brace_indices = [ - right_index - for left_index, right_index in self.brace_index_pairs - ] - pattern = "".join([ - r"\\", - "(", - "|".join(color_related_command_dict.keys()), - ")", - r"(?![a-zA-Z])" - ]) - for match_obj in re.finditer(pattern, self.string): - span_begin, cmd_end = match_obj.span() - if span_begin not in backslash_indices: + brace_spans_dict = dict(self.brace_spans) + brace_begins = list(brace_spans_dict.keys()) + for cmd_span in self.command_spans: + cmd_name = self.get_substr(cmd_span) + if cmd_name not in TEX_COLOR_COMMANDS_DICT.keys(): continue - cmd_name = match_obj.group(1) - n_braces, substitute_cmd = color_related_command_dict[cmd_name] - span_end = self.take_nearest_value( - right_brace_indices, cmd_end, n_braces - ) + 1 + n_braces, substitute_cmd = TEX_COLOR_COMMANDS_DICT[cmd_name] + span_begin, span_end = cmd_span + for _ in n_braces: + span_end = brace_spans_dict[min(filter( + lambda index: index >= span_end, + brace_begins + ))] if substitute_cmd: repl_str = "\\" + cmd_name + n_braces * "{black}" else: @@ -204,51 +198,60 @@ class MTex(LabelledString): result.append(((span_begin, span_end), repl_str)) return result - def get_extra_entity_spans(self) -> list[Span]: - return [ - self.match_at(r"\\([a-zA-Z]+|.?)", index).span() - for index in self.backslash_indices - ] + # Parsing - def get_extra_ignored_spans(self) -> list[int]: - return self.script_char_spans.copy() + def get_skippable_indices(self) -> list[int]: + return list(it.chain( + self.find_indices(r"\s"), + self.script_char_indices + )) - def get_internal_specified_spans(self) -> list[Span]: - # Match paired double braces (`{{...}}`). + def get_entity_spans(self) -> list[Span]: + return self.command_spans.copy() + + def get_bracket_spans(self) -> list[Span]: + return self.brace_spans.copy() + + def get_extra_isolated_items(self) -> list[tuple[Span, dict[str, str]]]: result = [] - reversed_brace_indices_dict = dict([ - pair[::-1] for pair in self.brace_index_pairs - ]) + + # Match paired double braces (`{{...}}`). + reversed_brace_spans_dict = dict(sorted([ + pair[::-1] for pair in self.brace_spans + ])) skip = False - for prev_right_index, right_index in self.get_neighbouring_pairs( - list(reversed_brace_indices_dict.keys()) + for prev_brace_end, brace_end in self.get_neighbouring_pairs( + list(reversed_brace_spans_dict.keys()) ): if skip: skip = False continue - if right_index != prev_right_index + 1: + if brace_end != prev_brace_end + 1: continue - left_index = reversed_brace_indices_dict[right_index] - prev_left_index = reversed_brace_indices_dict[prev_right_index] - if left_index != prev_left_index - 1: + brace_begin = reversed_brace_spans_dict[brace_end] + prev_brace_begin = reversed_brace_spans_dict[prev_brace_end] + if brace_begin != prev_brace_begin - 1: continue - result.append((left_index, right_index + 1)) + result.append((brace_begin, brace_end)) skip = True - return result - def get_external_specified_spans(self) -> list[Span]: - return list(it.chain(*[ + result.extend(it.chain(*[ self.find_spans_by_selector(selector) for selector in self.tex_to_color_map.keys() ])) + return [(span, {}) for span in result] def get_label_span_list(self) -> list[Span]: result = self.script_content_spans.copy() + reversed_script_spans_dict = dict([ + script_span[::-1] for script_span in self.script_spans + ]) for span_begin, span_end in self.specified_spans: - shrinked_end = self.lslide(span_end, self.script_spans) - if span_begin >= shrinked_end: + while span_end in reversed_script_spans_dict.keys(): + span_end = reversed_script_spans_dict[span_end] + if span_begin >= span_end: continue - shrinked_span = (span_begin, shrinked_end) + shrinked_span = (span_begin, span_end) if shrinked_span in result: continue result.append(shrinked_span) @@ -256,12 +259,15 @@ class MTex(LabelledString): def get_content(self, is_labelled: bool) -> str: if is_labelled: - extended_label_span_list = [ - span - if span in self.script_content_spans - else (span[0], self.rslide(span[1], self.script_spans)) - for span in self.label_span_list - ] + extended_label_span_list = [] + script_spans_dict = dict(self.script_spans) + for span in self.label_span_list: + if span not in self.script_content_spans: + span_begin, span_end = span + while span_end in script_spans_dict.keys(): + span_end = script_spans_dict[span_end] + span = (span_begin, span_end) + extended_label_span_list.append(span) inserted_string_pairs = [ (span, ( "{{" + self.get_color_command_str(label + 1), @@ -270,8 +276,7 @@ class MTex(LabelledString): for label, span in enumerate(extended_label_span_list) ] span_repl_dict = self.generate_span_repl_dict( - inserted_string_pairs, - self.command_repl_items + inserted_string_pairs, self.command_repl_items ) else: span_repl_dict = {} @@ -296,15 +301,26 @@ class MTex(LabelledString): # Post-parsing def get_cleaned_substr(self, span: Span) -> str: - substr = super().get_cleaned_substr(span) - if not self.brace_index_pairs: - return substr + if not self.brace_spans: + brace_begins, brace_ends = [], [] + else: + brace_begins, brace_ends = zip(*self.brace_spans) + left_brace_indices = list(brace_begins) + right_brace_indices = [index - 1 for index in brace_ends] + skippable_indices = list(it.chain( + self.skippable_indices, + left_brace_indices, + right_brace_indices + )) + shrinked_span = self.shrink_span(span, skippable_indices) + + if shrinked_span[0] >= shrinked_span[1]: + return "" # Balance braces. - left_brace_indices, right_brace_indices = zip(*self.brace_index_pairs) unclosed_left_braces = 0 unclosed_right_braces = 0 - for index in range(*span): + for index in range(*shrinked_span): if index in left_brace_indices: unclosed_left_braces += 1 elif index in right_brace_indices: @@ -314,7 +330,7 @@ class MTex(LabelledString): unclosed_left_braces -= 1 return "".join([ unclosed_right_braces * "{", - substr, + self.get_substr(shrinked_span), unclosed_left_braces * "}" ]) diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index f4000b3d..372e23cf 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -20,7 +20,6 @@ from manimlib.utils.config_ops import digest_config from manimlib.utils.customization import get_customization from manimlib.utils.directories import get_downloads_dir from manimlib.utils.directories import get_text_dir -from manimlib.utils.iterables import remove_list_redundancies from manimlib.utils.tex_file_writing import tex_hash from typing import TYPE_CHECKING @@ -244,11 +243,9 @@ class MarkupText(LabelledString): def pre_parse(self) -> None: super().pre_parse() - self.tag_items_from_markup = self.get_tag_items_from_markup() - self.global_dict_from_config = self.get_global_dict_from_config() - self.local_dicts_from_markup = self.get_local_dicts_from_markup() - self.local_dicts_from_config = self.get_local_dicts_from_config() - self.predefined_attr_dicts = self.get_predefined_attr_dicts() + self.tag_pairs_from_markup = self.get_tag_pairs_from_markup() + self.tag_spans = self.get_tag_spans() + self.items_from_markup = self.get_items_from_markup() # Toolkits @@ -259,42 +256,9 @@ class MarkupText(LabelledString): for key, val in attr_dict.items() ]) - @staticmethod - def merge_attr_dicts( - attr_dict_items: list[tuple[Span, dict[str, str]]] - ) -> list[tuple[Span, dict[str, str]]]: - index_seq = [0] - attr_dict_list = [{}] - for span, attr_dict in attr_dict_items: - if span[0] >= span[1]: - continue - region_indices = [ - MarkupText.find_region_index(index_seq, index) - for index in span - ] - for flag in (1, 0): - if index_seq[region_indices[flag]] == span[flag]: - continue - region_index = region_indices[flag] - index_seq.insert(region_index + 1, span[flag]) - attr_dict_list.insert( - region_index + 1, attr_dict_list[region_index].copy() - ) - region_indices[flag] += 1 - if flag == 0: - region_indices[1] += 1 - for key, val in attr_dict.items(): - if not key: - continue - for mid_dict in attr_dict_list[slice(*region_indices)]: - mid_dict[key] = val - return list(zip( - MarkupText.get_neighbouring_pairs(index_seq), attr_dict_list[:-1] - )) - # Pre-parsing - def get_tag_items_from_markup( + def get_tag_pairs_from_markup( self ) -> list[tuple[Span, Span, dict[str, str]]]: if not self.is_markup: @@ -342,52 +306,64 @@ class MarkupText(LabelledString): ) return result - def get_global_dict_from_config(self) -> dict[str, str]: - result = { - "line_height": str(( - (self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1 - ) * 0.6), - "font_family": self.font, - "font_size": str(self.font_size * 1024), - "font_style": self.slant, - "font_weight": self.weight - } - result.update(self.global_config) - return result - - def get_local_dicts_from_markup( - self - ) -> list[Span, dict[str, str]]: - return sorted([ - ((begin_tag_span[0], end_tag_span[1]), attr_dict) - for begin_tag_span, end_tag_span, attr_dict - in self.tag_items_from_markup - ]) - - def get_local_dicts_from_config( - self - ) -> list[Span, dict[str, str]]: + def get_tag_spans(self) -> list[Span]: return [ - (span, {key: val}) - for t2x_dict, key in ( - (self.t2c, "foreground"), - (self.t2f, "font_family"), - (self.t2s, "font_style"), - (self.t2w, "font_weight") - ) - for selector, val in t2x_dict.items() - for span in self.find_spans_by_selector(selector) - ] + [ - (span, local_config) - for selector, local_config in self.local_configs.items() - for span in self.find_spans_by_selector(selector) + tag_span + for begin_tag, end_tag, _ in self.tag_pairs_from_markup + for tag_span in (begin_tag, end_tag) ] - def get_predefined_attr_dicts(self) -> list[Span, dict[str, str]]: - attr_dict_items = [ - (self.full_span, self.global_dict_from_config), - *self.local_dicts_from_markup, - *self.local_dicts_from_config + def get_items_from_markup(self) -> list[Span]: + return [ + ((begin_tag_span[0], end_tag_span[1]), attr_dict) + for begin_tag_span, end_tag_span, attr_dict + in self.tag_pairs_from_markup + ] + + # Parsing + + def get_skippable_indices(self) -> list[int]: + return self.find_indices(r"\s") + + def get_entity_spans(self) -> list[Span]: + result = self.tag_spans.copy() + if self.is_markup: + result.extend(self.find_spans(r"&[\s\S]*?;")) + return result + + def get_bracket_spans(self) -> list[Span]: + return [span for span, _ in self.items_from_markup] + + def get_extra_isolated_items(self) -> list[tuple[Span, dict[str, str]]]: + result = [ + (self.full_span, { + "line_height": str(( + (self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1 + ) * 0.6), + "font_family": self.font, + "font_size": str(self.font_size * 1024), + "font_style": self.slant, + "font_weight": self.weight, + "foreground": self.int_to_hex(self.base_color_int) + }), + (self.full_span, self.global_config), + *self.items_from_markup, + *[ + (span, {key: val}) + for t2x_dict, key in ( + (self.t2c, "foreground"), + (self.t2f, "font_family"), + (self.t2s, "font_style"), + (self.t2w, "font_weight") + ) + for selector, val in t2x_dict.items() + for span in self.find_spans_by_selector(selector) + ], + *[ + (span, local_config) + for selector, local_config in self.local_configs.items() + for span in self.find_spans_by_selector(selector) + ] ] key_conversion_dict = { key: key_alias_list[0] @@ -399,19 +375,63 @@ class MarkupText(LabelledString): key_conversion_dict[key.lower()]: val for key, val in attr_dict.items() }) + for span, attr_dict in result + ] + + def get_label_span_list(self) -> list[Span]: + interval_spans = sorted(it.chain( + self.tag_spans, + [ + (index, index) + for span in self.specified_spans + for index in span + ] + )) + text_spans = self.get_complement_spans(interval_spans, self.full_span) + if self.is_markup: + pattern = r"[0-9a-zA-Z]+|(?:&[\s\S]*?;|[^0-9a-zA-Z\s])+" + else: + pattern = r"[0-9a-zA-Z]+|[^0-9a-zA-Z\s]+" + return list(it.chain(*[ + self.find_spans(pattern, pos=span_begin, endpos=span_end) + for span_begin, span_end in text_spans + ])) + + def get_content(self, is_labelled: bool) -> str: + if is_labelled: + attr_dict_items = list(it.chain( + [ + (span, { + key: BLACK if key in MARKUP_COLOR_KEYS else val + for key, val in attr_dict.items() + }) + for span, attr_dict in self.specified_items + ], + [ + (span, {"foreground": self.int_to_hex(label + 1)}) + for label, span in enumerate(self.label_span_list) + ] + )) + else: + attr_dict_items = list(it.chain( + self.specified_items, + [ + (span, {}) + for span in self.label_span_list + ] + )) + inserted_string_pairs = [ + (span, ( + f"", + "" + )) for span, attr_dict in attr_dict_items ] - - # Parsing - - def get_command_repl_items(self) -> list[tuple[Span, str]]: - result = [ - (tag_span, "") - for begin_tag, end_tag, _ in self.tag_items_from_markup - for tag_span in (begin_tag, end_tag) + repl_items = [ + (tag_span, "") for tag_span in self.tag_spans ] if not self.is_markup: - result += [ + repl_items.extend([ (span, escaped) for char, escaped in ( ("&", "&"), @@ -419,83 +439,18 @@ class MarkupText(LabelledString): ("<", "<") ) for span in self.find_spans(re.escape(char)) - ] - return result - - def get_extra_entity_spans(self) -> list[Span]: - if not self.is_markup: - return [] - return self.find_spans(r"&.*?;") - - def get_extra_ignored_spans(self) -> list[int]: - return [] - - def get_internal_specified_spans(self) -> list[Span]: - return [] - - def get_external_specified_spans(self) -> list[Span]: - return [span for span, _ in self.local_dicts_from_config] - - def get_label_span_list(self) -> list[Span]: - breakup_indices = remove_list_redundancies(list(it.chain(*it.chain( - self.find_spans(r"\b"), - self.space_spans, - self.specified_spans - )))) - breakup_indices = sorted(filter( - self.is_splittable_index, breakup_indices - )) - return list(filter( - lambda span: self.get_substr(span).strip(), - self.get_neighbouring_pairs(breakup_indices) - )) - - def get_content(self, is_labelled: bool) -> str: - filtered_attr_dicts = list(filter( - lambda item: all([ - self.is_splittable_index(index) - for index in item[0] - ]), - self.predefined_attr_dicts - )) - if is_labelled: - attr_dict_items = [ - (self.full_span, {"foreground": BLACK}), - *[ - (span, { - key: BLACK if key in MARKUP_COLOR_KEYS else val - for key, val in attr_dict.items() - }) - for span, attr_dict in filtered_attr_dicts - ], - *[ - (span, {"foreground": self.int_to_hex(label + 1)}) - for label, span in enumerate(self.label_span_list) - ] - ] - else: - attr_dict_items = [ - (self.full_span, { - "foreground": self.int_to_hex(self.base_color_int) - }), - *filtered_attr_dicts, - *[ - (span, {}) - for span in self.label_span_list - ] - ] - inserted_string_pairs = [ - (span, ( - f"", - "" - )) - for span, attr_dict in self.merge_attr_dicts(attr_dict_items) - ] + ]) span_repl_dict = self.generate_span_repl_dict( - inserted_string_pairs, self.command_repl_items + inserted_string_pairs, repl_items ) return self.get_replaced_substr(self.full_span, span_repl_dict) + # Post-parsing + + def get_cleaned_substr(self, span: Span) -> str: + repl_dict = dict.fromkeys(self.tag_spans, "") + return self.get_replaced_substr(span, repl_dict).strip() + # Method alias def get_parts_by_text(self, selector: Selector, **kwargs) -> VGroup: From cbb7e69f68d22b848413da95bb3021f31c8e1849 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Mon, 18 Apr 2022 18:47:57 +0800 Subject: [PATCH 29/64] Refactor LabelledString and relevant classes --- manimlib/mobject/number_line.py | 2 + manimlib/mobject/svg/labelled_string.py | 25 ++-- manimlib/mobject/svg/mtex_mobject.py | 14 +- manimlib/mobject/svg/svg_mobject.py | 4 +- manimlib/mobject/svg/text_mobject.py | 180 ++++++++++-------------- manimlib/utils/iterables.py | 10 +- 6 files changed, 100 insertions(+), 135 deletions(-) diff --git a/manimlib/mobject/number_line.py b/manimlib/mobject/number_line.py index 2553ac3c..e16382e6 100644 --- a/manimlib/mobject/number_line.py +++ b/manimlib/mobject/number_line.py @@ -1,5 +1,7 @@ from __future__ import annotations +import numpy as np + from manimlib.constants import DOWN, LEFT, RIGHT, UP from manimlib.constants import GREY_B from manimlib.constants import MED_SMALL_BUFF diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 7b3f457f..55b8fca6 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -53,11 +53,16 @@ class LabelledString(SVGMobject, ABC): digest_config(self, kwargs) if self.base_color is None: self.base_color = WHITE + self.base_color_int = self.color_to_int(self.base_color) - self.pre_parse() + self.string_len = len(self.string) + self.full_span = (0, self.string_len) self.parse() super().__init__() - self.post_parse() + self.labelled_submobject_items = [ + (submob.label, submob) + for submob in self.submobjects + ] def get_file_path(self) -> str: return self.get_file_path_(is_labelled=False) @@ -85,7 +90,6 @@ class LabelledString(SVGMobject, ABC): submob_color_ints = [0] * len(self.submobjects) if len(self.submobjects) != len(submob_color_ints): - print(len(self.submobjects), len(submob_color_ints)) raise ValueError( "Cannot align submobjects of the labelled svg " "to the original svg" @@ -104,11 +108,6 @@ class LabelledString(SVGMobject, ABC): for submob, color_int in zip(self.submobjects, submob_color_ints): submob.label = color_int - 1 - def pre_parse(self) -> None: - self.string_len = len(self.string) - self.full_span = (0, self.string_len) - self.base_color_int = self.color_to_int(self.base_color) - def parse(self) -> None: self.skippable_indices = self.get_skippable_indices() self.entity_spans = self.get_entity_spans() @@ -121,12 +120,6 @@ class LabelledString(SVGMobject, ABC): if len(self.label_span_list) >= 16777216: raise ValueError("Cannot handle that many substrings") - def post_parse(self) -> None: - self.labelled_submobject_items = [ - (submob.label, submob) - for submob in self.submobjects - ] - def copy(self): return self.deepcopy() @@ -362,7 +355,7 @@ class LabelledString(SVGMobject, ABC): def get_content(self, is_labelled: bool) -> str: return "" - # Post-parsing + # Selector @abstractmethod def get_cleaned_substr(self, span: Span) -> str: @@ -414,8 +407,6 @@ class LabelledString(SVGMobject, ABC): for span in self.specified_spans ] - # Selector - def find_span_components( self, custom_span: Span, substring: bool = True ) -> list[Span]: diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 1c3a5f20..8b99b924 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -84,8 +84,7 @@ class MTex(LabelledString): file_path = tex_to_svg_file(full_tex) return file_path - def pre_parse(self) -> None: - super().pre_parse() + def parse(self) -> None: self.backslash_indices = self.get_backslash_indices() self.command_spans = self.get_command_spans() self.brace_spans = self.get_brace_spans() @@ -93,6 +92,7 @@ class MTex(LabelledString): self.script_content_spans = self.get_script_content_spans() self.script_spans = self.get_script_spans() self.command_repl_items = self.get_command_repl_items() + super().parse() # Toolkits @@ -102,7 +102,7 @@ class MTex(LabelledString): r, g = divmod(rg, 256) return f"\\color[RGB]{{{r}, {g}, {b}}}" - # Pre-parsing + # Parsing def get_backslash_indices(self) -> list[int]: # The latter of `\\` doesn't count. @@ -186,20 +186,18 @@ class MTex(LabelledString): continue n_braces, substitute_cmd = TEX_COLOR_COMMANDS_DICT[cmd_name] span_begin, span_end = cmd_span - for _ in n_braces: + for _ in range(n_braces): span_end = brace_spans_dict[min(filter( lambda index: index >= span_end, brace_begins ))] if substitute_cmd: - repl_str = "\\" + cmd_name + n_braces * "{black}" + repl_str = cmd_name + n_braces * "{black}" else: repl_str = "" result.append(((span_begin, span_end), repl_str)) return result - # Parsing - def get_skippable_indices(self) -> list[int]: return list(it.chain( self.find_indices(r"\s"), @@ -298,7 +296,7 @@ class MTex(LabelledString): ]) return result - # Post-parsing + # Selector def get_cleaned_substr(self, span: Span) -> str: if not self.brace_spans: diff --git a/manimlib/mobject/svg/svg_mobject.py b/manimlib/mobject/svg/svg_mobject.py index 09a11ff8..88b99d12 100644 --- a/manimlib/mobject/svg/svg_mobject.py +++ b/manimlib/mobject/svg/svg_mobject.py @@ -198,9 +198,9 @@ class SVGMobject(VMobject): ) -> VMobject: mob.set_style( stroke_width=shape.stroke_width, - stroke_color=shape.stroke.hex, + stroke_color=shape.stroke.hexrgb, stroke_opacity=shape.stroke.opacity, - fill_color=shape.fill.hex, + fill_color=shape.fill.hexrgb, fill_opacity=shape.fill.opacity ) return mob diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 372e23cf..13f5b9c0 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -6,13 +6,12 @@ import os from pathlib import Path import re -from manimpango import MarkupUtils +import manimpango import pygments import pygments.formatters import pygments.lexers -from manimlib.constants import BLACK -from manimlib.constants import DEFAULT_PIXEL_HEIGHT, DEFAULT_PIXEL_WIDTH +from manimlib.constants import DEFAULT_PIXEL_WIDTH, FRAME_WIDTH from manimlib.constants import NORMAL from manimlib.logger import log from manimlib.mobject.svg.labelled_string import LabelledString @@ -46,48 +45,15 @@ if TYPE_CHECKING: TEXT_MOB_SCALE_FACTOR = 0.0076 DEFAULT_LINE_SPACING_SCALE = 0.6 +# Ensure the canvas is large enough to hold all glyphs. +DEFAULT_CANVAS_WIDTH = 16384 +DEFAULT_CANVAS_HEIGHT = 16384 # See https://docs.gtk.org/Pango/pango_markup.html -# A tag containing two aliases will cause warning, -# so only use the first key of each group of aliases. -MARKUP_KEY_ALIAS_LIST = ( - ("font", "font_desc"), - ("font_family", "face"), - ("font_size", "size"), - ("font_style", "style"), - ("font_weight", "weight"), - ("font_variant", "variant"), - ("font_stretch", "stretch"), - ("font_features",), - ("foreground", "fgcolor", "color"), - ("background", "bgcolor"), - ("alpha", "fgalpha"), - ("background_alpha", "bgalpha"), - ("underline",), - ("underline_color",), - ("overline",), - ("overline_color",), - ("rise",), - ("baseline_shift",), - ("font_scale",), - ("strikethrough",), - ("strikethrough_color",), - ("fallback",), - ("lang",), - ("letter_spacing",), - ("gravity",), - ("gravity_hint",), - ("show",), - ("insert_hyphens",), - ("allow_breaks",), - ("line_height",), - ("text_transform",), - ("segment",), -) MARKUP_COLOR_KEYS = ( - "foreground", - "background", + "foreground", "fgcolor", "color", + "background", "bgcolor", "underline_color", "overline_color", "strikethrough_color" @@ -125,7 +91,7 @@ class MarkupText(LabelledString): "justify": False, "indent": 0, "alignment": "LEFT", - "line_width_factor": None, + "line_width": None, "font": "", "slant": NORMAL, "weight": NORMAL, @@ -146,9 +112,7 @@ class MarkupText(LabelledString): if not self.font: self.font = get_customization()["style"]["font"] if self.is_markup: - validate_error = MarkupUtils.validate(text) - if validate_error: - raise ValueError(validate_error) + self.validate_markup_string(text) self.text = text super().__init__(text, **kwargs) @@ -178,7 +142,7 @@ class MarkupText(LabelledString): self.justify, self.indent, self.alignment, - self.line_width_factor, + self.line_width, self.font, self.slant, self.weight, @@ -205,23 +169,32 @@ class MarkupText(LabelledString): kwargs[short_name] = kwargs.pop(long_name) def get_file_path_by_content(self, content: str) -> str: + hash_content = str(( + content, + self.justify, + self.indent, + self.alignment, + self.line_width + )) svg_file = os.path.join( - get_text_dir(), tex_hash(content) + ".svg" + get_text_dir(), tex_hash(hash_content) + ".svg" ) if not os.path.exists(svg_file): self.markup_to_svg(content, svg_file) return svg_file def markup_to_svg(self, markup_str: str, file_name: str) -> str: + self.validate_markup_string(markup_str) + # `manimpango` is under construction, # so the following code is intended to suit its interface alignment = _Alignment(self.alignment) - if self.line_width_factor is None: + if self.line_width is None: pango_width = -1 else: - pango_width = self.line_width_factor * DEFAULT_PIXEL_WIDTH + pango_width = self.line_width / FRAME_WIDTH * DEFAULT_PIXEL_WIDTH - return MarkupUtils.text2svg( + return manimpango.MarkupUtils.text2svg( text=markup_str, font="", # Already handled slant="NORMAL", # Already handled @@ -232,8 +205,8 @@ class MarkupText(LabelledString): file_name=file_name, START_X=0, START_Y=0, - width=DEFAULT_PIXEL_WIDTH, - height=DEFAULT_PIXEL_HEIGHT, + width=DEFAULT_CANVAS_WIDTH, + height=DEFAULT_CANVAS_HEIGHT, justify=self.justify, indent=self.indent, line_spacing=None, # Already handled @@ -241,11 +214,22 @@ class MarkupText(LabelledString): pango_width=pango_width ) - def pre_parse(self) -> None: - super().pre_parse() + @staticmethod + def validate_markup_string(markup_str: str) -> None: + validate_error = manimpango.MarkupUtils.validate(markup_str) + if not validate_error: + return + raise ValueError( + f"Invalid markup string \"{markup_str}\"\n" + f"{validate_error}" + ) + + def parse(self) -> None: + self.global_attr_dict = self.get_global_attr_dict() self.tag_pairs_from_markup = self.get_tag_pairs_from_markup() self.tag_spans = self.get_tag_spans() self.items_from_markup = self.get_items_from_markup() + super().parse() # Toolkits @@ -256,7 +240,24 @@ class MarkupText(LabelledString): for key, val in attr_dict.items() ]) - # Pre-parsing + # Parsing + + def get_global_attr_dict(self) -> dict[str, str]: + result = { + "font_size": str(self.font_size * 1024), + "foreground": self.int_to_hex(self.base_color_int), + "font_family": self.font, + "font_style": self.slant, + "font_weight": self.weight, + } + # `line_height` attribute is supported since Pango 1.50. + if tuple(map(int, manimpango.pango_version().split("."))) >= (1, 50): + result.update({ + "line_height": str(( + (self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1 + ) * 0.6), + }) + return result def get_tag_pairs_from_markup( self @@ -264,8 +265,8 @@ class MarkupText(LabelledString): if not self.is_markup: return [] - tag_pattern = r"""<(/?)(\w+)\s*((?:\w+\s*\=\s*(['"]).*?\4\s*)*)>""" - attr_pattern = r"""(\w+)\s*\=\s*(['"])(.*?)\2""" + tag_pattern = r"""<(/?)(\w+)\s*((\w+\s*\=\s*(['"])[\s\S]*?\5\s*)*)>""" + attr_pattern = r"""(\w+)\s*\=\s*(['"])([\s\S]*?)\2""" begin_match_obj_stack = [] match_obj_pairs = [] for match_obj in re.finditer(tag_pattern, self.string): @@ -275,16 +276,10 @@ class MarkupText(LabelledString): match_obj_pairs.append( (begin_match_obj_stack.pop(), match_obj) ) - if begin_match_obj_stack: - raise ValueError("Unclosed tag(s) detected") result = [] for begin_match_obj, end_match_obj in match_obj_pairs: tag_name = begin_match_obj.group(2) - if tag_name != end_match_obj.group(2): - raise ValueError("Unmatched tag names") - if end_match_obj.group(3): - raise ValueError("Attributes shan't exist in ending tags") if tag_name == "span": attr_dict = { match.group(1): match.group(3) @@ -292,14 +287,8 @@ class MarkupText(LabelledString): attr_pattern, begin_match_obj.group(3) ) } - elif tag_name in MARKUP_TAG_CONVERSION_DICT.keys(): - if begin_match_obj.group(3): - raise ValueError( - f"Attributes shan't exist in tag '{tag_name}'" - ) - attr_dict = MARKUP_TAG_CONVERSION_DICT[tag_name].copy() else: - raise ValueError(f"Unknown tag: '{tag_name}'") + attr_dict = MARKUP_TAG_CONVERSION_DICT.get(tag_name, {}) result.append( (begin_match_obj.span(), end_match_obj.span(), attr_dict) @@ -320,8 +309,6 @@ class MarkupText(LabelledString): in self.tag_pairs_from_markup ] - # Parsing - def get_skippable_indices(self) -> list[int]: return self.find_indices(r"\s") @@ -335,20 +322,9 @@ class MarkupText(LabelledString): return [span for span, _ in self.items_from_markup] def get_extra_isolated_items(self) -> list[tuple[Span, dict[str, str]]]: - result = [ - (self.full_span, { - "line_height": str(( - (self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1 - ) * 0.6), - "font_family": self.font, - "font_size": str(self.font_size * 1024), - "font_style": self.slant, - "font_weight": self.weight, - "foreground": self.int_to_hex(self.base_color_int) - }), - (self.full_span, self.global_config), - *self.items_from_markup, - *[ + return list(it.chain( + self.items_from_markup, + [ (span, {key: val}) for t2x_dict, key in ( (self.t2c, "foreground"), @@ -359,24 +335,12 @@ class MarkupText(LabelledString): for selector, val in t2x_dict.items() for span in self.find_spans_by_selector(selector) ], - *[ + [ (span, local_config) for selector, local_config in self.local_configs.items() for span in self.find_spans_by_selector(selector) ] - ] - key_conversion_dict = { - key: key_alias_list[0] - for key_alias_list in MARKUP_KEY_ALIAS_LIST - for key in key_alias_list - } - return [ - (span, { - key_conversion_dict[key.lower()]: val - for key, val in attr_dict.items() - }) - for span, attr_dict in result - ] + )) def get_label_span_list(self) -> list[Span]: interval_spans = sorted(it.chain( @@ -398,14 +362,20 @@ class MarkupText(LabelledString): ])) def get_content(self, is_labelled: bool) -> str: + predefined_items = [ + (self.full_span, self.global_attr_dict), + (self.full_span, self.global_config), + *self.specified_items + ] if is_labelled: attr_dict_items = list(it.chain( [ (span, { - key: BLACK if key in MARKUP_COLOR_KEYS else val + key: + "black" if key.lower() in MARKUP_COLOR_KEYS else val for key, val in attr_dict.items() }) - for span, attr_dict in self.specified_items + for span, attr_dict in predefined_items ], [ (span, {"foreground": self.int_to_hex(label + 1)}) @@ -414,7 +384,7 @@ class MarkupText(LabelledString): )) else: attr_dict_items = list(it.chain( - self.specified_items, + predefined_items, [ (span, {}) for span in self.label_span_list @@ -425,7 +395,7 @@ class MarkupText(LabelledString): f"", "" )) - for span, attr_dict in attr_dict_items + for span, attr_dict in attr_dict_items if attr_dict ] repl_items = [ (tag_span, "") for tag_span in self.tag_spans @@ -445,7 +415,7 @@ class MarkupText(LabelledString): ) return self.get_replaced_substr(self.full_span, span_repl_dict) - # Post-parsing + # Selector def get_cleaned_substr(self, span: Span) -> str: repl_dict = dict.fromkeys(self.tag_spans, "") diff --git a/manimlib/utils/iterables.py b/manimlib/utils/iterables.py index ecbdbc1a..bdaa76c2 100644 --- a/manimlib/utils/iterables.py +++ b/manimlib/utils/iterables.py @@ -135,10 +135,14 @@ def make_even( def hash_obj(obj: object) -> int: if isinstance(obj, dict): - new_obj = {k: hash_obj(v) for k, v in obj.items()} - return hash(tuple(frozenset(sorted(new_obj.items())))) + return hash(tuple(sorted([ + (hash_obj(k), hash_obj(v)) for k, v in obj.items() + ]))) - if isinstance(obj, (set, tuple, list)): + if isinstance(obj, set): + return hash(tuple(sorted(hash_obj(e) for e in obj))) + + if isinstance(obj, (tuple, list)): return hash(tuple(hash_obj(e) for e in obj)) if isinstance(obj, Color): From 8852921b3d8c627a2f2855cd6b8e633d17a958aa Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Mon, 18 Apr 2022 19:44:32 +0800 Subject: [PATCH 30/64] Refactor double brace parsing --- manimlib/mobject/svg/mtex_mobject.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 8b99b924..d4f502f4 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -214,23 +214,19 @@ class MTex(LabelledString): result = [] # Match paired double braces (`{{...}}`). - reversed_brace_spans_dict = dict(sorted([ - pair[::-1] for pair in self.brace_spans - ])) + sorted_brace_spans = sorted( + self.brace_spans, key=lambda span: span[1] + ) skip = False - for prev_brace_end, brace_end in self.get_neighbouring_pairs( - list(reversed_brace_spans_dict.keys()) + for prev_span, span in self.get_neighbouring_pairs( + sorted_brace_spans ): if skip: skip = False continue - if brace_end != prev_brace_end + 1: + if span[0] != prev_span[0] - 1 or span[1] != prev_span[1] + 1: continue - brace_begin = reversed_brace_spans_dict[brace_end] - prev_brace_begin = reversed_brace_spans_dict[prev_brace_end] - if brace_begin != prev_brace_begin - 1: - continue - result.append((brace_begin, brace_end)) + result.append(span) skip = True result.extend(it.chain(*[ From c04615c4e97cb6c962ffd8b673d27608adbd26fb Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Thu, 21 Apr 2022 14:30:39 -0700 Subject: [PATCH 31/64] In Mobject.set_uniforms, copy uniforms that are numpy arrays --- manimlib/mobject/mobject.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index d09773ed..170f34a2 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -136,8 +136,10 @@ class Mobject(object): return self def set_uniforms(self, uniforms: dict): - for key in uniforms: - self.uniforms[key] = uniforms[key] # Copy? + for key, value in uniforms.items(): + if isinstance(value, np.ndarray): + value = value.copy() + self.uniforms[key] = value return self @property From fe3e10acd29a3dd6f8b485c0e36ead819f2d937b Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Thu, 21 Apr 2022 14:32:27 -0700 Subject: [PATCH 32/64] Updates to copying based on pickle serializing --- manimlib/mobject/mobject.py | 80 +++++++++++++++++++++---------------- 1 file changed, 46 insertions(+), 34 deletions(-) diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 170f34a2..a14a3ad6 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -462,32 +462,21 @@ class Mobject(object): self.assemble_family() return self - # Creating new Mobjects from this one + # Copying and serialization - def replicate(self, n: int) -> Group: - return self.get_group_class()( - *(self.copy() for x in range(n)) - ) - - def get_grid(self, n_rows: int, n_cols: int, height: float | None = None, **kwargs): - """ - Returns a new mobject containing multiple copies of this one - arranged in a grid - """ - grid = self.replicate(n_rows * n_cols) - grid.arrange_in_grid(n_rows, n_cols, **kwargs) - if height is not None: - grid.set_height(height) - return grid - - # Copying + def serialize(self): + pre, self.parents = self.parents, [] + result = pickle.dumps(self) + self.parents = pre + return result def copy(self): - self.parents = [] try: - return pickle.loads(pickle.dumps(self)) + serial = self.serialize() + return pickle.loads(serial) except AttributeError: return copy.deepcopy(self) + return result def deepcopy(self): # This used to be different from copy, so is now just here for backward compatibility @@ -513,7 +502,7 @@ class Mobject(object): self.become(self.saved_state) return self - def save_to_file(self, file_path): + def save_to_file(self, file_path: str): if not file_path.endswith(".mob"): file_path += ".mob" if os.path.exists(file_path): @@ -521,7 +510,7 @@ class Mobject(object): if cont != "y": return with open(file_path, "wb") as fp: - pickle.dump(self, fp) + fp.write(self.serialize()) log.info(f"Saved mobject to {file_path}") return self @@ -534,6 +523,41 @@ class Mobject(object): mobject = pickle.load(fp) return mobject + def become(self, mobject: Mobject): + """ + Edit all data and submobjects to be idential + to another mobject + """ + self.align_family(mobject) + for sm1, sm2 in zip(self.get_family(), mobject.get_family()): + sm1.set_data(sm2.data) + sm1.set_uniforms(sm2.uniforms) + sm1.shader_folder = sm2.shader_folder + sm1.texture_paths = sm2.texture_paths + sm1.depth_test = sm2.depth_test + sm1.render_primitive = sm2.render_primitive + self.refresh_shader_wrapper_id() + self.refresh_bounding_box(recurse_down=True) + return self + + # Creating new Mobjects from this one + + def replicate(self, n: int) -> Group: + serial = self.serialize() + group_class = self.get_group_class() + return group_class(*(pickle.loads(serial) for _ in range(n))) + + def get_grid(self, n_rows: int, n_cols: int, height: float | None = None, **kwargs) -> Group: + """ + Returns a new mobject containing multiple copies of this one + arranged in a grid + """ + grid = self.replicate(n_rows * n_cols) + grid.arrange_in_grid(n_rows, n_cols, **kwargs) + if height is not None: + grid.set_height(height) + return grid + # Updating def init_updaters(self): @@ -1521,18 +1545,6 @@ class Mobject(object): """ pass # To implement in subclass - def become(self, mobject: Mobject): - """ - Edit all data and submobjects to be idential - to another mobject - """ - self.align_family(mobject) - for sm1, sm2 in zip(self.get_family(), mobject.get_family()): - sm1.set_data(sm2.data) - sm1.set_uniforms(sm2.uniforms) - self.refresh_bounding_box(recurse_down=True) - return self - # Locking data def lock_data(self, keys: Iterable[str]): From 9d5e2b32fa9215219d11a601829126cea40410d1 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Thu, 21 Apr 2022 14:32:39 -0700 Subject: [PATCH 33/64] Add VHighlight --- manimlib/mobject/types/vectorized_mobject.py | 24 ++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index 6615f715..538ea655 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -10,6 +10,8 @@ import moderngl import numpy.typing as npt from manimlib.constants import * +from manimlib.constants import GREY_C +from manimlib.constants import GREY_E from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Point from manimlib.utils.bezier import bezier @@ -20,6 +22,7 @@ from manimlib.utils.bezier import interpolate from manimlib.utils.bezier import inverse_interpolate from manimlib.utils.bezier import integer_interpolate from manimlib.utils.bezier import partial_quadratic_bezier_points +from manimlib.utils.color import color_gradient from manimlib.utils.color import rgb_to_hex from manimlib.utils.iterables import make_even from manimlib.utils.iterables import resize_array @@ -1174,3 +1177,24 @@ class DashedVMobject(VMobject): # Family is already taken care of by get_subcurve # implementation self.match_style(vmobject, recurse=False) + + +class VHighlight(VGroup): + def __init__( + self, + vmobject: VMobject, + n_layers: int = 3, + color_bounds: tuple[ManimColor] = (GREY_C, GREY_E), + max_stroke_width: float = 10.0, + ): + outline = vmobject.replicate(n_layers) + outline.set_fill(opacity=0) + added_widths = np.linspace(0, max_stroke_width, n_layers + 1)[1:] + colors = color_gradient(color_bounds, n_layers) + for part, added_width, color in zip(reversed(outline), added_widths, colors): + for sm in part.family_members_with_points(): + part.set_stroke( + width=sm.get_stroke_width() + added_width, + color=color, + ) + super().__init__(*outline) From f53f202dcd250111a90604329a4630fe62919375 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Thu, 21 Apr 2022 15:00:58 -0700 Subject: [PATCH 34/64] A few small cleanups --- manimlib/mobject/mobject.py | 8 +----- manimlib/scene/interactive_scene.py | 42 +++++++++++++++-------------- 2 files changed, 23 insertions(+), 27 deletions(-) diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index a14a3ad6..10ebc35b 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -502,13 +502,7 @@ class Mobject(object): self.become(self.saved_state) return self - def save_to_file(self, file_path: str): - if not file_path.endswith(".mob"): - file_path += ".mob" - if os.path.exists(file_path): - cont = input(f"{file_path} already exists. Overwrite (y/n)? ") - if cont != "y": - return + def save_to_file(self, file_path: str, supress_overwrite_warning: bool = False): with open(file_path, "wb") as fp: fp.write(self.serialize()) log.info(f"Saved mobject to {file_path}") diff --git a/manimlib/scene/interactive_scene.py b/manimlib/scene/interactive_scene.py index d1cc6d1f..3d76dd1e 100644 --- a/manimlib/scene/interactive_scene.py +++ b/manimlib/scene/interactive_scene.py @@ -138,14 +138,6 @@ class InteractiveScene(Scene): palette.fix_in_frame() return palette - def get_stroke_highlight(self, vmobject): - outline = vmobject.copy() - for sm, osm in zip(vmobject.get_family(), outline.get_family()): - osm.set_fill(opacity=0) - osm.set_stroke(YELLOW, width=sm.get_stroke_width() + 1.5) - outline.add_updater(lambda o: o.replace(vmobject)) - return outline - def get_corner_dots(self, mobject): dots = DotCloud(**self.corner_dot_config) radius = self.corner_dot_config["radius"] @@ -160,8 +152,10 @@ class InteractiveScene(Scene): return dots def get_highlight(self, mobject): - if isinstance(mobject, VMobject) and mobject.has_points(): - return self.get_stroke_highlight(mobject) + if isinstance(mobject, VMobject) and mobject.has_points() and not self.select_top_level_mobs: + result = VHighlight(mobject) + result.add_updater(lambda m: m.replace(mobject)) + return result else: return self.get_corner_dots(mobject) @@ -182,10 +176,14 @@ class InteractiveScene(Scene): return rect def add_to_selection(self, *mobjects): - mobs = list(filter(lambda m: m not in self.unselectables, mobjects)) - self.selection.add(*mobjects) - self.selection_highlight.add(*map(self.get_highlight, mobs)) - self.saved_selection_state = [(mob, mob.copy()) for mob in self.selection] + mobs = list(filter( + lambda m: m not in self.unselectables and m not in self.selection, + mobjects + )) + if mobs: + self.selection.add(*mobs) + self.selection_highlight.add(*map(self.get_highlight, mobs)) + self.saved_selection_state = [(mob, mob.copy()) for mob in self.selection] def toggle_from_selection(self, *mobjects): for mob in mobjects: @@ -382,16 +380,16 @@ class InteractiveScene(Scene): def on_mouse_motion(self, point: np.ndarray, d_point: np.ndarray) -> None: super().on_mouse_motion(point, d_point) # Move selection - if self.window.is_key_pressed(ord("g")): + if self.window.is_key_pressed(ord(GRAB_KEY)): self.selection.move_to(point - self.mouse_to_selection) # Move selection restricted to horizontal - elif self.window.is_key_pressed(ord("h")): + elif self.window.is_key_pressed(ord(HORIZONTAL_GRAB_KEY)): self.selection.set_x((point - self.mouse_to_selection)[0]) # Move selection restricted to vertical - elif self.window.is_key_pressed(ord("v")): + elif self.window.is_key_pressed(ord(VERTICAL_GRAB_KEY)): self.selection.set_y((point - self.mouse_to_selection)[1]) # Scale selection - elif self.window.is_key_pressed(ord("t")): + elif self.window.is_key_pressed(ord(RESIZE_KEY)): # TODO, allow for scaling about the opposite corner vect = point - self.scale_about_point scalar = get_norm(vect) / get_norm(self.scale_ref_vect) @@ -411,10 +409,14 @@ class InteractiveScene(Scene): ))) mob = self.point_to_mobject(point, to_search) if mob is not None: - self.selection.set_color(mob.get_fill_color()) + self.selection.set_color(mob.get_color()) self.remove(self.color_palette) elif self.window.is_key_pressed(SHIFT_SYMBOL): - mob = self.point_to_mobject(point) + mob = self.point_to_mobject( + point, + search_set=self.get_selection_search_set(), + buff=SMALL_BUFF + ) if mob is not None: self.toggle_from_selection(mob) else: From 3a60ab144bd3c7b5f13851797fa0f7c507dc33c8 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Thu, 21 Apr 2022 15:01:30 -0700 Subject: [PATCH 35/64] Remove saved mobject directory logic from InteractiveScene --- manimlib/scene/interactive_scene.py | 30 ++++------------------------- 1 file changed, 4 insertions(+), 26 deletions(-) diff --git a/manimlib/scene/interactive_scene.py b/manimlib/scene/interactive_scene.py index 3d76dd1e..8304c453 100644 --- a/manimlib/scene/interactive_scene.py +++ b/manimlib/scene/interactive_scene.py @@ -240,31 +240,6 @@ class InteractiveScene(Scene): self.remove(*self.selection) self.clear_selection() - def saved_selection_to_file(self): - directory = self.file_writer.get_saved_mobject_directory() - files = os.listdir(directory) - for mob in self.selection: - file_name = str(mob) + "_0.mob" - index = 0 - while file_name in files: - file_name = file_name.replace(str(index), str(index + 1)) - index += 1 - if platform.system() == 'Darwin': - user_name = os.popen(f""" - osascript -e ' - set chosenfile to (choose file name default name "{file_name}" default location "{directory}") - POSIX path of chosenfile' - """).read() - user_name = user_name.replace("\n", "") - else: - user_name = input( - f"Enter mobject file name (default is {file_name}): " - ) - if user_name: - file_name = user_name - files.append(file_name) - self.save_mobect(mob, file_name) - def undo(self): mobs = [] for mob, state in self.saved_selection_state: @@ -355,7 +330,10 @@ class InteractiveScene(Scene): self.undo() # Command + s -> Save selections to file elif char == "s" and modifiers == COMMAND_MODIFIER: - self.saved_selection_to_file() + to_save = self.selection + if len(to_save) == 1: + to_save = to_save[0] + self.save_mobect(to_save) # Keyboard movements elif symbol in ARROW_SYMBOLS: nudge = self.selection_nudge_size From 4caa03332367631d2fff15afd7e56b15fe8701ee Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Thu, 21 Apr 2022 15:01:54 -0700 Subject: [PATCH 36/64] Allow for sweeping selection --- manimlib/scene/interactive_scene.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/manimlib/scene/interactive_scene.py b/manimlib/scene/interactive_scene.py index 8304c453..4c479030 100644 --- a/manimlib/scene/interactive_scene.py +++ b/manimlib/scene/interactive_scene.py @@ -1,11 +1,9 @@ import numpy as np import itertools as it import pyperclip -import os -import platform from manimlib.animation.fading import FadeIn -from manimlib.constants import MANIM_COLORS, WHITE, YELLOW +from manimlib.constants import MANIM_COLORS, WHITE from manimlib.constants import ORIGIN, UP, DOWN, LEFT, RIGHT, DL, UL, UR, DR from manimlib.constants import FRAME_WIDTH, SMALL_BUFF from manimlib.constants import SHIFT_SYMBOL, DELETE_SYMBOL, ARROW_SYMBOLS @@ -18,6 +16,7 @@ from manimlib.mobject.svg.tex_mobject import Tex from manimlib.mobject.svg.text_mobject import Text from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.mobject.types.vectorized_mobject import VGroup +from manimlib.mobject.types.vectorized_mobject import VHighlight from manimlib.mobject.types.dot_cloud import DotCloud from manimlib.scene.scene import Scene from manimlib.utils.tex_file_writing import LatexError @@ -375,6 +374,14 @@ class InteractiveScene(Scene): scalar * self.scale_ref_width, about_point=self.scale_about_point ) + # Add to selection + elif self.window.is_key_pressed(ord(SELECT_KEY)) and self.window.is_key_pressed(SHIFT_SYMBOL): + mob = self.point_to_mobject( + point, search_set=self.get_selection_search_set(), + buff=SMALL_BUFF + ) + if mob is not None: + self.add_to_selection(mob) def on_mouse_release(self, point: np.ndarray, button: int, mods: int) -> None: super().on_mouse_release(point, button, mods) From 78a707877212d0df5383cc26e82a62209bc13bee Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Thu, 21 Apr 2022 15:02:11 -0700 Subject: [PATCH 37/64] Move saved mobject directory logic to scene_file_writer.py --- manimlib/scene/scene.py | 10 ++++++---- manimlib/scene/scene_file_writer.py | 31 +++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/manimlib/scene/scene.py b/manimlib/scene/scene.py index 01037f62..5f91bb37 100644 --- a/manimlib/scene/scene.py +++ b/manimlib/scene/scene.py @@ -605,10 +605,12 @@ class Scene(object): mob.become(mob_state) self.mobjects.append(mob) - def save_mobect(self, mobject: Mobject, file_name: str): - directory = self.file_writer.get_saved_mobject_directory() - path = os.path.join(directory, file_name) - mobject.save_to_file(path) + def save_mobect(self, mobject: Mobject, file_path: str | None = None) -> None: + if file_path is None: + file_path = self.file_writer.get_saved_mobject_path(mobject) + if file_path is None: + return + mobject.save_to_file(file_path) def load_mobject(self, file_name): if os.path.exists(file_name): diff --git a/manimlib/scene/scene_file_writer.py b/manimlib/scene/scene_file_writer.py index 869db27b..a46ec01d 100644 --- a/manimlib/scene/scene_file_writer.py +++ b/manimlib/scene/scene_file_writer.py @@ -11,6 +11,7 @@ from pydub import AudioSegment from tqdm import tqdm as ProgressDisplay from manimlib.constants import FFMPEG_BIN +from manimlib.mobject.mobject import Mobject from manimlib.utils.config_ops import digest_config from manimlib.utils.file_ops import guarantee_existence from manimlib.utils.file_ops import add_extension_if_not_present @@ -127,6 +128,36 @@ class SceneFileWriter(object): str(self.scene), )) + def get_saved_mobject_path(self, mobject: Mobject) -> str | None: + directory = self.get_saved_mobject_directory() + files = os.listdir(directory) + default_name = str(mobject) + "_0.mob" + index = 0 + while default_name in files: + default_name = default_name.replace(str(index), str(index + 1)) + index += 1 + if platform.system() == 'Darwin': + cmds = [ + "osascript", "-e", + f""" + set chosenfile to (choose file name default name "{default_name}" default location "{directory}") + POSIX path of chosenfile + """, + ] + process = sp.Popen(cmds, stdout=sp.PIPE) + file_path = process.stdout.read().decode("utf-8").split("\n")[0] + if not file_path: + return + else: + user_name = input(f"Enter mobject file name (default is {default_name}): ") + file_path = os.path.join(directory, user_name or default_name) + if os.path.exists(file_path) or os.path.exists(file_path + ".mob"): + if input(f"{file_path} already exists. Overwrite (y/n)? ") != "y": + return + if not file_path.endswith(".mob"): + file_path = file_path + ".mob" + return file_path + # Sound def init_audio(self) -> None: self.includes_sound: bool = False From b4b72d1b68d0993b96a6af76c4bb6816f77f0f12 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Thu, 21 Apr 2022 15:31:46 -0700 Subject: [PATCH 38/64] Allow stretched-resizing --- manimlib/constants.py | 1 + manimlib/scene/interactive_scene.py | 24 +++++++++++++++++------- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/manimlib/constants.py b/manimlib/constants.py index cc73c0ac..88859df7 100644 --- a/manimlib/constants.py +++ b/manimlib/constants.py @@ -74,6 +74,7 @@ DEFAULT_STROKE_WIDTH = 4 # For keyboard interactions CTRL_SYMBOL = 65508 SHIFT_SYMBOL = 65505 +COMMAND_SYMBOL = 65517 DELETE_SYMBOL = 65288 ARROW_SYMBOLS = list(range(65361, 65365)) diff --git a/manimlib/scene/interactive_scene.py b/manimlib/scene/interactive_scene.py index 4c479030..7ce23a89 100644 --- a/manimlib/scene/interactive_scene.py +++ b/manimlib/scene/interactive_scene.py @@ -6,7 +6,7 @@ from manimlib.animation.fading import FadeIn from manimlib.constants import MANIM_COLORS, WHITE from manimlib.constants import ORIGIN, UP, DOWN, LEFT, RIGHT, DL, UL, UR, DR from manimlib.constants import FRAME_WIDTH, SMALL_BUFF -from manimlib.constants import SHIFT_SYMBOL, DELETE_SYMBOL, ARROW_SYMBOLS +from manimlib.constants import SHIFT_SYMBOL, CTRL_SYMBOL, DELETE_SYMBOL, ARROW_SYMBOLS from manimlib.constants import SHIFT_MODIFIER, COMMAND_MODIFIER from manimlib.mobject.mobject import Mobject from manimlib.mobject.geometry import Rectangle @@ -258,6 +258,7 @@ class InteractiveScene(Scene): self.scale_about_point = center self.scale_ref_vect = mp - self.scale_about_point self.scale_ref_width = self.selection.get_width() + self.scale_ref_height = self.selection.get_height() # Event handlers @@ -367,13 +368,22 @@ class InteractiveScene(Scene): self.selection.set_y((point - self.mouse_to_selection)[1]) # Scale selection elif self.window.is_key_pressed(ord(RESIZE_KEY)): - # TODO, allow for scaling about the opposite corner vect = point - self.scale_about_point - scalar = get_norm(vect) / get_norm(self.scale_ref_vect) - self.selection.set_width( - scalar * self.scale_ref_width, - about_point=self.scale_about_point - ) + if self.window.is_key_pressed(CTRL_SYMBOL): + for i in (0, 1): + scalar = vect[i] / self.scale_ref_vect[i] + self.selection.rescale_to_fit( + scalar * [self.scale_ref_width, self.scale_ref_height][i], + dim=i, + about_point=self.scale_about_point, + stretch=True, + ) + else: + scalar = get_norm(vect) / get_norm(self.scale_ref_vect) + self.selection.set_width( + scalar * self.scale_ref_width, + about_point=self.scale_about_point + ) # Add to selection elif self.window.is_key_pressed(ord(SELECT_KEY)) and self.window.is_key_pressed(SHIFT_SYMBOL): mob = self.point_to_mobject( From f8c8a399c91b6bd5e4612f028c2beda0acf21afc Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Fri, 22 Apr 2022 15:31:13 +0800 Subject: [PATCH 39/64] Revert some files --- manimlib/animation/creation.py | 14 +- .../animation/transform_matching_parts.py | 120 ++--- manimlib/mobject/svg/labelled_string.py | 502 ++++++++++-------- manimlib/mobject/svg/mtex_mobject.py | 342 ++++++------ manimlib/mobject/svg/text_mobject.py | 501 +++++++++-------- manimlib/mobject/types/image_mobject.py | 3 + 6 files changed, 799 insertions(+), 683 deletions(-) diff --git a/manimlib/animation/creation.py b/manimlib/animation/creation.py index 6ad6a9bd..27460899 100644 --- a/manimlib/animation/creation.py +++ b/manimlib/animation/creation.py @@ -1,12 +1,12 @@ from __future__ import annotations -from abc import ABC, abstractmethod +import itertools as it +from abc import abstractmethod import numpy as np from manimlib.animation.animation import Animation from manimlib.mobject.svg.labelled_string import LabelledString -from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.bezier import integer_interpolate from manimlib.utils.config_ops import digest_config @@ -17,10 +17,10 @@ from manimlib.utils.rate_functions import smooth from typing import TYPE_CHECKING if TYPE_CHECKING: - from manimlib.mobject.mobject import Mobject + from manimlib.mobject.mobject import Group -class ShowPartial(Animation, ABC): +class ShowPartial(Animation): """ Abstract class for ShowCreation and ShowPassingFlash """ @@ -176,7 +176,7 @@ class ShowIncreasingSubsets(Animation): "int_func": np.round, } - def __init__(self, group: Mobject, **kwargs): + def __init__(self, group: Group, **kwargs): self.all_submobs = list(group.submobjects) super().__init__(group, **kwargs) @@ -213,9 +213,7 @@ class AddTextWordByWord(ShowIncreasingSubsets): def __init__(self, string_mobject, **kwargs): assert isinstance(string_mobject, LabelledString) - grouped_mobject = VGroup(*[ - part for _, part in string_mobject.get_group_part_items() - ]) + grouped_mobject = string_mobject.submob_groups digest_config(self, kwargs) if self.run_time is None: self.run_time = self.time_per_word * len(grouped_mobject) diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index e84f1d9d..dab88005 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -5,9 +5,9 @@ import itertools as it import numpy as np from manimlib.animation.composition import AnimationGroup +from manimlib.animation.fading import FadeTransformPieces from manimlib.animation.fading import FadeInFromPoint from manimlib.animation.fading import FadeOutToPoint -from manimlib.animation.fading import FadeTransformPieces from manimlib.animation.transform import ReplacementTransform from manimlib.animation.transform import Transform from manimlib.mobject.mobject import Mobject @@ -16,13 +16,13 @@ from manimlib.mobject.svg.labelled_string import LabelledString from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.config_ops import digest_config +from manimlib.utils.iterables import remove_list_redundancies from typing import TYPE_CHECKING if TYPE_CHECKING: - from manimlib.mobject.svg.tex_mobject import SingleStringTex - from manimlib.mobject.svg.tex_mobject import Tex from manimlib.scene.scene import Scene + from manimlib.mobject.svg.tex_mobject import Tex, SingleStringTex class TransformMatchingParts(AnimationGroup): @@ -168,36 +168,36 @@ class TransformMatchingStrings(AnimationGroup): assert isinstance(source, LabelledString) assert isinstance(target, LabelledString) anims = [] + source_indices = list(range(len(source.labelled_submobjects))) + target_indices = list(range(len(target.labelled_submobjects))) - source_submobs = [ - submob for _, submob in source.labelled_submobject_items - ] - target_submobs = [ - submob for _, submob in target.labelled_submobject_items - ] - source_indices = list(range(len(source_submobs))) - target_indices = list(range(len(target_submobs))) - - def get_filtered_indices_lists(parts, submobs, rest_indices): - return list(filter( - lambda indices_list: all([ - index in rest_indices - for index in indices_list - ]), + def get_indices_lists(mobject, parts): + return [ [ - [submobs.index(submob) for submob in part] - for part in parts + mobject.labelled_submobjects.index(submob) + for submob in part ] - )) + for part in parts + ] - def add_anims(anim_class, parts_pairs): - for source_parts, target_parts in parts_pairs: - source_indices_lists = get_filtered_indices_lists( - source_parts, source_submobs, source_indices - ) - target_indices_lists = get_filtered_indices_lists( - target_parts, target_submobs, target_indices - ) + def add_anims_from(anim_class, func, source_args, target_args=None): + if target_args is None: + target_args = source_args.copy() + for source_arg, target_arg in zip(source_args, target_args): + source_parts = func(source, source_arg) + target_parts = func(target, target_arg) + source_indices_lists = list(filter( + lambda indices_list: all([ + index in source_indices + for index in indices_list + ]), get_indices_lists(source, source_parts) + )) + target_indices_lists = list(filter( + lambda indices_list: all([ + index in target_indices + for index in indices_list + ]), get_indices_lists(target, target_parts) + )) if not source_indices_lists or not target_indices_lists: continue anims.append(anim_class(source_parts, target_parts, **kwargs)) @@ -206,45 +206,41 @@ class TransformMatchingStrings(AnimationGroup): for index in it.chain(*target_indices_lists): target_indices.remove(index) - def get_substr_to_parts_map(part_items): - result = {} - for substr, part in part_items: - if substr not in result: - result[substr] = [] - result[substr].append(part) + def get_common_substrs(substrs_from_source, substrs_from_target): + return sorted([ + substr for substr in substrs_from_source + if substr and substr in substrs_from_target + ], key=len, reverse=True) + + def get_parts_from_keys(mobject, keys): + if isinstance(keys, str): + keys = [keys] + result = VGroup() + for key in keys: + if not isinstance(key, str): + raise TypeError(key) + result.add(*mobject.get_parts_by_string(key)) return result - def add_anims_from(anim_class, func): - source_substr_to_parts_map = get_substr_to_parts_map(func(source)) - target_substr_to_parts_map = get_substr_to_parts_map(func(target)) - add_anims( - anim_class, - [ - ( - VGroup(*source_substr_to_parts_map[substr]), - VGroup(*target_substr_to_parts_map[substr]) - ) - for substr in sorted([ - s for s in source_substr_to_parts_map.keys() - if s and s in target_substr_to_parts_map.keys() - ], key=len, reverse=True) - ] + add_anims_from( + ReplacementTransform, get_parts_from_keys, + self.key_map.keys(), self.key_map.values() + ) + add_anims_from( + FadeTransformPieces, + LabelledString.get_parts_by_string, + get_common_substrs( + source.specified_substrs, + target.specified_substrs ) - - add_anims( - ReplacementTransform, - [ - (source.select_parts(k), target.select_parts(v)) - for k, v in self.key_map.items() - ] ) add_anims_from( FadeTransformPieces, - LabelledString.get_specified_part_items - ) - add_anims_from( - FadeTransformPieces, - LabelledString.get_group_part_items + LabelledString.get_parts_by_group_substr, + get_common_substrs( + source.group_substrs, + target.group_substrs + ) ) rest_source = VGroup(*[source[index] for index in source_indices]) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 55b8fca6..f1354f0c 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -1,41 +1,30 @@ from __future__ import annotations -from abc import ABC, abstractmethod -import itertools as it import re +import colour +import itertools as it +from typing import Iterable, Union, Sequence +from abc import ABC, abstractmethod -from manimlib.constants import WHITE +from manimlib.constants import BLACK, WHITE from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.types.vectorized_mobject import VGroup +from manimlib.utils.color import color_to_int_rgb from manimlib.utils.color import color_to_rgb from manimlib.utils.color import rgb_to_hex from manimlib.utils.config_ops import digest_config from manimlib.utils.iterables import remove_list_redundancies + from typing import TYPE_CHECKING if TYPE_CHECKING: - from colour import Color - from typing import Iterable, Union - - ManimColor = Union[str, Color] + from manimlib.mobject.types.vectorized_mobject import VMobject + ManimColor = Union[str, colour.Color, Sequence[float]] Span = tuple[int, int] - Selector = Union[ - str, - re.Pattern, - tuple[Union[int, None], Union[int, None]], - Iterable[Union[ - str, - re.Pattern, - tuple[Union[int, None], Union[int, None]] - ]] - ] -class LabelledString(SVGMobject, ABC): - """ - An abstract base class for `MTex` and `MarkupText` - """ +class _StringSVG(SVGMobject): CONFIG = { "height": None, "stroke_width": 0, @@ -44,31 +33,42 @@ class LabelledString(SVGMobject, ABC): "should_subdivide_sharp_curves": True, "should_remove_null_curves": True, }, + } + + +class LabelledString(_StringSVG, ABC): + """ + An abstract base class for `MTex` and `MarkupText` + """ + CONFIG = { "base_color": WHITE, + "use_plain_file": False, "isolate": [], } def __init__(self, string: str, **kwargs): self.string = string digest_config(self, kwargs) - if self.base_color is None: - self.base_color = WHITE - self.base_color_int = self.color_to_int(self.base_color) - self.string_len = len(self.string) - self.full_span = (0, self.string_len) + # Convert `base_color` to hex code. + self.base_color = rgb_to_hex(color_to_rgb( + self.base_color \ + or self.svg_default.get("color", None) \ + or self.svg_default.get("fill_color", None) \ + or WHITE + )) + self.svg_default["fill_color"] = BLACK + + self.pre_parse() self.parse() super().__init__() - self.labelled_submobject_items = [ - (submob.label, submob) - for submob in self.submobjects - ] + self.post_parse() def get_file_path(self) -> str: - return self.get_file_path_(is_labelled=False) + return self.get_file_path_(use_plain_file=False) - def get_file_path_(self, is_labelled: bool) -> str: - content = self.get_content(is_labelled) + def get_file_path_(self, use_plain_file: bool) -> str: + content = self.get_content(use_plain_file) return self.get_file_path_by_content(content) @abstractmethod @@ -78,113 +78,87 @@ class LabelledString(SVGMobject, ABC): def generate_mobject(self) -> None: super().generate_mobject() - num_labels = len(self.label_span_list) - if num_labels: - file_path = self.get_file_path_(is_labelled=True) - labelled_svg = SVGMobject(file_path) - submob_color_ints = [ - self.color_to_int(submob.get_fill_color()) - for submob in labelled_svg.submobjects - ] + submob_labels = [ + self.color_to_label(submob.get_fill_color()) + for submob in self.submobjects + ] + if self.use_plain_file or self.has_predefined_local_colors: + file_path = self.get_file_path_(use_plain_file=True) + plain_svg = _StringSVG( + file_path, + svg_default=self.svg_default, + path_string_config=self.path_string_config + ) + self.set_submobjects(plain_svg.submobjects) else: - submob_color_ints = [0] * len(self.submobjects) + self.set_fill(self.base_color) + for submob, label in zip(self.submobjects, submob_labels): + submob.label = label - if len(self.submobjects) != len(submob_color_ints): - raise ValueError( - "Cannot align submobjects of the labelled svg " - "to the original svg" - ) - - unrecognized_color_ints = remove_list_redundancies(sorted(filter( - lambda color_int: color_int > num_labels, - submob_color_ints - ))) - if unrecognized_color_ints: - raise ValueError( - "Unrecognized color label(s) detected: " - f"{','.join(map(self.int_to_hex, unrecognized_color_ints))}" - ) - - for submob, color_int in zip(self.submobjects, submob_color_ints): - submob.label = color_int - 1 + def pre_parse(self) -> None: + self.string_len = len(self.string) + self.full_span = (0, self.string_len) def parse(self) -> None: - self.skippable_indices = self.get_skippable_indices() + self.command_repl_items = self.get_command_repl_items() + self.command_spans = self.get_command_spans() + self.extra_entity_spans = self.get_extra_entity_spans() self.entity_spans = self.get_entity_spans() - self.bracket_spans = self.get_bracket_spans() - self.extra_isolated_items = self.get_extra_isolated_items() - self.specified_items = self.get_specified_items() + self.extra_ignored_spans = self.get_extra_ignored_spans() + self.skipped_spans = self.get_skipped_spans() + self.internal_specified_spans = self.get_internal_specified_spans() + self.external_specified_spans = self.get_external_specified_spans() self.specified_spans = self.get_specified_spans() - self.check_overlapping() self.label_span_list = self.get_label_span_list() - if len(self.label_span_list) >= 16777216: - raise ValueError("Cannot handle that many substrings") + self.check_overlapping() - def copy(self): - return self.deepcopy() + def post_parse(self) -> None: + self.labelled_submobject_items = [ + (submob.label, submob) + for submob in self.submobjects + ] + self.labelled_submobjects = self.get_labelled_submobjects() + self.specified_substrs = self.get_specified_substrs() + self.group_items = self.get_group_items() + self.group_substrs = self.get_group_substrs() + self.submob_groups = self.get_submob_groups() # Toolkits def get_substr(self, span: Span) -> str: return self.string[slice(*span)] - def match(self, pattern: str | re.Pattern, **kwargs) -> re.Pattern | None: - if isinstance(pattern, str): - pattern = re.compile(pattern) - return re.compile(pattern).match(self.string, **kwargs) + def finditer( + self, pattern: str, flags: int = 0, **kwargs + ) -> Iterable[re.Match]: + return re.compile(pattern, flags).finditer(self.string, **kwargs) - def find_spans(self, pattern: str | re.Pattern, **kwargs) -> list[Span]: - if isinstance(pattern, str): - pattern = re.compile(pattern) + def search( + self, pattern: str, flags: int = 0, **kwargs + ) -> re.Match | None: + return re.compile(pattern, flags).search(self.string, **kwargs) + + def match( + self, pattern: str, flags: int = 0, **kwargs + ) -> re.Match | None: + return re.compile(pattern, flags).match(self.string, **kwargs) + + def find_spans(self, pattern: str, **kwargs) -> list[Span]: return [ match_obj.span() - for match_obj in pattern.finditer(self.string, **kwargs) + for match_obj in self.finditer(pattern, **kwargs) ] - def find_indices(self, pattern: str | re.Pattern, **kwargs) -> list[int]: - return [index for index, _ in self.find_spans(pattern, **kwargs)] + def find_substr(self, substr: str, **kwargs) -> list[Span]: + if not substr: + return [] + return self.find_spans(re.escape(substr), **kwargs) - @staticmethod - def is_single_selector(selector: Selector) -> bool: - if isinstance(selector, str): - return True - if isinstance(selector, re.Pattern): - return True - if isinstance(selector, tuple): - if len(selector) == 2 and all([ - isinstance(index, int) or index is None - for index in selector - ]): - return True - return False - - def find_spans_by_selector(self, selector: Selector) -> list[Span]: - if self.is_single_selector(selector): - selector = (selector,) - result = [] - for sel in selector: - if not self.is_single_selector(sel): - raise TypeError(f"Invalid selector: '{sel}'") - if isinstance(sel, str): - spans = self.find_spans(re.escape(sel)) - elif isinstance(sel, re.Pattern): - spans = self.find_spans(sel) - else: - span = tuple([ - ( - min(index, self.string_len) - if index >= 0 - else max(index + self.string_len, 0) - ) - if index is not None else default_index - for index, default_index in zip(sel, self.full_span) - ]) - spans = [span] - result.extend(spans) - return sorted(filter( - lambda span: span[0] < span[1], - remove_list_redundancies(result) - )) + def find_substrs(self, substrs: list[str], **kwargs) -> list[Span]: + return list(it.chain(*[ + self.find_substr(substr, **kwargs) + for substr in remove_list_redundancies(substrs) + ])) @staticmethod def get_neighbouring_pairs(iterable: list) -> list[tuple]: @@ -223,24 +197,41 @@ class LabelledString(SVGMobject, ABC): spans = LabelledString.get_neighbouring_pairs(indices) return list(zip(unique_vals, spans)) + @staticmethod + def find_region_index(seq: list[int], val: int) -> int: + # Returns an integer in `range(-1, len(seq))` satisfying + # `seq[result] <= val < seq[result + 1]`. + # `seq` should be sorted in ascending order. + if not seq or val < seq[0]: + return -1 + result = len(seq) - 1 + while val < seq[result]: + result -= 1 + return result + + @staticmethod + def take_nearest_value(seq: list[int], val: int, index_shift: int) -> int: + sorted_seq = sorted(seq) + index = LabelledString.find_region_index(sorted_seq, val) + return sorted_seq[index + index_shift] + @staticmethod def generate_span_repl_dict( inserted_string_pairs: list[tuple[Span, tuple[str, str]]], - repl_items: list[tuple[Span, str]] + other_repl_items: list[tuple[Span, str]] ) -> dict[Span, str]: - result = dict(repl_items) + result = dict(other_repl_items) if not inserted_string_pairs: return result - indices, _, _, _, inserted_strings = zip(*sorted([ + indices, _, _, inserted_strings = zip(*sorted([ ( - item[0][flag], + span[flag], -flag, - -item[0][1 - flag], - (1, -1)[flag] * item_index, - item[1][flag] + -span[1 - flag], + str_pair[flag] ) - for item_index, item in enumerate(inserted_string_pairs) + for span, str_pair in inserted_string_pairs for flag in range(2) ])) result.update({ @@ -272,74 +263,113 @@ class LabelledString(SVGMobject, ABC): return "".join(it.chain(*zip(pieces, repl_strs))) @staticmethod - def color_to_int(color: ManimColor) -> int: - hex_code = rgb_to_hex(color_to_rgb(color)) - return int(hex_code[1:], 16) + def rslide(index: int, skipped: list[Span]) -> int: + transfer_dict = dict(sorted(skipped)) + while index in transfer_dict.keys(): + index = transfer_dict[index] + return index + + @staticmethod + def lslide(index: int, skipped: list[Span]) -> int: + transfer_dict = dict(sorted([ + skipped_span[::-1] for skipped_span in skipped + ], reverse=True)) + while index in transfer_dict.keys(): + index = transfer_dict[index] + return index + + @staticmethod + def rgb_to_int(rgb_tuple: tuple[int, int, int]) -> int: + r, g, b = rgb_tuple + rg = r * 256 + g + return rg * 256 + b + + @staticmethod + def int_to_rgb(rgb_int: int) -> tuple[int, int, int]: + rg, b = divmod(rgb_int, 256) + r, g = divmod(rg, 256) + return r, g, b @staticmethod def int_to_hex(rgb_int: int) -> str: return "#{:06x}".format(rgb_int).upper() + @staticmethod + def hex_to_int(rgb_hex: str) -> int: + return int(rgb_hex[1:], 16) + + @staticmethod + def color_to_label(color: ManimColor) -> int: + rgb_tuple = color_to_int_rgb(color) + rgb = LabelledString.rgb_to_int(rgb_tuple) + return rgb - 1 + # Parsing @abstractmethod - def get_skippable_indices(self) -> list[int]: + def get_command_repl_items(self) -> list[tuple[Span, str]]: return [] - @staticmethod - def shrink_span(span: Span, skippable_indices: list[int]) -> Span: - span_begin, span_end = span - while span_begin in skippable_indices: - span_begin += 1 - while span_end - 1 in skippable_indices: - span_end -= 1 - return (span_begin, span_end) + def get_command_spans(self) -> list[Span]: + return [cmd_span for cmd_span, _ in self.command_repl_items] @abstractmethod + def get_extra_entity_spans(self) -> list[Span]: + return [] + def get_entity_spans(self) -> list[Span]: - return [] - - @abstractmethod - def get_bracket_spans(self) -> list[Span]: - return [] - - @abstractmethod - def get_extra_isolated_items(self) -> list[tuple[Span, dict[str, str]]]: - return [] - - def get_specified_items(self) -> list[tuple[Span, dict[str, str]]]: - span_items = list(it.chain( - self.extra_isolated_items, - [ - (span, {}) - for span in self.find_spans_by_selector(self.isolate) - ] + return list(it.chain( + self.command_spans, + self.extra_entity_spans )) - result = [] - for span, attr_dict in span_items: - shrinked_span = self.shrink_span(span, self.skippable_indices) - if shrinked_span[0] >= shrinked_span[1]: - continue - if any([ - entity_span[0] < index < entity_span[1] - for index in shrinked_span - for entity_span in self.entity_spans - ]): - continue - result.append((shrinked_span, attr_dict)) - return result + + @abstractmethod + def get_extra_ignored_spans(self) -> list[int]: + return [] + + def get_skipped_spans(self) -> list[Span]: + return list(it.chain( + self.find_spans(r"\s"), + self.command_spans, + self.extra_ignored_spans + )) + + def shrink_span(self, span: Span) -> Span: + return ( + self.rslide(span[0], self.skipped_spans), + self.lslide(span[1], self.skipped_spans) + ) + + @abstractmethod + def get_internal_specified_spans(self) -> list[Span]: + return [] + + @abstractmethod + def get_external_specified_spans(self) -> list[Span]: + return [] def get_specified_spans(self) -> list[Span]: - return remove_list_redundancies([ - span for span, _ in self.specified_items - ]) + spans = list(it.chain( + self.internal_specified_spans, + self.external_specified_spans, + self.find_substrs(self.isolate) + )) + shrinked_spans = list(filter( + lambda span: span[0] < span[1] and not any([ + entity_span[0] < index < entity_span[1] + for index in span + for entity_span in self.entity_spans + ]), + [self.shrink_span(span) for span in spans] + )) + return remove_list_redundancies(shrinked_spans) + + @abstractmethod + def get_label_span_list(self) -> list[Span]: + return [] def check_overlapping(self) -> None: - spans = remove_list_redundancies(list(it.chain( - self.specified_spans, - self.bracket_spans - ))) - for span_0, span_1 in it.product(spans, repeat=2): + for span_0, span_1 in it.product(self.label_span_list, repeat=2): if not span_0[0] < span_1[0] < span_0[1] < span_1[1]: continue raise ValueError( @@ -348,20 +378,29 @@ class LabelledString(SVGMobject, ABC): ) @abstractmethod - def get_label_span_list(self) -> list[Span]: - return [] - - @abstractmethod - def get_content(self, is_labelled: bool) -> str: + def get_content(self, use_plain_file: bool) -> str: return "" - # Selector - @abstractmethod + def has_predefined_local_colors(self) -> bool: + return False + + # Post-parsing + + def get_labelled_submobjects(self) -> list[VMobject]: + return [submob for _, submob in self.labelled_submobject_items] + def get_cleaned_substr(self, span: Span) -> str: - return "" + span_repl_dict = dict.fromkeys(self.command_spans, "") + return self.get_replaced_substr(span, span_repl_dict) - def get_group_part_items(self) -> list[tuple[str, VGroup]]: + def get_specified_substrs(self) -> list[str]: + return remove_list_redundancies([ + self.get_cleaned_substr(span) + for span in self.specified_spans + ]) + + def get_group_items(self) -> list[tuple[str, VGroup]]: if not self.labelled_submobject_items: return [] @@ -386,31 +425,41 @@ class LabelledString(SVGMobject, ABC): ordered_spans ) ] - group_substrs = [ - self.get_cleaned_substr(span) if span[0] < span[1] else "" + shrinked_spans = [ + self.shrink_span(span) for span in self.get_complement_spans( interval_spans, (ordered_spans[0][0], ordered_spans[-1][1]) ) ] + group_substrs = [ + self.get_cleaned_substr(span) if span[0] < span[1] else "" + for span in shrinked_spans + ] submob_groups = VGroup(*[ VGroup(*labelled_submobjects[slice(*submob_span)]) for submob_span in labelled_submob_spans ]) return list(zip(group_substrs, submob_groups)) - def get_specified_part_items(self) -> list[tuple[str, VGroup]]: - return [ - ( - self.get_substr(span), - self.select_part_by_span(span, substring=False) - ) - for span in self.specified_spans - ] + def get_group_substrs(self) -> list[str]: + return [group_substr for group_substr, _ in self.group_items] + + def get_submob_groups(self) -> list[VGroup]: + return [submob_group for _, submob_group in self.group_items] + + def get_parts_by_group_substr(self, substr: str) -> VGroup: + return VGroup(*[ + group + for group_substr, group in self.group_items + if group_substr == substr + ]) + + # Selector def find_span_components( self, custom_span: Span, substring: bool = True ) -> list[Span]: - shrinked_span = self.shrink_span(custom_span, self.skippable_indices) + shrinked_span = self.shrink_span(custom_span) if shrinked_span[0] >= shrinked_span[1]: return [] @@ -419,12 +468,12 @@ class LabelledString(SVGMobject, ABC): self.full_span, *self.label_span_list ))) - span_begin = max(filter( - lambda index: index <= shrinked_span[0], indices - )) - span_end = min(filter( - lambda index: index >= shrinked_span[1], indices - )) + span_begin = self.take_nearest_value( + indices, shrinked_span[0], 0 + ) + span_end = self.take_nearest_value( + indices, shrinked_span[1] - 1, 1 + ) else: span_begin, span_end = shrinked_span @@ -445,7 +494,7 @@ class LabelledString(SVGMobject, ABC): span_begin = next_begin return result - def select_part_by_span(self, custom_span: Span, **kwargs) -> VGroup: + def get_part_by_custom_span(self, custom_span: Span, **kwargs) -> VGroup: labels = [ label for label, span in enumerate(self.label_span_list) if any([ @@ -460,31 +509,34 @@ class LabelledString(SVGMobject, ABC): if label in labels ]) - def select_parts(self, selector: Selector, **kwargs) -> VGroup: - return VGroup(*filter( - lambda part: part.submobjects, - [ - self.select_part_by_span(span, **kwargs) - for span in self.find_spans_by_selector(selector) - ] - )) - - def select_part( - self, selector: Selector, index: int = 0, **kwargs + def get_parts_by_string( + self, substr: str, + case_sensitive: bool = True, regex: bool = False, **kwargs ) -> VGroup: - return self.select_parts(selector, **kwargs)[index] + flags = 0 + if not case_sensitive: + flags |= re.I + pattern = substr if regex else re.escape(substr) + return VGroup(*[ + self.get_part_by_custom_span(span, **kwargs) + for span in self.find_spans(pattern, flags=flags) + if span[0] < span[1] + ]) - def set_parts_color( - self, selector: Selector, color: ManimColor, **kwargs - ): - self.select_parts(selector, **kwargs).set_color(color) + def get_part_by_string( + self, substr: str, index: int = 0, **kwargs + ) -> VMobject: + return self.get_parts_by_string(substr, **kwargs)[index] + + def set_color_by_string(self, substr: str, color: ManimColor, **kwargs): + self.get_parts_by_string(substr, **kwargs).set_color(color) return self - def set_parts_color_by_dict( - self, color_map: dict[Selector, ManimColor], **kwargs + def set_color_by_string_to_color_map( + self, string_to_color_map: dict[str, ManimColor], **kwargs ): - for selector, color in color_map.items(): - self.set_parts_color(selector, color, **kwargs) + for substr, color in string_to_color_map.items(): + self.set_color_by_string(substr, color, **kwargs) return self def get_string(self) -> str: diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index d4f502f4..fb7922e1 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -1,47 +1,27 @@ from __future__ import annotations import itertools as it -import re +import colour +from typing import Union, Sequence from manimlib.mobject.svg.labelled_string import LabelledString -from manimlib.utils.tex_file_writing import display_during_execution -from manimlib.utils.tex_file_writing import get_tex_config from manimlib.utils.tex_file_writing import tex_to_svg_file +from manimlib.utils.tex_file_writing import get_tex_config +from manimlib.utils.tex_file_writing import display_during_execution + from typing import TYPE_CHECKING if TYPE_CHECKING: - from colour import Color - from typing import Iterable, Union - + from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.mobject.types.vectorized_mobject import VGroup - - ManimColor = Union[str, Color] + ManimColor = Union[str, colour.Color, Sequence[float]] Span = tuple[int, int] - Selector = Union[ - str, - re.Pattern, - tuple[Union[int, None], Union[int, None]], - Iterable[Union[ - str, - re.Pattern, - tuple[Union[int, None], Union[int, None]] - ]] - ] SCALE_FACTOR_PER_FONT_POINT = 0.001 -TEX_COLOR_COMMANDS_DICT = { - "\\color": (1, False), - "\\textcolor": (1, False), - "\\pagecolor": (1, True), - "\\colorbox": (1, True), - "\\fcolorbox": (2, True), -} - - class MTex(LabelledString): CONFIG = { "font_size": 48, @@ -52,7 +32,7 @@ class MTex(LabelledString): def __init__(self, tex_string: str, **kwargs): # Prevent from passing an empty string. - if not tex_string.strip(): + if not tex_string: tex_string = "\\\\" self.tex_string = tex_string super().__init__(tex_string, **kwargs) @@ -67,6 +47,7 @@ class MTex(LabelledString): self.svg_default, self.path_string_config, self.base_color, + self.use_plain_file, self.isolate, self.tex_string, self.alignment, @@ -80,87 +61,85 @@ class MTex(LabelledString): tex_config["text_to_replace"], content ) - with display_during_execution(f"Writing \"{self.string}\""): + with display_during_execution(f"Writing \"{self.tex_string}\""): file_path = tex_to_svg_file(full_tex) return file_path - def parse(self) -> None: + def pre_parse(self) -> None: + super().pre_parse() self.backslash_indices = self.get_backslash_indices() - self.command_spans = self.get_command_spans() - self.brace_spans = self.get_brace_spans() - self.script_char_indices = self.get_script_char_indices() + self.brace_index_pairs = self.get_brace_index_pairs() + self.script_char_spans = self.get_script_char_spans() self.script_content_spans = self.get_script_content_spans() self.script_spans = self.get_script_spans() - self.command_repl_items = self.get_command_repl_items() - super().parse() # Toolkits @staticmethod def get_color_command_str(rgb_int: int) -> str: - rg, b = divmod(rgb_int, 256) - r, g = divmod(rg, 256) - return f"\\color[RGB]{{{r}, {g}, {b}}}" + rgb_tuple = MTex.int_to_rgb(rgb_int) + return "".join([ + "\\color[RGB]", + "{", + ",".join(map(str, rgb_tuple)), + "}" + ]) - # Parsing + # Pre-parsing def get_backslash_indices(self) -> list[int]: # The latter of `\\` doesn't count. - return self.find_indices(r"\\.") + return list(it.chain(*[ + range(span[0], span[1], 2) + for span in self.find_spans(r"\\+") + ])) - def get_command_spans(self) -> list[Span]: - return [ - self.match(r"\\(?:[a-zA-Z]+|.)", pos=index).span() - for index in self.backslash_indices - ] - - def get_unescaped_char_indices(self, char: str) -> list[int]: - return list(filter( - lambda index: index - 1 not in self.backslash_indices, - self.find_indices(re.escape(char)) + def get_unescaped_char_spans(self, chars: str): + return sorted(filter( + lambda span: span[0] - 1 not in self.backslash_indices, + self.find_substrs(list(chars)) )) - def get_brace_spans(self) -> list[Span]: - span_begins = [] - span_ends = [] - span_begins_stack = [] - char_items = sorted([ - (index, char) - for char in "{}" - for index in self.get_unescaped_char_indices(char) - ]) - for index, char in char_items: - if char == "{": - span_begins_stack.append(index) + def get_brace_index_pairs(self) -> list[Span]: + left_brace_indices = [] + right_brace_indices = [] + left_brace_indices_stack = [] + for span in self.get_unescaped_char_spans("{}"): + index = span[0] + if self.get_substr(span) == "{": + left_brace_indices_stack.append(index) else: - if not span_begins_stack: + if not left_brace_indices_stack: raise ValueError("Missing '{' inserted") - span_begins.append(span_begins_stack.pop()) - span_ends.append(index + 1) - if span_begins_stack: + left_brace_index = left_brace_indices_stack.pop() + left_brace_indices.append(left_brace_index) + right_brace_indices.append(index) + if left_brace_indices_stack: raise ValueError("Missing '}' inserted") - return list(zip(span_begins, span_ends)) + return list(zip(left_brace_indices, right_brace_indices)) - def get_script_char_indices(self) -> list[int]: - return list(it.chain(*[ - self.get_unescaped_char_indices(char) - for char in "_^" - ])) + def get_script_char_spans(self) -> list[int]: + return self.get_unescaped_char_spans("_^") def get_script_content_spans(self) -> list[Span]: result = [] - script_entity_dict = dict(it.chain( - self.brace_spans, - self.command_spans - )) - for index in self.script_char_indices: - span_begin = self.match(r"\s*", pos=index + 1).end() - if span_begin in script_entity_dict.keys(): - span_end = script_entity_dict[span_begin] + brace_indices_dict = dict(self.brace_index_pairs) + script_pattern = r"[a-zA-Z0-9]|\\[a-zA-Z]+" + for script_char_span in self.script_char_spans: + span_begin = self.match(r"\s*", pos=script_char_span[1]).end() + if span_begin in brace_indices_dict.keys(): + span_end = brace_indices_dict[span_begin] + 1 else: - match_obj = self.match(r".", pos=span_begin) - if match_obj is None: - continue + match_obj = self.match(script_pattern, pos=span_begin) + if not match_obj: + script_name = { + "_": "subscript", + "^": "superscript" + }[script_char] + raise ValueError( + f"Unclear {script_name} detected while parsing. " + "Please use braces to clarify" + ) span_end = match_obj.end() result.append((span_begin, span_end)) return result @@ -168,100 +147,110 @@ class MTex(LabelledString): def get_script_spans(self) -> list[Span]: return [ ( - self.match(r"[\s\S]*?(\s*)$", endpos=index).start(1), + self.search(r"\s*$", endpos=script_char_span[0]).start(), script_content_span[1] ) - for index, script_content_span in zip( - self.script_char_indices, self.script_content_spans + for script_char_span, script_content_span in zip( + self.script_char_spans, self.script_content_spans ) ] + # Parsing + def get_command_repl_items(self) -> list[tuple[Span, str]]: + color_related_command_dict = { + "color": (1, False), + "textcolor": (1, False), + "pagecolor": (1, True), + "colorbox": (1, True), + "fcolorbox": (2, True), + } result = [] - brace_spans_dict = dict(self.brace_spans) - brace_begins = list(brace_spans_dict.keys()) - for cmd_span in self.command_spans: - cmd_name = self.get_substr(cmd_span) - if cmd_name not in TEX_COLOR_COMMANDS_DICT.keys(): + backslash_indices = self.backslash_indices + right_brace_indices = [ + right_index + for left_index, right_index in self.brace_index_pairs + ] + pattern = "".join([ + r"\\", + "(", + "|".join(color_related_command_dict.keys()), + ")", + r"(?![a-zA-Z])" + ]) + for match_obj in self.finditer(pattern): + span_begin, cmd_end = match_obj.span() + if span_begin not in backslash_indices: continue - n_braces, substitute_cmd = TEX_COLOR_COMMANDS_DICT[cmd_name] - span_begin, span_end = cmd_span - for _ in range(n_braces): - span_end = brace_spans_dict[min(filter( - lambda index: index >= span_end, - brace_begins - ))] + cmd_name = match_obj.group(1) + n_braces, substitute_cmd = color_related_command_dict[cmd_name] + span_end = self.take_nearest_value( + right_brace_indices, cmd_end, n_braces + ) + 1 if substitute_cmd: - repl_str = cmd_name + n_braces * "{black}" + repl_str = "\\" + cmd_name + n_braces * "{black}" else: repl_str = "" result.append(((span_begin, span_end), repl_str)) return result - def get_skippable_indices(self) -> list[int]: - return list(it.chain( - self.find_indices(r"\s"), - self.script_char_indices - )) + def get_extra_entity_spans(self) -> list[Span]: + return [ + self.match(r"\\([a-zA-Z]+|.)", pos=index).span() + for index in self.backslash_indices + ] - def get_entity_spans(self) -> list[Span]: - return self.command_spans.copy() - - def get_bracket_spans(self) -> list[Span]: - return self.brace_spans.copy() - - def get_extra_isolated_items(self) -> list[tuple[Span, dict[str, str]]]: - result = [] + def get_extra_ignored_spans(self) -> list[int]: + return self.script_char_spans.copy() + def get_internal_specified_spans(self) -> list[Span]: # Match paired double braces (`{{...}}`). - sorted_brace_spans = sorted( - self.brace_spans, key=lambda span: span[1] - ) + result = [] + reversed_brace_indices_dict = dict([ + pair[::-1] for pair in self.brace_index_pairs + ]) skip = False - for prev_span, span in self.get_neighbouring_pairs( - sorted_brace_spans + for prev_right_index, right_index in self.get_neighbouring_pairs( + list(reversed_brace_indices_dict.keys()) ): if skip: skip = False continue - if span[0] != prev_span[0] - 1 or span[1] != prev_span[1] + 1: + if right_index != prev_right_index + 1: continue - result.append(span) + left_index = reversed_brace_indices_dict[right_index] + prev_left_index = reversed_brace_indices_dict[prev_right_index] + if left_index != prev_left_index - 1: + continue + result.append((left_index, right_index + 1)) skip = True + return result - result.extend(it.chain(*[ - self.find_spans_by_selector(selector) - for selector in self.tex_to_color_map.keys() - ])) - return [(span, {}) for span in result] + def get_external_specified_spans(self) -> list[Span]: + return self.find_substrs(list(self.tex_to_color_map.keys())) def get_label_span_list(self) -> list[Span]: result = self.script_content_spans.copy() - reversed_script_spans_dict = dict([ - script_span[::-1] for script_span in self.script_spans - ]) for span_begin, span_end in self.specified_spans: - while span_end in reversed_script_spans_dict.keys(): - span_end = reversed_script_spans_dict[span_end] - if span_begin >= span_end: + shrinked_end = self.lslide(span_end, self.script_spans) + if span_begin >= shrinked_end: continue - shrinked_span = (span_begin, span_end) + shrinked_span = (span_begin, shrinked_end) if shrinked_span in result: continue result.append(shrinked_span) return result - def get_content(self, is_labelled: bool) -> str: - if is_labelled: - extended_label_span_list = [] - script_spans_dict = dict(self.script_spans) - for span in self.label_span_list: - if span not in self.script_content_spans: - span_begin, span_end = span - while span_end in script_spans_dict.keys(): - span_end = script_spans_dict[span_end] - span = (span_begin, span_end) - extended_label_span_list.append(span) + def get_content(self, use_plain_file: bool) -> str: + if use_plain_file: + span_repl_dict = {} + else: + extended_label_span_list = [ + span + if span in self.script_content_spans + else (span[0], self.rslide(span[1], self.script_spans)) + for span in self.label_span_list + ] inserted_string_pairs = [ (span, ( "{{" + self.get_color_command_str(label + 1), @@ -270,51 +259,42 @@ class MTex(LabelledString): for label, span in enumerate(extended_label_span_list) ] span_repl_dict = self.generate_span_repl_dict( - inserted_string_pairs, self.command_repl_items + inserted_string_pairs, + self.command_repl_items ) - else: - span_repl_dict = {} result = self.get_replaced_substr(self.full_span, span_repl_dict) if self.tex_environment: - if isinstance(self.tex_environment, str): - prefix = f"\\begin{{{self.tex_environment}}}" - suffix = f"\\end{{{self.tex_environment}}}" - else: - prefix, suffix = self.tex_environment - result = "\n".join([prefix, result, suffix]) + result = "\n".join([ + f"\\begin{{{self.tex_environment}}}", + result, + f"\\end{{{self.tex_environment}}}" + ]) if self.alignment: result = "\n".join([self.alignment, result]) - if not is_labelled: + if use_plain_file: result = "\n".join([ - self.get_color_command_str(self.base_color_int), + self.get_color_command_str(self.hex_to_int(self.base_color)), result ]) return result - # Selector + @property + def has_predefined_local_colors(self) -> bool: + return bool(self.command_repl_items) + + # Post-parsing def get_cleaned_substr(self, span: Span) -> str: - if not self.brace_spans: - brace_begins, brace_ends = [], [] - else: - brace_begins, brace_ends = zip(*self.brace_spans) - left_brace_indices = list(brace_begins) - right_brace_indices = [index - 1 for index in brace_ends] - skippable_indices = list(it.chain( - self.skippable_indices, - left_brace_indices, - right_brace_indices - )) - shrinked_span = self.shrink_span(span, skippable_indices) - - if shrinked_span[0] >= shrinked_span[1]: - return "" + substr = super().get_cleaned_substr(span) + if not self.brace_index_pairs: + return substr # Balance braces. + left_brace_indices, right_brace_indices = zip(*self.brace_index_pairs) unclosed_left_braces = 0 unclosed_right_braces = 0 - for index in range(*shrinked_span): + for index in range(*span): if index in left_brace_indices: unclosed_left_braces += 1 elif index in right_brace_indices: @@ -324,27 +304,27 @@ class MTex(LabelledString): unclosed_left_braces -= 1 return "".join([ unclosed_right_braces * "{", - self.get_substr(shrinked_span), + substr, unclosed_left_braces * "}" ]) # Method alias - def get_parts_by_tex(self, selector: Selector, **kwargs) -> VGroup: - return self.select_parts(selector, **kwargs) + def get_parts_by_tex(self, tex: str, **kwargs) -> VGroup: + return self.get_parts_by_string(tex, **kwargs) - def get_part_by_tex(self, selector: Selector, **kwargs) -> VGroup: - return self.select_part(selector, **kwargs) + def get_part_by_tex(self, tex: str, **kwargs) -> VMobject: + return self.get_part_by_string(tex, **kwargs) - def set_color_by_tex( - self, selector: Selector, color: ManimColor, **kwargs - ): - return self.set_parts_color(selector, color, **kwargs) + def set_color_by_tex(self, tex: str, color: ManimColor, **kwargs): + return self.set_color_by_string(tex, color, **kwargs) def set_color_by_tex_to_color_map( - self, color_map: dict[Selector, ManimColor], **kwargs + self, tex_to_color_map: dict[str, ManimColor], **kwargs ): - return self.set_parts_color_by_dict(color_map, **kwargs) + return self.set_color_by_string_to_color_map( + tex_to_color_map, **kwargs + ) def get_tex(self) -> str: return self.get_string() diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 13f5b9c0..c3c3be19 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -1,64 +1,93 @@ from __future__ import annotations -from contextlib import contextmanager -import itertools as it import os -from pathlib import Path import re +import itertools as it +from pathlib import Path +from contextlib import contextmanager +import typing +from typing import Iterable, Sequence, Union -import manimpango import pygments import pygments.formatters import pygments.lexers -from manimlib.constants import DEFAULT_PIXEL_WIDTH, FRAME_WIDTH -from manimlib.constants import NORMAL +from manimpango import MarkupUtils + from manimlib.logger import log +from manimlib.constants import * from manimlib.mobject.svg.labelled_string import LabelledString -from manimlib.utils.config_ops import digest_config from manimlib.utils.customization import get_customization +from manimlib.utils.tex_file_writing import tex_hash +from manimlib.utils.config_ops import digest_config from manimlib.utils.directories import get_downloads_dir from manimlib.utils.directories import get_text_dir -from manimlib.utils.tex_file_writing import tex_hash +from manimlib.utils.iterables import remove_list_redundancies + from typing import TYPE_CHECKING if TYPE_CHECKING: - from colour import Color - from typing import Iterable, Union - + from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.mobject.types.vectorized_mobject import VGroup - - ManimColor = Union[str, Color] + ManimColor = Union[str, colour.Color, Sequence[float]] Span = tuple[int, int] - Selector = Union[ - str, - re.Pattern, - tuple[Union[int, None], Union[int, None]], - Iterable[Union[ - str, - re.Pattern, - tuple[Union[int, None], Union[int, None]] - ]] - ] TEXT_MOB_SCALE_FACTOR = 0.0076 DEFAULT_LINE_SPACING_SCALE = 0.6 -# Ensure the canvas is large enough to hold all glyphs. -DEFAULT_CANVAS_WIDTH = 16384 -DEFAULT_CANVAS_HEIGHT = 16384 # See https://docs.gtk.org/Pango/pango_markup.html -MARKUP_COLOR_KEYS = ( - "foreground", "fgcolor", "color", - "background", "bgcolor", - "underline_color", - "overline_color", - "strikethrough_color" +# A tag containing two aliases will cause warning, +# so only use the first key of each group of aliases. +SPAN_ATTR_KEY_ALIAS_LIST = ( + ("font", "font_desc"), + ("font_family", "face"), + ("font_size", "size"), + ("font_style", "style"), + ("font_weight", "weight"), + ("font_variant", "variant"), + ("font_stretch", "stretch"), + ("font_features",), + ("foreground", "fgcolor", "color"), + ("background", "bgcolor"), + ("alpha", "fgalpha"), + ("background_alpha", "bgalpha"), + ("underline",), + ("underline_color",), + ("overline",), + ("overline_color",), + ("rise",), + ("baseline_shift",), + ("font_scale",), + ("strikethrough",), + ("strikethrough_color",), + ("fallback",), + ("lang",), + ("letter_spacing",), + ("gravity",), + ("gravity_hint",), + ("show",), + ("insert_hyphens",), + ("allow_breaks",), + ("line_height",), + ("text_transform",), + ("segment",), ) -MARKUP_TAG_CONVERSION_DICT = { +COLOR_RELATED_KEYS = ( + "foreground", + "background", + "underline_color", + "overline_color", + "strikethrough_color" +) +SPAN_ATTR_KEY_CONVERSION = { + key: key_alias_list[0] + for key_alias_list in SPAN_ATTR_KEY_ALIAS_LIST + for key in key_alias_list +} +TAG_TO_ATTR_DICT = { "b": {"font_weight": "bold"}, "big": {"font_size": "larger"}, "i": {"font_style": "italic"}, @@ -91,7 +120,7 @@ class MarkupText(LabelledString): "justify": False, "indent": 0, "alignment": "LEFT", - "line_width": None, + "line_width_factor": None, "font": "", "slant": NORMAL, "weight": NORMAL, @@ -112,7 +141,9 @@ class MarkupText(LabelledString): if not self.font: self.font = get_customization()["style"]["font"] if self.is_markup: - self.validate_markup_string(text) + validate_error = MarkupUtils.validate(text) + if validate_error: + raise ValueError(validate_error) self.text = text super().__init__(text, **kwargs) @@ -134,6 +165,7 @@ class MarkupText(LabelledString): self.svg_default, self.path_string_config, self.base_color, + self.use_plain_file, self.isolate, self.text, self.is_markup, @@ -142,7 +174,7 @@ class MarkupText(LabelledString): self.justify, self.indent, self.alignment, - self.line_width, + self.line_width_factor, self.font, self.slant, self.weight, @@ -169,32 +201,23 @@ class MarkupText(LabelledString): kwargs[short_name] = kwargs.pop(long_name) def get_file_path_by_content(self, content: str) -> str: - hash_content = str(( - content, - self.justify, - self.indent, - self.alignment, - self.line_width - )) svg_file = os.path.join( - get_text_dir(), tex_hash(hash_content) + ".svg" + get_text_dir(), tex_hash(content) + ".svg" ) if not os.path.exists(svg_file): self.markup_to_svg(content, svg_file) return svg_file def markup_to_svg(self, markup_str: str, file_name: str) -> str: - self.validate_markup_string(markup_str) - # `manimpango` is under construction, # so the following code is intended to suit its interface alignment = _Alignment(self.alignment) - if self.line_width is None: + if self.line_width_factor is None: pango_width = -1 else: - pango_width = self.line_width / FRAME_WIDTH * DEFAULT_PIXEL_WIDTH + pango_width = self.line_width_factor * DEFAULT_PIXEL_WIDTH - return manimpango.MarkupUtils.text2svg( + return MarkupUtils.text2svg( text=markup_str, font="", # Already handled slant="NORMAL", # Already handled @@ -205,8 +228,8 @@ class MarkupText(LabelledString): file_name=file_name, START_X=0, START_Y=0, - width=DEFAULT_CANVAS_WIDTH, - height=DEFAULT_CANVAS_HEIGHT, + width=DEFAULT_PIXEL_WIDTH, + height=DEFAULT_PIXEL_HEIGHT, justify=self.justify, indent=self.indent, line_spacing=None, # Already handled @@ -214,22 +237,13 @@ class MarkupText(LabelledString): pango_width=pango_width ) - @staticmethod - def validate_markup_string(markup_str: str) -> None: - validate_error = manimpango.MarkupUtils.validate(markup_str) - if not validate_error: - return - raise ValueError( - f"Invalid markup string \"{markup_str}\"\n" - f"{validate_error}" - ) - - def parse(self) -> None: - self.global_attr_dict = self.get_global_attr_dict() - self.tag_pairs_from_markup = self.get_tag_pairs_from_markup() - self.tag_spans = self.get_tag_spans() - self.items_from_markup = self.get_items_from_markup() - super().parse() + def pre_parse(self) -> None: + super().pre_parse() + self.tag_items_from_markup = self.get_tag_items_from_markup() + self.global_dict_from_config = self.get_global_dict_from_config() + self.local_dicts_from_markup = self.get_local_dicts_from_markup() + self.local_dicts_from_config = self.get_local_dicts_from_config() + self.predefined_attr_dicts = self.get_predefined_attr_dicts() # Toolkits @@ -240,46 +254,87 @@ class MarkupText(LabelledString): for key, val in attr_dict.items() ]) - # Parsing + @staticmethod + def merge_attr_dicts( + attr_dict_items: list[Span, str, typing.Any] + ) -> list[tuple[Span, dict[str, str]]]: + index_seq = [0] + attr_dict_list = [{}] + for span, attr_dict in attr_dict_items: + if span[0] >= span[1]: + continue + region_indices = [ + MarkupText.find_region_index(index_seq, index) + for index in span + ] + for flag in (1, 0): + if index_seq[region_indices[flag]] == span[flag]: + continue + region_index = region_indices[flag] + index_seq.insert(region_index + 1, span[flag]) + attr_dict_list.insert( + region_index + 1, attr_dict_list[region_index].copy() + ) + region_indices[flag] += 1 + if flag == 0: + region_indices[1] += 1 + for key, val in attr_dict.items(): + if not key: + continue + for mid_dict in attr_dict_list[slice(*region_indices)]: + mid_dict[key] = val + return list(zip( + MarkupText.get_neighbouring_pairs(index_seq), attr_dict_list[:-1] + )) - def get_global_attr_dict(self) -> dict[str, str]: - result = { - "font_size": str(self.font_size * 1024), - "foreground": self.int_to_hex(self.base_color_int), - "font_family": self.font, - "font_style": self.slant, - "font_weight": self.weight, - } - # `line_height` attribute is supported since Pango 1.50. - if tuple(map(int, manimpango.pango_version().split("."))) >= (1, 50): - result.update({ - "line_height": str(( - (self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1 - ) * 0.6), - }) - return result + def find_substr_or_span( + self, substr_or_span: str | tuple[int | None, int | None] + ) -> list[Span]: + if isinstance(substr_or_span, str): + return self.find_substr(substr_or_span) - def get_tag_pairs_from_markup( + span = tuple([ + ( + min(index, self.string_len) + if index >= 0 + else max(index + self.string_len, 0) + ) + if index is not None else default_index + for index, default_index in zip(substr_or_span, self.full_span) + ]) + if span[0] >= span[1]: + return [] + return [span] + + # Pre-parsing + + def get_tag_items_from_markup( self ) -> list[tuple[Span, Span, dict[str, str]]]: if not self.is_markup: return [] - tag_pattern = r"""<(/?)(\w+)\s*((\w+\s*\=\s*(['"])[\s\S]*?\5\s*)*)>""" - attr_pattern = r"""(\w+)\s*\=\s*(['"])([\s\S]*?)\2""" + tag_pattern = r"""<(/?)(\w+)\s*((?:\w+\s*\=\s*(['"]).*?\4\s*)*)>""" + attr_pattern = r"""(\w+)\s*\=\s*(['"])(.*?)\2""" begin_match_obj_stack = [] match_obj_pairs = [] - for match_obj in re.finditer(tag_pattern, self.string): + for match_obj in self.finditer(tag_pattern): if not match_obj.group(1): begin_match_obj_stack.append(match_obj) else: match_obj_pairs.append( (begin_match_obj_stack.pop(), match_obj) ) + if begin_match_obj_stack: + raise ValueError("Unclosed tag(s) detected") result = [] for begin_match_obj, end_match_obj in match_obj_pairs: tag_name = begin_match_obj.group(2) + if tag_name != end_match_obj.group(2): + raise ValueError("Unmatched tag names") + if end_match_obj.group(3): + raise ValueError("Attributes shan't exist in ending tags") if tag_name == "span": attr_dict = { match.group(1): match.group(3) @@ -287,157 +342,189 @@ class MarkupText(LabelledString): attr_pattern, begin_match_obj.group(3) ) } + elif tag_name in TAG_TO_ATTR_DICT.keys(): + if begin_match_obj.group(3): + raise ValueError( + f"Attributes shan't exist in tag '{tag_name}'" + ) + attr_dict = TAG_TO_ATTR_DICT[tag_name].copy() else: - attr_dict = MARKUP_TAG_CONVERSION_DICT.get(tag_name, {}) + raise ValueError(f"Unknown tag: '{tag_name}'") result.append( (begin_match_obj.span(), end_match_obj.span(), attr_dict) ) return result - def get_tag_spans(self) -> list[Span]: - return [ - tag_span - for begin_tag, end_tag, _ in self.tag_pairs_from_markup - for tag_span in (begin_tag, end_tag) - ] - - def get_items_from_markup(self) -> list[Span]: - return [ - ((begin_tag_span[0], end_tag_span[1]), attr_dict) - for begin_tag_span, end_tag_span, attr_dict - in self.tag_pairs_from_markup - ] - - def get_skippable_indices(self) -> list[int]: - return self.find_indices(r"\s") - - def get_entity_spans(self) -> list[Span]: - result = self.tag_spans.copy() - if self.is_markup: - result.extend(self.find_spans(r"&[\s\S]*?;")) + def get_global_dict_from_config(self) -> dict[str, typing.Any]: + result = { + "line_height": ( + (self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1 + ) * 0.6, + "font_family": self.font, + "font_size": self.font_size * 1024, + "font_style": self.slant, + "font_weight": self.weight + } + result.update(self.global_config) return result - def get_bracket_spans(self) -> list[Span]: - return [span for span, _ in self.items_from_markup] + def get_local_dicts_from_markup( + self + ) -> list[Span, dict[str, str]]: + return sorted([ + ((begin_tag_span[0], end_tag_span[1]), attr_dict) + for begin_tag_span, end_tag_span, attr_dict + in self.tag_items_from_markup + ]) - def get_extra_isolated_items(self) -> list[tuple[Span, dict[str, str]]]: - return list(it.chain( - self.items_from_markup, - [ - (span, {key: val}) - for t2x_dict, key in ( - (self.t2c, "foreground"), - (self.t2f, "font_family"), - (self.t2s, "font_style"), - (self.t2w, "font_weight") - ) - for selector, val in t2x_dict.items() - for span in self.find_spans_by_selector(selector) - ], - [ - (span, local_config) - for selector, local_config in self.local_configs.items() - for span in self.find_spans_by_selector(selector) - ] - )) - - def get_label_span_list(self) -> list[Span]: - interval_spans = sorted(it.chain( - self.tag_spans, - [ - (index, index) - for span in self.specified_spans - for index in span - ] - )) - text_spans = self.get_complement_spans(interval_spans, self.full_span) - if self.is_markup: - pattern = r"[0-9a-zA-Z]+|(?:&[\s\S]*?;|[^0-9a-zA-Z\s])+" - else: - pattern = r"[0-9a-zA-Z]+|[^0-9a-zA-Z\s]+" - return list(it.chain(*[ - self.find_spans(pattern, pos=span_begin, endpos=span_end) - for span_begin, span_end in text_spans - ])) - - def get_content(self, is_labelled: bool) -> str: - predefined_items = [ - (self.full_span, self.global_attr_dict), - (self.full_span, self.global_config), - *self.specified_items + def get_local_dicts_from_config( + self + ) -> list[Span, dict[str, typing.Any]]: + return [ + (span, {key: val}) + for t2x_dict, key in ( + (self.t2c, "foreground"), + (self.t2f, "font_family"), + (self.t2s, "font_style"), + (self.t2w, "font_weight") + ) + for substr_or_span, val in t2x_dict.items() + for span in self.find_substr_or_span(substr_or_span) + ] + [ + (span, local_config) + for substr_or_span, local_config in self.local_configs.items() + for span in self.find_substr_or_span(substr_or_span) ] - if is_labelled: - attr_dict_items = list(it.chain( - [ - (span, { - key: - "black" if key.lower() in MARKUP_COLOR_KEYS else val - for key, val in attr_dict.items() - }) - for span, attr_dict in predefined_items - ], - [ - (span, {"foreground": self.int_to_hex(label + 1)}) - for label, span in enumerate(self.label_span_list) - ] - )) - else: - attr_dict_items = list(it.chain( - predefined_items, - [ - (span, {}) - for span in self.label_span_list - ] - )) - inserted_string_pairs = [ - (span, ( - f"", - "" - )) - for span, attr_dict in attr_dict_items if attr_dict + + def get_predefined_attr_dicts(self) -> list[Span, dict[str, str]]: + attr_dict_items = [ + (self.full_span, self.global_dict_from_config), + *self.local_dicts_from_markup, + *self.local_dicts_from_config ] - repl_items = [ - (tag_span, "") for tag_span in self.tag_spans + return [ + (span, { + SPAN_ATTR_KEY_CONVERSION[key.lower()]: str(val) + for key, val in attr_dict.items() + }) + for span, attr_dict in attr_dict_items + ] + + # Parsing + + def get_command_repl_items(self) -> list[tuple[Span, str]]: + result = [ + (tag_span, "") + for begin_tag, end_tag, _ in self.tag_items_from_markup + for tag_span in (begin_tag, end_tag) ] if not self.is_markup: - repl_items.extend([ + result += [ (span, escaped) for char, escaped in ( ("&", "&"), (">", ">"), ("<", "<") ) - for span in self.find_spans(re.escape(char)) - ]) + for span in self.find_substr(char) + ] + return result + + def get_extra_entity_spans(self) -> list[Span]: + if not self.is_markup: + return [] + return self.find_spans(r"&.*?;") + + def get_extra_ignored_spans(self) -> list[int]: + return [] + + def get_internal_specified_spans(self) -> list[Span]: + return [span for span, _ in self.local_dicts_from_markup] + + def get_external_specified_spans(self) -> list[Span]: + return [span for span, _ in self.local_dicts_from_config] + + def get_label_span_list(self) -> list[Span]: + breakup_indices = remove_list_redundancies(list(it.chain(*it.chain( + self.find_spans(r"\s+"), + self.find_spans(r"\b"), + self.specified_spans + )))) + breakup_indices = sorted(filter( + lambda index: not any([ + span[0] < index < span[1] + for span in self.entity_spans + ]), + breakup_indices + )) + return list(filter( + lambda span: self.get_substr(span).strip(), + self.get_neighbouring_pairs(breakup_indices) + )) + + def get_content(self, use_plain_file: bool) -> str: + if use_plain_file: + attr_dict_items = [ + (self.full_span, {"foreground": self.base_color}), + *self.predefined_attr_dicts, + *[ + (span, {}) + for span in self.label_span_list + ] + ] + else: + attr_dict_items = [ + (self.full_span, {"foreground": BLACK}), + *[ + (span, { + key: BLACK if key in COLOR_RELATED_KEYS else val + for key, val in attr_dict.items() + }) + for span, attr_dict in self.predefined_attr_dicts + ], + *[ + (span, {"foreground": self.int_to_hex(label + 1)}) + for label, span in enumerate(self.label_span_list) + ] + ] + inserted_string_pairs = [ + (span, ( + f"", + "" + )) + for span, attr_dict in self.merge_attr_dicts(attr_dict_items) + ] span_repl_dict = self.generate_span_repl_dict( - inserted_string_pairs, repl_items + inserted_string_pairs, self.command_repl_items ) return self.get_replaced_substr(self.full_span, span_repl_dict) - # Selector - - def get_cleaned_substr(self, span: Span) -> str: - repl_dict = dict.fromkeys(self.tag_spans, "") - return self.get_replaced_substr(span, repl_dict).strip() + @property + def has_predefined_local_colors(self) -> bool: + return any([ + key in COLOR_RELATED_KEYS + for _, attr_dict in self.predefined_attr_dicts + for key in attr_dict.keys() + ]) # Method alias - def get_parts_by_text(self, selector: Selector, **kwargs) -> VGroup: - return self.select_parts(selector, **kwargs) + def get_parts_by_text(self, text: str, **kwargs) -> VGroup: + return self.get_parts_by_string(text, **kwargs) - def get_part_by_text(self, selector: Selector, **kwargs) -> VGroup: - return self.select_part(selector, **kwargs) + def get_part_by_text(self, text: str, **kwargs) -> VMobject: + return self.get_part_by_string(text, **kwargs) - def set_color_by_text( - self, selector: Selector, color: ManimColor, **kwargs - ): - return self.set_parts_color(selector, color, **kwargs) + def set_color_by_text(self, text: str, color: ManimColor, **kwargs): + return self.set_color_by_string(text, color, **kwargs) def set_color_by_text_to_color_map( - self, color_map: dict[Selector, ManimColor], **kwargs + self, text_to_color_map: dict[str, ManimColor], **kwargs ): - return self.set_parts_color_by_dict(color_map, **kwargs) + return self.set_color_by_string_to_color_map( + text_to_color_map, **kwargs + ) def get_text(self) -> str: return self.get_string() diff --git a/manimlib/mobject/types/image_mobject.py b/manimlib/mobject/types/image_mobject.py index 0f9c4d0d..dd993319 100644 --- a/manimlib/mobject/types/image_mobject.py +++ b/manimlib/mobject/types/image_mobject.py @@ -48,6 +48,9 @@ class ImageMobject(Mobject): mob.data["opacity"] = np.array([[o] for o in listify(opacity)]) return self + def set_color(self, color, opacity=None, recurse=None): + return self + def point_to_rgb(self, point: np.ndarray) -> np.ndarray: x0, y0 = self.get_corner(UL)[:2] x1, y1 = self.get_corner(DR)[:2] From 37075590b50c5250bd2c4e98cfca5afe25e2d96f Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Fri, 22 Apr 2022 16:42:45 +0800 Subject: [PATCH 40/64] Sort imports --- manimlib/__init__.py | 6 +++--- manimlib/mobject/mobject.py | 10 +++++----- manimlib/mobject/three_dimensions.py | 2 ++ manimlib/scene/interactive_scene.py | 22 +++++++++++----------- manimlib/scene/scene.py | 10 ++++------ requirements.txt | 22 +++++++++++----------- setup.cfg | 23 ++++++++++++----------- 7 files changed, 48 insertions(+), 47 deletions(-) diff --git a/manimlib/__init__.py b/manimlib/__init__.py index 1688de0a..2043738c 100644 --- a/manimlib/__init__.py +++ b/manimlib/__init__.py @@ -4,6 +4,8 @@ __version__ = pkg_resources.get_distribution("manimgl").version from manimlib.constants import * +from manimlib.window import * + from manimlib.animation.animation import * from manimlib.animation.composition import * from manimlib.animation.creation import * @@ -50,8 +52,8 @@ from manimlib.mobject.types.vectorized_mobject import * from manimlib.mobject.value_tracker import * from manimlib.mobject.vector_field import * -from manimlib.scene.scene import * from manimlib.scene.interactive_scene import * +from manimlib.scene.scene import * from manimlib.scene.three_d_scene import * from manimlib.utils.bezier import * @@ -68,5 +70,3 @@ from manimlib.utils.rate_functions import * from manimlib.utils.simple_functions import * from manimlib.utils.sounds import * from manimlib.utils.space_ops import * - -from manimlib.window import * diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 9dffc83c..3b628970 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -1,12 +1,12 @@ from __future__ import annotations import copy -import sys -import random -import itertools as it from functools import wraps -import pickle +import itertools as it import os +import pickle +import random +import sys import moderngl import numbers @@ -1224,7 +1224,7 @@ class Mobject(object): bb = self.get_bounding_box() return np.array([ [bb[indices[-i + 1]][i] for i in range(3)] - for indices in it.product(*3 * [[0, 2]]) + for indices in it.product([0, 2], repeat=3) ]) def get_center(self) -> np.ndarray: diff --git a/manimlib/mobject/three_dimensions.py b/manimlib/mobject/three_dimensions.py index a42aa7ba..3a1a8010 100644 --- a/manimlib/mobject/three_dimensions.py +++ b/manimlib/mobject/three_dimensions.py @@ -2,6 +2,8 @@ from __future__ import annotations import math +import numpy as np + from manimlib.constants import BLUE, BLUE_D, BLUE_E from manimlib.constants import IN, ORIGIN, OUT, RIGHT from manimlib.constants import PI, TAU diff --git a/manimlib/scene/interactive_scene.py b/manimlib/scene/interactive_scene.py index d1cc6d1f..82b06c82 100644 --- a/manimlib/scene/interactive_scene.py +++ b/manimlib/scene/interactive_scene.py @@ -1,29 +1,29 @@ -import numpy as np import itertools as it -import pyperclip +import numpy as np import os import platform +import pyperclip from manimlib.animation.fading import FadeIn -from manimlib.constants import MANIM_COLORS, WHITE, YELLOW -from manimlib.constants import ORIGIN, UP, DOWN, LEFT, RIGHT, DL, UL, UR, DR +from manimlib.constants import ARROW_SYMBOLS, DELETE_SYMBOL, SHIFT_SYMBOL +from manimlib.constants import COMMAND_MODIFIER, SHIFT_MODIFIER +from manimlib.constants import DL, DOWN, DR, LEFT, ORIGIN, RIGHT, UL, UP, UR from manimlib.constants import FRAME_WIDTH, SMALL_BUFF -from manimlib.constants import SHIFT_SYMBOL, DELETE_SYMBOL, ARROW_SYMBOLS -from manimlib.constants import SHIFT_MODIFIER, COMMAND_MODIFIER -from manimlib.mobject.mobject import Mobject +from manimlib.constants import MANIM_COLORS, WHITE, YELLOW +from manimlib.logger import log from manimlib.mobject.geometry import Rectangle from manimlib.mobject.geometry import Square from manimlib.mobject.mobject import Group +from manimlib.mobject.mobject import Mobject from manimlib.mobject.svg.tex_mobject import Tex from manimlib.mobject.svg.text_mobject import Text -from manimlib.mobject.types.vectorized_mobject import VMobject -from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.dot_cloud import DotCloud +from manimlib.mobject.types.vectorized_mobject import VGroup +from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.scene.scene import Scene -from manimlib.utils.tex_file_writing import LatexError from manimlib.utils.family_ops import extract_mobject_family_members from manimlib.utils.space_ops import get_norm -from manimlib.logger import log +from manimlib.utils.tex_file_writing import LatexError SELECT_KEY = 's' diff --git a/manimlib/scene/scene.py b/manimlib/scene/scene.py index ec3b097d..6b86de14 100644 --- a/manimlib/scene/scene.py +++ b/manimlib/scene/scene.py @@ -3,12 +3,10 @@ from __future__ import annotations from functools import wraps import inspect import itertools as it +import os import platform import random import time -import platform -from functools import wraps -import os import numpy as np from tqdm import tqdm as ProgressDisplay @@ -18,16 +16,16 @@ from manimlib.animation.transform import MoveToTarget from manimlib.camera.camera import Camera from manimlib.config import get_custom_config from manimlib.constants import ARROW_SYMBOLS -from manimlib.constants import SHIFT_MODIFIER, CTRL_MODIFIER, COMMAND_MODIFIER from manimlib.constants import DEFAULT_WAIT_TIME +from manimlib.constants import COMMAND_MODIFIER, CTRL_MODIFIER, SHIFT_MODIFIER from manimlib.event_handler import EVENT_DISPATCHER from manimlib.event_handler.event_type import EventType from manimlib.logger import log +from manimlib.mobject.mobject import Group from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Point -from manimlib.mobject.mobject import Group -from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.mobject.types.vectorized_mobject import VGroup +from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.scene.scene_file_writer import SceneFileWriter from manimlib.utils.config_ops import digest_config from manimlib.utils.family_ops import extract_mobject_family_members diff --git a/requirements.txt b/requirements.txt index a5225a8b..5c6ae599 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,23 +1,23 @@ colour -numpy -Pillow -scipy -sympy -tqdm +ipython +isosurfaces +manimpango>=0.4.0.post0,<0.5.0 mapbox-earcut matplotlib moderngl moderngl_window -skia-pathops +numpy +Pillow pydub pygments +PyOpenGL pyperclip pyyaml rich +scipy screeninfo -validators -ipython -PyOpenGL -manimpango>=0.4.0.post0,<0.5.0 -isosurfaces +skia-pathops svgelements +sympy +tqdm +validators diff --git a/setup.cfg b/setup.cfg index 934f051c..82c033f9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,27 +30,28 @@ packages = find: include_package_data = True install_requires = colour - numpy - Pillow - scipy - sympy - tqdm + ipython + isosurfaces + manimpango>=0.4.0.post0,<0.5.0 mapbox-earcut matplotlib moderngl moderngl_window - skia-pathops + numpy + Pillow pydub pygments + PyOpenGL + pyperclip pyyaml rich + scipy screeninfo - validators - ipython - PyOpenGL - manimpango>=0.4.0.post0,<0.5.0 - isosurfaces + skia-pathops svgelements + sympy + tqdm + validators [options.entry_points] console_scripts = From 1b2460f02a694314897437b9b8755443ed290cc1 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Fri, 22 Apr 2022 08:14:05 -0700 Subject: [PATCH 41/64] Remove refresh_shader_wrapper_id from Mobject.become --- manimlib/mobject/mobject.py | 1 - 1 file changed, 1 deletion(-) diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 10ebc35b..2e174a84 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -530,7 +530,6 @@ class Mobject(object): sm1.texture_paths = sm2.texture_paths sm1.depth_test = sm2.depth_test sm1.render_primitive = sm2.render_primitive - self.refresh_shader_wrapper_id() self.refresh_bounding_box(recurse_down=True) return self From 5927f6a1cd8a3be10391132e23df58dbcf216502 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Fri, 22 Apr 2022 08:14:29 -0700 Subject: [PATCH 42/64] Default to "" for scene_file_writer output dir --- manimlib/scene/scene_file_writer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/manimlib/scene/scene_file_writer.py b/manimlib/scene/scene_file_writer.py index a46ec01d..4b6b70c4 100644 --- a/manimlib/scene/scene_file_writer.py +++ b/manimlib/scene/scene_file_writer.py @@ -61,7 +61,7 @@ class SceneFileWriter(object): # Output directories and files def init_output_directories(self) -> None: - out_dir = self.output_directory + out_dir = self.output_directory or "" if self.mirror_module_path: module_dir = self.get_default_module_directory() out_dir = os.path.join(out_dir, module_dir) From c96bdc243e57c17bb75bf12d73ab5bf119cf1464 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Fri, 22 Apr 2022 08:16:17 -0700 Subject: [PATCH 43/64] Update Scene.embed to play nicely with gui interactions --- manimlib/scene/scene.py | 49 ++++++++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/manimlib/scene/scene.py b/manimlib/scene/scene.py index 5f91bb37..4303d737 100644 --- a/manimlib/scene/scene.py +++ b/manimlib/scene/scene.py @@ -14,7 +14,6 @@ import numpy as np from manimlib.animation.animation import prepare_animation from manimlib.animation.transform import MoveToTarget from manimlib.camera.camera import Camera -from manimlib.config import get_custom_config from manimlib.constants import DEFAULT_WAIT_TIME from manimlib.constants import ARROW_SYMBOLS from manimlib.constants import SHIFT_MODIFIER, CTRL_MODIFIER, COMMAND_MODIFIER @@ -143,32 +142,40 @@ class Scene(object): def embed(self, close_scene_on_exit: bool = True) -> None: if not self.preview: - # If the scene is just being - # written, ignore embed calls + # If the scene is just being written, ignore embed calls return self.stop_skipping() self.linger_after_completion = False self.update_frame() - - # Save scene state at the point of embedding self.save_state() - from IPython.terminal.embed import InteractiveShellEmbed - shell = InteractiveShellEmbed() - # Have the frame update after each command - shell.events.register('post_run_cell', lambda *a, **kw: self.refresh_static_mobjects()) - shell.events.register('post_run_cell', lambda *a, **kw: self.update_frame()) - # Use the locals of the caller as the local namespace - # once embedded, and add a few custom shortcuts + # Configure and launch embedded terminal + from IPython.terminal import embed, pt_inputhooks + shell = embed.InteractiveShellEmbed.instance() + + # Use the locals namespace of the caller local_ns = inspect.currentframe().f_back.f_locals - local_ns["touch"] = self.interact - local_ns["i2g"] = self.ids_to_group - for term in ("play", "wait", "add", "remove", "clear", "save_state", "restore"): + # Add a few custom shortcuts + for term in ("play", "wait", "add", "remove", "clear", "save_state", "restore", "i2g", "i2m"): local_ns[term] = getattr(self, term) - log.info("Tips: Now the embed iPython terminal is open. But you can't interact with" - " the window directly. To do so, you need to type `touch()` or `self.interact()`") - exec(get_custom_config()["universal_import_line"]) + + # Enables gui interactions during the embed + def inputhook(context): + while not context.input_is_ready(): + self.update_frame() + + pt_inputhooks.register("manim", inputhook) + shell.enable_gui("manim") + + # Have the frame update after each command + def post_cell_func(*args, **kwargs): + self.refresh_static_mobjects() + + shell.events.register("post_run_cell", post_cell_func) + + # Launch shell, with stack_depth=2 indicating we should use caller globals/locals shell(local_ns=local_ns, stack_depth=2) + # End scene when exiting an embed if close_scene_on_exit: raise EndSceneEarlyException() @@ -331,6 +338,12 @@ class Scene(object): map(self.id_to_mobject, id_values) )) + def i2g(self, *id_values): + return self.ids_to_group(*id_values) + + def i2m(self, id_value): + return self.id_to_mobject(id_value) + # Related to skipping def update_skipping_status(self) -> None: if self.start_at_animation_number is not None: From 2737d9a736885a594dd101ffe07bb82e00069333 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Fri, 22 Apr 2022 08:33:18 -0700 Subject: [PATCH 44/64] Have BlankScene inherit from InteractiveScene --- manimlib/extract_scene.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/manimlib/extract_scene.py b/manimlib/extract_scene.py index abec96ec..b13afe5b 100644 --- a/manimlib/extract_scene.py +++ b/manimlib/extract_scene.py @@ -3,11 +3,12 @@ import sys import copy from manimlib.scene.scene import Scene +from manimlib.scene.interactive_scene import InteractiveScene from manimlib.config import get_custom_config from manimlib.logger import log -class BlankScene(Scene): +class BlankScene(InteractiveScene): def construct(self): exec(get_custom_config()["universal_import_line"]) self.embed() From 581228b08f8bbd31fb0386e46860d76b79a6af98 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Fri, 22 Apr 2022 08:33:57 -0700 Subject: [PATCH 45/64] Have scene keep track of a map from mobject ids to mobjects for all it's ever seen --- manimlib/scene/scene.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/manimlib/scene/scene.py b/manimlib/scene/scene.py index 4303d737..a540cd09 100644 --- a/manimlib/scene/scene.py +++ b/manimlib/scene/scene.py @@ -76,6 +76,7 @@ class Scene(object): self.camera: Camera = self.camera_class(**self.camera_config) self.file_writer = SceneFileWriter(self, **self.file_writer_config) self.mobjects: list[Mobject] = [self.camera.frame] + self.id_to_mobject_map: dict[int, Mobject] = dict() self.num_plays: int = 0 self.time: float = 0 self.skip_time: float = 0 @@ -263,6 +264,11 @@ class Scene(object): """ self.remove(*new_mobjects) self.mobjects += new_mobjects + self.id_to_mobject_map.update({ + id(sm): sm + for m in new_mobjects + for sm in m.get_family() + }) return self def add_mobjects_among(self, values: Iterable): @@ -326,11 +332,7 @@ class Scene(object): return Group(*mobjects) def id_to_mobject(self, id_value): - for mob in self.mobjects: - for sm in mob.get_family(): - if id(sm) == id_value: - return sm - return None + return self.id_to_mobject_map[id_value] def ids_to_group(self, *id_values): return self.get_group(*filter( From e0f5686d667152582f052021cd62bd2ef8c6b470 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Fri, 22 Apr 2022 10:16:43 -0700 Subject: [PATCH 46/64] Fix bug with trying to close window during embed --- manimlib/scene/scene.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/manimlib/scene/scene.py b/manimlib/scene/scene.py index a540cd09..d3617f5b 100644 --- a/manimlib/scene/scene.py +++ b/manimlib/scene/scene.py @@ -163,7 +163,11 @@ class Scene(object): # Enables gui interactions during the embed def inputhook(context): while not context.input_is_ready(): - self.update_frame() + if self.window.is_closing: + pass + # self.window.destroy() + else: + self.update_frame(dt=0) pt_inputhooks.register("manim", inputhook) shell.enable_gui("manim") From bb7fa2c8aa68d7c7992517cfde3c7d0e804e13e8 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Fri, 22 Apr 2022 10:17:15 -0700 Subject: [PATCH 47/64] Update behavior of -e flag to take in (optional) strings as inputs --- manimlib/config.py | 63 +++++++++++++++++++++++++++++++++++----------- 1 file changed, 48 insertions(+), 15 deletions(-) diff --git a/manimlib/config.py b/manimlib/config.py index a2e68c51..479797d8 100644 --- a/manimlib/config.py +++ b/manimlib/config.py @@ -117,16 +117,18 @@ def parse_cli(): ) parser.add_argument( "-n", "--start_at_animation_number", - help="Start rendering not from the first animation, but" - "from another, specified by its index. If you pass" - "in two comma separated values, e.g. \"3,6\", it will end" + help="Start rendering not from the first animation, but " + "from another, specified by its index. If you pass " + "in two comma separated values, e.g. \"3,6\", it will end " "the rendering at the second value", ) parser.add_argument( - "-e", "--embed", metavar="LINENO", - help="Takes a line number as an argument, and results" - "in the scene being called as if the line `self.embed()`" - "was inserted into the scene code at that line number." + "-e", "--embed", + nargs="*", + help="Creates a new file where the line `self.embed` is inserted " + "into the Scenes construct method. " + "If a string is passed in, the line will be inserted below the " + "last line of code including that string." ) parser.add_argument( "-r", "--resolution", @@ -186,12 +188,43 @@ def get_module(file_name): @contextmanager -def insert_embed_line(file_name, lineno): +def insert_embed_line(file_name: str, scene_names: list[str], strings_to_match: str): + """ + This is hacky, but convenient. When user includes the argument "-e", it will try + to recreate a file that inserts the line `self.embed()` into the end of the scene's + construct method. If there is an argument passed in, it will insert the line after + the last line in the sourcefile which includes that string. + """ with open(file_name, 'r') as fp: lines = fp.readlines() - line = lines[lineno - 1] - n_spaces = len(line) - len(line.lstrip()) - lines.insert(lineno, " " * n_spaces + "self.embed()\n") + + line = None + if strings_to_match: + matching_lines = [ + line for line in lines + if any(s in line for s in strings_to_match) + ] + if matching_lines: + line = matching_lines[-1] + n_spaces = len(line) - len(line.lstrip()) + lines.insert(lines.index(line), " " * n_spaces + "self.embed()\n") + if line is None: + lineno = 0 + in_scene = False + in_construct = False + n_spaces = 8 + # Search for scene definition + for lineno, line in enumerate(lines): + indent = len(line) - len(line.lstrip()) + if line.startswith(f"class {scene_names[0]}"): + in_scene = True + elif in_scene and "def construct" in line: + in_construct = True + n_spaces = indent + 4 + elif in_construct: + if len(line.strip()) > 0 and indent < n_spaces: + break + lines.insert(lineno, " " * n_spaces + "self.embed()\n") alt_file = file_name.replace(".py", "_inserted_embed.py") with open(alt_file, 'w') as fp: @@ -296,10 +329,10 @@ def get_configuration(args): "quiet": args.quiet, } - if args.embed is None: - module = get_module(args.file) - else: - with insert_embed_line(args.file, int(args.embed)) as alt_file: + module = get_module(args.file) + + if args.embed is not None: + with insert_embed_line(args.file, args.scene_names, args.embed) as alt_file: module = get_module(alt_file) config = { From b9751e9d06068f27a327b419c52fd3c9d68db2e6 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Fri, 22 Apr 2022 10:17:29 -0700 Subject: [PATCH 48/64] Add cursor location label --- manimlib/scene/interactive_scene.py | 103 ++++++++++++++++++---------- 1 file changed, 67 insertions(+), 36 deletions(-) diff --git a/manimlib/scene/interactive_scene.py b/manimlib/scene/interactive_scene.py index 7ce23a89..b805134e 100644 --- a/manimlib/scene/interactive_scene.py +++ b/manimlib/scene/interactive_scene.py @@ -3,15 +3,16 @@ import itertools as it import pyperclip from manimlib.animation.fading import FadeIn -from manimlib.constants import MANIM_COLORS, WHITE +from manimlib.constants import MANIM_COLORS, WHITE, GREY_C from manimlib.constants import ORIGIN, UP, DOWN, LEFT, RIGHT, DL, UL, UR, DR from manimlib.constants import FRAME_WIDTH, SMALL_BUFF from manimlib.constants import SHIFT_SYMBOL, CTRL_SYMBOL, DELETE_SYMBOL, ARROW_SYMBOLS from manimlib.constants import SHIFT_MODIFIER, COMMAND_MODIFIER -from manimlib.mobject.mobject import Mobject from manimlib.mobject.geometry import Rectangle from manimlib.mobject.geometry import Square from manimlib.mobject.mobject import Group +from manimlib.mobject.mobject import Mobject +from manimlib.mobject.numbers import DecimalNumber from manimlib.mobject.svg.tex_mobject import Tex from manimlib.mobject.svg.text_mobject import Text from manimlib.mobject.types.vectorized_mobject import VMobject @@ -31,6 +32,7 @@ HORIZONTAL_GRAB_KEY = 'h' VERTICAL_GRAB_KEY = 'v' RESIZE_KEY = 't' COLOR_KEY = 'c' +CURSOR_LOCATION_KEY = 'l' # Note, a lot of the functionality here is still buggy and very much a work in progress. @@ -65,16 +67,23 @@ class InteractiveScene(Scene): selection_rectangle_stroke_width = 1.0 colors = MANIM_COLORS selection_nudge_size = 0.05 + cursor_location_config = dict( + font_size=14, + fill_color=GREY_C, + num_decimal_places=3, + ) def setup(self): self.selection = Group() self.selection_highlight = Group() self.selection_rectangle = self.get_selection_rectangle() self.color_palette = self.get_color_palette() + self.cursor_location_label = self.get_cursor_location_label() self.unselectables = [ self.selection, self.selection_highlight, self.selection_rectangle, + self.cursor_location_label, self.camera.frame ] self.saved_selection_state = [] @@ -83,6 +92,57 @@ class InteractiveScene(Scene): self.is_selecting = False self.add(self.selection_highlight) + def get_selection_rectangle(self): + rect = Rectangle( + stroke_color=self.selection_rectangle_stroke_color, + stroke_width=self.selection_rectangle_stroke_width, + ) + rect.fix_in_frame() + rect.fixed_corner = ORIGIN + rect.add_updater(self.update_selection_rectangle) + return rect + + def update_selection_rectangle(self, rect): + p1 = rect.fixed_corner + p2 = self.mouse_point.get_center() + rect.set_points_as_corners([ + p1, [p2[0], p1[1], 0], + p2, [p1[0], p2[1], 0], + p1, + ]) + return rect + + def get_color_palette(self): + palette = VGroup(*( + Square(fill_color=color, fill_opacity=1, side_length=1) + for color in self.colors + )) + palette.set_stroke(width=0) + palette.arrange(RIGHT, buff=0.5) + palette.set_width(FRAME_WIDTH - 0.5) + palette.to_edge(DOWN, buff=SMALL_BUFF) + palette.fix_in_frame() + return palette + + def get_cursor_location_label(self): + decimals = VGroup(*( + DecimalNumber(**self.cursor_location_config) + for n in range(3) + )) + + def update_coords(decimals): + for mob, coord in zip(decimals, self.mouse_point.get_location()): + mob.set_value(coord) + decimals.arrange(RIGHT, buff=decimals.get_height()) + decimals.to_corner(DR, buff=SMALL_BUFF) + decimals.fix_in_frame() + return decimals + + decimals.add_updater(update_coords) + return decimals + + # Related to selection + def toggle_selection_mode(self): self.select_top_level_mobs = not self.select_top_level_mobs self.refresh_selection_scope() @@ -115,28 +175,6 @@ class InteractiveScene(Scene): ) self.refresh_selection_highlight() - def get_selection_rectangle(self): - rect = Rectangle( - stroke_color=self.selection_rectangle_stroke_color, - stroke_width=self.selection_rectangle_stroke_width, - ) - rect.fix_in_frame() - rect.fixed_corner = ORIGIN - rect.add_updater(self.update_selection_rectangle) - return rect - - def get_color_palette(self): - palette = VGroup(*( - Square(fill_color=color, fill_opacity=1, side_length=1) - for color in self.colors - )) - palette.set_stroke(width=0) - palette.arrange(RIGHT, buff=0.5) - palette.set_width(FRAME_WIDTH - 0.5) - palette.to_edge(DOWN, buff=SMALL_BUFF) - palette.fix_in_frame() - return palette - def get_corner_dots(self, mobject): dots = DotCloud(**self.corner_dot_config) radius = self.corner_dot_config["radius"] @@ -164,16 +202,6 @@ class InteractiveScene(Scene): for mob in self.selection ]) - def update_selection_rectangle(self, rect): - p1 = rect.fixed_corner - p2 = self.mouse_point.get_center() - rect.set_points_as_corners([ - p1, [p2[0], p1[1], 0], - p2, [p1[0], p2[1], 0], - p1, - ]) - return rect - def add_to_selection(self, *mobjects): mobs = list(filter( lambda m: m not in self.unselectables and m not in self.selection, @@ -201,7 +229,7 @@ class InteractiveScene(Scene): mob.make_movable() super().add(*new_mobjects) - # Selection operations + # Functions for keyboard actions def copy_selection(self): ids = map(id, self.selection) @@ -288,6 +316,8 @@ class InteractiveScene(Scene): self.add(self.color_palette) else: self.remove(self.color_palette) + elif char == CURSOR_LOCATION_KEY and modifiers == 0: + self.add(self.cursor_location_label) # Command + c -> Copy mobject ids to clipboard elif char == "c" and modifiers == COMMAND_MODIFIER: self.copy_selection() @@ -350,7 +380,8 @@ class InteractiveScene(Scene): for mob in reversed(self.get_selection_search_set()): if mob.is_movable() and self.selection_rectangle.is_touching(mob): self.add_to_selection(mob) - + elif chr(symbol) == CURSOR_LOCATION_KEY: + self.remove(self.cursor_location_label) elif symbol == SHIFT_SYMBOL: if self.window.is_key_pressed(ord(RESIZE_KEY)): self.prepare_resizing(about_corner=False) From 4d8698a0e88333f6481c08d1b84b6e44f9dc4543 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Fri, 22 Apr 2022 11:42:26 -0700 Subject: [PATCH 49/64] Add Mobject.deserialize --- manimlib/mobject/mobject.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 2e174a84..3caeda69 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -470,6 +470,10 @@ class Mobject(object): self.parents = pre return result + def deserialize(self, data: bytes): + self.become(pickle.loads(data)) + return self + def copy(self): try: serial = self.serialize() From cf466006faa00fc12dc22f5732dc21ccedaa5a63 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Fri, 22 Apr 2022 11:44:28 -0700 Subject: [PATCH 50/64] Add undo and redo stacks for scene, together with Command + Z functionality --- manimlib/scene/interactive_scene.py | 25 ++++++----- manimlib/scene/scene.py | 70 ++++++++++++++++++++--------- 2 files changed, 63 insertions(+), 32 deletions(-) diff --git a/manimlib/scene/interactive_scene.py b/manimlib/scene/interactive_scene.py index b805134e..3c75334a 100644 --- a/manimlib/scene/interactive_scene.py +++ b/manimlib/scene/interactive_scene.py @@ -86,7 +86,6 @@ class InteractiveScene(Scene): self.cursor_location_label, self.camera.frame ] - self.saved_selection_state = [] self.select_top_level_mobs = True self.is_selecting = False @@ -210,7 +209,6 @@ class InteractiveScene(Scene): if mobs: self.selection.add(*mobs) self.selection_highlight.add(*map(self.get_highlight, mobs)) - self.saved_selection_state = [(mob, mob.copy()) for mob in self.selection] def toggle_from_selection(self, *mobjects): for mob in mobjects: @@ -267,14 +265,8 @@ class InteractiveScene(Scene): self.remove(*self.selection) self.clear_selection() - def undo(self): - mobs = [] - for mob, state in self.saved_selection_state: - mob.become(state) - mobs.append(mob) - if mob not in self.mobjects: - self.add(mob) - self.selection.set_submobjects(mobs) + def restore_state(self, mobject_states: list[tuple[Mobject, Mobject]]): + super().restore_state(mobject_states) self.refresh_selection_highlight() def prepare_resizing(self, about_corner=False): @@ -313,9 +305,11 @@ class InteractiveScene(Scene): if len(self.selection) == 0: return if self.color_palette not in self.mobjects: + self.save_state() self.add(self.color_palette) else: self.remove(self.color_palette) + # Show coordiantes of cursor location elif char == CURSOR_LOCATION_KEY and modifiers == 0: self.add(self.cursor_location_label) # Command + c -> Copy mobject ids to clipboard @@ -355,15 +349,18 @@ class InteractiveScene(Scene): # Command + t -> Toggle selection mode elif char == "t" and modifiers == COMMAND_MODIFIER: self.toggle_selection_mode() - # Command + z -> Restore selection to original state + # Command + z -> Undo elif char == "z" and modifiers == COMMAND_MODIFIER: self.undo() + # Command + shift + z -> Redo + elif char == "z" and modifiers == COMMAND_MODIFIER | SHIFT_MODIFIER: + self.redo() # Command + s -> Save selections to file elif char == "s" and modifiers == COMMAND_MODIFIER: to_save = self.selection if len(to_save) == 1: to_save = to_save[0] - self.save_mobect(to_save) + self.save_mobject_to_file(to_save) # Keyboard movements elif symbol in ARROW_SYMBOLS: nudge = self.selection_nudge_size @@ -372,6 +369,10 @@ class InteractiveScene(Scene): vect = [LEFT, UP, RIGHT, DOWN][ARROW_SYMBOLS.index(symbol)] self.selection.shift(nudge * vect) + # Conditions for saving state + if char in [GRAB_KEY, HORIZONTAL_GRAB_KEY, VERTICAL_GRAB_KEY, RESIZE_KEY]: + self.save_state() + def on_key_release(self, symbol: int, modifiers: int) -> None: super().on_key_release(symbol, modifiers) if chr(symbol) == SELECT_KEY: diff --git a/manimlib/scene/scene.py b/manimlib/scene/scene.py index d3617f5b..6ba65367 100644 --- a/manimlib/scene/scene.py +++ b/manimlib/scene/scene.py @@ -61,6 +61,7 @@ class Scene(object): "presenter_mode": False, "linger_after_completion": True, "pan_sensitivity": 3, + "max_num_saved_states": 20, } def __init__(self, **kwargs): @@ -70,6 +71,8 @@ class Scene(object): self.window = Window(scene=self, **self.window_config) self.camera_config["ctx"] = self.window.ctx self.camera_config["frame_rate"] = 30 # Where's that 30 from? + self.undo_stack = [] + self.redo_stack = [] else: self.window = None @@ -88,12 +91,16 @@ class Scene(object): self.mouse_point = Point() self.mouse_drag_point = Point() self.hold_on_wait = self.presenter_mode + self.inside_embed = False # Much nicer to work with deterministic scenes if self.random_seed is not None: random.seed(self.random_seed) np.random.seed(self.random_seed) + def __str__(self) -> str: + return self.__class__.__name__ + def run(self) -> None: self.virtual_animation_start_time: float = 0 self.real_animation_start_time: float = time.time() @@ -143,22 +150,28 @@ class Scene(object): def embed(self, close_scene_on_exit: bool = True) -> None: if not self.preview: - # If the scene is just being written, ignore embed calls + # Ignore embed calls when there is no preview return + self.inside_embed = True self.stop_skipping() self.linger_after_completion = False self.update_frame() self.save_state() - # Configure and launch embedded terminal + # Configure and launch embedded IPython terminal from IPython.terminal import embed, pt_inputhooks shell = embed.InteractiveShellEmbed.instance() # Use the locals namespace of the caller local_ns = inspect.currentframe().f_back.f_locals # Add a few custom shortcuts - for term in ("play", "wait", "add", "remove", "clear", "save_state", "restore", "i2g", "i2m"): - local_ns[term] = getattr(self, term) + local_ns.update({ + name: getattr(self, name) + for name in [ + "play", "wait", "add", "remove", "clear", + "save_state", "undo", "redo", "i2g", "i2m" + ] + }) # Enables gui interactions during the embed def inputhook(context): @@ -172,7 +185,7 @@ class Scene(object): pt_inputhooks.register("manim", inputhook) shell.enable_gui("manim") - # Have the frame update after each command + # Operation to run after each ipython command def post_cell_func(*args, **kwargs): self.refresh_static_mobjects() @@ -181,14 +194,13 @@ class Scene(object): # Launch shell, with stack_depth=2 indicating we should use caller globals/locals shell(local_ns=local_ns, stack_depth=2) + self.inside_embed = False # End scene when exiting an embed if close_scene_on_exit: raise EndSceneEarlyException() - def __str__(self) -> str: - return self.__class__.__name__ - # Only these methods should touch the camera + def get_image(self) -> Image: return self.camera.get_image() @@ -219,6 +231,7 @@ class Scene(object): self.file_writer.write_frame(self.camera) # Related to updating + def update_mobjects(self, dt: float) -> None: for mobject in self.mobjects: mobject.update(dt) @@ -237,6 +250,7 @@ class Scene(object): ]) # Related to time + def get_time(self) -> float: return self.time @@ -244,6 +258,7 @@ class Scene(object): self.time += dt # Related to internal mobject organization + def get_top_level_mobjects(self) -> list[Mobject]: # Return only those which are not in the family # of another mobject from the scene @@ -351,6 +366,7 @@ class Scene(object): return self.id_to_mobject(id_value) # Related to skipping + def update_skipping_status(self) -> None: if self.start_at_animation_number is not None: if self.num_plays == self.start_at_animation_number: @@ -366,6 +382,7 @@ class Scene(object): self.skip_animations = False # Methods associated with running animations + def get_time_progression( self, run_time: float, @@ -489,6 +506,8 @@ class Scene(object): def handle_play_like_call(func): @wraps(func) def wrapper(self, *args, **kwargs): + if self.inside_embed: + self.save_state() self.update_skipping_status() should_write = not self.skip_animations if should_write: @@ -610,21 +629,32 @@ class Scene(object): self.file_writer.add_sound(sound_file, time, gain, gain_to_background) # Helpers for interactive development + + def get_state(self) -> list[tuple[Mobject, Mobject]]: + return [(mob, mob.copy()) for mob in self.mobjects] + + def restore_state(self, mobject_states: list[tuple[Mobject, Mobject]]): + self.mobjects = [mob.become(mob_copy) for mob, mob_copy in mobject_states] + def save_state(self) -> None: - self.saved_state = [ - (mob, mob.copy()) - for mob in self.mobjects - ] + if not self.preview: + return + self.redo_stack = [] + self.undo_stack.append(self.get_state()) + if len(self.undo_stack) > self.max_num_saved_states: + self.undo_stack.pop(0) - def restore(self) -> None: - if not hasattr(self, "saved_state"): - raise Exception("Trying to restore scene without having saved") - self.mobjects = [] - for mob, mob_state in self.saved_state: - mob.become(mob_state) - self.mobjects.append(mob) + def undo(self): + if self.undo_stack: + self.redo_stack.append(self.get_state()) + self.restore_state(self.undo_stack.pop()) - def save_mobect(self, mobject: Mobject, file_path: str | None = None) -> None: + def redo(self): + if self.redo_stack: + self.undo_stack.append(self.get_state()) + self.restore_state(self.redo_stack.pop()) + + def save_mobject_to_file(self, mobject: Mobject, file_path: str | None = None) -> None: if file_path is None: file_path = self.file_writer.get_saved_mobject_path(mobject) if file_path is None: From b2e0aee93e9e7f6777c57aa69be56c03fd208447 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Fri, 22 Apr 2022 11:46:18 -0700 Subject: [PATCH 51/64] Get rid of ctrl + shift + e embed option --- manimlib/scene/scene.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/manimlib/scene/scene.py b/manimlib/scene/scene.py index 6ba65367..3f295ab5 100644 --- a/manimlib/scene/scene.py +++ b/manimlib/scene/scene.py @@ -16,7 +16,7 @@ from manimlib.animation.transform import MoveToTarget from manimlib.camera.camera import Camera from manimlib.constants import DEFAULT_WAIT_TIME from manimlib.constants import ARROW_SYMBOLS -from manimlib.constants import SHIFT_MODIFIER, CTRL_MODIFIER, COMMAND_MODIFIER +from manimlib.constants import COMMAND_MODIFIER from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Point from manimlib.mobject.mobject import Group @@ -42,7 +42,6 @@ FRAME_SHIFT_KEY = 'f' ZOOM_KEY = 'z' RESET_FRAME_KEY = 'r' QUIT_KEY = 'q' -EMBED_KEY = 'e' class Scene(object): @@ -787,9 +786,6 @@ class Scene(object): # Space or right arrow elif char == " " or symbol == ARROW_SYMBOLS[2]: self.hold_on_wait = False - # ctrl + shift + e - elif char == EMBED_KEY and modifiers == CTRL_MODIFIER | SHIFT_MODIFIER: - self.embed(close_scene_on_exit=False) def on_resize(self, width: int, height: int) -> None: self.camera.reset_pixel_shape(width, height) From 71c14969dffc8762a43f9646a0c3dc024a51b8df Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Fri, 22 Apr 2022 15:41:23 -0700 Subject: [PATCH 52/64] Refactor -e flag hackiness --- manimlib/config.py | 78 ++++++++++++++++++++++++++++------------------ 1 file changed, 48 insertions(+), 30 deletions(-) diff --git a/manimlib/config.py b/manimlib/config.py index 479797d8..59b4f567 100644 --- a/manimlib/config.py +++ b/manimlib/config.py @@ -124,7 +124,8 @@ def parse_cli(): ) parser.add_argument( "-e", "--embed", - nargs="*", + nargs="?", + const="", help="Creates a new file where the line `self.embed` is inserted " "into the Scenes construct method. " "If a string is passed in, the line will be inserted below the " @@ -187,8 +188,12 @@ def get_module(file_name): return module +def get_indent(line: str): + return len(line) - len(line.lstrip()) + + @contextmanager -def insert_embed_line(file_name: str, scene_names: list[str], strings_to_match: str): +def insert_embed_line(file_name: str, scene_name: str, line_marker: str): """ This is hacky, but convenient. When user includes the argument "-e", it will try to recreate a file that inserts the line `self.embed()` into the end of the scene's @@ -198,34 +203,47 @@ def insert_embed_line(file_name: str, scene_names: list[str], strings_to_match: with open(file_name, 'r') as fp: lines = fp.readlines() - line = None - if strings_to_match: - matching_lines = [ - line for line in lines - if any(s in line for s in strings_to_match) - ] - if matching_lines: - line = matching_lines[-1] - n_spaces = len(line) - len(line.lstrip()) - lines.insert(lines.index(line), " " * n_spaces + "self.embed()\n") - if line is None: - lineno = 0 - in_scene = False - in_construct = False - n_spaces = 8 - # Search for scene definition - for lineno, line in enumerate(lines): - indent = len(line) - len(line.lstrip()) - if line.startswith(f"class {scene_names[0]}"): - in_scene = True - elif in_scene and "def construct" in line: - in_construct = True - n_spaces = indent + 4 - elif in_construct: - if len(line.strip()) > 0 and indent < n_spaces: - break - lines.insert(lineno, " " * n_spaces + "self.embed()\n") + try: + scene_line_number = next( + i for i, line in enumerate(lines) + if line.startswith(f"class {scene_name}") + ) + except StopIteration: + log.error(f"No scene {scene_name}") + prev_line_num = None + n_spaces = None + if len(line_marker) == 0: + # Find the end of the construct method + in_construct = False + for index in range(scene_line_number, len(lines) - 1): + line = lines[index] + if line.lstrip().startswith("def construct"): + in_construct = True + n_spaces = get_indent(line) + 4 + elif in_construct: + if len(line.strip()) > 0 and get_indent(line) < n_spaces: + prev_line_num = index - 2 + break + elif line_marker.isdigit(): + # Treat the argument as a line number + prev_line_num = int(line_marker) - 1 + elif len(line_marker) > 0: + # Treat the argument as a string + try: + prev_line_num = next( + i + for i in range(len(lines) - 1, scene_line_number, -1) + if line_marker in lines[i] + ) + except StopIteration: + log.error(f"No lines matching {line_marker}") + sys.exit(2) + + # Insert and write new file + if n_spaces is None: + n_spaces = get_indent(lines[prev_line_num]) + lines.insert(prev_line_num + 1, " " * n_spaces + "self.embed()\n") alt_file = file_name.replace(".py", "_inserted_embed.py") with open(alt_file, 'w') as fp: fp.writelines(lines) @@ -332,7 +350,7 @@ def get_configuration(args): module = get_module(args.file) if args.embed is not None: - with insert_embed_line(args.file, args.scene_names, args.embed) as alt_file: + with insert_embed_line(args.file, args.scene_names[0], args.embed) as alt_file: module = get_module(alt_file) config = { From 59506b89cc73fff3b3736245dd72e61dcebf9a2c Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Fri, 22 Apr 2022 19:02:44 -0700 Subject: [PATCH 53/64] Revert to original copying scheme --- manimlib/mobject/mobject.py | 79 ++++++++++++++++++++++++++----------- 1 file changed, 57 insertions(+), 22 deletions(-) diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 3caeda69..5649e616 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -464,40 +464,76 @@ class Mobject(object): # Copying and serialization + def stash_mobject_pointers(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + uncopied_attrs = ["parents", "target", "saved_state"] + stash = dict() + for attr in uncopied_attrs: + if hasattr(self, attr): + value = getattr(self, attr) + stash[attr] = value + null_value = [] if isinstance(value, Iterable) else None + setattr(self, attr, null_value) + result = func(self, *args, **kwargs) + self.__dict__.update(stash) + return result + return wrapper + + @stash_mobject_pointers def serialize(self): - pre, self.parents = self.parents, [] - result = pickle.dumps(self) - self.parents = pre - return result + return pickle.dumps(self) def deserialize(self, data: bytes): self.become(pickle.loads(data)) return self - def copy(self): - try: - serial = self.serialize() - return pickle.loads(serial) - except AttributeError: - return copy.deepcopy(self) + @stash_mobject_pointers + def copy(self, deep: bool = False): + if deep: + try: + # Often faster than deepcopy + return pickle.loads(self.serialize()) + except AttributeError: + return copy.deepcopy(self) + + result = copy.copy(self) + + # The line above is only a shallow copy, so the internal + # data which are numpyu arrays or other mobjects still + # need to be further copied. + result.data = dict(self.data) + for key in result.data: + result.data[key] = result.data[key].copy() + + result.uniforms = dict(self.uniforms) + for key in result.uniforms: + if isinstance(result.uniforms[key], np.ndarray): + result.uniforms[key] = result.uniforms[key].copy() + + result.submobjects = [] + result.add(*(sm.copy() for sm in self.submobjects)) + result.match_updaters(self) + + family = self.get_family() + for attr, value in list(self.__dict__.items()): + if isinstance(value, Mobject) and value in family and value is not self: + setattr(result, attr, result.family[self.family.index(value)]) + if isinstance(value, np.ndarray): + setattr(result, attr, value.copy()) + if isinstance(value, ShaderWrapper): + setattr(result, attr, value.copy()) return result def deepcopy(self): - # This used to be different from copy, so is now just here for backward compatibility - return self.copy() + return self.copy(deep=True) def generate_target(self, use_deepcopy: bool = False): - # TODO, remove now pointless use_deepcopy arg - self.target = None # Prevent exponential explosion - self.target = self.copy() + self.target = self.copy(deep=use_deepcopy) return self.target def save_state(self, use_deepcopy: bool = False): - # TODO, remove now pointless use_deepcopy arg - if hasattr(self, "saved_state"): - # Prevent exponential growth of data - self.saved_state = None - self.saved_state = self.copy() + self.saved_state = self.copy(deep=use_deepcopy) return self def restore(self): @@ -540,9 +576,8 @@ class Mobject(object): # Creating new Mobjects from this one def replicate(self, n: int) -> Group: - serial = self.serialize() group_class = self.get_group_class() - return group_class(*(pickle.loads(serial) for _ in range(n))) + return group_class(*(self.copy() for _ in range(n))) def get_grid(self, n_rows: int, n_cols: int, height: float | None = None, **kwargs) -> Group: """ From 7b342a27591a07298fb2d61d717324fabfa01f9b Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Fri, 22 Apr 2022 19:03:00 -0700 Subject: [PATCH 54/64] Remove unnecessary lines --- manimlib/scene/interactive_scene.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/manimlib/scene/interactive_scene.py b/manimlib/scene/interactive_scene.py index 3c75334a..0e7bc1fb 100644 --- a/manimlib/scene/interactive_scene.py +++ b/manimlib/scene/interactive_scene.py @@ -23,7 +23,6 @@ from manimlib.scene.scene import Scene from manimlib.utils.tex_file_writing import LatexError from manimlib.utils.family_ops import extract_mobject_family_members from manimlib.utils.space_ops import get_norm -from manimlib.logger import log SELECT_KEY = 's' @@ -320,8 +319,6 @@ class InteractiveScene(Scene): self.paste_selection() # Command + x -> Cut elif char == "x" and modifiers == COMMAND_MODIFIER: - # TODO, this copy won't work, because once the objects are removed, - # they're not searched for in the pasting. self.copy_selection() self.delete_selection() # Delete From 3961005fd708333a3e77856d10e78451faa04075 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Fri, 22 Apr 2022 19:17:39 -0700 Subject: [PATCH 55/64] Rename is_movable to interaction_allowed --- manimlib/mobject/mobject.py | 12 ++++++------ manimlib/scene/interactive_scene.py | 19 +++++++++++++------ 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 5649e616..bf6b7465 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -84,7 +84,7 @@ class Mobject(object): self.locked_data_keys: set[str] = set() self.needs_new_bounding_box: bool = True self._is_animating: bool = False - self._is_movable: bool = False + self.interaction_allowed: bool = False self.init_data() self.init_uniforms() @@ -692,20 +692,20 @@ class Mobject(object): # Check if mark as static or not for camera def is_changing(self) -> bool: - return self._is_animating or self.has_updaters or self._is_movable + return self._is_animating or self.has_updaters or self.interaction_allowed def set_animating_status(self, is_animating: bool, recurse: bool = True) -> None: for mob in self.get_family(recurse): mob._is_animating = is_animating return self - def make_movable(self, value: bool = True, recurse: bool = True) -> None: + def allow_interaction(self, value: bool = True, recurse: bool = True) -> None: for mob in self.get_family(recurse): - mob._is_movable = value + mob.interaction_allowed = value return self - def is_movable(self) -> bool: - return self._is_movable + def is_interaction_allowed(self) -> bool: + return self.interaction_allowed # Transforming operations diff --git a/manimlib/scene/interactive_scene.py b/manimlib/scene/interactive_scene.py index 0e7bc1fb..86fd36b7 100644 --- a/manimlib/scene/interactive_scene.py +++ b/manimlib/scene/interactive_scene.py @@ -145,8 +145,11 @@ class InteractiveScene(Scene): self.select_top_level_mobs = not self.select_top_level_mobs self.refresh_selection_scope() - def get_selection_search_set(self): - mobs = [m for m in self.mobjects if m not in self.unselectables] + def get_selection_search_set(self) -> list[Mobject]: + mobs = [ + m for m in self.mobjects + if m not in self.unselectables and m.is_interaction_allowed() + ] if self.select_top_level_mobs: return mobs else: @@ -173,7 +176,7 @@ class InteractiveScene(Scene): ) self.refresh_selection_highlight() - def get_corner_dots(self, mobject): + def get_corner_dots(self, mobject: Mobject) -> Mobject: dots = DotCloud(**self.corner_dot_config) radius = self.corner_dot_config["radius"] if mobject.get_depth() < 1e-2: @@ -186,7 +189,7 @@ class InteractiveScene(Scene): ])) return dots - def get_highlight(self, mobject): + def get_highlight(self, mobject: Mobject) -> Mobject: if isinstance(mobject, VMobject) and mobject.has_points() and not self.select_top_level_mobs: result = VHighlight(mobject) result.add_updater(lambda m: m.replace(mobject)) @@ -223,9 +226,13 @@ class InteractiveScene(Scene): def add(self, *new_mobjects: Mobject): for mob in new_mobjects: - mob.make_movable() + mob.allow_interaction() super().add(*new_mobjects) + def disable_interaction(self, *mobjects: Mobject): + for mob in mobjects: + mob.allow_interaction(False) + # Functions for keyboard actions def copy_selection(self): @@ -376,7 +383,7 @@ class InteractiveScene(Scene): self.is_selecting = False self.remove(self.selection_rectangle) for mob in reversed(self.get_selection_search_set()): - if mob.is_movable() and self.selection_rectangle.is_touching(mob): + if self.selection_rectangle.is_touching(mob): self.add_to_selection(mob) elif chr(symbol) == CURSOR_LOCATION_KEY: self.remove(self.cursor_location_label) From 62289045cc8e102121cfe4d7739f3c89102046fb Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Fri, 22 Apr 2022 19:42:47 -0700 Subject: [PATCH 56/64] Fix animating Mobject.restore bug --- manimlib/mobject/mobject.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index bf6b7465..277f75dc 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -85,6 +85,8 @@ class Mobject(object): self.needs_new_bounding_box: bool = True self._is_animating: bool = False self.interaction_allowed: bool = False + self.saved_state = None + self.target = None self.init_data() self.init_uniforms() @@ -473,7 +475,7 @@ class Mobject(object): if hasattr(self, attr): value = getattr(self, attr) stash[attr] = value - null_value = [] if isinstance(value, Iterable) else None + null_value = [] if isinstance(value, list) else None setattr(self, attr, null_value) result = func(self, *args, **kwargs) self.__dict__.update(stash) @@ -530,14 +532,16 @@ class Mobject(object): def generate_target(self, use_deepcopy: bool = False): self.target = self.copy(deep=use_deepcopy) + self.target.saved_state = self.saved_state return self.target def save_state(self, use_deepcopy: bool = False): self.saved_state = self.copy(deep=use_deepcopy) + self.saved_state.target = self.target return self def restore(self): - if not hasattr(self, "saved_state") or self.save_state is None: + if not hasattr(self, "saved_state") or self.saved_state is None: raise Exception("Trying to restore without having saved") self.become(self.saved_state) return self From 04bca6cafbb1482b8f25cfb34ce83316d8a095c9 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Fri, 22 Apr 2022 23:14:00 -0700 Subject: [PATCH 57/64] Refresh static mobjects on undo's and redo's --- manimlib/scene/scene.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/manimlib/scene/scene.py b/manimlib/scene/scene.py index 3f295ab5..1af974b4 100644 --- a/manimlib/scene/scene.py +++ b/manimlib/scene/scene.py @@ -647,11 +647,13 @@ class Scene(object): if self.undo_stack: self.redo_stack.append(self.get_state()) self.restore_state(self.undo_stack.pop()) + self.refresh_static_mobjects() def redo(self): if self.redo_stack: self.undo_stack.append(self.get_state()) self.restore_state(self.redo_stack.pop()) + self.refresh_static_mobjects() def save_mobject_to_file(self, mobject: Mobject, file_path: str | None = None) -> None: if file_path is None: From 754316bf586be5a59839f8bac6fb9fcc47da0efb Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Fri, 22 Apr 2022 23:14:19 -0700 Subject: [PATCH 58/64] Factor out event handling --- manimlib/scene/interactive_scene.py | 300 ++++++++++++++++------------ 1 file changed, 174 insertions(+), 126 deletions(-) diff --git a/manimlib/scene/interactive_scene.py b/manimlib/scene/interactive_scene.py index 86fd36b7..eb42665c 100644 --- a/manimlib/scene/interactive_scene.py +++ b/manimlib/scene/interactive_scene.py @@ -27,8 +27,9 @@ from manimlib.utils.space_ops import get_norm SELECT_KEY = 's' GRAB_KEY = 'g' -HORIZONTAL_GRAB_KEY = 'h' -VERTICAL_GRAB_KEY = 'v' +X_GRAB_KEY = 'h' +Y_GRAB_KEY = 'v' +GRAB_KEYS = [GRAB_KEY, X_GRAB_KEY, Y_GRAB_KEY] RESIZE_KEY = 't' COLOR_KEY = 'c' CURSOR_LOCATION_KEY = 'l' @@ -36,6 +37,7 @@ CURSOR_LOCATION_KEY = 'l' # Note, a lot of the functionality here is still buggy and very much a work in progress. + class InteractiveScene(Scene): """ To select mobjects on screen, hold ctrl and move the mouse to highlight a region, @@ -86,8 +88,10 @@ class InteractiveScene(Scene): self.camera.frame ] self.select_top_level_mobs = True + self.regenerate_selection_search_set() self.is_selecting = False + self.is_grabbing = False self.add(self.selection_highlight) def get_selection_rectangle(self): @@ -144,18 +148,22 @@ class InteractiveScene(Scene): def toggle_selection_mode(self): self.select_top_level_mobs = not self.select_top_level_mobs self.refresh_selection_scope() + self.regenerate_selection_search_set() def get_selection_search_set(self) -> list[Mobject]: - mobs = [ - m for m in self.mobjects - if m not in self.unselectables and m.is_interaction_allowed() - ] + return self.selection_search_set + + def regenerate_selection_search_set(self): + selectable = list(filter( + lambda m: m not in self.unselectables, + self.mobjects + )) if self.select_top_level_mobs: - return mobs + self.selection_search_set = selectable else: - return [ + self.selection_search_set = [ submob - for mob in mobs + for mob in selectable for submob in mob.family_members_with_points() ] @@ -208,30 +216,47 @@ class InteractiveScene(Scene): lambda m: m not in self.unselectables and m not in self.selection, mobjects )) - if mobs: - self.selection.add(*mobs) - self.selection_highlight.add(*map(self.get_highlight, mobs)) + if len(mobs) == 0: + return + self.selection.add(*mobs) + self.selection_highlight.add(*map(self.get_highlight, mobs)) + for mob in mobs: + mob.set_animating_status(True) + self.refresh_static_mobjects() def toggle_from_selection(self, *mobjects): for mob in mobjects: if mob in self.selection: self.selection.remove(mob) + mob.set_animating_status(False) else: self.add_to_selection(mob) self.refresh_selection_highlight() def clear_selection(self): + for mob in self.selection: + mob.set_animating_status(False) self.selection.set_submobjects([]) self.selection_highlight.set_submobjects([]) + self.refresh_static_mobjects() def add(self, *new_mobjects: Mobject): - for mob in new_mobjects: - mob.allow_interaction() super().add(*new_mobjects) + self.regenerate_selection_search_set() + + def remove(self, *mobjects: Mobject): + super().remove(*mobjects) + self.regenerate_selection_search_set() def disable_interaction(self, *mobjects: Mobject): for mob in mobjects: - mob.allow_interaction(False) + self.unselectables.append(mob) + self.regenerate_selection_search_set() + + def enable_interaction(self, *mobjects: Mobject): + for mob in mobjects: + if mob in self.unselectables: + self.unselectables.remove(mob) # Functions for keyboard actions @@ -247,11 +272,11 @@ class InteractiveScene(Scene): mobs = map(self.id_to_mobject, ids) mob_copies = [m.copy() for m in mobs if m is not None] self.clear_selection() - self.add_to_selection(*mob_copies) self.play(*( FadeIn(mc, run_time=0.5, scale=1.5) for mc in mob_copies )) + self.add_to_selection(*mob_copies) return except ValueError: pass @@ -275,6 +300,23 @@ class InteractiveScene(Scene): super().restore_state(mobject_states) self.refresh_selection_highlight() + def enable_selection(self): + self.is_selecting = True + self.add(self.selection_rectangle) + self.selection_rectangle.fixed_corner = self.mouse_point.get_center().copy() + + def gather_new_selection(self): + self.is_selecting = False + self.remove(self.selection_rectangle) + for mob in reversed(self.get_selection_search_set()): + if self.selection_rectangle.is_touching(mob): + self.add_to_selection(mob) + + def prepare_grab(self): + mp = self.mouse_point.get_center() + self.mouse_to_selection = mp - self.selection.get_center() + self.is_grabbing = True + def prepare_resizing(self, about_corner=False): center = self.selection.get_center() mp = self.mouse_point.get_center() @@ -286,169 +328,175 @@ class InteractiveScene(Scene): self.scale_ref_width = self.selection.get_width() self.scale_ref_height = self.selection.get_height() - # Event handlers + def toggle_color_palette(self): + if len(self.selection) == 0: + return + if self.color_palette not in self.mobjects: + self.save_state() + self.add(self.color_palette) + else: + self.remove(self.color_palette) + + def group_selection(self): + group = self.get_group(*self.selection) + self.add(group) + self.clear_selection() + self.add_to_selection(group) + + def ungroup_selection(self): + pieces = [] + for mob in list(self.selection): + self.remove(mob) + pieces.extend(list(mob)) + self.clear_selection() + self.add(*pieces) + self.add_to_selection(*pieces) + + def nudge_selection(self, vect: np.ndarray, large: bool = False): + nudge = self.selection_nudge_size + if large: + nudge *= 10 + self.selection.shift(nudge * vect) + + def save_selection_to_file(self): + if len(self.selection) == 1: + self.save_mobject_to_file(self.selection[0]) + else: + self.save_mobject_to_file(self.selection) def on_key_press(self, symbol: int, modifiers: int) -> None: super().on_key_press(symbol, modifiers) char = chr(symbol) - # Enable selection if char == SELECT_KEY and modifiers == 0: - self.is_selecting = True - self.add(self.selection_rectangle) - self.selection_rectangle.fixed_corner = self.mouse_point.get_center().copy() - # Prepare for move - elif char in [GRAB_KEY, HORIZONTAL_GRAB_KEY, VERTICAL_GRAB_KEY] and modifiers == 0: - mp = self.mouse_point.get_center() - self.mouse_to_selection = mp - self.selection.get_center() - # Prepare for resizing + self.enable_selection() + elif char in GRAB_KEYS and modifiers == 0: + self.prepare_grab() elif char == RESIZE_KEY and modifiers in [0, SHIFT_MODIFIER]: self.prepare_resizing(about_corner=(modifiers == SHIFT_MODIFIER)) elif symbol == SHIFT_SYMBOL: if self.window.is_key_pressed(ord("t")): self.prepare_resizing(about_corner=True) - # Show color palette elif char == COLOR_KEY and modifiers == 0: - if len(self.selection) == 0: - return - if self.color_palette not in self.mobjects: - self.save_state() - self.add(self.color_palette) - else: - self.remove(self.color_palette) - # Show coordiantes of cursor location + self.toggle_color_palette() elif char == CURSOR_LOCATION_KEY and modifiers == 0: self.add(self.cursor_location_label) - # Command + c -> Copy mobject ids to clipboard elif char == "c" and modifiers == COMMAND_MODIFIER: self.copy_selection() - # Command + v -> Paste elif char == "v" and modifiers == COMMAND_MODIFIER: self.paste_selection() - # Command + x -> Cut elif char == "x" and modifiers == COMMAND_MODIFIER: self.copy_selection() self.delete_selection() - # Delete elif symbol == DELETE_SYMBOL: self.delete_selection() - # Command + a -> Select all elif char == "a" and modifiers == COMMAND_MODIFIER: self.clear_selection() self.add_to_selection(*self.mobjects) - # Command + g -> Group selection elif char == "g" and modifiers == COMMAND_MODIFIER: - group = self.get_group(*self.selection) - self.add(group) - self.clear_selection() - self.add_to_selection(group) - # Command + shift + g -> Ungroup the selection + self.group_selection() elif char == "g" and modifiers == COMMAND_MODIFIER | SHIFT_MODIFIER: - pieces = [] - for mob in list(self.selection): - self.remove(mob) - pieces.extend(list(mob)) - self.clear_selection() - self.add(*pieces) - self.add_to_selection(*pieces) - # Command + t -> Toggle selection mode + self.ungroup_selection() elif char == "t" and modifiers == COMMAND_MODIFIER: self.toggle_selection_mode() - # Command + z -> Undo elif char == "z" and modifiers == COMMAND_MODIFIER: self.undo() - # Command + shift + z -> Redo elif char == "z" and modifiers == COMMAND_MODIFIER | SHIFT_MODIFIER: self.redo() - # Command + s -> Save selections to file elif char == "s" and modifiers == COMMAND_MODIFIER: - to_save = self.selection - if len(to_save) == 1: - to_save = to_save[0] - self.save_mobject_to_file(to_save) - # Keyboard movements + self.save_selection_to_file() elif symbol in ARROW_SYMBOLS: - nudge = self.selection_nudge_size - if (modifiers & SHIFT_MODIFIER): - nudge *= 10 - vect = [LEFT, UP, RIGHT, DOWN][ARROW_SYMBOLS.index(symbol)] - self.selection.shift(nudge * vect) + self.nudge_selection( + vect=[LEFT, UP, RIGHT, DOWN][ARROW_SYMBOLS.index(symbol)], + large=(modifiers & SHIFT_MODIFIER), + ) # Conditions for saving state - if char in [GRAB_KEY, HORIZONTAL_GRAB_KEY, VERTICAL_GRAB_KEY, RESIZE_KEY]: + if char in [GRAB_KEY, X_GRAB_KEY, Y_GRAB_KEY, RESIZE_KEY]: self.save_state() def on_key_release(self, symbol: int, modifiers: int) -> None: super().on_key_release(symbol, modifiers) if chr(symbol) == SELECT_KEY: - self.is_selecting = False - self.remove(self.selection_rectangle) - for mob in reversed(self.get_selection_search_set()): - if self.selection_rectangle.is_touching(mob): - self.add_to_selection(mob) + self.gather_new_selection() + if chr(symbol) in GRAB_KEYS: + self.is_grabbing = False elif chr(symbol) == CURSOR_LOCATION_KEY: self.remove(self.cursor_location_label) - elif symbol == SHIFT_SYMBOL: - if self.window.is_key_pressed(ord(RESIZE_KEY)): - self.prepare_resizing(about_corner=False) + elif symbol == SHIFT_SYMBOL and self.window.is_key_pressed(ord(RESIZE_KEY)): + self.prepare_resizing(about_corner=False) + + # Mouse actions + def handle_grabbing(self, point: np.ndarray): + diff = point - self.mouse_to_selection + if self.window.is_key_pressed(ord(GRAB_KEY)): + self.selection.move_to(diff) + elif self.window.is_key_pressed(ord(X_GRAB_KEY)): + self.selection.set_x(diff[0]) + elif self.window.is_key_pressed(ord(Y_GRAB_KEY)): + self.selection.set_y(diff[1]) + + def handle_resizing(self, point: np.ndarray): + vect = point - self.scale_about_point + if self.window.is_key_pressed(CTRL_SYMBOL): + for i in (0, 1): + scalar = vect[i] / self.scale_ref_vect[i] + self.selection.rescale_to_fit( + scalar * [self.scale_ref_width, self.scale_ref_height][i], + dim=i, + about_point=self.scale_about_point, + stretch=True, + ) + else: + scalar = get_norm(vect) / get_norm(self.scale_ref_vect) + self.selection.set_width( + scalar * self.scale_ref_width, + about_point=self.scale_about_point + ) + + def handle_sweeping_selection(self, point: np.ndarray): + mob = self.point_to_mobject( + point, search_set=self.get_selection_search_set(), + buff=SMALL_BUFF + ) + if mob is not None: + self.add_to_selection(mob) + + def choose_color(self, point: np.ndarray): + # Search through all mobject on the screen, not just the palette + to_search = [ + sm + for mobject in self.mobjects + for sm in mobject.family_members_with_points() + if mobject not in self.unselectables + ] + mob = self.point_to_mobject(point, to_search) + if mob is not None: + self.selection.set_color(mob.get_color()) + self.remove(self.color_palette) + + def toggle_clicked_mobject_from_selection(self, point: np.ndarray): + mob = self.point_to_mobject( + point, + search_set=self.get_selection_search_set(), + buff=SMALL_BUFF + ) + if mob is not None: + self.toggle_from_selection(mob) def on_mouse_motion(self, point: np.ndarray, d_point: np.ndarray) -> None: super().on_mouse_motion(point, d_point) - # Move selection - if self.window.is_key_pressed(ord(GRAB_KEY)): - self.selection.move_to(point - self.mouse_to_selection) - # Move selection restricted to horizontal - elif self.window.is_key_pressed(ord(HORIZONTAL_GRAB_KEY)): - self.selection.set_x((point - self.mouse_to_selection)[0]) - # Move selection restricted to vertical - elif self.window.is_key_pressed(ord(VERTICAL_GRAB_KEY)): - self.selection.set_y((point - self.mouse_to_selection)[1]) - # Scale selection + if self.is_grabbing: + self.handle_grabbing(point) elif self.window.is_key_pressed(ord(RESIZE_KEY)): - vect = point - self.scale_about_point - if self.window.is_key_pressed(CTRL_SYMBOL): - for i in (0, 1): - scalar = vect[i] / self.scale_ref_vect[i] - self.selection.rescale_to_fit( - scalar * [self.scale_ref_width, self.scale_ref_height][i], - dim=i, - about_point=self.scale_about_point, - stretch=True, - ) - else: - scalar = get_norm(vect) / get_norm(self.scale_ref_vect) - self.selection.set_width( - scalar * self.scale_ref_width, - about_point=self.scale_about_point - ) - # Add to selection + self.handle_resizing(point) elif self.window.is_key_pressed(ord(SELECT_KEY)) and self.window.is_key_pressed(SHIFT_SYMBOL): - mob = self.point_to_mobject( - point, search_set=self.get_selection_search_set(), - buff=SMALL_BUFF - ) - if mob is not None: - self.add_to_selection(mob) + self.handle_sweeping_selection(point) def on_mouse_release(self, point: np.ndarray, button: int, mods: int) -> None: super().on_mouse_release(point, button, mods) if self.color_palette in self.mobjects: - # Search through all mobject on the screne, not just the palette - to_search = list(it.chain(*( - mobject.family_members_with_points() - for mobject in self.mobjects - if mobject not in self.unselectables - ))) - mob = self.point_to_mobject(point, to_search) - if mob is not None: - self.selection.set_color(mob.get_color()) - self.remove(self.color_palette) + self.choose_color(point) elif self.window.is_key_pressed(SHIFT_SYMBOL): - mob = self.point_to_mobject( - point, - search_set=self.get_selection_search_set(), - buff=SMALL_BUFF - ) - if mob is not None: - self.toggle_from_selection(mob) + self.toggle_clicked_mobject_from_selection(point) else: self.clear_selection() From f70e91348c8241bcb96470e7881dd92d9d3386d3 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Fri, 22 Apr 2022 23:14:57 -0700 Subject: [PATCH 59/64] Remove Mobject.interaction_allowed, in favor of using _is_animating for multiple purposes --- manimlib/mobject/mobject.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 277f75dc..9a70c4b9 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -84,7 +84,6 @@ class Mobject(object): self.locked_data_keys: set[str] = set() self.needs_new_bounding_box: bool = True self._is_animating: bool = False - self.interaction_allowed: bool = False self.saved_state = None self.target = None @@ -495,7 +494,7 @@ class Mobject(object): if deep: try: # Often faster than deepcopy - return pickle.loads(self.serialize()) + return pickle.loads(pickle.dumps(self)) except AttributeError: return copy.deepcopy(self) @@ -696,21 +695,13 @@ class Mobject(object): # Check if mark as static or not for camera def is_changing(self) -> bool: - return self._is_animating or self.has_updaters or self.interaction_allowed + return self._is_animating or self.has_updaters def set_animating_status(self, is_animating: bool, recurse: bool = True) -> None: for mob in self.get_family(recurse): mob._is_animating = is_animating return self - def allow_interaction(self, value: bool = True, recurse: bool = True) -> None: - for mob in self.get_family(recurse): - mob.interaction_allowed = value - return self - - def is_interaction_allowed(self) -> bool: - return self.interaction_allowed - # Transforming operations def shift(self, vector: np.ndarray): From 0fd8491c515ad23ca308099abe0f39fc38e2dd0e Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Sat, 23 Apr 2022 09:20:44 -0700 Subject: [PATCH 60/64] Move Command + z and Command + shift + z behavior to Scene --- manimlib/scene/interactive_scene.py | 4 ---- manimlib/scene/scene.py | 6 +++++- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/manimlib/scene/interactive_scene.py b/manimlib/scene/interactive_scene.py index a59adeea..68603add 100644 --- a/manimlib/scene/interactive_scene.py +++ b/manimlib/scene/interactive_scene.py @@ -398,10 +398,6 @@ class InteractiveScene(Scene): self.ungroup_selection() elif char == "t" and modifiers == COMMAND_MODIFIER: self.toggle_selection_mode() - elif char == "z" and modifiers == COMMAND_MODIFIER: - self.undo() - elif char == "z" and modifiers == COMMAND_MODIFIER | SHIFT_MODIFIER: - self.redo() elif char == "s" and modifiers == COMMAND_MODIFIER: self.save_selection_to_file() elif symbol in ARROW_SYMBOLS: diff --git a/manimlib/scene/scene.py b/manimlib/scene/scene.py index ca15eec3..a67c91e1 100644 --- a/manimlib/scene/scene.py +++ b/manimlib/scene/scene.py @@ -2,7 +2,6 @@ from __future__ import annotations from functools import wraps import inspect -import itertools as it import os import platform import random @@ -17,6 +16,7 @@ from manimlib.camera.camera import Camera from manimlib.constants import ARROW_SYMBOLS from manimlib.constants import DEFAULT_WAIT_TIME from manimlib.constants import COMMAND_MODIFIER +from manimlib.constants import SHIFT_MODIFIER from manimlib.event_handler import EVENT_DISPATCHER from manimlib.event_handler.event_type import EventType from manimlib.logger import log @@ -785,6 +785,10 @@ class Scene(object): if char == RESET_FRAME_KEY: self.camera.frame.to_default_state() + elif char == "z" and modifiers == COMMAND_MODIFIER: + self.undo() + elif char == "z" and modifiers == COMMAND_MODIFIER | SHIFT_MODIFIER: + self.redo() # command + q elif char == QUIT_KEY and modifiers == COMMAND_MODIFIER: self.quit_interaction = True From d733687834a64cebd43cd14d745176a46aec0ec6 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Sat, 23 Apr 2022 10:16:11 -0700 Subject: [PATCH 61/64] Have -e write over original source file, then correct --- manimlib/config.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/manimlib/config.py b/manimlib/config.py index 2236e0d6..1f863b8e 100644 --- a/manimlib/config.py +++ b/manimlib/config.py @@ -243,15 +243,15 @@ def insert_embed_line(file_name: str, scene_name: str, line_marker: str): # Insert and write new file if n_spaces is None: n_spaces = get_indent(lines[prev_line_num]) - lines.insert(prev_line_num + 1, " " * n_spaces + "self.embed()\n") - alt_file = file_name.replace(".py", "_inserted_embed.py") - with open(alt_file, 'w') as fp: - fp.writelines(lines) - + new_lines = list(lines) + new_lines.insert(prev_line_num + 1, " " * n_spaces + "self.embed()\n") + with open(file_name, 'w') as fp: + fp.writelines(new_lines) try: - yield alt_file + yield file_name finally: - os.remove(alt_file) + with open(file_name, 'w') as fp: + fp.writelines(lines) def get_custom_config(): From 587bc4d0bd662c8ecaf369c371d35e175ffeb08e Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Sat, 23 Apr 2022 10:16:23 -0700 Subject: [PATCH 62/64] Add necessary import --- manimlib/mobject/functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/manimlib/mobject/functions.py b/manimlib/mobject/functions.py index b6033fb8..9ec232a7 100644 --- a/manimlib/mobject/functions.py +++ b/manimlib/mobject/functions.py @@ -1,6 +1,7 @@ from __future__ import annotations from isosurfaces import plot_isoline +import numpy as np from manimlib.constants import FRAME_X_RADIUS, FRAME_Y_RADIUS from manimlib.constants import YELLOW From 902c2c002d6ca03c8080b2bd02ca36f2b8a748b6 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Sat, 23 Apr 2022 10:16:35 -0700 Subject: [PATCH 63/64] Slight copy refactor --- manimlib/mobject/mobject.py | 39 ++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 164cc9d7..c074c9f7 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -500,28 +500,31 @@ class Mobject(object): self.become(pickle.loads(data)) return self + def deepcopy(self): + try: + # Often faster than deepcopy + return pickle.loads(pickle.dumps(self)) + except AttributeError: + return copy.deepcopy(self) + @stash_mobject_pointers def copy(self, deep: bool = False): if deep: - try: - # Often faster than deepcopy - return pickle.loads(pickle.dumps(self)) - except AttributeError: - return copy.deepcopy(self) + return self.deepcopy() result = copy.copy(self) # The line above is only a shallow copy, so the internal # data which are numpyu arrays or other mobjects still # need to be further copied. - result.data = dict(self.data) - for key in result.data: - result.data[key] = result.data[key].copy() - - result.uniforms = dict(self.uniforms) - for key in result.uniforms: - if isinstance(result.uniforms[key], np.ndarray): - result.uniforms[key] = result.uniforms[key].copy() + result.data = { + key: np.array(value) + for key, value in self.data.items() + } + result.uniforms = { + key: np.array(value) + for key, value in self.uniforms.items() + } result.submobjects = [] result.add(*(sm.copy() for sm in self.submobjects)) @@ -529,17 +532,17 @@ class Mobject(object): family = self.get_family() for attr, value in list(self.__dict__.items()): - if isinstance(value, Mobject) and value in family and value is not self: - setattr(result, attr, result.family[self.family.index(value)]) + if isinstance(value, Mobject) and value is not self: + if value in family: + setattr(result, attr, result.family[self.family.index(value)]) + else: + setattr(result, attr, value.copy()) if isinstance(value, np.ndarray): setattr(result, attr, value.copy()) if isinstance(value, ShaderWrapper): setattr(result, attr, value.copy()) return result - def deepcopy(self): - return self.copy(deep=True) - def generate_target(self, use_deepcopy: bool = False): self.target = self.copy(deep=use_deepcopy) self.target.saved_state = self.saved_state From d9475a686002dbe27bfbddcf5b020d4e7bf85f97 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Sat, 23 Apr 2022 10:16:48 -0700 Subject: [PATCH 64/64] Remove unnecessary imports --- manimlib/mobject/svg/svg_mobject.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/manimlib/mobject/svg/svg_mobject.py b/manimlib/mobject/svg/svg_mobject.py index 88b99d12..5a918e66 100644 --- a/manimlib/mobject/svg/svg_mobject.py +++ b/manimlib/mobject/svg/svg_mobject.py @@ -1,7 +1,6 @@ from __future__ import annotations import hashlib -import itertools as it import os from xml.etree import ElementTree as ET @@ -17,7 +16,6 @@ from manimlib.mobject.geometry import Polyline from manimlib.mobject.geometry import Rectangle from manimlib.mobject.geometry import RoundedRectangle from manimlib.mobject.types.vectorized_mobject import VMobject -from manimlib.utils.config_ops import digest_config from manimlib.utils.directories import get_mobject_data_dir from manimlib.utils.images import get_full_vector_image_path from manimlib.utils.iterables import hash_obj