chore: add type hints to manimlib.mobject.svg

This commit is contained in:
TonyCrane 2022-02-14 22:55:41 +08:00
parent 61c70b426c
commit 773e013af9
No known key found for this signature in database
GPG key ID: 2313A5058A9C637C
5 changed files with 226 additions and 128 deletions

View file

@ -1,11 +1,15 @@
import numpy as np
from __future__ import annotations
import math
import copy
from typing import Iterable
import numpy as np
from manimlib.animation.composition import AnimationGroup
from manimlib.constants import *
from manimlib.animation.fading import FadeIn
from manimlib.animation.growing import GrowFromCenter
from manimlib.animation.composition import AnimationGroup
from manimlib.mobject.svg.tex_mobject import Tex
from manimlib.mobject.svg.tex_mobject import SingleStringTex
from manimlib.mobject.svg.tex_mobject import TexText
@ -14,6 +18,10 @@ from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.config_ops import digest_config
from manimlib.utils.space_ops import get_norm
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.mobject.mobject import Mobject
from manimlib.animation.animation import Animation
class Brace(SingleStringTex):
CONFIG = {
@ -21,7 +29,12 @@ class Brace(SingleStringTex):
"tex_string": r"\underbrace{\qquad}"
}
def __init__(self, mobject, direction=DOWN, **kwargs):
def __init__(
self,
mobject: Mobject,
direction: np.ndarray = DOWN,
**kwargs
):
digest_config(self, kwargs, locals())
angle = -math.atan2(*direction[:2]) + PI
mobject.rotate(-angle, about_point=ORIGIN)
@ -36,7 +49,7 @@ class Brace(SingleStringTex):
for mob in mobject, self:
mob.rotate(angle, about_point=ORIGIN)
def set_initial_width(self, width):
def set_initial_width(self, width: float):
width_diff = width - self.get_width()
if width_diff > 0:
for tip, rect, vect in [(self[0], self[1], RIGHT), (self[5], self[4], LEFT)]:
@ -49,7 +62,12 @@ class Brace(SingleStringTex):
self.set_width(width, stretch=True)
return self
def put_at_tip(self, mob, use_next_to=True, **kwargs):
def put_at_tip(
self,
mob: Mobject,
use_next_to: bool = True,
**kwargs
):
if use_next_to:
mob.next_to(
self.get_tip(),
@ -63,24 +81,24 @@ class Brace(SingleStringTex):
mob.shift(self.get_direction() * shift_distance)
return self
def get_text(self, text, **kwargs):
def get_text(self, text: str, **kwargs) -> Text:
buff = kwargs.pop("buff", SMALL_BUFF)
text_mob = Text(text, **kwargs)
self.put_at_tip(text_mob, buff=buff)
return text_mob
def get_tex(self, *tex, **kwargs):
def get_tex(self, *tex: str, **kwargs) -> Tex:
tex_mob = Tex(*tex)
self.put_at_tip(tex_mob, **kwargs)
return tex_mob
def get_tip(self):
def get_tip(self) -> np.ndarray:
# Very specific to the LaTeX representation
# of a brace, but it's the only way I can think
# of to get the tip regardless of orientation.
return self.get_all_points()[self.tip_point_index]
def get_direction(self):
def get_direction(self) -> np.ndarray:
vect = self.get_tip() - self.get_center()
return vect / get_norm(vect)
@ -92,14 +110,20 @@ class BraceLabel(VMobject):
"label_buff": DEFAULT_MOBJECT_TO_MOBJECT_BUFFER
}
def __init__(self, obj, text, brace_direction=DOWN, **kwargs):
def __init__(
self,
obj: list[VMobject] | Mobject,
text: Iterable[str] | str,
brace_direction: np.ndarray = DOWN,
**kwargs
) -> None:
VMobject.__init__(self, **kwargs)
self.brace_direction = brace_direction
if isinstance(obj, list):
obj = VMobject(*obj)
self.brace = Brace(obj, brace_direction, **kwargs)
if isinstance(text, tuple) or isinstance(text, list):
if isinstance(text, Iterable):
self.label = self.label_constructor(*text, **kwargs)
else:
self.label = self.label_constructor(str(text))
@ -109,10 +133,14 @@ class BraceLabel(VMobject):
self.brace.put_at_tip(self.label, buff=self.label_buff)
self.set_submobjects([self.brace, self.label])
def creation_anim(self, label_anim=FadeIn, brace_anim=GrowFromCenter):
def creation_anim(
self,
label_anim: Animation = FadeIn,
brace_anim: Animation=GrowFromCenter
) -> AnimationGroup:
return AnimationGroup(brace_anim(self.brace), label_anim(self.label))
def shift_brace(self, obj, **kwargs):
def shift_brace(self, obj: list[VMobject] | Mobject, **kwargs):
if isinstance(obj, list):
obj = VMobject(*obj)
self.brace = Brace(obj, self.brace_direction, **kwargs)
@ -120,7 +148,7 @@ class BraceLabel(VMobject):
self.submobjects[0] = self.brace
return self
def change_label(self, *text, **kwargs):
def change_label(self, *text: str, **kwargs):
self.label = self.label_constructor(*text, **kwargs)
if self.label_scale != 1:
self.label.scale(self.label_scale)
@ -129,7 +157,7 @@ class BraceLabel(VMobject):
self.submobjects[1] = self.label
return self
def change_brace_label(self, obj, *text):
def change_brace_label(self, obj: list[VMobject] | Mobject, *text: str):
self.shift_brace(obj)
self.change_label(*text)
return self

