Merge branch 'refactor' into master

This commit is contained in:
YishiMichael 2022-04-24 08:24:27 +08:00
parent 304cf88451
commit 30e33b1baa
No known key found for this signature in database
GPG key ID: EC615C0C5A86BC80
6 changed files with 706 additions and 875 deletions

View file

@ -1,12 +1,12 @@
from __future__ import annotations from __future__ import annotations
import itertools as it from abc import ABC, abstractmethod
from abc import abstractmethod
import numpy as np import numpy as np
from manimlib.animation.animation import Animation from manimlib.animation.animation import Animation
from manimlib.mobject.svg.labelled_string import LabelledString from manimlib.mobject.svg.labelled_string import LabelledString
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.bezier import integer_interpolate from manimlib.utils.bezier import integer_interpolate
from manimlib.utils.config_ops import digest_config from manimlib.utils.config_ops import digest_config
@ -17,10 +17,10 @@ from manimlib.utils.rate_functions import smooth
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from manimlib.mobject.mobject import Group from manimlib.mobject.mobject import Mobject
class ShowPartial(Animation): class ShowPartial(Animation, ABC):
""" """
Abstract class for ShowCreation and ShowPassingFlash Abstract class for ShowCreation and ShowPassingFlash
""" """
@ -176,7 +176,7 @@ class ShowIncreasingSubsets(Animation):
"int_func": np.round, "int_func": np.round,
} }
def __init__(self, group: Group, **kwargs): def __init__(self, group: Mobject, **kwargs):
self.all_submobs = list(group.submobjects) self.all_submobs = list(group.submobjects)
super().__init__(group, **kwargs) super().__init__(group, **kwargs)
@ -213,7 +213,9 @@ class AddTextWordByWord(ShowIncreasingSubsets):
def __init__(self, string_mobject, **kwargs): def __init__(self, string_mobject, **kwargs):
assert isinstance(string_mobject, LabelledString) assert isinstance(string_mobject, LabelledString)
grouped_mobject = string_mobject.submob_groups grouped_mobject = VGroup(*[
part for _, part in string_mobject.get_group_part_items()
])
digest_config(self, kwargs) digest_config(self, kwargs)
if self.run_time is None: if self.run_time is None:
self.run_time = self.time_per_word * len(grouped_mobject) self.run_time = self.time_per_word * len(grouped_mobject)

View file

@ -5,9 +5,9 @@ import itertools as it
import numpy as np import numpy as np
from manimlib.animation.composition import AnimationGroup from manimlib.animation.composition import AnimationGroup
from manimlib.animation.fading import FadeTransformPieces
from manimlib.animation.fading import FadeInFromPoint from manimlib.animation.fading import FadeInFromPoint
from manimlib.animation.fading import FadeOutToPoint from manimlib.animation.fading import FadeOutToPoint
from manimlib.animation.fading import FadeTransformPieces
from manimlib.animation.transform import ReplacementTransform from manimlib.animation.transform import ReplacementTransform
from manimlib.animation.transform import Transform from manimlib.animation.transform import Transform
from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Mobject
@ -16,13 +16,13 @@ from manimlib.mobject.svg.labelled_string import LabelledString
from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.config_ops import digest_config from manimlib.utils.config_ops import digest_config
from manimlib.utils.iterables import remove_list_redundancies
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from manimlib.mobject.svg.tex_mobject import SingleStringTex
from manimlib.mobject.svg.tex_mobject import Tex
from manimlib.scene.scene import Scene from manimlib.scene.scene import Scene
from manimlib.mobject.svg.tex_mobject import Tex, SingleStringTex
class TransformMatchingParts(AnimationGroup): class TransformMatchingParts(AnimationGroup):
@ -168,36 +168,36 @@ class TransformMatchingStrings(AnimationGroup):
assert isinstance(source, LabelledString) assert isinstance(source, LabelledString)
assert isinstance(target, LabelledString) assert isinstance(target, LabelledString)
anims = [] anims = []
source_indices = list(range(len(source.labelled_submobjects)))
target_indices = list(range(len(target.labelled_submobjects)))
def get_indices_lists(mobject, parts): source_submobs = [
return [ submob for _, submob in source.labelled_submobject_items
]
target_submobs = [
submob for _, submob in target.labelled_submobject_items
]
source_indices = list(range(len(source_submobs)))
target_indices = list(range(len(target_submobs)))
def get_filtered_indices_lists(parts, submobs, rest_indices):
return list(filter(
lambda indices_list: all([
index in rest_indices
for index in indices_list
]),
[ [
mobject.labelled_submobjects.index(submob) [submobs.index(submob) for submob in part]
for submob in part for part in parts
] ]
for part in parts ))
]
def add_anims_from(anim_class, func, source_args, target_args=None): def add_anims(anim_class, parts_pairs):
if target_args is None: for source_parts, target_parts in parts_pairs:
target_args = source_args.copy() source_indices_lists = get_filtered_indices_lists(
for source_arg, target_arg in zip(source_args, target_args): source_parts, source_submobs, source_indices
source_parts = func(source, source_arg) )
target_parts = func(target, target_arg) target_indices_lists = get_filtered_indices_lists(
source_indices_lists = list(filter( target_parts, target_submobs, target_indices
lambda indices_list: all([ )
index in source_indices
for index in indices_list
]), get_indices_lists(source, source_parts)
))
target_indices_lists = list(filter(
lambda indices_list: all([
index in target_indices
for index in indices_list
]), get_indices_lists(target, target_parts)
))
if not source_indices_lists or not target_indices_lists: if not source_indices_lists or not target_indices_lists:
continue continue
anims.append(anim_class(source_parts, target_parts, **kwargs)) anims.append(anim_class(source_parts, target_parts, **kwargs))
@ -206,41 +206,45 @@ class TransformMatchingStrings(AnimationGroup):
for index in it.chain(*target_indices_lists): for index in it.chain(*target_indices_lists):
target_indices.remove(index) target_indices.remove(index)
def get_common_substrs(substrs_from_source, substrs_from_target): def get_substr_to_parts_map(part_items):
return sorted([ result = {}
substr for substr in substrs_from_source for substr, part in part_items:
if substr and substr in substrs_from_target if substr not in result:
], key=len, reverse=True) result[substr] = []
result[substr].append(part)
def get_parts_from_keys(mobject, keys):
if isinstance(keys, str):
keys = [keys]
result = VGroup()
for key in keys:
if not isinstance(key, str):
raise TypeError(key)
result.add(*mobject.get_parts_by_string(key))
return result return result
add_anims_from( def add_anims_from(anim_class, func):
ReplacementTransform, get_parts_from_keys, source_substr_to_parts_map = get_substr_to_parts_map(func(source))
self.key_map.keys(), self.key_map.values() target_substr_to_parts_map = get_substr_to_parts_map(func(target))
add_anims(
anim_class,
[
(
VGroup(*source_substr_to_parts_map[substr]),
VGroup(*target_substr_to_parts_map[substr])
)
for substr in sorted([
s for s in source_substr_to_parts_map.keys()
if s and s in target_substr_to_parts_map.keys()
], key=len, reverse=True)
]
)
add_anims(
ReplacementTransform,
[
(source.select_parts(k), target.select_parts(v))
for k, v in self.key_map.items()
]
) )
add_anims_from( add_anims_from(
FadeTransformPieces, FadeTransformPieces,
LabelledString.get_parts_by_string, LabelledString.get_specified_part_items
get_common_substrs(
source.specified_substrs,
target.specified_substrs
)
) )
add_anims_from( add_anims_from(
FadeTransformPieces, FadeTransformPieces,
LabelledString.get_parts_by_group_substr, LabelledString.get_group_part_items
get_common_substrs(
source.group_substrs,
target.group_substrs
)
) )
rest_source = VGroup(*[source[index] for index in source_indices]) rest_source = VGroup(*[source[index] for index in source_indices])

View file

