diff --git a/manimlib/mobject/svg/brace.py b/manimlib/mobject/svg/brace.py index 31217a28..f9d96cec 100644 --- a/manimlib/mobject/svg/brace.py +++ b/manimlib/mobject/svg/brace.py @@ -1,11 +1,15 @@ -import numpy as np +from __future__ import annotations + import math import copy +from typing import Iterable + +import numpy as np -from manimlib.animation.composition import AnimationGroup from manimlib.constants import * from manimlib.animation.fading import FadeIn from manimlib.animation.growing import GrowFromCenter +from manimlib.animation.composition import AnimationGroup from manimlib.mobject.svg.tex_mobject import Tex from manimlib.mobject.svg.tex_mobject import SingleStringTex from manimlib.mobject.svg.tex_mobject import TexText @@ -14,6 +18,10 @@ from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.config_ops import digest_config from manimlib.utils.space_ops import get_norm +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from manimlib.mobject.mobject import Mobject + from manimlib.animation.animation import Animation class Brace(SingleStringTex): CONFIG = { @@ -21,7 +29,12 @@ class Brace(SingleStringTex): "tex_string": r"\underbrace{\qquad}" } - def __init__(self, mobject, direction=DOWN, **kwargs): + def __init__( + self, + mobject: Mobject, + direction: np.ndarray = DOWN, + **kwargs + ): digest_config(self, kwargs, locals()) angle = -math.atan2(*direction[:2]) + PI mobject.rotate(-angle, about_point=ORIGIN) @@ -36,7 +49,7 @@ class Brace(SingleStringTex): for mob in mobject, self: mob.rotate(angle, about_point=ORIGIN) - def set_initial_width(self, width): + def set_initial_width(self, width: float): width_diff = width - self.get_width() if width_diff > 0: for tip, rect, vect in [(self[0], self[1], RIGHT), (self[5], self[4], LEFT)]: @@ -49,7 +62,12 @@ class Brace(SingleStringTex): self.set_width(width, stretch=True) return self - def put_at_tip(self, mob, use_next_to=True, **kwargs): + def put_at_tip( + self, + mob: Mobject, + use_next_to: bool = True, + **kwargs + ): if use_next_to: mob.next_to( self.get_tip(), @@ -63,24 +81,24 @@ class Brace(SingleStringTex): mob.shift(self.get_direction() * shift_distance) return self - def get_text(self, text, **kwargs): + def get_text(self, text: str, **kwargs) -> Text: buff = kwargs.pop("buff", SMALL_BUFF) text_mob = Text(text, **kwargs) self.put_at_tip(text_mob, buff=buff) return text_mob - def get_tex(self, *tex, **kwargs): + def get_tex(self, *tex: str, **kwargs) -> Tex: tex_mob = Tex(*tex) self.put_at_tip(tex_mob, **kwargs) return tex_mob - def get_tip(self): + def get_tip(self) -> np.ndarray: # Very specific to the LaTeX representation # of a brace, but it's the only way I can think # of to get the tip regardless of orientation. return self.get_all_points()[self.tip_point_index] - def get_direction(self): + def get_direction(self) -> np.ndarray: vect = self.get_tip() - self.get_center() return vect / get_norm(vect) @@ -92,14 +110,20 @@ class BraceLabel(VMobject): "label_buff": DEFAULT_MOBJECT_TO_MOBJECT_BUFFER } - def __init__(self, obj, text, brace_direction=DOWN, **kwargs): + def __init__( + self, + obj: list[VMobject] | Mobject, + text: Iterable[str] | str, + brace_direction: np.ndarray = DOWN, + **kwargs + ) -> None: VMobject.__init__(self, **kwargs) self.brace_direction = brace_direction if isinstance(obj, list): obj = VMobject(*obj) self.brace = Brace(obj, brace_direction, **kwargs) - if isinstance(text, tuple) or isinstance(text, list): + if isinstance(text, Iterable): self.label = self.label_constructor(*text, **kwargs) else: self.label = self.label_constructor(str(text)) @@ -109,10 +133,14 @@ class BraceLabel(VMobject): self.brace.put_at_tip(self.label, buff=self.label_buff) self.set_submobjects([self.brace, self.label]) - def creation_anim(self, label_anim=FadeIn, brace_anim=GrowFromCenter): + def creation_anim( + self, + label_anim: Animation = FadeIn, + brace_anim: Animation=GrowFromCenter + ) -> AnimationGroup: return AnimationGroup(brace_anim(self.brace), label_anim(self.label)) - def shift_brace(self, obj, **kwargs): + def shift_brace(self, obj: list[VMobject] | Mobject, **kwargs): if isinstance(obj, list): obj = VMobject(*obj) self.brace = Brace(obj, self.brace_direction, **kwargs) @@ -120,7 +148,7 @@ class BraceLabel(VMobject): self.submobjects[0] = self.brace return self - def change_label(self, *text, **kwargs): + def change_label(self, *text: str, **kwargs): self.label = self.label_constructor(*text, **kwargs) if self.label_scale != 1: self.label.scale(self.label_scale) @@ -129,7 +157,7 @@ class BraceLabel(VMobject): self.submobjects[1] = self.label return self - def change_brace_label(self, obj, *text): + def change_brace_label(self, obj: list[VMobject] | Mobject, *text: str): self.shift_brace(obj) self.change_label(*text) return self diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 84c0cbf5..c7b1438c 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -1,6 +1,10 @@ -import itertools as it +from __future__ import annotations + import re +import colour +import itertools as it from types import MethodType +from typing import Iterable, Union, Sequence from manimlib.constants import BLACK from manimlib.mobject.svg.svg_mobject import SVGMobject @@ -14,14 +18,16 @@ from manimlib.utils.tex_file_writing import get_tex_config from manimlib.utils.tex_file_writing import display_during_execution from manimlib.logger import log +ManimColor = Union[str, colour.Color, Sequence[float]] + SCALE_FACTOR_PER_FONT_POINT = 0.001 -TEX_HASH_TO_MOB_MAP = {} +TEX_HASH_TO_MOB_MAP: dict[int, VGroup] = {} -def _get_neighbouring_pairs(iterable): +def _get_neighbouring_pairs(iterable: Iterable) -> list: return list(adjacent_pairs(iterable))[:-1] @@ -38,17 +44,19 @@ class _TexSVG(SVGMobject): class _TexParser(object): - def __init__(self, tex_string, additional_substrings): + def __init__(self, tex_string: str, additional_substrings: str): self.tex_string = tex_string self.whitespace_indices = self.get_whitespace_indices() self.backslash_indices = self.get_backslash_indices() self.script_indices = self.get_script_indices() self.brace_indices_dict = self.get_brace_indices_dict() - self.tex_span_list = [] - self.script_span_to_char_dict = {} - self.script_span_to_tex_span_dict = {} - self.neighbouring_script_span_pairs = [] - self.specified_substrings = [] + self.tex_span_list: list[tuple[int, int]] = [] + self.script_span_to_char_dict: dict[tuple[int, int], str] = {} + self.script_span_to_tex_span_dict: dict[ + tuple[int, int], tuple[int, int] + ] = {} + self.neighbouring_script_span_pairs: list[tuple[int, int]] = [] + self.specified_substrings: list[str] = [] self.add_tex_span((0, len(tex_string))) self.break_up_by_scripts() self.break_up_by_double_braces() @@ -59,17 +67,17 @@ class _TexParser(object): ) self.containing_labels_dict = self.get_containing_labels_dict() - def add_tex_span(self, tex_span): + def add_tex_span(self, tex_span: tuple[int, int]) -> None: if tex_span not in self.tex_span_list: self.tex_span_list.append(tex_span) - def get_whitespace_indices(self): + def get_whitespace_indices(self) -> list[int]: return [ match_obj.start() for match_obj in re.finditer(r"\s", self.tex_string) ] - def get_backslash_indices(self): + def get_backslash_indices(self) -> list[int]: # Newlines (`\\`) don't count. return [ match_obj.end() - 1 @@ -77,19 +85,19 @@ class _TexParser(object): if len(match_obj.group()) % 2 == 1 ] - def filter_out_escaped_characters(self, indices): + def filter_out_escaped_characters(self, indices) -> list[int]: return list(filter( lambda index: index - 1 not in self.backslash_indices, indices )) - def get_script_indices(self): + def get_script_indices(self) -> list[int]: return self.filter_out_escaped_characters([ match_obj.start() for match_obj in re.finditer(r"[_^]", self.tex_string) ]) - def get_brace_indices_dict(self): + def get_brace_indices_dict(self) -> dict[int, int]: tex_string = self.tex_string indices = self.filter_out_escaped_characters([ match_obj.start() @@ -105,7 +113,7 @@ class _TexParser(object): result[left_brace_index] = index return result - def break_up_by_scripts(self): + def break_up_by_scripts(self) -> None: # Match subscripts & superscripts. tex_string = self.tex_string whitespace_indices = self.whitespace_indices @@ -154,7 +162,7 @@ class _TexParser(object): if span_0[1] == span_1[0]: self.neighbouring_script_span_pairs.append((span_0, span_1)) - def break_up_by_double_braces(self): + def break_up_by_double_braces(self) -> None: # Match paired double braces (`{{...}}`). tex_string = self.tex_string reversed_indices_dict = dict( @@ -178,7 +186,10 @@ class _TexParser(object): self.specified_substrings.append(tex_string[slice(*tex_span)]) skip = True - def break_up_by_additional_substrings(self, additional_substrings): + def break_up_by_additional_substrings( + self, + additional_substrings: Iterable[str] + ) -> None: stripped_substrings = sorted(remove_list_redundancies([ string.strip() for string in additional_substrings @@ -208,7 +219,7 @@ class _TexParser(object): continue self.add_tex_span((span_begin, span_end)) - def get_containing_labels_dict(self): + def get_containing_labels_dict(self) -> dict[tuple[int, int], list[int]]: tex_span_list = self.tex_span_list result = { tex_span: [] @@ -233,7 +244,7 @@ class _TexParser(object): raise ValueError return result - def get_labelled_tex_string(self): + def get_labelled_tex_string(self) -> str: indices, _, flags, labels = zip(*sorted([ (*tex_span[::(1, -1)[flag]], flag, label) for label, tex_span in enumerate(self.tex_span_list) @@ -251,7 +262,7 @@ class _TexParser(object): return "".join(it.chain(*zip(command_pieces, string_pieces))) @staticmethod - def get_color_command(label): + def get_color_command(label: int) -> str: rg, b = divmod(label, 256) r, g = divmod(rg, 256) return "".join([ @@ -261,7 +272,7 @@ class _TexParser(object): "}" ]) - def get_sorted_submob_indices(self, submob_labels): + def get_sorted_submob_indices(self, submob_labels: Iterable[int]) -> list[int]: def script_span_to_submob_range(script_span): tex_span = self.script_span_to_tex_span_dict[script_span] submob_indices = [ @@ -295,7 +306,7 @@ class _TexParser(object): ] return result - def get_submob_tex_strings(self, submob_labels): + def get_submob_tex_strings(self, submob_labels: Iterable[int]) -> list[str]: ordered_tex_spans = [ self.tex_span_list[label] for label in submob_labels ] @@ -356,7 +367,10 @@ class _TexParser(object): ])) return result - def find_span_components_of_custom_span(self, custom_span): + def find_span_components_of_custom_span( + self, + custom_span: tuple[int, int] + ) -> list[tuple[int, int]] | None: skipped_indices = sorted(it.chain( self.whitespace_indices, self.script_indices @@ -384,16 +398,19 @@ class _TexParser(object): span_begin = next_begin return result - def get_containing_labels_by_tex_spans(self, tex_spans): + def get_containing_labels_by_tex_spans( + self, + tex_spans: Iterable[tuple[int, int]] + ) -> list[int]: return remove_list_redundancies(list(it.chain(*[ self.containing_labels_dict[tex_span] for tex_span in tex_spans ]))) - def get_specified_substrings(self): + def get_specified_substrings(self) -> list[str]: return self.specified_substrings - def get_isolated_substrings(self): + def get_isolated_substrings(self) -> list[str]: return remove_list_redundancies([ self.tex_string[slice(*tex_span)] for tex_span in self.tex_span_list @@ -412,7 +429,7 @@ class MTex(VMobject): "use_plain_tex": False, } - def __init__(self, tex_string, **kwargs): + def __init__(self, tex_string: str, **kwargs): super().__init__(**kwargs) tex_string = tex_string.strip() # Prevent from passing an empty string. @@ -431,12 +448,12 @@ class MTex(VMobject): self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) @staticmethod - def color_to_label(color): + def color_to_label(color: ManimColor) -> list: r, g, b = color_to_int_rgb(color) rg = r * 256 + g return rg * 256 + b - def generate_mobject(self): + def generate_mobject(self) -> VGroup: labelled_tex_string = self.__parser.get_labelled_tex_string() labelled_tex_content = self.get_tex_file_content(labelled_tex_string) hash_val = hash((labelled_tex_content, self.use_plain_tex)) @@ -471,7 +488,7 @@ class MTex(VMobject): TEX_HASH_TO_MOB_MAP[hash_val] = mob return mob - def get_tex_file_content(self, tex_string): + def get_tex_file_content(self, tex_string: str) -> str: if self.tex_environment: tex_string = "\n".join([ f"\\begin{{{self.tex_environment}}}", @@ -483,7 +500,7 @@ class MTex(VMobject): return tex_string @staticmethod - def tex_content_to_glyphs(tex_content): + def tex_content_to_glyphs(tex_content: str) -> _TexSVG: tex_config = get_tex_config() full_tex = tex_config["tex_body"].replace( tex_config["text_to_replace"], @@ -492,7 +509,11 @@ class MTex(VMobject): filename = tex_to_svg_file(full_tex) return _TexSVG(filename) - def build_mobject(self, svg_glyphs, glyph_labels): + def build_mobject( + self, + svg_glyphs: _TexSVG | None, + glyph_labels: Iterable[int] + ) -> VGroup: if not svg_glyphs: return VGroup() @@ -530,14 +551,17 @@ class MTex(VMobject): submob.get_tex = MethodType(lambda inst: inst.tex_string, submob) return VGroup(*rearranged_submobjects) - def get_part_by_tex_spans(self, tex_spans): + def get_part_by_tex_spans( + self, + tex_spans: Iterable[tuple[int, int]] + ) -> VGroup: labels = self.__parser.get_containing_labels_by_tex_spans(tex_spans) return VGroup(*filter( lambda submob: submob.submob_label in labels, self.submobjects )) - def get_part_by_custom_span(self, custom_span): + def get_part_by_custom_span(self, custom_span: tuple[int, int]) -> VGroup: tex_spans = self.__parser.find_span_components_of_custom_span( custom_span ) @@ -546,7 +570,7 @@ class MTex(VMobject): raise ValueError(f"Failed to match mobjects from tex: \"{tex}\"") return self.get_part_by_tex_spans(tex_spans) - def get_parts_by_tex(self, tex): + def get_parts_by_tex(self, tex: str) -> VGroup: return VGroup(*[ self.get_part_by_custom_span(match_obj.span()) for match_obj in re.finditer( @@ -554,20 +578,23 @@ class MTex(VMobject): ) ]) - def get_part_by_tex(self, tex, index=0): + def get_part_by_tex(self, tex: str, index: int = 0) -> VGroup: all_parts = self.get_parts_by_tex(tex) return all_parts[index] - def set_color_by_tex(self, tex, color): + def set_color_by_tex(self, tex: str, color: ManimColor): self.get_parts_by_tex(tex).set_color(color) return self - def set_color_by_tex_to_color_map(self, tex_to_color_map): + def set_color_by_tex_to_color_map( + self, + tex_to_color_map: dict[str, ManimColor] + ): for tex, color in tex_to_color_map.items(): self.set_color_by_tex(tex, color) return self - def indices_of_part(self, part): + def indices_of_part(self, part: Iterable[VGroup]) -> list[int]: indices = [ index for index, submob in enumerate(self.submobjects) if submob in part @@ -576,23 +603,23 @@ class MTex(VMobject): raise ValueError("Failed to find part in tex") return indices - def indices_of_part_by_tex(self, tex, index=0): + def indices_of_part_by_tex(self, tex: str, index: int = 0) -> list[int]: part = self.get_part_by_tex(tex, index=index) return self.indices_of_part(part) - def get_tex(self): + def get_tex(self) -> str: return self.tex_string - def get_submob_tex(self): + def get_submob_tex(self) -> list[str]: return [ submob.get_tex() for submob in self.submobjects ] - def get_specified_substrings(self): + def get_specified_substrings(self) -> list[str]: return self.__parser.get_specified_substrings() - def get_isolated_substrings(self): + def get_isolated_substrings(self) -> list[str]: return self.__parser.get_isolated_substrings() diff --git a/manimlib/mobject/svg/svg_mobject.py b/manimlib/mobject/svg/svg_mobject.py index fd79dffa..ac3c66b5 100644 --- a/manimlib/mobject/svg/svg_mobject.py +++ b/manimlib/mobject/svg/svg_mobject.py @@ -1,7 +1,10 @@ +from __future__ import annotations + import os import re import hashlib import itertools as it +from typing import Callable import svgelements as se import numpy as np @@ -20,7 +23,7 @@ from manimlib.utils.images import get_full_vector_image_path from manimlib.logger import log -def _convert_point_to_3d(x, y): +def _convert_point_to_3d(x: float, y: float) -> np.ndarray: return np.array([x, y, 0.0]) @@ -41,7 +44,7 @@ class SVGMobject(VMobject): "path_string_config": {} } - def __init__(self, file_name=None, **kwargs): + def __init__(self, file_name: str | None = None, **kwargs): digest_config(self, kwargs) self.file_name = file_name or self.file_name if file_name is None: @@ -51,7 +54,7 @@ class SVGMobject(VMobject): super().__init__(**kwargs) self.move_into_position() - def move_into_position(self): + def move_into_position(self) -> None: if self.should_center: self.center() if self.height is not None: @@ -59,7 +62,7 @@ class SVGMobject(VMobject): if self.width is not None: self.set_width(self.width) - def init_colors(self): + def init_colors(self) -> None: # Remove fill_color, fill_opacity, # stroke_width, stroke_color, stroke_opacity # as each submobject may have those values specified in svg file @@ -68,7 +71,7 @@ class SVGMobject(VMobject): self.set_flat_stroke(self.flat_stroke) return self - def init_points(self): + def init_points(self) -> None: with open(self.file_path, "r") as svg_file: svg_string = svg_file.read() @@ -96,7 +99,7 @@ class SVGMobject(VMobject): self.flip(RIGHT) # Flip y self.scale(0.75) - def modify_svg_file(self, svg_string): + def modify_svg_file(self, svg_string: str) -> str: # svgelements cannot handle em, ex units # Convert them using 1em = 16px, 1ex = 0.5em = 8px def convert_unit(match_obj): @@ -127,7 +130,7 @@ class SVGMobject(VMobject): return result - def generate_context_values_from_config(self): + def generate_context_values_from_config(self) -> dict[str]: result = {} if self.stroke_width is not None: result["stroke-width"] = self.stroke_width @@ -145,7 +148,7 @@ class SVGMobject(VMobject): result["stroke-opacity"] = self.stroke_opacity return result - def get_mobjects_from(self, shape): + def get_mobjects_from(self, shape) -> list[VMobject]: if isinstance(shape, se.Group): return list(it.chain(*( self.get_mobjects_from(child) @@ -161,7 +164,7 @@ class SVGMobject(VMobject): return [mob] @staticmethod - def handle_transform(mob, matrix): + def handle_transform(mob: VMobject, matrix: se.Matrix) -> VMobject: mat = np.array([ [matrix.a, matrix.c], [matrix.b, matrix.d] @@ -171,8 +174,10 @@ class SVGMobject(VMobject): mob.shift(vec) return mob - def get_mobject_from(self, shape): - shape_class_to_func_map = { + def get_mobject_from(self, shape: se.Shape | se.Text) -> VMobject | None: + shape_class_to_func_map: dict[ + type, Callable[[se.Shape | se.Text], VMobject] + ] = { se.Path: self.path_to_mobject, se.SimpleLine: self.line_to_mobject, se.Rect: self.rect_to_mobject, @@ -194,7 +199,10 @@ class SVGMobject(VMobject): return None @staticmethod - def apply_style_to_mobject(mob, shape): + def apply_style_to_mobject( + mob: VMobject, + shape: se.Shape | se.Text + ) -> VMobject: mob.set_style( stroke_width=shape.stroke_width, stroke_color=shape.stroke.hex, @@ -204,16 +212,16 @@ class SVGMobject(VMobject): ) return mob - def path_to_mobject(self, path): + def path_to_mobject(self, path: se.Path) -> VMobjectFromSVGPath: return VMobjectFromSVGPath(path, **self.path_string_config) - def line_to_mobject(self, line): + def line_to_mobject(self, line: se.Line) -> Line: return Line( start=_convert_point_to_3d(line.x1, line.y1), end=_convert_point_to_3d(line.x2, line.y2) ) - def rect_to_mobject(self, rect): + def rect_to_mobject(self, rect: se.Rect) -> Rectangle | RoundedRectangle: if rect.rx == 0 or rect.ry == 0: mob = Rectangle( width=rect.width, @@ -232,7 +240,7 @@ class SVGMobject(VMobject): )) return mob - def circle_to_mobject(self, circle): + def circle_to_mobject(self, circle: se.Circle) -> Circle: # svgelements supports `rx` & `ry` but `r` mob = Circle(radius=circle.rx) mob.shift(_convert_point_to_3d( @@ -240,7 +248,7 @@ class SVGMobject(VMobject): )) return mob - def ellipse_to_mobject(self, ellipse): + def ellipse_to_mobject(self, ellipse: se.Ellipse) -> Circle: mob = Circle(radius=ellipse.rx) mob.stretch_to_fit_height(2 * ellipse.ry) mob.shift(_convert_point_to_3d( @@ -248,21 +256,21 @@ class SVGMobject(VMobject): )) return mob - def polygon_to_mobject(self, polygon): + def polygon_to_mobject(self, polygon: se.Polygon) -> Polygon: points = [ _convert_point_to_3d(*point) for point in polygon ] return Polygon(*points) - def polyline_to_mobject(self, polyline): + def polyline_to_mobject(self, polyline: se.Polyline) -> Polyline: points = [ _convert_point_to_3d(*point) for point in polyline ] return Polyline(*points) - def text_to_mobject(self, text): + def text_to_mobject(self, text: se.Text): pass @@ -273,13 +281,13 @@ class VMobjectFromSVGPath(VMobject): "should_remove_null_curves": False, } - def __init__(self, path_obj, **kwargs): + def __init__(self, path_obj: se.Path, **kwargs): # Get rid of arcs path_obj.approximate_arcs_with_quads() self.path_obj = path_obj super().__init__(**kwargs) - def init_points(self): + def init_points(self) -> None: # After a given svg_path has been converted into points, the result # will be saved to a file so that future calls for the same path # don't need to retrace the same computation. @@ -305,7 +313,7 @@ class VMobjectFromSVGPath(VMobject): np.save(points_filepath, self.get_points()) np.save(tris_filepath, self.get_triangulation()) - def handle_commands(self): + def handle_commands(self) -> None: segment_class_to_func_map = { se.Move: (self.start_new_path, ("end",)), se.Close: (self.close_path, ()), diff --git a/manimlib/mobject/svg/tex_mobject.py b/manimlib/mobject/svg/tex_mobject.py index c81a781b..0e765c48 100644 --- a/manimlib/mobject/svg/tex_mobject.py +++ b/manimlib/mobject/svg/tex_mobject.py @@ -1,5 +1,9 @@ +from __future__ import annotations + +from typing import Iterable, Sequence, Union from functools import reduce import operator as op +import colour import re from manimlib.constants import * @@ -13,10 +17,13 @@ from manimlib.utils.tex_file_writing import get_tex_config from manimlib.utils.tex_file_writing import display_during_execution +ManimColor = Union[str, colour.Color, Sequence[float]] + SCALE_FACTOR_PER_FONT_POINT = 0.001 - -tex_string_with_color_to_mob_map = {} +tex_string_with_color_to_mob_map: dict[ + tuple[ManimColor, str], SVGMobject +] = {} class SingleStringTex(VMobject): @@ -31,7 +38,7 @@ class SingleStringTex(VMobject): "math_mode": True, } - def __init__(self, tex_string, **kwargs): + def __init__(self, tex_string: str, **kwargs): super().__init__(**kwargs) assert(isinstance(tex_string, str)) self.tex_string = tex_string @@ -66,7 +73,7 @@ class SingleStringTex(VMobject): self.set_flat_stroke(self.flat_stroke) return self - def get_tex_file_body(self, tex_string): + def get_tex_file_body(self, tex_string: str) -> str: new_tex = self.get_modified_expression(tex_string) if self.math_mode: new_tex = "\\begin{align*}\n" + new_tex + "\n\\end{align*}" @@ -79,10 +86,10 @@ class SingleStringTex(VMobject): new_tex ) - def get_modified_expression(self, tex_string): + def get_modified_expression(self, tex_string: str) -> str: return self.modify_special_strings(tex_string.strip()) - def modify_special_strings(self, tex): + def modify_special_strings(self, tex: str) -> str: tex = tex.strip() should_add_filler = reduce(op.or_, [ # Fraction line needs something to be over @@ -134,7 +141,7 @@ class SingleStringTex(VMobject): tex = "" return tex - def balance_braces(self, tex): + def balance_braces(self, tex: str) -> str: """ Makes Tex resiliant to unmatched braces """ @@ -154,7 +161,7 @@ class SingleStringTex(VMobject): tex += num_unclosed_brackets * "}" return tex - def get_tex(self): + def get_tex(self) -> str: return self.tex_string def organize_submobjects_left_to_right(self): @@ -169,7 +176,7 @@ class Tex(SingleStringTex): "tex_to_color_map": {}, } - def __init__(self, *tex_strings, **kwargs): + def __init__(self, *tex_strings: str, **kwargs): digest_config(self, kwargs) self.tex_strings = self.break_up_tex_strings(tex_strings) full_string = self.arg_separator.join(self.tex_strings) @@ -180,7 +187,7 @@ class Tex(SingleStringTex): if self.organize_left_to_right: self.organize_submobjects_left_to_right() - def break_up_tex_strings(self, tex_strings): + def break_up_tex_strings(self, tex_strings: Iterable[str]) -> Iterable[str]: # Separate out any strings specified in the isolate # or tex_to_color_map lists. substrings_to_isolate = [*self.isolate, *self.tex_to_color_map.keys()] @@ -228,7 +235,12 @@ class Tex(SingleStringTex): self.set_submobjects(new_submobjects) return self - def get_parts_by_tex(self, tex, substring=True, case_sensitive=True): + def get_parts_by_tex( + self, + tex: str, + substring: bool = True, + case_sensitive: bool = True + ) -> VGroup: def test(tex1, tex2): if not case_sensitive: tex1 = tex1.lower() @@ -243,27 +255,36 @@ class Tex(SingleStringTex): self.submobjects )) - def get_part_by_tex(self, tex, **kwargs): + def get_part_by_tex(self, tex: str, **kwargs) -> SingleStringTex | None: all_parts = self.get_parts_by_tex(tex, **kwargs) return all_parts[0] if all_parts else None - def set_color_by_tex(self, tex, color, **kwargs): + def set_color_by_tex(self, tex: str, color: ManimColor, **kwargs): self.get_parts_by_tex(tex, **kwargs).set_color(color) return self - def set_color_by_tex_to_color_map(self, tex_to_color_map, **kwargs): + def set_color_by_tex_to_color_map( + self, + tex_to_color_map: dict[str, ManimColor], + **kwargs + ): for tex, color in list(tex_to_color_map.items()): self.set_color_by_tex(tex, color, **kwargs) return self - def index_of_part(self, part, start=0): + def index_of_part(self, part: SingleStringTex, start: int = 0) -> int: return self.submobjects.index(part, start) - def index_of_part_by_tex(self, tex, start=0, **kwargs): + def index_of_part_by_tex(self, tex: str, start: int = 0, **kwargs) -> int: part = self.get_part_by_tex(tex, **kwargs) return self.index_of_part(part, start) - def slice_by_tex(self, start_tex=None, stop_tex=None, **kwargs): + def slice_by_tex( + self, + start_tex: str | None = None, + stop_tex: str | None = None, + **kwargs + ) -> VGroup: if start_tex is None: start_index = 0 else: @@ -275,10 +296,10 @@ class Tex(SingleStringTex): stop_index = self.index_of_part_by_tex(stop_tex, start=start_index, **kwargs) return self[start_index:stop_index] - def sort_alphabetically(self): + def sort_alphabetically(self) -> None: self.submobjects.sort(key=lambda m: m.get_tex()) - def set_bstroke(self, color=BLACK, width=4): + def set_bstroke(self, color: ManimColor = BLACK, width: float = 4): self.set_stroke(color, width, background=True) return self @@ -297,7 +318,7 @@ class BulletedList(TexText): "alignment": "", } - def __init__(self, *items, **kwargs): + def __init__(self, *items: str, **kwargs): line_separated_items = [s + "\\\\" for s in items] TexText.__init__(self, *line_separated_items, **kwargs) for part in self: @@ -310,7 +331,7 @@ class BulletedList(TexText): buff=self.buff ) - def fade_all_but(self, index_or_string, opacity=0.5): + def fade_all_but(self, index_or_string: int | str, opacity: float = 0.5) -> None: arg = index_or_string if isinstance(arg, str): part = self.get_part_by_tex(arg) @@ -348,7 +369,7 @@ class Title(TexText): "underline_buff": MED_SMALL_BUFF, } - def __init__(self, *text_parts, **kwargs): + def __init__(self, *text_parts: str, **kwargs): TexText.__init__(self, *text_parts, **kwargs) self.scale(self.scale_factor) self.to_edge(UP) diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index e412f08d..668cddbf 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -1,17 +1,19 @@ -import hashlib +from __future__ import annotations + import os import re import io -import typing -import xml.etree.ElementTree as ET +import hashlib import functools +from pathlib import Path +import xml.etree.ElementTree as ET +from contextlib import contextmanager +from typing import Iterable, Sequence, Union + import pygments import pygments.lexers import pygments.styles -from contextlib import contextmanager -from pathlib import Path - import manimpango from manimlib.logger import log from manimlib.constants import * @@ -23,6 +25,12 @@ from manimlib.utils.customization import get_customization from manimlib.utils.directories import get_downloads_dir, get_text_dir from manimpango import PangoUtils, TextSetting, MarkupUtils +from typing import TYPE_CHECKING +if TYPE_CHECKING: + import colour + from manimlib.mobject.types.vectorized_mobject import VMobject + ManimColor = Union[str, colour.Color, Sequence[float]] + TEXT_MOB_SCALE_FACTOR = 0.0076 DEFAULT_LINE_SPACING_SCALE = 0.6 @@ -50,7 +58,7 @@ class Text(SVGMobject): "disable_ligatures": True, } - def __init__(self, text, **kwargs): + def __init__(self, text: str, **kwargs): self.full2short(kwargs) digest_config(self, kwargs) if self.size: @@ -60,9 +68,9 @@ class Text(SVGMobject): ) self.font_size = self.size if self.lsh == -1: - self.lsh = self.font_size + self.font_size * DEFAULT_LINE_SPACING_SCALE + self.lsh: float = self.font_size + self.font_size * DEFAULT_LINE_SPACING_SCALE else: - self.lsh = self.font_size + self.font_size * self.lsh + self.lsh: float = self.font_size + self.font_size * self.lsh text_without_tabs = text if text.find('\t') != -1: text_without_tabs = text.replace('\t', ' ' * self.tab_width) @@ -87,14 +95,14 @@ class Text(SVGMobject): if self.height is None: self.scale(TEXT_MOB_SCALE_FACTOR) - def remove_empty_path(self, file_name): + def remove_empty_path(self, file_name: str) -> None: with open(file_name, 'r') as fpr: content = fpr.read() content = re.sub(r'', '', content) with open(file_name, 'w') as fpw: fpw.write(content) - def apply_space_chars(self): + def apply_space_chars(self) -> None: submobs = self.submobjects.copy() for char_index in range(len(self.text)): if self.text[char_index] in [" ", "\t", "\n"]: @@ -103,7 +111,7 @@ class Text(SVGMobject): submobs.insert(char_index, space) self.set_submobjects(submobs) - def find_indexes(self, word): + def find_indexes(self, word: str) -> list[tuple[int, int]]: m = re.match(r'\[([0-9\-]{0,}):([0-9\-]{0,})\]', word) if m: start = int(m.group(1)) if m.group(1) != '' else 0 @@ -119,20 +127,20 @@ class Text(SVGMobject): index = self.text.find(word, index + len(word)) return indexes - def get_parts_by_text(self, word): + def get_parts_by_text(self, word: str) -> VGroup: return VGroup(*( self[i:j] for i, j in self.find_indexes(word) )) - def get_part_by_text(self, word): + def get_part_by_text(self, word: str) -> VMobject | None: parts = self.get_parts_by_text(word) if len(parts) > 0: return parts[0] else: return None - def full2short(self, config): + def full2short(self, config: dict[str]) -> None: for kwargs in [config, self.CONFIG]: if kwargs.__contains__('line_spacing_height'): kwargs['lsh'] = kwargs.pop('line_spacing_height') @@ -147,19 +155,25 @@ class Text(SVGMobject): if kwargs.__contains__('text2weight'): kwargs['t2w'] = kwargs.pop('text2weight') - def set_color_by_t2c(self, t2c=None): + def set_color_by_t2c( + self, + t2c: dict[str, ManimColor] | None = None + ) -> None: t2c = t2c if t2c else self.t2c for word, color in t2c.items(): for start, end in self.find_indexes(word): self[start:end].set_color(color) - def set_color_by_t2g(self, t2g=None): + def set_color_by_t2g( + self, + t2g: dict[str, Iterable[ManimColor]] | None = None + ) -> None: t2g = t2g if t2g else self.t2g for word, gradient in t2g.items(): for start, end in self.find_indexes(word): self[start:end].set_color_by_gradient(*gradient) - def text2hash(self): + def text2hash(self) -> str: settings = self.font + self.slant + self.weight settings += str(self.t2f) + str(self.t2s) + str(self.t2w) settings += str(self.lsh) + str(self.font_size) @@ -168,7 +182,7 @@ class Text(SVGMobject): hasher.update(id_str.encode()) return hasher.hexdigest()[:16] - def text2settings(self): + def text2settings(self) -> list[TextSetting]: """ Substrings specified in t2f, t2s, t2w can occupy each other. For each category of style, a stack following first-in-last-out is constructed, @@ -227,7 +241,7 @@ class Text(SVGMobject): del self.line_num return settings - def text2svg(self): + def text2svg(self) -> str: # anti-aliasing size = self.font_size lsh = self.lsh @@ -503,7 +517,7 @@ class Code(Text): "char_width": None } - def __init__(self, code, **kwargs): + def __init__(self, code: str, **kwargs): self.full2short(kwargs) digest_config(self, kwargs) code = code.lstrip("\n") # avoid mismatches of character indices @@ -536,7 +550,7 @@ class Code(Text): if self.char_width is not None: self.set_monospace(self.char_width) - def set_monospace(self, char_width): + def set_monospace(self, char_width: float) -> None: current_char_index = 0 for i, char in enumerate(self.text): if char == "\n": @@ -548,7 +562,7 @@ class Code(Text): @contextmanager -def register_font(font_file: typing.Union[str, Path]): +def register_font(font_file: str | Path): """Temporarily add a font file to Pango's search path. This searches for the font_file at various places. The order it searches it described below. 1. Absolute path.