chore: add type hints to manimlib.mobject.svg

This commit is contained in:
TonyCrane 2022-02-14 22:55:41 +08:00
parent 61c70b426c
commit 773e013af9
No known key found for this signature in database
GPG key ID: 2313A5058A9C637C
5 changed files with 226 additions and 128 deletions

View file

@ -1,11 +1,15 @@
import numpy as np from __future__ import annotations
import math import math
import copy import copy
from typing import Iterable
import numpy as np
from manimlib.animation.composition import AnimationGroup
from manimlib.constants import * from manimlib.constants import *
from manimlib.animation.fading import FadeIn from manimlib.animation.fading import FadeIn
from manimlib.animation.growing import GrowFromCenter from manimlib.animation.growing import GrowFromCenter
from manimlib.animation.composition import AnimationGroup
from manimlib.mobject.svg.tex_mobject import Tex from manimlib.mobject.svg.tex_mobject import Tex
from manimlib.mobject.svg.tex_mobject import SingleStringTex from manimlib.mobject.svg.tex_mobject import SingleStringTex
from manimlib.mobject.svg.tex_mobject import TexText from manimlib.mobject.svg.tex_mobject import TexText
@ -14,6 +18,10 @@ from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.config_ops import digest_config from manimlib.utils.config_ops import digest_config
from manimlib.utils.space_ops import get_norm from manimlib.utils.space_ops import get_norm
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.mobject.mobject import Mobject
from manimlib.animation.animation import Animation
class Brace(SingleStringTex): class Brace(SingleStringTex):
CONFIG = { CONFIG = {
@ -21,7 +29,12 @@ class Brace(SingleStringTex):
"tex_string": r"\underbrace{\qquad}" "tex_string": r"\underbrace{\qquad}"
} }
def __init__(self, mobject, direction=DOWN, **kwargs): def __init__(
self,
mobject: Mobject,
direction: np.ndarray = DOWN,
**kwargs
):
digest_config(self, kwargs, locals()) digest_config(self, kwargs, locals())
angle = -math.atan2(*direction[:2]) + PI angle = -math.atan2(*direction[:2]) + PI
mobject.rotate(-angle, about_point=ORIGIN) mobject.rotate(-angle, about_point=ORIGIN)
@ -36,7 +49,7 @@ class Brace(SingleStringTex):
for mob in mobject, self: for mob in mobject, self:
mob.rotate(angle, about_point=ORIGIN) mob.rotate(angle, about_point=ORIGIN)
def set_initial_width(self, width): def set_initial_width(self, width: float):
width_diff = width - self.get_width() width_diff = width - self.get_width()
if width_diff > 0: if width_diff > 0:
for tip, rect, vect in [(self[0], self[1], RIGHT), (self[5], self[4], LEFT)]: for tip, rect, vect in [(self[0], self[1], RIGHT), (self[5], self[4], LEFT)]:
@ -49,7 +62,12 @@ class Brace(SingleStringTex):
self.set_width(width, stretch=True) self.set_width(width, stretch=True)
return self return self
def put_at_tip(self, mob, use_next_to=True, **kwargs): def put_at_tip(
self,
mob: Mobject,
use_next_to: bool = True,
**kwargs
):
if use_next_to: if use_next_to:
mob.next_to( mob.next_to(
self.get_tip(), self.get_tip(),
@ -63,24 +81,24 @@ class Brace(SingleStringTex):
mob.shift(self.get_direction() * shift_distance) mob.shift(self.get_direction() * shift_distance)
return self return self
def get_text(self, text, **kwargs): def get_text(self, text: str, **kwargs) -> Text:
buff = kwargs.pop("buff", SMALL_BUFF) buff = kwargs.pop("buff", SMALL_BUFF)
text_mob = Text(text, **kwargs) text_mob = Text(text, **kwargs)
self.put_at_tip(text_mob, buff=buff) self.put_at_tip(text_mob, buff=buff)
return text_mob return text_mob
def get_tex(self, *tex, **kwargs): def get_tex(self, *tex: str, **kwargs) -> Tex:
tex_mob = Tex(*tex) tex_mob = Tex(*tex)
self.put_at_tip(tex_mob, **kwargs) self.put_at_tip(tex_mob, **kwargs)
return tex_mob return tex_mob
def get_tip(self): def get_tip(self) -> np.ndarray:
# Very specific to the LaTeX representation # Very specific to the LaTeX representation
# of a brace, but it's the only way I can think # of a brace, but it's the only way I can think
# of to get the tip regardless of orientation. # of to get the tip regardless of orientation.
return self.get_all_points()[self.tip_point_index] return self.get_all_points()[self.tip_point_index]
def get_direction(self): def get_direction(self) -> np.ndarray:
vect = self.get_tip() - self.get_center() vect = self.get_tip() - self.get_center()
return vect / get_norm(vect) return vect / get_norm(vect)
@ -92,14 +110,20 @@ class BraceLabel(VMobject):
"label_buff": DEFAULT_MOBJECT_TO_MOBJECT_BUFFER "label_buff": DEFAULT_MOBJECT_TO_MOBJECT_BUFFER
} }
def __init__(self, obj, text, brace_direction=DOWN, **kwargs): def __init__(
self,
obj: list[VMobject] | Mobject,
text: Iterable[str] | str,
brace_direction: np.ndarray = DOWN,
**kwargs
) -> None:
VMobject.__init__(self, **kwargs) VMobject.__init__(self, **kwargs)
self.brace_direction = brace_direction self.brace_direction = brace_direction
if isinstance(obj, list): if isinstance(obj, list):
obj = VMobject(*obj) obj = VMobject(*obj)
self.brace = Brace(obj, brace_direction, **kwargs) self.brace = Brace(obj, brace_direction, **kwargs)
if isinstance(text, tuple) or isinstance(text, list): if isinstance(text, Iterable):
self.label = self.label_constructor(*text, **kwargs) self.label = self.label_constructor(*text, **kwargs)
else: else:
self.label = self.label_constructor(str(text)) self.label = self.label_constructor(str(text))
@ -109,10 +133,14 @@ class BraceLabel(VMobject):
self.brace.put_at_tip(self.label, buff=self.label_buff) self.brace.put_at_tip(self.label, buff=self.label_buff)
self.set_submobjects([self.brace, self.label]) self.set_submobjects([self.brace, self.label])
def creation_anim(self, label_anim=FadeIn, brace_anim=GrowFromCenter): def creation_anim(
self,
label_anim: Animation = FadeIn,
brace_anim: Animation=GrowFromCenter
) -> AnimationGroup:
return AnimationGroup(brace_anim(self.brace), label_anim(self.label)) return AnimationGroup(brace_anim(self.brace), label_anim(self.label))
def shift_brace(self, obj, **kwargs): def shift_brace(self, obj: list[VMobject] | Mobject, **kwargs):
if isinstance(obj, list): if isinstance(obj, list):
obj = VMobject(*obj) obj = VMobject(*obj)
self.brace = Brace(obj, self.brace_direction, **kwargs) self.brace = Brace(obj, self.brace_direction, **kwargs)
@ -120,7 +148,7 @@ class BraceLabel(VMobject):
self.submobjects[0] = self.brace self.submobjects[0] = self.brace
return self return self
def change_label(self, *text, **kwargs): def change_label(self, *text: str, **kwargs):
self.label = self.label_constructor(*text, **kwargs) self.label = self.label_constructor(*text, **kwargs)
if self.label_scale != 1: if self.label_scale != 1:
self.label.scale(self.label_scale) self.label.scale(self.label_scale)
@ -129,7 +157,7 @@ class BraceLabel(VMobject):
self.submobjects[1] = self.label self.submobjects[1] = self.label
return self return self
def change_brace_label(self, obj, *text): def change_brace_label(self, obj: list[VMobject] | Mobject, *text: str):
self.shift_brace(obj) self.shift_brace(obj)
self.change_label(*text) self.change_label(*text)
return self return self