@ -1,30 +1,43 @@
from __future__ import annotations from __future__ import annotations
import re
import colour
import itertools as it
from typing import Iterable, Union, Sequence
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import itertools as it
import numpy as np
import re
from manimlib.constants import BLACK, WHITE from manimlib.constants import WHITE
from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.svg.svg_mobject import SVGMobject
from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.color import color_to_int_rgb
from manimlib.utils.color import color_to_rgb from manimlib.utils.color import color_to_rgb
from manimlib.utils.color import rgb_to_hex from manimlib.utils.color import rgb_to_hex
from manimlib.utils.config_ops import digest_config from manimlib.utils.config_ops import digest_config
from manimlib.utils.iterables import remove_list_redundancies from manimlib.utils.iterables import remove_list_redundancies
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from manimlib.mobject.types.vectorized_mobject import VMobject from colour import Color
ManimColor = Union[str, colour.Color, Sequence[float]] from typing import Iterable, Sequence, TypeVar, Union
ManimColor = Union[str, Color]
Span = tuple[int, int] Span = tuple[int, int]
Selector = Union[
str,
re.Pattern,
tuple[Union[int, None], Union[int, None]],
Iterable[Union[
str,
re.Pattern,
tuple[Union[int, None], Union[int, None]]
]]
]
T = TypeVar("T")
class _StringSVG(SVGMobject): class LabelledString(SVGMobject, ABC):
"""
An abstract base class for `MTex` and `MarkupText`
"""
CONFIG = { CONFIG = {
"height": None, "height": None,
"stroke_width": 0, "stroke_width": 0,
@ -33,42 +46,30 @@ class _StringSVG(SVGMobject):
"should_subdivide_sharp_curves": True, "should_subdivide_sharp_curves": True,
"should_remove_null_curves": True, "should_remove_null_curves": True,
}, },
}
class LabelledString(_StringSVG, ABC):
"""
An abstract base class for `MTex` and `MarkupText`
"""
CONFIG = {
"base_color": WHITE, "base_color": WHITE,
"use_plain_file": False,
"isolate": [], "isolate": [],
} }
def __init__(self, string: str, **kwargs): def __init__(self, string: str, **kwargs):
self.string = string self.string = string
digest_config(self, kwargs) digest_config(self, kwargs)
if self.base_color is None:
self.base_color = WHITE
self.base_color_int = self.color_to_int(self.base_color)
# Convert `base_color` to hex code. self.full_span = (0, len(self.string))
self.base_color = rgb_to_hex(color_to_rgb(
self.base_color \
or self.svg_default.get("color", None) \
or self.svg_default.get("fill_color", None) \
or WHITE
))
self.svg_default["fill_color"] = BLACK
self.pre_parse()
self.parse() self.parse()
super().__init__() super().__init__(**kwargs)
self.post_parse() self.labelled_submobject_items = [
(submob.label, submob)
for submob in self.submobjects
]
def get_file_path(self) -> str: def get_file_path(self) -> str:
return self.get_file_path_(use_plain_file=False) return self.get_file_path_(is_labelled=False)
def get_file_path_(self, use_plain_file: bool) -> str: def get_file_path_(self, is_labelled: bool) -> str:
content = self.get_content(use_plain_file) content = self.get_content(is_labelled)
return self.get_file_path_by_content(content) return self.get_file_path_by_content(content)
@abstractmethod @abstractmethod
@ -78,91 +79,135 @@ class LabelledString(_StringSVG, ABC):
def generate_mobject(self) -> None: def generate_mobject(self) -> None:
super().generate_mobject() super().generate_mobject()
submob_labels = [ num_labels = len(self.label_span_list)
self.color_to_label(submob.get_fill_color()) if num_labels:
for submob in self.submobjects file_path = self.get_file_path_(is_labelled=True)
] labelled_svg = SVGMobject(file_path)
if self.use_plain_file or self.has_predefined_local_colors: submob_color_ints = [
file_path = self.get_file_path_(use_plain_file=True) self.color_to_int(submob.get_fill_color())
plain_svg = _StringSVG( for submob in labelled_svg.submobjects
file_path, ]
svg_default=self.svg_default,
path_string_config=self.path_string_config
)
self.set_submobjects(plain_svg.submobjects)
else: else:
self.set_fill(self.base_color) submob_color_ints = [0] * len(self.submobjects)
for submob, label in zip(self.submobjects, submob_labels):
submob.label = label
def pre_parse(self) -> None: if len(self.submobjects) != len(submob_color_ints):
self.string_len = len(self.string) raise ValueError(
self.full_span = (0, self.string_len) "Cannot align submobjects of the labelled svg "
"to the original svg"
)
unrecognized_color_ints = self.remove_redundancies(sorted(filter(
lambda color_int: color_int > num_labels,
submob_color_ints
)))
if unrecognized_color_ints:
raise ValueError(
"Unrecognized color label(s) detected: "
f"{','.join(map(self.int_to_hex, unrecognized_color_ints))}"
)
for submob, color_int in zip(self.submobjects, submob_color_ints):
submob.label = color_int - 1
def parse(self) -> None: def parse(self) -> None:
self.command_repl_items = self.get_command_repl_items() self.command_repl_items = self.get_command_repl_items()
self.command_spans = self.get_command_spans()
self.extra_entity_spans = self.get_extra_entity_spans()
self.entity_spans = self.get_entity_spans()
self.extra_ignored_spans = self.get_extra_ignored_spans()
self.skipped_spans = self.get_skipped_spans()
self.internal_specified_spans = self.get_internal_specified_spans()
self.external_specified_spans = self.get_external_specified_spans()
self.specified_spans = self.get_specified_spans() self.specified_spans = self.get_specified_spans()
self.label_span_list = self.get_label_span_list()
self.check_overlapping() self.check_overlapping()
self.label_span_list = self.get_label_span_list()
def post_parse(self) -> None: if len(self.label_span_list) >= 16777216:
self.labelled_submobject_items = [ raise ValueError("Cannot handle that many substrings")
(submob.label, submob)
for submob in self.submobjects
]
self.labelled_submobjects = self.get_labelled_submobjects()
self.specified_substrs = self.get_specified_substrs()
self.group_items = self.get_group_items()
self.group_substrs = self.get_group_substrs()
self.submob_groups = self.get_submob_groups()
# Toolkits # Toolkits
def get_substr(self, span: Span) -> str: def get_substr(self, span: Span) -> str:
return self.string[slice(*span)] return self.string[slice(*span)]
def finditer( def match(self, pattern: str | re.Pattern, **kwargs) -> re.Pattern | None:
self, pattern: str, flags: int = 0, **kwargs if isinstance(pattern, str):
) -> Iterable[re.Match]: pattern = re.compile(pattern)
return re.compile(pattern, flags).finditer(self.string, **kwargs) return re.compile(pattern).match(self.string, **kwargs)
def search( def find_spans(self, pattern: str | re.Pattern, **kwargs) -> list[Span]:
self, pattern: str, flags: int = 0, **kwargs if isinstance(pattern, str):
) -> re.Match | None: pattern = re.compile(pattern)
return re.compile(pattern, flags).search(self.string, **kwargs)
def match(
self, pattern: str, flags: int = 0, **kwargs
) -> re.Match | None:
return re.compile(pattern, flags).match(self.string, **kwargs)
def find_spans(self, pattern: str, **kwargs) -> list[Span]:
return [ return [
match_obj.span() match_obj.span()
for match_obj in self.finditer(pattern, **kwargs) for match_obj in pattern.finditer(self.string, **kwargs)
] ]
def find_substr(self, substr: str, **kwargs) -> list[Span]: def find_indices(self, pattern: str | re.Pattern, **kwargs) -> list[int]:
if not substr: return [index for index, _ in self.find_spans(pattern, **kwargs)]
return []
return self.find_spans(re.escape(substr), **kwargs)
def find_substrs(self, substrs: list[str], **kwargs) -> list[Span]:
return list(it.chain(*[
self.find_substr(substr, **kwargs)
for substr in remove_list_redundancies(substrs)
]))
@staticmethod @staticmethod
def get_neighbouring_pairs(iterable: list) -> list[tuple]: def is_single_selector(selector: Selector) -> bool:
return list(zip(iterable[:-1], iterable[1:])) if isinstance(selector, str):
return True
if isinstance(selector, re.Pattern):
return True
if isinstance(selector, tuple):
if len(selector) == 2 and all([
isinstance(index, int) or index is None
for index in selector
]):
return True
return False
def find_spans_by_selector(self, selector: Selector) -> list[Span]:
if self.is_single_selector(selector):
selector = (selector,)
result = []
for sel in selector:
if not self.is_single_selector(sel):
raise TypeError(f"Invalid selector: '{sel}'")
if isinstance(sel, str):
spans = self.find_spans(re.escape(sel))
elif isinstance(sel, re.Pattern):
spans = self.find_spans(sel)
else:
string_len = self.full_span[1]
span = tuple([
(
min(index, string_len)
if index >= 0
else max(index + string_len, 0)
)
if index is not None else default_index
for index, default_index in zip(sel, self.full_span)
])
spans = [span]
result.extend(spans)
return sorted(filter(
lambda span: span[0] < span[1],
self.remove_redundancies(result)
))
@staticmethod
def chain(*iterables: Iterable[T]) -> list[T]:
return list(it.chain(*iterables))
@staticmethod
def remove_redundancies(vals: Sequence[T]) -> list[T]:
return remove_list_redundancies(vals)
@staticmethod
def get_neighbouring_pairs(vals: Sequence[T]) -> list[tuple[T, T]]:
return list(zip(vals[:-1], vals[1:]))
@staticmethod
def compress_neighbours(vals: Sequence[T]) -> list[tuple[T, Span]]:
if not vals:
return []
unique_vals = [vals[0]]
indices = [0]
for index, val in enumerate(vals):
if val == unique_vals[-1]:
continue
unique_vals.append(val)
indices.append(index)
indices.append(len(vals))
spans = LabelledString.get_neighbouring_pairs(indices)
return list(zip(unique_vals, spans))
@staticmethod @staticmethod
def span_contains(span_0: Span, span_1: Span) -> bool: def span_contains(span_0: Span, span_1: Span) -> bool:
@ -182,194 +227,88 @@ class LabelledString(_StringSVG, ABC):
)) ))
@staticmethod @staticmethod
def compress_neighbours(vals: list[int]) -> list[tuple[int, Span]]: def merge_inserted_strings_from_pairs(
if not vals: inserted_string_pairs: list[tuple[Span, tuple[str, str]]]
) -> list[tuple[int, str]]:
if not inserted_string_pairs:
return [] return []
unique_vals = [vals[0]] spans = [
indices = [0] span for span, _ in inserted_string_pairs
for index, val in enumerate(vals): ]
if val == unique_vals[-1]: sorted_index_flag_pairs = sorted(
continue it.product(range(len(spans)), range(2)),
unique_vals.append(val) key=lambda t: (
indices.append(index) spans[t[0]][t[1]],
indices.append(len(vals)) np.sign(spans[t[0]][1 - t[1]] - spans[t[0]][t[1]]),
spans = LabelledString.get_neighbouring_pairs(indices) -spans[t[0]][1 - t[1]],
return list(zip(unique_vals, spans)) t[1],
(1, -1)[t[1]] * t[0]
@staticmethod
def find_region_index(seq: list[int], val: int) -> int:
# Returns an integer in `range(-1, len(seq))` satisfying
# `seq[result] <= val < seq[result + 1]`.
# `seq` should be sorted in ascending order.
if not seq or val < seq[0]:
return -1
result = len(seq) - 1
while val < seq[result]:
result -= 1
return result
@staticmethod
def take_nearest_value(seq: list[int], val: int, index_shift: int) -> int:
sorted_seq = sorted(seq)
index = LabelledString.find_region_index(sorted_seq, val)
return sorted_seq[index + index_shift]
@staticmethod
def generate_span_repl_dict(
inserted_string_pairs: list[tuple[Span, tuple[str, str]]],
other_repl_items: list[tuple[Span, str]]
) -> dict[Span, str]:
result = dict(other_repl_items)
if not inserted_string_pairs:
return result
indices, _, _, inserted_strings = zip(*sorted([
(
span[flag],
-flag,
-span[1 - flag],
str_pair[flag]
) )
for span, str_pair in inserted_string_pairs )
for flag in range(2) indices, inserted_strings = zip(*[
])) list(zip(*inserted_string_pairs[item_index]))[flag]
result.update({ for item_index, flag in sorted_index_flag_pairs
(index, index): "".join(inserted_strings[slice(*item_span)]) ])
return [
(index, "".join(inserted_strings[slice(*item_span)]))
for index, item_span for index, item_span
in LabelledString.compress_neighbours(indices) in LabelledString.compress_neighbours(indices)
}) ]
return result
def get_replaced_substr( def get_replaced_substr(
self, span: Span, span_repl_dict: dict[Span, str] self, span: Span, repl_items: list[tuple[Span, str]]
) -> str: ) -> str:
repl_spans = sorted(filter( if not repl_items:
lambda repl_span: self.span_contains(span, repl_span), return self.get_substr(span)
span_repl_dict.keys()
))
if not all(
span_0[1] <= span_1[0]
for span_0, span_1 in self.get_neighbouring_pairs(repl_spans)
):
raise ValueError("Overlapping replacement")
sorted_repl_items = sorted(repl_items, key=lambda t: t[0])
repl_spans, repl_strs = zip(*sorted_repl_items)
pieces = [ pieces = [
self.get_substr(piece_span) self.get_substr(piece_span)
for piece_span in self.get_complement_spans(repl_spans, span) for piece_span in self.get_complement_spans(repl_spans, span)
] ]
repl_strs = [span_repl_dict[repl_span] for repl_span in repl_spans] repl_strs = [*repl_strs, ""]
repl_strs.append("") return "".join(self.chain(*zip(pieces, repl_strs)))
return "".join(it.chain(*zip(pieces, repl_strs)))
def get_replaced_string(
self,
inserted_string_pairs: list[tuple[Span, tuple[str, str]]],
repl_items: list[tuple[Span, str]]
) -> str:
all_repl_items = self.chain(
repl_items,
[
((index, index), inserted_string)
for index, inserted_string
in self.merge_inserted_strings_from_pairs(
inserted_string_pairs
)
]
)
return self.get_replaced_substr(self.full_span, all_repl_items)
@staticmethod @staticmethod
def rslide(index: int, skipped: list[Span]) -> int: def color_to_int(color: ManimColor) -> int:
transfer_dict = dict(sorted(skipped)) hex_code = rgb_to_hex(color_to_rgb(color))
while index in transfer_dict.keys(): return int(hex_code[1:], 16)
index = transfer_dict[index]
return index
@staticmethod
def lslide(index: int, skipped: list[Span]) -> int:
transfer_dict = dict(sorted([
skipped_span[::-1] for skipped_span in skipped
], reverse=True))
while index in transfer_dict.keys():
index = transfer_dict[index]
return index
@staticmethod
def rgb_to_int(rgb_tuple: tuple[int, int, int]) -> int:
r, g, b = rgb_tuple
rg = r * 256 + g
return rg * 256 + b
@staticmethod
def int_to_rgb(rgb_int: int) -> tuple[int, int, int]:
rg, b = divmod(rgb_int, 256)
r, g = divmod(rg, 256)
return r, g, b
@staticmethod @staticmethod
def int_to_hex(rgb_int: int) -> str: def int_to_hex(rgb_int: int) -> str:
return "#{:06x}".format(rgb_int).upper() return "#{:06x}".format(rgb_int).upper()
@staticmethod
def hex_to_int(rgb_hex: str) -> int:
return int(rgb_hex[1:], 16)
@staticmethod
def color_to_label(color: ManimColor) -> int:
rgb_tuple = color_to_int_rgb(color)
rgb = LabelledString.rgb_to_int(rgb_tuple)
return rgb - 1
# Parsing # Parsing
@abstractmethod @abstractmethod
def get_command_repl_items(self) -> list[tuple[Span, str]]: def get_command_repl_items(self) -> list[tuple[Span, str]]:
return [] return []
def get_command_spans(self) -> list[Span]:
return [cmd_span for cmd_span, _ in self.command_repl_items]
@abstractmethod @abstractmethod
def get_extra_entity_spans(self) -> list[Span]:
return []
def get_entity_spans(self) -> list[Span]:
return list(it.chain(
self.command_spans,
self.extra_entity_spans
))
@abstractmethod
def get_extra_ignored_spans(self) -> list[int]:
return []
def get_skipped_spans(self) -> list[Span]:
return list(it.chain(
self.find_spans(r"\s"),
self.command_spans,
self.extra_ignored_spans
))
def shrink_span(self, span: Span) -> Span:
return (
self.rslide(span[0], self.skipped_spans),
self.lslide(span[1], self.skipped_spans)
)
@abstractmethod
def get_internal_specified_spans(self) -> list[Span]:
return []
@abstractmethod
def get_external_specified_spans(self) -> list[Span]:
return []
def get_specified_spans(self) -> list[Span]: def get_specified_spans(self) -> list[Span]:
spans = list(it.chain(
self.internal_specified_spans,
self.external_specified_spans,
self.find_substrs(self.isolate)
))
shrinked_spans = list(filter(
lambda span: span[0] < span[1] and not any([
entity_span[0] < index < entity_span[1]
for index in span
for entity_span in self.entity_spans
]),
[self.shrink_span(span) for span in spans]
))
return remove_list_redundancies(shrinked_spans)
@abstractmethod
def get_label_span_list(self) -> list[Span]:
return [] return []
def check_overlapping(self) -> None: def check_overlapping(self) -> None:
for span_0, span_1 in it.product(self.label_span_list, repeat=2): for span_0, span_1 in it.product(self.specified_spans, repeat=2):
if not span_0[0] < span_1[0] < span_0[1] < span_1[1]: if not span_0[0] < span_1[0] < span_0[1] < span_1[1]:
continue continue
raise ValueError( raise ValueError(
@ -378,29 +317,20 @@ class LabelledString(_StringSVG, ABC):
) )
@abstractmethod @abstractmethod
def get_content(self, use_plain_file: bool) -> str: def get_label_span_list(self) -> list[Span]:
return "" return []
@abstractmethod @abstractmethod
def has_predefined_local_colors(self) -> bool: def get_content(self, is_labelled: bool) -> str:
return False return ""
# Post-parsing # Selector
def get_labelled_submobjects(self) -> list[VMobject]:
return [submob for _, submob in self.labelled_submobject_items]
@abstractmethod
def get_cleaned_substr(self, span: Span) -> str: def get_cleaned_substr(self, span: Span) -> str:
span_repl_dict = dict.fromkeys(self.command_spans, "") return ""
return self.get_replaced_substr(span, span_repl_dict)
def get_specified_substrs(self) -> list[str]: def get_group_part_items(self) -> list[tuple[str, VGroup]]:
return remove_list_redundancies([
self.get_cleaned_substr(span)
for span in self.specified_spans
])
def get_group_items(self) -> list[tuple[str, VGroup]]:
if not self.labelled_submobject_items: if not self.labelled_submobject_items:
return [] return []
@ -425,118 +355,56 @@ class LabelledString(_StringSVG, ABC):
ordered_spans ordered_spans
) )
] ]
shrinked_spans = [ group_substrs = [
self.shrink_span(span) self.get_cleaned_substr(span) if span[0] < span[1] else ""
for span in self.get_complement_spans( for span in self.get_complement_spans(
interval_spans, (ordered_spans[0][0], ordered_spans[-1][1]) interval_spans, (ordered_spans[0][0], ordered_spans[-1][1])
) )
] ]
group_substrs = [
self.get_cleaned_substr(span) if span[0] < span[1] else ""
for span in shrinked_spans
]
submob_groups = VGroup(*[ submob_groups = VGroup(*[
VGroup(*labelled_submobjects[slice(*submob_span)]) VGroup(*labelled_submobjects[slice(*submob_span)])
for submob_span in labelled_submob_spans for submob_span in labelled_submob_spans
]) ])
return list(zip(group_substrs, submob_groups)) return list(zip(group_substrs, submob_groups))
def get_group_substrs(self) -> list[str]: def get_specified_part_items(self) -> list[tuple[str, VGroup]]:
return [group_substr for group_substr, _ in self.group_items] return [
(
def get_submob_groups(self) -> list[VGroup]: self.get_substr(span),
return [submob_group for _, submob_group in self.group_items] self.select_part_by_span(span)
def get_parts_by_group_substr(self, substr: str) -> VGroup:
return VGroup(*[
group
for group_substr, group in self.group_items
if group_substr == substr
])
# Selector
def find_span_components(
self, custom_span: Span, substring: bool = True
) -> list[Span]:
shrinked_span = self.shrink_span(custom_span)
if shrinked_span[0] >= shrinked_span[1]:
return []
if substring:
indices = remove_list_redundancies(list(it.chain(
self.full_span,
*self.label_span_list
)))
span_begin = self.take_nearest_value(
indices, shrinked_span[0], 0
) )
span_end = self.take_nearest_value( for span in self.specified_spans
indices, shrinked_span[1] - 1, 1 ]
)
else:
span_begin, span_end = shrinked_span
span_choices = sorted(filter( def select_part_by_span(self, custom_span: Span) -> VGroup:
lambda span: self.span_contains((span_begin, span_end), span),
self.label_span_list
))
# Choose spans that reach the farthest.
span_choices_dict = dict(span_choices)
result = []
while span_begin < span_end:
if span_begin not in span_choices_dict.keys():
span_begin += 1
continue
next_begin = span_choices_dict[span_begin]
result.append((span_begin, next_begin))
span_begin = next_begin
return result
def get_part_by_custom_span(self, custom_span: Span, **kwargs) -> VGroup:
labels = [ labels = [
label for label, span in enumerate(self.label_span_list) label for label, span in enumerate(self.label_span_list)
if any([ if self.span_contains(custom_span, span)
self.span_contains(span_component, span)
for span_component in self.find_span_components(
custom_span, **kwargs
)
])
] ]
return VGroup(*[ return VGroup(*[
submob for label, submob in self.labelled_submobject_items submob for label, submob in self.labelled_submobject_items
if label in labels if label in labels
]) ])
def get_parts_by_string( def select_parts(self, selector: Selector) -> VGroup:
self, substr: str, return VGroup(*filter(
case_sensitive: bool = True, regex: bool = False, **kwargs lambda part: part.submobjects,
) -> VGroup: [
flags = 0 self.select_part_by_span(span)
if not case_sensitive: for span in self.find_spans_by_selector(selector)
flags |= re.I ]
pattern = substr if regex else re.escape(substr) ))
return VGroup(*[
self.get_part_by_custom_span(span, **kwargs)
for span in self.find_spans(pattern, flags=flags)
if span[0] < span[1]
])
def get_part_by_string( def select_part(self, selector: Selector, index: int = 0) -> VGroup:
self, substr: str, index: int = 0, **kwargs return self.select_parts(selector)[index]
) -> VMobject:
return self.get_parts_by_string(substr, **kwargs)[index]
def set_color_by_string(self, substr: str, color: ManimColor, **kwargs): def set_parts_color(self, selector: Selector, color: ManimColor):
self.get_parts_by_string(substr, **kwargs).set_color(color) self.select_parts(selector).set_color(color)
return self return self
def set_color_by_string_to_color_map( def set_parts_color_by_dict(self, color_map: dict[Selector, ManimColor]):
self, string_to_color_map: dict[str, ManimColor], **kwargs for selector, color in color_map.items():
): self.set_parts_color(selector, color)
for substr, color in string_to_color_map.items():
self.set_color_by_string(substr, color, **kwargs)
return self return self
def get_string(self) -> str: def get_string(self) -> str:

