Merge pull request #1795 from YishiMichael/refactor

Refactor StringMobject and relevant classes
This commit is contained in:
Grant Sanderson 2022-05-17 09:17:53 -07:00 committed by GitHub
commit cd240f2a80
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 937 additions and 1249 deletions

View file

@ -38,8 +38,8 @@ from manimlib.mobject.probability import *
from manimlib.mobject.shape_matchers import *
from manimlib.mobject.svg.brace import *
from manimlib.mobject.svg.drawings import *
from manimlib.mobject.svg.labelled_string import *
from manimlib.mobject.svg.mtex_mobject import *
from manimlib.mobject.svg.string_mobject import *
from manimlib.mobject.svg.svg_mobject import *
from manimlib.mobject.svg.tex_mobject import *
from manimlib.mobject.svg.text_mobject import *

View file

@ -1,12 +1,12 @@
from __future__ import annotations
import itertools as it
from abc import abstractmethod
from abc import ABC, abstractmethod
import numpy as np
from manimlib.animation.animation import Animation
from manimlib.mobject.svg.labelled_string import LabelledString
from manimlib.mobject.svg.string_mobject import StringMobject
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.bezier import integer_interpolate
from manimlib.utils.config_ops import digest_config
@ -17,10 +17,10 @@ from manimlib.utils.rate_functions import smooth
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.mobject.mobject import Group
from manimlib.mobject.mobject import Mobject
class ShowPartial(Animation):
class ShowPartial(Animation, ABC):
"""
Abstract class for ShowCreation and ShowPassingFlash
"""
@ -176,7 +176,7 @@ class ShowIncreasingSubsets(Animation):
"int_func": np.round,
}
def __init__(self, group: Group, **kwargs):
def __init__(self, group: Mobject, **kwargs):
self.all_submobs = list(group.submobjects)
super().__init__(group, **kwargs)
@ -212,8 +212,8 @@ class AddTextWordByWord(ShowIncreasingSubsets):
}
def __init__(self, string_mobject, **kwargs):
assert isinstance(string_mobject, LabelledString)
grouped_mobject = string_mobject.submob_groups
assert isinstance(string_mobject, StringMobject)
grouped_mobject = string_mobject.build_groups()
digest_config(self, kwargs)
if self.run_time is None:
self.run_time = self.time_per_word * len(grouped_mobject)

View file