View file

@ -1,6 +1,10 @@
import itertools as it
from __future__ import annotations
import re
import colour
import itertools as it
from types import MethodType
from typing import Iterable, Union, Sequence
from manimlib.constants import BLACK
from manimlib.mobject.svg.svg_mobject import SVGMobject
@ -14,14 +18,16 @@ from manimlib.utils.tex_file_writing import get_tex_config
from manimlib.utils.tex_file_writing import display_during_execution
from manimlib.logger import log
ManimColor = Union[str, colour.Color, Sequence[float]]
SCALE_FACTOR_PER_FONT_POINT = 0.001
TEX_HASH_TO_MOB_MAP = {}
TEX_HASH_TO_MOB_MAP: dict[int, VGroup] = {}
def _get_neighbouring_pairs(iterable):
def _get_neighbouring_pairs(iterable: Iterable) -> list:
return list(adjacent_pairs(iterable))[:-1]
@ -38,17 +44,19 @@ class _TexSVG(SVGMobject):
class _TexParser(object):
def __init__(self, tex_string, additional_substrings):
def __init__(self, tex_string: str, additional_substrings: str):
self.tex_string = tex_string
self.whitespace_indices = self.get_whitespace_indices()
self.backslash_indices = self.get_backslash_indices()
self.script_indices = self.get_script_indices()
self.brace_indices_dict = self.get_brace_indices_dict()
self.tex_span_list = []
self.script_span_to_char_dict = {}
self.script_span_to_tex_span_dict = {}
self.neighbouring_script_span_pairs = []
self.specified_substrings = []
self.tex_span_list: list[tuple[int, int]] = []
self.script_span_to_char_dict: dict[tuple[int, int], str] = {}
self.script_span_to_tex_span_dict: dict[
tuple[int, int], tuple[int, int]
] = {}
self.neighbouring_script_span_pairs: list[tuple[int, int]] = []
self.specified_substrings: list[str] = []
self.add_tex_span((0, len(tex_string)))
self.break_up_by_scripts()
self.break_up_by_double_braces()
@ -59,17 +67,17 @@ class _TexParser(object):
)
self.containing_labels_dict = self.get_containing_labels_dict()
def add_tex_span(self, tex_span):
def add_tex_span(self, tex_span: tuple[int, int]) -> None:
if tex_span not in self.tex_span_list:
self.tex_span_list.append(tex_span)
def get_whitespace_indices(self):
def get_whitespace_indices(self) -> list[int]:
return [
match_obj.start()
for match_obj in re.finditer(r"\s", self.tex_string)
]
def get_backslash_indices(self):
def get_backslash_indices(self) -> list[int]:
# Newlines (`\\`) don't count.
return [
match_obj.end() - 1
@ -77,19 +85,19 @@ class _TexParser(object):
if len(match_obj.group()) % 2 == 1
]
def filter_out_escaped_characters(self, indices):
def filter_out_escaped_characters(self, indices) -> list[int]:
return list(filter(
lambda index: index - 1 not in self.backslash_indices,
indices
))
def get_script_indices(self):
def get_script_indices(self) -> list[int]:
return self.filter_out_escaped_characters([
match_obj.start()
for match_obj in re.finditer(r"[_^]", self.tex_string)
])
def get_brace_indices_dict(self):
def get_brace_indices_dict(self) -> dict[int, int]:
tex_string = self.tex_string
indices = self.filter_out_escaped_characters([
match_obj.start()
@ -105,7 +113,7 @@ class _TexParser(object):
result[left_brace_index] = index
return result
def break_up_by_scripts(self):
def break_up_by_scripts(self) -> None:
# Match subscripts & superscripts.
tex_string = self.tex_string
whitespace_indices = self.whitespace_indices
@ -154,7 +162,7 @@ class _TexParser(object):
if span_0[1] == span_1[0]:
self.neighbouring_script_span_pairs.append((span_0, span_1))
def break_up_by_double_braces(self):
def break_up_by_double_braces(self) -> None:
# Match paired double braces (`{{...}}`).
tex_string = self.tex_string
reversed_indices_dict = dict(
@ -178,7 +186,10 @@ class _TexParser(object):
self.specified_substrings.append(tex_string[slice(*tex_span)])
skip = True
def break_up_by_additional_substrings(self, additional_substrings):
def break_up_by_additional_substrings(
self,
additional_substrings: Iterable[str]
) -> None:
stripped_substrings = sorted(remove_list_redundancies([
string.strip()
for string in additional_substrings
@ -208,7 +219,7 @@ class _TexParser(object):
continue
self.add_tex_span((span_begin, span_end))
def get_containing_labels_dict(self):
def get_containing_labels_dict(self) -> dict[tuple[int, int], list[int]]:
tex_span_list = self.tex_span_list
result = {
tex_span: []
@ -233,7 +244,7 @@ class _TexParser(object):
raise ValueError
return result
def get_labelled_tex_string(self):
def get_labelled_tex_string(self) -> str:
indices, _, flags, labels = zip(*sorted([
(*tex_span[::(1, -1)[flag]], flag, label)
for label, tex_span in enumerate(self.tex_span_list)
@ -251,7 +262,7 @@ class _TexParser(object):
return "".join(it.chain(*zip(command_pieces, string_pieces)))
@staticmethod
def get_color_command(label):
def get_color_command(label: int) -> str:
rg, b = divmod(label, 256)
r, g = divmod(rg, 256)
return "".join([
@ -261,7 +272,7 @@ class _TexParser(object):
"}"
])
def get_sorted_submob_indices(self, submob_labels):
def get_sorted_submob_indices(self, submob_labels: Iterable[int]) -> list[int]:
def script_span_to_submob_range(script_span):
tex_span = self.script_span_to_tex_span_dict[script_span]
submob_indices = [
@ -295,7 +306,7 @@ class _TexParser(object):
]
return result
def get_submob_tex_strings(self, submob_labels):
def get_submob_tex_strings(self, submob_labels: Iterable[int]) -> list[str]:
ordered_tex_spans = [
self.tex_span_list[label] for label in submob_labels
]
@ -356,7 +367,10 @@ class _TexParser(object):
]))
return result
def find_span_components_of_custom_span(self, custom_span):
def find_span_components_of_custom_span(
self,
custom_span: tuple[int, int]
) -> list[tuple[int, int]] | None:
skipped_indices = sorted(it.chain(
self.whitespace_indices,
self.script_indices
@ -384,16 +398,19 @@ class _TexParser(object):
span_begin = next_begin
return result
def get_containing_labels_by_tex_spans(self, tex_spans):
def get_containing_labels_by_tex_spans(
self,
tex_spans: Iterable[tuple[int, int]]
) -> list[int]:
return remove_list_redundancies(list(it.chain(*[
self.containing_labels_dict[tex_span]
for tex_span in tex_spans
])))
def get_specified_substrings(self):
def get_specified_substrings(self) -> list[str]:
return self.specified_substrings
def get_isolated_substrings(self):
def get_isolated_substrings(self) -> list[str]:
return remove_list_redundancies([
self.tex_string[slice(*tex_span)]
for tex_span in self.tex_span_list
@ -412,7 +429,7 @@ class MTex(VMobject):
"use_plain_tex": False,
}
def __init__(self, tex_string, **kwargs):
def __init__(self, tex_string: str, **kwargs):
super().__init__(**kwargs)
tex_string = tex_string.strip()
# Prevent from passing an empty string.
@ -431,12 +448,12 @@ class MTex(VMobject):
self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size)
@staticmethod
def color_to_label(color):
def color_to_label(color: ManimColor) -> list:
r, g, b = color_to_int_rgb(color)
rg = r * 256 + g
return rg * 256 + b
def generate_mobject(self):
def generate_mobject(self) -> VGroup:
labelled_tex_string = self.__parser.get_labelled_tex_string()
labelled_tex_content = self.get_tex_file_content(labelled_tex_string)
hash_val = hash((labelled_tex_content, self.use_plain_tex))
@ -471,7 +488,7 @@ class MTex(VMobject):
TEX_HASH_TO_MOB_MAP[hash_val] = mob
return mob
def get_tex_file_content(self, tex_string):
def get_tex_file_content(self, tex_string: str) -> str:
if self.tex_environment:
tex_string = "\n".join([
f"\\begin{{{self.tex_environment}}}",
@ -483,7 +500,7 @@ class MTex(VMobject):
return tex_string
@staticmethod
def tex_content_to_glyphs(tex_content):
def tex_content_to_glyphs(tex_content: str) -> _TexSVG:
tex_config = get_tex_config()
full_tex = tex_config["tex_body"].replace(
tex_config["text_to_replace"],
@ -492,7 +509,11 @@ class MTex(VMobject):
filename = tex_to_svg_file(full_tex)
return _TexSVG(filename)
def build_mobject(self, svg_glyphs, glyph_labels):
def build_mobject(
self,
svg_glyphs: _TexSVG | None,
glyph_labels: Iterable[int]
) -> VGroup:
if not svg_glyphs:
return VGroup()
@ -530,14 +551,17 @@ class MTex(VMobject):
submob.get_tex = MethodType(lambda inst: inst.tex_string, submob)
return VGroup(*rearranged_submobjects)
def get_part_by_tex_spans(self, tex_spans):
def get_part_by_tex_spans(
self,
tex_spans: Iterable[tuple[int, int]]
) -> VGroup:
labels = self.__parser.get_containing_labels_by_tex_spans(tex_spans)
return VGroup(*filter(
lambda submob: submob.submob_label in labels,
self.submobjects
))
def get_part_by_custom_span(self, custom_span):
def get_part_by_custom_span(self, custom_span: tuple[int, int]) -> VGroup:
tex_spans = self.__parser.find_span_components_of_custom_span(
custom_span
)
@ -546,7 +570,7 @@ class MTex(VMobject):
raise ValueError(f"Failed to match mobjects from tex: \"{tex}\"")
return self.get_part_by_tex_spans(tex_spans)
def get_parts_by_tex(self, tex):
def get_parts_by_tex(self, tex: str) -> VGroup:
return VGroup(*[
self.get_part_by_custom_span(match_obj.span())
for match_obj in re.finditer(
@ -554,20 +578,23 @@ class MTex(VMobject):
)
])
def get_part_by_tex(self, tex, index=0):
def get_part_by_tex(self, tex: str, index: int = 0) -> VGroup:
all_parts = self.get_parts_by_tex(tex)
return all_parts[index]
def set_color_by_tex(self, tex, color):
def set_color_by_tex(self, tex: str, color: ManimColor):
self.get_parts_by_tex(tex).set_color(color)
return self
def set_color_by_tex_to_color_map(self, tex_to_color_map):
def set_color_by_tex_to_color_map(
self,
tex_to_color_map: dict[str, ManimColor]
):
for tex, color in tex_to_color_map.items():
self.set_color_by_tex(tex, color)
return self
def indices_of_part(self, part):
def indices_of_part(self, part: Iterable[VGroup]) -> list[int]:
indices = [
index for index, submob in enumerate(self.submobjects)
if submob in part
@ -576,23 +603,23 @@ class MTex(VMobject):
raise ValueError("Failed to find part in tex")
return indices
def indices_of_part_by_tex(self, tex, index=0):
def indices_of_part_by_tex(self, tex: str, index: int = 0) -> list[int]:
part = self.get_part_by_tex(tex, index=index)
return self.indices_of_part(part)
def get_tex(self):
def get_tex(self) -> str:
return self.tex_string
def get_submob_tex(self):
def get_submob_tex(self) -> list[str]:
return [
submob.get_tex()
for submob in self.submobjects
]
def get_specified_substrings(self):
def get_specified_substrings(self) -> list[str]:
return self.__parser.get_specified_substrings()
def get_isolated_substrings(self):
def get_isolated_substrings(self) -> list[str]:
return self.__parser.get_isolated_substrings()

View file

@ -1,7 +1,10 @@
from __future__ import annotations
import os
import re
import hashlib
import itertools as it
from typing import Callable
import svgelements as se
import numpy as np
@ -20,7 +23,7 @@ from manimlib.utils.images import get_full_vector_image_path
from manimlib.logger import log
def _convert_point_to_3d(x, y):
def _convert_point_to_3d(x: float, y: float) -> np.ndarray:
return np.array([x, y, 0.0])
@ -41,7 +44,7 @@ class SVGMobject(VMobject):
"path_string_config": {}
}
def __init__(self, file_name=None, **kwargs):
def __init__(self, file_name: str | None = None, **kwargs):
digest_config(self, kwargs)
self.file_name = file_name or self.file_name
if file_name is None:
@ -51,7 +54,7 @@ class SVGMobject(VMobject):
super().__init__(**kwargs)
self.move_into_position()
def move_into_position(self):
def move_into_position(self) -> None:
if self.should_center:
self.center()
if self.height is not None:
@ -59,7 +62,7 @@ class SVGMobject(VMobject):
if self.width is not None:
self.set_width(self.width)
def init_colors(self):
def init_colors(self) -> None:
# Remove fill_color, fill_opacity,
# stroke_width, stroke_color, stroke_opacity
# as each submobject may have those values specified in svg file
@ -68,7 +71,7 @@ class SVGMobject(VMobject):
self.set_flat_stroke(self.flat_stroke)
return self
def init_points(self):
def init_points(self) -> None:
with open(self.file_path, "r") as svg_file:
svg_string = svg_file.read()
@ -96,7 +99,7 @@ class SVGMobject(VMobject):
self.flip(RIGHT) # Flip y
self.scale(0.75)
def modify_svg_file(self, svg_string):
def modify_svg_file(self, svg_string: str) -> str:
# svgelements cannot handle em, ex units
# Convert them using 1em = 16px, 1ex = 0.5em = 8px
def convert_unit(match_obj):
@ -127,7 +130,7 @@ class SVGMobject(VMobject):
return result
def generate_context_values_from_config(self):
def generate_context_values_from_config(self) -> dict[str]:
result = {}
if self.stroke_width is not None:
result["stroke-width"] = self.stroke_width
@ -145,7 +148,7 @@ class SVGMobject(VMobject):
result["stroke-opacity"] = self.stroke_opacity
return result
def get_mobjects_from(self, shape):
def get_mobjects_from(self, shape) -> list[VMobject]:
if isinstance(shape, se.Group):
return list(it.chain(*(
self.get_mobjects_from(child)
@ -161,7 +164,7 @@ class SVGMobject(VMobject):
return [mob]
@staticmethod
def handle_transform(mob, matrix):
def handle_transform(mob: VMobject, matrix: se.Matrix) -> VMobject:
mat = np.array([
[matrix.a, matrix.c],
[matrix.b, matrix.d]
@ -171,8 +174,10 @@ class SVGMobject(VMobject):
mob.shift(vec)
return mob
def get_mobject_from(self, shape):
shape_class_to_func_map = {
def get_mobject_from(self, shape: se.Shape | se.Text) -> VMobject | None:
shape_class_to_func_map: dict[
type, Callable[[se.Shape | se.Text], VMobject]
] = {
se.Path: self.path_to_mobject,
se.SimpleLine: self.line_to_mobject,
se.Rect: self.rect_to_mobject,
@ -194,7 +199,10 @@ class SVGMobject(VMobject):
return None
@staticmethod
def apply_style_to_mobject(mob, shape):
def apply_style_to_mobject(
mob: VMobject,
shape: se.Shape | se.Text
) -> VMobject:
mob.set_style(
stroke_width=shape.stroke_width,
stroke_color=shape.stroke.hex,
@ -204,16 +212,16 @@ class SVGMobject(VMobject):
)
return mob
def path_to_mobject(self, path):
def path_to_mobject(self, path: se.Path) -> VMobjectFromSVGPath:
return VMobjectFromSVGPath(path, **self.path_string_config)
def line_to_mobject(self, line):
def line_to_mobject(self, line: se.Line) -> Line:
return Line(
start=_convert_point_to_3d(line.x1, line.y1),
end=_convert_point_to_3d(line.x2, line.y2)
)
def rect_to_mobject(self, rect):
def rect_to_mobject(self, rect: se.Rect) -> Rectangle | RoundedRectangle:
if rect.rx == 0 or rect.ry == 0:
mob = Rectangle(
width=rect.width,
@ -232,7 +240,7 @@ class SVGMobject(VMobject):
))
return mob
def circle_to_mobject(self, circle):
def circle_to_mobject(self, circle: se.Circle) -> Circle:
# svgelements supports `rx` & `ry` but `r`
mob = Circle(radius=circle.rx)
mob.shift(_convert_point_to_3d(
@ -240,7 +248,7 @@ class SVGMobject(VMobject):
))
return mob
def ellipse_to_mobject(self, ellipse):
def ellipse_to_mobject(self, ellipse: se.Ellipse) -> Circle:
mob = Circle(radius=ellipse.rx)
mob.stretch_to_fit_height(2 * ellipse.ry)
mob.shift(_convert_point_to_3d(
@ -248,21 +256,21 @@ class SVGMobject(VMobject):
))
return mob
def polygon_to_mobject(self, polygon):
def polygon_to_mobject(self, polygon: se.Polygon) -> Polygon:
points = [
_convert_point_to_3d(*point)
for point in polygon
]
return Polygon(*points)
def polyline_to_mobject(self, polyline):
def polyline_to_mobject(self, polyline: se.Polyline) -> Polyline:
points = [
_convert_point_to_3d(*point)
for point in polyline
]
return Polyline(*points)
def text_to_mobject(self, text):
def text_to_mobject(self, text: se.Text):
pass
@ -273,13 +281,13 @@ class VMobjectFromSVGPath(VMobject):
"should_remove_null_curves": False,
}
def __init__(self, path_obj, **kwargs):
def __init__(self, path_obj: se.Path, **kwargs):
# Get rid of arcs
path_obj.approximate_arcs_with_quads()
self.path_obj = path_obj
super().__init__(**kwargs)
def init_points(self):
def init_points(self) -> None:
# After a given svg_path has been converted into points, the result
# will be saved to a file so that future calls for the same path
# don't need to retrace the same computation.
@ -305,7 +313,7 @@ class VMobjectFromSVGPath(VMobject):
np.save(points_filepath, self.get_points())
np.save(tris_filepath, self.get_triangulation())
def handle_commands(self):
def handle_commands(self) -> None:
segment_class_to_func_map = {
se.Move: (self.start_new_path, ("end",)),
se.Close: (self.close_path, ()),

View file

@ -1,5 +1,9 @@
from __future__ import annotations
from typing import Iterable, Sequence, Union
from functools import reduce
import operator as op
import colour
import re
from manimlib.constants import *
@ -13,10 +17,13 @@ from manimlib.utils.tex_file_writing import get_tex_config
from manimlib.utils.tex_file_writing import display_during_execution
ManimColor = Union[str, colour.Color, Sequence[float]]
SCALE_FACTOR_PER_FONT_POINT = 0.001
tex_string_with_color_to_mob_map = {}
tex_string_with_color_to_mob_map: dict[
tuple[ManimColor, str], SVGMobject
] = {}
class SingleStringTex(VMobject):
@ -31,7 +38,7 @@ class SingleStringTex(VMobject):
"math_mode": True,
}
def __init__(self, tex_string, **kwargs):
def __init__(self, tex_string: str, **kwargs):
super().__init__(**kwargs)
assert(isinstance(tex_string, str))
self.tex_string = tex_string
@ -66,7 +73,7 @@ class SingleStringTex(VMobject):
self.set_flat_stroke(self.flat_stroke)
return self
def get_tex_file_body(self, tex_string):
def get_tex_file_body(self, tex_string: str) -> str:
new_tex = self.get_modified_expression(tex_string)
if self.math_mode:
new_tex = "\\begin{align*}\n" + new_tex + "\n\\end{align*}"
@ -79,10 +86,10 @@ class SingleStringTex(VMobject):
new_tex
)
def get_modified_expression(self, tex_string):
def get_modified_expression(self, tex_string: str) -> str:
return self.modify_special_strings(tex_string.strip())
def modify_special_strings(self, tex):
def modify_special_strings(self, tex: str) -> str:
tex = tex.strip()
should_add_filler = reduce(op.or_, [
# Fraction line needs something to be over
@ -134,7 +141,7 @@ class SingleStringTex(VMobject):
tex = ""
return tex
def balance_braces(self, tex):
def balance_braces(self, tex: str) -> str:
"""
Makes Tex resiliant to unmatched braces
"""
@ -154,7 +161,7 @@ class SingleStringTex(VMobject):
tex += num_unclosed_brackets * "}"
return tex
def get_tex(self):
def get_tex(self) -> str:
return self.tex_string
def organize_submobjects_left_to_right(self):
@ -169,7 +176,7 @@ class Tex(SingleStringTex):
"tex_to_color_map": {},
}
def __init__(self, *tex_strings, **kwargs):
def __init__(self, *tex_strings: str, **kwargs):
digest_config(self, kwargs)
self.tex_strings = self.break_up_tex_strings(tex_strings)
full_string = self.arg_separator.join(self.tex_strings)
@ -180,7 +187,7 @@ class Tex(SingleStringTex):
if self.organize_left_to_right:
self.organize_submobjects_left_to_right()
def break_up_tex_strings(self, tex_strings):
def break_up_tex_strings(self, tex_strings: Iterable[str]) -> Iterable[str]:
# Separate out any strings specified in the isolate
# or tex_to_color_map lists.
substrings_to_isolate = [*self.isolate, *self.tex_to_color_map.keys()]
@ -228,7 +235,12 @@ class Tex(SingleStringTex):
self.set_submobjects(new_submobjects)
return self
def get_parts_by_tex(self, tex, substring=True, case_sensitive=True):
def get_parts_by_tex(
self,
tex: str,
substring: bool = True,
case_sensitive: bool = True
) -> VGroup:
def test(tex1, tex2):
if not case_sensitive:
tex1 = tex1.lower()
@ -243,27 +255,36 @@ class Tex(SingleStringTex):
self.submobjects
))
def get_part_by_tex(self, tex, **kwargs):
def get_part_by_tex(self, tex: str, **kwargs) -> SingleStringTex | None:
all_parts = self.get_parts_by_tex(tex, **kwargs)
return all_parts[0] if all_parts else None
def set_color_by_tex(self, tex, color, **kwargs):
def set_color_by_tex(self, tex: str, color: ManimColor, **kwargs):
self.get_parts_by_tex(tex, **kwargs).set_color(color)
return self
def set_color_by_tex_to_color_map(self, tex_to_color_map, **kwargs):
def set_color_by_tex_to_color_map(
self,
tex_to_color_map: dict[str, ManimColor],
**kwargs
):
for tex, color in list(tex_to_color_map.items()):
self.set_color_by_tex(tex, color, **kwargs)
return self
def index_of_part(self, part, start=0):
def index_of_part(self, part: SingleStringTex, start: int = 0) -> int:
return self.submobjects.index(part, start)
def index_of_part_by_tex(self, tex, start=0, **kwargs):
def index_of_part_by_tex(self, tex: str, start: int = 0, **kwargs) -> int:
part = self.get_part_by_tex(tex, **kwargs)
return self.index_of_part(part, start)
def slice_by_tex(self, start_tex=None, stop_tex=None, **kwargs):
def slice_by_tex(
self,
start_tex: str | None = None,
stop_tex: str | None = None,
**kwargs
) -> VGroup:
if start_tex is None:
start_index = 0
else:
@ -275,10 +296,10 @@ class Tex(SingleStringTex):
stop_index = self.index_of_part_by_tex(stop_tex, start=start_index, **kwargs)
return self[start_index:stop_index]
def sort_alphabetically(self):
def sort_alphabetically(self) -> None:
self.submobjects.sort(key=lambda m: m.get_tex())
def set_bstroke(self, color=BLACK, width=4):
def set_bstroke(self, color: ManimColor = BLACK, width: float = 4):
self.set_stroke(color, width, background=True)
return self
@ -297,7 +318,7 @@ class BulletedList(TexText):
"alignment": "",
}
def __init__(self, *items, **kwargs):
def __init__(self, *items: str, **kwargs):
line_separated_items = [s + "\\\\" for s in items]
TexText.__init__(self, *line_separated_items, **kwargs)
for part in self:
@ -310,7 +331,7 @@ class BulletedList(TexText):
buff=self.buff
)
def fade_all_but(self, index_or_string, opacity=0.5):
def fade_all_but(self, index_or_string: int | str, opacity: float = 0.5) -> None:
arg = index_or_string
if isinstance(arg, str):
part = self.get_part_by_tex(arg)
@ -348,7 +369,7 @@ class Title(TexText):
"underline_buff": MED_SMALL_BUFF,
}
def __init__(self, *text_parts, **kwargs):
def __init__(self, *text_parts: str, **kwargs):
TexText.__init__(self, *text_parts, **kwargs)
self.scale(self.scale_factor)
self.to_edge(UP)

View file

@ -1,17 +1,19 @@
import hashlib
from __future__ import annotations
import os
import re
import io
import typing
import xml.etree.ElementTree as ET
import hashlib
import functools
from pathlib import Path
import xml.etree.ElementTree as ET
from contextlib import contextmanager
from typing import Iterable, Sequence, Union
import pygments
import pygments.lexers
import pygments.styles
from contextlib import contextmanager
from pathlib import Path
import manimpango
from manimlib.logger import log
from manimlib.constants import *
@ -23,6 +25,12 @@ from manimlib.utils.customization import get_customization
from manimlib.utils.directories import get_downloads_dir, get_text_dir
from manimpango import PangoUtils, TextSetting, MarkupUtils
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import colour
from manimlib.mobject.types.vectorized_mobject import VMobject
ManimColor = Union[str, colour.Color, Sequence[float]]
TEXT_MOB_SCALE_FACTOR = 0.0076
DEFAULT_LINE_SPACING_SCALE = 0.6
@ -50,7 +58,7 @@ class Text(SVGMobject):
"disable_ligatures": True,
}
def __init__(self, text, **kwargs):
def __init__(self, text: str, **kwargs):
self.full2short(kwargs)
digest_config(self, kwargs)
if self.size:
@ -60,9 +68,9 @@ class Text(SVGMobject):
)
self.font_size = self.size
if self.lsh == -1:
self.lsh = self.font_size + self.font_size * DEFAULT_LINE_SPACING_SCALE
self.lsh: float = self.font_size + self.font_size * DEFAULT_LINE_SPACING_SCALE
else:
self.lsh = self.font_size + self.font_size * self.lsh
self.lsh: float = self.font_size + self.font_size * self.lsh
text_without_tabs = text
if text.find('\t') != -1:
text_without_tabs = text.replace('\t', ' ' * self.tab_width)
@ -87,14 +95,14 @@ class Text(SVGMobject):
if self.height is None:
self.scale(TEXT_MOB_SCALE_FACTOR)
def remove_empty_path(self, file_name):
def remove_empty_path(self, file_name: str) -> None:
with open(file_name, 'r') as fpr:
content = fpr.read()
content = re.sub(r'<path .*?d=""/>', '', content)
with open(file_name, 'w') as fpw:
fpw.write(content)
def apply_space_chars(self):
def apply_space_chars(self) -> None:
submobs = self.submobjects.copy()
for char_index in range(len(self.text)):
if self.text[char_index] in [" ", "\t", "\n"]:
@ -103,7 +111,7 @@ class Text(SVGMobject):
submobs.insert(char_index, space)
self.set_submobjects(submobs)
def find_indexes(self, word):
def find_indexes(self, word: str) -> list[tuple[int, int]]:
m = re.match(r'\[([0-9\-]{0,}):([0-9\-]{0,})\]', word)
if m:
start = int(m.group(1)) if m.group(1) != '' else 0
@ -119,20 +127,20 @@ class Text(SVGMobject):
index = self.text.find(word, index + len(word))
return indexes
def get_parts_by_text(self, word):
def get_parts_by_text(self, word: str) -> VGroup:
return VGroup(*(
self[i:j]
for i, j in self.find_indexes(word)
))
def get_part_by_text(self, word):
def get_part_by_text(self, word: str) -> VMobject | None:
parts = self.get_parts_by_text(word)
if len(parts) > 0:
return parts[0]
else:
return None
def full2short(self, config):
def full2short(self, config: dict[str]) -> None:
for kwargs in [config, self.CONFIG]:
if kwargs.__contains__('line_spacing_height'):
kwargs['lsh'] = kwargs.pop('line_spacing_height')
@ -147,19 +155,25 @@ class Text(SVGMobject):
if kwargs.__contains__('text2weight'):
kwargs['t2w'] = kwargs.pop('text2weight')
def set_color_by_t2c(self, t2c=None):
def set_color_by_t2c(
self,
t2c: dict[str, ManimColor] | None = None
) -> None:
t2c = t2c if t2c else self.t2c
for word, color in t2c.items():
for start, end in self.find_indexes(word):
self[start:end].set_color(color)
def set_color_by_t2g(self, t2g=None):
def set_color_by_t2g(
self,
t2g: dict[str, Iterable[ManimColor]] | None = None
) -> None:
t2g = t2g if t2g else self.t2g
for word, gradient in t2g.items():
for start, end in self.find_indexes(word):
self[start:end].set_color_by_gradient(*gradient)
def text2hash(self):
def text2hash(self) -> str:
settings = self.font + self.slant + self.weight
settings += str(self.t2f) + str(self.t2s) + str(self.t2w)
settings += str(self.lsh) + str(self.font_size)
@ -168,7 +182,7 @@ class Text(SVGMobject):
hasher.update(id_str.encode())
return hasher.hexdigest()[:16]
def text2settings(self):
def text2settings(self) -> list[TextSetting]:
"""
Substrings specified in t2f, t2s, t2w can occupy each other.
For each category of style, a stack following first-in-last-out is constructed,
@ -227,7 +241,7 @@ class Text(SVGMobject):
del self.line_num
return settings
def text2svg(self):
def text2svg(self) -> str:
# anti-aliasing
size = self.font_size
lsh = self.lsh
@ -503,7 +517,7 @@ class Code(Text):
"char_width": None
}
def __init__(self, code, **kwargs):
def __init__(self, code: str, **kwargs):
self.full2short(kwargs)
digest_config(self, kwargs)
code = code.lstrip("\n") # avoid mismatches of character indices
@ -536,7 +550,7 @@ class Code(Text):
if self.char_width is not None:
self.set_monospace(self.char_width)
def set_monospace(self, char_width):
def set_monospace(self, char_width: float) -> None:
current_char_index = 0
for i, char in enumerate(self.text):
if char == "\n":
@ -548,7 +562,7 @@ class Code(Text):
@contextmanager
def register_font(font_file: typing.Union[str, Path]):
def register_font(font_file: str | Path):
"""Temporarily add a font file to Pango's search path.
This searches for the font_file at various places. The order it searches it described below.
1. Absolute path.