View file

@ -1,27 +1,46 @@
from __future__ import annotations from __future__ import annotations
import itertools as it import re
import colour
from typing import Union, Sequence
from manimlib.mobject.svg.labelled_string import LabelledString from manimlib.mobject.svg.labelled_string import LabelledString
from manimlib.utils.tex_file_writing import tex_to_svg_file
from manimlib.utils.tex_file_writing import get_tex_config
from manimlib.utils.tex_file_writing import display_during_execution from manimlib.utils.tex_file_writing import display_during_execution
from manimlib.utils.tex_file_writing import get_tex_config
from manimlib.utils.tex_file_writing import tex_to_svg_file
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from manimlib.mobject.types.vectorized_mobject import VMobject from colour import Color
from typing import Iterable, Union
from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VGroup
ManimColor = Union[str, colour.Color, Sequence[float]]
ManimColor = Union[str, Color]
Span = tuple[int, int] Span = tuple[int, int]
Selector = Union[
str,
re.Pattern,
tuple[Union[int, None], Union[int, None]],
Iterable[Union[
str,
re.Pattern,
tuple[Union[int, None], Union[int, None]]
]]
]
SCALE_FACTOR_PER_FONT_POINT = 0.001 SCALE_FACTOR_PER_FONT_POINT = 0.001
TEX_COLOR_COMMANDS_DICT = {
"\\color": (1, False),
"\\textcolor": (1, False),
"\\pagecolor": (1, True),
"\\colorbox": (1, True),
"\\fcolorbox": (2, True),
}
class MTex(LabelledString): class MTex(LabelledString):
CONFIG = { CONFIG = {
"font_size": 48, "font_size": 48,
@ -32,7 +51,7 @@ class MTex(LabelledString):
def __init__(self, tex_string: str, **kwargs): def __init__(self, tex_string: str, **kwargs):
# Prevent from passing an empty string. # Prevent from passing an empty string.
if not tex_string: if not tex_string.strip():
tex_string = "\\\\" tex_string = "\\\\"
self.tex_string = tex_string self.tex_string = tex_string
super().__init__(tex_string, **kwargs) super().__init__(tex_string, **kwargs)
@ -47,7 +66,6 @@ class MTex(LabelledString):
self.svg_default, self.svg_default,
self.path_string_config, self.path_string_config,
self.base_color, self.base_color,
self.use_plain_file,
self.isolate, self.isolate,
self.tex_string, self.tex_string,
self.alignment, self.alignment,
@ -61,85 +79,95 @@ class MTex(LabelledString):
tex_config["text_to_replace"], tex_config["text_to_replace"],
content content
) )
with display_during_execution(f"Writing \"{self.tex_string}\""): with display_during_execution(f"Writing \"{self.string}\""):
file_path = tex_to_svg_file(full_tex) file_path = tex_to_svg_file(full_tex)
return file_path return file_path
def pre_parse(self) -> None: def parse(self) -> None:
super().pre_parse()
self.backslash_indices = self.get_backslash_indices() self.backslash_indices = self.get_backslash_indices()
self.brace_index_pairs = self.get_brace_index_pairs() self.command_spans = self.get_command_spans()
self.script_char_spans = self.get_script_char_spans() self.brace_spans = self.get_brace_spans()
self.script_char_indices = self.get_script_char_indices()
self.script_content_spans = self.get_script_content_spans() self.script_content_spans = self.get_script_content_spans()
self.script_spans = self.get_script_spans() self.script_spans = self.get_script_spans()
super().parse()
# Toolkits # Toolkits
@staticmethod @staticmethod
def get_color_command_str(rgb_int: int) -> str: def get_color_command_str(rgb_int: int) -> str:
rgb_tuple = MTex.int_to_rgb(rgb_int) rg, b = divmod(rgb_int, 256)
return "".join([ r, g = divmod(rg, 256)
"\\color[RGB]", return f"\\color[RGB]{{{r}, {g}, {b}}}"
"{",
",".join(map(str, rgb_tuple)),
"}"
])
# Pre-parsing @staticmethod
def shrink_span(span: Span, skippable_indices: list[int]) -> Span:
span_begin, span_end = span
while span_begin in skippable_indices:
span_begin += 1
while span_end - 1 in skippable_indices:
span_end -= 1
return (span_begin, span_end)
# Parsing
def get_backslash_indices(self) -> list[int]: def get_backslash_indices(self) -> list[int]:
# The latter of `\\` doesn't count. # The latter of `\\` doesn't count.
return list(it.chain(*[ return self.find_indices(r"\\.")
range(span[0], span[1], 2)
for span in self.find_spans(r"\\+")
]))
def get_unescaped_char_spans(self, chars: str): def get_command_spans(self) -> list[Span]:
return sorted(filter( return [
lambda span: span[0] - 1 not in self.backslash_indices, self.match(r"\\(?:[a-zA-Z]+|.)", pos=index).span()
self.find_substrs(list(chars)) for index in self.backslash_indices
]
def get_unescaped_char_indices(self, char: str) -> list[int]:
return list(filter(
lambda index: index - 1 not in self.backslash_indices,
self.find_indices(re.escape(char))
)) ))
def get_brace_index_pairs(self) -> list[Span]: def get_brace_spans(self) -> list[Span]:
left_brace_indices = [] span_begins = []
right_brace_indices = [] span_ends = []
left_brace_indices_stack = [] span_begins_stack = []
for span in self.get_unescaped_char_spans("{}"): char_items = sorted([
index = span[0] (index, char)
if self.get_substr(span) == "{": for char in "{}"
left_brace_indices_stack.append(index) for index in self.get_unescaped_char_indices(char)
])
for index, char in char_items:
if char == "{":
span_begins_stack.append(index)
else: else:
if not left_brace_indices_stack: if not span_begins_stack:
raise ValueError("Missing '{' inserted") raise ValueError("Missing '{' inserted")
left_brace_index = left_brace_indices_stack.pop() span_begins.append(span_begins_stack.pop())
left_brace_indices.append(left_brace_index) span_ends.append(index + 1)
right_brace_indices.append(index) if span_begins_stack:
if left_brace_indices_stack:
raise ValueError("Missing '}' inserted") raise ValueError("Missing '}' inserted")
return list(zip(left_brace_indices, right_brace_indices)) return list(zip(span_begins, span_ends))
def get_script_char_spans(self) -> list[int]: def get_script_char_indices(self) -> list[int]:
return self.get_unescaped_char_spans("_^") return self.chain(*[
self.get_unescaped_char_indices(char)
for char in "_^"
])
def get_script_content_spans(self) -> list[Span]: def get_script_content_spans(self) -> list[Span]:
result = [] result = []
brace_indices_dict = dict(self.brace_index_pairs) script_entity_dict = dict(self.chain(
script_pattern = r"[a-zA-Z0-9]|\\[a-zA-Z]+" self.brace_spans,
for script_char_span in self.script_char_spans: self.command_spans
span_begin = self.match(r"\s*", pos=script_char_span[1]).end() ))
if span_begin in brace_indices_dict.keys(): for index in self.script_char_indices:
span_end = brace_indices_dict[span_begin] + 1 span_begin = self.match(r"\s*", pos=index + 1).end()
if span_begin in script_entity_dict.keys():
span_end = script_entity_dict[span_begin]
else: else:
match_obj = self.match(script_pattern, pos=span_begin) match_obj = self.match(r".", pos=span_begin)
if not match_obj: if match_obj is None:
script_name = { continue
"_": "subscript",
"^": "superscript"
}[script_char]
raise ValueError(
f"Unclear {script_name} detected while parsing. "
"Please use braces to clarify"
)
span_end = match_obj.end() span_end = match_obj.end()
result.append((span_begin, span_end)) result.append((span_begin, span_end))
return result return result
@ -147,110 +175,102 @@ class MTex(LabelledString):
def get_script_spans(self) -> list[Span]: def get_script_spans(self) -> list[Span]:
return [ return [
( (
self.search(r"\s*$", endpos=script_char_span[0]).start(), self.match(r"[\s\S]*?(\s*)$", endpos=index).start(1),
script_content_span[1] script_content_span[1]
) )
for script_char_span, script_content_span in zip( for index, script_content_span in zip(
self.script_char_spans, self.script_content_spans self.script_char_indices, self.script_content_spans
) )
] ]
# Parsing
def get_command_repl_items(self) -> list[tuple[Span, str]]: def get_command_repl_items(self) -> list[tuple[Span, str]]:
color_related_command_dict = {
"color": (1, False),
"textcolor": (1, False),
"pagecolor": (1, True),
"colorbox": (1, True),
"fcolorbox": (2, True),
}
result = [] result = []
backslash_indices = self.backslash_indices brace_spans_dict = dict(self.brace_spans)
right_brace_indices = [ brace_begins = list(brace_spans_dict.keys())
right_index for cmd_span in self.command_spans:
for left_index, right_index in self.brace_index_pairs cmd_name = self.get_substr(cmd_span)
] if cmd_name not in TEX_COLOR_COMMANDS_DICT.keys():
pattern = "".join([
r"\\",
"(",
"|".join(color_related_command_dict.keys()),
")",
r"(?![a-zA-Z])"
])
for match_obj in self.finditer(pattern):
span_begin, cmd_end = match_obj.span()
if span_begin not in backslash_indices:
continue continue
cmd_name = match_obj.group(1) n_braces, substitute_cmd = TEX_COLOR_COMMANDS_DICT[cmd_name]
n_braces, substitute_cmd = color_related_command_dict[cmd_name] span_begin, span_end = cmd_span
span_end = self.take_nearest_value( for _ in range(n_braces):
right_brace_indices, cmd_end, n_braces span_end = brace_spans_dict[min(filter(
) + 1 lambda index: index >= span_end,
brace_begins
))]
if substitute_cmd: if substitute_cmd:
repl_str = "\\" + cmd_name + n_braces * "{black}" repl_str = cmd_name + n_braces * "{black}"
else: else:
repl_str = "" repl_str = ""
result.append(((span_begin, span_end), repl_str)) result.append(((span_begin, span_end), repl_str))
return result return result
def get_extra_entity_spans(self) -> list[Span]: def get_specified_spans(self) -> list[Span]:
return [ # Match paired double braces (`{{...}}`).
self.match(r"\\([a-zA-Z]+|.)", pos=index).span() sorted_brace_spans = sorted(self.brace_spans, key=lambda t: t[1])
for index in self.backslash_indices inner_brace_spans = [
sorted_brace_spans[span_span[0]]
for _, span_span in self.compress_neighbours([
(brace_span[0] + index, brace_span[1] - index)
for index, brace_span in enumerate(sorted_brace_spans)
])
if span_span[1] - span_span[0] >= 2
]
inner_brace_content_spans = [
(span[0] + 1, span[1] - 1)
for span in inner_brace_spans
if span[1] - span[0] > 2
] ]
def get_extra_ignored_spans(self) -> list[int]: result = self.remove_redundancies(self.chain(
return self.script_char_spans.copy() inner_brace_content_spans,
*[
def get_internal_specified_spans(self) -> list[Span]: self.find_spans_by_selector(selector)
# Match paired double braces (`{{...}}`). for selector in self.tex_to_color_map.keys()
result = [] ],
reversed_brace_indices_dict = dict([ self.find_spans_by_selector(self.isolate)
pair[::-1] for pair in self.brace_index_pairs ))
]) return list(filter(
skip = False lambda span: not any([
for prev_right_index, right_index in self.get_neighbouring_pairs( entity_begin < index < entity_end
list(reversed_brace_indices_dict.keys()) for index in span
): for entity_begin, entity_end in self.command_spans
if skip: ]),
skip = False result
continue ))
if right_index != prev_right_index + 1:
continue
left_index = reversed_brace_indices_dict[right_index]
prev_left_index = reversed_brace_indices_dict[prev_right_index]
if left_index != prev_left_index - 1:
continue
result.append((left_index, right_index + 1))
skip = True
return result
def get_external_specified_spans(self) -> list[Span]:
return self.find_substrs(list(self.tex_to_color_map.keys()))
def get_label_span_list(self) -> list[Span]: def get_label_span_list(self) -> list[Span]:
reversed_script_spans_dict = dict([
script_span[::-1] for script_span in self.script_spans
])
skippable_indices = self.chain(
self.find_indices(r"\s"),
self.script_char_indices
)
result = self.script_content_spans.copy() result = self.script_content_spans.copy()
for span_begin, span_end in self.specified_spans: for span in self.specified_spans:
shrinked_end = self.lslide(span_end, self.script_spans) span_begin, span_end = self.shrink_span(span, skippable_indices)
if span_begin >= shrinked_end: while span_end in reversed_script_spans_dict.keys():
span_end = reversed_script_spans_dict[span_end]
if span_begin >= span_end:
continue continue
shrinked_span = (span_begin, shrinked_end) shrinked_span = (span_begin, span_end)
if shrinked_span in result: if shrinked_span in result:
continue continue
result.append(shrinked_span) result.append(shrinked_span)
return result return result
def get_content(self, use_plain_file: bool) -> str: def get_content(self, is_labelled: bool) -> str:
if use_plain_file: if is_labelled:
span_repl_dict = {} extended_label_span_list = []
else: script_spans_dict = dict(self.script_spans)
extended_label_span_list = [ for span in self.label_span_list:
span if span not in self.script_content_spans:
if span in self.script_content_spans span_begin, span_end = span
else (span[0], self.rslide(span[1], self.script_spans)) while span_end in script_spans_dict.keys():
for span in self.label_span_list span_end = script_spans_dict[span_end]
] span = (span_begin, span_end)
extended_label_span_list.append(span)
inserted_string_pairs = [ inserted_string_pairs = [
(span, ( (span, (
"{{" + self.get_color_command_str(label + 1), "{{" + self.get_color_command_str(label + 1),
@ -258,43 +278,52 @@ class MTex(LabelledString):
)) ))
for label, span in enumerate(extended_label_span_list) for label, span in enumerate(extended_label_span_list)
] ]
span_repl_dict = self.generate_span_repl_dict( result = self.get_replaced_string(
inserted_string_pairs, inserted_string_pairs, self.command_repl_items
self.command_repl_items
) )
result = self.get_replaced_substr(self.full_span, span_repl_dict) else:
result = self.string
if self.tex_environment: if self.tex_environment:
result = "\n".join([ if isinstance(self.tex_environment, str):
f"\\begin{{{self.tex_environment}}}", prefix = f"\\begin{{{self.tex_environment}}}"
result, suffix = f"\\end{{{self.tex_environment}}}"
f"\\end{{{self.tex_environment}}}" else:
]) prefix, suffix = self.tex_environment
result = "\n".join([prefix, result, suffix])
if self.alignment: if self.alignment:
result = "\n".join([self.alignment, result]) result = "\n".join([self.alignment, result])
if use_plain_file: if not is_labelled:
result = "\n".join([ result = "\n".join([
self.get_color_command_str(self.hex_to_int(self.base_color)), self.get_color_command_str(self.base_color_int),
result result
]) ])
return result return result
@property # Selector
def has_predefined_local_colors(self) -> bool:
return bool(self.command_repl_items)
# Post-parsing
def get_cleaned_substr(self, span: Span) -> str: def get_cleaned_substr(self, span: Span) -> str:
substr = super().get_cleaned_substr(span) if not self.brace_spans:
if not self.brace_index_pairs: brace_begins, brace_ends = [], []
return substr else:
brace_begins, brace_ends = zip(*self.brace_spans)
left_brace_indices = list(brace_begins)
right_brace_indices = [index - 1 for index in brace_ends]
skippable_indices = self.chain(
self.find_indices(r"\s"),
self.script_char_indices,
left_brace_indices,
right_brace_indices
)
shrinked_span = self.shrink_span(span, skippable_indices)
if shrinked_span[0] >= shrinked_span[1]:
return ""
# Balance braces. # Balance braces.
left_brace_indices, right_brace_indices = zip(*self.brace_index_pairs)
unclosed_left_braces = 0 unclosed_left_braces = 0
unclosed_right_braces = 0 unclosed_right_braces = 0
for index in range(*span): for index in range(*shrinked_span):
if index in left_brace_indices: if index in left_brace_indices:
unclosed_left_braces += 1 unclosed_left_braces += 1
elif index in right_brace_indices: elif index in right_brace_indices:
@ -304,27 +333,25 @@ class MTex(LabelledString):
unclosed_left_braces -= 1 unclosed_left_braces -= 1
return "".join([ return "".join([
unclosed_right_braces * "{", unclosed_right_braces * "{",
substr, self.get_substr(shrinked_span),
unclosed_left_braces * "}" unclosed_left_braces * "}"
]) ])
# Method alias # Method alias
def get_parts_by_tex(self, tex: str, **kwargs) -> VGroup: def get_parts_by_tex(self, selector: Selector) -> VGroup:
return self.get_parts_by_string(tex, **kwargs) return self.select_parts(selector)
def get_part_by_tex(self, tex: str, **kwargs) -> VMobject: def get_part_by_tex(self, selector: Selector) -> VGroup:
return self.get_part_by_string(tex, **kwargs) return self.select_part(selector)
def set_color_by_tex(self, tex: str, color: ManimColor, **kwargs): def set_color_by_tex(self, selector: Selector, color: ManimColor):
return self.set_color_by_string(tex, color, **kwargs) return self.set_parts_color(selector, color)
def set_color_by_tex_to_color_map( def set_color_by_tex_to_color_map(
self, tex_to_color_map: dict[str, ManimColor], **kwargs self, color_map: dict[Selector, ManimColor]
): ):
return self.set_color_by_string_to_color_map( return self.set_parts_color_by_dict(color_map)
tex_to_color_map, **kwargs
)
def get_tex(self) -> str: def get_tex(self) -> str:
return self.get_string() return self.get_string()