@ -5,24 +5,24 @@ import itertools as it
import numpy as np
from manimlib.animation.composition import AnimationGroup
from manimlib.animation.fading import FadeTransformPieces
from manimlib.animation.fading import FadeInFromPoint
from manimlib.animation.fading import FadeOutToPoint
from manimlib.animation.fading import FadeTransformPieces
from manimlib.animation.transform import ReplacementTransform
from manimlib.animation.transform import Transform
from manimlib.mobject.mobject import Mobject
from manimlib.mobject.mobject import Group
from manimlib.mobject.svg.labelled_string import LabelledString
from manimlib.mobject.svg.string_mobject import StringMobject
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.config_ops import digest_config
from manimlib.utils.iterables import remove_list_redundancies
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.mobject.svg.tex_mobject import SingleStringTex
from manimlib.mobject.svg.tex_mobject import Tex
from manimlib.scene.scene import Scene
from manimlib.mobject.svg.tex_mobject import Tex, SingleStringTex
class TransformMatchingParts(AnimationGroup):
@ -155,92 +155,89 @@ class TransformMatchingTex(TransformMatchingParts):
class TransformMatchingStrings(AnimationGroup):
CONFIG = {
"key_map": dict(),
"key_map": {},
"transform_mismatches": False,
}
def __init__(self,
source: LabelledString,
target: LabelledString,
source: StringMobject,
target: StringMobject,
**kwargs
):
digest_config(self, kwargs)
assert isinstance(source, LabelledString)
assert isinstance(target, LabelledString)
assert isinstance(source, StringMobject)
assert isinstance(target, StringMobject)
anims = []
source_indices = list(range(len(source.labelled_submobjects)))
target_indices = list(range(len(target.labelled_submobjects)))
source_indices = list(range(len(source.labels)))
target_indices = list(range(len(target.labels)))
def get_indices_lists(mobject, parts):
return [
[
mobject.labelled_submobjects.index(submob)
for submob in part
]
for part in parts
]
def add_anims_from(anim_class, func, source_args, target_args=None):
if target_args is None:
target_args = source_args.copy()
for source_arg, target_arg in zip(source_args, target_args):
source_parts = func(source, source_arg)
target_parts = func(target, target_arg)
source_indices_lists = list(filter(
lambda indices_list: all([
index in source_indices
for index in indices_list
]), get_indices_lists(source, source_parts)
))
target_indices_lists = list(filter(
lambda indices_list: all([
index in target_indices
for index in indices_list
]), get_indices_lists(target, target_parts)
))
if not source_indices_lists or not target_indices_lists:
def get_filtered_indices_lists(indices_lists, rest_indices):
result = []
for indices_list in indices_lists:
if not indices_list:
continue
anims.append(anim_class(source_parts, target_parts, **kwargs))
for index in it.chain(*source_indices_lists):
source_indices.remove(index)
for index in it.chain(*target_indices_lists):
target_indices.remove(index)
def get_common_substrs(substrs_from_source, substrs_from_target):
return sorted([
substr for substr in substrs_from_source
if substr and substr in substrs_from_target
], key=len, reverse=True)
def get_parts_from_keys(mobject, keys):
if isinstance(keys, str):
keys = [keys]
result = VGroup()
for key in keys:
if not isinstance(key, str):
raise TypeError(key)
result.add(*mobject.get_parts_by_string(key))
if not all(index in rest_indices for index in indices_list):
continue
result.append(indices_list)
for index in indices_list:
rest_indices.remove(index)
return result
add_anims_from(
ReplacementTransform, get_parts_from_keys,
self.key_map.keys(), self.key_map.values()
def add_anims(anim_class, indices_lists_pairs):
for source_indices_lists, target_indices_lists in indices_lists_pairs:
source_indices_lists = get_filtered_indices_lists(
source_indices_lists, source_indices
)
target_indices_lists = get_filtered_indices_lists(
target_indices_lists, target_indices
)
if not source_indices_lists or not target_indices_lists:
continue
anims.append(anim_class(
source.build_parts_from_indices_lists(source_indices_lists),
target.build_parts_from_indices_lists(target_indices_lists),
**kwargs
))
def get_substr_to_indices_lists_map(part_items):
result = {}
for substr, indices_list in part_items:
if substr not in result:
result[substr] = []
result[substr].append(indices_list)
return result
def add_anims_from(anim_class, func):
source_substr_map = get_substr_to_indices_lists_map(func(source))
target_substr_map = get_substr_to_indices_lists_map(func(target))
common_substrings = sorted([
s for s in source_substr_map if s and s in target_substr_map
], key=len, reverse=True)
add_anims(
anim_class,
[
(source_substr_map[substr], target_substr_map[substr])
for substr in common_substrings
]
)
add_anims(
ReplacementTransform,
[
(
source.get_submob_indices_lists_by_selector(k),
target.get_submob_indices_lists_by_selector(v)
)
for k, v in self.key_map.items()
]
)
add_anims_from(
FadeTransformPieces,
LabelledString.get_parts_by_string,
get_common_substrs(
source.specified_substrs,
target.specified_substrs
)
StringMobject.get_specified_part_items
)
add_anims_from(
FadeTransformPieces,
LabelledString.get_parts_by_group_substr,
get_common_substrs(
source.group_substrs,
target.group_substrs
)
StringMobject.get_group_part_items
)
rest_source = VGroup(*[source[index] for index in source_indices])

View file

@ -1,543 +0,0 @@
from __future__ import annotations
import re
import colour
import itertools as it
from typing import Iterable, Union, Sequence
from abc import ABC, abstractmethod
from manimlib.constants import BLACK, WHITE
from manimlib.mobject.svg.svg_mobject import SVGMobject
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.color import color_to_int_rgb
from manimlib.utils.color import color_to_rgb
from manimlib.utils.color import rgb_to_hex
from manimlib.utils.config_ops import digest_config
from manimlib.utils.iterables import remove_list_redundancies
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.mobject.types.vectorized_mobject import VMobject
ManimColor = Union[str, colour.Color, Sequence[float]]
Span = tuple[int, int]
class _StringSVG(SVGMobject):
CONFIG = {
"height": None,
"stroke_width": 0,
"stroke_color": WHITE,
"path_string_config": {
"should_subdivide_sharp_curves": True,
"should_remove_null_curves": True,
},
}
class LabelledString(_StringSVG, ABC):
"""
An abstract base class for `MTex` and `MarkupText`
"""
CONFIG = {
"base_color": WHITE,
"use_plain_file": False,
"isolate": [],
}
def __init__(self, string: str, **kwargs):
self.string = string
digest_config(self, kwargs)
# Convert `base_color` to hex code.
self.base_color = rgb_to_hex(color_to_rgb(
self.base_color \
or self.svg_default.get("color", None) \
or self.svg_default.get("fill_color", None) \
or WHITE
))
self.svg_default["fill_color"] = BLACK
self.pre_parse()
self.parse()
super().__init__()
self.post_parse()
def get_file_path(self) -> str:
return self.get_file_path_(use_plain_file=False)
def get_file_path_(self, use_plain_file: bool) -> str:
content = self.get_content(use_plain_file)
return self.get_file_path_by_content(content)
@abstractmethod
def get_file_path_by_content(self, content: str) -> str:
return ""
def generate_mobject(self) -> None:
super().generate_mobject()
submob_labels = [
self.color_to_label(submob.get_fill_color())
for submob in self.submobjects
]
if self.use_plain_file or self.has_predefined_local_colors:
file_path = self.get_file_path_(use_plain_file=True)
plain_svg = _StringSVG(
file_path,
svg_default=self.svg_default,
path_string_config=self.path_string_config
)
self.set_submobjects(plain_svg.submobjects)
else:
self.set_fill(self.base_color)
for submob, label in zip(self.submobjects, submob_labels):
submob.label = label
def pre_parse(self) -> None:
self.string_len = len(self.string)
self.full_span = (0, self.string_len)
def parse(self) -> None:
self.command_repl_items = self.get_command_repl_items()
self.command_spans = self.get_command_spans()
self.extra_entity_spans = self.get_extra_entity_spans()
self.entity_spans = self.get_entity_spans()
self.extra_ignored_spans = self.get_extra_ignored_spans()
self.skipped_spans = self.get_skipped_spans()
self.internal_specified_spans = self.get_internal_specified_spans()
self.external_specified_spans = self.get_external_specified_spans()
self.specified_spans = self.get_specified_spans()
self.label_span_list = self.get_label_span_list()
self.check_overlapping()
def post_parse(self) -> None:
self.labelled_submobject_items = [
(submob.label, submob)
for submob in self.submobjects
]
self.labelled_submobjects = self.get_labelled_submobjects()
self.specified_substrs = self.get_specified_substrs()
self.group_items = self.get_group_items()
self.group_substrs = self.get_group_substrs()
self.submob_groups = self.get_submob_groups()
# Toolkits
def get_substr(self, span: Span) -> str:
return self.string[slice(*span)]
def finditer(
self, pattern: str, flags: int = 0, **kwargs
) -> Iterable[re.Match]:
return re.compile(pattern, flags).finditer(self.string, **kwargs)
def search(
self, pattern: str, flags: int = 0, **kwargs
) -> re.Match | None:
return re.compile(pattern, flags).search(self.string, **kwargs)
def match(
self, pattern: str, flags: int = 0, **kwargs
) -> re.Match | None:
return re.compile(pattern, flags).match(self.string, **kwargs)
def find_spans(self, pattern: str, **kwargs) -> list[Span]:
return [
match_obj.span()
for match_obj in self.finditer(pattern, **kwargs)
]
def find_substr(self, substr: str, **kwargs) -> list[Span]:
if not substr:
return []
return self.find_spans(re.escape(substr), **kwargs)
def find_substrs(self, substrs: list[str], **kwargs) -> list[Span]:
return list(it.chain(*[
self.find_substr(substr, **kwargs)
for substr in remove_list_redundancies(substrs)
]))
@staticmethod
def get_neighbouring_pairs(iterable: list) -> list[tuple]:
return list(zip(iterable[:-1], iterable[1:]))
@staticmethod
def span_contains(span_0: Span, span_1: Span) -> bool:
return span_0[0] <= span_1[0] and span_0[1] >= span_1[1]
@staticmethod
def get_complement_spans(
interval_spans: list[Span], universal_span: Span
) -> list[Span]:
if not interval_spans:
return [universal_span]
span_ends, span_begins = zip(*interval_spans)
return list(zip(
(universal_span[0], *span_begins),
(*span_ends, universal_span[1])
))
@staticmethod
def compress_neighbours(vals: list[int]) -> list[tuple[int, Span]]:
if not vals:
return []
unique_vals = [vals[0]]
indices = [0]
for index, val in enumerate(vals):
if val == unique_vals[-1]:
continue
unique_vals.append(val)
indices.append(index)
indices.append(len(vals))
spans = LabelledString.get_neighbouring_pairs(indices)
return list(zip(unique_vals, spans))
@staticmethod
def find_region_index(seq: list[int], val: int) -> int:
# Returns an integer in `range(-1, len(seq))` satisfying
# `seq[result] <= val < seq[result + 1]`.
# `seq` should be sorted in ascending order.
if not seq or val < seq[0]:
return -1
result = len(seq) - 1
while val < seq[result]:
result -= 1
return result
@staticmethod
def take_nearest_value(seq: list[int], val: int, index_shift: int) -> int:
sorted_seq = sorted(seq)
index = LabelledString.find_region_index(sorted_seq, val)
return sorted_seq[index + index_shift]
@staticmethod
def generate_span_repl_dict(
inserted_string_pairs: list[tuple[Span, tuple[str, str]]],
other_repl_items: list[tuple[Span, str]]
) -> dict[Span, str]:
result = dict(other_repl_items)
if not inserted_string_pairs:
return result
indices, _, _, inserted_strings = zip(*sorted([
(
span[flag],
-flag,
-span[1 - flag],
str_pair[flag]
)
for span, str_pair in inserted_string_pairs
for flag in range(2)
]))
result.update({
(index, index): "".join(inserted_strings[slice(*item_span)])
for index, item_span
in LabelledString.compress_neighbours(indices)
})
return result
def get_replaced_substr(
self, span: Span, span_repl_dict: dict[Span, str]
) -> str:
repl_spans = sorted(filter(
lambda repl_span: self.span_contains(span, repl_span),
span_repl_dict.keys()
))
if not all(
span_0[1] <= span_1[0]
for span_0, span_1 in self.get_neighbouring_pairs(repl_spans)
):
raise ValueError("Overlapping replacement")
pieces = [
self.get_substr(piece_span)
for piece_span in self.get_complement_spans(repl_spans, span)
]
repl_strs = [span_repl_dict[repl_span] for repl_span in repl_spans]
repl_strs.append("")
return "".join(it.chain(*zip(pieces, repl_strs)))
@staticmethod
def rslide(index: int, skipped: list[Span]) -> int:
transfer_dict = dict(sorted(skipped))
while index in transfer_dict.keys():
index = transfer_dict[index]
return index
@staticmethod
def lslide(index: int, skipped: list[Span]) -> int:
transfer_dict = dict(sorted([
skipped_span[::-1] for skipped_span in skipped
], reverse=True))
while index in transfer_dict.keys():
index = transfer_dict[index]
return index
@staticmethod
def rgb_to_int(rgb_tuple: tuple[int, int, int]) -> int:
r, g, b = rgb_tuple
rg = r * 256 + g
return rg * 256 + b
@staticmethod
def int_to_rgb(rgb_int: int) -> tuple[int, int, int]:
rg, b = divmod(rgb_int, 256)
r, g = divmod(rg, 256)
return r, g, b
@staticmethod
def int_to_hex(rgb_int: int) -> str:
return "#{:06x}".format(rgb_int).upper()
@staticmethod
def hex_to_int(rgb_hex: str) -> int:
return int(rgb_hex[1:], 16)
@staticmethod
def color_to_label(color: ManimColor) -> int:
rgb_tuple = color_to_int_rgb(color)
rgb = LabelledString.rgb_to_int(rgb_tuple)
return rgb - 1
# Parsing
@abstractmethod
def get_command_repl_items(self) -> list[tuple[Span, str]]:
return []
def get_command_spans(self) -> list[Span]:
return [cmd_span for cmd_span, _ in self.command_repl_items]
@abstractmethod
def get_extra_entity_spans(self) -> list[Span]:
return []
def get_entity_spans(self) -> list[Span]:
return list(it.chain(
self.command_spans,
self.extra_entity_spans
))
@abstractmethod
def get_extra_ignored_spans(self) -> list[int]:
return []
def get_skipped_spans(self) -> list[Span]:
return list(it.chain(
self.find_spans(r"\s"),
self.command_spans,
self.extra_ignored_spans
))
def shrink_span(self, span: Span) -> Span:
return (
self.rslide(span[0], self.skipped_spans),
self.lslide(span[1], self.skipped_spans)
)
@abstractmethod
def get_internal_specified_spans(self) -> list[Span]:
return []
@abstractmethod
def get_external_specified_spans(self) -> list[Span]:
return []
def get_specified_spans(self) -> list[Span]:
spans = list(it.chain(
self.internal_specified_spans,
self.external_specified_spans,
self.find_substrs(self.isolate)
))
shrinked_spans = list(filter(
lambda span: span[0] < span[1] and not any([
entity_span[0] < index < entity_span[1]
for index in span
for entity_span in self.entity_spans
]),
[self.shrink_span(span) for span in spans]
))
return remove_list_redundancies(shrinked_spans)
@abstractmethod
def get_label_span_list(self) -> list[Span]:
return []
def check_overlapping(self) -> None:
for span_0, span_1 in it.product(self.label_span_list, repeat=2):
if not span_0[0] < span_1[0] < span_0[1] < span_1[1]:
continue
raise ValueError(
"Partially overlapping substrings detected: "
f"'{self.get_substr(span_0)}' and '{self.get_substr(span_1)}'"
)
@abstractmethod
def get_content(self, use_plain_file: bool) -> str:
return ""
@abstractmethod
def has_predefined_local_colors(self) -> bool:
return False
# Post-parsing
def get_labelled_submobjects(self) -> list[VMobject]:
return [submob for _, submob in self.labelled_submobject_items]
def get_cleaned_substr(self, span: Span) -> str:
span_repl_dict = dict.fromkeys(self.command_spans, "")
return self.get_replaced_substr(span, span_repl_dict)
def get_specified_substrs(self) -> list[str]:
return remove_list_redundancies([
self.get_cleaned_substr(span)
for span in self.specified_spans
])
def get_group_items(self) -> list[tuple[str, VGroup]]:
if not self.labelled_submobject_items:
return []
labels, labelled_submobjects = zip(*self.labelled_submobject_items)
group_labels, labelled_submob_spans = zip(
*self.compress_neighbours(labels)
)
ordered_spans = [
self.label_span_list[label] if label != -1 else self.full_span
for label in group_labels
]
interval_spans = [
(
next_span[0]
if self.span_contains(prev_span, next_span)
else prev_span[1],
prev_span[1]
if self.span_contains(next_span, prev_span)
else next_span[0]
)
for prev_span, next_span in self.get_neighbouring_pairs(
ordered_spans
)
]
shrinked_spans = [
self.shrink_span(span)
for span in self.get_complement_spans(
interval_spans, (ordered_spans[0][0], ordered_spans[-1][1])
)
]
group_substrs = [
self.get_cleaned_substr(span) if span[0] < span[1] else ""
for span in shrinked_spans
]
submob_groups = VGroup(*[
VGroup(*labelled_submobjects[slice(*submob_span)])
for submob_span in labelled_submob_spans
])
return list(zip(group_substrs, submob_groups))
def get_group_substrs(self) -> list[str]:
return [group_substr for group_substr, _ in self.group_items]
def get_submob_groups(self) -> list[VGroup]:
return [submob_group for _, submob_group in self.group_items]
def get_parts_by_group_substr(self, substr: str) -> VGroup:
return VGroup(*[
group
for group_substr, group in self.group_items
if group_substr == substr
])
# Selector
def find_span_components(
self, custom_span: Span, substring: bool = True
) -> list[Span]:
shrinked_span = self.shrink_span(custom_span)
if shrinked_span[0] >= shrinked_span[1]:
return []
if substring:
indices = remove_list_redundancies(list(it.chain(
self.full_span,
*self.label_span_list
)))
span_begin = self.take_nearest_value(
indices, shrinked_span[0], 0
)
span_end = self.take_nearest_value(
indices, shrinked_span[1] - 1, 1
)
else:
span_begin, span_end = shrinked_span
span_choices = sorted(filter(
lambda span: self.span_contains((span_begin, span_end), span),
self.label_span_list
))
# Choose spans that reach the farthest.
span_choices_dict = dict(span_choices)
result = []
while span_begin < span_end:
if span_begin not in span_choices_dict.keys():
span_begin += 1
continue
next_begin = span_choices_dict[span_begin]
result.append((span_begin, next_begin))
span_begin = next_begin
return result
def get_part_by_custom_span(self, custom_span: Span, **kwargs) -> VGroup:
labels = [
label for label, span in enumerate(self.label_span_list)
if any([
self.span_contains(span_component, span)
for span_component in self.find_span_components(
custom_span, **kwargs
)
])
]
return VGroup(*[
submob for label, submob in self.labelled_submobject_items
if label in labels
])
def get_parts_by_string(
self, substr: str,
case_sensitive: bool = True, regex: bool = False, **kwargs
) -> VGroup:
flags = 0
if not case_sensitive:
flags |= re.I
pattern = substr if regex else re.escape(substr)
return VGroup(*[
self.get_part_by_custom_span(span, **kwargs)
for span in self.find_spans(pattern, flags=flags)
if span[0] < span[1]
])
def get_part_by_string(
self, substr: str, index: int = 0, **kwargs
) -> VMobject:
return self.get_parts_by_string(substr, **kwargs)[index]
def set_color_by_string(self, substr: str, color: ManimColor, **kwargs):
self.get_parts_by_string(substr, **kwargs).set_color(color)
return self
def set_color_by_string_to_color_map(
self, string_to_color_map: dict[str, ManimColor], **kwargs
):
for substr, color in string_to_color_map.items():
self.set_color_by_string(substr, color, **kwargs)
return self
def get_string(self) -> str:
return self.string

View file

@ -1,28 +1,37 @@
from __future__ import annotations
import itertools as it
import colour
from typing import Union, Sequence
from manimlib.mobject.svg.labelled_string import LabelledString
from manimlib.utils.tex_file_writing import tex_to_svg_file
from manimlib.utils.tex_file_writing import get_tex_config
from manimlib.mobject.svg.string_mobject import StringMobject
from manimlib.utils.tex_file_writing import display_during_execution
from manimlib.utils.tex_file_writing import get_tex_config
from manimlib.utils.tex_file_writing import tex_to_svg_file
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.mobject.types.vectorized_mobject import VMobject
from colour import Color
import re
from typing import Iterable, Union
from manimlib.mobject.types.vectorized_mobject import VGroup
ManimColor = Union[str, colour.Color, Sequence[float]]
ManimColor = Union[str, Color]
Span = tuple[int, int]
Selector = Union[
str,
re.Pattern,
tuple[Union[int, None], Union[int, None]],
Iterable[Union[
str,
re.Pattern,
tuple[Union[int, None], Union[int, None]]
]]
]
SCALE_FACTOR_PER_FONT_POINT = 0.001
class MTex(LabelledString):
class MTex(StringMobject):
CONFIG = {
"font_size": 48,
"alignment": "\\centering",
@ -32,7 +41,7 @@ class MTex(LabelledString):
def __init__(self, tex_string: str, **kwargs):
# Prevent from passing an empty string.
if not tex_string:
if not tex_string.strip():
tex_string = "\\\\"
self.tex_string = tex_string
super().__init__(tex_string, **kwargs)
@ -47,7 +56,6 @@ class MTex(LabelledString):
self.svg_default,
self.path_string_config,
self.base_color,
self.use_plain_file,
self.isolate,
self.tex_string,
self.alignment,
@ -61,270 +69,103 @@ class MTex(LabelledString):
tex_config["text_to_replace"],
content
)
with display_during_execution(f"Writing \"{self.tex_string}\""):
with display_during_execution(f"Writing \"{self.string}\""):
file_path = tex_to_svg_file(full_tex)
return file_path
def pre_parse(self) -> None:
super().pre_parse()
self.backslash_indices = self.get_backslash_indices()
self.brace_index_pairs = self.get_brace_index_pairs()
self.script_char_spans = self.get_script_char_spans()
self.script_content_spans = self.get_script_content_spans()
self.script_spans = self.get_script_spans()
# Toolkits
@staticmethod
def get_color_command_str(rgb_int: int) -> str:
rgb_tuple = MTex.int_to_rgb(rgb_int)
return "".join([
"\\color[RGB]",
"{",
",".join(map(str, rgb_tuple)),
"}"
])
# Pre-parsing
def get_backslash_indices(self) -> list[int]:
# The latter of `\\` doesn't count.
return list(it.chain(*[
range(span[0], span[1], 2)
for span in self.find_spans(r"\\+")
]))
def get_unescaped_char_spans(self, chars: str):
return sorted(filter(
lambda span: span[0] - 1 not in self.backslash_indices,
self.find_substrs(list(chars))
))
def get_brace_index_pairs(self) -> list[Span]:
left_brace_indices = []
right_brace_indices = []
left_brace_indices_stack = []
for span in self.get_unescaped_char_spans("{}"):
index = span[0]
if self.get_substr(span) == "{":
left_brace_indices_stack.append(index)
else:
if not left_brace_indices_stack:
raise ValueError("Missing '{' inserted")
left_brace_index = left_brace_indices_stack.pop()
left_brace_indices.append(left_brace_index)
right_brace_indices.append(index)
if left_brace_indices_stack:
raise ValueError("Missing '}' inserted")
return list(zip(left_brace_indices, right_brace_indices))
def get_script_char_spans(self) -> list[int]:
return self.get_unescaped_char_spans("_^")
def get_script_content_spans(self) -> list[Span]:
result = []
brace_indices_dict = dict(self.brace_index_pairs)
script_pattern = r"[a-zA-Z0-9]|\\[a-zA-Z]+"
for script_char_span in self.script_char_spans:
span_begin = self.match(r"\s*", pos=script_char_span[1]).end()
if span_begin in brace_indices_dict.keys():
span_end = brace_indices_dict[span_begin] + 1
else:
match_obj = self.match(script_pattern, pos=span_begin)
if not match_obj:
script_name = {
"_": "subscript",
"^": "superscript"
}[script_char]
raise ValueError(
f"Unclear {script_name} detected while parsing. "
"Please use braces to clarify"
)
span_end = match_obj.end()
result.append((span_begin, span_end))
return result
def get_script_spans(self) -> list[Span]:
return [
(
self.search(r"\s*$", endpos=script_char_span[0]).start(),
script_content_span[1]
)
for script_char_span, script_content_span in zip(
self.script_char_spans, self.script_content_spans
)
]
# Parsing
def get_command_repl_items(self) -> list[tuple[Span, str]]:
color_related_command_dict = {
"color": (1, False),
"textcolor": (1, False),
"pagecolor": (1, True),
"colorbox": (1, True),
"fcolorbox": (2, True),
}
result = []
backslash_indices = self.backslash_indices
right_brace_indices = [
right_index
for left_index, right_index in self.brace_index_pairs
def get_cmd_spans(self) -> list[Span]:
return self.find_spans(r"\\(?:[a-zA-Z]+|\s|\S)|[_^{}]")
def get_substr_flag(self, substr: str) -> int:
return {"{": 1, "}": -1}.get(substr, 0)
def get_repl_substr_for_content(self, substr: str) -> str:
return substr
def get_repl_substr_for_matching(self, substr: str) -> str:
return substr if substr.startswith("\\") else ""
def get_specified_items(
self, cmd_span_pairs: list[tuple[Span, Span]]
) -> list[tuple[Span, dict[str, str]]]:
cmd_content_spans = [
(span_begin, span_end)
for (_, span_begin), (span_end, _) in cmd_span_pairs
]
pattern = "".join([
r"\\",
"(",
"|".join(color_related_command_dict.keys()),
")",
r"(?![a-zA-Z])"
])
for match_obj in self.finditer(pattern):
span_begin, cmd_end = match_obj.span()
if span_begin not in backslash_indices:
continue
cmd_name = match_obj.group(1)
n_braces, substitute_cmd = color_related_command_dict[cmd_name]
span_end = self.take_nearest_value(
right_brace_indices, cmd_end, n_braces
) + 1
if substitute_cmd:
repl_str = "\\" + cmd_name + n_braces * "{black}"
else:
repl_str = ""
result.append(((span_begin, span_end), repl_str))
return result
def get_extra_entity_spans(self) -> list[Span]:
return [
self.match(r"\\([a-zA-Z]+|.)", pos=index).span()
for index in self.backslash_indices
]
def get_extra_ignored_spans(self) -> list[int]:
return self.script_char_spans.copy()
def get_internal_specified_spans(self) -> list[Span]:
# Match paired double braces (`{{...}}`).
result = []
reversed_brace_indices_dict = dict([
pair[::-1] for pair in self.brace_index_pairs
])
skip = False
for prev_right_index, right_index in self.get_neighbouring_pairs(
list(reversed_brace_indices_dict.keys())
):
if skip:
skip = False
continue
if right_index != prev_right_index + 1:
continue
left_index = reversed_brace_indices_dict[right_index]
prev_left_index = reversed_brace_indices_dict[prev_right_index]
if left_index != prev_left_index - 1:
continue
result.append((left_index, right_index + 1))
skip = True
return result
def get_external_specified_spans(self) -> list[Span]:
return self.find_substrs(list(self.tex_to_color_map.keys()))
def get_label_span_list(self) -> list[Span]:
result = self.script_content_spans.copy()
for span_begin, span_end in self.specified_spans:
shrinked_end = self.lslide(span_end, self.script_spans)
if span_begin >= shrinked_end:
continue
shrinked_span = (span_begin, shrinked_end)
if shrinked_span in result:
continue
result.append(shrinked_span)
return result
def get_content(self, use_plain_file: bool) -> str:
if use_plain_file:
span_repl_dict = {}
else:
extended_label_span_list = [
specified_spans = [
*[
cmd_content_spans[range_begin]
for _, (range_begin, range_end) in self.compress_neighbours([
(span_begin + index, span_end - index)
for index, (span_begin, span_end) in enumerate(
cmd_content_spans
)
])
if range_end - range_begin >= 2
],
*[
span
if span in self.script_content_spans
else (span[0], self.rslide(span[1], self.script_spans))
for span in self.label_span_list
]
inserted_string_pairs = [
(span, (
"{{" + self.get_color_command_str(label + 1),
"}}"
))
for label, span in enumerate(extended_label_span_list)
]
span_repl_dict = self.generate_span_repl_dict(
inserted_string_pairs,
self.command_repl_items
)
result = self.get_replaced_substr(self.full_span, span_repl_dict)
for selector in self.tex_to_color_map
for span in self.find_spans_by_selector(selector)
],
*self.find_spans_by_selector(self.isolate)
]
return [(span, {}) for span in specified_spans]
if self.tex_environment:
result = "\n".join([
f"\\begin{{{self.tex_environment}}}",
result,
f"\\end{{{self.tex_environment}}}"
])
@staticmethod
def get_color_cmd_str(rgb_hex: str) -> str:
rgb = MTex.hex_to_int(rgb_hex)
rg, b = divmod(rgb, 256)
r, g = divmod(rg, 256)
return f"\\color[RGB]{{{r}, {g}, {b}}}"
@staticmethod
def get_cmd_str_pair(
attr_dict: dict[str, str], label_hex: str | None
) -> tuple[str, str]:
if label_hex is None:
return "", ""
return "{{" + MTex.get_color_cmd_str(label_hex), "}}"
def get_content_prefix_and_suffix(
self, is_labelled: bool
) -> tuple[str, str]:
prefix_lines = []
suffix_lines = []
if not is_labelled:
prefix_lines.append(self.get_color_cmd_str(self.base_color_hex))
if self.alignment:
result = "\n".join([self.alignment, result])
if use_plain_file:
result = "\n".join([
self.get_color_command_str(self.hex_to_int(self.base_color)),
result
])
return result
@property
def has_predefined_local_colors(self) -> bool:
return bool(self.command_repl_items)
# Post-parsing
def get_cleaned_substr(self, span: Span) -> str:
substr = super().get_cleaned_substr(span)
if not self.brace_index_pairs:
return substr
# Balance braces.
left_brace_indices, right_brace_indices = zip(*self.brace_index_pairs)
unclosed_left_braces = 0
unclosed_right_braces = 0
for index in range(*span):
if index in left_brace_indices:
unclosed_left_braces += 1
elif index in right_brace_indices:
if unclosed_left_braces == 0:
unclosed_right_braces += 1
else:
unclosed_left_braces -= 1
return "".join([
unclosed_right_braces * "{",
substr,
unclosed_left_braces * "}"
])
prefix_lines.append(self.alignment)
if self.tex_environment:
if isinstance(self.tex_environment, str):
env_prefix = f"\\begin{{{self.tex_environment}}}"
env_suffix = f"\\end{{{self.tex_environment}}}"
else:
env_prefix, env_suffix = self.tex_environment
prefix_lines.append(env_prefix)
suffix_lines.append(env_suffix)
return (
"".join([line + "\n" for line in prefix_lines]),
"".join(["\n" + line for line in suffix_lines])
)
# Method alias
def get_parts_by_tex(self, tex: str, **kwargs) -> VGroup:
return self.get_parts_by_string(tex, **kwargs)
def get_parts_by_tex(self, selector: Selector) -> VGroup:
return self.select_parts(selector)
def get_part_by_tex(self, tex: str, **kwargs) -> VMobject:
return self.get_part_by_string(tex, **kwargs)
def get_part_by_tex(self, selector: Selector) -> VGroup:
return self.select_part(selector)
def set_color_by_tex(self, tex: str, color: ManimColor, **kwargs):
return self.set_color_by_string(tex, color, **kwargs)
def set_color_by_tex(self, selector: Selector, color: ManimColor):
return self.set_parts_color(selector, color)
def set_color_by_tex_to_color_map(
self, tex_to_color_map: dict[str, ManimColor], **kwargs
self, color_map: dict[Selector, ManimColor]
):
return self.set_color_by_string_to_color_map(
tex_to_color_map, **kwargs
)
return self.set_parts_color_by_dict(color_map)
def get_tex(self) -> str:
return self.get_string()

View file

@ -0,0 +1,532 @@
from __future__ import annotations
from abc import ABC, abstractmethod
import itertools as it
import re
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import cdist
from manimlib.constants import WHITE
from manimlib.logger import log
from manimlib.mobject.svg.svg_mobject import SVGMobject
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.color import color_to_rgb
from manimlib.utils.color import rgb_to_hex
from manimlib.utils.config_ops import digest_config
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from colour import Color
from typing import Iterable, Sequence, TypeVar, Union
ManimColor = Union[str, Color]
Span = tuple[int, int]
Selector = Union[
str,
re.Pattern,
tuple[Union[int, None], Union[int, None]],
Iterable[Union[
str,
re.Pattern,
tuple[Union[int, None], Union[int, None]]
]]
]
T = TypeVar("T")
class StringMobject(SVGMobject, ABC):
"""
An abstract base class for `MTex` and `MarkupText`
This class aims to optimize the logic of "slicing submobjects
via substrings". This could be much clearer and more user-friendly
than slicing through numerical indices explicitly.
Users are expected to specify substrings in `isolate` parameter
if they want to do anything with their corresponding submobjects.
`isolate` parameter can be either a string, a `re.Pattern` object,
or a 2-tuple containing integers or None, or a collection of the above.
Note, substrings specified cannot *partially* overlap with each other.
Each instance of `StringMobject` generates 2 svg files.
The additional one is generated with some color commands inserted,
so that each submobject of the original `SVGMobject` will be labelled
by the color of its paired submobject from the additional `SVGMobject`.
"""
CONFIG = {
"height": None,
"stroke_width": 0,
"stroke_color": WHITE,
"path_string_config": {
"should_subdivide_sharp_curves": True,
"should_remove_null_curves": True,
},
"base_color": WHITE,
"isolate": (),
}
def __init__(self, string: str, **kwargs):
self.string = string
digest_config(self, kwargs)
if self.base_color is None:
self.base_color = WHITE
self.base_color_hex = self.color_to_hex(self.base_color)
self.full_span = (0, len(self.string))
self.parse()
super().__init__(**kwargs)
self.labels = [submob.label for submob in self.submobjects]
def get_file_path(self) -> str:
original_content = self.get_content(is_labelled=False)
return self.get_file_path_by_content(original_content)
@abstractmethod
def get_file_path_by_content(self, content: str) -> str:
return ""
def generate_mobject(self) -> None:
super().generate_mobject()
labels_count = len(self.labelled_spans)
if not labels_count:
for submob in self.submobjects:
submob.label = -1
return
labelled_content = self.get_content(is_labelled=True)
file_path = self.get_file_path_by_content(labelled_content)
labelled_svg = SVGMobject(file_path)
if len(self.submobjects) != len(labelled_svg.submobjects):
log.warning(
"Cannot align submobjects of the labelled svg "
"to the original svg. Skip the labelling process."
)
for submob in self.submobjects:
submob.label = -1
return
self.rearrange_submobjects_by_positions(labelled_svg)
unrecognizable_colors = []
for submob, labelled_svg_submob in zip(
self.submobjects, labelled_svg.submobjects
):
color_int = self.hex_to_int(self.color_to_hex(
labelled_svg_submob.get_fill_color()
))
if color_int > labels_count:
unrecognizable_colors.append(color_int)
color_int = 0
submob.label = color_int - 1
if unrecognizable_colors:
log.warning(
"Unrecognizable color labels detected (%s, etc). "
"The result could be unexpected.",
self.int_to_hex(unrecognizable_colors[0])
)
def rearrange_submobjects_by_positions(
self, labelled_svg: SVGMobject
) -> None:
# Rearrange submobjects of `labelled_svg` so that
# each submobject is labelled by the nearest one of `labelled_svg`.
# The correctness cannot be ensured, since the svg may
# change significantly after inserting color commands.
if not labelled_svg.submobjects:
return
bb_0 = self.get_bounding_box()
bb_1 = labelled_svg.get_bounding_box()
scale_factor = abs((bb_0[2] - bb_0[0]) / (bb_1[2] - bb_1[0]))
labelled_svg.move_to(self).scale(scale_factor)
distance_matrix = cdist(
[submob.get_center() for submob in self.submobjects],
[submob.get_center() for submob in labelled_svg.submobjects]
)
_, indices = linear_sum_assignment(distance_matrix)
labelled_svg.set_submobjects([
labelled_svg.submobjects[index]
for index in indices
])
# Toolkits
def get_substr(self, span: Span) -> str:
return self.string[slice(*span)]
def find_spans(self, pattern: str | re.Pattern) -> list[Span]:
return [
match_obj.span()
for match_obj in re.finditer(pattern, self.string)
]
def find_spans_by_selector(self, selector: Selector) -> list[Span]:
def find_spans_by_single_selector(sel):
if isinstance(sel, str):
return self.find_spans(re.escape(sel))
if isinstance(sel, re.Pattern):
return self.find_spans(sel)
if isinstance(sel, tuple) and len(sel) == 2 and all(
isinstance(index, int) or index is None
for index in sel
):
l = self.full_span[1]
span = tuple(
min(index, l) if index >= 0 else max(index + l, 0)
if index is not None else default_index
for index, default_index in zip(sel, self.full_span)
)
return [span]
return None
result = find_spans_by_single_selector(selector)
if result is None:
result = []
for sel in selector:
spans = find_spans_by_single_selector(sel)
if spans is None:
raise TypeError(f"Invalid selector: '{sel}'")
result.extend(spans)
return result
@staticmethod
def get_neighbouring_pairs(vals: Sequence[T]) -> list[tuple[T, T]]:
return list(zip(vals[:-1], vals[1:]))
@staticmethod
def compress_neighbours(vals: Sequence[T]) -> list[tuple[T, Span]]:
if not vals:
return []
unique_vals = [vals[0]]
indices = [0]
for index, val in enumerate(vals):
if val == unique_vals[-1]:
continue
unique_vals.append(val)
indices.append(index)
indices.append(len(vals))
val_ranges = StringMobject.get_neighbouring_pairs(indices)
return list(zip(unique_vals, val_ranges))
@staticmethod
def span_contains(span_0: Span, span_1: Span) -> bool:
return span_0[0] <= span_1[0] and span_0[1] >= span_1[1]
@staticmethod
def get_complement_spans(
universal_span: Span, interval_spans: list[Span]
) -> list[Span]:
if not interval_spans:
return [universal_span]
span_ends, span_begins = zip(*interval_spans)
return list(zip(
(universal_span[0], *span_begins),
(*span_ends, universal_span[1])
))
def replace_substr(self, span: Span, repl_items: list[Span, str]):
if not repl_items:
return self.get_substr(span)
repl_spans, repl_strs = zip(*sorted(repl_items, key=lambda t: t[0]))
pieces = [
self.get_substr(piece_span)
for piece_span in self.get_complement_spans(span, repl_spans)
]
repl_strs = [*repl_strs, ""]
return "".join(it.chain(*zip(pieces, repl_strs)))
@staticmethod
def color_to_hex(color: ManimColor) -> str:
return rgb_to_hex(color_to_rgb(color))
@staticmethod
def hex_to_int(rgb_hex: str) -> int:
return int(rgb_hex[1:], 16)
@staticmethod
def int_to_hex(rgb_int: int) -> str:
return f"#{rgb_int:06x}".upper()
# Parsing
def parse(self) -> None:
cmd_spans = self.get_cmd_spans()
cmd_substrs = [self.get_substr(span) for span in cmd_spans]
flags = [self.get_substr_flag(substr) for substr in cmd_substrs]
specified_items = self.get_specified_items(
self.get_cmd_span_pairs(cmd_spans, flags)
)
split_items = [
(span, attr_dict)
for specified_span, attr_dict in specified_items
for span in self.split_span_by_levels(
specified_span, cmd_spans, flags
)
]
self.specified_spans = [span for span, _ in specified_items]
self.split_items = split_items
self.labelled_spans = [span for span, _ in split_items]
self.cmd_repl_items_for_content = [
(span, self.get_repl_substr_for_content(substr))
for span, substr in zip(cmd_spans, cmd_substrs)
]
self.cmd_repl_items_for_matching = [
(span, self.get_repl_substr_for_matching(substr))
for span, substr in zip(cmd_spans, cmd_substrs)
]
self.check_overlapping()
@abstractmethod
def get_cmd_spans(self) -> list[Span]:
return []
@abstractmethod
def get_substr_flag(self, substr: str) -> int:
return 0
@abstractmethod
def get_repl_substr_for_content(self, substr: str) -> str:
return ""
@abstractmethod
def get_repl_substr_for_matching(self, substr: str) -> str:
return ""
@staticmethod
def get_cmd_span_pairs(
cmd_spans: list[Span], flags: list[int]
) -> list[tuple[Span, Span]]:
result = []
begin_cmd_spans_stack = []
for cmd_span, flag in zip(cmd_spans, flags):
if flag == 1:
begin_cmd_spans_stack.append(cmd_span)
elif flag == -1:
if not begin_cmd_spans_stack:
raise ValueError("Missing open command")
begin_cmd_span = begin_cmd_spans_stack.pop()
result.append((begin_cmd_span, cmd_span))
if begin_cmd_spans_stack:
raise ValueError("Missing close command")
return result
@abstractmethod
def get_specified_items(
self, cmd_span_pairs: list[tuple[Span, Span]]
) -> list[tuple[Span, dict[str, str]]]:
return []
def split_span_by_levels(
self, arbitrary_span: Span, cmd_spans: list[Span], flags: list[int]
) -> list[Span]:
cmd_range = (
sum([
arbitrary_span[0] > interval_begin
for interval_begin, _ in cmd_spans
]),
sum([
arbitrary_span[1] >= interval_end
for _, interval_end in cmd_spans
])
)
complement_spans = self.get_complement_spans(
self.full_span, cmd_spans
)
adjusted_span = (
max(arbitrary_span[0], complement_spans[cmd_range[0]][0]),
min(arbitrary_span[1], complement_spans[cmd_range[1]][1])
)
if adjusted_span[0] > adjusted_span[1]:
return []
upward_cmd_spans = []
downward_cmd_spans = []
for cmd_span, flag in list(zip(cmd_spans, flags))[slice(*cmd_range)]:
if flag == 1:
upward_cmd_spans.append(cmd_span)
elif flag == -1:
if upward_cmd_spans:
upward_cmd_spans.pop()
else:
downward_cmd_spans.append(cmd_span)
return list(filter(
lambda span: self.get_substr(span).strip(),
self.get_complement_spans(
adjusted_span, downward_cmd_spans + upward_cmd_spans
)
))
def check_overlapping(self) -> None:
labelled_spans = self.labelled_spans
if len(labelled_spans) >= 16777216:
raise ValueError("Cannot handle that many substrings")
for span_0, span_1 in it.product(labelled_spans, repeat=2):
if not span_0[0] < span_1[0] < span_0[1] < span_1[1]:
continue
raise ValueError(
"Partially overlapping substrings detected: "
f"'{self.get_substr(span_0)}' and '{self.get_substr(span_1)}'"
)
@staticmethod
@abstractmethod
def get_cmd_str_pair(
attr_dict: dict[str, str], label_hex: str | None
) -> tuple[str, str]:
return "", ""
@abstractmethod
def get_content_prefix_and_suffix(
self, is_labelled: bool
) -> tuple[str, str]:
return "", ""
def get_content(self, is_labelled: bool) -> str:
inserted_str_pairs = [
(span, self.get_cmd_str_pair(
attr_dict,
label_hex=self.int_to_hex(label + 1) if is_labelled else None
))
for label, (span, attr_dict) in enumerate(self.split_items)
]
inserted_str_items = sorted([
(index, s)
for (index, _), s in [
*sorted([
(span[::-1], end_str)
for span, (_, end_str) in reversed(inserted_str_pairs)
], key=lambda t: (t[0][0], -t[0][1])),
*sorted([
(span, begin_str)
for span, (begin_str, _) in inserted_str_pairs
], key=lambda t: (t[0][0], -t[0][1]))
]
], key=lambda t: t[0])
repl_items = self.cmd_repl_items_for_content + [
((index, index), inserted_str)
for index, inserted_str in inserted_str_items
]
prefix, suffix = self.get_content_prefix_and_suffix(is_labelled)
return "".join([
prefix,
self.replace_substr(self.full_span, repl_items),
suffix
])
# Selector
def get_submob_indices_list_by_span(
self, arbitrary_span: Span
) -> list[int]:
return [
submob_index
for submob_index, label in enumerate(self.labels)
if label != -1 and self.span_contains(
arbitrary_span, self.labelled_spans[label]
)
]
def get_specified_part_items(self) -> list[tuple[str, list[int]]]:
return [
(
self.get_substr(span),
self.get_submob_indices_list_by_span(span)
)
for span in self.specified_spans
]
def get_group_part_items(self) -> list[tuple[str, list[int]]]:
if not self.labels:
return []
group_labels, labelled_submob_ranges = zip(
*self.compress_neighbours(self.labels)
)
ordered_spans = [
self.labelled_spans[label] if label != -1 else self.full_span
for label in group_labels
]
interval_spans = [
(
next_span[0]
if self.span_contains(prev_span, next_span)
else prev_span[1],
prev_span[1]
if self.span_contains(next_span, prev_span)
else next_span[0]
)
for prev_span, next_span in self.get_neighbouring_pairs(
ordered_spans
)
]
group_substrs = [
re.sub(r"\s+", "", self.replace_substr(
span, [
(cmd_span, repl_str)
for cmd_span, repl_str in self.cmd_repl_items_for_matching
if self.span_contains(span, cmd_span)
]
))
for span in self.get_complement_spans(
(ordered_spans[0][0], ordered_spans[-1][1]), interval_spans
)
]
submob_indices_lists = [
list(range(*submob_range))
for submob_range in labelled_submob_ranges
]
return list(zip(group_substrs, submob_indices_lists))
def get_submob_indices_lists_by_selector(
self, selector: Selector
) -> list[list[int]]:
return list(filter(
lambda indices_list: indices_list,
[
self.get_submob_indices_list_by_span(span)
for span in self.find_spans_by_selector(selector)
]
))
def build_parts_from_indices_lists(
self, indices_lists: list[list[int]]
) -> VGroup:
return VGroup(*[
VGroup(*[
self.submobjects[submob_index]
for submob_index in indices_list
])
for indices_list in indices_lists
])
def build_groups(self) -> VGroup:
return self.build_parts_from_indices_lists([
indices_list
for _, indices_list in self.get_group_part_items()
])
def select_parts(self, selector: Selector) -> VGroup:
return self.build_parts_from_indices_lists(
self.get_submob_indices_lists_by_selector(selector)
)
def select_part(self, selector: Selector, index: int = 0) -> VGroup:
return self.select_parts(selector)[index]
def set_parts_color(self, selector: Selector, color: ManimColor):
self.select_parts(selector).set_color(color)
return self
def set_parts_color_by_dict(self, color_map: dict[Selector, ManimColor]):
for selector, color in color_map.items():
self.set_parts_color(selector, color)
return self
def get_string(self) -> str:
return self.string

View file

@ -1,103 +1,52 @@
from __future__ import annotations
import os
import re
import itertools as it
from pathlib import Path
from contextlib import contextmanager
import typing
from typing import Iterable, Sequence, Union
import os
from pathlib import Path
import re
import manimpango
import pygments
import pygments.formatters
import pygments.lexers
from manimpango import MarkupUtils
from manimlib.constants import DEFAULT_PIXEL_WIDTH, FRAME_WIDTH
from manimlib.constants import NORMAL
from manimlib.logger import log
from manimlib.constants import *
from manimlib.mobject.svg.labelled_string import LabelledString
from manimlib.utils.customization import get_customization
from manimlib.utils.tex_file_writing import tex_hash
from manimlib.mobject.svg.string_mobject import StringMobject
from manimlib.utils.config_ops import digest_config
from manimlib.utils.customization import get_customization
from manimlib.utils.directories import get_downloads_dir
from manimlib.utils.directories import get_text_dir
from manimlib.utils.iterables import remove_list_redundancies
from manimlib.utils.tex_file_writing import tex_hash
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.mobject.types.vectorized_mobject import VMobject
from colour import Color
from typing import Iterable, Union
from manimlib.mobject.types.vectorized_mobject import VGroup
ManimColor = Union[str, colour.Color, Sequence[float]]
ManimColor = Union[str, Color]
Span = tuple[int, int]
Selector = Union[
str,
re.Pattern,
tuple[Union[int, None], Union[int, None]],
Iterable[Union[
str,
re.Pattern,
tuple[Union[int, None], Union[int, None]]
]]
]
TEXT_MOB_SCALE_FACTOR = 0.0076
DEFAULT_LINE_SPACING_SCALE = 0.6
# See https://docs.gtk.org/Pango/pango_markup.html
# A tag containing two aliases will cause warning,
# so only use the first key of each group of aliases.
SPAN_ATTR_KEY_ALIAS_LIST = (
("font", "font_desc"),
("font_family", "face"),
("font_size", "size"),
("font_style", "style"),
("font_weight", "weight"),
("font_variant", "variant"),
("font_stretch", "stretch"),
("font_features",),
("foreground", "fgcolor", "color"),
("background", "bgcolor"),
("alpha", "fgalpha"),
("background_alpha", "bgalpha"),
("underline",),
("underline_color",),
("overline",),
("overline_color",),
("rise",),
("baseline_shift",),
("font_scale",),
("strikethrough",),
("strikethrough_color",),
("fallback",),
("lang",),
("letter_spacing",),
("gravity",),
("gravity_hint",),
("show",),
("insert_hyphens",),
("allow_breaks",),
("line_height",),
("text_transform",),
("segment",),
)
COLOR_RELATED_KEYS = (
"foreground",
"background",
"underline_color",
"overline_color",
"strikethrough_color"
)
SPAN_ATTR_KEY_CONVERSION = {
key: key_alias_list[0]
for key_alias_list in SPAN_ATTR_KEY_ALIAS_LIST
for key in key_alias_list
}
TAG_TO_ATTR_DICT = {
"b": {"font_weight": "bold"},
"big": {"font_size": "larger"},
"i": {"font_style": "italic"},
"s": {"strikethrough": "true"},
"sub": {"baseline_shift": "subscript", "font_scale": "subscript"},
"sup": {"baseline_shift": "superscript", "font_scale": "superscript"},
"small": {"font_size": "smaller"},
"tt": {"font_family": "monospace"},
"u": {"underline": "single"},
}
# Ensure the canvas is large enough to hold all glyphs.
DEFAULT_CANVAS_WIDTH = 16384
DEFAULT_CANVAS_HEIGHT = 16384
# Temporary handler
@ -112,7 +61,7 @@ class _Alignment:
self.value = _Alignment.VAL_DICT[s.upper()]
class MarkupText(LabelledString):
class MarkupText(StringMobject):
CONFIG = {
"is_markup": True,
"font_size": 48,
@ -120,7 +69,7 @@ class MarkupText(LabelledString):
"justify": False,
"indent": 0,
"alignment": "LEFT",
"line_width_factor": None,
"line_width": None,
"font": "",
"slant": NORMAL,
"weight": NORMAL,
@ -132,6 +81,31 @@ class MarkupText(LabelledString):
"t2w": {},
"global_config": {},
"local_configs": {},
# For backward compatibility
"isolate": (re.compile(r"[a-zA-Z]+"), re.compile(r"\S+")),
}
# See https://docs.gtk.org/Pango/pango_markup.html
MARKUP_COLOR_KEYS = {
"foreground": False,
"fgcolor": False,
"color": False,
"background": True,
"bgcolor": True,
"underline_color": True,
"overline_color": True,
"strikethrough_color": True,
}
MARKUP_TAGS = {
"b": {"font_weight": "bold"},
"big": {"font_size": "larger"},
"i": {"font_style": "italic"},
"s": {"strikethrough": "true"},
"sub": {"baseline_shift": "subscript", "font_scale": "subscript"},
"sup": {"baseline_shift": "superscript", "font_scale": "superscript"},
"small": {"font_size": "smaller"},
"tt": {"font_family": "monospace"},
"u": {"underline": "single"},
}
def __init__(self, text: str, **kwargs):
@ -141,9 +115,7 @@ class MarkupText(LabelledString):
if not self.font:
self.font = get_customization()["style"]["font"]
if self.is_markup:
validate_error = MarkupUtils.validate(text)
if validate_error:
raise ValueError(validate_error)
self.validate_markup_string(text)
self.text = text
super().__init__(text, **kwargs)
@ -165,7 +137,6 @@ class MarkupText(LabelledString):
self.svg_default,
self.path_string_config,
self.base_color,
self.use_plain_file,
self.isolate,
self.text,
self.is_markup,
@ -174,7 +145,7 @@ class MarkupText(LabelledString):
self.justify,
self.indent,
self.alignment,
self.line_width_factor,
self.line_width,
self.font,
self.slant,
self.weight,
@ -201,23 +172,32 @@ class MarkupText(LabelledString):
kwargs[short_name] = kwargs.pop(long_name)
def get_file_path_by_content(self, content: str) -> str:
hash_content = str((
content,
self.justify,
self.indent,
self.alignment,
self.line_width
))
svg_file = os.path.join(
get_text_dir(), tex_hash(content) + ".svg"
get_text_dir(), tex_hash(hash_content) + ".svg"
)
if not os.path.exists(svg_file):
self.markup_to_svg(content, svg_file)
return svg_file
def markup_to_svg(self, markup_str: str, file_name: str) -> str:
self.validate_markup_string(markup_str)
# `manimpango` is under construction,
# so the following code is intended to suit its interface
alignment = _Alignment(self.alignment)
if self.line_width_factor is None:
if self.line_width is None:
pango_width = -1
else:
pango_width = self.line_width_factor * DEFAULT_PIXEL_WIDTH
pango_width = self.line_width / FRAME_WIDTH * DEFAULT_PIXEL_WIDTH
return MarkupUtils.text2svg(
return manimpango.MarkupUtils.text2svg(
text=markup_str,
font="", # Already handled
slant="NORMAL", # Already handled
@ -228,8 +208,8 @@ class MarkupText(LabelledString):
file_name=file_name,
START_X=0,
START_Y=0,
width=DEFAULT_PIXEL_WIDTH,
height=DEFAULT_PIXEL_HEIGHT,
width=DEFAULT_CANVAS_WIDTH,
height=DEFAULT_CANVAS_HEIGHT,
justify=self.justify,
indent=self.indent,
line_spacing=None, # Already handled
@ -237,294 +217,173 @@ class MarkupText(LabelledString):
pango_width=pango_width
)
def pre_parse(self) -> None:
super().pre_parse()
self.tag_items_from_markup = self.get_tag_items_from_markup()
self.global_dict_from_config = self.get_global_dict_from_config()
self.local_dicts_from_markup = self.get_local_dicts_from_markup()
self.local_dicts_from_config = self.get_local_dicts_from_config()
self.predefined_attr_dicts = self.get_predefined_attr_dicts()
# Toolkits
@staticmethod
def get_attr_dict_str(attr_dict: dict[str, str]) -> str:
return " ".join([
f"{key}='{val}'"
for key, val in attr_dict.items()
])
@staticmethod
def merge_attr_dicts(
attr_dict_items: list[Span, str, typing.Any]
) -> list[tuple[Span, dict[str, str]]]:
index_seq = [0]
attr_dict_list = [{}]
for span, attr_dict in attr_dict_items:
if span[0] >= span[1]:
continue
region_indices = [
MarkupText.find_region_index(index_seq, index)
for index in span
]
for flag in (1, 0):
if index_seq[region_indices[flag]] == span[flag]:
continue
region_index = region_indices[flag]
index_seq.insert(region_index + 1, span[flag])
attr_dict_list.insert(
region_index + 1, attr_dict_list[region_index].copy()
)
region_indices[flag] += 1
if flag == 0:
region_indices[1] += 1
for key, val in attr_dict.items():
if not key:
continue
for mid_dict in attr_dict_list[slice(*region_indices)]:
mid_dict[key] = val
return list(zip(
MarkupText.get_neighbouring_pairs(index_seq), attr_dict_list[:-1]
))
def find_substr_or_span(
self, substr_or_span: str | tuple[int | None, int | None]
) -> list[Span]:
if isinstance(substr_or_span, str):
return self.find_substr(substr_or_span)
span = tuple([
(
min(index, self.string_len)
if index >= 0
else max(index + self.string_len, 0)
)
if index is not None else default_index
for index, default_index in zip(substr_or_span, self.full_span)
])
if span[0] >= span[1]:
return []
return [span]
# Pre-parsing
def get_tag_items_from_markup(
self
) -> list[tuple[Span, Span, dict[str, str]]]:
if not self.is_markup:
return []
tag_pattern = r"""<(/?)(\w+)\s*((?:\w+\s*\=\s*(['"]).*?\4\s*)*)>"""
attr_pattern = r"""(\w+)\s*\=\s*(['"])(.*?)\2"""
begin_match_obj_stack = []
match_obj_pairs = []
for match_obj in self.finditer(tag_pattern):
if not match_obj.group(1):
begin_match_obj_stack.append(match_obj)
else:
match_obj_pairs.append(
(begin_match_obj_stack.pop(), match_obj)
)
if begin_match_obj_stack:
raise ValueError("Unclosed tag(s) detected")
result = []
for begin_match_obj, end_match_obj in match_obj_pairs:
tag_name = begin_match_obj.group(2)
if tag_name != end_match_obj.group(2):
raise ValueError("Unmatched tag names")
if end_match_obj.group(3):
raise ValueError("Attributes shan't exist in ending tags")
if tag_name == "span":
attr_dict = {
match.group(1): match.group(3)
for match in re.finditer(
attr_pattern, begin_match_obj.group(3)
)
}
elif tag_name in TAG_TO_ATTR_DICT.keys():
if begin_match_obj.group(3):
raise ValueError(
f"Attributes shan't exist in tag '{tag_name}'"
)
attr_dict = TAG_TO_ATTR_DICT[tag_name].copy()
else:
raise ValueError(f"Unknown tag: '{tag_name}'")
result.append(
(begin_match_obj.span(), end_match_obj.span(), attr_dict)
)
return result
def get_global_dict_from_config(self) -> dict[str, typing.Any]:
result = {
"line_height": (
(self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1
) * 0.6,
"font_family": self.font,
"font_size": self.font_size * 1024,
"font_style": self.slant,
"font_weight": self.weight
}
result.update(self.global_config)
return result
def get_local_dicts_from_markup(
self
) -> list[Span, dict[str, str]]:
return sorted([
((begin_tag_span[0], end_tag_span[1]), attr_dict)
for begin_tag_span, end_tag_span, attr_dict
in self.tag_items_from_markup
])
def get_local_dicts_from_config(
self
) -> list[Span, dict[str, typing.Any]]:
return [
(span, {key: val})
for t2x_dict, key in (
(self.t2c, "foreground"),
(self.t2f, "font_family"),
(self.t2s, "font_style"),
(self.t2w, "font_weight")
)
for substr_or_span, val in t2x_dict.items()
for span in self.find_substr_or_span(substr_or_span)
] + [
(span, local_config)
for substr_or_span, local_config in self.local_configs.items()
for span in self.find_substr_or_span(substr_or_span)
]
def get_predefined_attr_dicts(self) -> list[Span, dict[str, str]]:
attr_dict_items = [
(self.full_span, self.global_dict_from_config),
*self.local_dicts_from_markup,
*self.local_dicts_from_config
]
return [
(span, {
SPAN_ATTR_KEY_CONVERSION[key.lower()]: str(val)
for key, val in attr_dict.items()
})
for span, attr_dict in attr_dict_items
]
def validate_markup_string(markup_str: str) -> None:
validate_error = manimpango.MarkupUtils.validate(markup_str)
if not validate_error:
return
raise ValueError(
f"Invalid markup string \"{markup_str}\"\n"
f"{validate_error}"
)
# Parsing
def get_command_repl_items(self) -> list[tuple[Span, str]]:
result = [
(tag_span, "")
for begin_tag, end_tag, _ in self.tag_items_from_markup
for tag_span in (begin_tag, end_tag)
]
def get_cmd_spans(self) -> list[Span]:
if not self.is_markup:
result += [
(span, escaped)
for char, escaped in (
("&", "&amp;"),
(">", "&gt;"),
("<", "&lt;")
)
for span in self.find_substr(char)
]
return result
return self.find_spans(r"""[<>&"']""")
def get_extra_entity_spans(self) -> list[Span]:
if not self.is_markup:
return []
return self.find_spans(r"&.*?;")
def get_extra_ignored_spans(self) -> list[int]:
return []
def get_internal_specified_spans(self) -> list[Span]:
return [span for span, _ in self.local_dicts_from_markup]
def get_external_specified_spans(self) -> list[Span]:
return [span for span, _ in self.local_dicts_from_config]
def get_label_span_list(self) -> list[Span]:
breakup_indices = remove_list_redundancies(list(it.chain(*it.chain(
self.find_spans(r"\s+"),
self.find_spans(r"\b"),
self.specified_spans
))))
breakup_indices = sorted(filter(
lambda index: not any([
span[0] < index < span[1]
for span in self.entity_spans
]),
breakup_indices
))
return list(filter(
lambda span: self.get_substr(span).strip(),
self.get_neighbouring_pairs(breakup_indices)
))
def get_content(self, use_plain_file: bool) -> str:
if use_plain_file:
attr_dict_items = [
(self.full_span, {"foreground": self.base_color}),
*self.predefined_attr_dicts,
*[
(span, {})
for span in self.label_span_list
]
]
else:
attr_dict_items = [
(self.full_span, {"foreground": BLACK}),
*[
(span, {
key: BLACK if key in COLOR_RELATED_KEYS else val
for key, val in attr_dict.items()
})
for span, attr_dict in self.predefined_attr_dicts
],
*[
(span, {"foreground": self.int_to_hex(label + 1)})
for label, span in enumerate(self.label_span_list)
]
]
inserted_string_pairs = [
(span, (
f"<span {self.get_attr_dict_str(attr_dict)}>",
"</span>"
))
for span, attr_dict in self.merge_attr_dicts(attr_dict_items)
]
span_repl_dict = self.generate_span_repl_dict(
inserted_string_pairs, self.command_repl_items
# Unsupported passthroughs:
# "<?...?>", "<!--...-->", "<![CDATA[...]]>", "<!DOCTYPE...>"
# See https://gitlab.gnome.org/GNOME/glib/-/blob/main/glib/gmarkup.c
return self.find_spans(
r"""&[\s\S]*?;|[>"']|</?\w+(?:\s*\w+\s*\=\s*(["'])[\s\S]*?\1)*/?>"""
)
return self.get_replaced_substr(self.full_span, span_repl_dict)
@property
def has_predefined_local_colors(self) -> bool:
return any([
key in COLOR_RELATED_KEYS
for _, attr_dict in self.predefined_attr_dicts
for key in attr_dict.keys()
def get_substr_flag(self, substr: str) -> int:
if re.fullmatch(r"<\w[\s\S]*[^/]>", substr):
return 1
if substr.startswith("</"):
return -1
return 0
def get_repl_substr_for_content(self, substr: str) -> str:
if substr.startswith("<") and substr.endswith(">"):
return ""
return {
"<": "&lt;",
">": "&gt;",
"&": "&amp;",
"\"": "&quot;",
"'": "&apos;"
}.get(substr, substr)
def get_repl_substr_for_matching(self, substr: str) -> str:
if substr.startswith("<") and substr.endswith(">"):
return ""
if substr.startswith("&#") and substr.endswith(";"):
if substr.startswith("&#x"):
char_reference = int(substr[3:-1], 16)
else:
char_reference = int(substr[2:-1], 10)
return chr(char_reference)
return {
"&lt;": "<",
"&gt;": ">",
"&amp;": "&",
"&quot;": "\"",
"&apos;": "'"
}.get(substr, substr)
def get_specified_items(
self, cmd_span_pairs: list[tuple[Span, Span]]
) -> list[tuple[Span, dict[str, str]]]:
attr_pattern = r"""(\w+)\s*\=\s*(["'])([\s\S]*?)\2"""
internal_items = []
for begin_cmd_span, end_cmd_span in cmd_span_pairs:
begin_tag = self.get_substr(begin_cmd_span)
tag_name = re.match(r"<(\w+)", begin_tag).group(1)
if tag_name == "span":
attr_dict = {
attr_match_obj.group(1): attr_match_obj.group(3)
for attr_match_obj in re.finditer(attr_pattern, begin_tag)
}
else:
attr_dict = MarkupText.MARKUP_TAGS.get(tag_name, {})
internal_items.append(
((begin_cmd_span[1], end_cmd_span[0]), attr_dict)
)
return [
*internal_items,
*[
(span, {key: val})
for t2x_dict, key in (
(self.t2c, "foreground"),
(self.t2f, "font_family"),
(self.t2s, "font_style"),
(self.t2w, "font_weight")
)
for selector, val in t2x_dict.items()
for span in self.find_spans_by_selector(selector)
],
*[
(span, local_config)
for selector, local_config in self.local_configs.items()
for span in self.find_spans_by_selector(selector)
],
*[
(span, {})
for span in self.find_spans_by_selector(self.isolate)
]
]
@staticmethod
def get_cmd_str_pair(
attr_dict: dict[str, str], label_hex: str | None
) -> tuple[str, str]:
if label_hex is not None:
converted_attr_dict = {"foreground": label_hex}
for key, val in attr_dict.items():
substitute_key = MarkupText.MARKUP_COLOR_KEYS.get(key, None)
if substitute_key is None:
converted_attr_dict[key] = val
elif substitute_key:
converted_attr_dict[key] = "black"
else:
converted_attr_dict = attr_dict.copy()
attrs_str = " ".join([
f"{key}='{val}'"
for key, val in converted_attr_dict.items()
])
return f"<span {attrs_str}>", "</span>"
def get_content_prefix_and_suffix(
self, is_labelled: bool
) -> tuple[str, str]:
global_attr_dict = {
"foreground": self.base_color_hex,
"font_family": self.font,
"font_style": self.slant,
"font_weight": self.weight,
"font_size": str(self.font_size * 1024),
}
global_attr_dict.update(self.global_config)
# `line_height` attribute is supported since Pango 1.50.
pango_version = manimpango.pango_version()
if tuple(map(int, pango_version.split("."))) < (1, 50):
if self.lsh is not None:
log.warning(
"Pango version %s found (< 1.50), "
"unable to set `line_height` attribute",
pango_version
)
else:
line_spacing_scale = self.lsh or DEFAULT_LINE_SPACING_SCALE
global_attr_dict["line_height"] = str(
((line_spacing_scale) + 1) * 0.6
)
return self.get_cmd_str_pair(
global_attr_dict,
label_hex=self.int_to_hex(0) if is_labelled else None
)
# Method alias
def get_parts_by_text(self, text: str, **kwargs) -> VGroup:
return self.get_parts_by_string(text, **kwargs)
def get_parts_by_text(self, selector: Selector) -> VGroup:
return self.select_parts(selector)
def get_part_by_text(self, text: str, **kwargs) -> VMobject:
return self.get_part_by_string(text, **kwargs)
def get_part_by_text(self, selector: Selector) -> VGroup:
return self.select_part(selector)
def set_color_by_text(self, text: str, color: ManimColor, **kwargs):
return self.set_color_by_string(text, color, **kwargs)
def set_color_by_text(self, selector: Selector, color: ManimColor):
return self.set_parts_color(selector, color)
def set_color_by_text_to_color_map(
self, text_to_color_map: dict[str, ManimColor], **kwargs
self, color_map: dict[Selector, ManimColor]
):
return self.set_color_by_string_to_color_map(
text_to_color_map, **kwargs
)
return self.set_parts_color_by_dict(color_map)
def get_text(self) -> str:
return self.get_string()

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import itertools as it
import numpy as np
import pyperclip

View file

@ -13,7 +13,7 @@ if TYPE_CHECKING:
S = TypeVar("S")
def remove_list_redundancies(l: Iterable[T]) -> list[T]:
def remove_list_redundancies(l: Sequence[T]) -> list[T]:
"""
Used instead of list(set(l)) to maintain order
Keeps the last occurrence of each element
@ -40,14 +40,14 @@ def list_difference_update(l1: Iterable[T], l2: Iterable[T]) -> list[T]:
return [e for e in l1 if e not in l2]
def adjacent_n_tuples(objects: Iterable[T], n: int) -> zip[tuple[T, T]]:
def adjacent_n_tuples(objects: Sequence[T], n: int) -> zip[tuple[T, T]]:
return zip(*[
[*objects[k:], *objects[:k]]
for k in range(n)
])
def adjacent_pairs(objects: Iterable[T]) -> zip[tuple[T, T]]:
def adjacent_pairs(objects: Sequence[T]) -> zip[tuple[T, T]]:
return adjacent_n_tuples(objects, 2)