3b1b-manim/manimlib/mobject/svg/mtex_mobject.py
2022-03-29 23:38:06 +08:00

923 lines
29 KiB
Python

from __future__ import annotations
import re
import colour
import itertools as it
from types import MethodType
from typing import Iterable, Union, Sequence
from abc import abstractmethod
from manimlib.constants import WHITE
from manimlib.mobject.svg.svg_mobject import SVGMobject
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.color import color_to_int_rgb
from manimlib.utils.config_ops import digest_config
from manimlib.utils.iterables import adjacent_pairs
from manimlib.utils.iterables import remove_list_redundancies
from manimlib.utils.tex_file_writing import tex_to_svg_file
from manimlib.utils.tex_file_writing import get_tex_config
from manimlib.utils.tex_file_writing import display_during_execution
from manimlib.logger import log
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.mobject.types.vectorized_mobject import VMobject
ManimColor = Union[str, colour.Color, Sequence[float]]
Span = tuple[int, int]
SCALE_FACTOR_PER_FONT_POINT = 0.001
class _StringSVG(SVGMobject):
CONFIG = {
"height": None,
"stroke_width": 0,
"stroke_color": WHITE,
"path_string_config": {
"should_subdivide_sharp_curves": True,
"should_remove_null_curves": True,
},
}
class LabelledString(_StringSVG):
"""
An abstract base class for `MTex` and `MarkupText`
"""
CONFIG = {
"base_color": WHITE,
"use_plain_file": False,
"isolate": [],
}
def __init__(self, string: str, **kwargs):
self.string = string
super().__init__(**kwargs)
def get_file_path(self, use_plain_file: bool = False) -> str:
#if use_plain_file:
# content = self.plain_string
#else:
# content = self.labelled_string
content = self.get_decorated_string(use_plain_file=use_plain_file)
return self.get_file_path_by_content(content)
@abstractmethod
def get_file_path_by_content(self, content: str) -> str:
return ""
def generate_mobject(self) -> None:
super().generate_mobject()
if not self.submobjects:
return
glyph_labels = [
self.color_to_label(glyph.get_fill_color())
for glyph in self.submobjects
]
if self.use_plain_file or self.has_predefined_colors:
file_path = self.get_file_path(use_plain_file=True)
glyphs = _StringSVG(file_path).submobjects
for glyph, plain_glyph in zip(self.submobjects, glyphs):
glyph.set_fill(plain_glyph.get_fill_color())
else:
glyphs = self.submobjects
self.set_fill(self.base_color)
# TODO
# Simply pack together adjacent mobjects with the same label.
submob_labels, glyphs_lists = self.group_neighbours(
glyph_labels, glyphs
)
submob_strings = self.get_submob_strings(submob_labels)
submobjects = []
for glyph_list, label, submob_string in zip(
glyphs_lists, submob_labels, submob_strings
):
submob = VGroup(*glyph_list)
submob.label = label
submob.string = submob_string
submob.get_string = MethodType(lambda inst: inst.string, submob)
submobjects.append(submob)
self.set_submobjects(submobjects)
# Toolkits
def find_spans(self, *patterns: str) -> list[Span]:
return [
match_obj.span()
for pattern in patterns
for match_obj in re.finditer(pattern, self.string)
]
@staticmethod
def get_neighbouring_pairs(iterable: Iterable) -> list:
return list(adjacent_pairs(iterable))[:-1]
@staticmethod
def span_contains(span_0: Span, span_1: Span) -> bool:
return span_0[0] <= span_1[0] and span_0[1] >= span_1[1]
@staticmethod
def group_neighbours(
labels: Iterable[object],
objs: Iterable[object]
) -> tuple[list[object], list[list[object]]]:
# Pack together neighbouring objects sharing the same label.
if not labels:
return [], []
group_labels = []
groups = []
new_group = []
current_label = labels[0]
for label, obj in zip(labels, objs):
if label == current_label:
new_group.append(obj)
else:
group_labels.append(current_label)
groups.append(new_group)
new_group = [obj]
current_label = label
group_labels.append(current_label)
groups.append(new_group)
return group_labels, groups
@staticmethod
def find_region_index(val: int, seq: list[int]) -> int:
# Returns an integer in `range(-1, len(seq))` satisfying
# `seq[result] <= val < seq[result + 1]`.
# `seq` should be sorted in ascending order.
if not seq or val < seq[0]:
return -1
result = len(seq) - 1
while val < seq[result]:
result -= 1
return result
@staticmethod
def replace_str_by_spans(
substr: str, span_repl_dict: dict[Span, str]
) -> str:
if not span_repl_dict:
return substr
spans = sorted(span_repl_dict.keys())
if not all(
span_0[1] <= span_1[0]
for span_0, span_1 in LabelledString.get_neighbouring_pairs(spans)
):
raise ValueError("Overlapping replacement")
span_ends, span_begins = zip(*spans)
pieces = [
substr[slice(*span)]
for span in zip(
(0, *span_begins),
(*span_ends, len(substr))
)
]
repl_strs = [*[span_repl_dict[span] for span in spans], ""]
return "".join(it.chain(*zip(pieces, repl_strs)))
@staticmethod
def get_span_replacement_dict(
inserted_string_pairs: list[tuple[Span, tuple[str, str]]],
other_repl_items: list[tuple[Span, str]]
) -> dict[Span, str]:
if not inserted_string_pairs:
return other_repl_items.copy()
indices, _, _, inserted_strings = zip(*sorted([
(
span[flag],
-flag,
-span[1 - flag],
str_pair[flag]
)
for span, str_pair in inserted_string_pairs
for flag in range(2)
]))
result = {
(index, index): "".join(inserted_strs)
for index, inserted_strs in zip(*LabelledString.group_neighbours(
indices, inserted_strings
))
}
result.update(other_repl_items)
return result
#@property
#def skipped_spans(self) -> list[Span]:
# return [
# match_obj.span()
# for match_obj in re.finditer(r"\s+", self.string)
# ]
#def lstrip(self, index: int) -> int:
# index_seq = list(it.chain(*self.skipped_spans))
# region_index = self.find_region_index(index, index_seq)
# if region_index % 2 == 0:
# return index_seq[region_index + 1]
# return index
#def rstrip(self, index: int) -> int:
# index_seq = list(it.chain(*self.skipped_spans))
# region_index = self.find_region_index(index - 1, index_seq)
# if region_index % 2 == 0:
# return index_seq[region_index]
# return index
#def strip(self, span: Span) -> Span | None:
# result = (
# self.lstrip(span[0]),
# self.rstrip(span[1])
# )
# if result[0] >= result[1]:
# return None
# return result
@staticmethod
def lstrip(index: int, skipped: list[Span]) -> int:
transfer_dict = dict(sorted(skipped))
while index in transfer_dict.keys():
index = transfer_dict[index]
return index
@staticmethod
def rstrip(index: int, skipped: list[Span]) -> int:
transfer_dict = dict(sorted([
skipped_span[::-1] for skipped_span in skipped
], reverse=True))
while index in transfer_dict.keys():
index = transfer_dict[index]
return index
@staticmethod
def strip(span: Span, skipped: list[Span]) -> Span | None:
result = (
LabelledString.lstrip(span[0], skipped),
LabelledString.rstrip(span[1], skipped)
)
if result[0] >= result[1]:
return None
return result
@abstractmethod
def get_begin_color_command_str(r: int, g: int, b: int) -> str:
return ""
@abstractmethod
def get_end_color_command_str() -> str:
return ""
@staticmethod
def color_to_label(color: ManimColor) -> int:
r, g, b = color_to_int_rgb(color)
rg = r * 256 + g
rgb = rg * 256 + b
if rgb == 16777215: # white
return -1
return rgb
# Parser
@property
def full_span(self) -> Span:
return (0, len(self.string))
@property
def space_spans(self) -> list[Span]:
return self.find_spans(r"\s+")
@abstractmethod
def internal_specified_spans(self) -> list[Span]:
return []
@property
def external_specified_spans(self) -> list[Span]:
substrs = remove_list_redundancies(self.isolate)
if "" in substrs:
substrs.remove("")
return self.find_spans(*[
re.escape(substr.strip()) for substr in substrs
])
@property
def specified_spans(self) -> list[Span]:
return remove_list_redundancies([
self.full_span,
*self.internal_specified_spans,
*self.external_specified_spans
])
def get_specified_substrings(self) -> list[str]:
return remove_list_redundancies([
self.string[slice(*span)]
for span in self.specified_spans
])
@abstractmethod
def label_span_list(self) -> list[Span]:
return []
@abstractmethod
def get_inserted_string_pairs(
self, use_plain_file: bool
) -> list[tuple[Span, tuple[str, str]]]:
return []
@abstractmethod
def command_repl_items(self) -> list[tuple[Span, str]]:
return []
@property
def command_spans(self) -> list[Span]:
return [cmd_span for cmd_span, _ in self.command_repl_items]
@abstractmethod
def remove_commands_in_plain_file(self) -> bool:
return True
#@abstractmethod
#def get_command_repl_items(
# self, use_plain_file: bool
#) -> list[tuple[Span, str]]:
# return []
def get_decorated_string(self, use_plain_file: bool) -> str:
if use_plain_file and self.remove_commands_in_plain_file:
other_repl_items = []
else:
other_repl_items = self.command_repl_items
span_repl_dict = self.get_span_replacement_dict(
self.get_inserted_string_pairs(use_plain_file),
other_repl_items
)
result = self.replace_str_by_spans(self.string, span_repl_dict)
if not use_plain_file:
return result
return "".join([
self.get_begin_color_command_str(
*color_to_int_rgb(self.base_color)
),
result,
self.get_end_color_command_str()
])
@abstractmethod
def has_predefined_colors(self) -> bool:
return False
#@property
#def plain_string(self) -> str:
# return self.string
#@property
#def labelled_string(self) -> str:
# return self.replace_str_by_spans(
# self.string, self.get_span_replacement_dict(
# self.inserted_string_pairs,
# self.command_repl_items
# )
# )
@property
def additionally_ignored_indices(self) -> list[int]:
return []
@property
def skipped_spans(self) -> list[Span]:
return list(it.chain(
self.space_spans,
self.command_spans,
[
(index, index + 1)
for index in self.additionally_ignored_indices
]
))
@property
def containing_labels_dict(self) -> dict[Span, list[int]]:
label_span_list = self.label_span_list
result = {
span: []
for span in label_span_list
}
for span_0 in label_span_list:
for span_index, span_1 in enumerate(label_span_list):
if self.span_contains(span_0, span_1):
result[span_0].append(span_index)
elif span_0[0] < span_1[0] < span_0[1] < span_1[1]:
string_0, string_1 = [
self.string[slice(*span)]
for span in [span_0, span_1]
]
raise ValueError(
"Partially overlapping substrings detected: "
f"'{string_0}' and '{string_1}'"
)
result[self.full_span] = list(range(-1, len(label_span_list)))
return result
def get_cleaned_substr(self, string_span: Span) -> str:
span = self.strip(string_span, self.skipped_spans)
if span is None:
return ""
span_repl_dict = {
tuple([index - span[0] for index in cmd_span]): ""
for cmd_span in self.command_spans
if self.span_contains(span, cmd_span)
}
return self.replace_str_by_spans(
self.string[slice(*span)], span_repl_dict
)
def get_submob_strings(self, submob_labels: list[int]) -> list[str]:
ordered_spans = [
self.label_span_list[label] if label != -1 else self.full_span
for label in submob_labels
]
ordered_containing_labels = [
self.containing_labels_dict[span]
for span in ordered_spans
]
ordered_span_begins, ordered_span_ends = zip(*ordered_spans)
string_span_begins = [
prev_end if prev_label in containing_labels else curr_begin
for prev_end, prev_label, containing_labels, curr_begin in zip(
ordered_span_ends[:-1], submob_labels[:-1],
ordered_containing_labels[1:], ordered_span_begins[1:]
)
]
string_span_ends = [
next_begin if next_label in containing_labels else curr_end
for next_begin, next_label, containing_labels, curr_end in zip(
ordered_span_begins[1:], submob_labels[1:],
ordered_containing_labels[:-1], ordered_span_ends[:-1]
)
]
string_spans = list(zip(
(ordered_span_begins[0], *string_span_begins),
(*string_span_ends, ordered_span_ends[-1])
))
return [
self.get_cleaned_substr(string_span)
for string_span in string_spans
]
# Selector
def find_span_components_of_custom_span(
self, custom_span: Span
) -> list[Span]:
skipped_spans = self.skipped_spans
span_choices = sorted(filter(
lambda span: self.span_contains(custom_span, span),
self.label_span_list
))
# Choose spans that reach the farthest.
span_choices_dict = dict(span_choices)
result = []
span_begin, span_end = custom_span
span_begin = self.rstrip(span_begin, skipped_spans)
span_end = self.rstrip(span_end, skipped_spans)
while span_begin != span_end:
span_begin = self.lstrip(span_begin, skipped_spans)
if span_begin not in span_choices_dict.keys():
return []
next_begin = span_choices_dict[span_begin]
result.append((span_begin, next_begin))
span_begin = next_begin
return result
def get_part_by_custom_span(self, custom_span: Span) -> VGroup:
spans = self.find_span_components_of_custom_span(custom_span)
labels = set(it.chain(*[
self.containing_labels_dict[span]
for span in spans
]))
return VGroup(*filter(
lambda submob: submob.label in labels,
self.submobjects
))
def get_parts_by_string(self, substr: str) -> VGroup:
return VGroup(*[
self.get_part_by_custom_span(span)
for span in self.find_spans(re.escape(substr.strip()))
])
def get_part_by_string(self, substr: str, index: int = 0) -> VMobject:
all_parts = self.get_parts_by_string(substr)
return all_parts[index]
def set_color_by_string(self, substr: str, color: ManimColor):
self.get_parts_by_string(substr).set_color(color)
return self
def set_color_by_string_to_color_map(
self, string_to_color_map: dict[str, ManimColor]
):
for substr, color in string_to_color_map.items():
self.set_color_by_string(substr, color)
return self
def indices_of_part(self, part: Iterable[VMobject]) -> list[int]:
indices = [
index for index, submob in enumerate(self.submobjects)
if submob in part
]
if not indices:
raise ValueError("Failed to find part")
return indices
def indices_of_part_by_string(
self, substr: str, index: int = 0
) -> list[int]:
part = self.get_part_by_string(substr, index=index)
return self.indices_of_part(part)
def get_string(self) -> str:
return self.string
class MTex(LabelledString):
CONFIG = {
"font_size": 48,
"alignment": "\\centering",
"tex_environment": "align*",
"tex_to_color_map": {},
}
def __init__(self, tex_string: str, **kwargs):
digest_config(self, kwargs)
tex_string = tex_string.strip()
# Prevent from passing an empty string.
if not tex_string:
tex_string = "\\quad"
self.tex_string = tex_string
self.isolate.extend(self.tex_to_color_map.keys())
super().__init__(tex_string, **kwargs)
self.set_color_by_tex_to_color_map(self.tex_to_color_map)
self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size)
@property
def hash_seed(self) -> tuple:
return (
self.__class__.__name__,
self.svg_default,
self.path_string_config,
self.base_color,
self.use_plain_file,
self.isolate,
self.tex_string,
self.alignment,
self.tex_environment,
self.tex_to_color_map
)
def get_file_path_by_content(self, content: str) -> str:
full_tex = self.get_tex_file_body(content)
with display_during_execution(f"Writing \"{self.string}\""):
file_path = self.tex_to_svg_file_path(full_tex)
return file_path
def get_tex_file_body(self, content: str) -> str:
if self.tex_environment:
content = "\n".join([
f"\\begin{{{self.tex_environment}}}",
content,
f"\\end{{{self.tex_environment}}}"
])
if self.alignment:
content = "\n".join([self.alignment, content])
tex_config = get_tex_config()
return tex_config["tex_body"].replace(
tex_config["text_to_replace"],
content
)
@staticmethod
def tex_to_svg_file_path(tex_file_content: str) -> str:
return tex_to_svg_file(tex_file_content)
# Toolkits
#@property
#def skipped_spans(self) -> list[Span]:
# return super().skipped_spans + self.indices_to_spans(
# self.script_char_indices
# )
@staticmethod
def get_begin_color_command_str(r: int, g: int, b: int) -> str:
return "".join([
"{{",
"\\color[RGB]",
"{",
",".join(map(str, (r, g, b))),
"}"
])
@staticmethod
def get_end_color_command_str() -> str:
return "}}"
# Parser
@property
def backslash_indices(self) -> list[int]:
# Newlines (`\\`) don't count.
return [
match_obj.end() - 1
for match_obj in re.finditer(r"\\+", self.string)
if len(match_obj.group()) % 2 == 1
]
@staticmethod
def get_unescaped_char_indices(*chars: str):
return list(filter(
lambda index: index - 1 not in self.backslash_indices,
[
match_obj.start()
for char in chars
for match_obj in re.finditer(re.escape(char), string)
]
))
def get_brace_indices_lists(self) -> tuple[list[Span], list[Span]]:
string = self.string
indices = self.get_unescaped_char_indices("{", "}")
left_brace_indices = []
right_brace_indices = []
left_brace_indices_stack = []
for index in indices:
if string[index] == "{":
left_brace_indices_stack.append(index)
else:
if not left_brace_indices_stack:
raise ValueError("Missing '{' inserted")
left_brace_index = left_brace_indices_stack.pop()
left_brace_indices.append(left_brace_index)
right_brace_indices.append(index)
if left_brace_indices_stack:
raise ValueError("Missing '}' inserted")
# `right_brace_indices` is already sorted.
return left_brace_indices, right_brace_indices
@property
def left_brace_indices(self) -> list[Span]:
return self.get_brace_indices_lists()[0]
@property
def right_brace_indices(self) -> list[Span]:
return self.get_brace_indices_lists()[1]
@property
def script_char_indices(self) -> list[Span]:
return self.get_unescaped_char_indices("_", "^")
@property
def script_content_spans(self) -> list[Span]:
result = []
brace_indices_dict = dict(zip(
self.left_brace_indices, self.right_brace_indices
))
for index in self.script_char_indices:
span_begin = self.lstrip(index, self.space_spans)
if span_begin in brace_indices_dict.keys():
span_end = brace_indices_dict[span_begin] + 1
else:
pattern = re.compile(r"[a-zA-Z0-9]|\\[a-zA-Z]+")
match_obj = pattern.match(self.string, pos=span_begin)
if not match_obj:
script_name = {
"_": "subscript",
"^": "superscript"
}[script_char]
raise ValueError(
f"Unclear {script_name} detected while parsing. "
"Please use braces to clarify"
)
span_end = match_obj.end()
result.append((span_begin, span_end))
return result
@property
def internal_specified_spans(self) -> list[Span]:
# Match paired double braces (`{{...}}`).
result = []
reversed_brace_indices_dict = dict(zip(
self.right_brace_indices, self.left_brace_indices
))
skip = False
for prev_right_index, right_index in self.get_neighbouring_pairs(
self.right_brace_indices
):
if skip:
skip = False
continue
if right_index != prev_right_index + 1:
continue
left_index = reversed_brace_indices_dict[right_index]
prev_left_index = reversed_brace_indices_dict[prev_right_index]
if left_index != prev_left_index - 1:
continue
result.append((left_index, right_index + 1))
skip = True
return result
@property
def label_span_list(self) -> list[Span]:
script_content_spans = self.script_content_spans
script_spans = [
(self.rstrip(index, self.space_spans), script_content_span[1])
for index, script_content_span in zip(
self.script_char_indices, script_content_spans
)
]
spans = remove_list_redundancies([
*self.specified_spans,
*script_content_spans
])
result = []
for span in spans:
if span in script_content_spans:
continue
span_begin, span_end = span
shrinked_end = self.rstrip(span_end, script_spans)
if span_begin >= shrinked_end:
continue
result.append((span_begin, self.lstrip(span_end, script_spans)))
#if extended:
# result = [
# (span_begin, self.lstrip(span_end, script_spans))
# for span_begin, span_end in result
# ]
return script_content_spans + remove_list_redundancies(result)
#@property
#def label_span_list(self) -> list[Span]:
# return self.get_label_span_list(extended=False)
def get_inserted_string_pairs(
self, use_plain_file: bool
) -> list[tuple[Span, tuple[str, str]]]:
if use_plain_file:
return []
return [
(span, (
self.get_begin_color_command_str(
label // 256 // 256,
label // 256 % 256,
label % 256
),
self.get_end_color_command_str()
))
for label, span in enumerate(
self.label_span_list
)
]
#@property
#def inserted_string_pairs(self) -> list[tuple[Span, tuple[str, str]]]:
# return [
# (span, (
# "{{" + self.get_color_command_by_label(label),
# "}}"
# ))
# for label, span in enumerate(
# self.get_label_span_list(extended=True)
# )
# ]
@property
def command_repl_items(self) -> list[tuple[Span, str]]:
color_related_command_dict = {
"color": (1, False),
"textcolor": (1, False),
"pagecolor": (1, True),
"colorbox": (1, True),
"fcolorbox": (2, True),
}
result = []
backslash_indices = self.backslash_indices
right_brace_indices = self.right_brace_indices
pattern = "".join([
r"\\",
"(",
"|".join(color_related_command_dict.keys()),
")",
r"(?![a-zA-Z])"
])
for match_obj in re.finditer(pattern, self.string):
span_begin, cmd_end = match_obj.span()
if span_begin not in backslash_indices:
continue
cmd_name = match_obj.group(1)
n_braces, substitute_cmd = color_related_command_dict[cmd_name]
span_end = right_brace_indices[self.find_region_index(
cmd_end, right_brace_indices
) + n_braces] + 1
if substitute_cmd:
repl_str = "\\" + cmd_name + n_braces * "{white}"
else:
repl_str = ""
result.append(((span_begin, span_end), repl_str))
return result
@property
def remove_commands_in_plain_file(self) -> bool:
return True
@property
def has_predefined_colors(self) -> bool:
return bool(self.command_repl_items)
#@staticmethod
#def get_color_command_by_label(label: int) -> str:
# if label == -1:
# label = 16777215 # white
# rg, b = divmod(label, 256)
# r, g = divmod(rg, 256)
# return "".join([
# "\\color[RGB]",
# "{",
# ",".join(map(str, (r, g, b))),
# "}"
# ])
#@property
#def plain_string(self) -> str:
# return "".join([
# "{{",
# self.get_color_command_by_label(
# self.color_to_label(self.base_color)
# ),
# self.string,
# "}}"
# ])
@property
def additionally_ignored_indices(self) -> list[int]:
return self.left_brace_indices + self.right_brace_indices
def get_cleaned_substr(self, string_span: Span) -> str:
substr = super().get_cleaned_substr(string_span)
unclosed_left_braces = 0
unclosed_right_braces = 0
for index in range(*string_span):
if index in self.left_brace_indices:
unclosed_left_braces += 1
elif index in self.right_brace_indices:
if unclosed_left_braces == 0:
unclosed_right_braces += 1
else:
unclosed_left_braces -= 1
return "".join([
unclosed_right_braces * "{",
substr,
unclosed_left_braces * "}"
])
# Method alias
def get_parts_by_tex(self, substr: str) -> VGroup:
return self.get_parts_by_string(substr)
def get_part_by_tex(self, substr: str, index: int = 0) -> VMobject:
return self.get_part_by_string(substr, index)
def set_color_by_tex(self, substr: str, color: ManimColor):
return self.set_color_by_string(substr, color)
def set_color_by_tex_to_color_map(
self, tex_to_color_map: dict[str, ManimColor]
):
return self.set_color_by_string_to_color_map(tex_to_color_map)
def indices_of_part_by_tex(
self, substr: str, index: int = 0
) -> list[int]:
return self.indices_of_part_by_string(substr, index)
def get_tex(self) -> str:
return self.get_string()
class MTexText(MTex):
CONFIG = {
"tex_environment": None,
}