View file

@ -1,6 +1,10 @@
import itertools as it from __future__ import annotations
import re import re
import colour
import itertools as it
from types import MethodType from types import MethodType
from typing import Iterable, Union, Sequence
from manimlib.constants import BLACK from manimlib.constants import BLACK
from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.svg.svg_mobject import SVGMobject
@ -14,14 +18,16 @@ from manimlib.utils.tex_file_writing import get_tex_config
from manimlib.utils.tex_file_writing import display_during_execution from manimlib.utils.tex_file_writing import display_during_execution
from manimlib.logger import log from manimlib.logger import log
ManimColor = Union[str, colour.Color, Sequence[float]]
SCALE_FACTOR_PER_FONT_POINT = 0.001 SCALE_FACTOR_PER_FONT_POINT = 0.001
TEX_HASH_TO_MOB_MAP = {} TEX_HASH_TO_MOB_MAP: dict[int, VGroup] = {}
def _get_neighbouring_pairs(iterable): def _get_neighbouring_pairs(iterable: Iterable) -> list:
return list(adjacent_pairs(iterable))[:-1] return list(adjacent_pairs(iterable))[:-1]
@ -38,17 +44,19 @@ class _TexSVG(SVGMobject):
class _TexParser(object): class _TexParser(object):
def __init__(self, tex_string, additional_substrings): def __init__(self, tex_string: str, additional_substrings: str):
self.tex_string = tex_string self.tex_string = tex_string
self.whitespace_indices = self.get_whitespace_indices() self.whitespace_indices = self.get_whitespace_indices()
self.backslash_indices = self.get_backslash_indices() self.backslash_indices = self.get_backslash_indices()
self.script_indices = self.get_script_indices() self.script_indices = self.get_script_indices()
self.brace_indices_dict = self.get_brace_indices_dict() self.brace_indices_dict = self.get_brace_indices_dict()
self.tex_span_list = [] self.tex_span_list: list[tuple[int, int]] = []
self.script_span_to_char_dict = {} self.script_span_to_char_dict: dict[tuple[int, int], str] = {}
self.script_span_to_tex_span_dict = {} self.script_span_to_tex_span_dict: dict[
self.neighbouring_script_span_pairs = [] tuple[int, int], tuple[int, int]
self.specified_substrings = [] ] = {}
self.neighbouring_script_span_pairs: list[tuple[int, int]] = []
self.specified_substrings: list[str] = []
self.add_tex_span((0, len(tex_string))) self.add_tex_span((0, len(tex_string)))
self.break_up_by_scripts() self.break_up_by_scripts()
self.break_up_by_double_braces() self.break_up_by_double_braces()
@ -59,17 +67,17 @@ class _TexParser(object):
) )
self.containing_labels_dict = self.get_containing_labels_dict() self.containing_labels_dict = self.get_containing_labels_dict()
def add_tex_span(self, tex_span): def add_tex_span(self, tex_span: tuple[int, int]) -> None:
if tex_span not in self.tex_span_list: if tex_span not in self.tex_span_list:
self.tex_span_list.append(tex_span) self.tex_span_list.append(tex_span)
def get_whitespace_indices(self): def get_whitespace_indices(self) -> list[int]:
return [ return [
match_obj.start() match_obj.start()
for match_obj in re.finditer(r"\s", self.tex_string) for match_obj in re.finditer(r"\s", self.tex_string)
] ]
def get_backslash_indices(self): def get_backslash_indices(self) -> list[int]:
# Newlines (`\\`) don't count. # Newlines (`\\`) don't count.
return [ return [
match_obj.end() - 1 match_obj.end() - 1
@ -77,19 +85,19 @@ class _TexParser(object):
if len(match_obj.group()) % 2 == 1 if len(match_obj.group()) % 2 == 1
] ]
def filter_out_escaped_characters(self, indices): def filter_out_escaped_characters(self, indices) -> list[int]:
return list(filter( return list(filter(
lambda index: index - 1 not in self.backslash_indices, lambda index: index - 1 not in self.backslash_indices,
indices indices
)) ))
def get_script_indices(self): def get_script_indices(self) -> list[int]:
return self.filter_out_escaped_characters([ return self.filter_out_escaped_characters([
match_obj.start() match_obj.start()
for match_obj in re.finditer(r"[_^]", self.tex_string) for match_obj in re.finditer(r"[_^]", self.tex_string)
]) ])
def get_brace_indices_dict(self): def get_brace_indices_dict(self) -> dict[int, int]:
tex_string = self.tex_string tex_string = self.tex_string
indices = self.filter_out_escaped_characters([ indices = self.filter_out_escaped_characters([
match_obj.start() match_obj.start()
@ -105,7 +113,7 @@ class _TexParser(object):
result[left_brace_index] = index result[left_brace_index] = index
return result return result
def break_up_by_scripts(self): def break_up_by_scripts(self) -> None:
# Match subscripts & superscripts. # Match subscripts & superscripts.
tex_string = self.tex_string tex_string = self.tex_string
whitespace_indices = self.whitespace_indices whitespace_indices = self.whitespace_indices
@ -154,7 +162,7 @@ class _TexParser(object):
if span_0[1] == span_1[0]: if span_0[1] == span_1[0]:
self.neighbouring_script_span_pairs.append((span_0, span_1)) self.neighbouring_script_span_pairs.append((span_0, span_1))
def break_up_by_double_braces(self): def break_up_by_double_braces(self) -> None:
# Match paired double braces (`{{...}}`). # Match paired double braces (`{{...}}`).
tex_string = self.tex_string tex_string = self.tex_string
reversed_indices_dict = dict( reversed_indices_dict = dict(
@ -178,7 +186,10 @@ class _TexParser(object):
self.specified_substrings.append(tex_string[slice(*tex_span)]) self.specified_substrings.append(tex_string[slice(*tex_span)])
skip = True skip = True
def break_up_by_additional_substrings(self, additional_substrings): def break_up_by_additional_substrings(
self,
additional_substrings: Iterable[str]
) -> None:
stripped_substrings = sorted(remove_list_redundancies([ stripped_substrings = sorted(remove_list_redundancies([
string.strip() string.strip()
for string in additional_substrings for string in additional_substrings
@ -208,7 +219,7 @@ class _TexParser(object):
continue continue
self.add_tex_span((span_begin, span_end)) self.add_tex_span((span_begin, span_end))
def get_containing_labels_dict(self): def get_containing_labels_dict(self) -> dict[tuple[int, int], list[int]]:
tex_span_list = self.tex_span_list tex_span_list = self.tex_span_list
result = { result = {
tex_span: [] tex_span: []
@ -233,7 +244,7 @@ class _TexParser(object):
raise ValueError raise ValueError
return result return result
def get_labelled_tex_string(self): def get_labelled_tex_string(self) -> str:
indices, _, flags, labels = zip(*sorted([ indices, _, flags, labels = zip(*sorted([
(*tex_span[::(1, -1)[flag]], flag, label) (*tex_span[::(1, -1)[flag]], flag, label)
for label, tex_span in enumerate(self.tex_span_list) for label, tex_span in enumerate(self.tex_span_list)
@ -251,7 +262,7 @@ class _TexParser(object):
return "".join(it.chain(*zip(command_pieces, string_pieces))) return "".join(it.chain(*zip(command_pieces, string_pieces)))
@staticmethod @staticmethod
def get_color_command(label): def get_color_command(label: int) -> str:
rg, b = divmod(label, 256) rg, b = divmod(label, 256)
r, g = divmod(rg, 256) r, g = divmod(rg, 256)
return "".join([ return "".join([
@ -261,7 +272,7 @@ class _TexParser(object):
"}" "}"
]) ])
def get_sorted_submob_indices(self, submob_labels): def get_sorted_submob_indices(self, submob_labels: Iterable[int]) -> list[int]:
def script_span_to_submob_range(script_span): def script_span_to_submob_range(script_span):
tex_span = self.script_span_to_tex_span_dict[script_span] tex_span = self.script_span_to_tex_span_dict[script_span]
submob_indices = [ submob_indices = [
@ -295,7 +306,7 @@ class _TexParser(object):
] ]
return result return result
def get_submob_tex_strings(self, submob_labels): def get_submob_tex_strings(self, submob_labels: Iterable[int]) -> list[str]:
ordered_tex_spans = [ ordered_tex_spans = [
self.tex_span_list[label] for label in submob_labels self.tex_span_list[label] for label in submob_labels
] ]
@ -356,7 +367,10 @@ class _TexParser(object):
])) ]))
return result return result
def find_span_components_of_custom_span(self, custom_span): def find_span_components_of_custom_span(
self,
custom_span: tuple[int, int]
) -> list[tuple[int, int]] | None:
skipped_indices = sorted(it.chain( skipped_indices = sorted(it.chain(
self.whitespace_indices, self.whitespace_indices,
self.script_indices self.script_indices
@ -384,16 +398,19 @@ class _TexParser(object):
span_begin = next_begin span_begin = next_begin
return result return result
def get_containing_labels_by_tex_spans(self, tex_spans): def get_containing_labels_by_tex_spans(
self,
tex_spans: Iterable[tuple[int, int]]
) -> list[int]:
return remove_list_redundancies(list(it.chain(*[ return remove_list_redundancies(list(it.chain(*[
self.containing_labels_dict[tex_span] self.containing_labels_dict[tex_span]
for tex_span in tex_spans for tex_span in tex_spans
]))) ])))
def get_specified_substrings(self): def get_specified_substrings(self) -> list[str]:
return self.specified_substrings return self.specified_substrings
def get_isolated_substrings(self): def get_isolated_substrings(self) -> list[str]:
return remove_list_redundancies([ return remove_list_redundancies([
self.tex_string[slice(*tex_span)] self.tex_string[slice(*tex_span)]
for tex_span in self.tex_span_list for tex_span in self.tex_span_list
@ -412,7 +429,7 @@ class MTex(VMobject):
"use_plain_tex": False, "use_plain_tex": False,
} }
def __init__(self, tex_string, **kwargs): def __init__(self, tex_string: str, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
tex_string = tex_string.strip() tex_string = tex_string.strip()
# Prevent from passing an empty string. # Prevent from passing an empty string.
@ -431,12 +448,12 @@ class MTex(VMobject):
self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size)
@staticmethod @staticmethod
def color_to_label(color): def color_to_label(color: ManimColor) -> list:
r, g, b = color_to_int_rgb(color) r, g, b = color_to_int_rgb(color)
rg = r * 256 + g rg = r * 256 + g
return rg * 256 + b return rg * 256 + b
def generate_mobject(self): def generate_mobject(self) -> VGroup:
labelled_tex_string = self.__parser.get_labelled_tex_string() labelled_tex_string = self.__parser.get_labelled_tex_string()
labelled_tex_content = self.get_tex_file_content(labelled_tex_string) labelled_tex_content = self.get_tex_file_content(labelled_tex_string)
hash_val = hash((labelled_tex_content, self.use_plain_tex)) hash_val = hash((labelled_tex_content, self.use_plain_tex))
@ -471,7 +488,7 @@ class MTex(VMobject):
TEX_HASH_TO_MOB_MAP[hash_val] = mob TEX_HASH_TO_MOB_MAP[hash_val] = mob
return mob return mob
def get_tex_file_content(self, tex_string): def get_tex_file_content(self, tex_string: str) -> str:
if self.tex_environment: if self.tex_environment:
tex_string = "\n".join([ tex_string = "\n".join([
f"\\begin{{{self.tex_environment}}}", f"\\begin{{{self.tex_environment}}}",
@ -483,7 +500,7 @@ class MTex(VMobject):
return tex_string return tex_string
@staticmethod @staticmethod
def tex_content_to_glyphs(tex_content): def tex_content_to_glyphs(tex_content: str) -> _TexSVG:
tex_config = get_tex_config() tex_config = get_tex_config()
full_tex = tex_config["tex_body"].replace( full_tex = tex_config["tex_body"].replace(
tex_config["text_to_replace"], tex_config["text_to_replace"],
@ -492,7 +509,11 @@ class MTex(VMobject):
filename = tex_to_svg_file(full_tex) filename = tex_to_svg_file(full_tex)
return _TexSVG(filename) return _TexSVG(filename)
def build_mobject(self, svg_glyphs, glyph_labels): def build_mobject(
self,
svg_glyphs: _TexSVG | None,
glyph_labels: Iterable[int]
) -> VGroup:
if not svg_glyphs: if not svg_glyphs:
return VGroup() return VGroup()
@ -530,14 +551,17 @@ class MTex(VMobject):
submob.get_tex = MethodType(lambda inst: inst.tex_string, submob) submob.get_tex = MethodType(lambda inst: inst.tex_string, submob)
return VGroup(*rearranged_submobjects) return VGroup(*rearranged_submobjects)
def get_part_by_tex_spans(self, tex_spans): def get_part_by_tex_spans(
self,
tex_spans: Iterable[tuple[int, int]]
) -> VGroup:
labels = self.__parser.get_containing_labels_by_tex_spans(tex_spans) labels = self.__parser.get_containing_labels_by_tex_spans(tex_spans)
return VGroup(*filter( return VGroup(*filter(
lambda submob: submob.submob_label in labels, lambda submob: submob.submob_label in labels,
self.submobjects self.submobjects
)) ))
def get_part_by_custom_span(self, custom_span): def get_part_by_custom_span(self, custom_span: tuple[int, int]) -> VGroup:
tex_spans = self.__parser.find_span_components_of_custom_span( tex_spans = self.__parser.find_span_components_of_custom_span(
custom_span custom_span
) )
@ -546,7 +570,7 @@ class MTex(VMobject):
raise ValueError(f"Failed to match mobjects from tex: \"{tex}\"") raise ValueError(f"Failed to match mobjects from tex: \"{tex}\"")
return self.get_part_by_tex_spans(tex_spans) return self.get_part_by_tex_spans(tex_spans)
def get_parts_by_tex(self, tex): def get_parts_by_tex(self, tex: str) -> VGroup:
return VGroup(*[ return VGroup(*[
self.get_part_by_custom_span(match_obj.span()) self.get_part_by_custom_span(match_obj.span())
for match_obj in re.finditer( for match_obj in re.finditer(
@ -554,20 +578,23 @@ class MTex(VMobject):
) )
]) ])
def get_part_by_tex(self, tex, index=0): def get_part_by_tex(self, tex: str, index: int = 0) -> VGroup:
all_parts = self.get_parts_by_tex(tex) all_parts = self.get_parts_by_tex(tex)
return all_parts[index] return all_parts[index]
def set_color_by_tex(self, tex, color): def set_color_by_tex(self, tex: str, color: ManimColor):
self.get_parts_by_tex(tex).set_color(color) self.get_parts_by_tex(tex).set_color(color)
return self return self
def set_color_by_tex_to_color_map(self, tex_to_color_map): def set_color_by_tex_to_color_map(
self,
tex_to_color_map: dict[str, ManimColor]
):
for tex, color in tex_to_color_map.items(): for tex, color in tex_to_color_map.items():
self.set_color_by_tex(tex, color) self.set_color_by_tex(tex, color)
return self return self
def indices_of_part(self, part): def indices_of_part(self, part: Iterable[VGroup]) -> list[int]:
indices = [ indices = [
index for index, submob in enumerate(self.submobjects) index for index, submob in enumerate(self.submobjects)
if submob in part if submob in part
@ -576,23 +603,23 @@ class MTex(VMobject):
raise ValueError("Failed to find part in tex") raise ValueError("Failed to find part in tex")
return indices return indices
def indices_of_part_by_tex(self, tex, index=0): def indices_of_part_by_tex(self, tex: str, index: int = 0) -> list[int]:
part = self.get_part_by_tex(tex, index=index) part = self.get_part_by_tex(tex, index=index)
return self.indices_of_part(part) return self.indices_of_part(part)
def get_tex(self): def get_tex(self) -> str:
return self.tex_string return self.tex_string
def get_submob_tex(self): def get_submob_tex(self) -> list[str]:
return [ return [
submob.get_tex() submob.get_tex()
for submob in self.submobjects for submob in self.submobjects
] ]
def get_specified_substrings(self): def get_specified_substrings(self) -> list[str]:
return self.__parser.get_specified_substrings() return self.__parser.get_specified_substrings()
def get_isolated_substrings(self): def get_isolated_substrings(self) -> list[str]:
return self.__parser.get_isolated_substrings() return self.__parser.get_isolated_substrings()

View file

@ -1,7 +1,10 @@
from __future__ import annotations
import os import os
import re import re
import hashlib import hashlib
import itertools as it import itertools as it
from typing import Callable
import svgelements as se import svgelements as se
import numpy as np import numpy as np
@ -20,7 +23,7 @@ from manimlib.utils.images import get_full_vector_image_path
from manimlib.logger import log from manimlib.logger import log
def _convert_point_to_3d(x, y): def _convert_point_to_3d(x: float, y: float) -> np.ndarray:
return np.array([x, y, 0.0]) return np.array([x, y, 0.0])
@ -41,7 +44,7 @@ class SVGMobject(VMobject):
"path_string_config": {} "path_string_config": {}
} }
def __init__(self, file_name=None, **kwargs): def __init__(self, file_name: str | None = None, **kwargs):
digest_config(self, kwargs) digest_config(self, kwargs)
self.file_name = file_name or self.file_name self.file_name = file_name or self.file_name
if file_name is None: if file_name is None:
@ -51,7 +54,7 @@ class SVGMobject(VMobject):
super().__init__(**kwargs) super().__init__(**kwargs)
self.move_into_position() self.move_into_position()
def move_into_position(self): def move_into_position(self) -> None:
if self.should_center: if self.should_center:
self.center() self.center()
if self.height is not None: if self.height is not None:
@ -59,7 +62,7 @@ class SVGMobject(VMobject):
if self.width is not None: if self.width is not None:
self.set_width(self.width) self.set_width(self.width)
def init_colors(self): def init_colors(self) -> None:
# Remove fill_color, fill_opacity, # Remove fill_color, fill_opacity,
# stroke_width, stroke_color, stroke_opacity # stroke_width, stroke_color, stroke_opacity
# as each submobject may have those values specified in svg file # as each submobject may have those values specified in svg file
@ -68,7 +71,7 @@ class SVGMobject(VMobject):
self.set_flat_stroke(self.flat_stroke) self.set_flat_stroke(self.flat_stroke)
return self return self
def init_points(self): def init_points(self) -> None:
with open(self.file_path, "r") as svg_file: with open(self.file_path, "r") as svg_file:
svg_string = svg_file.read() svg_string = svg_file.read()
@ -96,7 +99,7 @@ class SVGMobject(VMobject):
self.flip(RIGHT) # Flip y self.flip(RIGHT) # Flip y
self.scale(0.75) self.scale(0.75)
def modify_svg_file(self, svg_string): def modify_svg_file(self, svg_string: str) -> str:
# svgelements cannot handle em, ex units # svgelements cannot handle em, ex units
# Convert them using 1em = 16px, 1ex = 0.5em = 8px # Convert them using 1em = 16px, 1ex = 0.5em = 8px
def convert_unit(match_obj): def convert_unit(match_obj):
@ -127,7 +130,7 @@ class SVGMobject(VMobject):
return result return result
def generate_context_values_from_config(self): def generate_context_values_from_config(self) -> dict[str]:
result = {} result = {}
if self.stroke_width is not None: if self.stroke_width is not None:
result["stroke-width"] = self.stroke_width result["stroke-width"] = self.stroke_width
@ -145,7 +148,7 @@ class SVGMobject(VMobject):
result["stroke-opacity"] = self.stroke_opacity result["stroke-opacity"] = self.stroke_opacity
return result return result
def get_mobjects_from(self, shape): def get_mobjects_from(self, shape) -> list[VMobject]:
if isinstance(shape, se.Group): if isinstance(shape, se.Group):
return list(it.chain(*( return list(it.chain(*(
self.get_mobjects_from(child) self.get_mobjects_from(child)
@ -161,7 +164,7 @@ class SVGMobject(VMobject):
return [mob] return [mob]
@staticmethod @staticmethod
def handle_transform(mob, matrix): def handle_transform(mob: VMobject, matrix: se.Matrix) -> VMobject:
mat = np.array([ mat = np.array([
[matrix.a, matrix.c], [matrix.a, matrix.c],
[matrix.b, matrix.d] [matrix.b, matrix.d]
@ -171,8 +174,10 @@ class SVGMobject(VMobject):
mob.shift(vec) mob.shift(vec)
return mob return mob
def get_mobject_from(self, shape): def get_mobject_from(self, shape: se.Shape | se.Text) -> VMobject | None:
shape_class_to_func_map = { shape_class_to_func_map: dict[
type, Callable[[se.Shape | se.Text], VMobject]
] = {
se.Path: self.path_to_mobject, se.Path: self.path_to_mobject,
se.SimpleLine: self.line_to_mobject, se.SimpleLine: self.line_to_mobject,
se.Rect: self.rect_to_mobject, se.Rect: self.rect_to_mobject,
@ -194,7 +199,10 @@ class SVGMobject(VMobject):
return None return None
@staticmethod @staticmethod
def apply_style_to_mobject(mob, shape): def apply_style_to_mobject(
mob: VMobject,
shape: se.Shape | se.Text
) -> VMobject:
mob.set_style( mob.set_style(
stroke_width=shape.stroke_width, stroke_width=shape.stroke_width,
stroke_color=shape.stroke.hex, stroke_color=shape.stroke.hex,
@ -204,16 +212,16 @@ class SVGMobject(VMobject):
) )
return mob return mob
def path_to_mobject(self, path): def path_to_mobject(self, path: se.Path) -> VMobjectFromSVGPath:
return VMobjectFromSVGPath(path, **self.path_string_config) return VMobjectFromSVGPath(path, **self.path_string_config)
def line_to_mobject(self, line): def line_to_mobject(self, line: se.Line) -> Line:
return Line( return Line(
start=_convert_point_to_3d(line.x1, line.y1), start=_convert_point_to_3d(line.x1, line.y1),
end=_convert_point_to_3d(line.x2, line.y2) end=_convert_point_to_3d(line.x2, line.y2)
) )
def rect_to_mobject(self, rect): def rect_to_mobject(self, rect: se.Rect) -> Rectangle | RoundedRectangle:
if rect.rx == 0 or rect.ry == 0: if rect.rx == 0 or rect.ry == 0:
mob = Rectangle( mob = Rectangle(
width=rect.width, width=rect.width,
@ -232,7 +240,7 @@ class SVGMobject(VMobject):
)) ))
return mob return mob
def circle_to_mobject(self, circle): def circle_to_mobject(self, circle: se.Circle) -> Circle:
# svgelements supports `rx` & `ry` but `r` # svgelements supports `rx` & `ry` but `r`
mob = Circle(radius=circle.rx) mob = Circle(radius=circle.rx)
mob.shift(_convert_point_to_3d( mob.shift(_convert_point_to_3d(
@ -240,7 +248,7 @@ class SVGMobject(VMobject):
)) ))
return mob return mob
def ellipse_to_mobject(self, ellipse): def ellipse_to_mobject(self, ellipse: se.Ellipse) -> Circle:
mob = Circle(radius=ellipse.rx) mob = Circle(radius=ellipse.rx)
mob.stretch_to_fit_height(2 * ellipse.ry) mob.stretch_to_fit_height(2 * ellipse.ry)
mob.shift(_convert_point_to_3d( mob.shift(_convert_point_to_3d(
@ -248,21 +256,21 @@ class SVGMobject(VMobject):
)) ))
return mob return mob
def polygon_to_mobject(self, polygon): def polygon_to_mobject(self, polygon: se.Polygon) -> Polygon:
points = [ points = [
_convert_point_to_3d(*point) _convert_point_to_3d(*point)
for point in polygon for point in polygon
] ]
return Polygon(*points) return Polygon(*points)
def polyline_to_mobject(self, polyline): def polyline_to_mobject(self, polyline: se.Polyline) -> Polyline:
points = [ points = [
_convert_point_to_3d(*point) _convert_point_to_3d(*point)
for point in polyline for point in polyline
] ]
return Polyline(*points) return Polyline(*points)
def text_to_mobject(self, text): def text_to_mobject(self, text: se.Text):
pass pass
@ -273,13 +281,13 @@ class VMobjectFromSVGPath(VMobject):
"should_remove_null_curves": False, "should_remove_null_curves": False,
} }
def __init__(self, path_obj, **kwargs): def __init__(self, path_obj: se.Path, **kwargs):
# Get rid of arcs # Get rid of arcs
path_obj.approximate_arcs_with_quads() path_obj.approximate_arcs_with_quads()
self.path_obj = path_obj self.path_obj = path_obj
super().__init__(**kwargs) super().__init__(**kwargs)
def init_points(self): def init_points(self) -> None:
# After a given svg_path has been converted into points, the result # After a given svg_path has been converted into points, the result
# will be saved to a file so that future calls for the same path # will be saved to a file so that future calls for the same path
# don't need to retrace the same computation. # don't need to retrace the same computation.
@ -305,7 +313,7 @@ class VMobjectFromSVGPath(VMobject):
np.save(points_filepath, self.get_points()) np.save(points_filepath, self.get_points())
np.save(tris_filepath, self.get_triangulation()) np.save(tris_filepath, self.get_triangulation())
def handle_commands(self): def handle_commands(self) -> None:
segment_class_to_func_map = { segment_class_to_func_map = {
se.Move: (self.start_new_path, ("end",)), se.Move: (self.start_new_path, ("end",)),
se.Close: (self.close_path, ()), se.Close: (self.close_path, ()),

View file

@ -1,5 +1,9 @@
from __future__ import annotations
from typing import Iterable, Sequence, Union
from functools import reduce from functools import reduce
import operator as op import operator as op
import colour
import re import re
from manimlib.constants import * from manimlib.constants import *
@ -13,10 +17,13 @@ from manimlib.utils.tex_file_writing import get_tex_config
from manimlib.utils.tex_file_writing import display_during_execution from manimlib.utils.tex_file_writing import display_during_execution
ManimColor = Union[str, colour.Color, Sequence[float]]
SCALE_FACTOR_PER_FONT_POINT = 0.001 SCALE_FACTOR_PER_FONT_POINT = 0.001
tex_string_with_color_to_mob_map: dict[
tex_string_with_color_to_mob_map = {} tuple[ManimColor, str], SVGMobject
] = {}
class SingleStringTex(VMobject): class SingleStringTex(VMobject):
@ -31,7 +38,7 @@ class SingleStringTex(VMobject):
"math_mode": True, "math_mode": True,
} }
def __init__(self, tex_string, **kwargs): def __init__(self, tex_string: str, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
assert(isinstance(tex_string, str)) assert(isinstance(tex_string, str))
self.tex_string = tex_string self.tex_string = tex_string
@ -66,7 +73,7 @@ class SingleStringTex(VMobject):
self.set_flat_stroke(self.flat_stroke) self.set_flat_stroke(self.flat_stroke)
return self return self
def get_tex_file_body(self, tex_string): def get_tex_file_body(self, tex_string: str) -> str:
new_tex = self.get_modified_expression(tex_string) new_tex = self.get_modified_expression(tex_string)
if self.math_mode: if self.math_mode:
new_tex = "\\begin{align*}\n" + new_tex + "\n\\end{align*}" new_tex = "\\begin{align*}\n" + new_tex + "\n\\end{align*}"
@ -79,10 +86,10 @@ class SingleStringTex(VMobject):
new_tex new_tex
) )
def get_modified_expression(self, tex_string): def get_modified_expression(self, tex_string: str) -> str:
return self.modify_special_strings(tex_string.strip()) return self.modify_special_strings(tex_string.strip())
def modify_special_strings(self, tex): def modify_special_strings(self, tex: str) -> str:
tex = tex.strip() tex = tex.strip()
should_add_filler = reduce(op.or_, [ should_add_filler = reduce(op.or_, [
# Fraction line needs something to be over # Fraction line needs something to be over
@ -134,7 +141,7 @@ class SingleStringTex(VMobject):
tex = "" tex = ""
return tex return tex
def balance_braces(self, tex): def balance_braces(self, tex: str) -> str:
""" """
Makes Tex resiliant to unmatched braces Makes Tex resiliant to unmatched braces
""" """
@ -154,7 +161,7 @@ class SingleStringTex(VMobject):
tex += num_unclosed_brackets * "}" tex += num_unclosed_brackets * "}"
return tex return tex
def get_tex(self): def get_tex(self) -> str:
return self.tex_string return self.tex_string
def organize_submobjects_left_to_right(self): def organize_submobjects_left_to_right(self):
@ -169,7 +176,7 @@ class Tex(SingleStringTex):
"tex_to_color_map": {}, "tex_to_color_map": {},
} }
def __init__(self, *tex_strings, **kwargs): def __init__(self, *tex_strings: str, **kwargs):
digest_config(self, kwargs) digest_config(self, kwargs)
self.tex_strings = self.break_up_tex_strings(tex_strings) self.tex_strings = self.break_up_tex_strings(tex_strings)
full_string = self.arg_separator.join(self.tex_strings) full_string = self.arg_separator.join(self.tex_strings)
@ -180,7 +187,7 @@ class Tex(SingleStringTex):
if self.organize_left_to_right: if self.organize_left_to_right:
self.organize_submobjects_left_to_right() self.organize_submobjects_left_to_right()
def break_up_tex_strings(self, tex_strings): def break_up_tex_strings(self, tex_strings: Iterable[str]) -> Iterable[str]:
# Separate out any strings specified in the isolate # Separate out any strings specified in the isolate
# or tex_to_color_map lists. # or tex_to_color_map lists.
substrings_to_isolate = [*self.isolate, *self.tex_to_color_map.keys()] substrings_to_isolate = [*self.isolate, *self.tex_to_color_map.keys()]
@ -228,7 +235,12 @@ class Tex(SingleStringTex):
self.set_submobjects(new_submobjects) self.set_submobjects(new_submobjects)
return self return self
def get_parts_by_tex(self, tex, substring=True, case_sensitive=True): def get_parts_by_tex(
self,
tex: str,
substring: bool = True,
case_sensitive: bool = True
) -> VGroup:
def test(tex1, tex2): def test(tex1, tex2):
if not case_sensitive: if not case_sensitive:
tex1 = tex1.lower() tex1 = tex1.lower()
@ -243,27 +255,36 @@ class Tex(SingleStringTex):
self.submobjects self.submobjects
)) ))
def get_part_by_tex(self, tex, **kwargs): def get_part_by_tex(self, tex: str, **kwargs) -> SingleStringTex | None:
all_parts = self.get_parts_by_tex(tex, **kwargs) all_parts = self.get_parts_by_tex(tex, **kwargs)
return all_parts[0] if all_parts else None return all_parts[0] if all_parts else None
def set_color_by_tex(self, tex, color, **kwargs): def set_color_by_tex(self, tex: str, color: ManimColor, **kwargs):
self.get_parts_by_tex(tex, **kwargs).set_color(color) self.get_parts_by_tex(tex, **kwargs).set_color(color)
return self return self
def set_color_by_tex_to_color_map(self, tex_to_color_map, **kwargs): def set_color_by_tex_to_color_map(
self,
tex_to_color_map: dict[str, ManimColor],
**kwargs
):
for tex, color in list(tex_to_color_map.items()): for tex, color in list(tex_to_color_map.items()):
self.set_color_by_tex(tex, color, **kwargs) self.set_color_by_tex(tex, color, **kwargs)
return self return self
def index_of_part(self, part, start=0): def index_of_part(self, part: SingleStringTex, start: int = 0) -> int:
return self.submobjects.index(part, start) return self.submobjects.index(part, start)
def index_of_part_by_tex(self, tex, start=0, **kwargs): def index_of_part_by_tex(self, tex: str, start: int = 0, **kwargs) -> int:
part = self.get_part_by_tex(tex, **kwargs) part = self.get_part_by_tex(tex, **kwargs)
return self.index_of_part(part, start) return self.index_of_part(part, start)
def slice_by_tex(self, start_tex=None, stop_tex=None, **kwargs): def slice_by_tex(
self,
start_tex: str | None = None,
stop_tex: str | None = None,
**kwargs
) -> VGroup:
if start_tex is None: if start_tex is None:
start_index = 0 start_index = 0
else: else:
@ -275,10 +296,10 @@ class Tex(SingleStringTex):
stop_index = self.index_of_part_by_tex(stop_tex, start=start_index, **kwargs) stop_index = self.index_of_part_by_tex(stop_tex, start=start_index, **kwargs)
return self[start_index:stop_index] return self[start_index:stop_index]
def sort_alphabetically(self): def sort_alphabetically(self) -> None:
self.submobjects.sort(key=lambda m: m.get_tex()) self.submobjects.sort(key=lambda m: m.get_tex())
def set_bstroke(self, color=BLACK, width=4): def set_bstroke(self, color: ManimColor = BLACK, width: float = 4):
self.set_stroke(color, width, background=True) self.set_stroke(color, width, background=True)
return self return self
@ -297,7 +318,7 @@ class BulletedList(TexText):
"alignment": "", "alignment": "",
} }
def __init__(self, *items, **kwargs): def __init__(self, *items: str, **kwargs):
line_separated_items = [s + "\\\\" for s in items] line_separated_items = [s + "\\\\" for s in items]
TexText.__init__(self, *line_separated_items, **kwargs) TexText.__init__(self, *line_separated_items, **kwargs)
for part in self: for part in self:
@ -310,7 +331,7 @@ class BulletedList(TexText):
buff=self.buff buff=self.buff
) )
def fade_all_but(self, index_or_string, opacity=0.5): def fade_all_but(self, index_or_string: int | str, opacity: float = 0.5) -> None:
arg = index_or_string arg = index_or_string
if isinstance(arg, str): if isinstance(arg, str):
part = self.get_part_by_tex(arg) part = self.get_part_by_tex(arg)
@ -348,7 +369,7 @@ class Title(TexText):
"underline_buff": MED_SMALL_BUFF, "underline_buff": MED_SMALL_BUFF,
} }
def __init__(self, *text_parts, **kwargs): def __init__(self, *text_parts: str, **kwargs):
TexText.__init__(self, *text_parts, **kwargs) TexText.__init__(self, *text_parts, **kwargs)
self.scale(self.scale_factor) self.scale(self.scale_factor)
self.to_edge(UP) self.to_edge(UP)

View file

@ -1,17 +1,19 @@
import hashlib from __future__ import annotations
import os import os
import re import re
import io import io
import typing import hashlib
import xml.etree.ElementTree as ET
import functools import functools
from pathlib import Path
import xml.etree.ElementTree as ET
from contextlib import contextmanager
from typing import Iterable, Sequence, Union
import pygments import pygments
import pygments.lexers import pygments.lexers
import pygments.styles import pygments.styles
from contextlib import contextmanager
from pathlib import Path
import manimpango import manimpango
from manimlib.logger import log from manimlib.logger import log
from manimlib.constants import * from manimlib.constants import *
@ -23,6 +25,12 @@ from manimlib.utils.customization import get_customization
from manimlib.utils.directories import get_downloads_dir, get_text_dir from manimlib.utils.directories import get_downloads_dir, get_text_dir
from manimpango import PangoUtils, TextSetting, MarkupUtils from manimpango import PangoUtils, TextSetting, MarkupUtils
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import colour
from manimlib.mobject.types.vectorized_mobject import VMobject
ManimColor = Union[str, colour.Color, Sequence[float]]
TEXT_MOB_SCALE_FACTOR = 0.0076 TEXT_MOB_SCALE_FACTOR = 0.0076
DEFAULT_LINE_SPACING_SCALE = 0.6 DEFAULT_LINE_SPACING_SCALE = 0.6
@ -50,7 +58,7 @@ class Text(SVGMobject):
"disable_ligatures": True, "disable_ligatures": True,
} }
def __init__(self, text, **kwargs): def __init__(self, text: str, **kwargs):
self.full2short(kwargs) self.full2short(kwargs)
digest_config(self, kwargs) digest_config(self, kwargs)
if self.size: if self.size:
@ -60,9 +68,9 @@ class Text(SVGMobject):
) )
self.font_size = self.size self.font_size = self.size
if self.lsh == -1: if self.lsh == -1:
self.lsh = self.font_size + self.font_size * DEFAULT_LINE_SPACING_SCALE self.lsh: float = self.font_size + self.font_size * DEFAULT_LINE_SPACING_SCALE
else: else:
self.lsh = self.font_size + self.font_size * self.lsh self.lsh: float = self.font_size + self.font_size * self.lsh
text_without_tabs = text text_without_tabs = text
if text.find('\t') != -1: if text.find('\t') != -1:
text_without_tabs = text.replace('\t', ' ' * self.tab_width) text_without_tabs = text.replace('\t', ' ' * self.tab_width)
@ -87,14 +95,14 @@ class Text(SVGMobject):
if self.height is None: if self.height is None:
self.scale(TEXT_MOB_SCALE_FACTOR) self.scale(TEXT_MOB_SCALE_FACTOR)
def remove_empty_path(self, file_name): def remove_empty_path(self, file_name: str) -> None:
with open(file_name, 'r') as fpr: with open(file_name, 'r') as fpr:
content = fpr.read() content = fpr.read()
content = re.sub(r'<path .*?d=""/>', '', content) content = re.sub(r'<path .*?d=""/>', '', content)
with open(file_name, 'w') as fpw: with open(file_name, 'w') as fpw:
fpw.write(content) fpw.write(content)
def apply_space_chars(self): def apply_space_chars(self) -> None:
submobs = self.submobjects.copy() submobs = self.submobjects.copy()
for char_index in range(len(self.text)): for char_index in range(len(self.text)):
if self.text[char_index] in [" ", "\t", "\n"]: if self.text[char_index] in [" ", "\t", "\n"]:
@ -103,7 +111,7 @@ class Text(SVGMobject):
submobs.insert(char_index, space) submobs.insert(char_index, space)
self.set_submobjects(submobs) self.set_submobjects(submobs)
def find_indexes(self, word): def find_indexes(self, word: str) -> list[tuple[int, int]]:
m = re.match(r'\[([0-9\-]{0,}):([0-9\-]{0,})\]', word) m = re.match(r'\[([0-9\-]{0,}):([0-9\-]{0,})\]', word)
if m: if m:
start = int(m.group(1)) if m.group(1) != '' else 0 start = int(m.group(1)) if m.group(1) != '' else 0
@ -119,20 +127,20 @@ class Text(SVGMobject):
index = self.text.find(word, index + len(word)) index = self.text.find(word, index + len(word))
return indexes return indexes
def get_parts_by_text(self, word): def get_parts_by_text(self, word: str) -> VGroup:
return VGroup(*( return VGroup(*(
self[i:j] self[i:j]
for i, j in self.find_indexes(word) for i, j in self.find_indexes(word)
)) ))
def get_part_by_text(self, word): def get_part_by_text(self, word: str) -> VMobject | None:
parts = self.get_parts_by_text(word) parts = self.get_parts_by_text(word)
if len(parts) > 0: if len(parts) > 0:
return parts[0] return parts[0]
else: else:
return None return None
def full2short(self, config): def full2short(self, config: dict[str]) -> None:
for kwargs in [config, self.CONFIG]: for kwargs in [config, self.CONFIG]:
if kwargs.__contains__('line_spacing_height'): if kwargs.__contains__('line_spacing_height'):
kwargs['lsh'] = kwargs.pop('line_spacing_height') kwargs['lsh'] = kwargs.pop('line_spacing_height')
@ -147,19 +155,25 @@ class Text(SVGMobject):
if kwargs.__contains__('text2weight'): if kwargs.__contains__('text2weight'):
kwargs['t2w'] = kwargs.pop('text2weight') kwargs['t2w'] = kwargs.pop('text2weight')
def set_color_by_t2c(self, t2c=None): def set_color_by_t2c(
self,
t2c: dict[str, ManimColor] | None = None
) -> None:
t2c = t2c if t2c else self.t2c t2c = t2c if t2c else self.t2c
for word, color in t2c.items(): for word, color in t2c.items():
for start, end in self.find_indexes(word): for start, end in self.find_indexes(word):
self[start:end].set_color(color) self[start:end].set_color(color)
def set_color_by_t2g(self, t2g=None): def set_color_by_t2g(
self,
t2g: dict[str, Iterable[ManimColor]] | None = None
) -> None:
t2g = t2g if t2g else self.t2g t2g = t2g if t2g else self.t2g
for word, gradient in t2g.items(): for word, gradient in t2g.items():
for start, end in self.find_indexes(word): for start, end in self.find_indexes(word):
self[start:end].set_color_by_gradient(*gradient) self[start:end].set_color_by_gradient(*gradient)
def text2hash(self): def text2hash(self) -> str:
settings = self.font + self.slant + self.weight settings = self.font + self.slant + self.weight
settings += str(self.t2f) + str(self.t2s) + str(self.t2w) settings += str(self.t2f) + str(self.t2s) + str(self.t2w)
settings += str(self.lsh) + str(self.font_size) settings += str(self.lsh) + str(self.font_size)
@ -168,7 +182,7 @@ class Text(SVGMobject):
hasher.update(id_str.encode()) hasher.update(id_str.encode())
return hasher.hexdigest()[:16] return hasher.hexdigest()[:16]
def text2settings(self): def text2settings(self) -> list[TextSetting]:
""" """
Substrings specified in t2f, t2s, t2w can occupy each other. Substrings specified in t2f, t2s, t2w can occupy each other.
For each category of style, a stack following first-in-last-out is constructed, For each category of style, a stack following first-in-last-out is constructed,
@ -227,7 +241,7 @@ class Text(SVGMobject):
del self.line_num del self.line_num
return settings return settings
def text2svg(self): def text2svg(self) -> str:
# anti-aliasing # anti-aliasing
size = self.font_size size = self.font_size
lsh = self.lsh lsh = self.lsh
@ -503,7 +517,7 @@ class Code(Text):
"char_width": None "char_width": None
} }
def __init__(self, code, **kwargs): def __init__(self, code: str, **kwargs):
self.full2short(kwargs) self.full2short(kwargs)
digest_config(self, kwargs) digest_config(self, kwargs)
code = code.lstrip("\n") # avoid mismatches of character indices code = code.lstrip("\n") # avoid mismatches of character indices
@ -536,7 +550,7 @@ class Code(Text):
if self.char_width is not None: if self.char_width is not None:
self.set_monospace(self.char_width) self.set_monospace(self.char_width)
def set_monospace(self, char_width): def set_monospace(self, char_width: float) -> None:
current_char_index = 0 current_char_index = 0
for i, char in enumerate(self.text): for i, char in enumerate(self.text):
if char == "\n": if char == "\n":
@ -548,7 +562,7 @@ class Code(Text):
@contextmanager @contextmanager
def register_font(font_file: typing.Union[str, Path]): def register_font(font_file: str | Path):
"""Temporarily add a font file to Pango's search path. """Temporarily add a font file to Pango's search path.
This searches for the font_file at various places. The order it searches it described below. This searches for the font_file at various places. The order it searches it described below.
1. Absolute path. 1. Absolute path.