View file

@ -1,93 +1,63 @@
from __future__ import annotations from __future__ import annotations
import os
import re
import itertools as it
from pathlib import Path
from contextlib import contextmanager from contextlib import contextmanager
import typing import os
from typing import Iterable, Sequence, Union from pathlib import Path
import re
import manimpango
import pygments import pygments
import pygments.formatters import pygments.formatters
import pygments.lexers import pygments.lexers
from manimpango import MarkupUtils from manimlib.constants import DEFAULT_PIXEL_WIDTH, FRAME_WIDTH
from manimlib.constants import NORMAL
from manimlib.logger import log from manimlib.logger import log
from manimlib.constants import *
from manimlib.mobject.svg.labelled_string import LabelledString from manimlib.mobject.svg.labelled_string import LabelledString
from manimlib.utils.customization import get_customization
from manimlib.utils.tex_file_writing import tex_hash
from manimlib.utils.config_ops import digest_config from manimlib.utils.config_ops import digest_config
from manimlib.utils.customization import get_customization
from manimlib.utils.directories import get_downloads_dir from manimlib.utils.directories import get_downloads_dir
from manimlib.utils.directories import get_text_dir from manimlib.utils.directories import get_text_dir
from manimlib.utils.iterables import remove_list_redundancies from manimlib.utils.tex_file_writing import tex_hash
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from manimlib.mobject.types.vectorized_mobject import VMobject from colour import Color
from typing import Iterable, Union
from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VGroup
ManimColor = Union[str, colour.Color, Sequence[float]]
ManimColor = Union[str, Color]
Span = tuple[int, int] Span = tuple[int, int]
Selector = Union[
str,
re.Pattern,
tuple[Union[int, None], Union[int, None]],
Iterable[Union[
str,
re.Pattern,
tuple[Union[int, None], Union[int, None]]
]]
]
TEXT_MOB_SCALE_FACTOR = 0.0076 TEXT_MOB_SCALE_FACTOR = 0.0076
DEFAULT_LINE_SPACING_SCALE = 0.6 DEFAULT_LINE_SPACING_SCALE = 0.6
# Ensure the canvas is large enough to hold all glyphs.
DEFAULT_CANVAS_WIDTH = 16384
DEFAULT_CANVAS_HEIGHT = 16384
# See https://docs.gtk.org/Pango/pango_markup.html # See https://docs.gtk.org/Pango/pango_markup.html
# A tag containing two aliases will cause warning, MARKUP_COLOR_KEYS = (
# so only use the first key of each group of aliases. "foreground", "fgcolor", "color",
SPAN_ATTR_KEY_ALIAS_LIST = ( "background", "bgcolor",
("font", "font_desc"), "underline_color",
("font_family", "face"), "overline_color",
("font_size", "size"), "strikethrough_color"
("font_style", "style"),
("font_weight", "weight"),
("font_variant", "variant"),
("font_stretch", "stretch"),
("font_features",),
("foreground", "fgcolor", "color"),
("background", "bgcolor"),
("alpha", "fgalpha"),
("background_alpha", "bgalpha"),
("underline",),
("underline_color",),
("overline",),
("overline_color",),
("rise",),
("baseline_shift",),
("font_scale",),
("strikethrough",),
("strikethrough_color",),
("fallback",),
("lang",),
("letter_spacing",),
("gravity",),
("gravity_hint",),
("show",),
("insert_hyphens",),
("allow_breaks",),
("line_height",),
("text_transform",),
("segment",),
) )
COLOR_RELATED_KEYS = ( MARKUP_TAG_CONVERSION_DICT = {
"foreground",
"background",
"underline_color",
"overline_color",
"strikethrough_color"
)
SPAN_ATTR_KEY_CONVERSION = {
key: key_alias_list[0]
for key_alias_list in SPAN_ATTR_KEY_ALIAS_LIST
for key in key_alias_list
}
TAG_TO_ATTR_DICT = {
"b": {"font_weight": "bold"}, "b": {"font_weight": "bold"},
"big": {"font_size": "larger"}, "big": {"font_size": "larger"},
"i": {"font_style": "italic"}, "i": {"font_style": "italic"},
@ -96,7 +66,7 @@ TAG_TO_ATTR_DICT = {
"sup": {"baseline_shift": "superscript", "font_scale": "superscript"}, "sup": {"baseline_shift": "superscript", "font_scale": "superscript"},
"small": {"font_size": "smaller"}, "small": {"font_size": "smaller"},
"tt": {"font_family": "monospace"}, "tt": {"font_family": "monospace"},
"u": {"underline": "single"}, "u": {"underline": "single"}
} }
@ -120,7 +90,7 @@ class MarkupText(LabelledString):
"justify": False, "justify": False,
"indent": 0, "indent": 0,
"alignment": "LEFT", "alignment": "LEFT",
"line_width_factor": None, "line_width": None,
"font": "", "font": "",
"slant": NORMAL, "slant": NORMAL,
"weight": NORMAL, "weight": NORMAL,
@ -141,9 +111,7 @@ class MarkupText(LabelledString):
if not self.font: if not self.font:
self.font = get_customization()["style"]["font"] self.font = get_customization()["style"]["font"]
if self.is_markup: if self.is_markup:
validate_error = MarkupUtils.validate(text) self.validate_markup_string(text)
if validate_error:
raise ValueError(validate_error)
self.text = text self.text = text
super().__init__(text, **kwargs) super().__init__(text, **kwargs)
@ -165,7 +133,6 @@ class MarkupText(LabelledString):
self.svg_default, self.svg_default,
self.path_string_config, self.path_string_config,
self.base_color, self.base_color,
self.use_plain_file,
self.isolate, self.isolate,
self.text, self.text,
self.is_markup, self.is_markup,
@ -174,7 +141,7 @@ class MarkupText(LabelledString):
self.justify, self.justify,
self.indent, self.indent,
self.alignment, self.alignment,
self.line_width_factor, self.line_width,
self.font, self.font,
self.slant, self.slant,
self.weight, self.weight,
@ -201,23 +168,32 @@ class MarkupText(LabelledString):
kwargs[short_name] = kwargs.pop(long_name) kwargs[short_name] = kwargs.pop(long_name)
def get_file_path_by_content(self, content: str) -> str: def get_file_path_by_content(self, content: str) -> str:
hash_content = str((
content,
self.justify,
self.indent,
self.alignment,
self.line_width
))
svg_file = os.path.join( svg_file = os.path.join(
get_text_dir(), tex_hash(content) + ".svg" get_text_dir(), tex_hash(hash_content) + ".svg"
) )
if not os.path.exists(svg_file): if not os.path.exists(svg_file):
self.markup_to_svg(content, svg_file) self.markup_to_svg(content, svg_file)
return svg_file return svg_file
def markup_to_svg(self, markup_str: str, file_name: str) -> str: def markup_to_svg(self, markup_str: str, file_name: str) -> str:
self.validate_markup_string(markup_str)
# `manimpango` is under construction, # `manimpango` is under construction,
# so the following code is intended to suit its interface # so the following code is intended to suit its interface
alignment = _Alignment(self.alignment) alignment = _Alignment(self.alignment)
if self.line_width_factor is None: if self.line_width is None:
pango_width = -1 pango_width = -1
else: else:
pango_width = self.line_width_factor * DEFAULT_PIXEL_WIDTH pango_width = self.line_width / FRAME_WIDTH * DEFAULT_PIXEL_WIDTH
return MarkupUtils.text2svg( return manimpango.MarkupUtils.text2svg(
text=markup_str, text=markup_str,
font="", # Already handled font="", # Already handled
slant="NORMAL", # Already handled slant="NORMAL", # Already handled
@ -228,8 +204,8 @@ class MarkupText(LabelledString):
file_name=file_name, file_name=file_name,
START_X=0, START_X=0,
START_Y=0, START_Y=0,
width=DEFAULT_PIXEL_WIDTH, width=DEFAULT_CANVAS_WIDTH,
height=DEFAULT_PIXEL_HEIGHT, height=DEFAULT_CANVAS_HEIGHT,
justify=self.justify, justify=self.justify,
indent=self.indent, indent=self.indent,
line_spacing=None, # Already handled line_spacing=None, # Already handled
@ -237,13 +213,23 @@ class MarkupText(LabelledString):
pango_width=pango_width pango_width=pango_width
) )
def pre_parse(self) -> None: @staticmethod
super().pre_parse() def validate_markup_string(markup_str: str) -> None:
self.tag_items_from_markup = self.get_tag_items_from_markup() validate_error = manimpango.MarkupUtils.validate(markup_str)
self.global_dict_from_config = self.get_global_dict_from_config() if not validate_error:
self.local_dicts_from_markup = self.get_local_dicts_from_markup() return
self.local_dicts_from_config = self.get_local_dicts_from_config() raise ValueError(
self.predefined_attr_dicts = self.get_predefined_attr_dicts() f"Invalid markup string \"{markup_str}\"\n"
f"{validate_error}"
)
def parse(self) -> None:
self.global_attr_dict = self.get_global_attr_dict()
self.tag_pairs_from_markup = self.get_tag_pairs_from_markup()
self.tag_spans = self.get_tag_spans()
self.items_from_markup = self.get_items_from_markup()
self.specified_items = self.get_specified_items()
super().parse()
# Toolkits # Toolkits
@ -254,87 +240,50 @@ class MarkupText(LabelledString):
for key, val in attr_dict.items() for key, val in attr_dict.items()
]) ])
@staticmethod # Parsing
def merge_attr_dicts(
attr_dict_items: list[Span, str, typing.Any] def get_global_attr_dict(self) -> dict[str, str]:
) -> list[tuple[Span, dict[str, str]]]: result = {
index_seq = [0] "foreground": self.int_to_hex(self.base_color_int),
attr_dict_list = [{}] "font_family": self.font,
for span, attr_dict in attr_dict_items: "font_style": self.slant,
if span[0] >= span[1]: "font_weight": self.weight,
continue "font_size": str(self.font_size * 1024),
region_indices = [ }
MarkupText.find_region_index(index_seq, index) # `line_height` attribute is supported since Pango 1.50.
for index in span pango_version = manimpango.pango_version()
] if tuple(map(int, pango_version.split("."))) < (1, 50):
for flag in (1, 0): if self.lsh is not None:
if index_seq[region_indices[flag]] == span[flag]: log.warning(
continue f"Pango version {pango_version} found (< 1.50), "
region_index = region_indices[flag] "unable to set `line_height` attribute"
index_seq.insert(region_index + 1, span[flag])
attr_dict_list.insert(
region_index + 1, attr_dict_list[region_index].copy()
) )
region_indices[flag] += 1 else:
if flag == 0: line_spacing_scale = self.lsh or DEFAULT_LINE_SPACING_SCALE
region_indices[1] += 1 result["line_height"] = str(((line_spacing_scale) + 1) * 0.6)
for key, val in attr_dict.items(): return result
if not key:
continue
for mid_dict in attr_dict_list[slice(*region_indices)]:
mid_dict[key] = val
return list(zip(
MarkupText.get_neighbouring_pairs(index_seq), attr_dict_list[:-1]
))
def find_substr_or_span( def get_tag_pairs_from_markup(
self, substr_or_span: str | tuple[int | None, int | None]
) -> list[Span]:
if isinstance(substr_or_span, str):
return self.find_substr(substr_or_span)
span = tuple([
(
min(index, self.string_len)
if index >= 0
else max(index + self.string_len, 0)
)
if index is not None else default_index
for index, default_index in zip(substr_or_span, self.full_span)
])
if span[0] >= span[1]:
return []
return [span]
# Pre-parsing
def get_tag_items_from_markup(
self self
) -> list[tuple[Span, Span, dict[str, str]]]: ) -> list[tuple[Span, Span, dict[str, str]]]:
if not self.is_markup: if not self.is_markup:
return [] return []
tag_pattern = r"""<(/?)(\w+)\s*((?:\w+\s*\=\s*(['"]).*?\4\s*)*)>""" tag_pattern = r"""<(/?)(\w+)\s*((\w+\s*\=\s*(['"])[\s\S]*?\5\s*)*)>"""
attr_pattern = r"""(\w+)\s*\=\s*(['"])(.*?)\2""" attr_pattern = r"""(\w+)\s*\=\s*(['"])([\s\S]*?)\2"""
begin_match_obj_stack = [] begin_match_obj_stack = []
match_obj_pairs = [] match_obj_pairs = []
for match_obj in self.finditer(tag_pattern): for match_obj in re.finditer(tag_pattern, self.string):
if not match_obj.group(1): if not match_obj.group(1):
begin_match_obj_stack.append(match_obj) begin_match_obj_stack.append(match_obj)
else: else:
match_obj_pairs.append( match_obj_pairs.append(
(begin_match_obj_stack.pop(), match_obj) (begin_match_obj_stack.pop(), match_obj)
) )
if begin_match_obj_stack:
raise ValueError("Unclosed tag(s) detected")
result = [] result = []
for begin_match_obj, end_match_obj in match_obj_pairs: for begin_match_obj, end_match_obj in match_obj_pairs:
tag_name = begin_match_obj.group(2) tag_name = begin_match_obj.group(2)
if tag_name != end_match_obj.group(2):
raise ValueError("Unmatched tag names")
if end_match_obj.group(3):
raise ValueError("Attributes shan't exist in ending tags")
if tag_name == "span": if tag_name == "span":
attr_dict = { attr_dict = {
match.group(1): match.group(3) match.group(1): match.group(3)
@ -342,189 +291,170 @@ class MarkupText(LabelledString):
attr_pattern, begin_match_obj.group(3) attr_pattern, begin_match_obj.group(3)
) )
} }
elif tag_name in TAG_TO_ATTR_DICT.keys():
if begin_match_obj.group(3):
raise ValueError(
f"Attributes shan't exist in tag '{tag_name}'"
)
attr_dict = TAG_TO_ATTR_DICT[tag_name].copy()
else: else:
raise ValueError(f"Unknown tag: '{tag_name}'") attr_dict = MARKUP_TAG_CONVERSION_DICT.get(tag_name, {})
result.append( result.append(
(begin_match_obj.span(), end_match_obj.span(), attr_dict) (begin_match_obj.span(), end_match_obj.span(), attr_dict)
) )
return result return result
def get_global_dict_from_config(self) -> dict[str, typing.Any]: def get_tag_spans(self) -> list[Span]:
result = { return [
"line_height": ( tag_span
(self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1 for begin_tag, end_tag, _ in self.tag_pairs_from_markup
) * 0.6, for tag_span in (begin_tag, end_tag)
"font_family": self.font, ]
"font_size": self.font_size * 1024,
"font_style": self.slant,
"font_weight": self.weight
}
result.update(self.global_config)
return result
def get_local_dicts_from_markup( def get_items_from_markup(self) -> list[Span]:
self return [
) -> list[Span, dict[str, str]]: ((begin_tag_span[1], end_tag_span[0]), attr_dict)
return sorted([
((begin_tag_span[0], end_tag_span[1]), attr_dict)
for begin_tag_span, end_tag_span, attr_dict for begin_tag_span, end_tag_span, attr_dict
in self.tag_items_from_markup in self.tag_pairs_from_markup
]) if begin_tag_span[1] < end_tag_span[0]
]
def get_local_dicts_from_config( def get_specified_items(self) -> list[tuple[Span, dict[str, str]]]:
self result = self.chain(
) -> list[Span, dict[str, typing.Any]]: self.items_from_markup,
[
(span, {key: val})
for t2x_dict, key in (
(self.t2c, "foreground"),
(self.t2f, "font_family"),
(self.t2s, "font_style"),
(self.t2w, "font_weight")
)
for selector, val in t2x_dict.items()
for span in self.find_spans_by_selector(selector)
],
[
(span, local_config)
for selector, local_config in self.local_configs.items()
for span in self.find_spans_by_selector(selector)
],
[
(span, {})
for span in self.find_spans_by_selector(self.isolate)
]
)
entity_spans = self.tag_spans.copy()
if self.is_markup:
entity_spans.extend(self.find_spans(r"&[\s\S]*?;"))
return [ return [
(span, {key: val}) (span, attr_dict)
for t2x_dict, key in ( for span, attr_dict in result
(self.t2c, "foreground"), if not any([
(self.t2f, "font_family"), entity_begin < index < entity_end
(self.t2s, "font_style"), for index in span
(self.t2w, "font_weight") for entity_begin, entity_end in entity_spans
) ])
for substr_or_span, val in t2x_dict.items()
for span in self.find_substr_or_span(substr_or_span)
] + [
(span, local_config)
for substr_or_span, local_config in self.local_configs.items()
for span in self.find_substr_or_span(substr_or_span)
] ]
def get_predefined_attr_dicts(self) -> list[Span, dict[str, str]]:
attr_dict_items = [
(self.full_span, self.global_dict_from_config),
*self.local_dicts_from_markup,
*self.local_dicts_from_config
]
return [
(span, {
SPAN_ATTR_KEY_CONVERSION[key.lower()]: str(val)
for key, val in attr_dict.items()
})
for span, attr_dict in attr_dict_items
]
# Parsing
def get_command_repl_items(self) -> list[tuple[Span, str]]: def get_command_repl_items(self) -> list[tuple[Span, str]]:
result = [ result = [
(tag_span, "") (tag_span, "") for tag_span in self.tag_spans
for begin_tag, end_tag, _ in self.tag_items_from_markup
for tag_span in (begin_tag, end_tag)
] ]
if not self.is_markup: if not self.is_markup:
result += [ result.extend([
(span, escaped) (span, escaped)
for char, escaped in ( for char, escaped in (
("&", "&amp;"), ("&", "&amp;"),
(">", "&gt;"), (">", "&gt;"),
("<", "&lt;") ("<", "&lt;")
) )
for span in self.find_substr(char) for span in self.find_spans(re.escape(char))
] ])
return result return result
def get_extra_entity_spans(self) -> list[Span]: def get_specified_spans(self) -> list[Span]:
if not self.is_markup: return self.remove_redundancies([
return [] span for span, _ in self.specified_items
return self.find_spans(r"&.*?;") ])
def get_extra_ignored_spans(self) -> list[int]:
return []
def get_internal_specified_spans(self) -> list[Span]:
return [span for span, _ in self.local_dicts_from_markup]
def get_external_specified_spans(self) -> list[Span]:
return [span for span, _ in self.local_dicts_from_config]
def get_label_span_list(self) -> list[Span]: def get_label_span_list(self) -> list[Span]:
breakup_indices = remove_list_redundancies(list(it.chain(*it.chain( interval_spans = sorted(self.chain(
self.find_spans(r"\s+"), self.tag_spans,
self.find_spans(r"\b"), [
self.specified_spans (index, index)
)))) for span in self.specified_spans
breakup_indices = sorted(filter( for index in span
lambda index: not any([
span[0] < index < span[1]
for span in self.entity_spans
]),
breakup_indices
))
return list(filter(
lambda span: self.get_substr(span).strip(),
self.get_neighbouring_pairs(breakup_indices)
))
def get_content(self, use_plain_file: bool) -> str:
if use_plain_file:
attr_dict_items = [
(self.full_span, {"foreground": self.base_color}),
*self.predefined_attr_dicts,
*[
(span, {})
for span in self.label_span_list
]
] ]
))
text_spans = self.get_complement_spans(interval_spans, self.full_span)
if self.is_markup:
pattern = r"[0-9a-zA-Z]+|(?:&[\s\S]*?;|[^0-9a-zA-Z\s])+"
else: else:
attr_dict_items = [ pattern = r"[0-9a-zA-Z]+|[^0-9a-zA-Z\s]+"
(self.full_span, {"foreground": BLACK}), return self.chain(*[
*[ self.find_spans(pattern, pos=span_begin, endpos=span_end)
for span_begin, span_end in text_spans
])
def get_content(self, is_labelled: bool) -> str:
predefined_items = [
(self.full_span, self.global_attr_dict),
(self.full_span, self.global_config),
*self.specified_items
]
if is_labelled:
attr_dict_items = self.chain(
[
(span, { (span, {
key: BLACK if key in COLOR_RELATED_KEYS else val key:
"black" if key.lower() in MARKUP_COLOR_KEYS else val
for key, val in attr_dict.items() for key, val in attr_dict.items()
}) })
for span, attr_dict in self.predefined_attr_dicts for span, attr_dict in predefined_items
], ],
*[ [
(span, {"foreground": self.int_to_hex(label + 1)}) (span, {"foreground": self.int_to_hex(label + 1)})
for label, span in enumerate(self.label_span_list) for label, span in enumerate(self.label_span_list)
] ]
] )
else:
attr_dict_items = self.chain(
predefined_items,
[
(span, {})
for span in self.label_span_list
]
)
inserted_string_pairs = [ inserted_string_pairs = [
(span, ( (span, (
f"<span {self.get_attr_dict_str(attr_dict)}>", f"<span {self.get_attr_dict_str(attr_dict)}>",
"</span>" "</span>"
)) ))
for span, attr_dict in self.merge_attr_dicts(attr_dict_items) for span, attr_dict in attr_dict_items if attr_dict
] ]
span_repl_dict = self.generate_span_repl_dict( return self.get_replaced_string(
inserted_string_pairs, self.command_repl_items inserted_string_pairs, self.command_repl_items
) )
return self.get_replaced_substr(self.full_span, span_repl_dict)
@property # Selector
def has_predefined_local_colors(self) -> bool:
return any([ def get_cleaned_substr(self, span: Span) -> str:
key in COLOR_RELATED_KEYS repl_items = list(filter(
for _, attr_dict in self.predefined_attr_dicts lambda repl_item: self.span_contains(span, repl_item[0]),
for key in attr_dict.keys() self.command_repl_items
]) ))
return self.get_replaced_substr(span, repl_items).strip()
# Method alias # Method alias
def get_parts_by_text(self, text: str, **kwargs) -> VGroup: def get_parts_by_text(self, selector: Selector) -> VGroup:
return self.get_parts_by_string(text, **kwargs) return self.select_parts(selector)
def get_part_by_text(self, text: str, **kwargs) -> VMobject: def get_part_by_text(self, selector: Selector) -> VGroup:
return self.get_part_by_string(text, **kwargs) return self.select_part(selector)
def set_color_by_text(self, text: str, color: ManimColor, **kwargs): def set_color_by_text(self, selector: Selector, color: ManimColor):
return self.set_color_by_string(text, color, **kwargs) return self.set_parts_color(selector, color)
def set_color_by_text_to_color_map( def set_color_by_text_to_color_map(
self, text_to_color_map: dict[str, ManimColor], **kwargs self, color_map: dict[Selector, ManimColor]
): ):
return self.set_color_by_string_to_color_map( return self.set_parts_color_by_dict(color_map)
text_to_color_map, **kwargs
)
def get_text(self) -> str: def get_text(self) -> str:
return self.get_string() return self.get_string()

View file

@ -13,7 +13,7 @@ if TYPE_CHECKING:
S = TypeVar("S") S = TypeVar("S")
def remove_list_redundancies(l: Iterable[T]) -> list[T]: def remove_list_redundancies(l: Sequence[T]) -> list[T]:
""" """
Used instead of list(set(l)) to maintain order Used instead of list(set(l)) to maintain order
Keeps the last occurrence of each element Keeps the last occurrence of each element
@ -40,14 +40,14 @@ def list_difference_update(l1: Iterable[T], l2: Iterable[T]) -> list[T]:
return [e for e in l1 if e not in l2] return [e for e in l1 if e not in l2]
def adjacent_n_tuples(objects: Iterable[T], n: int) -> zip[tuple[T, T]]: def adjacent_n_tuples(objects: Sequence[T], n: int) -> zip[tuple[T, T]]:
return zip(*[ return zip(*[
[*objects[k:], *objects[:k]] [*objects[k:], *objects[:k]]
for k in range(n) for k in range(n)
]) ])
def adjacent_pairs(objects: Iterable[T]) -> zip[tuple[T, T]]: def adjacent_pairs(objects: Sequence[T]) -> zip[tuple[T, T]]:
return adjacent_n_tuples(objects, 2) return adjacent_n_tuples(objects, 2)