diff --git a/manimlib/__init__.py b/manimlib/__init__.py
index 2043738c..67703a8f 100644
--- a/manimlib/__init__.py
+++ b/manimlib/__init__.py
@@ -38,8 +38,8 @@ from manimlib.mobject.probability import *
from manimlib.mobject.shape_matchers import *
from manimlib.mobject.svg.brace import *
from manimlib.mobject.svg.drawings import *
-from manimlib.mobject.svg.labelled_string import *
from manimlib.mobject.svg.mtex_mobject import *
+from manimlib.mobject.svg.string_mobject import *
from manimlib.mobject.svg.svg_mobject import *
from manimlib.mobject.svg.tex_mobject import *
from manimlib.mobject.svg.text_mobject import *
diff --git a/manimlib/animation/creation.py b/manimlib/animation/creation.py
index 27460899..86c5ca05 100644
--- a/manimlib/animation/creation.py
+++ b/manimlib/animation/creation.py
@@ -1,12 +1,12 @@
from __future__ import annotations
-import itertools as it
-from abc import abstractmethod
+from abc import ABC, abstractmethod
import numpy as np
from manimlib.animation.animation import Animation
-from manimlib.mobject.svg.labelled_string import LabelledString
+from manimlib.mobject.svg.string_mobject import StringMobject
+from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.bezier import integer_interpolate
from manimlib.utils.config_ops import digest_config
@@ -17,10 +17,10 @@ from manimlib.utils.rate_functions import smooth
from typing import TYPE_CHECKING
if TYPE_CHECKING:
- from manimlib.mobject.mobject import Group
+ from manimlib.mobject.mobject import Mobject
-class ShowPartial(Animation):
+class ShowPartial(Animation, ABC):
"""
Abstract class for ShowCreation and ShowPassingFlash
"""
@@ -176,7 +176,7 @@ class ShowIncreasingSubsets(Animation):
"int_func": np.round,
}
- def __init__(self, group: Group, **kwargs):
+ def __init__(self, group: Mobject, **kwargs):
self.all_submobs = list(group.submobjects)
super().__init__(group, **kwargs)
@@ -212,8 +212,8 @@ class AddTextWordByWord(ShowIncreasingSubsets):
}
def __init__(self, string_mobject, **kwargs):
- assert isinstance(string_mobject, LabelledString)
- grouped_mobject = string_mobject.submob_groups
+ assert isinstance(string_mobject, StringMobject)
+ grouped_mobject = string_mobject.build_groups()
digest_config(self, kwargs)
if self.run_time is None:
self.run_time = self.time_per_word * len(grouped_mobject)
diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py
index dab88005..e82bafaf 100644
--- a/manimlib/animation/transform_matching_parts.py
+++ b/manimlib/animation/transform_matching_parts.py
@@ -5,24 +5,24 @@ import itertools as it
import numpy as np
from manimlib.animation.composition import AnimationGroup
-from manimlib.animation.fading import FadeTransformPieces
from manimlib.animation.fading import FadeInFromPoint
from manimlib.animation.fading import FadeOutToPoint
+from manimlib.animation.fading import FadeTransformPieces
from manimlib.animation.transform import ReplacementTransform
from manimlib.animation.transform import Transform
from manimlib.mobject.mobject import Mobject
from manimlib.mobject.mobject import Group
-from manimlib.mobject.svg.labelled_string import LabelledString
+from manimlib.mobject.svg.string_mobject import StringMobject
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.config_ops import digest_config
-from manimlib.utils.iterables import remove_list_redundancies
from typing import TYPE_CHECKING
if TYPE_CHECKING:
+ from manimlib.mobject.svg.tex_mobject import SingleStringTex
+ from manimlib.mobject.svg.tex_mobject import Tex
from manimlib.scene.scene import Scene
- from manimlib.mobject.svg.tex_mobject import Tex, SingleStringTex
class TransformMatchingParts(AnimationGroup):
@@ -155,92 +155,89 @@ class TransformMatchingTex(TransformMatchingParts):
class TransformMatchingStrings(AnimationGroup):
CONFIG = {
- "key_map": dict(),
+ "key_map": {},
"transform_mismatches": False,
}
def __init__(self,
- source: LabelledString,
- target: LabelledString,
+ source: StringMobject,
+ target: StringMobject,
**kwargs
):
digest_config(self, kwargs)
- assert isinstance(source, LabelledString)
- assert isinstance(target, LabelledString)
+ assert isinstance(source, StringMobject)
+ assert isinstance(target, StringMobject)
anims = []
- source_indices = list(range(len(source.labelled_submobjects)))
- target_indices = list(range(len(target.labelled_submobjects)))
+ source_indices = list(range(len(source.labels)))
+ target_indices = list(range(len(target.labels)))
- def get_indices_lists(mobject, parts):
- return [
- [
- mobject.labelled_submobjects.index(submob)
- for submob in part
- ]
- for part in parts
- ]
-
- def add_anims_from(anim_class, func, source_args, target_args=None):
- if target_args is None:
- target_args = source_args.copy()
- for source_arg, target_arg in zip(source_args, target_args):
- source_parts = func(source, source_arg)
- target_parts = func(target, target_arg)
- source_indices_lists = list(filter(
- lambda indices_list: all([
- index in source_indices
- for index in indices_list
- ]), get_indices_lists(source, source_parts)
- ))
- target_indices_lists = list(filter(
- lambda indices_list: all([
- index in target_indices
- for index in indices_list
- ]), get_indices_lists(target, target_parts)
- ))
- if not source_indices_lists or not target_indices_lists:
+ def get_filtered_indices_lists(indices_lists, rest_indices):
+ result = []
+ for indices_list in indices_lists:
+ if not indices_list:
continue
- anims.append(anim_class(source_parts, target_parts, **kwargs))
- for index in it.chain(*source_indices_lists):
- source_indices.remove(index)
- for index in it.chain(*target_indices_lists):
- target_indices.remove(index)
-
- def get_common_substrs(substrs_from_source, substrs_from_target):
- return sorted([
- substr for substr in substrs_from_source
- if substr and substr in substrs_from_target
- ], key=len, reverse=True)
-
- def get_parts_from_keys(mobject, keys):
- if isinstance(keys, str):
- keys = [keys]
- result = VGroup()
- for key in keys:
- if not isinstance(key, str):
- raise TypeError(key)
- result.add(*mobject.get_parts_by_string(key))
+ if not all(index in rest_indices for index in indices_list):
+ continue
+ result.append(indices_list)
+ for index in indices_list:
+ rest_indices.remove(index)
return result
- add_anims_from(
- ReplacementTransform, get_parts_from_keys,
- self.key_map.keys(), self.key_map.values()
+ def add_anims(anim_class, indices_lists_pairs):
+ for source_indices_lists, target_indices_lists in indices_lists_pairs:
+ source_indices_lists = get_filtered_indices_lists(
+ source_indices_lists, source_indices
+ )
+ target_indices_lists = get_filtered_indices_lists(
+ target_indices_lists, target_indices
+ )
+ if not source_indices_lists or not target_indices_lists:
+ continue
+ anims.append(anim_class(
+ source.build_parts_from_indices_lists(source_indices_lists),
+ target.build_parts_from_indices_lists(target_indices_lists),
+ **kwargs
+ ))
+
+ def get_substr_to_indices_lists_map(part_items):
+ result = {}
+ for substr, indices_list in part_items:
+ if substr not in result:
+ result[substr] = []
+ result[substr].append(indices_list)
+ return result
+
+ def add_anims_from(anim_class, func):
+ source_substr_map = get_substr_to_indices_lists_map(func(source))
+ target_substr_map = get_substr_to_indices_lists_map(func(target))
+ common_substrings = sorted([
+ s for s in source_substr_map if s and s in target_substr_map
+ ], key=len, reverse=True)
+ add_anims(
+ anim_class,
+ [
+ (source_substr_map[substr], target_substr_map[substr])
+ for substr in common_substrings
+ ]
+ )
+
+ add_anims(
+ ReplacementTransform,
+ [
+ (
+ source.get_submob_indices_lists_by_selector(k),
+ target.get_submob_indices_lists_by_selector(v)
+ )
+ for k, v in self.key_map.items()
+ ]
)
add_anims_from(
FadeTransformPieces,
- LabelledString.get_parts_by_string,
- get_common_substrs(
- source.specified_substrs,
- target.specified_substrs
- )
+ StringMobject.get_specified_part_items
)
add_anims_from(
FadeTransformPieces,
- LabelledString.get_parts_by_group_substr,
- get_common_substrs(
- source.group_substrs,
- target.group_substrs
- )
+ StringMobject.get_group_part_items
)
rest_source = VGroup(*[source[index] for index in source_indices])
diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py
deleted file mode 100644
index f1354f0c..00000000
--- a/manimlib/mobject/svg/labelled_string.py
+++ /dev/null
@@ -1,543 +0,0 @@
-from __future__ import annotations
-
-import re
-import colour
-import itertools as it
-from typing import Iterable, Union, Sequence
-from abc import ABC, abstractmethod
-
-from manimlib.constants import BLACK, WHITE
-from manimlib.mobject.svg.svg_mobject import SVGMobject
-from manimlib.mobject.types.vectorized_mobject import VGroup
-from manimlib.utils.color import color_to_int_rgb
-from manimlib.utils.color import color_to_rgb
-from manimlib.utils.color import rgb_to_hex
-from manimlib.utils.config_ops import digest_config
-from manimlib.utils.iterables import remove_list_redundancies
-
-
-from typing import TYPE_CHECKING
-
-if TYPE_CHECKING:
- from manimlib.mobject.types.vectorized_mobject import VMobject
- ManimColor = Union[str, colour.Color, Sequence[float]]
- Span = tuple[int, int]
-
-
-class _StringSVG(SVGMobject):
- CONFIG = {
- "height": None,
- "stroke_width": 0,
- "stroke_color": WHITE,
- "path_string_config": {
- "should_subdivide_sharp_curves": True,
- "should_remove_null_curves": True,
- },
- }
-
-
-class LabelledString(_StringSVG, ABC):
- """
- An abstract base class for `MTex` and `MarkupText`
- """
- CONFIG = {
- "base_color": WHITE,
- "use_plain_file": False,
- "isolate": [],
- }
-
- def __init__(self, string: str, **kwargs):
- self.string = string
- digest_config(self, kwargs)
-
- # Convert `base_color` to hex code.
- self.base_color = rgb_to_hex(color_to_rgb(
- self.base_color \
- or self.svg_default.get("color", None) \
- or self.svg_default.get("fill_color", None) \
- or WHITE
- ))
- self.svg_default["fill_color"] = BLACK
-
- self.pre_parse()
- self.parse()
- super().__init__()
- self.post_parse()
-
- def get_file_path(self) -> str:
- return self.get_file_path_(use_plain_file=False)
-
- def get_file_path_(self, use_plain_file: bool) -> str:
- content = self.get_content(use_plain_file)
- return self.get_file_path_by_content(content)
-
- @abstractmethod
- def get_file_path_by_content(self, content: str) -> str:
- return ""
-
- def generate_mobject(self) -> None:
- super().generate_mobject()
-
- submob_labels = [
- self.color_to_label(submob.get_fill_color())
- for submob in self.submobjects
- ]
- if self.use_plain_file or self.has_predefined_local_colors:
- file_path = self.get_file_path_(use_plain_file=True)
- plain_svg = _StringSVG(
- file_path,
- svg_default=self.svg_default,
- path_string_config=self.path_string_config
- )
- self.set_submobjects(plain_svg.submobjects)
- else:
- self.set_fill(self.base_color)
- for submob, label in zip(self.submobjects, submob_labels):
- submob.label = label
-
- def pre_parse(self) -> None:
- self.string_len = len(self.string)
- self.full_span = (0, self.string_len)
-
- def parse(self) -> None:
- self.command_repl_items = self.get_command_repl_items()
- self.command_spans = self.get_command_spans()
- self.extra_entity_spans = self.get_extra_entity_spans()
- self.entity_spans = self.get_entity_spans()
- self.extra_ignored_spans = self.get_extra_ignored_spans()
- self.skipped_spans = self.get_skipped_spans()
- self.internal_specified_spans = self.get_internal_specified_spans()
- self.external_specified_spans = self.get_external_specified_spans()
- self.specified_spans = self.get_specified_spans()
- self.label_span_list = self.get_label_span_list()
- self.check_overlapping()
-
- def post_parse(self) -> None:
- self.labelled_submobject_items = [
- (submob.label, submob)
- for submob in self.submobjects
- ]
- self.labelled_submobjects = self.get_labelled_submobjects()
- self.specified_substrs = self.get_specified_substrs()
- self.group_items = self.get_group_items()
- self.group_substrs = self.get_group_substrs()
- self.submob_groups = self.get_submob_groups()
-
- # Toolkits
-
- def get_substr(self, span: Span) -> str:
- return self.string[slice(*span)]
-
- def finditer(
- self, pattern: str, flags: int = 0, **kwargs
- ) -> Iterable[re.Match]:
- return re.compile(pattern, flags).finditer(self.string, **kwargs)
-
- def search(
- self, pattern: str, flags: int = 0, **kwargs
- ) -> re.Match | None:
- return re.compile(pattern, flags).search(self.string, **kwargs)
-
- def match(
- self, pattern: str, flags: int = 0, **kwargs
- ) -> re.Match | None:
- return re.compile(pattern, flags).match(self.string, **kwargs)
-
- def find_spans(self, pattern: str, **kwargs) -> list[Span]:
- return [
- match_obj.span()
- for match_obj in self.finditer(pattern, **kwargs)
- ]
-
- def find_substr(self, substr: str, **kwargs) -> list[Span]:
- if not substr:
- return []
- return self.find_spans(re.escape(substr), **kwargs)
-
- def find_substrs(self, substrs: list[str], **kwargs) -> list[Span]:
- return list(it.chain(*[
- self.find_substr(substr, **kwargs)
- for substr in remove_list_redundancies(substrs)
- ]))
-
- @staticmethod
- def get_neighbouring_pairs(iterable: list) -> list[tuple]:
- return list(zip(iterable[:-1], iterable[1:]))
-
- @staticmethod
- def span_contains(span_0: Span, span_1: Span) -> bool:
- return span_0[0] <= span_1[0] and span_0[1] >= span_1[1]
-
- @staticmethod
- def get_complement_spans(
- interval_spans: list[Span], universal_span: Span
- ) -> list[Span]:
- if not interval_spans:
- return [universal_span]
-
- span_ends, span_begins = zip(*interval_spans)
- return list(zip(
- (universal_span[0], *span_begins),
- (*span_ends, universal_span[1])
- ))
-
- @staticmethod
- def compress_neighbours(vals: list[int]) -> list[tuple[int, Span]]:
- if not vals:
- return []
-
- unique_vals = [vals[0]]
- indices = [0]
- for index, val in enumerate(vals):
- if val == unique_vals[-1]:
- continue
- unique_vals.append(val)
- indices.append(index)
- indices.append(len(vals))
- spans = LabelledString.get_neighbouring_pairs(indices)
- return list(zip(unique_vals, spans))
-
- @staticmethod
- def find_region_index(seq: list[int], val: int) -> int:
- # Returns an integer in `range(-1, len(seq))` satisfying
- # `seq[result] <= val < seq[result + 1]`.
- # `seq` should be sorted in ascending order.
- if not seq or val < seq[0]:
- return -1
- result = len(seq) - 1
- while val < seq[result]:
- result -= 1
- return result
-
- @staticmethod
- def take_nearest_value(seq: list[int], val: int, index_shift: int) -> int:
- sorted_seq = sorted(seq)
- index = LabelledString.find_region_index(sorted_seq, val)
- return sorted_seq[index + index_shift]
-
- @staticmethod
- def generate_span_repl_dict(
- inserted_string_pairs: list[tuple[Span, tuple[str, str]]],
- other_repl_items: list[tuple[Span, str]]
- ) -> dict[Span, str]:
- result = dict(other_repl_items)
- if not inserted_string_pairs:
- return result
-
- indices, _, _, inserted_strings = zip(*sorted([
- (
- span[flag],
- -flag,
- -span[1 - flag],
- str_pair[flag]
- )
- for span, str_pair in inserted_string_pairs
- for flag in range(2)
- ]))
- result.update({
- (index, index): "".join(inserted_strings[slice(*item_span)])
- for index, item_span
- in LabelledString.compress_neighbours(indices)
- })
- return result
-
- def get_replaced_substr(
- self, span: Span, span_repl_dict: dict[Span, str]
- ) -> str:
- repl_spans = sorted(filter(
- lambda repl_span: self.span_contains(span, repl_span),
- span_repl_dict.keys()
- ))
- if not all(
- span_0[1] <= span_1[0]
- for span_0, span_1 in self.get_neighbouring_pairs(repl_spans)
- ):
- raise ValueError("Overlapping replacement")
-
- pieces = [
- self.get_substr(piece_span)
- for piece_span in self.get_complement_spans(repl_spans, span)
- ]
- repl_strs = [span_repl_dict[repl_span] for repl_span in repl_spans]
- repl_strs.append("")
- return "".join(it.chain(*zip(pieces, repl_strs)))
-
- @staticmethod
- def rslide(index: int, skipped: list[Span]) -> int:
- transfer_dict = dict(sorted(skipped))
- while index in transfer_dict.keys():
- index = transfer_dict[index]
- return index
-
- @staticmethod
- def lslide(index: int, skipped: list[Span]) -> int:
- transfer_dict = dict(sorted([
- skipped_span[::-1] for skipped_span in skipped
- ], reverse=True))
- while index in transfer_dict.keys():
- index = transfer_dict[index]
- return index
-
- @staticmethod
- def rgb_to_int(rgb_tuple: tuple[int, int, int]) -> int:
- r, g, b = rgb_tuple
- rg = r * 256 + g
- return rg * 256 + b
-
- @staticmethod
- def int_to_rgb(rgb_int: int) -> tuple[int, int, int]:
- rg, b = divmod(rgb_int, 256)
- r, g = divmod(rg, 256)
- return r, g, b
-
- @staticmethod
- def int_to_hex(rgb_int: int) -> str:
- return "#{:06x}".format(rgb_int).upper()
-
- @staticmethod
- def hex_to_int(rgb_hex: str) -> int:
- return int(rgb_hex[1:], 16)
-
- @staticmethod
- def color_to_label(color: ManimColor) -> int:
- rgb_tuple = color_to_int_rgb(color)
- rgb = LabelledString.rgb_to_int(rgb_tuple)
- return rgb - 1
-
- # Parsing
-
- @abstractmethod
- def get_command_repl_items(self) -> list[tuple[Span, str]]:
- return []
-
- def get_command_spans(self) -> list[Span]:
- return [cmd_span for cmd_span, _ in self.command_repl_items]
-
- @abstractmethod
- def get_extra_entity_spans(self) -> list[Span]:
- return []
-
- def get_entity_spans(self) -> list[Span]:
- return list(it.chain(
- self.command_spans,
- self.extra_entity_spans
- ))
-
- @abstractmethod
- def get_extra_ignored_spans(self) -> list[int]:
- return []
-
- def get_skipped_spans(self) -> list[Span]:
- return list(it.chain(
- self.find_spans(r"\s"),
- self.command_spans,
- self.extra_ignored_spans
- ))
-
- def shrink_span(self, span: Span) -> Span:
- return (
- self.rslide(span[0], self.skipped_spans),
- self.lslide(span[1], self.skipped_spans)
- )
-
- @abstractmethod
- def get_internal_specified_spans(self) -> list[Span]:
- return []
-
- @abstractmethod
- def get_external_specified_spans(self) -> list[Span]:
- return []
-
- def get_specified_spans(self) -> list[Span]:
- spans = list(it.chain(
- self.internal_specified_spans,
- self.external_specified_spans,
- self.find_substrs(self.isolate)
- ))
- shrinked_spans = list(filter(
- lambda span: span[0] < span[1] and not any([
- entity_span[0] < index < entity_span[1]
- for index in span
- for entity_span in self.entity_spans
- ]),
- [self.shrink_span(span) for span in spans]
- ))
- return remove_list_redundancies(shrinked_spans)
-
- @abstractmethod
- def get_label_span_list(self) -> list[Span]:
- return []
-
- def check_overlapping(self) -> None:
- for span_0, span_1 in it.product(self.label_span_list, repeat=2):
- if not span_0[0] < span_1[0] < span_0[1] < span_1[1]:
- continue
- raise ValueError(
- "Partially overlapping substrings detected: "
- f"'{self.get_substr(span_0)}' and '{self.get_substr(span_1)}'"
- )
-
- @abstractmethod
- def get_content(self, use_plain_file: bool) -> str:
- return ""
-
- @abstractmethod
- def has_predefined_local_colors(self) -> bool:
- return False
-
- # Post-parsing
-
- def get_labelled_submobjects(self) -> list[VMobject]:
- return [submob for _, submob in self.labelled_submobject_items]
-
- def get_cleaned_substr(self, span: Span) -> str:
- span_repl_dict = dict.fromkeys(self.command_spans, "")
- return self.get_replaced_substr(span, span_repl_dict)
-
- def get_specified_substrs(self) -> list[str]:
- return remove_list_redundancies([
- self.get_cleaned_substr(span)
- for span in self.specified_spans
- ])
-
- def get_group_items(self) -> list[tuple[str, VGroup]]:
- if not self.labelled_submobject_items:
- return []
-
- labels, labelled_submobjects = zip(*self.labelled_submobject_items)
- group_labels, labelled_submob_spans = zip(
- *self.compress_neighbours(labels)
- )
- ordered_spans = [
- self.label_span_list[label] if label != -1 else self.full_span
- for label in group_labels
- ]
- interval_spans = [
- (
- next_span[0]
- if self.span_contains(prev_span, next_span)
- else prev_span[1],
- prev_span[1]
- if self.span_contains(next_span, prev_span)
- else next_span[0]
- )
- for prev_span, next_span in self.get_neighbouring_pairs(
- ordered_spans
- )
- ]
- shrinked_spans = [
- self.shrink_span(span)
- for span in self.get_complement_spans(
- interval_spans, (ordered_spans[0][0], ordered_spans[-1][1])
- )
- ]
- group_substrs = [
- self.get_cleaned_substr(span) if span[0] < span[1] else ""
- for span in shrinked_spans
- ]
- submob_groups = VGroup(*[
- VGroup(*labelled_submobjects[slice(*submob_span)])
- for submob_span in labelled_submob_spans
- ])
- return list(zip(group_substrs, submob_groups))
-
- def get_group_substrs(self) -> list[str]:
- return [group_substr for group_substr, _ in self.group_items]
-
- def get_submob_groups(self) -> list[VGroup]:
- return [submob_group for _, submob_group in self.group_items]
-
- def get_parts_by_group_substr(self, substr: str) -> VGroup:
- return VGroup(*[
- group
- for group_substr, group in self.group_items
- if group_substr == substr
- ])
-
- # Selector
-
- def find_span_components(
- self, custom_span: Span, substring: bool = True
- ) -> list[Span]:
- shrinked_span = self.shrink_span(custom_span)
- if shrinked_span[0] >= shrinked_span[1]:
- return []
-
- if substring:
- indices = remove_list_redundancies(list(it.chain(
- self.full_span,
- *self.label_span_list
- )))
- span_begin = self.take_nearest_value(
- indices, shrinked_span[0], 0
- )
- span_end = self.take_nearest_value(
- indices, shrinked_span[1] - 1, 1
- )
- else:
- span_begin, span_end = shrinked_span
-
- span_choices = sorted(filter(
- lambda span: self.span_contains((span_begin, span_end), span),
- self.label_span_list
- ))
- # Choose spans that reach the farthest.
- span_choices_dict = dict(span_choices)
-
- result = []
- while span_begin < span_end:
- if span_begin not in span_choices_dict.keys():
- span_begin += 1
- continue
- next_begin = span_choices_dict[span_begin]
- result.append((span_begin, next_begin))
- span_begin = next_begin
- return result
-
- def get_part_by_custom_span(self, custom_span: Span, **kwargs) -> VGroup:
- labels = [
- label for label, span in enumerate(self.label_span_list)
- if any([
- self.span_contains(span_component, span)
- for span_component in self.find_span_components(
- custom_span, **kwargs
- )
- ])
- ]
- return VGroup(*[
- submob for label, submob in self.labelled_submobject_items
- if label in labels
- ])
-
- def get_parts_by_string(
- self, substr: str,
- case_sensitive: bool = True, regex: bool = False, **kwargs
- ) -> VGroup:
- flags = 0
- if not case_sensitive:
- flags |= re.I
- pattern = substr if regex else re.escape(substr)
- return VGroup(*[
- self.get_part_by_custom_span(span, **kwargs)
- for span in self.find_spans(pattern, flags=flags)
- if span[0] < span[1]
- ])
-
- def get_part_by_string(
- self, substr: str, index: int = 0, **kwargs
- ) -> VMobject:
- return self.get_parts_by_string(substr, **kwargs)[index]
-
- def set_color_by_string(self, substr: str, color: ManimColor, **kwargs):
- self.get_parts_by_string(substr, **kwargs).set_color(color)
- return self
-
- def set_color_by_string_to_color_map(
- self, string_to_color_map: dict[str, ManimColor], **kwargs
- ):
- for substr, color in string_to_color_map.items():
- self.set_color_by_string(substr, color, **kwargs)
- return self
-
- def get_string(self) -> str:
- return self.string
diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py
index fb7922e1..149f313f 100644
--- a/manimlib/mobject/svg/mtex_mobject.py
+++ b/manimlib/mobject/svg/mtex_mobject.py
@@ -1,28 +1,37 @@
from __future__ import annotations
-import itertools as it
-import colour
-from typing import Union, Sequence
-
-from manimlib.mobject.svg.labelled_string import LabelledString
-from manimlib.utils.tex_file_writing import tex_to_svg_file
-from manimlib.utils.tex_file_writing import get_tex_config
+from manimlib.mobject.svg.string_mobject import StringMobject
from manimlib.utils.tex_file_writing import display_during_execution
-
+from manimlib.utils.tex_file_writing import get_tex_config
+from manimlib.utils.tex_file_writing import tex_to_svg_file
from typing import TYPE_CHECKING
if TYPE_CHECKING:
- from manimlib.mobject.types.vectorized_mobject import VMobject
+ from colour import Color
+ import re
+ from typing import Iterable, Union
+
from manimlib.mobject.types.vectorized_mobject import VGroup
- ManimColor = Union[str, colour.Color, Sequence[float]]
+
+ ManimColor = Union[str, Color]
Span = tuple[int, int]
+ Selector = Union[
+ str,
+ re.Pattern,
+ tuple[Union[int, None], Union[int, None]],
+ Iterable[Union[
+ str,
+ re.Pattern,
+ tuple[Union[int, None], Union[int, None]]
+ ]]
+ ]
SCALE_FACTOR_PER_FONT_POINT = 0.001
-class MTex(LabelledString):
+class MTex(StringMobject):
CONFIG = {
"font_size": 48,
"alignment": "\\centering",
@@ -32,7 +41,7 @@ class MTex(LabelledString):
def __init__(self, tex_string: str, **kwargs):
# Prevent from passing an empty string.
- if not tex_string:
+ if not tex_string.strip():
tex_string = "\\\\"
self.tex_string = tex_string
super().__init__(tex_string, **kwargs)
@@ -47,7 +56,6 @@ class MTex(LabelledString):
self.svg_default,
self.path_string_config,
self.base_color,
- self.use_plain_file,
self.isolate,
self.tex_string,
self.alignment,
@@ -61,270 +69,103 @@ class MTex(LabelledString):
tex_config["text_to_replace"],
content
)
- with display_during_execution(f"Writing \"{self.tex_string}\""):
+ with display_during_execution(f"Writing \"{self.string}\""):
file_path = tex_to_svg_file(full_tex)
return file_path
- def pre_parse(self) -> None:
- super().pre_parse()
- self.backslash_indices = self.get_backslash_indices()
- self.brace_index_pairs = self.get_brace_index_pairs()
- self.script_char_spans = self.get_script_char_spans()
- self.script_content_spans = self.get_script_content_spans()
- self.script_spans = self.get_script_spans()
-
- # Toolkits
-
- @staticmethod
- def get_color_command_str(rgb_int: int) -> str:
- rgb_tuple = MTex.int_to_rgb(rgb_int)
- return "".join([
- "\\color[RGB]",
- "{",
- ",".join(map(str, rgb_tuple)),
- "}"
- ])
-
- # Pre-parsing
-
- def get_backslash_indices(self) -> list[int]:
- # The latter of `\\` doesn't count.
- return list(it.chain(*[
- range(span[0], span[1], 2)
- for span in self.find_spans(r"\\+")
- ]))
-
- def get_unescaped_char_spans(self, chars: str):
- return sorted(filter(
- lambda span: span[0] - 1 not in self.backslash_indices,
- self.find_substrs(list(chars))
- ))
-
- def get_brace_index_pairs(self) -> list[Span]:
- left_brace_indices = []
- right_brace_indices = []
- left_brace_indices_stack = []
- for span in self.get_unescaped_char_spans("{}"):
- index = span[0]
- if self.get_substr(span) == "{":
- left_brace_indices_stack.append(index)
- else:
- if not left_brace_indices_stack:
- raise ValueError("Missing '{' inserted")
- left_brace_index = left_brace_indices_stack.pop()
- left_brace_indices.append(left_brace_index)
- right_brace_indices.append(index)
- if left_brace_indices_stack:
- raise ValueError("Missing '}' inserted")
- return list(zip(left_brace_indices, right_brace_indices))
-
- def get_script_char_spans(self) -> list[int]:
- return self.get_unescaped_char_spans("_^")
-
- def get_script_content_spans(self) -> list[Span]:
- result = []
- brace_indices_dict = dict(self.brace_index_pairs)
- script_pattern = r"[a-zA-Z0-9]|\\[a-zA-Z]+"
- for script_char_span in self.script_char_spans:
- span_begin = self.match(r"\s*", pos=script_char_span[1]).end()
- if span_begin in brace_indices_dict.keys():
- span_end = brace_indices_dict[span_begin] + 1
- else:
- match_obj = self.match(script_pattern, pos=span_begin)
- if not match_obj:
- script_name = {
- "_": "subscript",
- "^": "superscript"
- }[script_char]
- raise ValueError(
- f"Unclear {script_name} detected while parsing. "
- "Please use braces to clarify"
- )
- span_end = match_obj.end()
- result.append((span_begin, span_end))
- return result
-
- def get_script_spans(self) -> list[Span]:
- return [
- (
- self.search(r"\s*$", endpos=script_char_span[0]).start(),
- script_content_span[1]
- )
- for script_char_span, script_content_span in zip(
- self.script_char_spans, self.script_content_spans
- )
- ]
-
# Parsing
- def get_command_repl_items(self) -> list[tuple[Span, str]]:
- color_related_command_dict = {
- "color": (1, False),
- "textcolor": (1, False),
- "pagecolor": (1, True),
- "colorbox": (1, True),
- "fcolorbox": (2, True),
- }
- result = []
- backslash_indices = self.backslash_indices
- right_brace_indices = [
- right_index
- for left_index, right_index in self.brace_index_pairs
+ def get_cmd_spans(self) -> list[Span]:
+ return self.find_spans(r"\\(?:[a-zA-Z]+|\s|\S)|[_^{}]")
+
+ def get_substr_flag(self, substr: str) -> int:
+ return {"{": 1, "}": -1}.get(substr, 0)
+
+ def get_repl_substr_for_content(self, substr: str) -> str:
+ return substr
+
+ def get_repl_substr_for_matching(self, substr: str) -> str:
+ return substr if substr.startswith("\\") else ""
+
+ def get_specified_items(
+ self, cmd_span_pairs: list[tuple[Span, Span]]
+ ) -> list[tuple[Span, dict[str, str]]]:
+ cmd_content_spans = [
+ (span_begin, span_end)
+ for (_, span_begin), (span_end, _) in cmd_span_pairs
]
- pattern = "".join([
- r"\\",
- "(",
- "|".join(color_related_command_dict.keys()),
- ")",
- r"(?![a-zA-Z])"
- ])
- for match_obj in self.finditer(pattern):
- span_begin, cmd_end = match_obj.span()
- if span_begin not in backslash_indices:
- continue
- cmd_name = match_obj.group(1)
- n_braces, substitute_cmd = color_related_command_dict[cmd_name]
- span_end = self.take_nearest_value(
- right_brace_indices, cmd_end, n_braces
- ) + 1
- if substitute_cmd:
- repl_str = "\\" + cmd_name + n_braces * "{black}"
- else:
- repl_str = ""
- result.append(((span_begin, span_end), repl_str))
- return result
-
- def get_extra_entity_spans(self) -> list[Span]:
- return [
- self.match(r"\\([a-zA-Z]+|.)", pos=index).span()
- for index in self.backslash_indices
- ]
-
- def get_extra_ignored_spans(self) -> list[int]:
- return self.script_char_spans.copy()
-
- def get_internal_specified_spans(self) -> list[Span]:
- # Match paired double braces (`{{...}}`).
- result = []
- reversed_brace_indices_dict = dict([
- pair[::-1] for pair in self.brace_index_pairs
- ])
- skip = False
- for prev_right_index, right_index in self.get_neighbouring_pairs(
- list(reversed_brace_indices_dict.keys())
- ):
- if skip:
- skip = False
- continue
- if right_index != prev_right_index + 1:
- continue
- left_index = reversed_brace_indices_dict[right_index]
- prev_left_index = reversed_brace_indices_dict[prev_right_index]
- if left_index != prev_left_index - 1:
- continue
- result.append((left_index, right_index + 1))
- skip = True
- return result
-
- def get_external_specified_spans(self) -> list[Span]:
- return self.find_substrs(list(self.tex_to_color_map.keys()))
-
- def get_label_span_list(self) -> list[Span]:
- result = self.script_content_spans.copy()
- for span_begin, span_end in self.specified_spans:
- shrinked_end = self.lslide(span_end, self.script_spans)
- if span_begin >= shrinked_end:
- continue
- shrinked_span = (span_begin, shrinked_end)
- if shrinked_span in result:
- continue
- result.append(shrinked_span)
- return result
-
- def get_content(self, use_plain_file: bool) -> str:
- if use_plain_file:
- span_repl_dict = {}
- else:
- extended_label_span_list = [
+ specified_spans = [
+ *[
+ cmd_content_spans[range_begin]
+ for _, (range_begin, range_end) in self.compress_neighbours([
+ (span_begin + index, span_end - index)
+ for index, (span_begin, span_end) in enumerate(
+ cmd_content_spans
+ )
+ ])
+ if range_end - range_begin >= 2
+ ],
+ *[
span
- if span in self.script_content_spans
- else (span[0], self.rslide(span[1], self.script_spans))
- for span in self.label_span_list
- ]
- inserted_string_pairs = [
- (span, (
- "{{" + self.get_color_command_str(label + 1),
- "}}"
- ))
- for label, span in enumerate(extended_label_span_list)
- ]
- span_repl_dict = self.generate_span_repl_dict(
- inserted_string_pairs,
- self.command_repl_items
- )
- result = self.get_replaced_substr(self.full_span, span_repl_dict)
+ for selector in self.tex_to_color_map
+ for span in self.find_spans_by_selector(selector)
+ ],
+ *self.find_spans_by_selector(self.isolate)
+ ]
+ return [(span, {}) for span in specified_spans]
- if self.tex_environment:
- result = "\n".join([
- f"\\begin{{{self.tex_environment}}}",
- result,
- f"\\end{{{self.tex_environment}}}"
- ])
+ @staticmethod
+ def get_color_cmd_str(rgb_hex: str) -> str:
+ rgb = MTex.hex_to_int(rgb_hex)
+ rg, b = divmod(rgb, 256)
+ r, g = divmod(rg, 256)
+ return f"\\color[RGB]{{{r}, {g}, {b}}}"
+
+ @staticmethod
+ def get_cmd_str_pair(
+ attr_dict: dict[str, str], label_hex: str | None
+ ) -> tuple[str, str]:
+ if label_hex is None:
+ return "", ""
+ return "{{" + MTex.get_color_cmd_str(label_hex), "}}"
+
+ def get_content_prefix_and_suffix(
+ self, is_labelled: bool
+ ) -> tuple[str, str]:
+ prefix_lines = []
+ suffix_lines = []
+ if not is_labelled:
+ prefix_lines.append(self.get_color_cmd_str(self.base_color_hex))
if self.alignment:
- result = "\n".join([self.alignment, result])
- if use_plain_file:
- result = "\n".join([
- self.get_color_command_str(self.hex_to_int(self.base_color)),
- result
- ])
- return result
-
- @property
- def has_predefined_local_colors(self) -> bool:
- return bool(self.command_repl_items)
-
- # Post-parsing
-
- def get_cleaned_substr(self, span: Span) -> str:
- substr = super().get_cleaned_substr(span)
- if not self.brace_index_pairs:
- return substr
-
- # Balance braces.
- left_brace_indices, right_brace_indices = zip(*self.brace_index_pairs)
- unclosed_left_braces = 0
- unclosed_right_braces = 0
- for index in range(*span):
- if index in left_brace_indices:
- unclosed_left_braces += 1
- elif index in right_brace_indices:
- if unclosed_left_braces == 0:
- unclosed_right_braces += 1
- else:
- unclosed_left_braces -= 1
- return "".join([
- unclosed_right_braces * "{",
- substr,
- unclosed_left_braces * "}"
- ])
+ prefix_lines.append(self.alignment)
+ if self.tex_environment:
+ if isinstance(self.tex_environment, str):
+ env_prefix = f"\\begin{{{self.tex_environment}}}"
+ env_suffix = f"\\end{{{self.tex_environment}}}"
+ else:
+ env_prefix, env_suffix = self.tex_environment
+ prefix_lines.append(env_prefix)
+ suffix_lines.append(env_suffix)
+ return (
+ "".join([line + "\n" for line in prefix_lines]),
+ "".join(["\n" + line for line in suffix_lines])
+ )
# Method alias
- def get_parts_by_tex(self, tex: str, **kwargs) -> VGroup:
- return self.get_parts_by_string(tex, **kwargs)
+ def get_parts_by_tex(self, selector: Selector) -> VGroup:
+ return self.select_parts(selector)
- def get_part_by_tex(self, tex: str, **kwargs) -> VMobject:
- return self.get_part_by_string(tex, **kwargs)
+ def get_part_by_tex(self, selector: Selector) -> VGroup:
+ return self.select_part(selector)
- def set_color_by_tex(self, tex: str, color: ManimColor, **kwargs):
- return self.set_color_by_string(tex, color, **kwargs)
+ def set_color_by_tex(self, selector: Selector, color: ManimColor):
+ return self.set_parts_color(selector, color)
def set_color_by_tex_to_color_map(
- self, tex_to_color_map: dict[str, ManimColor], **kwargs
+ self, color_map: dict[Selector, ManimColor]
):
- return self.set_color_by_string_to_color_map(
- tex_to_color_map, **kwargs
- )
+ return self.set_parts_color_by_dict(color_map)
def get_tex(self) -> str:
return self.get_string()
diff --git a/manimlib/mobject/svg/string_mobject.py b/manimlib/mobject/svg/string_mobject.py
new file mode 100644
index 00000000..5004960e
--- /dev/null
+++ b/manimlib/mobject/svg/string_mobject.py
@@ -0,0 +1,532 @@
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+import itertools as it
+import re
+from scipy.optimize import linear_sum_assignment
+from scipy.spatial.distance import cdist
+
+from manimlib.constants import WHITE
+from manimlib.logger import log
+from manimlib.mobject.svg.svg_mobject import SVGMobject
+from manimlib.mobject.types.vectorized_mobject import VGroup
+from manimlib.utils.color import color_to_rgb
+from manimlib.utils.color import rgb_to_hex
+from manimlib.utils.config_ops import digest_config
+
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from colour import Color
+ from typing import Iterable, Sequence, TypeVar, Union
+
+ ManimColor = Union[str, Color]
+ Span = tuple[int, int]
+ Selector = Union[
+ str,
+ re.Pattern,
+ tuple[Union[int, None], Union[int, None]],
+ Iterable[Union[
+ str,
+ re.Pattern,
+ tuple[Union[int, None], Union[int, None]]
+ ]]
+ ]
+ T = TypeVar("T")
+
+
+class StringMobject(SVGMobject, ABC):
+ """
+ An abstract base class for `MTex` and `MarkupText`
+
+ This class aims to optimize the logic of "slicing submobjects
+ via substrings". This could be much clearer and more user-friendly
+ than slicing through numerical indices explicitly.
+
+ Users are expected to specify substrings in `isolate` parameter
+ if they want to do anything with their corresponding submobjects.
+ `isolate` parameter can be either a string, a `re.Pattern` object,
+ or a 2-tuple containing integers or None, or a collection of the above.
+ Note, substrings specified cannot *partially* overlap with each other.
+
+ Each instance of `StringMobject` generates 2 svg files.
+ The additional one is generated with some color commands inserted,
+ so that each submobject of the original `SVGMobject` will be labelled
+ by the color of its paired submobject from the additional `SVGMobject`.
+ """
+ CONFIG = {
+ "height": None,
+ "stroke_width": 0,
+ "stroke_color": WHITE,
+ "path_string_config": {
+ "should_subdivide_sharp_curves": True,
+ "should_remove_null_curves": True,
+ },
+ "base_color": WHITE,
+ "isolate": (),
+ }
+
+ def __init__(self, string: str, **kwargs):
+ self.string = string
+ digest_config(self, kwargs)
+ if self.base_color is None:
+ self.base_color = WHITE
+ self.base_color_hex = self.color_to_hex(self.base_color)
+
+ self.full_span = (0, len(self.string))
+ self.parse()
+ super().__init__(**kwargs)
+ self.labels = [submob.label for submob in self.submobjects]
+
+ def get_file_path(self) -> str:
+ original_content = self.get_content(is_labelled=False)
+ return self.get_file_path_by_content(original_content)
+
+ @abstractmethod
+ def get_file_path_by_content(self, content: str) -> str:
+ return ""
+
+ def generate_mobject(self) -> None:
+ super().generate_mobject()
+
+ labels_count = len(self.labelled_spans)
+ if not labels_count:
+ for submob in self.submobjects:
+ submob.label = -1
+ return
+
+ labelled_content = self.get_content(is_labelled=True)
+ file_path = self.get_file_path_by_content(labelled_content)
+ labelled_svg = SVGMobject(file_path)
+ if len(self.submobjects) != len(labelled_svg.submobjects):
+ log.warning(
+ "Cannot align submobjects of the labelled svg "
+ "to the original svg. Skip the labelling process."
+ )
+ for submob in self.submobjects:
+ submob.label = -1
+ return
+
+ self.rearrange_submobjects_by_positions(labelled_svg)
+ unrecognizable_colors = []
+ for submob, labelled_svg_submob in zip(
+ self.submobjects, labelled_svg.submobjects
+ ):
+ color_int = self.hex_to_int(self.color_to_hex(
+ labelled_svg_submob.get_fill_color()
+ ))
+ if color_int > labels_count:
+ unrecognizable_colors.append(color_int)
+ color_int = 0
+ submob.label = color_int - 1
+ if unrecognizable_colors:
+ log.warning(
+ "Unrecognizable color labels detected (%s, etc). "
+ "The result could be unexpected.",
+ self.int_to_hex(unrecognizable_colors[0])
+ )
+
+ def rearrange_submobjects_by_positions(
+ self, labelled_svg: SVGMobject
+ ) -> None:
+ # Rearrange submobjects of `labelled_svg` so that
+ # each submobject is labelled by the nearest one of `labelled_svg`.
+ # The correctness cannot be ensured, since the svg may
+ # change significantly after inserting color commands.
+ if not labelled_svg.submobjects:
+ return
+
+ bb_0 = self.get_bounding_box()
+ bb_1 = labelled_svg.get_bounding_box()
+ scale_factor = abs((bb_0[2] - bb_0[0]) / (bb_1[2] - bb_1[0]))
+ labelled_svg.move_to(self).scale(scale_factor)
+
+ distance_matrix = cdist(
+ [submob.get_center() for submob in self.submobjects],
+ [submob.get_center() for submob in labelled_svg.submobjects]
+ )
+ _, indices = linear_sum_assignment(distance_matrix)
+ labelled_svg.set_submobjects([
+ labelled_svg.submobjects[index]
+ for index in indices
+ ])
+
+ # Toolkits
+
+ def get_substr(self, span: Span) -> str:
+ return self.string[slice(*span)]
+
+ def find_spans(self, pattern: str | re.Pattern) -> list[Span]:
+ return [
+ match_obj.span()
+ for match_obj in re.finditer(pattern, self.string)
+ ]
+
+ def find_spans_by_selector(self, selector: Selector) -> list[Span]:
+ def find_spans_by_single_selector(sel):
+ if isinstance(sel, str):
+ return self.find_spans(re.escape(sel))
+ if isinstance(sel, re.Pattern):
+ return self.find_spans(sel)
+ if isinstance(sel, tuple) and len(sel) == 2 and all(
+ isinstance(index, int) or index is None
+ for index in sel
+ ):
+ l = self.full_span[1]
+ span = tuple(
+ min(index, l) if index >= 0 else max(index + l, 0)
+ if index is not None else default_index
+ for index, default_index in zip(sel, self.full_span)
+ )
+ return [span]
+ return None
+
+ result = find_spans_by_single_selector(selector)
+ if result is None:
+ result = []
+ for sel in selector:
+ spans = find_spans_by_single_selector(sel)
+ if spans is None:
+ raise TypeError(f"Invalid selector: '{sel}'")
+ result.extend(spans)
+ return result
+
+ @staticmethod
+ def get_neighbouring_pairs(vals: Sequence[T]) -> list[tuple[T, T]]:
+ return list(zip(vals[:-1], vals[1:]))
+
+ @staticmethod
+ def compress_neighbours(vals: Sequence[T]) -> list[tuple[T, Span]]:
+ if not vals:
+ return []
+
+ unique_vals = [vals[0]]
+ indices = [0]
+ for index, val in enumerate(vals):
+ if val == unique_vals[-1]:
+ continue
+ unique_vals.append(val)
+ indices.append(index)
+ indices.append(len(vals))
+ val_ranges = StringMobject.get_neighbouring_pairs(indices)
+ return list(zip(unique_vals, val_ranges))
+
+ @staticmethod
+ def span_contains(span_0: Span, span_1: Span) -> bool:
+ return span_0[0] <= span_1[0] and span_0[1] >= span_1[1]
+
+ @staticmethod
+ def get_complement_spans(
+ universal_span: Span, interval_spans: list[Span]
+ ) -> list[Span]:
+ if not interval_spans:
+ return [universal_span]
+
+ span_ends, span_begins = zip(*interval_spans)
+ return list(zip(
+ (universal_span[0], *span_begins),
+ (*span_ends, universal_span[1])
+ ))
+
+ def replace_substr(self, span: Span, repl_items: list[Span, str]):
+ if not repl_items:
+ return self.get_substr(span)
+
+ repl_spans, repl_strs = zip(*sorted(repl_items, key=lambda t: t[0]))
+ pieces = [
+ self.get_substr(piece_span)
+ for piece_span in self.get_complement_spans(span, repl_spans)
+ ]
+ repl_strs = [*repl_strs, ""]
+ return "".join(it.chain(*zip(pieces, repl_strs)))
+
+ @staticmethod
+ def color_to_hex(color: ManimColor) -> str:
+ return rgb_to_hex(color_to_rgb(color))
+
+ @staticmethod
+ def hex_to_int(rgb_hex: str) -> int:
+ return int(rgb_hex[1:], 16)
+
+ @staticmethod
+ def int_to_hex(rgb_int: int) -> str:
+ return f"#{rgb_int:06x}".upper()
+
+ # Parsing
+
+ def parse(self) -> None:
+ cmd_spans = self.get_cmd_spans()
+ cmd_substrs = [self.get_substr(span) for span in cmd_spans]
+ flags = [self.get_substr_flag(substr) for substr in cmd_substrs]
+ specified_items = self.get_specified_items(
+ self.get_cmd_span_pairs(cmd_spans, flags)
+ )
+ split_items = [
+ (span, attr_dict)
+ for specified_span, attr_dict in specified_items
+ for span in self.split_span_by_levels(
+ specified_span, cmd_spans, flags
+ )
+ ]
+
+ self.specified_spans = [span for span, _ in specified_items]
+ self.split_items = split_items
+ self.labelled_spans = [span for span, _ in split_items]
+ self.cmd_repl_items_for_content = [
+ (span, self.get_repl_substr_for_content(substr))
+ for span, substr in zip(cmd_spans, cmd_substrs)
+ ]
+ self.cmd_repl_items_for_matching = [
+ (span, self.get_repl_substr_for_matching(substr))
+ for span, substr in zip(cmd_spans, cmd_substrs)
+ ]
+ self.check_overlapping()
+
+ @abstractmethod
+ def get_cmd_spans(self) -> list[Span]:
+ return []
+
+ @abstractmethod
+ def get_substr_flag(self, substr: str) -> int:
+ return 0
+
+ @abstractmethod
+ def get_repl_substr_for_content(self, substr: str) -> str:
+ return ""
+
+ @abstractmethod
+ def get_repl_substr_for_matching(self, substr: str) -> str:
+ return ""
+
+ @staticmethod
+ def get_cmd_span_pairs(
+ cmd_spans: list[Span], flags: list[int]
+ ) -> list[tuple[Span, Span]]:
+ result = []
+ begin_cmd_spans_stack = []
+ for cmd_span, flag in zip(cmd_spans, flags):
+ if flag == 1:
+ begin_cmd_spans_stack.append(cmd_span)
+ elif flag == -1:
+ if not begin_cmd_spans_stack:
+ raise ValueError("Missing open command")
+ begin_cmd_span = begin_cmd_spans_stack.pop()
+ result.append((begin_cmd_span, cmd_span))
+ if begin_cmd_spans_stack:
+ raise ValueError("Missing close command")
+ return result
+
+ @abstractmethod
+ def get_specified_items(
+ self, cmd_span_pairs: list[tuple[Span, Span]]
+ ) -> list[tuple[Span, dict[str, str]]]:
+ return []
+
+ def split_span_by_levels(
+ self, arbitrary_span: Span, cmd_spans: list[Span], flags: list[int]
+ ) -> list[Span]:
+ cmd_range = (
+ sum([
+ arbitrary_span[0] > interval_begin
+ for interval_begin, _ in cmd_spans
+ ]),
+ sum([
+ arbitrary_span[1] >= interval_end
+ for _, interval_end in cmd_spans
+ ])
+ )
+ complement_spans = self.get_complement_spans(
+ self.full_span, cmd_spans
+ )
+ adjusted_span = (
+ max(arbitrary_span[0], complement_spans[cmd_range[0]][0]),
+ min(arbitrary_span[1], complement_spans[cmd_range[1]][1])
+ )
+ if adjusted_span[0] > adjusted_span[1]:
+ return []
+
+ upward_cmd_spans = []
+ downward_cmd_spans = []
+ for cmd_span, flag in list(zip(cmd_spans, flags))[slice(*cmd_range)]:
+ if flag == 1:
+ upward_cmd_spans.append(cmd_span)
+ elif flag == -1:
+ if upward_cmd_spans:
+ upward_cmd_spans.pop()
+ else:
+ downward_cmd_spans.append(cmd_span)
+ return list(filter(
+ lambda span: self.get_substr(span).strip(),
+ self.get_complement_spans(
+ adjusted_span, downward_cmd_spans + upward_cmd_spans
+ )
+ ))
+
+ def check_overlapping(self) -> None:
+ labelled_spans = self.labelled_spans
+ if len(labelled_spans) >= 16777216:
+ raise ValueError("Cannot handle that many substrings")
+ for span_0, span_1 in it.product(labelled_spans, repeat=2):
+ if not span_0[0] < span_1[0] < span_0[1] < span_1[1]:
+ continue
+ raise ValueError(
+ "Partially overlapping substrings detected: "
+ f"'{self.get_substr(span_0)}' and '{self.get_substr(span_1)}'"
+ )
+
+ @staticmethod
+ @abstractmethod
+ def get_cmd_str_pair(
+ attr_dict: dict[str, str], label_hex: str | None
+ ) -> tuple[str, str]:
+ return "", ""
+
+ @abstractmethod
+ def get_content_prefix_and_suffix(
+ self, is_labelled: bool
+ ) -> tuple[str, str]:
+ return "", ""
+
+ def get_content(self, is_labelled: bool) -> str:
+ inserted_str_pairs = [
+ (span, self.get_cmd_str_pair(
+ attr_dict,
+ label_hex=self.int_to_hex(label + 1) if is_labelled else None
+ ))
+ for label, (span, attr_dict) in enumerate(self.split_items)
+ ]
+ inserted_str_items = sorted([
+ (index, s)
+ for (index, _), s in [
+ *sorted([
+ (span[::-1], end_str)
+ for span, (_, end_str) in reversed(inserted_str_pairs)
+ ], key=lambda t: (t[0][0], -t[0][1])),
+ *sorted([
+ (span, begin_str)
+ for span, (begin_str, _) in inserted_str_pairs
+ ], key=lambda t: (t[0][0], -t[0][1]))
+ ]
+ ], key=lambda t: t[0])
+ repl_items = self.cmd_repl_items_for_content + [
+ ((index, index), inserted_str)
+ for index, inserted_str in inserted_str_items
+ ]
+ prefix, suffix = self.get_content_prefix_and_suffix(is_labelled)
+ return "".join([
+ prefix,
+ self.replace_substr(self.full_span, repl_items),
+ suffix
+ ])
+
+ # Selector
+
+ def get_submob_indices_list_by_span(
+ self, arbitrary_span: Span
+ ) -> list[int]:
+ return [
+ submob_index
+ for submob_index, label in enumerate(self.labels)
+ if label != -1 and self.span_contains(
+ arbitrary_span, self.labelled_spans[label]
+ )
+ ]
+
+ def get_specified_part_items(self) -> list[tuple[str, list[int]]]:
+ return [
+ (
+ self.get_substr(span),
+ self.get_submob_indices_list_by_span(span)
+ )
+ for span in self.specified_spans
+ ]
+
+ def get_group_part_items(self) -> list[tuple[str, list[int]]]:
+ if not self.labels:
+ return []
+
+ group_labels, labelled_submob_ranges = zip(
+ *self.compress_neighbours(self.labels)
+ )
+ ordered_spans = [
+ self.labelled_spans[label] if label != -1 else self.full_span
+ for label in group_labels
+ ]
+ interval_spans = [
+ (
+ next_span[0]
+ if self.span_contains(prev_span, next_span)
+ else prev_span[1],
+ prev_span[1]
+ if self.span_contains(next_span, prev_span)
+ else next_span[0]
+ )
+ for prev_span, next_span in self.get_neighbouring_pairs(
+ ordered_spans
+ )
+ ]
+ group_substrs = [
+ re.sub(r"\s+", "", self.replace_substr(
+ span, [
+ (cmd_span, repl_str)
+ for cmd_span, repl_str in self.cmd_repl_items_for_matching
+ if self.span_contains(span, cmd_span)
+ ]
+ ))
+ for span in self.get_complement_spans(
+ (ordered_spans[0][0], ordered_spans[-1][1]), interval_spans
+ )
+ ]
+ submob_indices_lists = [
+ list(range(*submob_range))
+ for submob_range in labelled_submob_ranges
+ ]
+ return list(zip(group_substrs, submob_indices_lists))
+
+ def get_submob_indices_lists_by_selector(
+ self, selector: Selector
+ ) -> list[list[int]]:
+ return list(filter(
+ lambda indices_list: indices_list,
+ [
+ self.get_submob_indices_list_by_span(span)
+ for span in self.find_spans_by_selector(selector)
+ ]
+ ))
+
+ def build_parts_from_indices_lists(
+ self, indices_lists: list[list[int]]
+ ) -> VGroup:
+ return VGroup(*[
+ VGroup(*[
+ self.submobjects[submob_index]
+ for submob_index in indices_list
+ ])
+ for indices_list in indices_lists
+ ])
+
+ def build_groups(self) -> VGroup:
+ return self.build_parts_from_indices_lists([
+ indices_list
+ for _, indices_list in self.get_group_part_items()
+ ])
+
+ def select_parts(self, selector: Selector) -> VGroup:
+ return self.build_parts_from_indices_lists(
+ self.get_submob_indices_lists_by_selector(selector)
+ )
+
+ def select_part(self, selector: Selector, index: int = 0) -> VGroup:
+ return self.select_parts(selector)[index]
+
+ def set_parts_color(self, selector: Selector, color: ManimColor):
+ self.select_parts(selector).set_color(color)
+ return self
+
+ def set_parts_color_by_dict(self, color_map: dict[Selector, ManimColor]):
+ for selector, color in color_map.items():
+ self.set_parts_color(selector, color)
+ return self
+
+ def get_string(self) -> str:
+ return self.string
diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py
index c3c3be19..93623c31 100644
--- a/manimlib/mobject/svg/text_mobject.py
+++ b/manimlib/mobject/svg/text_mobject.py
@@ -1,103 +1,52 @@
from __future__ import annotations
-import os
-import re
-import itertools as it
-from pathlib import Path
from contextlib import contextmanager
-import typing
-from typing import Iterable, Sequence, Union
+import os
+from pathlib import Path
+import re
+import manimpango
import pygments
import pygments.formatters
import pygments.lexers
-from manimpango import MarkupUtils
-
+from manimlib.constants import DEFAULT_PIXEL_WIDTH, FRAME_WIDTH
+from manimlib.constants import NORMAL
from manimlib.logger import log
-from manimlib.constants import *
-from manimlib.mobject.svg.labelled_string import LabelledString
-from manimlib.utils.customization import get_customization
-from manimlib.utils.tex_file_writing import tex_hash
+from manimlib.mobject.svg.string_mobject import StringMobject
from manimlib.utils.config_ops import digest_config
+from manimlib.utils.customization import get_customization
from manimlib.utils.directories import get_downloads_dir
from manimlib.utils.directories import get_text_dir
-from manimlib.utils.iterables import remove_list_redundancies
-
+from manimlib.utils.tex_file_writing import tex_hash
from typing import TYPE_CHECKING
if TYPE_CHECKING:
- from manimlib.mobject.types.vectorized_mobject import VMobject
+ from colour import Color
+ from typing import Iterable, Union
+
from manimlib.mobject.types.vectorized_mobject import VGroup
- ManimColor = Union[str, colour.Color, Sequence[float]]
+
+ ManimColor = Union[str, Color]
Span = tuple[int, int]
+ Selector = Union[
+ str,
+ re.Pattern,
+ tuple[Union[int, None], Union[int, None]],
+ Iterable[Union[
+ str,
+ re.Pattern,
+ tuple[Union[int, None], Union[int, None]]
+ ]]
+ ]
TEXT_MOB_SCALE_FACTOR = 0.0076
DEFAULT_LINE_SPACING_SCALE = 0.6
-
-
-# See https://docs.gtk.org/Pango/pango_markup.html
-# A tag containing two aliases will cause warning,
-# so only use the first key of each group of aliases.
-SPAN_ATTR_KEY_ALIAS_LIST = (
- ("font", "font_desc"),
- ("font_family", "face"),
- ("font_size", "size"),
- ("font_style", "style"),
- ("font_weight", "weight"),
- ("font_variant", "variant"),
- ("font_stretch", "stretch"),
- ("font_features",),
- ("foreground", "fgcolor", "color"),
- ("background", "bgcolor"),
- ("alpha", "fgalpha"),
- ("background_alpha", "bgalpha"),
- ("underline",),
- ("underline_color",),
- ("overline",),
- ("overline_color",),
- ("rise",),
- ("baseline_shift",),
- ("font_scale",),
- ("strikethrough",),
- ("strikethrough_color",),
- ("fallback",),
- ("lang",),
- ("letter_spacing",),
- ("gravity",),
- ("gravity_hint",),
- ("show",),
- ("insert_hyphens",),
- ("allow_breaks",),
- ("line_height",),
- ("text_transform",),
- ("segment",),
-)
-COLOR_RELATED_KEYS = (
- "foreground",
- "background",
- "underline_color",
- "overline_color",
- "strikethrough_color"
-)
-SPAN_ATTR_KEY_CONVERSION = {
- key: key_alias_list[0]
- for key_alias_list in SPAN_ATTR_KEY_ALIAS_LIST
- for key in key_alias_list
-}
-TAG_TO_ATTR_DICT = {
- "b": {"font_weight": "bold"},
- "big": {"font_size": "larger"},
- "i": {"font_style": "italic"},
- "s": {"strikethrough": "true"},
- "sub": {"baseline_shift": "subscript", "font_scale": "subscript"},
- "sup": {"baseline_shift": "superscript", "font_scale": "superscript"},
- "small": {"font_size": "smaller"},
- "tt": {"font_family": "monospace"},
- "u": {"underline": "single"},
-}
+# Ensure the canvas is large enough to hold all glyphs.
+DEFAULT_CANVAS_WIDTH = 16384
+DEFAULT_CANVAS_HEIGHT = 16384
# Temporary handler
@@ -112,7 +61,7 @@ class _Alignment:
self.value = _Alignment.VAL_DICT[s.upper()]
-class MarkupText(LabelledString):
+class MarkupText(StringMobject):
CONFIG = {
"is_markup": True,
"font_size": 48,
@@ -120,7 +69,7 @@ class MarkupText(LabelledString):
"justify": False,
"indent": 0,
"alignment": "LEFT",
- "line_width_factor": None,
+ "line_width": None,
"font": "",
"slant": NORMAL,
"weight": NORMAL,
@@ -132,6 +81,31 @@ class MarkupText(LabelledString):
"t2w": {},
"global_config": {},
"local_configs": {},
+ # For backward compatibility
+ "isolate": (re.compile(r"[a-zA-Z]+"), re.compile(r"\S+")),
+ }
+
+ # See https://docs.gtk.org/Pango/pango_markup.html
+ MARKUP_COLOR_KEYS = {
+ "foreground": False,
+ "fgcolor": False,
+ "color": False,
+ "background": True,
+ "bgcolor": True,
+ "underline_color": True,
+ "overline_color": True,
+ "strikethrough_color": True,
+ }
+ MARKUP_TAGS = {
+ "b": {"font_weight": "bold"},
+ "big": {"font_size": "larger"},
+ "i": {"font_style": "italic"},
+ "s": {"strikethrough": "true"},
+ "sub": {"baseline_shift": "subscript", "font_scale": "subscript"},
+ "sup": {"baseline_shift": "superscript", "font_scale": "superscript"},
+ "small": {"font_size": "smaller"},
+ "tt": {"font_family": "monospace"},
+ "u": {"underline": "single"},
}
def __init__(self, text: str, **kwargs):
@@ -141,9 +115,7 @@ class MarkupText(LabelledString):
if not self.font:
self.font = get_customization()["style"]["font"]
if self.is_markup:
- validate_error = MarkupUtils.validate(text)
- if validate_error:
- raise ValueError(validate_error)
+ self.validate_markup_string(text)
self.text = text
super().__init__(text, **kwargs)
@@ -165,7 +137,6 @@ class MarkupText(LabelledString):
self.svg_default,
self.path_string_config,
self.base_color,
- self.use_plain_file,
self.isolate,
self.text,
self.is_markup,
@@ -174,7 +145,7 @@ class MarkupText(LabelledString):
self.justify,
self.indent,
self.alignment,
- self.line_width_factor,
+ self.line_width,
self.font,
self.slant,
self.weight,
@@ -201,23 +172,32 @@ class MarkupText(LabelledString):
kwargs[short_name] = kwargs.pop(long_name)
def get_file_path_by_content(self, content: str) -> str:
+ hash_content = str((
+ content,
+ self.justify,
+ self.indent,
+ self.alignment,
+ self.line_width
+ ))
svg_file = os.path.join(
- get_text_dir(), tex_hash(content) + ".svg"
+ get_text_dir(), tex_hash(hash_content) + ".svg"
)
if not os.path.exists(svg_file):
self.markup_to_svg(content, svg_file)
return svg_file
def markup_to_svg(self, markup_str: str, file_name: str) -> str:
+ self.validate_markup_string(markup_str)
+
# `manimpango` is under construction,
# so the following code is intended to suit its interface
alignment = _Alignment(self.alignment)
- if self.line_width_factor is None:
+ if self.line_width is None:
pango_width = -1
else:
- pango_width = self.line_width_factor * DEFAULT_PIXEL_WIDTH
+ pango_width = self.line_width / FRAME_WIDTH * DEFAULT_PIXEL_WIDTH
- return MarkupUtils.text2svg(
+ return manimpango.MarkupUtils.text2svg(
text=markup_str,
font="", # Already handled
slant="NORMAL", # Already handled
@@ -228,8 +208,8 @@ class MarkupText(LabelledString):
file_name=file_name,
START_X=0,
START_Y=0,
- width=DEFAULT_PIXEL_WIDTH,
- height=DEFAULT_PIXEL_HEIGHT,
+ width=DEFAULT_CANVAS_WIDTH,
+ height=DEFAULT_CANVAS_HEIGHT,
justify=self.justify,
indent=self.indent,
line_spacing=None, # Already handled
@@ -237,294 +217,173 @@ class MarkupText(LabelledString):
pango_width=pango_width
)
- def pre_parse(self) -> None:
- super().pre_parse()
- self.tag_items_from_markup = self.get_tag_items_from_markup()
- self.global_dict_from_config = self.get_global_dict_from_config()
- self.local_dicts_from_markup = self.get_local_dicts_from_markup()
- self.local_dicts_from_config = self.get_local_dicts_from_config()
- self.predefined_attr_dicts = self.get_predefined_attr_dicts()
-
- # Toolkits
-
@staticmethod
- def get_attr_dict_str(attr_dict: dict[str, str]) -> str:
- return " ".join([
- f"{key}='{val}'"
- for key, val in attr_dict.items()
- ])
-
- @staticmethod
- def merge_attr_dicts(
- attr_dict_items: list[Span, str, typing.Any]
- ) -> list[tuple[Span, dict[str, str]]]:
- index_seq = [0]
- attr_dict_list = [{}]
- for span, attr_dict in attr_dict_items:
- if span[0] >= span[1]:
- continue
- region_indices = [
- MarkupText.find_region_index(index_seq, index)
- for index in span
- ]
- for flag in (1, 0):
- if index_seq[region_indices[flag]] == span[flag]:
- continue
- region_index = region_indices[flag]
- index_seq.insert(region_index + 1, span[flag])
- attr_dict_list.insert(
- region_index + 1, attr_dict_list[region_index].copy()
- )
- region_indices[flag] += 1
- if flag == 0:
- region_indices[1] += 1
- for key, val in attr_dict.items():
- if not key:
- continue
- for mid_dict in attr_dict_list[slice(*region_indices)]:
- mid_dict[key] = val
- return list(zip(
- MarkupText.get_neighbouring_pairs(index_seq), attr_dict_list[:-1]
- ))
-
- def find_substr_or_span(
- self, substr_or_span: str | tuple[int | None, int | None]
- ) -> list[Span]:
- if isinstance(substr_or_span, str):
- return self.find_substr(substr_or_span)
-
- span = tuple([
- (
- min(index, self.string_len)
- if index >= 0
- else max(index + self.string_len, 0)
- )
- if index is not None else default_index
- for index, default_index in zip(substr_or_span, self.full_span)
- ])
- if span[0] >= span[1]:
- return []
- return [span]
-
- # Pre-parsing
-
- def get_tag_items_from_markup(
- self
- ) -> list[tuple[Span, Span, dict[str, str]]]:
- if not self.is_markup:
- return []
-
- tag_pattern = r"""<(/?)(\w+)\s*((?:\w+\s*\=\s*(['"]).*?\4\s*)*)>"""
- attr_pattern = r"""(\w+)\s*\=\s*(['"])(.*?)\2"""
- begin_match_obj_stack = []
- match_obj_pairs = []
- for match_obj in self.finditer(tag_pattern):
- if not match_obj.group(1):
- begin_match_obj_stack.append(match_obj)
- else:
- match_obj_pairs.append(
- (begin_match_obj_stack.pop(), match_obj)
- )
- if begin_match_obj_stack:
- raise ValueError("Unclosed tag(s) detected")
-
- result = []
- for begin_match_obj, end_match_obj in match_obj_pairs:
- tag_name = begin_match_obj.group(2)
- if tag_name != end_match_obj.group(2):
- raise ValueError("Unmatched tag names")
- if end_match_obj.group(3):
- raise ValueError("Attributes shan't exist in ending tags")
- if tag_name == "span":
- attr_dict = {
- match.group(1): match.group(3)
- for match in re.finditer(
- attr_pattern, begin_match_obj.group(3)
- )
- }
- elif tag_name in TAG_TO_ATTR_DICT.keys():
- if begin_match_obj.group(3):
- raise ValueError(
- f"Attributes shan't exist in tag '{tag_name}'"
- )
- attr_dict = TAG_TO_ATTR_DICT[tag_name].copy()
- else:
- raise ValueError(f"Unknown tag: '{tag_name}'")
-
- result.append(
- (begin_match_obj.span(), end_match_obj.span(), attr_dict)
- )
- return result
-
- def get_global_dict_from_config(self) -> dict[str, typing.Any]:
- result = {
- "line_height": (
- (self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1
- ) * 0.6,
- "font_family": self.font,
- "font_size": self.font_size * 1024,
- "font_style": self.slant,
- "font_weight": self.weight
- }
- result.update(self.global_config)
- return result
-
- def get_local_dicts_from_markup(
- self
- ) -> list[Span, dict[str, str]]:
- return sorted([
- ((begin_tag_span[0], end_tag_span[1]), attr_dict)
- for begin_tag_span, end_tag_span, attr_dict
- in self.tag_items_from_markup
- ])
-
- def get_local_dicts_from_config(
- self
- ) -> list[Span, dict[str, typing.Any]]:
- return [
- (span, {key: val})
- for t2x_dict, key in (
- (self.t2c, "foreground"),
- (self.t2f, "font_family"),
- (self.t2s, "font_style"),
- (self.t2w, "font_weight")
- )
- for substr_or_span, val in t2x_dict.items()
- for span in self.find_substr_or_span(substr_or_span)
- ] + [
- (span, local_config)
- for substr_or_span, local_config in self.local_configs.items()
- for span in self.find_substr_or_span(substr_or_span)
- ]
-
- def get_predefined_attr_dicts(self) -> list[Span, dict[str, str]]:
- attr_dict_items = [
- (self.full_span, self.global_dict_from_config),
- *self.local_dicts_from_markup,
- *self.local_dicts_from_config
- ]
- return [
- (span, {
- SPAN_ATTR_KEY_CONVERSION[key.lower()]: str(val)
- for key, val in attr_dict.items()
- })
- for span, attr_dict in attr_dict_items
- ]
+ def validate_markup_string(markup_str: str) -> None:
+ validate_error = manimpango.MarkupUtils.validate(markup_str)
+ if not validate_error:
+ return
+ raise ValueError(
+ f"Invalid markup string \"{markup_str}\"\n"
+ f"{validate_error}"
+ )
# Parsing
- def get_command_repl_items(self) -> list[tuple[Span, str]]:
- result = [
- (tag_span, "")
- for begin_tag, end_tag, _ in self.tag_items_from_markup
- for tag_span in (begin_tag, end_tag)
- ]
+ def get_cmd_spans(self) -> list[Span]:
if not self.is_markup:
- result += [
- (span, escaped)
- for char, escaped in (
- ("&", "&"),
- (">", ">"),
- ("<", "<")
- )
- for span in self.find_substr(char)
- ]
- return result
+ return self.find_spans(r"""[<>&"']""")
- def get_extra_entity_spans(self) -> list[Span]:
- if not self.is_markup:
- return []
- return self.find_spans(r"&.*?;")
-
- def get_extra_ignored_spans(self) -> list[int]:
- return []
-
- def get_internal_specified_spans(self) -> list[Span]:
- return [span for span, _ in self.local_dicts_from_markup]
-
- def get_external_specified_spans(self) -> list[Span]:
- return [span for span, _ in self.local_dicts_from_config]
-
- def get_label_span_list(self) -> list[Span]:
- breakup_indices = remove_list_redundancies(list(it.chain(*it.chain(
- self.find_spans(r"\s+"),
- self.find_spans(r"\b"),
- self.specified_spans
- ))))
- breakup_indices = sorted(filter(
- lambda index: not any([
- span[0] < index < span[1]
- for span in self.entity_spans
- ]),
- breakup_indices
- ))
- return list(filter(
- lambda span: self.get_substr(span).strip(),
- self.get_neighbouring_pairs(breakup_indices)
- ))
-
- def get_content(self, use_plain_file: bool) -> str:
- if use_plain_file:
- attr_dict_items = [
- (self.full_span, {"foreground": self.base_color}),
- *self.predefined_attr_dicts,
- *[
- (span, {})
- for span in self.label_span_list
- ]
- ]
- else:
- attr_dict_items = [
- (self.full_span, {"foreground": BLACK}),
- *[
- (span, {
- key: BLACK if key in COLOR_RELATED_KEYS else val
- for key, val in attr_dict.items()
- })
- for span, attr_dict in self.predefined_attr_dicts
- ],
- *[
- (span, {"foreground": self.int_to_hex(label + 1)})
- for label, span in enumerate(self.label_span_list)
- ]
- ]
- inserted_string_pairs = [
- (span, (
- f"",
- ""
- ))
- for span, attr_dict in self.merge_attr_dicts(attr_dict_items)
- ]
- span_repl_dict = self.generate_span_repl_dict(
- inserted_string_pairs, self.command_repl_items
+ # Unsupported passthroughs:
+ # "...?>", "", "", ""
+ # See https://gitlab.gnome.org/GNOME/glib/-/blob/main/glib/gmarkup.c
+ return self.find_spans(
+ r"""&[\s\S]*?;|[>"']|?\w+(?:\s*\w+\s*\=\s*(["'])[\s\S]*?\1)*/?>"""
)
- return self.get_replaced_substr(self.full_span, span_repl_dict)
- @property
- def has_predefined_local_colors(self) -> bool:
- return any([
- key in COLOR_RELATED_KEYS
- for _, attr_dict in self.predefined_attr_dicts
- for key in attr_dict.keys()
+ def get_substr_flag(self, substr: str) -> int:
+ if re.fullmatch(r"<\w[\s\S]*[^/]>", substr):
+ return 1
+ if substr.startswith(""):
+ return -1
+ return 0
+
+ def get_repl_substr_for_content(self, substr: str) -> str:
+ if substr.startswith("<") and substr.endswith(">"):
+ return ""
+ return {
+ "<": "<",
+ ">": ">",
+ "&": "&",
+ "\"": """,
+ "'": "'"
+ }.get(substr, substr)
+
+ def get_repl_substr_for_matching(self, substr: str) -> str:
+ if substr.startswith("<") and substr.endswith(">"):
+ return ""
+ if substr.startswith("") and substr.endswith(";"):
+ if substr.startswith(""):
+ char_reference = int(substr[3:-1], 16)
+ else:
+ char_reference = int(substr[2:-1], 10)
+ return chr(char_reference)
+ return {
+ "<": "<",
+ ">": ">",
+ "&": "&",
+ """: "\"",
+ "'": "'"
+ }.get(substr, substr)
+
+ def get_specified_items(
+ self, cmd_span_pairs: list[tuple[Span, Span]]
+ ) -> list[tuple[Span, dict[str, str]]]:
+ attr_pattern = r"""(\w+)\s*\=\s*(["'])([\s\S]*?)\2"""
+ internal_items = []
+ for begin_cmd_span, end_cmd_span in cmd_span_pairs:
+ begin_tag = self.get_substr(begin_cmd_span)
+ tag_name = re.match(r"<(\w+)", begin_tag).group(1)
+ if tag_name == "span":
+ attr_dict = {
+ attr_match_obj.group(1): attr_match_obj.group(3)
+ for attr_match_obj in re.finditer(attr_pattern, begin_tag)
+ }
+ else:
+ attr_dict = MarkupText.MARKUP_TAGS.get(tag_name, {})
+ internal_items.append(
+ ((begin_cmd_span[1], end_cmd_span[0]), attr_dict)
+ )
+
+ return [
+ *internal_items,
+ *[
+ (span, {key: val})
+ for t2x_dict, key in (
+ (self.t2c, "foreground"),
+ (self.t2f, "font_family"),
+ (self.t2s, "font_style"),
+ (self.t2w, "font_weight")
+ )
+ for selector, val in t2x_dict.items()
+ for span in self.find_spans_by_selector(selector)
+ ],
+ *[
+ (span, local_config)
+ for selector, local_config in self.local_configs.items()
+ for span in self.find_spans_by_selector(selector)
+ ],
+ *[
+ (span, {})
+ for span in self.find_spans_by_selector(self.isolate)
+ ]
+ ]
+
+ @staticmethod
+ def get_cmd_str_pair(
+ attr_dict: dict[str, str], label_hex: str | None
+ ) -> tuple[str, str]:
+ if label_hex is not None:
+ converted_attr_dict = {"foreground": label_hex}
+ for key, val in attr_dict.items():
+ substitute_key = MarkupText.MARKUP_COLOR_KEYS.get(key, None)
+ if substitute_key is None:
+ converted_attr_dict[key] = val
+ elif substitute_key:
+ converted_attr_dict[key] = "black"
+ else:
+ converted_attr_dict = attr_dict.copy()
+ attrs_str = " ".join([
+ f"{key}='{val}'"
+ for key, val in converted_attr_dict.items()
])
+ return f"", ""
+
+ def get_content_prefix_and_suffix(
+ self, is_labelled: bool
+ ) -> tuple[str, str]:
+ global_attr_dict = {
+ "foreground": self.base_color_hex,
+ "font_family": self.font,
+ "font_style": self.slant,
+ "font_weight": self.weight,
+ "font_size": str(self.font_size * 1024),
+ }
+ global_attr_dict.update(self.global_config)
+ # `line_height` attribute is supported since Pango 1.50.
+ pango_version = manimpango.pango_version()
+ if tuple(map(int, pango_version.split("."))) < (1, 50):
+ if self.lsh is not None:
+ log.warning(
+ "Pango version %s found (< 1.50), "
+ "unable to set `line_height` attribute",
+ pango_version
+ )
+ else:
+ line_spacing_scale = self.lsh or DEFAULT_LINE_SPACING_SCALE
+ global_attr_dict["line_height"] = str(
+ ((line_spacing_scale) + 1) * 0.6
+ )
+
+ return self.get_cmd_str_pair(
+ global_attr_dict,
+ label_hex=self.int_to_hex(0) if is_labelled else None
+ )
# Method alias
- def get_parts_by_text(self, text: str, **kwargs) -> VGroup:
- return self.get_parts_by_string(text, **kwargs)
+ def get_parts_by_text(self, selector: Selector) -> VGroup:
+ return self.select_parts(selector)
- def get_part_by_text(self, text: str, **kwargs) -> VMobject:
- return self.get_part_by_string(text, **kwargs)
+ def get_part_by_text(self, selector: Selector) -> VGroup:
+ return self.select_part(selector)
- def set_color_by_text(self, text: str, color: ManimColor, **kwargs):
- return self.set_color_by_string(text, color, **kwargs)
+ def set_color_by_text(self, selector: Selector, color: ManimColor):
+ return self.set_parts_color(selector, color)
def set_color_by_text_to_color_map(
- self, text_to_color_map: dict[str, ManimColor], **kwargs
+ self, color_map: dict[Selector, ManimColor]
):
- return self.set_color_by_string_to_color_map(
- text_to_color_map, **kwargs
- )
+ return self.set_parts_color_by_dict(color_map)
def get_text(self) -> str:
return self.get_string()
diff --git a/manimlib/scene/interactive_scene.py b/manimlib/scene/interactive_scene.py
index 40ec18c5..3b7397b8 100644
--- a/manimlib/scene/interactive_scene.py
+++ b/manimlib/scene/interactive_scene.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import itertools as it
import numpy as np
import pyperclip
diff --git a/manimlib/utils/iterables.py b/manimlib/utils/iterables.py
index bdaa76c2..fa54e68d 100644
--- a/manimlib/utils/iterables.py
+++ b/manimlib/utils/iterables.py
@@ -13,7 +13,7 @@ if TYPE_CHECKING:
S = TypeVar("S")
-def remove_list_redundancies(l: Iterable[T]) -> list[T]:
+def remove_list_redundancies(l: Sequence[T]) -> list[T]:
"""
Used instead of list(set(l)) to maintain order
Keeps the last occurrence of each element
@@ -40,14 +40,14 @@ def list_difference_update(l1: Iterable[T], l2: Iterable[T]) -> list[T]:
return [e for e in l1 if e not in l2]
-def adjacent_n_tuples(objects: Iterable[T], n: int) -> zip[tuple[T, T]]:
+def adjacent_n_tuples(objects: Sequence[T], n: int) -> zip[tuple[T, T]]:
return zip(*[
[*objects[k:], *objects[:k]]
for k in range(n)
])
-def adjacent_pairs(objects: Iterable[T]) -> zip[tuple[T, T]]:
+def adjacent_pairs(objects: Sequence[T]) -> zip[tuple[T, T]]:
return adjacent_n_tuples(objects, 2)