diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py
index 90ffa76f..486007dd 100644
--- a/manimlib/animation/transform_matching_parts.py
+++ b/manimlib/animation/transform_matching_parts.py
@@ -12,7 +12,7 @@ from manimlib.animation.transform import ReplacementTransform
from manimlib.animation.transform import Transform
from manimlib.mobject.mobject import Mobject
from manimlib.mobject.mobject import Group
-from manimlib.mobject.svg.mtex_mobject import MTex
+from manimlib.mobject.svg.mtex_mobject import LabelledString
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.config_ops import digest_config
@@ -153,15 +153,16 @@ class TransformMatchingTex(TransformMatchingParts):
return mobject.get_tex()
-class TransformMatchingMTex(AnimationGroup):
+class TransformMatchingString(AnimationGroup):
CONFIG = {
"key_map": dict(),
+ "transform_mismatches_class": None,
}
- def __init__(self, source_mobject: MTex, target_mobject: MTex, **kwargs):
+ def __init__(self, source_mobject: LabelledString, target_mobject: LabelledString, **kwargs):
digest_config(self, kwargs)
- assert isinstance(source_mobject, MTex)
- assert isinstance(target_mobject, MTex)
+ assert isinstance(source_mobject, LabelledString)
+ assert isinstance(target_mobject, LabelledString)
anims = []
rest_source_submobs = source_mobject.submobjects.copy()
rest_target_submobs = target_mobject.submobjects.copy()
@@ -207,7 +208,7 @@ class TransformMatchingMTex(AnimationGroup):
elif isinstance(key, range):
indices.extend(key)
elif isinstance(key, str):
- all_parts = mobject.get_parts_by_tex(key)
+ all_parts = mobject.get_parts_by_string(key)
indices.extend(it.chain(*[
mobject.indices_of_part(part) for part in all_parts
]))
@@ -228,31 +229,34 @@ class TransformMatchingMTex(AnimationGroup):
target_mobject.get_specified_substrings()
)
), key=len, reverse=True)
- for part_tex_string in common_specified_substrings:
+ for part_string in common_specified_substrings:
add_anim_from(
- FadeTransformPieces, MTex.get_parts_by_tex, part_tex_string
+ FadeTransformPieces, LabelledString.get_parts_by_string, part_string
)
- common_submob_tex_strings = {
- source_submob.get_tex() for source_submob in source_mobject
+ common_submob_strings = {
+ source_submob.get_string() for source_submob in source_mobject
}.intersection({
- target_submob.get_tex() for target_submob in target_mobject
+ target_submob.get_string() for target_submob in target_mobject
})
- for tex_string in common_submob_tex_strings:
+ for substr in common_submob_strings:
add_anim_from(
FadeTransformPieces,
lambda mobject, attr: VGroup(*[
VGroup(mob) for mob in mobject
- if mob.get_tex() == attr
+ if mob.get_string() == attr
]),
- tex_string
+ substr
)
- anims.append(FadeOutToPoint(
- VGroup(*rest_source_submobs), target_mobject.get_center(), **kwargs
- ))
- anims.append(FadeInFromPoint(
- VGroup(*rest_target_submobs), source_mobject.get_center(), **kwargs
- ))
+ if self.transform_mismatches_class is not None:
+ anims.append(self.transform_mismatches_class(fade_source, fade_target, **kwargs))
+ else:
+ anims.append(FadeOutToPoint(
+ VGroup(*rest_source_submobs), target_mobject.get_center(), **kwargs
+ ))
+ anims.append(FadeInFromPoint(
+ VGroup(*rest_target_submobs), source_mobject.get_center(), **kwargs
+ ))
super().__init__(*anims)
diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py
index a14004e7..efc542b2 100644
--- a/manimlib/mobject/svg/mtex_mobject.py
+++ b/manimlib/mobject/svg/mtex_mobject.py
@@ -5,8 +5,9 @@ import colour
import itertools as it
from types import MethodType
from typing import Iterable, Union, Sequence
+from abc import abstractmethod
-from manimlib.constants import BLACK, WHITE
+from manimlib.constants import WHITE
from manimlib.mobject.svg.svg_mobject import SVGMobject
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.color import color_to_int_rgb
@@ -24,17 +25,15 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.mobject.types.vectorized_mobject import VMobject
ManimColor = Union[str, colour.Color, Sequence[float]]
+ Span = tuple[int, int]
SCALE_FACTOR_PER_FONT_POINT = 0.001
-class _TexSVG(SVGMobject):
+class _StringSVG(SVGMobject):
CONFIG = {
"height": None,
- "svg_default": {
- "fill_color": BLACK,
- },
"stroke_width": 0,
"stroke_color": WHITE,
"path_string_config": {
@@ -44,75 +43,29 @@ class _TexSVG(SVGMobject):
}
-class MTex(_TexSVG):
+class LabelledString(_StringSVG):
+ """
+ An abstract base class for `MTex` and `MarkupText`
+ """
CONFIG = {
"base_color": WHITE,
- "font_size": 48,
- "alignment": "\\centering",
- "tex_environment": "align*",
- "isolate": [],
- "tex_to_color_map": {},
- "use_plain_tex": False,
+ "use_plain_file": False,
}
def __init__(self, string: str, **kwargs):
- digest_config(self, kwargs)
- string = string.strip()
- # Prevent from passing an empty string.
- if not string:
- string = "\\quad"
- self.tex_string = string
self.string = string
super().__init__(**kwargs)
- self.set_color_by_tex_to_color_map(self.tex_to_color_map)
- self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size)
-
- @property
- def hash_seed(self) -> tuple:
- return (
- self.__class__.__name__,
- self.svg_default,
- self.path_string_config,
- self.string,
- self.base_color,
- self.alignment,
- self.tex_environment,
- self.isolate,
- self.tex_to_color_map,
- self.use_plain_tex
- )
-
def get_file_path(self, use_plain_file: bool = False) -> str:
if use_plain_file:
content = self.plain_string
else:
content = self.labelled_string
+ return self.get_file_path_by_content(content)
- full_tex = self.get_tex_file_body(content)
- with display_during_execution(f"Writing \"{self.string}\""):
- file_path = self.tex_to_svg_file_path(full_tex)
- return file_path
-
- def get_tex_file_body(self, content: str) -> str:
- if self.tex_environment:
- content = "\n".join([
- f"\\begin{{{self.tex_environment}}}",
- content,
- f"\\end{{{self.tex_environment}}}"
- ])
- if self.alignment:
- content = "\n".join([self.alignment, content])
-
- tex_config = get_tex_config()
- return tex_config["tex_body"].replace(
- tex_config["text_to_replace"],
- content
- )
-
- @staticmethod
- def tex_to_svg_file_path(tex_file_content: str) -> str:
- return tex_to_svg_file(tex_file_content)
+ @abstractmethod
+ def get_file_path_by_content(self, content: str) -> str:
+ return ""
def generate_mobject(self) -> None:
super().generate_mobject()
@@ -125,13 +78,9 @@ class MTex(_TexSVG):
for glyph in self.submobjects
]
- if any([
- self.use_plain_tex,
- self.color_cmd_repl_items,
- self.base_color in (BLACK, WHITE)
- ]):
+ if self.use_plain_file or self.has_predefined_colors:
file_path = self.get_file_path(use_plain_file=True)
- glyphs = _TexSVG(file_path).submobjects
+ glyphs = _StringSVG(file_path).submobjects
for glyph, plain_glyph in zip(self.submobjects, glyphs):
glyph.set_fill(plain_glyph.get_fill_color())
else:
@@ -142,21 +91,19 @@ class MTex(_TexSVG):
submob_labels, glyphs_lists = self.group_neighbours(
glyph_labels, glyphs
)
- submobjects = [
- VGroup(*glyph_list)
- for glyph_list in glyphs_lists
- ]
- submob_tex_strings = self.get_submob_tex_strings(submob_labels)
- for submob, label, submob_tex in zip(
- submobjects, submob_labels, submob_tex_strings
+ submob_strings = self.get_submob_strings(submob_labels)
+ submobjects = []
+ for glyph_list, label, submob_string in zip(
+ glyphs_lists, submob_labels, submob_strings
):
+ submob = VGroup(*glyph_list)
submob.label = label
- submob.tex_string = submob_tex
- # Support `get_tex()` method here.
- submob.get_tex = MethodType(lambda inst: inst.tex_string, submob)
+ submob.string = submob_string
+ submob.get_string = MethodType(lambda inst: inst.string, submob)
+ submobjects.append(submob)
self.set_submobjects(submobjects)
- ## Static methods
+ # Toolkits
@staticmethod
def color_to_label(color: ManimColor) -> int:
@@ -167,23 +114,14 @@ class MTex(_TexSVG):
return -1
return rgb
- @staticmethod
- def get_color_command(label: int) -> str:
- if label == -1:
- label = 16777215 # white
- rg, b = divmod(label, 256)
- r, g = divmod(rg, 256)
- return "".join([
- "\\color[RGB]",
- "{",
- ",".join(map(str, (r, g, b))),
- "}"
- ])
-
@staticmethod
def get_neighbouring_pairs(iterable: Iterable) -> list:
return list(adjacent_pairs(iterable))[:-1]
+ @staticmethod
+ def span_contains(span_0: Span, span_1: Span) -> bool:
+ return span_0[0] <= span_1[0] and span_0[1] >= span_1[1]
+
@staticmethod
def group_neighbours(
labels: Iterable[object],
@@ -211,79 +149,412 @@ class MTex(_TexSVG):
@staticmethod
def find_region_index(val: int, seq: list[int]) -> int:
- # Returns an integer in `range(len(seq) + 1)` satisfying
- # `seq[result - 1] <= val < seq[result]`
- if not seq:
- return 0
- if val >= seq[-1]:
- return len(seq)
- result = 0
- while val >= seq[result]:
- result += 1
+ # Returns an integer in `range(-1, len(seq))` satisfying
+ # `seq[result] <= val < seq[result + 1]`.
+ # `seq` should be sorted in ascending order.
+ if not seq or val < seq[0]:
+ return -1
+ result = len(seq) - 1
+ while val < seq[result]:
+ result -= 1
return result
@staticmethod
- def lstrip(index: int, skipped_spans: list[tuple[int, int]]) -> int:
- index_seq = list(it.chain(*skipped_spans))
- region_index = MTex.find_region_index(index, index_seq)
- if region_index % 2 == 1:
+ def replace_str_by_spans(
+ substr: str, span_repl_dict: dict[Span, str]
+ ) -> str:
+ if not span_repl_dict:
+ return substr
+
+ spans = sorted(span_repl_dict.keys())
+ if not all(
+ span_0[1] <= span_1[0]
+ for span_0, span_1 in LabelledString.get_neighbouring_pairs(spans)
+ ):
+ raise ValueError("Overlapping replacement")
+
+ span_ends, span_begins = zip(*spans)
+ pieces = [
+ substr[slice(*span)]
+ for span in zip(
+ (0, *span_begins),
+ (*span_ends, len(substr))
+ )
+ ]
+ repl_strs = [*[span_repl_dict[span] for span in spans], ""]
+ return "".join(it.chain(*zip(pieces, repl_strs)))
+
+ @staticmethod
+ def get_span_replacement_dict(
+ inserted_string_pairs: list[tuple[Span, tuple[str, str]]],
+ other_repl_items: list[tuple[Span, str]]
+ ) -> dict[Span, str]:
+ if not inserted_string_pairs:
+ return other_repl_items.copy()
+
+ indices, _, _, inserted_strings = zip(*sorted([
+ (
+ span[flag],
+ -flag,
+ -span[1 - flag],
+ str_pair[flag]
+ )
+ for span, str_pair in inserted_string_pairs
+ for flag in range(2)
+ ]))
+ result = {
+ (index, index): "".join(inserted_strs)
+ for index, inserted_strs in zip(*LabelledString.group_neighbours(
+ indices, inserted_strings
+ ))
+ }
+ result.update(other_repl_items)
+ return result
+
+ @property
+ def skipped_spans(self) -> list[Span]:
+ return []
+
+ def lstrip(self, index: int) -> int:
+ index_seq = list(it.chain(*self.skipped_spans))
+ region_index = self.find_region_index(index, index_seq)
+ if region_index % 2 == 0:
+ return index_seq[region_index + 1]
+ return index
+
+ def rstrip(self, index: int) -> int:
+ index_seq = list(it.chain(*self.skipped_spans))
+ region_index = self.find_region_index(index - 1, index_seq)
+ if region_index % 2 == 0:
return index_seq[region_index]
return index
- @staticmethod
- def rstrip(index: int, skipped_spans: list[tuple[int, int]]) -> int:
- index_seq = list(it.chain(*skipped_spans))
- region_index = MTex.find_region_index(index - 1, index_seq)
- if region_index % 2 == 1:
- return index_seq[region_index - 1]
- return index
-
- @staticmethod
- def strip(
- tex_span: tuple[int, int], skipped_spans: list[tuple[int, int]]
- ) -> tuple[int, int] | None:
+ def strip(self, span: Span) -> Span | None:
result = (
- MTex.lstrip(tex_span[0], skipped_spans),
- MTex.rstrip(tex_span[1], skipped_spans)
+ self.lstrip(span[0]),
+ self.rstrip(span[1])
)
if result[0] >= result[1]:
return None
return result
@staticmethod
- def lslide(index: int, slid_spans: list[tuple[int, int]]) -> int:
- slide_dict = dict(slid_spans)
+ def lslide(index: int, slid_spans: list[Span]) -> int:
+ slide_dict = dict(sorted(slid_spans))
while index in slide_dict.keys():
index = slide_dict[index]
return index
@staticmethod
- def rslide(index: int, slid_spans: list[tuple[int, int]]) -> int:
- slide_dict = dict([
+ def rslide(index: int, slid_spans: list[Span]) -> int:
+ slide_dict = dict(sorted([
slide_span[::-1] for slide_span in slid_spans
- ])
+ ], reverse=True))
while index in slide_dict.keys():
index = slide_dict[index]
return index
@staticmethod
- def slide(
- tex_span: tuple[int, int], slid_spans: list[tuple[int, int]]
- ) -> tuple[int, int] | None:
+ def slide(span: Span, slid_spans: list[Span]) -> Span | None:
result = (
- MTex.lslide(tex_span[0], slid_spans),
- MTex.rslide(tex_span[1], slid_spans)
+ LabelledString.lslide(span[0], slid_spans),
+ LabelledString.rslide(span[1], slid_spans)
)
if result[0] >= result[1]:
return None
return result
- ## Parser
+ # Parser
@property
- def full_span(self) -> tuple[int, int]:
+ def full_span(self) -> Span:
return (0, len(self.string))
+ def get_substrs_to_isolate(self, substrs: list[str]) -> list[str]:
+ result = list(filter(
+ lambda s: s in self.string,
+ remove_list_redundancies(substrs)
+ ))
+ if "" in result:
+ result.remove("")
+ return result
+
+ @property
+ def label_span_list(self) -> list[Span]:
+ return []
+
+ @property
+ def inserted_string_pairs(self) -> list[tuple[Span, tuple[str, str]]]:
+ return []
+
+ @property
+ def command_repl_items(self) -> list[tuple[Span, str]]:
+ return []
+
+ @abstractmethod
+ def has_predefined_colors(self) -> bool:
+ return False
+
+ @property
+ def plain_string(self) -> str:
+ return self.string
+
+ @property
+ def labelled_string(self) -> str:
+ return self.replace_str_by_spans(
+ self.string, self.get_span_replacement_dict(
+ self.inserted_string_pairs,
+ self.command_repl_items
+ )
+ )
+
+ @property
+ def ignored_indices_for_submob_strings(self) -> list[int]:
+ return []
+
+ def handle_submob_string(self, substr: str, string_span: Span) -> str:
+ return substr
+
+ def get_submob_strings(self, submob_labels: list[int]) -> list[str]:
+ ordered_spans = [
+ self.label_span_list[label] if label != -1 else self.full_span
+ for label in submob_labels
+ ]
+ ordered_containing_labels = [
+ self.containing_labels_dict[span]
+ for span in ordered_spans
+ ]
+ ordered_span_begins, ordered_span_ends = zip(*ordered_spans)
+ string_span_begins = [
+ prev_end if prev_label in containing_labels else curr_begin
+ for prev_end, prev_label, containing_labels, curr_begin in zip(
+ ordered_span_ends[:-1], submob_labels[:-1],
+ ordered_containing_labels[1:], ordered_span_begins[1:]
+ )
+ ]
+ string_span_ends = [
+ next_begin if next_label in containing_labels else curr_end
+ for next_begin, next_label, containing_labels, curr_end in zip(
+ ordered_span_begins[1:], submob_labels[1:],
+ ordered_containing_labels[:-1], ordered_span_ends[:-1]
+ )
+ ]
+ string_spans = list(zip(
+ (ordered_span_begins[0], *string_span_begins),
+ (*string_span_ends, ordered_span_ends[-1])
+ ))
+
+ command_spans = [span for span, _ in self.command_repl_items]
+ slid_spans = list(it.chain(
+ self.skipped_spans,
+ command_spans,
+ [
+ (index, index + 1)
+ for index in self.ignored_indices_for_submob_strings
+ ]
+ ))
+ result = []
+ for string_span in string_spans:
+ string_span = self.slide(string_span, slid_spans)
+ if string_span is None:
+ result.append("")
+ continue
+
+ span_repl_dict = {
+ tuple([index - string_span[0] for index in cmd_span]): ""
+ for cmd_span in command_spans
+ if self.span_contains(string_span, cmd_span)
+ }
+ substr = self.string[slice(*string_span)]
+ substr = self.replace_str_by_spans(substr, span_repl_dict)
+ substr = self.handle_submob_string(substr, string_span)
+ result.append(substr)
+ return result
+
+ # Selector
+
+ @property
+ def containing_labels_dict(self) -> dict[Span, list[int]]:
+ label_span_list = self.label_span_list
+ result = {
+ span: []
+ for span in label_span_list
+ }
+ for span_0 in label_span_list:
+ for span_index, span_1 in enumerate(label_span_list):
+ if self.span_contains(span_0, span_1):
+ result[span_0].append(span_index)
+ elif span_0[0] < span_1[0] < span_0[1] < span_1[1]:
+ string_0, string_1 = [
+ self.string[slice(*span)]
+ for span in [span_0, span_1]
+ ]
+ raise ValueError(
+ "Partially overlapping substrings detected: "
+ f"'{string_0}' and '{string_1}'"
+ )
+ result[self.full_span] = list(range(-1, len(label_span_list)))
+ return result
+
+ def find_span_components_of_custom_span(
+ self, custom_span: Span
+ ) -> list[Span] | None:
+ span_choices = sorted(filter(
+ lambda span: self.span_contains(custom_span, span),
+ self.label_span_list
+ ))
+ # Choose spans that reach the farthest.
+ span_choices_dict = dict(span_choices)
+
+ result = []
+ span_begin, span_end = custom_span
+ span_begin = self.rstrip(span_begin)
+ span_end = self.rstrip(span_end)
+ while span_begin != span_end:
+ span_begin = self.lstrip(span_begin)
+ if span_begin not in span_choices_dict.keys():
+ return None
+ next_begin = span_choices_dict[span_begin]
+ result.append((span_begin, next_begin))
+ span_begin = next_begin
+ return result
+
+ def get_part_by_custom_span(self, custom_span: Span) -> VGroup:
+ spans = self.find_span_components_of_custom_span(custom_span)
+ if spans is None:
+ substr = self.string[slice(*custom_span)]
+ raise ValueError(f"Failed to match mobjects from \"{substr}\"")
+
+ labels = set(it.chain(*[
+ self.containing_labels_dict[span]
+ for span in spans
+ ]))
+ return VGroup(*filter(
+ lambda submob: submob.label in labels,
+ self.submobjects
+ ))
+
+ def get_parts_by_string(self, substr: str) -> VGroup:
+ return VGroup(*[
+ self.get_part_by_custom_span(match_obj.span())
+ for match_obj in re.finditer(re.escape(substr), self.string)
+ ])
+
+ def get_part_by_string(self, substr: str, index: int = 0) -> VMobject:
+ all_parts = self.get_parts_by_string(substr)
+ return all_parts[index]
+
+ def set_color_by_string(self, substr: str, color: ManimColor):
+ self.get_parts_by_string(substr).set_color(color)
+ return self
+
+ def set_color_by_string_to_color_map(
+ self, string_to_color_map: dict[str, ManimColor]
+ ):
+ for substr, color in string_to_color_map.items():
+ self.set_color_by_string(substr, color)
+ return self
+
+ def indices_of_part(self, part: Iterable[VMobject]) -> list[int]:
+ indices = [
+ index for index, submob in enumerate(self.submobjects)
+ if submob in part
+ ]
+ if not indices:
+ raise ValueError("Failed to find part")
+ return indices
+
+ def indices_of_part_by_string(
+ self, substr: str, index: int = 0
+ ) -> list[int]:
+ part = self.get_part_by_string(substr, index=index)
+ return self.indices_of_part(part)
+
+ @property
+ def specified_substrings(self) -> list[str]:
+ return []
+
+ def get_specified_substrings(self) -> list[str]:
+ return self.specified_substrings
+
+ @property
+ def isolated_substrings(self) -> list[str]:
+ return remove_list_redundancies([
+ self.string[slice(*span)]
+ for span in self.label_span_list
+ ])
+
+ def get_isolated_substrings(self) -> list[str]:
+ return self.isolated_substrings
+
+ def get_string(self) -> str:
+ return self.string
+
+
+class MTex(LabelledString):
+ CONFIG = {
+ "font_size": 48,
+ "alignment": "\\centering",
+ "tex_environment": "align*",
+ "isolate": [],
+ "tex_to_color_map": {},
+ "use_plain_file": False,
+ }
+
+ def __init__(self, tex_string: str, **kwargs):
+ tex_string = tex_string.strip()
+ # Prevent from passing an empty string.
+ if not tex_string:
+ tex_string = "\\quad"
+ self.tex_string = tex_string
+ super().__init__(tex_string, **kwargs)
+
+ self.set_color_by_tex_to_color_map(self.tex_to_color_map)
+ self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size)
+
+ @property
+ def hash_seed(self) -> tuple:
+ return (
+ self.__class__.__name__,
+ self.svg_default,
+ self.path_string_config,
+ self.tex_string,
+ self.base_color,
+ self.alignment,
+ self.tex_environment,
+ self.isolate,
+ self.tex_to_color_map,
+ self.use_plain_file
+ )
+
+ def get_file_path_by_content(self, content: str) -> str:
+ full_tex = self.get_tex_file_body(content)
+ with display_during_execution(f"Writing \"{self.string}\""):
+ file_path = self.tex_to_svg_file_path(full_tex)
+ return file_path
+
+ def get_tex_file_body(self, content: str) -> str:
+ if self.tex_environment:
+ content = "\n".join([
+ f"\\begin{{{self.tex_environment}}}",
+ content,
+ f"\\end{{{self.tex_environment}}}"
+ ])
+ if self.alignment:
+ content = "\n".join([self.alignment, content])
+
+ tex_config = get_tex_config()
+ return tex_config["tex_body"].replace(
+ tex_config["text_to_replace"],
+ content
+ )
+
+ @staticmethod
+ def tex_to_svg_file_path(tex_file_content: str) -> str:
+ return tex_to_svg_file(tex_file_content)
+
+ # Parser
+
@property
def backslash_indices(self) -> list[int]:
# Newlines (`\\`) don't count.
@@ -293,9 +564,7 @@ class MTex(_TexSVG):
if len(match_obj.group()) % 2 == 1
]
- def get_left_and_right_brace_indices(
- self
- ) -> tuple[list[tuple[int, int]], list[tuple[int, int]]]:
+ def get_brace_indices_lists(self) -> tuple[list[Span], list[Span]]:
string = self.string
indices = list(filter(
lambda index: index - 1 not in self.backslash_indices,
@@ -322,15 +591,15 @@ class MTex(_TexSVG):
return left_brace_indices, right_brace_indices
@property
- def left_brace_indices(self) -> list[tuple[int, int]]:
- return self.get_left_and_right_brace_indices()[0]
+ def left_brace_indices(self) -> list[Span]:
+ return self.get_brace_indices_lists()[0]
@property
- def right_brace_indices(self) -> list[tuple[int, int]]:
- return self.get_left_and_right_brace_indices()[1]
+ def right_brace_indices(self) -> list[Span]:
+ return self.get_brace_indices_lists()[1]
@property
- def skipped_spans(self) -> list[tuple[int, int]]:
+ def skipped_spans(self) -> list[Span]:
return [
match_obj.span()
for match_obj in re.finditer(r"\s*([_^])\s*|(\s+)", self.string)
@@ -338,24 +607,15 @@ class MTex(_TexSVG):
or match_obj.start(1) - 1 not in self.backslash_indices
]
- def lstrip_span(self, index: int) -> int:
- return self.lstrip(index, self.skipped_spans)
-
- def rstrip_span(self, index: int) -> int:
- return self.rstrip(index, self.skipped_spans)
-
- def strip_span(self, index: int) -> int:
- return self.strip(index, self.skipped_spans)
-
@property
- def script_char_spans(self) -> list[tuple[int, int]]:
+ def script_char_spans(self) -> list[Span]:
return list(filter(
- lambda tex_span: self.string[slice(*tex_span)].strip(),
+ lambda span: self.string[slice(*span)].strip(),
self.skipped_spans
))
@property
- def script_content_spans(self) -> list[tuple[int, int]]:
+ def script_content_spans(self) -> list[Span]:
result = []
brace_indices_dict = dict(zip(
self.left_brace_indices, self.right_brace_indices
@@ -380,7 +640,7 @@ class MTex(_TexSVG):
return result
@property
- def double_braces_spans(self) -> list[tuple[int, int]]:
+ def double_braces_spans(self) -> list[Span]:
# Match paired double braces (`{{...}}`).
result = []
reversed_brace_indices_dict = dict(zip(
@@ -405,19 +665,12 @@ class MTex(_TexSVG):
@property
def additional_substrings(self) -> list[str]:
- result = remove_list_redundancies(list(it.chain(
+ return self.get_substrs_to_isolate(list(it.chain(
self.tex_to_color_map.keys(),
self.isolate
)))
- if "" in result:
- result.remove("")
- return result
- def get_tex_span_lists(
- self
- ) -> tuple[list[tuple[int, int]], list[tuple[int, int]]]:
- result = []
- extended_result = []
+ def get_label_span_list(self, extended: bool) -> list[Span]:
script_content_spans = self.script_content_spans
script_spans = [
(script_char_span[0], script_content_span[1])
@@ -425,65 +678,54 @@ class MTex(_TexSVG):
self.script_char_spans, script_content_spans
)
]
- tex_spans = remove_list_redundancies([
+ spans = remove_list_redundancies([
self.full_span,
*self.double_braces_spans,
*filter(lambda stripped_span: stripped_span is not None, [
- self.strip_span(match_obj.span())
+ self.strip(match_obj.span())
for substr in self.additional_substrings
for match_obj in re.finditer(re.escape(substr), self.string)
]),
*script_content_spans
])
- for tex_span in tex_spans:
- if tex_span in script_content_spans:
- result.append(tex_span)
- extended_result.append(tex_span)
+ result = []
+ for span in spans:
+ if span in script_content_spans:
continue
- span_begin, span_end = tex_span
- shrinked_span = (span_begin, self.rslide(span_end, script_spans))
- extended_span = (span_begin, self.lslide(span_end, script_spans))
- if shrinked_span[0] >= shrinked_span[1]:
+ span_begin, span_end = span
+ shrinked_end = self.rslide(span_end, script_spans)
+ if span_begin >= shrinked_end:
continue
+ shrinked_span = (span_begin, shrinked_end)
if shrinked_span in result:
continue
result.append(shrinked_span)
- extended_result.append(extended_span)
- return result, extended_result
+
+ if extended:
+ result = [
+ (span_begin, self.lslide(span_end, script_spans))
+ for span_begin, span_end in result
+ ]
+ return script_content_spans + result
@property
- def tex_span_list(self) -> list[tuple[int, int]]:
- return self.get_tex_span_lists()[0]
+ def label_span_list(self) -> list[Span]:
+ return self.get_label_span_list(extended=False)
@property
- def extended_tex_span_list(self) -> list[tuple[int, int]]:
- return self.get_tex_span_lists()[1]
+ def inserted_string_pairs(self) -> list[tuple[Span, tuple[str, str]]]:
+ return [
+ (span, (
+ "{{" + self.get_color_command_by_label(label),
+ "}}"
+ ))
+ for label, span in enumerate(
+ self.get_label_span_list(extended=True)
+ )
+ ]
@property
- def containing_labels_dict(self) -> dict[tuple[int, int], list[int]]:
- tex_span_list = self.tex_span_list
- result = {
- tex_span: []
- for tex_span in tex_span_list
- }
- for span_0 in tex_span_list:
- for span_index, span_1 in enumerate(tex_span_list):
- if span_0[0] <= span_1[0] and span_1[1] <= span_0[1]:
- result[span_0].append(span_index)
- elif span_0[0] < span_1[0] < span_0[1] < span_1[1]:
- string_0, string_1 = [
- self.string[slice(*tex_span)]
- for tex_span in [span_0, span_1]
- ]
- raise ValueError(
- "Partially overlapping substrings detected: "
- f"'{string_0}' and '{string_1}'"
- )
- result[self.full_span] = list(range(-1, len(tex_span_list)))
- return result
-
- @property
- def color_cmd_repl_items(self) -> list[tuple[tuple[int, int], str]]:
+ def command_repl_items(self) -> list[tuple[Span, str]]:
color_related_command_dict = {
"color": (1, False),
"textcolor": (1, False),
@@ -509,7 +751,7 @@ class MTex(_TexSVG):
n_braces, substitute_cmd = color_related_command_dict[cmd_name]
span_end = right_brace_indices[self.find_region_index(
cmd_end, right_brace_indices
- ) + n_braces - 1] + 1
+ ) + n_braces] + 1
if substitute_cmd:
repl_str = "\\" + cmd_name + n_braces * "{white}"
else:
@@ -518,237 +760,84 @@ class MTex(_TexSVG):
return result
@property
- def span_repl_dict(self) -> dict[tuple[int, int], str]:
- indices, _, _, cmd_strings = zip(*sorted([
- (
- tex_span[flag],
- -flag,
- -tex_span[1 - flag],
- ("{{" + self.get_color_command(label), "}}")[flag]
- )
- for label, tex_span in enumerate(self.extended_tex_span_list)
- for flag in range(2)
- ]))
- result = {
- (index, index): "".join(cmd_strs)
- for index, cmd_strs in zip(*self.group_neighbours(
- indices, cmd_strings
- ))
- }
- result.update(self.color_cmd_repl_items)
- return result
+ def has_predefined_colors(self) -> bool:
+ return bool(self.command_repl_items)
+
+ @staticmethod
+ def get_color_command_by_label(label: int) -> str:
+ if label == -1:
+ label = 16777215 # white
+ rg, b = divmod(label, 256)
+ r, g = divmod(rg, 256)
+ return "".join([
+ "\\color[RGB]",
+ "{",
+ ",".join(map(str, (r, g, b))),
+ "}"
+ ])
@property
def plain_string(self) -> str:
return "".join([
"{{",
- self.get_color_command(self.color_to_label(self.base_color)),
+ self.get_color_command_by_label(
+ self.color_to_label(self.base_color)
+ ),
self.string,
"}}"
])
@property
- def labelled_string(self) -> str:
- if not self.span_repl_dict:
- return self.string
+ def ignored_indices_for_submob_strings(self) -> list[int]:
+ return self.left_brace_indices + self.right_brace_indices
- spans = sorted(self.span_repl_dict.keys())
- if not all(
- span_0[1] <= span_1[0]
- for span_0, span_1 in self.get_neighbouring_pairs(spans)
- ):
- raise ValueError("Failed to generate the labelled string")
-
- span_ends, span_begins = zip(*spans)
- string_pieces = [
- self.string[slice(*span)]
- for span in zip(
- (0, *span_begins),
- (*span_ends, len(self.string))
- )
- ]
- repl_strs = [
- self.span_repl_dict[span]
- for span in spans
- ]
- repl_strs.append("")
- return "".join(it.chain(*zip(string_pieces, repl_strs)))
-
- def get_submob_tex_strings(self, submob_labels: list[int]) -> list[str]:
- ordered_tex_spans = [
- self.tex_span_list[label] if label != -1 else self.full_span
- for label in submob_labels
- ]
- ordered_containing_labels = [
- self.containing_labels_dict[tex_span]
- for tex_span in ordered_tex_spans
- ]
- ordered_span_begins, ordered_span_ends = zip(*ordered_tex_spans)
- string_span_begins = [
- prev_end if prev_label in containing_labels else curr_begin
- for prev_end, prev_label, containing_labels, curr_begin in zip(
- ordered_span_ends[:-1], submob_labels[:-1],
- ordered_containing_labels[1:], ordered_span_begins[1:]
- )
- ]
- string_span_ends = [
- next_begin if next_label in containing_labels else curr_end
- for next_begin, next_label, containing_labels, curr_end in zip(
- ordered_span_begins[1:], submob_labels[1:],
- ordered_containing_labels[:-1], ordered_span_ends[:-1]
- )
- ]
- string_spans = list(zip(
- (ordered_span_begins[0], *string_span_begins),
- (*string_span_ends, ordered_span_ends[-1])
- ))
-
- string = self.string
- left_brace_indices = self.left_brace_indices
- right_brace_indices = self.right_brace_indices
- slid_spans = self.skipped_spans + [
- (index, index + 1)
- for index in left_brace_indices + right_brace_indices
- ]
- result = []
- for str_span in string_spans:
- str_span = self.strip_span(str_span)
- if str_span is None:
- continue
- str_span = self.slide(str_span, slid_spans)
- if str_span is None:
- continue
- unclosed_left_braces = 0
- unclosed_right_braces = 0
- for index in range(*str_span):
- if index in left_brace_indices:
- unclosed_left_braces += 1
- elif index in right_brace_indices:
- if unclosed_left_braces == 0:
- unclosed_right_braces += 1
- else:
- unclosed_left_braces -= 1
- result.append("".join([
- unclosed_right_braces * "{",
- string[slice(*str_span)],
- unclosed_left_braces * "}"
- ]))
- return result
+ def handle_submob_string(self, substr: str, string_span: Span) -> str:
+ unclosed_left_braces = 0
+ unclosed_right_braces = 0
+ for index in range(*string_span):
+ if index in self.left_brace_indices:
+ unclosed_left_braces += 1
+ elif index in self.right_brace_indices:
+ if unclosed_left_braces == 0:
+ unclosed_right_braces += 1
+ else:
+ unclosed_left_braces -= 1
+ return "".join([
+ unclosed_right_braces * "{",
+ substr,
+ unclosed_left_braces * "}"
+ ])
@property
def specified_substrings(self) -> list[str]:
return remove_list_redundancies([
self.string[slice(*double_braces_span)]
for double_braces_span in self.double_braces_spans
- ] + list(filter(
- lambda s: s in self.string,
- self.additional_substrings
- )))
+ ] + self.additional_substrings)
- def get_specified_substrings(self) -> list[str]:
- return self.specified_substrings
+ # Method alias
- @property
- def isolated_substrings(self) -> list[str]:
- return remove_list_redundancies([
- self.string[slice(*tex_span)]
- for tex_span in self.tex_span_list
- ])
+ def get_parts_by_tex(self, substr: str) -> VGroup:
+ return self.get_parts_by_string(substr)
- def get_isolated_substrings(self) -> list[str]:
- return self.isolated_substrings
+ def get_part_by_tex(self, substr: str, index: int = 0) -> VMobject:
+ return self.get_part_by_string(substr, index)
- ## Selector
-
- def find_span_components_of_custom_span(
- self,
- custom_span: tuple[int, int]
- ) -> list[tuple[int, int]] | None:
- tex_span_choices = sorted(filter(
- lambda tex_span: all([
- tex_span[0] >= custom_span[0],
- tex_span[1] <= custom_span[1]
- ]),
- self.tex_span_list
- ))
- # Choose spans that reach the farthest.
- tex_span_choices_dict = dict(tex_span_choices)
-
- result = []
- span_begin, span_end = custom_span
- span_begin = self.rstrip_span(span_begin)
- span_end = self.rstrip_span(span_end)
- while span_begin != span_end:
- span_begin = self.lstrip_span(span_begin)
- if span_begin not in tex_span_choices_dict.keys():
- return None
- next_begin = tex_span_choices_dict[span_begin]
- result.append((span_begin, next_begin))
- span_begin = next_begin
- return result
-
- def get_part_by_custom_span(self, custom_span: tuple[int, int]) -> VGroup:
- tex_spans = self.find_span_components_of_custom_span(
- custom_span
- )
- if tex_spans is None:
- tex = self.string[slice(*custom_span)]
- raise ValueError(f"Failed to match mobjects from tex: \"{tex}\"")
-
- labels = set(it.chain(*[
- self.containing_labels_dict[tex_span]
- for tex_span in tex_spans
- ]))
- return VGroup(*filter(
- lambda submob: submob.label in labels,
- self.submobjects
- ))
-
- def get_parts_by_tex(self, tex: str) -> VGroup:
- return VGroup(*[
- self.get_part_by_custom_span(match_obj.span())
- for match_obj in re.finditer(
- re.escape(tex), self.string
- )
- ])
-
- def get_part_by_tex(self, tex: str, index: int = 0) -> VMobject:
- all_parts = self.get_parts_by_tex(tex)
- return all_parts[index]
-
- def set_color_by_tex(self, tex: str, color: ManimColor):
- self.get_parts_by_tex(tex).set_color(color)
- return self
+ def set_color_by_tex(self, substr: str, color: ManimColor):
+ return self.set_color_by_string(substr, color)
def set_color_by_tex_to_color_map(
- self,
- tex_to_color_map: dict[str, ManimColor]
+ self, tex_to_color_map: dict[str, ManimColor]
):
- for tex, color in tex_to_color_map.items():
- self.set_color_by_tex(tex, color)
- return self
+ return self.set_color_by_string_to_color_map(tex_to_color_map)
- def indices_of_part(self, part: Iterable[VMobject]) -> list[int]:
- indices = [
- index for index, submob in enumerate(self.submobjects)
- if submob in part
- ]
- if not indices:
- raise ValueError("Failed to find part in tex")
- return indices
-
- def indices_of_part_by_tex(self, tex: str, index: int = 0) -> list[int]:
- part = self.get_part_by_tex(tex, index=index)
- return self.indices_of_part(part)
+ def indices_of_part_by_tex(
+ self, substr: str, index: int = 0
+ ) -> list[int]:
+ return self.indices_of_part_by_string(substr, index)
def get_tex(self) -> str:
- return self.string
-
- def get_submob_tex(self) -> list[str]:
- return [
- submob.get_tex()
- for submob in self.submobjects
- ]
+ return self.get_string()
class MTexText(MTex):
diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py
index a13d1d80..24a5b111 100644
--- a/manimlib/mobject/svg/text_mobject.py
+++ b/manimlib/mobject/svg/text_mobject.py
@@ -2,11 +2,11 @@ from __future__ import annotations
import os
import re
-import typing
-from pathlib import Path
-
+import itertools as it
import xml.sax.saxutils as saxutils
+from pathlib import Path
from contextlib import contextmanager
+import typing
from typing import Iterable, Sequence, Union
import pygments
@@ -17,198 +17,87 @@ from manimpango import MarkupUtils
from manimlib.logger import log
from manimlib.constants import *
-from manimlib.mobject.geometry import Dot
-from manimlib.mobject.svg.svg_mobject import SVGMobject
+from manimlib.mobject.svg.mtex_mobject import LabelledString
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.customization import get_customization
from manimlib.utils.tex_file_writing import tex_hash
from manimlib.utils.config_ops import digest_config
from manimlib.utils.directories import get_downloads_dir
from manimlib.utils.directories import get_text_dir
+from manimlib.utils.iterables import remove_list_redundancies
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.mobject.types.vectorized_mobject import VMobject
+ ManimColor = Union[str, colour.Color, Sequence[float]]
+ Span = tuple[int, int]
TEXT_MOB_SCALE_FACTOR = 0.0076
DEFAULT_LINE_SPACING_SCALE = 0.6
-class _TextParser(object):
- # See https://docs.gtk.org/Pango/pango_markup.html
- # A tag containing two aliases will cause warning,
- # so only use the first key of each group of aliases.
- SPAN_ATTR_KEY_ALIAS_LIST = (
- ("font", "font_desc"),
- ("font_family", "face"),
- ("font_size", "size"),
- ("font_style", "style"),
- ("font_weight", "weight"),
- ("font_variant", "variant"),
- ("font_stretch", "stretch"),
- ("font_features",),
- ("foreground", "fgcolor", "color"),
- ("background", "bgcolor"),
- ("alpha", "fgalpha"),
- ("background_alpha", "bgalpha"),
- ("underline",),
- ("underline_color",),
- ("overline",),
- ("overline_color",),
- ("rise",),
- ("baseline_shift",),
- ("font_scale",),
- ("strikethrough",),
- ("strikethrough_color",),
- ("fallback",),
- ("lang",),
- ("letter_spacing",),
- ("gravity",),
- ("gravity_hint",),
- ("show",),
- ("insert_hyphens",),
- ("allow_breaks",),
- ("line_height",),
- ("text_transform",),
- ("segment",),
- )
- SPAN_ATTR_KEY_CONVERSION = {
- key: key_alias_list[0]
- for key_alias_list in SPAN_ATTR_KEY_ALIAS_LIST
- for key in key_alias_list
- }
-
- TAG_TO_ATTR_DICT = {
- "b": {"font_weight": "bold"},
- "big": {"font_size": "larger"},
- "i": {"font_style": "italic"},
- "s": {"strikethrough": "true"},
- "sub": {"baseline_shift": "subscript", "font_scale": "subscript"},
- "sup": {"baseline_shift": "superscript", "font_scale": "superscript"},
- "small": {"font_size": "smaller"},
- "tt": {"font_family": "monospace"},
- "u": {"underline": "single"},
- }
-
- def __init__(self, text: str = "", is_markup: bool = True):
- self.text = text
- self.is_markup = is_markup
- self.global_attrs = {}
- self.local_attrs = {(0, len(self.text)): {}}
- self.tag_strings = set()
- if is_markup:
- self.parse_markup()
-
- def parse_markup(self) -> None:
- tag_pattern = r"""<(/?)(\w+)\s*((\w+\s*\=\s*('[^']*'|"[^"]*")\s*)*)>"""
- attr_pattern = r"""(\w+)\s*\=\s*(?:(?:'([^']*)')|(?:"([^"]*)"))"""
- start_match_obj_stack = []
- match_obj_pairs = []
- for match_obj in re.finditer(tag_pattern, self.text):
- if not match_obj.group(1):
- start_match_obj_stack.append(match_obj)
- else:
- match_obj_pairs.append((start_match_obj_stack.pop(), match_obj))
- self.tag_strings.add(match_obj.group())
- assert not start_match_obj_stack, "Unclosed tag(s) detected"
-
- for start_match_obj, end_match_obj in match_obj_pairs:
- tag_name = start_match_obj.group(2)
- assert tag_name == end_match_obj.group(2), "Unmatched tag names"
- assert not end_match_obj.group(3), "Attributes shan't exist in ending tags"
- if tag_name == "span":
- attr_dict = {
- match.group(1): match.group(2) or match.group(3)
- for match in re.finditer(attr_pattern, start_match_obj.group(3))
- }
- elif tag_name in _TextParser.TAG_TO_ATTR_DICT.keys():
- assert not start_match_obj.group(3), f"Attributes shan't exist in tag '{tag_name}'"
- attr_dict = _TextParser.TAG_TO_ATTR_DICT[tag_name]
- else:
- raise AssertionError(f"Unknown tag: '{tag_name}'")
-
- text_span = (start_match_obj.end(), end_match_obj.start())
- self.update_local_attrs(text_span, attr_dict)
-
- @staticmethod
- def convert_key_alias(key: str) -> str:
- return _TextParser.SPAN_ATTR_KEY_CONVERSION[key]
-
- @staticmethod
- def update_attr_dict(attr_dict: dict[str, str], key: str, value: typing.Any) -> None:
- converted_key = _TextParser.convert_key_alias(key)
- attr_dict[converted_key] = str(value)
-
- def update_global_attr(self, key: str, value: typing.Any) -> None:
- _TextParser.update_attr_dict(self.global_attrs, key, value)
-
- def update_global_attrs(self, attr_dict: dict[str, typing.Any]) -> None:
- for key, value in attr_dict.items():
- self.update_global_attr(key, value)
-
- def update_local_attr(self, span: tuple[int, int], key: str, value: typing.Any) -> None:
- if span[0] >= span[1]:
- log.warning(f"Span {span} doesn't match any part of the string")
- return
-
- if span in self.local_attrs.keys():
- _TextParser.update_attr_dict(self.local_attrs[span], key, value)
- return
-
- span_triplets = []
- for sp, attr_dict in self.local_attrs.items():
- if sp[1] <= span[0] or span[1] <= sp[0]:
- continue
- span_to_become = (max(sp[0], span[0]), min(sp[1], span[1]))
- spans_to_add = []
- if sp[0] < span[0]:
- spans_to_add.append((sp[0], span[0]))
- if span[1] < sp[1]:
- spans_to_add.append((span[1], sp[1]))
- span_triplets.append((sp, span_to_become, spans_to_add))
- for span_to_remove, span_to_become, spans_to_add in span_triplets:
- attr_dict = self.local_attrs.pop(span_to_remove)
- for span_to_add in spans_to_add:
- self.local_attrs[span_to_add] = attr_dict.copy()
- self.local_attrs[span_to_become] = attr_dict
- _TextParser.update_attr_dict(self.local_attrs[span_to_become], key, value)
-
- def update_local_attrs(self, text_span: tuple[int, int], attr_dict: dict[str, typing.Any]) -> None:
- for key, value in attr_dict.items():
- self.update_local_attr(text_span, key, value)
-
- def remove_tags(self, string: str) -> str:
- for tag_string in self.tag_strings:
- string = string.replace(tag_string, "")
- return string
-
- def get_text_pieces(self) -> list[tuple[str, dict[str, str]]]:
- result = []
- for span in sorted(self.local_attrs.keys()):
- text_piece = self.remove_tags(self.text[slice(*span)])
- if not text_piece:
- continue
- if not self.is_markup:
- text_piece = saxutils.escape(text_piece)
- attr_dict = self.global_attrs.copy()
- attr_dict.update(self.local_attrs[span])
- result.append((text_piece, attr_dict))
- return result
-
- def get_markup_str_with_attrs(self) -> str:
- return "".join([
- f"{text_piece}"
- for text_piece, attr_dict in self.get_text_pieces()
- ])
-
- @staticmethod
- def get_attr_dict_str(attr_dict: dict[str, str]) -> str:
- return " ".join([
- f"{key}='{value}'"
- for key, value in attr_dict.items()
- ])
+# See https://docs.gtk.org/Pango/pango_markup.html
+# A tag containing two aliases will cause warning,
+# so only use the first key of each group of aliases.
+SPAN_ATTR_KEY_ALIAS_LIST = (
+ ("font", "font_desc"),
+ ("font_family", "face"),
+ ("font_size", "size"),
+ ("font_style", "style"),
+ ("font_weight", "weight"),
+ ("font_variant", "variant"),
+ ("font_stretch", "stretch"),
+ ("font_features",),
+ ("foreground", "fgcolor", "color"),
+ ("background", "bgcolor"),
+ ("alpha", "fgalpha"),
+ ("background_alpha", "bgalpha"),
+ ("underline",),
+ ("underline_color",),
+ ("overline",),
+ ("overline_color",),
+ ("rise",),
+ ("baseline_shift",),
+ ("font_scale",),
+ ("strikethrough",),
+ ("strikethrough_color",),
+ ("fallback",),
+ ("lang",),
+ ("letter_spacing",),
+ ("gravity",),
+ ("gravity_hint",),
+ ("show",),
+ ("insert_hyphens",),
+ ("allow_breaks",),
+ ("line_height",),
+ ("text_transform",),
+ ("segment",),
+)
+COLOR_RELATED_KEYS = (
+ "foreground",
+ "background",
+ "underline_color",
+ "overline_color",
+ "strikethrough_color"
+)
+SPAN_ATTR_KEY_CONVERSION = {
+ key: key_alias_list[0]
+ for key_alias_list in SPAN_ATTR_KEY_ALIAS_LIST
+ for key in key_alias_list
+}
+TAG_TO_ATTR_DICT = {
+ "b": {"font_weight": "bold"},
+ "big": {"font_size": "larger"},
+ "i": {"font_style": "italic"},
+ "s": {"strikethrough": "true"},
+ "sub": {"baseline_shift": "subscript", "font_scale": "subscript"},
+ "sup": {"baseline_shift": "superscript", "font_scale": "superscript"},
+ "small": {"font_size": "smaller"},
+ "tt": {"font_family": "monospace"},
+ "u": {"underline": "single"},
+}
# Temporary handler
@@ -223,16 +112,9 @@ class _Alignment:
self.value = _Alignment.VAL_DICT[s.upper()]
-class Text(SVGMobject):
+class MarkupText(LabelledString):
CONFIG = {
- # Mobject
- "stroke_width": 0,
- "svg_default": {
- "color": WHITE,
- },
- "height": None,
- # Text
- "is_markup": False,
+ "is_markup": True,
"font_size": 48,
"lsh": None,
"justify": False,
@@ -240,8 +122,6 @@ class Text(SVGMobject):
"alignment": "LEFT",
"line_width_factor": None,
"font": "",
- "disable_ligatures": True,
- "apply_space_chars": True,
"slant": NORMAL,
"weight": NORMAL,
"gradient": None,
@@ -252,6 +132,7 @@ class Text(SVGMobject):
"t2w": {},
"global_config": {},
"local_configs": {},
+ "isolate": [],
}
def __init__(self, text: str, **kwargs):
@@ -260,10 +141,15 @@ class Text(SVGMobject):
validate_error = MarkupUtils.validate(text)
if validate_error:
raise ValueError(validate_error)
- self.text = text
- self.parser = _TextParser(text, is_markup=self.is_markup)
- super().__init__(**kwargs)
+ self.text = text
+ super().__init__(text, **kwargs)
+
+ if self.t2g:
+ log.warning(
+ "Manim currently cannot parse gradient from svg. "
+ "Please set gradient via `set_color_by_gradient`.",
+ )
if self.gradient:
self.set_color_by_gradient(*self.gradient)
if self.height is None:
@@ -284,8 +170,6 @@ class Text(SVGMobject):
self.alignment,
self.line_width_factor,
self.font,
- self.disable_ligatures,
- self.apply_space_chars,
self.slant,
self.weight,
self.t2c,
@@ -293,71 +177,32 @@ class Text(SVGMobject):
self.t2s,
self.t2w,
self.global_config,
- self.local_configs
+ self.local_configs,
+ self.isolate
)
- def get_file_path(self) -> str:
- full_markup = self.get_full_markup_str()
+ def full2short(self, config: dict) -> None:
+ conversion_dict = {
+ "line_spacing_height": "lsh",
+ "text2color": "t2c",
+ "text2font": "t2f",
+ "text2gradient": "t2g",
+ "text2slant": "t2s",
+ "text2weight": "t2w"
+ }
+ for kwargs in [config, self.CONFIG]:
+ for long_name, short_name in conversion_dict.items():
+ if long_name in kwargs:
+ kwargs[short_name] = kwargs.pop(long_name)
+
+ def get_file_path_by_content(self, content: str) -> str:
svg_file = os.path.join(
- get_text_dir(), tex_hash(full_markup) + ".svg"
+ get_text_dir(), tex_hash(content) + ".svg"
)
if not os.path.exists(svg_file):
- self.markup_to_svg(full_markup, svg_file)
+ self.markup_to_svg(content, svg_file)
return svg_file
- def get_full_markup_str(self) -> str:
- if self.t2g:
- log.warning(
- "Manim currently cannot parse gradient from svg. "
- "Please set gradient via `set_color_by_gradient`.",
- )
-
- config_style_dict = self.generate_config_style_dict()
- global_attr_dict = {
- "line_height": ((self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1) * 0.6,
- "font_family": self.font or get_customization()["style"]["font"],
- "font_size": self.font_size * 1024,
- "font_style": self.slant,
- "font_weight": self.weight,
- # TODO, it seems this doesn't work
- "font_features": "liga=0,dlig=0,clig=0,hlig=0" if self.disable_ligatures else None,
- "foreground": config_style_dict.get("fill", None),
- "alpha": config_style_dict.get("fill-opacity", None)
- }
- global_attr_dict = {
- k: v
- for k, v in global_attr_dict.items()
- if v is not None
- }
- global_attr_dict.update(self.global_config)
- self.parser.update_global_attrs(global_attr_dict)
-
- local_attr_items = [
- (word_or_text_span, {key: value})
- for t2x_dict, key in (
- (self.t2c, "foreground"),
- (self.t2f, "font_family"),
- (self.t2s, "font_style"),
- (self.t2w, "font_weight")
- )
- for word_or_text_span, value in t2x_dict.items()
- ]
- local_attr_items.extend(self.local_configs.items())
- for word_or_text_span, local_config in local_attr_items:
- for text_span in self.find_indexes(word_or_text_span):
- self.parser.update_local_attrs(text_span, local_config)
-
- return self.parser.get_markup_str_with_attrs()
-
- def find_indexes(self, word_or_text_span: str | tuple[int, int]) -> list[tuple[int, int]]:
- if isinstance(word_or_text_span, tuple):
- return [word_or_text_span]
-
- return [
- match_obj.span()
- for match_obj in re.finditer(re.escape(word_or_text_span), self.text)
- ]
-
def markup_to_svg(self, markup_str: str, file_name: str) -> str:
# `manimpango` is under construction,
# so the following code is intended to suit its interface
@@ -374,7 +219,7 @@ class Text(SVGMobject):
weight="NORMAL", # Already handled
size=1, # Already handled
_=0, # Empty parameter
- disable_liga=False, # Already handled
+ disable_liga=False, # Need not to handle
file_name=file_name,
START_X=0,
START_Y=0,
@@ -387,63 +232,318 @@ class Text(SVGMobject):
pango_width=pango_width
)
- def generate_mobject(self) -> None:
- super().generate_mobject()
+ # Toolkits
- # Remove empty paths
- submobjects = list(filter(lambda submob: submob.has_points(), self))
+ @staticmethod
+ def get_attr_dict_str(attr_dict: dict[str, str]) -> str:
+ return " ".join([
+ f"{key}='{value}'"
+ for key, value in attr_dict.items()
+ ])
- # Apply space characters
- if self.apply_space_chars:
- content_str = self.parser.remove_tags(self.text)
- if self.is_markup:
- content_str = saxutils.unescape(content_str)
- for match_obj in re.finditer(r"\s", content_str):
- char_index = match_obj.start()
- space = Dot(radius=0, fill_opacity=0, stroke_opacity=0)
- space.move_to(submobjects[max(char_index - 1, 0)].get_center())
- submobjects.insert(char_index, space)
- self.set_submobjects(submobjects)
+ @staticmethod
+ def get_begin_tag_str(attr_dict: dict[str, str]) -> str:
+ return f""
- def full2short(self, config: dict) -> None:
- conversion_dict = {
- "line_spacing_height": "lsh",
- "text2color": "t2c",
- "text2font": "t2f",
- "text2gradient": "t2g",
- "text2slant": "t2s",
- "text2weight": "t2w"
- }
- for kwargs in [config, self.CONFIG]:
- for long_name, short_name in conversion_dict.items():
- if long_name in kwargs:
- kwargs[short_name] = kwargs.pop(long_name)
+ @staticmethod
+ def get_end_tag_str() -> str:
+ return ""
- def get_parts_by_text(self, word: str) -> VGroup:
- if self.is_markup:
- log.warning(
- "Slicing MarkupText via `get_parts_by_text`, "
- "the result could be unexpected."
- )
- elif not self.apply_space_chars:
- log.warning(
- "Slicing Text via `get_parts_by_text` without applying spaces, "
- "the result could be unexpected."
- )
- return VGroup(*(
- self[i:j]
- for i, j in self.find_indexes(word)
+ @staticmethod
+ def convert_attr_key(key: str) -> str:
+ return SPAN_ATTR_KEY_CONVERSION[key.lower()]
+
+ @staticmethod
+ def convert_attr_val(val: typing.Any) -> str:
+ return str(val).lower()
+
+ @staticmethod
+ def merge_attr_items(
+ attr_items: list[Span, str, str]
+ ) -> list[tuple[Span, dict[str, str]]]:
+ index_seq = [0]
+ attr_dict_list = [{}]
+ for span, key, value in attr_items:
+ if span[0] >= span[1]:
+ continue
+ region_indices = [
+ MarkupText.find_region_index(index, index_seq)
+ for index in span
+ ]
+ for flag in (1, 0):
+ if index_seq[region_indices[flag]] == span[flag]:
+ continue
+ region_index = region_indices[flag]
+ index_seq.insert(region_index + 1, span[flag])
+ attr_dict_list.insert(
+ region_index + 1, attr_dict_list[region_index].copy()
+ )
+ region_indices[flag] += 1
+ if flag == 0:
+ region_indices[1] += 1
+ for attr_dict in attr_dict_list[slice(*region_indices)]:
+ attr_dict[key] = value
+ return list(zip(
+ MarkupText.get_neighbouring_pairs(index_seq), attr_dict_list[:-1]
))
- def get_part_by_text(self, word: str) -> VMobject | None:
- parts = self.get_parts_by_text(word)
- return parts[0] if parts else None
+ # Parser
+
+ @property
+ def tag_items_from_markup(
+ self
+ ) -> list[tuple[Span, Span, dict[str, str]]]:
+ if not self.is_markup:
+ return []
+
+ tag_pattern = r"""<(/?)(\w+)\s*((\w+\s*\=\s*('.*?'|".*?")\s*)*)>"""
+ attr_pattern = r"""(\w+)\s*\=\s*(?:(?:'(.*?)')|(?:"(.*?)"))"""
+ begin_match_obj_stack = []
+ match_obj_pairs = []
+ for match_obj in re.finditer(tag_pattern, self.string):
+ if not match_obj.group(1):
+ begin_match_obj_stack.append(match_obj)
+ else:
+ match_obj_pairs.append(
+ (begin_match_obj_stack.pop(), match_obj)
+ )
+ if begin_match_obj_stack:
+ raise ValueError("Unclosed tag(s) detected")
+
+ result = []
+ for begin_match_obj, end_match_obj in match_obj_pairs:
+ tag_name = begin_match_obj.group(2)
+ if tag_name != end_match_obj.group(2):
+ raise ValueError("Unmatched tag names")
+ if end_match_obj.group(3):
+ raise ValueError("Attributes shan't exist in ending tags")
+ if tag_name == "span":
+ attr_dict = dict([
+ (
+ MarkupText.convert_attr_key(match.group(1)),
+ MarkupText.convert_attr_val(
+ match.group(2) or match.group(3)
+ )
+ )
+ for match in re.finditer(
+ attr_pattern, begin_match_obj.group(3)
+ )
+ ])
+ elif tag_name in TAG_TO_ATTR_DICT.keys():
+ if begin_match_obj.group(3):
+ raise ValueError(
+ f"Attributes shan't exist in tag '{tag_name}'"
+ )
+ attr_dict = TAG_TO_ATTR_DICT[tag_name].copy()
+ else:
+ raise ValueError(f"Unknown tag: '{tag_name}'")
+
+ result.append(
+ (begin_match_obj.span(), end_match_obj.span(), attr_dict)
+ )
+ return result
+
+ @property
+ def global_attr_items_from_config(self) -> list[str, str]:
+ global_attr_dict = {
+ "line_height": (
+ (self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1
+ ) * 0.6,
+ "font_family": self.font or get_customization()["style"]["font"],
+ "font_size": self.font_size * 1024,
+ "font_style": self.slant,
+ "font_weight": self.weight
+ }
+ global_attr_dict = {
+ k: v
+ for k, v in global_attr_dict.items()
+ if v is not None
+ }
+ result = list(it.chain(
+ global_attr_dict.items(),
+ self.global_config.items()
+ ))
+ return [
+ (
+ self.convert_attr_key(key),
+ self.convert_attr_val(val)
+ )
+ for key, val in result
+ ]
+
+ @property
+ def local_attr_items_from_config(self) -> list[tuple[Span, str, str]]:
+ result = [
+ (text_span, key, val)
+ for t2x_dict, key in (
+ (self.t2c, "foreground"),
+ (self.t2f, "font_family"),
+ (self.t2s, "font_style"),
+ (self.t2w, "font_weight")
+ )
+ for word_or_span, val in t2x_dict.items()
+ for text_span in self.find_spans(word_or_span)
+ ] + [
+ (text_span, key, val)
+ for word_or_span, local_config in self.local_configs.items()
+ for text_span in self.find_spans(word_or_span)
+ for key, val in local_config.items()
+ ]
+ return [
+ (
+ text_span,
+ self.convert_attr_key(key),
+ self.convert_attr_val(val)
+ )
+ for text_span, key, val in result
+ ]
+
+ def find_spans(self, word_or_span: str | Span) -> list[Span]:
+ if isinstance(word_or_span, tuple):
+ return [word_or_span]
+
+ return [
+ match_obj.span()
+ for match_obj in re.finditer(re.escape(word_or_span), self.string)
+ ]
+
+ @property
+ def skipped_spans(self) -> list[Span]:
+ return [
+ match_obj.span()
+ for match_obj in re.finditer(r"\s+", self.string)
+ ]
+
+ @property
+ def label_span_list(self) -> list[Span]:
+ breakup_indices = [
+ index
+ for pattern in [
+ r"\s+",
+ r"\b",
+ *[
+ re.escape(substr)
+ for substr in self.get_substrs_to_isolate(self.isolate)
+ ]
+ ]
+ for match_obj in re.finditer(pattern, self.string)
+ for index in match_obj.span()
+ ]
+ breakup_indices = sorted(filter(
+ lambda index: not any([
+ span[0] < index < span[1]
+ for span, _ in self.command_repl_items
+ ]),
+ remove_list_redundancies([
+ *self.full_span, *breakup_indices
+ ])
+ ))
+ return list(filter(
+ lambda span: self.string[slice(*span)].strip(),
+ self.get_neighbouring_pairs(breakup_indices)
+ ))
+
+ @property
+ def predefined_items(self) -> list[Span, str, str]:
+ return list(it.chain(
+ [
+ (self.full_span, key, val)
+ for key, val in self.global_attr_items_from_config
+ ],
+ sorted([
+ ((begin_tag_span[0], end_tag_span[1]), key, val)
+ for begin_tag_span, end_tag_span, attr_dict
+ in self.tag_items_from_markup
+ for key, val in attr_dict.items()
+ ]),
+ self.local_attr_items_from_config
+ ))
+
+ def get_inserted_string_pairs(
+ self, use_label: bool
+ ) -> list[tuple[Span, tuple[str, str]]]:
+ attr_items = self.predefined_items
+ if use_label:
+ attr_items = [
+ (span, key, WHITE if key in COLOR_RELATED_KEYS else val)
+ for span, key, val in attr_items
+ ] + [
+ (span, "foreground", "#{:06x}".format(label))
+ for label, span in enumerate(self.label_span_list)
+ ]
+ return [
+ (span, (
+ self.get_begin_tag_str(attr_dict),
+ self.get_end_tag_str()
+ ))
+ for span, attr_dict in self.merge_attr_items(attr_items)
+ ]
+
+ @property
+ def inserted_string_pairs(self) -> list[tuple[Span, tuple[str, str]]]:
+ return self.get_inserted_string_pairs(use_label=True)
+
+ @property
+ def command_repl_items(self) -> list[tuple[Span, str]]:
+ return [
+ (tag_span, "")
+ for begin_tag, end_tag, _ in self.tag_items_from_markup
+ for tag_span in (begin_tag, end_tag)
+ ]
+
+ @property
+ def has_predefined_colors(self) -> bool:
+ return any([
+ key in COLOR_RELATED_KEYS
+ for _, key, _ in self.predefined_items
+ ])
+
+ @property
+ def plain_string(self) -> str:
+ return "".join([
+ self.get_begin_tag_str({"foreground": self.base_color}),
+ self.replace_str_by_spans(
+ self.string, self.get_span_replacement_dict(
+ self.get_inserted_string_pairs(use_label=False),
+ self.command_repl_items
+ )
+ ),
+ self.get_end_tag_str()
+ ])
+
+ def handle_submob_string(self, substr: str, string_span: Span) -> str:
+ if self.is_markup:
+ substr = saxutils.unescape(substr)
+ return substr
+
+ # Method alias
+
+ def get_parts_by_text(self, substr: str) -> VGroup:
+ return self.get_parts_by_string(substr)
+
+ def get_part_by_text(self, substr: str, index: int = 0) -> VMobject:
+ return self.get_part_by_string(substr, index)
+
+ def set_color_by_text(self, substr: str, color: ManimColor):
+ return self.set_color_by_string(substr, color)
+
+ def set_color_by_text_to_color_map(
+ self, text_to_color_map: dict[str, ManimColor]
+ ):
+ return self.set_color_by_string_to_color_map(text_to_color_map)
+
+ def indices_of_part_by_text(
+ self, substr: str, index: int = 0
+ ) -> list[int]:
+ return self.indices_of_part_by_string(substr, index)
+
+ def get_text(self) -> str:
+ return self.get_string()
-class MarkupText(Text):
+class Text(MarkupText):
CONFIG = {
- "is_markup": True,
- "apply_space_chars": False,
+ "is_markup": False,
}
@@ -461,7 +561,9 @@ class Code(MarkupText):
digest_config(self, kwargs)
self.code = code
lexer = pygments.lexers.get_lexer_by_name(self.language)
- formatter = pygments.formatters.PangoMarkupFormatter(style=self.code_style)
+ formatter = pygments.formatters.PangoMarkupFormatter(
+ style=self.code_style
+ )
markup = pygments.highlight(code, lexer, formatter)
markup = re.sub(r"?tt>", "", markup)
super().__init__(markup, **kwargs)