diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index 90ffa76f..486007dd 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -12,7 +12,7 @@ from manimlib.animation.transform import ReplacementTransform from manimlib.animation.transform import Transform from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Group -from manimlib.mobject.svg.mtex_mobject import MTex +from manimlib.mobject.svg.mtex_mobject import LabelledString from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.config_ops import digest_config @@ -153,15 +153,16 @@ class TransformMatchingTex(TransformMatchingParts): return mobject.get_tex() -class TransformMatchingMTex(AnimationGroup): +class TransformMatchingString(AnimationGroup): CONFIG = { "key_map": dict(), + "transform_mismatches_class": None, } - def __init__(self, source_mobject: MTex, target_mobject: MTex, **kwargs): + def __init__(self, source_mobject: LabelledString, target_mobject: LabelledString, **kwargs): digest_config(self, kwargs) - assert isinstance(source_mobject, MTex) - assert isinstance(target_mobject, MTex) + assert isinstance(source_mobject, LabelledString) + assert isinstance(target_mobject, LabelledString) anims = [] rest_source_submobs = source_mobject.submobjects.copy() rest_target_submobs = target_mobject.submobjects.copy() @@ -207,7 +208,7 @@ class TransformMatchingMTex(AnimationGroup): elif isinstance(key, range): indices.extend(key) elif isinstance(key, str): - all_parts = mobject.get_parts_by_tex(key) + all_parts = mobject.get_parts_by_string(key) indices.extend(it.chain(*[ mobject.indices_of_part(part) for part in all_parts ])) @@ -228,31 +229,34 @@ class TransformMatchingMTex(AnimationGroup): target_mobject.get_specified_substrings() ) ), key=len, reverse=True) - for part_tex_string in common_specified_substrings: + for part_string in common_specified_substrings: add_anim_from( - FadeTransformPieces, MTex.get_parts_by_tex, part_tex_string + FadeTransformPieces, LabelledString.get_parts_by_string, part_string ) - common_submob_tex_strings = { - source_submob.get_tex() for source_submob in source_mobject + common_submob_strings = { + source_submob.get_string() for source_submob in source_mobject }.intersection({ - target_submob.get_tex() for target_submob in target_mobject + target_submob.get_string() for target_submob in target_mobject }) - for tex_string in common_submob_tex_strings: + for substr in common_submob_strings: add_anim_from( FadeTransformPieces, lambda mobject, attr: VGroup(*[ VGroup(mob) for mob in mobject - if mob.get_tex() == attr + if mob.get_string() == attr ]), - tex_string + substr ) - anims.append(FadeOutToPoint( - VGroup(*rest_source_submobs), target_mobject.get_center(), **kwargs - )) - anims.append(FadeInFromPoint( - VGroup(*rest_target_submobs), source_mobject.get_center(), **kwargs - )) + if self.transform_mismatches_class is not None: + anims.append(self.transform_mismatches_class(fade_source, fade_target, **kwargs)) + else: + anims.append(FadeOutToPoint( + VGroup(*rest_source_submobs), target_mobject.get_center(), **kwargs + )) + anims.append(FadeInFromPoint( + VGroup(*rest_target_submobs), source_mobject.get_center(), **kwargs + )) super().__init__(*anims) diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index a14004e7..efc542b2 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -5,8 +5,9 @@ import colour import itertools as it from types import MethodType from typing import Iterable, Union, Sequence +from abc import abstractmethod -from manimlib.constants import BLACK, WHITE +from manimlib.constants import WHITE from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.utils.color import color_to_int_rgb @@ -24,17 +25,15 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from manimlib.mobject.types.vectorized_mobject import VMobject ManimColor = Union[str, colour.Color, Sequence[float]] + Span = tuple[int, int] SCALE_FACTOR_PER_FONT_POINT = 0.001 -class _TexSVG(SVGMobject): +class _StringSVG(SVGMobject): CONFIG = { "height": None, - "svg_default": { - "fill_color": BLACK, - }, "stroke_width": 0, "stroke_color": WHITE, "path_string_config": { @@ -44,75 +43,29 @@ class _TexSVG(SVGMobject): } -class MTex(_TexSVG): +class LabelledString(_StringSVG): + """ + An abstract base class for `MTex` and `MarkupText` + """ CONFIG = { "base_color": WHITE, - "font_size": 48, - "alignment": "\\centering", - "tex_environment": "align*", - "isolate": [], - "tex_to_color_map": {}, - "use_plain_tex": False, + "use_plain_file": False, } def __init__(self, string: str, **kwargs): - digest_config(self, kwargs) - string = string.strip() - # Prevent from passing an empty string. - if not string: - string = "\\quad" - self.tex_string = string self.string = string super().__init__(**kwargs) - self.set_color_by_tex_to_color_map(self.tex_to_color_map) - self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) - - @property - def hash_seed(self) -> tuple: - return ( - self.__class__.__name__, - self.svg_default, - self.path_string_config, - self.string, - self.base_color, - self.alignment, - self.tex_environment, - self.isolate, - self.tex_to_color_map, - self.use_plain_tex - ) - def get_file_path(self, use_plain_file: bool = False) -> str: if use_plain_file: content = self.plain_string else: content = self.labelled_string + return self.get_file_path_by_content(content) - full_tex = self.get_tex_file_body(content) - with display_during_execution(f"Writing \"{self.string}\""): - file_path = self.tex_to_svg_file_path(full_tex) - return file_path - - def get_tex_file_body(self, content: str) -> str: - if self.tex_environment: - content = "\n".join([ - f"\\begin{{{self.tex_environment}}}", - content, - f"\\end{{{self.tex_environment}}}" - ]) - if self.alignment: - content = "\n".join([self.alignment, content]) - - tex_config = get_tex_config() - return tex_config["tex_body"].replace( - tex_config["text_to_replace"], - content - ) - - @staticmethod - def tex_to_svg_file_path(tex_file_content: str) -> str: - return tex_to_svg_file(tex_file_content) + @abstractmethod + def get_file_path_by_content(self, content: str) -> str: + return "" def generate_mobject(self) -> None: super().generate_mobject() @@ -125,13 +78,9 @@ class MTex(_TexSVG): for glyph in self.submobjects ] - if any([ - self.use_plain_tex, - self.color_cmd_repl_items, - self.base_color in (BLACK, WHITE) - ]): + if self.use_plain_file or self.has_predefined_colors: file_path = self.get_file_path(use_plain_file=True) - glyphs = _TexSVG(file_path).submobjects + glyphs = _StringSVG(file_path).submobjects for glyph, plain_glyph in zip(self.submobjects, glyphs): glyph.set_fill(plain_glyph.get_fill_color()) else: @@ -142,21 +91,19 @@ class MTex(_TexSVG): submob_labels, glyphs_lists = self.group_neighbours( glyph_labels, glyphs ) - submobjects = [ - VGroup(*glyph_list) - for glyph_list in glyphs_lists - ] - submob_tex_strings = self.get_submob_tex_strings(submob_labels) - for submob, label, submob_tex in zip( - submobjects, submob_labels, submob_tex_strings + submob_strings = self.get_submob_strings(submob_labels) + submobjects = [] + for glyph_list, label, submob_string in zip( + glyphs_lists, submob_labels, submob_strings ): + submob = VGroup(*glyph_list) submob.label = label - submob.tex_string = submob_tex - # Support `get_tex()` method here. - submob.get_tex = MethodType(lambda inst: inst.tex_string, submob) + submob.string = submob_string + submob.get_string = MethodType(lambda inst: inst.string, submob) + submobjects.append(submob) self.set_submobjects(submobjects) - ## Static methods + # Toolkits @staticmethod def color_to_label(color: ManimColor) -> int: @@ -167,23 +114,14 @@ class MTex(_TexSVG): return -1 return rgb - @staticmethod - def get_color_command(label: int) -> str: - if label == -1: - label = 16777215 # white - rg, b = divmod(label, 256) - r, g = divmod(rg, 256) - return "".join([ - "\\color[RGB]", - "{", - ",".join(map(str, (r, g, b))), - "}" - ]) - @staticmethod def get_neighbouring_pairs(iterable: Iterable) -> list: return list(adjacent_pairs(iterable))[:-1] + @staticmethod + def span_contains(span_0: Span, span_1: Span) -> bool: + return span_0[0] <= span_1[0] and span_0[1] >= span_1[1] + @staticmethod def group_neighbours( labels: Iterable[object], @@ -211,79 +149,412 @@ class MTex(_TexSVG): @staticmethod def find_region_index(val: int, seq: list[int]) -> int: - # Returns an integer in `range(len(seq) + 1)` satisfying - # `seq[result - 1] <= val < seq[result]` - if not seq: - return 0 - if val >= seq[-1]: - return len(seq) - result = 0 - while val >= seq[result]: - result += 1 + # Returns an integer in `range(-1, len(seq))` satisfying + # `seq[result] <= val < seq[result + 1]`. + # `seq` should be sorted in ascending order. + if not seq or val < seq[0]: + return -1 + result = len(seq) - 1 + while val < seq[result]: + result -= 1 return result @staticmethod - def lstrip(index: int, skipped_spans: list[tuple[int, int]]) -> int: - index_seq = list(it.chain(*skipped_spans)) - region_index = MTex.find_region_index(index, index_seq) - if region_index % 2 == 1: + def replace_str_by_spans( + substr: str, span_repl_dict: dict[Span, str] + ) -> str: + if not span_repl_dict: + return substr + + spans = sorted(span_repl_dict.keys()) + if not all( + span_0[1] <= span_1[0] + for span_0, span_1 in LabelledString.get_neighbouring_pairs(spans) + ): + raise ValueError("Overlapping replacement") + + span_ends, span_begins = zip(*spans) + pieces = [ + substr[slice(*span)] + for span in zip( + (0, *span_begins), + (*span_ends, len(substr)) + ) + ] + repl_strs = [*[span_repl_dict[span] for span in spans], ""] + return "".join(it.chain(*zip(pieces, repl_strs))) + + @staticmethod + def get_span_replacement_dict( + inserted_string_pairs: list[tuple[Span, tuple[str, str]]], + other_repl_items: list[tuple[Span, str]] + ) -> dict[Span, str]: + if not inserted_string_pairs: + return other_repl_items.copy() + + indices, _, _, inserted_strings = zip(*sorted([ + ( + span[flag], + -flag, + -span[1 - flag], + str_pair[flag] + ) + for span, str_pair in inserted_string_pairs + for flag in range(2) + ])) + result = { + (index, index): "".join(inserted_strs) + for index, inserted_strs in zip(*LabelledString.group_neighbours( + indices, inserted_strings + )) + } + result.update(other_repl_items) + return result + + @property + def skipped_spans(self) -> list[Span]: + return [] + + def lstrip(self, index: int) -> int: + index_seq = list(it.chain(*self.skipped_spans)) + region_index = self.find_region_index(index, index_seq) + if region_index % 2 == 0: + return index_seq[region_index + 1] + return index + + def rstrip(self, index: int) -> int: + index_seq = list(it.chain(*self.skipped_spans)) + region_index = self.find_region_index(index - 1, index_seq) + if region_index % 2 == 0: return index_seq[region_index] return index - @staticmethod - def rstrip(index: int, skipped_spans: list[tuple[int, int]]) -> int: - index_seq = list(it.chain(*skipped_spans)) - region_index = MTex.find_region_index(index - 1, index_seq) - if region_index % 2 == 1: - return index_seq[region_index - 1] - return index - - @staticmethod - def strip( - tex_span: tuple[int, int], skipped_spans: list[tuple[int, int]] - ) -> tuple[int, int] | None: + def strip(self, span: Span) -> Span | None: result = ( - MTex.lstrip(tex_span[0], skipped_spans), - MTex.rstrip(tex_span[1], skipped_spans) + self.lstrip(span[0]), + self.rstrip(span[1]) ) if result[0] >= result[1]: return None return result @staticmethod - def lslide(index: int, slid_spans: list[tuple[int, int]]) -> int: - slide_dict = dict(slid_spans) + def lslide(index: int, slid_spans: list[Span]) -> int: + slide_dict = dict(sorted(slid_spans)) while index in slide_dict.keys(): index = slide_dict[index] return index @staticmethod - def rslide(index: int, slid_spans: list[tuple[int, int]]) -> int: - slide_dict = dict([ + def rslide(index: int, slid_spans: list[Span]) -> int: + slide_dict = dict(sorted([ slide_span[::-1] for slide_span in slid_spans - ]) + ], reverse=True)) while index in slide_dict.keys(): index = slide_dict[index] return index @staticmethod - def slide( - tex_span: tuple[int, int], slid_spans: list[tuple[int, int]] - ) -> tuple[int, int] | None: + def slide(span: Span, slid_spans: list[Span]) -> Span | None: result = ( - MTex.lslide(tex_span[0], slid_spans), - MTex.rslide(tex_span[1], slid_spans) + LabelledString.lslide(span[0], slid_spans), + LabelledString.rslide(span[1], slid_spans) ) if result[0] >= result[1]: return None return result - ## Parser + # Parser @property - def full_span(self) -> tuple[int, int]: + def full_span(self) -> Span: return (0, len(self.string)) + def get_substrs_to_isolate(self, substrs: list[str]) -> list[str]: + result = list(filter( + lambda s: s in self.string, + remove_list_redundancies(substrs) + )) + if "" in result: + result.remove("") + return result + + @property + def label_span_list(self) -> list[Span]: + return [] + + @property + def inserted_string_pairs(self) -> list[tuple[Span, tuple[str, str]]]: + return [] + + @property + def command_repl_items(self) -> list[tuple[Span, str]]: + return [] + + @abstractmethod + def has_predefined_colors(self) -> bool: + return False + + @property + def plain_string(self) -> str: + return self.string + + @property + def labelled_string(self) -> str: + return self.replace_str_by_spans( + self.string, self.get_span_replacement_dict( + self.inserted_string_pairs, + self.command_repl_items + ) + ) + + @property + def ignored_indices_for_submob_strings(self) -> list[int]: + return [] + + def handle_submob_string(self, substr: str, string_span: Span) -> str: + return substr + + def get_submob_strings(self, submob_labels: list[int]) -> list[str]: + ordered_spans = [ + self.label_span_list[label] if label != -1 else self.full_span + for label in submob_labels + ] + ordered_containing_labels = [ + self.containing_labels_dict[span] + for span in ordered_spans + ] + ordered_span_begins, ordered_span_ends = zip(*ordered_spans) + string_span_begins = [ + prev_end if prev_label in containing_labels else curr_begin + for prev_end, prev_label, containing_labels, curr_begin in zip( + ordered_span_ends[:-1], submob_labels[:-1], + ordered_containing_labels[1:], ordered_span_begins[1:] + ) + ] + string_span_ends = [ + next_begin if next_label in containing_labels else curr_end + for next_begin, next_label, containing_labels, curr_end in zip( + ordered_span_begins[1:], submob_labels[1:], + ordered_containing_labels[:-1], ordered_span_ends[:-1] + ) + ] + string_spans = list(zip( + (ordered_span_begins[0], *string_span_begins), + (*string_span_ends, ordered_span_ends[-1]) + )) + + command_spans = [span for span, _ in self.command_repl_items] + slid_spans = list(it.chain( + self.skipped_spans, + command_spans, + [ + (index, index + 1) + for index in self.ignored_indices_for_submob_strings + ] + )) + result = [] + for string_span in string_spans: + string_span = self.slide(string_span, slid_spans) + if string_span is None: + result.append("") + continue + + span_repl_dict = { + tuple([index - string_span[0] for index in cmd_span]): "" + for cmd_span in command_spans + if self.span_contains(string_span, cmd_span) + } + substr = self.string[slice(*string_span)] + substr = self.replace_str_by_spans(substr, span_repl_dict) + substr = self.handle_submob_string(substr, string_span) + result.append(substr) + return result + + # Selector + + @property + def containing_labels_dict(self) -> dict[Span, list[int]]: + label_span_list = self.label_span_list + result = { + span: [] + for span in label_span_list + } + for span_0 in label_span_list: + for span_index, span_1 in enumerate(label_span_list): + if self.span_contains(span_0, span_1): + result[span_0].append(span_index) + elif span_0[0] < span_1[0] < span_0[1] < span_1[1]: + string_0, string_1 = [ + self.string[slice(*span)] + for span in [span_0, span_1] + ] + raise ValueError( + "Partially overlapping substrings detected: " + f"'{string_0}' and '{string_1}'" + ) + result[self.full_span] = list(range(-1, len(label_span_list))) + return result + + def find_span_components_of_custom_span( + self, custom_span: Span + ) -> list[Span] | None: + span_choices = sorted(filter( + lambda span: self.span_contains(custom_span, span), + self.label_span_list + )) + # Choose spans that reach the farthest. + span_choices_dict = dict(span_choices) + + result = [] + span_begin, span_end = custom_span + span_begin = self.rstrip(span_begin) + span_end = self.rstrip(span_end) + while span_begin != span_end: + span_begin = self.lstrip(span_begin) + if span_begin not in span_choices_dict.keys(): + return None + next_begin = span_choices_dict[span_begin] + result.append((span_begin, next_begin)) + span_begin = next_begin + return result + + def get_part_by_custom_span(self, custom_span: Span) -> VGroup: + spans = self.find_span_components_of_custom_span(custom_span) + if spans is None: + substr = self.string[slice(*custom_span)] + raise ValueError(f"Failed to match mobjects from \"{substr}\"") + + labels = set(it.chain(*[ + self.containing_labels_dict[span] + for span in spans + ])) + return VGroup(*filter( + lambda submob: submob.label in labels, + self.submobjects + )) + + def get_parts_by_string(self, substr: str) -> VGroup: + return VGroup(*[ + self.get_part_by_custom_span(match_obj.span()) + for match_obj in re.finditer(re.escape(substr), self.string) + ]) + + def get_part_by_string(self, substr: str, index: int = 0) -> VMobject: + all_parts = self.get_parts_by_string(substr) + return all_parts[index] + + def set_color_by_string(self, substr: str, color: ManimColor): + self.get_parts_by_string(substr).set_color(color) + return self + + def set_color_by_string_to_color_map( + self, string_to_color_map: dict[str, ManimColor] + ): + for substr, color in string_to_color_map.items(): + self.set_color_by_string(substr, color) + return self + + def indices_of_part(self, part: Iterable[VMobject]) -> list[int]: + indices = [ + index for index, submob in enumerate(self.submobjects) + if submob in part + ] + if not indices: + raise ValueError("Failed to find part") + return indices + + def indices_of_part_by_string( + self, substr: str, index: int = 0 + ) -> list[int]: + part = self.get_part_by_string(substr, index=index) + return self.indices_of_part(part) + + @property + def specified_substrings(self) -> list[str]: + return [] + + def get_specified_substrings(self) -> list[str]: + return self.specified_substrings + + @property + def isolated_substrings(self) -> list[str]: + return remove_list_redundancies([ + self.string[slice(*span)] + for span in self.label_span_list + ]) + + def get_isolated_substrings(self) -> list[str]: + return self.isolated_substrings + + def get_string(self) -> str: + return self.string + + +class MTex(LabelledString): + CONFIG = { + "font_size": 48, + "alignment": "\\centering", + "tex_environment": "align*", + "isolate": [], + "tex_to_color_map": {}, + "use_plain_file": False, + } + + def __init__(self, tex_string: str, **kwargs): + tex_string = tex_string.strip() + # Prevent from passing an empty string. + if not tex_string: + tex_string = "\\quad" + self.tex_string = tex_string + super().__init__(tex_string, **kwargs) + + self.set_color_by_tex_to_color_map(self.tex_to_color_map) + self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) + + @property + def hash_seed(self) -> tuple: + return ( + self.__class__.__name__, + self.svg_default, + self.path_string_config, + self.tex_string, + self.base_color, + self.alignment, + self.tex_environment, + self.isolate, + self.tex_to_color_map, + self.use_plain_file + ) + + def get_file_path_by_content(self, content: str) -> str: + full_tex = self.get_tex_file_body(content) + with display_during_execution(f"Writing \"{self.string}\""): + file_path = self.tex_to_svg_file_path(full_tex) + return file_path + + def get_tex_file_body(self, content: str) -> str: + if self.tex_environment: + content = "\n".join([ + f"\\begin{{{self.tex_environment}}}", + content, + f"\\end{{{self.tex_environment}}}" + ]) + if self.alignment: + content = "\n".join([self.alignment, content]) + + tex_config = get_tex_config() + return tex_config["tex_body"].replace( + tex_config["text_to_replace"], + content + ) + + @staticmethod + def tex_to_svg_file_path(tex_file_content: str) -> str: + return tex_to_svg_file(tex_file_content) + + # Parser + @property def backslash_indices(self) -> list[int]: # Newlines (`\\`) don't count. @@ -293,9 +564,7 @@ class MTex(_TexSVG): if len(match_obj.group()) % 2 == 1 ] - def get_left_and_right_brace_indices( - self - ) -> tuple[list[tuple[int, int]], list[tuple[int, int]]]: + def get_brace_indices_lists(self) -> tuple[list[Span], list[Span]]: string = self.string indices = list(filter( lambda index: index - 1 not in self.backslash_indices, @@ -322,15 +591,15 @@ class MTex(_TexSVG): return left_brace_indices, right_brace_indices @property - def left_brace_indices(self) -> list[tuple[int, int]]: - return self.get_left_and_right_brace_indices()[0] + def left_brace_indices(self) -> list[Span]: + return self.get_brace_indices_lists()[0] @property - def right_brace_indices(self) -> list[tuple[int, int]]: - return self.get_left_and_right_brace_indices()[1] + def right_brace_indices(self) -> list[Span]: + return self.get_brace_indices_lists()[1] @property - def skipped_spans(self) -> list[tuple[int, int]]: + def skipped_spans(self) -> list[Span]: return [ match_obj.span() for match_obj in re.finditer(r"\s*([_^])\s*|(\s+)", self.string) @@ -338,24 +607,15 @@ class MTex(_TexSVG): or match_obj.start(1) - 1 not in self.backslash_indices ] - def lstrip_span(self, index: int) -> int: - return self.lstrip(index, self.skipped_spans) - - def rstrip_span(self, index: int) -> int: - return self.rstrip(index, self.skipped_spans) - - def strip_span(self, index: int) -> int: - return self.strip(index, self.skipped_spans) - @property - def script_char_spans(self) -> list[tuple[int, int]]: + def script_char_spans(self) -> list[Span]: return list(filter( - lambda tex_span: self.string[slice(*tex_span)].strip(), + lambda span: self.string[slice(*span)].strip(), self.skipped_spans )) @property - def script_content_spans(self) -> list[tuple[int, int]]: + def script_content_spans(self) -> list[Span]: result = [] brace_indices_dict = dict(zip( self.left_brace_indices, self.right_brace_indices @@ -380,7 +640,7 @@ class MTex(_TexSVG): return result @property - def double_braces_spans(self) -> list[tuple[int, int]]: + def double_braces_spans(self) -> list[Span]: # Match paired double braces (`{{...}}`). result = [] reversed_brace_indices_dict = dict(zip( @@ -405,19 +665,12 @@ class MTex(_TexSVG): @property def additional_substrings(self) -> list[str]: - result = remove_list_redundancies(list(it.chain( + return self.get_substrs_to_isolate(list(it.chain( self.tex_to_color_map.keys(), self.isolate ))) - if "" in result: - result.remove("") - return result - def get_tex_span_lists( - self - ) -> tuple[list[tuple[int, int]], list[tuple[int, int]]]: - result = [] - extended_result = [] + def get_label_span_list(self, extended: bool) -> list[Span]: script_content_spans = self.script_content_spans script_spans = [ (script_char_span[0], script_content_span[1]) @@ -425,65 +678,54 @@ class MTex(_TexSVG): self.script_char_spans, script_content_spans ) ] - tex_spans = remove_list_redundancies([ + spans = remove_list_redundancies([ self.full_span, *self.double_braces_spans, *filter(lambda stripped_span: stripped_span is not None, [ - self.strip_span(match_obj.span()) + self.strip(match_obj.span()) for substr in self.additional_substrings for match_obj in re.finditer(re.escape(substr), self.string) ]), *script_content_spans ]) - for tex_span in tex_spans: - if tex_span in script_content_spans: - result.append(tex_span) - extended_result.append(tex_span) + result = [] + for span in spans: + if span in script_content_spans: continue - span_begin, span_end = tex_span - shrinked_span = (span_begin, self.rslide(span_end, script_spans)) - extended_span = (span_begin, self.lslide(span_end, script_spans)) - if shrinked_span[0] >= shrinked_span[1]: + span_begin, span_end = span + shrinked_end = self.rslide(span_end, script_spans) + if span_begin >= shrinked_end: continue + shrinked_span = (span_begin, shrinked_end) if shrinked_span in result: continue result.append(shrinked_span) - extended_result.append(extended_span) - return result, extended_result + + if extended: + result = [ + (span_begin, self.lslide(span_end, script_spans)) + for span_begin, span_end in result + ] + return script_content_spans + result @property - def tex_span_list(self) -> list[tuple[int, int]]: - return self.get_tex_span_lists()[0] + def label_span_list(self) -> list[Span]: + return self.get_label_span_list(extended=False) @property - def extended_tex_span_list(self) -> list[tuple[int, int]]: - return self.get_tex_span_lists()[1] + def inserted_string_pairs(self) -> list[tuple[Span, tuple[str, str]]]: + return [ + (span, ( + "{{" + self.get_color_command_by_label(label), + "}}" + )) + for label, span in enumerate( + self.get_label_span_list(extended=True) + ) + ] @property - def containing_labels_dict(self) -> dict[tuple[int, int], list[int]]: - tex_span_list = self.tex_span_list - result = { - tex_span: [] - for tex_span in tex_span_list - } - for span_0 in tex_span_list: - for span_index, span_1 in enumerate(tex_span_list): - if span_0[0] <= span_1[0] and span_1[1] <= span_0[1]: - result[span_0].append(span_index) - elif span_0[0] < span_1[0] < span_0[1] < span_1[1]: - string_0, string_1 = [ - self.string[slice(*tex_span)] - for tex_span in [span_0, span_1] - ] - raise ValueError( - "Partially overlapping substrings detected: " - f"'{string_0}' and '{string_1}'" - ) - result[self.full_span] = list(range(-1, len(tex_span_list))) - return result - - @property - def color_cmd_repl_items(self) -> list[tuple[tuple[int, int], str]]: + def command_repl_items(self) -> list[tuple[Span, str]]: color_related_command_dict = { "color": (1, False), "textcolor": (1, False), @@ -509,7 +751,7 @@ class MTex(_TexSVG): n_braces, substitute_cmd = color_related_command_dict[cmd_name] span_end = right_brace_indices[self.find_region_index( cmd_end, right_brace_indices - ) + n_braces - 1] + 1 + ) + n_braces] + 1 if substitute_cmd: repl_str = "\\" + cmd_name + n_braces * "{white}" else: @@ -518,237 +760,84 @@ class MTex(_TexSVG): return result @property - def span_repl_dict(self) -> dict[tuple[int, int], str]: - indices, _, _, cmd_strings = zip(*sorted([ - ( - tex_span[flag], - -flag, - -tex_span[1 - flag], - ("{{" + self.get_color_command(label), "}}")[flag] - ) - for label, tex_span in enumerate(self.extended_tex_span_list) - for flag in range(2) - ])) - result = { - (index, index): "".join(cmd_strs) - for index, cmd_strs in zip(*self.group_neighbours( - indices, cmd_strings - )) - } - result.update(self.color_cmd_repl_items) - return result + def has_predefined_colors(self) -> bool: + return bool(self.command_repl_items) + + @staticmethod + def get_color_command_by_label(label: int) -> str: + if label == -1: + label = 16777215 # white + rg, b = divmod(label, 256) + r, g = divmod(rg, 256) + return "".join([ + "\\color[RGB]", + "{", + ",".join(map(str, (r, g, b))), + "}" + ]) @property def plain_string(self) -> str: return "".join([ "{{", - self.get_color_command(self.color_to_label(self.base_color)), + self.get_color_command_by_label( + self.color_to_label(self.base_color) + ), self.string, "}}" ]) @property - def labelled_string(self) -> str: - if not self.span_repl_dict: - return self.string + def ignored_indices_for_submob_strings(self) -> list[int]: + return self.left_brace_indices + self.right_brace_indices - spans = sorted(self.span_repl_dict.keys()) - if not all( - span_0[1] <= span_1[0] - for span_0, span_1 in self.get_neighbouring_pairs(spans) - ): - raise ValueError("Failed to generate the labelled string") - - span_ends, span_begins = zip(*spans) - string_pieces = [ - self.string[slice(*span)] - for span in zip( - (0, *span_begins), - (*span_ends, len(self.string)) - ) - ] - repl_strs = [ - self.span_repl_dict[span] - for span in spans - ] - repl_strs.append("") - return "".join(it.chain(*zip(string_pieces, repl_strs))) - - def get_submob_tex_strings(self, submob_labels: list[int]) -> list[str]: - ordered_tex_spans = [ - self.tex_span_list[label] if label != -1 else self.full_span - for label in submob_labels - ] - ordered_containing_labels = [ - self.containing_labels_dict[tex_span] - for tex_span in ordered_tex_spans - ] - ordered_span_begins, ordered_span_ends = zip(*ordered_tex_spans) - string_span_begins = [ - prev_end if prev_label in containing_labels else curr_begin - for prev_end, prev_label, containing_labels, curr_begin in zip( - ordered_span_ends[:-1], submob_labels[:-1], - ordered_containing_labels[1:], ordered_span_begins[1:] - ) - ] - string_span_ends = [ - next_begin if next_label in containing_labels else curr_end - for next_begin, next_label, containing_labels, curr_end in zip( - ordered_span_begins[1:], submob_labels[1:], - ordered_containing_labels[:-1], ordered_span_ends[:-1] - ) - ] - string_spans = list(zip( - (ordered_span_begins[0], *string_span_begins), - (*string_span_ends, ordered_span_ends[-1]) - )) - - string = self.string - left_brace_indices = self.left_brace_indices - right_brace_indices = self.right_brace_indices - slid_spans = self.skipped_spans + [ - (index, index + 1) - for index in left_brace_indices + right_brace_indices - ] - result = [] - for str_span in string_spans: - str_span = self.strip_span(str_span) - if str_span is None: - continue - str_span = self.slide(str_span, slid_spans) - if str_span is None: - continue - unclosed_left_braces = 0 - unclosed_right_braces = 0 - for index in range(*str_span): - if index in left_brace_indices: - unclosed_left_braces += 1 - elif index in right_brace_indices: - if unclosed_left_braces == 0: - unclosed_right_braces += 1 - else: - unclosed_left_braces -= 1 - result.append("".join([ - unclosed_right_braces * "{", - string[slice(*str_span)], - unclosed_left_braces * "}" - ])) - return result + def handle_submob_string(self, substr: str, string_span: Span) -> str: + unclosed_left_braces = 0 + unclosed_right_braces = 0 + for index in range(*string_span): + if index in self.left_brace_indices: + unclosed_left_braces += 1 + elif index in self.right_brace_indices: + if unclosed_left_braces == 0: + unclosed_right_braces += 1 + else: + unclosed_left_braces -= 1 + return "".join([ + unclosed_right_braces * "{", + substr, + unclosed_left_braces * "}" + ]) @property def specified_substrings(self) -> list[str]: return remove_list_redundancies([ self.string[slice(*double_braces_span)] for double_braces_span in self.double_braces_spans - ] + list(filter( - lambda s: s in self.string, - self.additional_substrings - ))) + ] + self.additional_substrings) - def get_specified_substrings(self) -> list[str]: - return self.specified_substrings + # Method alias - @property - def isolated_substrings(self) -> list[str]: - return remove_list_redundancies([ - self.string[slice(*tex_span)] - for tex_span in self.tex_span_list - ]) + def get_parts_by_tex(self, substr: str) -> VGroup: + return self.get_parts_by_string(substr) - def get_isolated_substrings(self) -> list[str]: - return self.isolated_substrings + def get_part_by_tex(self, substr: str, index: int = 0) -> VMobject: + return self.get_part_by_string(substr, index) - ## Selector - - def find_span_components_of_custom_span( - self, - custom_span: tuple[int, int] - ) -> list[tuple[int, int]] | None: - tex_span_choices = sorted(filter( - lambda tex_span: all([ - tex_span[0] >= custom_span[0], - tex_span[1] <= custom_span[1] - ]), - self.tex_span_list - )) - # Choose spans that reach the farthest. - tex_span_choices_dict = dict(tex_span_choices) - - result = [] - span_begin, span_end = custom_span - span_begin = self.rstrip_span(span_begin) - span_end = self.rstrip_span(span_end) - while span_begin != span_end: - span_begin = self.lstrip_span(span_begin) - if span_begin not in tex_span_choices_dict.keys(): - return None - next_begin = tex_span_choices_dict[span_begin] - result.append((span_begin, next_begin)) - span_begin = next_begin - return result - - def get_part_by_custom_span(self, custom_span: tuple[int, int]) -> VGroup: - tex_spans = self.find_span_components_of_custom_span( - custom_span - ) - if tex_spans is None: - tex = self.string[slice(*custom_span)] - raise ValueError(f"Failed to match mobjects from tex: \"{tex}\"") - - labels = set(it.chain(*[ - self.containing_labels_dict[tex_span] - for tex_span in tex_spans - ])) - return VGroup(*filter( - lambda submob: submob.label in labels, - self.submobjects - )) - - def get_parts_by_tex(self, tex: str) -> VGroup: - return VGroup(*[ - self.get_part_by_custom_span(match_obj.span()) - for match_obj in re.finditer( - re.escape(tex), self.string - ) - ]) - - def get_part_by_tex(self, tex: str, index: int = 0) -> VMobject: - all_parts = self.get_parts_by_tex(tex) - return all_parts[index] - - def set_color_by_tex(self, tex: str, color: ManimColor): - self.get_parts_by_tex(tex).set_color(color) - return self + def set_color_by_tex(self, substr: str, color: ManimColor): + return self.set_color_by_string(substr, color) def set_color_by_tex_to_color_map( - self, - tex_to_color_map: dict[str, ManimColor] + self, tex_to_color_map: dict[str, ManimColor] ): - for tex, color in tex_to_color_map.items(): - self.set_color_by_tex(tex, color) - return self + return self.set_color_by_string_to_color_map(tex_to_color_map) - def indices_of_part(self, part: Iterable[VMobject]) -> list[int]: - indices = [ - index for index, submob in enumerate(self.submobjects) - if submob in part - ] - if not indices: - raise ValueError("Failed to find part in tex") - return indices - - def indices_of_part_by_tex(self, tex: str, index: int = 0) -> list[int]: - part = self.get_part_by_tex(tex, index=index) - return self.indices_of_part(part) + def indices_of_part_by_tex( + self, substr: str, index: int = 0 + ) -> list[int]: + return self.indices_of_part_by_string(substr, index) def get_tex(self) -> str: - return self.string - - def get_submob_tex(self) -> list[str]: - return [ - submob.get_tex() - for submob in self.submobjects - ] + return self.get_string() class MTexText(MTex): diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index a13d1d80..24a5b111 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -2,11 +2,11 @@ from __future__ import annotations import os import re -import typing -from pathlib import Path - +import itertools as it import xml.sax.saxutils as saxutils +from pathlib import Path from contextlib import contextmanager +import typing from typing import Iterable, Sequence, Union import pygments @@ -17,198 +17,87 @@ from manimpango import MarkupUtils from manimlib.logger import log from manimlib.constants import * -from manimlib.mobject.geometry import Dot -from manimlib.mobject.svg.svg_mobject import SVGMobject +from manimlib.mobject.svg.mtex_mobject import LabelledString from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.utils.customization import get_customization from manimlib.utils.tex_file_writing import tex_hash from manimlib.utils.config_ops import digest_config from manimlib.utils.directories import get_downloads_dir from manimlib.utils.directories import get_text_dir +from manimlib.utils.iterables import remove_list_redundancies from typing import TYPE_CHECKING if TYPE_CHECKING: from manimlib.mobject.types.vectorized_mobject import VMobject + ManimColor = Union[str, colour.Color, Sequence[float]] + Span = tuple[int, int] TEXT_MOB_SCALE_FACTOR = 0.0076 DEFAULT_LINE_SPACING_SCALE = 0.6 -class _TextParser(object): - # See https://docs.gtk.org/Pango/pango_markup.html - # A tag containing two aliases will cause warning, - # so only use the first key of each group of aliases. - SPAN_ATTR_KEY_ALIAS_LIST = ( - ("font", "font_desc"), - ("font_family", "face"), - ("font_size", "size"), - ("font_style", "style"), - ("font_weight", "weight"), - ("font_variant", "variant"), - ("font_stretch", "stretch"), - ("font_features",), - ("foreground", "fgcolor", "color"), - ("background", "bgcolor"), - ("alpha", "fgalpha"), - ("background_alpha", "bgalpha"), - ("underline",), - ("underline_color",), - ("overline",), - ("overline_color",), - ("rise",), - ("baseline_shift",), - ("font_scale",), - ("strikethrough",), - ("strikethrough_color",), - ("fallback",), - ("lang",), - ("letter_spacing",), - ("gravity",), - ("gravity_hint",), - ("show",), - ("insert_hyphens",), - ("allow_breaks",), - ("line_height",), - ("text_transform",), - ("segment",), - ) - SPAN_ATTR_KEY_CONVERSION = { - key: key_alias_list[0] - for key_alias_list in SPAN_ATTR_KEY_ALIAS_LIST - for key in key_alias_list - } - - TAG_TO_ATTR_DICT = { - "b": {"font_weight": "bold"}, - "big": {"font_size": "larger"}, - "i": {"font_style": "italic"}, - "s": {"strikethrough": "true"}, - "sub": {"baseline_shift": "subscript", "font_scale": "subscript"}, - "sup": {"baseline_shift": "superscript", "font_scale": "superscript"}, - "small": {"font_size": "smaller"}, - "tt": {"font_family": "monospace"}, - "u": {"underline": "single"}, - } - - def __init__(self, text: str = "", is_markup: bool = True): - self.text = text - self.is_markup = is_markup - self.global_attrs = {} - self.local_attrs = {(0, len(self.text)): {}} - self.tag_strings = set() - if is_markup: - self.parse_markup() - - def parse_markup(self) -> None: - tag_pattern = r"""<(/?)(\w+)\s*((\w+\s*\=\s*('[^']*'|"[^"]*")\s*)*)>""" - attr_pattern = r"""(\w+)\s*\=\s*(?:(?:'([^']*)')|(?:"([^"]*)"))""" - start_match_obj_stack = [] - match_obj_pairs = [] - for match_obj in re.finditer(tag_pattern, self.text): - if not match_obj.group(1): - start_match_obj_stack.append(match_obj) - else: - match_obj_pairs.append((start_match_obj_stack.pop(), match_obj)) - self.tag_strings.add(match_obj.group()) - assert not start_match_obj_stack, "Unclosed tag(s) detected" - - for start_match_obj, end_match_obj in match_obj_pairs: - tag_name = start_match_obj.group(2) - assert tag_name == end_match_obj.group(2), "Unmatched tag names" - assert not end_match_obj.group(3), "Attributes shan't exist in ending tags" - if tag_name == "span": - attr_dict = { - match.group(1): match.group(2) or match.group(3) - for match in re.finditer(attr_pattern, start_match_obj.group(3)) - } - elif tag_name in _TextParser.TAG_TO_ATTR_DICT.keys(): - assert not start_match_obj.group(3), f"Attributes shan't exist in tag '{tag_name}'" - attr_dict = _TextParser.TAG_TO_ATTR_DICT[tag_name] - else: - raise AssertionError(f"Unknown tag: '{tag_name}'") - - text_span = (start_match_obj.end(), end_match_obj.start()) - self.update_local_attrs(text_span, attr_dict) - - @staticmethod - def convert_key_alias(key: str) -> str: - return _TextParser.SPAN_ATTR_KEY_CONVERSION[key] - - @staticmethod - def update_attr_dict(attr_dict: dict[str, str], key: str, value: typing.Any) -> None: - converted_key = _TextParser.convert_key_alias(key) - attr_dict[converted_key] = str(value) - - def update_global_attr(self, key: str, value: typing.Any) -> None: - _TextParser.update_attr_dict(self.global_attrs, key, value) - - def update_global_attrs(self, attr_dict: dict[str, typing.Any]) -> None: - for key, value in attr_dict.items(): - self.update_global_attr(key, value) - - def update_local_attr(self, span: tuple[int, int], key: str, value: typing.Any) -> None: - if span[0] >= span[1]: - log.warning(f"Span {span} doesn't match any part of the string") - return - - if span in self.local_attrs.keys(): - _TextParser.update_attr_dict(self.local_attrs[span], key, value) - return - - span_triplets = [] - for sp, attr_dict in self.local_attrs.items(): - if sp[1] <= span[0] or span[1] <= sp[0]: - continue - span_to_become = (max(sp[0], span[0]), min(sp[1], span[1])) - spans_to_add = [] - if sp[0] < span[0]: - spans_to_add.append((sp[0], span[0])) - if span[1] < sp[1]: - spans_to_add.append((span[1], sp[1])) - span_triplets.append((sp, span_to_become, spans_to_add)) - for span_to_remove, span_to_become, spans_to_add in span_triplets: - attr_dict = self.local_attrs.pop(span_to_remove) - for span_to_add in spans_to_add: - self.local_attrs[span_to_add] = attr_dict.copy() - self.local_attrs[span_to_become] = attr_dict - _TextParser.update_attr_dict(self.local_attrs[span_to_become], key, value) - - def update_local_attrs(self, text_span: tuple[int, int], attr_dict: dict[str, typing.Any]) -> None: - for key, value in attr_dict.items(): - self.update_local_attr(text_span, key, value) - - def remove_tags(self, string: str) -> str: - for tag_string in self.tag_strings: - string = string.replace(tag_string, "") - return string - - def get_text_pieces(self) -> list[tuple[str, dict[str, str]]]: - result = [] - for span in sorted(self.local_attrs.keys()): - text_piece = self.remove_tags(self.text[slice(*span)]) - if not text_piece: - continue - if not self.is_markup: - text_piece = saxutils.escape(text_piece) - attr_dict = self.global_attrs.copy() - attr_dict.update(self.local_attrs[span]) - result.append((text_piece, attr_dict)) - return result - - def get_markup_str_with_attrs(self) -> str: - return "".join([ - f"{text_piece}" - for text_piece, attr_dict in self.get_text_pieces() - ]) - - @staticmethod - def get_attr_dict_str(attr_dict: dict[str, str]) -> str: - return " ".join([ - f"{key}='{value}'" - for key, value in attr_dict.items() - ]) +# See https://docs.gtk.org/Pango/pango_markup.html +# A tag containing two aliases will cause warning, +# so only use the first key of each group of aliases. +SPAN_ATTR_KEY_ALIAS_LIST = ( + ("font", "font_desc"), + ("font_family", "face"), + ("font_size", "size"), + ("font_style", "style"), + ("font_weight", "weight"), + ("font_variant", "variant"), + ("font_stretch", "stretch"), + ("font_features",), + ("foreground", "fgcolor", "color"), + ("background", "bgcolor"), + ("alpha", "fgalpha"), + ("background_alpha", "bgalpha"), + ("underline",), + ("underline_color",), + ("overline",), + ("overline_color",), + ("rise",), + ("baseline_shift",), + ("font_scale",), + ("strikethrough",), + ("strikethrough_color",), + ("fallback",), + ("lang",), + ("letter_spacing",), + ("gravity",), + ("gravity_hint",), + ("show",), + ("insert_hyphens",), + ("allow_breaks",), + ("line_height",), + ("text_transform",), + ("segment",), +) +COLOR_RELATED_KEYS = ( + "foreground", + "background", + "underline_color", + "overline_color", + "strikethrough_color" +) +SPAN_ATTR_KEY_CONVERSION = { + key: key_alias_list[0] + for key_alias_list in SPAN_ATTR_KEY_ALIAS_LIST + for key in key_alias_list +} +TAG_TO_ATTR_DICT = { + "b": {"font_weight": "bold"}, + "big": {"font_size": "larger"}, + "i": {"font_style": "italic"}, + "s": {"strikethrough": "true"}, + "sub": {"baseline_shift": "subscript", "font_scale": "subscript"}, + "sup": {"baseline_shift": "superscript", "font_scale": "superscript"}, + "small": {"font_size": "smaller"}, + "tt": {"font_family": "monospace"}, + "u": {"underline": "single"}, +} # Temporary handler @@ -223,16 +112,9 @@ class _Alignment: self.value = _Alignment.VAL_DICT[s.upper()] -class Text(SVGMobject): +class MarkupText(LabelledString): CONFIG = { - # Mobject - "stroke_width": 0, - "svg_default": { - "color": WHITE, - }, - "height": None, - # Text - "is_markup": False, + "is_markup": True, "font_size": 48, "lsh": None, "justify": False, @@ -240,8 +122,6 @@ class Text(SVGMobject): "alignment": "LEFT", "line_width_factor": None, "font": "", - "disable_ligatures": True, - "apply_space_chars": True, "slant": NORMAL, "weight": NORMAL, "gradient": None, @@ -252,6 +132,7 @@ class Text(SVGMobject): "t2w": {}, "global_config": {}, "local_configs": {}, + "isolate": [], } def __init__(self, text: str, **kwargs): @@ -260,10 +141,15 @@ class Text(SVGMobject): validate_error = MarkupUtils.validate(text) if validate_error: raise ValueError(validate_error) - self.text = text - self.parser = _TextParser(text, is_markup=self.is_markup) - super().__init__(**kwargs) + self.text = text + super().__init__(text, **kwargs) + + if self.t2g: + log.warning( + "Manim currently cannot parse gradient from svg. " + "Please set gradient via `set_color_by_gradient`.", + ) if self.gradient: self.set_color_by_gradient(*self.gradient) if self.height is None: @@ -284,8 +170,6 @@ class Text(SVGMobject): self.alignment, self.line_width_factor, self.font, - self.disable_ligatures, - self.apply_space_chars, self.slant, self.weight, self.t2c, @@ -293,71 +177,32 @@ class Text(SVGMobject): self.t2s, self.t2w, self.global_config, - self.local_configs + self.local_configs, + self.isolate ) - def get_file_path(self) -> str: - full_markup = self.get_full_markup_str() + def full2short(self, config: dict) -> None: + conversion_dict = { + "line_spacing_height": "lsh", + "text2color": "t2c", + "text2font": "t2f", + "text2gradient": "t2g", + "text2slant": "t2s", + "text2weight": "t2w" + } + for kwargs in [config, self.CONFIG]: + for long_name, short_name in conversion_dict.items(): + if long_name in kwargs: + kwargs[short_name] = kwargs.pop(long_name) + + def get_file_path_by_content(self, content: str) -> str: svg_file = os.path.join( - get_text_dir(), tex_hash(full_markup) + ".svg" + get_text_dir(), tex_hash(content) + ".svg" ) if not os.path.exists(svg_file): - self.markup_to_svg(full_markup, svg_file) + self.markup_to_svg(content, svg_file) return svg_file - def get_full_markup_str(self) -> str: - if self.t2g: - log.warning( - "Manim currently cannot parse gradient from svg. " - "Please set gradient via `set_color_by_gradient`.", - ) - - config_style_dict = self.generate_config_style_dict() - global_attr_dict = { - "line_height": ((self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1) * 0.6, - "font_family": self.font or get_customization()["style"]["font"], - "font_size": self.font_size * 1024, - "font_style": self.slant, - "font_weight": self.weight, - # TODO, it seems this doesn't work - "font_features": "liga=0,dlig=0,clig=0,hlig=0" if self.disable_ligatures else None, - "foreground": config_style_dict.get("fill", None), - "alpha": config_style_dict.get("fill-opacity", None) - } - global_attr_dict = { - k: v - for k, v in global_attr_dict.items() - if v is not None - } - global_attr_dict.update(self.global_config) - self.parser.update_global_attrs(global_attr_dict) - - local_attr_items = [ - (word_or_text_span, {key: value}) - for t2x_dict, key in ( - (self.t2c, "foreground"), - (self.t2f, "font_family"), - (self.t2s, "font_style"), - (self.t2w, "font_weight") - ) - for word_or_text_span, value in t2x_dict.items() - ] - local_attr_items.extend(self.local_configs.items()) - for word_or_text_span, local_config in local_attr_items: - for text_span in self.find_indexes(word_or_text_span): - self.parser.update_local_attrs(text_span, local_config) - - return self.parser.get_markup_str_with_attrs() - - def find_indexes(self, word_or_text_span: str | tuple[int, int]) -> list[tuple[int, int]]: - if isinstance(word_or_text_span, tuple): - return [word_or_text_span] - - return [ - match_obj.span() - for match_obj in re.finditer(re.escape(word_or_text_span), self.text) - ] - def markup_to_svg(self, markup_str: str, file_name: str) -> str: # `manimpango` is under construction, # so the following code is intended to suit its interface @@ -374,7 +219,7 @@ class Text(SVGMobject): weight="NORMAL", # Already handled size=1, # Already handled _=0, # Empty parameter - disable_liga=False, # Already handled + disable_liga=False, # Need not to handle file_name=file_name, START_X=0, START_Y=0, @@ -387,63 +232,318 @@ class Text(SVGMobject): pango_width=pango_width ) - def generate_mobject(self) -> None: - super().generate_mobject() + # Toolkits - # Remove empty paths - submobjects = list(filter(lambda submob: submob.has_points(), self)) + @staticmethod + def get_attr_dict_str(attr_dict: dict[str, str]) -> str: + return " ".join([ + f"{key}='{value}'" + for key, value in attr_dict.items() + ]) - # Apply space characters - if self.apply_space_chars: - content_str = self.parser.remove_tags(self.text) - if self.is_markup: - content_str = saxutils.unescape(content_str) - for match_obj in re.finditer(r"\s", content_str): - char_index = match_obj.start() - space = Dot(radius=0, fill_opacity=0, stroke_opacity=0) - space.move_to(submobjects[max(char_index - 1, 0)].get_center()) - submobjects.insert(char_index, space) - self.set_submobjects(submobjects) + @staticmethod + def get_begin_tag_str(attr_dict: dict[str, str]) -> str: + return f"" - def full2short(self, config: dict) -> None: - conversion_dict = { - "line_spacing_height": "lsh", - "text2color": "t2c", - "text2font": "t2f", - "text2gradient": "t2g", - "text2slant": "t2s", - "text2weight": "t2w" - } - for kwargs in [config, self.CONFIG]: - for long_name, short_name in conversion_dict.items(): - if long_name in kwargs: - kwargs[short_name] = kwargs.pop(long_name) + @staticmethod + def get_end_tag_str() -> str: + return "" - def get_parts_by_text(self, word: str) -> VGroup: - if self.is_markup: - log.warning( - "Slicing MarkupText via `get_parts_by_text`, " - "the result could be unexpected." - ) - elif not self.apply_space_chars: - log.warning( - "Slicing Text via `get_parts_by_text` without applying spaces, " - "the result could be unexpected." - ) - return VGroup(*( - self[i:j] - for i, j in self.find_indexes(word) + @staticmethod + def convert_attr_key(key: str) -> str: + return SPAN_ATTR_KEY_CONVERSION[key.lower()] + + @staticmethod + def convert_attr_val(val: typing.Any) -> str: + return str(val).lower() + + @staticmethod + def merge_attr_items( + attr_items: list[Span, str, str] + ) -> list[tuple[Span, dict[str, str]]]: + index_seq = [0] + attr_dict_list = [{}] + for span, key, value in attr_items: + if span[0] >= span[1]: + continue + region_indices = [ + MarkupText.find_region_index(index, index_seq) + for index in span + ] + for flag in (1, 0): + if index_seq[region_indices[flag]] == span[flag]: + continue + region_index = region_indices[flag] + index_seq.insert(region_index + 1, span[flag]) + attr_dict_list.insert( + region_index + 1, attr_dict_list[region_index].copy() + ) + region_indices[flag] += 1 + if flag == 0: + region_indices[1] += 1 + for attr_dict in attr_dict_list[slice(*region_indices)]: + attr_dict[key] = value + return list(zip( + MarkupText.get_neighbouring_pairs(index_seq), attr_dict_list[:-1] )) - def get_part_by_text(self, word: str) -> VMobject | None: - parts = self.get_parts_by_text(word) - return parts[0] if parts else None + # Parser + + @property + def tag_items_from_markup( + self + ) -> list[tuple[Span, Span, dict[str, str]]]: + if not self.is_markup: + return [] + + tag_pattern = r"""<(/?)(\w+)\s*((\w+\s*\=\s*('.*?'|".*?")\s*)*)>""" + attr_pattern = r"""(\w+)\s*\=\s*(?:(?:'(.*?)')|(?:"(.*?)"))""" + begin_match_obj_stack = [] + match_obj_pairs = [] + for match_obj in re.finditer(tag_pattern, self.string): + if not match_obj.group(1): + begin_match_obj_stack.append(match_obj) + else: + match_obj_pairs.append( + (begin_match_obj_stack.pop(), match_obj) + ) + if begin_match_obj_stack: + raise ValueError("Unclosed tag(s) detected") + + result = [] + for begin_match_obj, end_match_obj in match_obj_pairs: + tag_name = begin_match_obj.group(2) + if tag_name != end_match_obj.group(2): + raise ValueError("Unmatched tag names") + if end_match_obj.group(3): + raise ValueError("Attributes shan't exist in ending tags") + if tag_name == "span": + attr_dict = dict([ + ( + MarkupText.convert_attr_key(match.group(1)), + MarkupText.convert_attr_val( + match.group(2) or match.group(3) + ) + ) + for match in re.finditer( + attr_pattern, begin_match_obj.group(3) + ) + ]) + elif tag_name in TAG_TO_ATTR_DICT.keys(): + if begin_match_obj.group(3): + raise ValueError( + f"Attributes shan't exist in tag '{tag_name}'" + ) + attr_dict = TAG_TO_ATTR_DICT[tag_name].copy() + else: + raise ValueError(f"Unknown tag: '{tag_name}'") + + result.append( + (begin_match_obj.span(), end_match_obj.span(), attr_dict) + ) + return result + + @property + def global_attr_items_from_config(self) -> list[str, str]: + global_attr_dict = { + "line_height": ( + (self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1 + ) * 0.6, + "font_family": self.font or get_customization()["style"]["font"], + "font_size": self.font_size * 1024, + "font_style": self.slant, + "font_weight": self.weight + } + global_attr_dict = { + k: v + for k, v in global_attr_dict.items() + if v is not None + } + result = list(it.chain( + global_attr_dict.items(), + self.global_config.items() + )) + return [ + ( + self.convert_attr_key(key), + self.convert_attr_val(val) + ) + for key, val in result + ] + + @property + def local_attr_items_from_config(self) -> list[tuple[Span, str, str]]: + result = [ + (text_span, key, val) + for t2x_dict, key in ( + (self.t2c, "foreground"), + (self.t2f, "font_family"), + (self.t2s, "font_style"), + (self.t2w, "font_weight") + ) + for word_or_span, val in t2x_dict.items() + for text_span in self.find_spans(word_or_span) + ] + [ + (text_span, key, val) + for word_or_span, local_config in self.local_configs.items() + for text_span in self.find_spans(word_or_span) + for key, val in local_config.items() + ] + return [ + ( + text_span, + self.convert_attr_key(key), + self.convert_attr_val(val) + ) + for text_span, key, val in result + ] + + def find_spans(self, word_or_span: str | Span) -> list[Span]: + if isinstance(word_or_span, tuple): + return [word_or_span] + + return [ + match_obj.span() + for match_obj in re.finditer(re.escape(word_or_span), self.string) + ] + + @property + def skipped_spans(self) -> list[Span]: + return [ + match_obj.span() + for match_obj in re.finditer(r"\s+", self.string) + ] + + @property + def label_span_list(self) -> list[Span]: + breakup_indices = [ + index + for pattern in [ + r"\s+", + r"\b", + *[ + re.escape(substr) + for substr in self.get_substrs_to_isolate(self.isolate) + ] + ] + for match_obj in re.finditer(pattern, self.string) + for index in match_obj.span() + ] + breakup_indices = sorted(filter( + lambda index: not any([ + span[0] < index < span[1] + for span, _ in self.command_repl_items + ]), + remove_list_redundancies([ + *self.full_span, *breakup_indices + ]) + )) + return list(filter( + lambda span: self.string[slice(*span)].strip(), + self.get_neighbouring_pairs(breakup_indices) + )) + + @property + def predefined_items(self) -> list[Span, str, str]: + return list(it.chain( + [ + (self.full_span, key, val) + for key, val in self.global_attr_items_from_config + ], + sorted([ + ((begin_tag_span[0], end_tag_span[1]), key, val) + for begin_tag_span, end_tag_span, attr_dict + in self.tag_items_from_markup + for key, val in attr_dict.items() + ]), + self.local_attr_items_from_config + )) + + def get_inserted_string_pairs( + self, use_label: bool + ) -> list[tuple[Span, tuple[str, str]]]: + attr_items = self.predefined_items + if use_label: + attr_items = [ + (span, key, WHITE if key in COLOR_RELATED_KEYS else val) + for span, key, val in attr_items + ] + [ + (span, "foreground", "#{:06x}".format(label)) + for label, span in enumerate(self.label_span_list) + ] + return [ + (span, ( + self.get_begin_tag_str(attr_dict), + self.get_end_tag_str() + )) + for span, attr_dict in self.merge_attr_items(attr_items) + ] + + @property + def inserted_string_pairs(self) -> list[tuple[Span, tuple[str, str]]]: + return self.get_inserted_string_pairs(use_label=True) + + @property + def command_repl_items(self) -> list[tuple[Span, str]]: + return [ + (tag_span, "") + for begin_tag, end_tag, _ in self.tag_items_from_markup + for tag_span in (begin_tag, end_tag) + ] + + @property + def has_predefined_colors(self) -> bool: + return any([ + key in COLOR_RELATED_KEYS + for _, key, _ in self.predefined_items + ]) + + @property + def plain_string(self) -> str: + return "".join([ + self.get_begin_tag_str({"foreground": self.base_color}), + self.replace_str_by_spans( + self.string, self.get_span_replacement_dict( + self.get_inserted_string_pairs(use_label=False), + self.command_repl_items + ) + ), + self.get_end_tag_str() + ]) + + def handle_submob_string(self, substr: str, string_span: Span) -> str: + if self.is_markup: + substr = saxutils.unescape(substr) + return substr + + # Method alias + + def get_parts_by_text(self, substr: str) -> VGroup: + return self.get_parts_by_string(substr) + + def get_part_by_text(self, substr: str, index: int = 0) -> VMobject: + return self.get_part_by_string(substr, index) + + def set_color_by_text(self, substr: str, color: ManimColor): + return self.set_color_by_string(substr, color) + + def set_color_by_text_to_color_map( + self, text_to_color_map: dict[str, ManimColor] + ): + return self.set_color_by_string_to_color_map(text_to_color_map) + + def indices_of_part_by_text( + self, substr: str, index: int = 0 + ) -> list[int]: + return self.indices_of_part_by_string(substr, index) + + def get_text(self) -> str: + return self.get_string() -class MarkupText(Text): +class Text(MarkupText): CONFIG = { - "is_markup": True, - "apply_space_chars": False, + "is_markup": False, } @@ -461,7 +561,9 @@ class Code(MarkupText): digest_config(self, kwargs) self.code = code lexer = pygments.lexers.get_lexer_by_name(self.language) - formatter = pygments.formatters.PangoMarkupFormatter(style=self.code_style) + formatter = pygments.formatters.PangoMarkupFormatter( + style=self.code_style + ) markup = pygments.highlight(code, lexer, formatter) markup = re.sub(r"", "", markup) super().__init__(markup, **kwargs)