Some refactors

This commit is contained in:
YishiMichael 2022-04-11 23:44:33 +08:00
parent 36d62ae1a3
commit 12bfe88f40
No known key found for this signature in database
GPG key ID: EC615C0C5A86BC80
3 changed files with 142 additions and 164 deletions

View file

@ -4,12 +4,14 @@ import re
import colour import colour
import itertools as it import itertools as it
from typing import Iterable, Union, Sequence from typing import Iterable, Union, Sequence
from abc import abstractmethod from abc import ABC, abstractmethod
from manimlib.constants import WHITE from manimlib.constants import BLACK, WHITE
from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.svg.svg_mobject import SVGMobject
from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.color import color_to_int_rgb from manimlib.utils.color import color_to_int_rgb
from manimlib.utils.color import color_to_rgb
from manimlib.utils.color import rgb_to_hex
from manimlib.utils.config_ops import digest_config from manimlib.utils.config_ops import digest_config
from manimlib.utils.iterables import remove_list_redundancies from manimlib.utils.iterables import remove_list_redundancies
@ -34,36 +36,39 @@ class _StringSVG(SVGMobject):
} }
class LabelledString(_StringSVG): class LabelledString(_StringSVG, ABC):
""" """
An abstract base class for `MTex` and `MarkupText` An abstract base class for `MTex` and `MarkupText`
""" """
CONFIG = { CONFIG = {
"base_color": None, "base_color": WHITE,
"use_plain_file": False, "use_plain_file": False,
"isolate": [], "isolate": [],
} }
def __init__(self, string: str, **kwargs): def __init__(self, string: str, **kwargs):
self.string = string self.string = string
reserved_svg_default = kwargs.pop("svg_default", {})
digest_config(self, kwargs) digest_config(self, kwargs)
self.reserved_svg_default = reserved_svg_default
self.base_color = self.base_color \ # Convert `base_color` to hex code.
or reserved_svg_default.get("color", None) \ self.base_color = rgb_to_hex(color_to_rgb(
or reserved_svg_default.get("fill_color", None) \ self.base_color \
or self.svg_default.get("color", None) \
or self.svg_default.get("fill_color", None) \
or WHITE or WHITE
))
self.svg_default["fill_color"] = BLACK
self.pre_parse() self.pre_parse()
self.parse() self.parse()
super().__init__(**kwargs) super().__init__()
self.post_parse() self.post_parse()
def get_file_path(self) -> str: def get_file_path(self) -> str:
return self.get_file_path_(use_plain_file=False) return self.get_file_path_(use_plain_file=False)
def get_file_path_(self, use_plain_file: bool) -> str: def get_file_path_(self, use_plain_file: bool) -> str:
content = self.get_decorated_string(use_plain_file=use_plain_file) content = self.get_content(use_plain_file)
return self.get_file_path_by_content(content) return self.get_file_path_by_content(content)
@abstractmethod @abstractmethod
@ -77,15 +82,11 @@ class LabelledString(_StringSVG):
self.color_to_label(submob.get_fill_color()) self.color_to_label(submob.get_fill_color())
for submob in self.submobjects for submob in self.submobjects
] ]
if any([ if self.use_plain_file or self.has_predefined_local_colors:
self.use_plain_file,
self.reserved_svg_default,
self.has_predefined_colors
]):
file_path = self.get_file_path_(use_plain_file=True) file_path = self.get_file_path_(use_plain_file=True)
plain_svg = _StringSVG( plain_svg = _StringSVG(
file_path, file_path,
svg_default=self.reserved_svg_default, svg_default=self.svg_default,
path_string_config=self.path_string_config path_string_config=self.path_string_config
) )
self.set_submobjects(plain_svg.submobjects) self.set_submobjects(plain_svg.submobjects)
@ -101,7 +102,9 @@ class LabelledString(_StringSVG):
def parse(self) -> None: def parse(self) -> None:
self.command_repl_items = self.get_command_repl_items() self.command_repl_items = self.get_command_repl_items()
self.command_spans = self.get_command_spans() self.command_spans = self.get_command_spans()
self.ignored_spans = self.get_ignored_spans() self.extra_entity_spans = self.get_extra_entity_spans()
self.entity_spans = self.get_entity_spans()
self.extra_ignored_spans = self.get_extra_ignored_spans()
self.skipped_spans = self.get_skipped_spans() self.skipped_spans = self.get_skipped_spans()
self.internal_specified_spans = self.get_internal_specified_spans() self.internal_specified_spans = self.get_internal_specified_spans()
self.external_specified_spans = self.get_external_specified_spans() self.external_specified_spans = self.get_external_specified_spans()
@ -216,7 +219,7 @@ class LabelledString(_StringSVG):
return sorted_seq[index + index_shift] return sorted_seq[index + index_shift]
@staticmethod @staticmethod
def get_span_replacement_dict( def generate_span_repl_dict(
inserted_string_pairs: list[tuple[Span, tuple[str, str]]], inserted_string_pairs: list[tuple[Span, tuple[str, str]]],
other_repl_items: list[tuple[Span, str]] other_repl_items: list[tuple[Span, str]]
) -> dict[Span, str]: ) -> dict[Span, str]:
@ -290,20 +293,20 @@ class LabelledString(_StringSVG):
r, g = divmod(rg, 256) r, g = divmod(rg, 256)
return r, g, b return r, g, b
@staticmethod
def int_to_hex(rgb_int: int) -> str:
return "#{:06x}".format(rgb_int).upper()
@staticmethod
def hex_to_int(rgb_hex: str) -> int:
return int(rgb_hex[1:], 16)
@staticmethod @staticmethod
def color_to_label(color: ManimColor) -> int: def color_to_label(color: ManimColor) -> int:
rgb_tuple = color_to_int_rgb(color) rgb_tuple = color_to_int_rgb(color)
rgb = LabelledString.rgb_to_int(rgb_tuple) rgb = LabelledString.rgb_to_int(rgb_tuple)
return rgb - 1 return rgb - 1
@abstractmethod
def get_begin_color_command_str(int_rgb: int) -> str:
return ""
@abstractmethod
def get_end_color_command_str() -> str:
return ""
# Parsing # Parsing
@abstractmethod @abstractmethod
@ -313,14 +316,25 @@ class LabelledString(_StringSVG):
def get_command_spans(self) -> list[Span]: def get_command_spans(self) -> list[Span]:
return [cmd_span for cmd_span, _ in self.command_repl_items] return [cmd_span for cmd_span, _ in self.command_repl_items]
def get_ignored_spans(self) -> list[int]: @abstractmethod
def get_extra_entity_spans(self) -> list[Span]:
return []
def get_entity_spans(self) -> list[Span]:
return list(it.chain(
self.command_spans,
self.extra_entity_spans
))
@abstractmethod
def get_extra_ignored_spans(self) -> list[int]:
return [] return []
def get_skipped_spans(self) -> list[Span]: def get_skipped_spans(self) -> list[Span]:
return list(it.chain( return list(it.chain(
self.find_spans(r"\s"), self.find_spans(r"\s"),
self.command_spans, self.command_spans,
self.ignored_spans self.extra_ignored_spans
)) ))
def shrink_span(self, span: Span) -> Span: def shrink_span(self, span: Span) -> Span:
@ -344,7 +358,11 @@ class LabelledString(_StringSVG):
self.find_substrs(self.isolate) self.find_substrs(self.isolate)
)) ))
shrinked_spans = list(filter( shrinked_spans = list(filter(
lambda span: span[0] < span[1], lambda span: span[0] < span[1] and not any([
entity_span[0] < index < entity_span[1]
for index in span
for entity_span in self.entity_spans
]),
[self.shrink_span(span) for span in spans] [self.shrink_span(span) for span in spans]
)) ))
return remove_list_redundancies(shrinked_spans) return remove_list_redundancies(shrinked_spans)
@ -363,36 +381,11 @@ class LabelledString(_StringSVG):
) )
@abstractmethod @abstractmethod
def get_inserted_string_pairs( def get_content(self, use_plain_file: bool) -> str:
self, use_plain_file: bool return ""
) -> list[tuple[Span, tuple[str, str]]]:
return []
@abstractmethod @abstractmethod
def get_other_repl_items( def has_predefined_local_colors(self) -> bool:
self, use_plain_file: bool
) -> list[tuple[Span, str]]:
return []
def get_decorated_string(self, use_plain_file: bool) -> str:
span_repl_dict = self.get_span_replacement_dict(
self.get_inserted_string_pairs(use_plain_file),
self.get_other_repl_items(use_plain_file)
)
result = self.get_replaced_substr(self.full_span, span_repl_dict)
if not use_plain_file:
return result
return "".join([
self.get_begin_color_command_str(
self.rgb_to_int(color_to_int_rgb(self.base_color))
),
result,
self.get_end_color_command_str()
])
@abstractmethod
def has_predefined_colors(self) -> bool:
return False return False
# Post-parsing # Post-parsing

View file

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import itertools as it
import colour import colour
from typing import Union, Sequence from typing import Union, Sequence
@ -32,7 +33,7 @@ class MTex(LabelledString):
def __init__(self, tex_string: str, **kwargs): def __init__(self, tex_string: str, **kwargs):
# Prevent from passing an empty string. # Prevent from passing an empty string.
if not tex_string: if not tex_string:
tex_string = "\\quad" tex_string = "\\\\"
self.tex_string = tex_string self.tex_string = tex_string
super().__init__(tex_string, **kwargs) super().__init__(tex_string, **kwargs)
@ -55,30 +56,14 @@ class MTex(LabelledString):
) )
def get_file_path_by_content(self, content: str) -> str: def get_file_path_by_content(self, content: str) -> str:
full_tex = self.get_tex_file_body(content)
with display_during_execution(f"Writing \"{self.tex_string}\""):
file_path = self.tex_to_svg_file_path(full_tex)
return file_path
def get_tex_file_body(self, content: str) -> str:
if self.tex_environment:
content = "\n".join([
f"\\begin{{{self.tex_environment}}}",
content,
f"\\end{{{self.tex_environment}}}"
])
if self.alignment:
content = "\n".join([self.alignment, content])
tex_config = get_tex_config() tex_config = get_tex_config()
return tex_config["tex_body"].replace( full_tex = tex_config["tex_body"].replace(
tex_config["text_to_replace"], tex_config["text_to_replace"],
content content
) )
with display_during_execution(f"Writing \"{self.tex_string}\""):
@staticmethod file_path = tex_to_svg_file(full_tex)
def tex_to_svg_file_path(tex_file_content: str) -> str: return file_path
return tex_to_svg_file(tex_file_content)
def pre_parse(self) -> None: def pre_parse(self) -> None:
super().pre_parse() super().pre_parse()
@ -91,29 +76,23 @@ class MTex(LabelledString):
# Toolkits # Toolkits
@staticmethod @staticmethod
def get_begin_color_command_str(rgb_int: int) -> str: def get_color_command_str(rgb_int: int) -> str:
rgb_tuple = MTex.int_to_rgb(rgb_int) rgb_tuple = MTex.int_to_rgb(rgb_int)
return "".join([ return "".join([
"{{",
"\\color[RGB]", "\\color[RGB]",
"{", "{",
",".join(map(str, rgb_tuple)), ",".join(map(str, rgb_tuple)),
"}" "}"
]) ])
@staticmethod
def get_end_color_command_str() -> str:
return "}}"
# Pre-parsing # Pre-parsing
def get_backslash_indices(self) -> list[int]: def get_backslash_indices(self) -> list[int]:
# Newlines (`\\`) don't count. # The latter of `\\` doesn't count.
return [ return list(it.chain(*[
span[1] - 1 range(span[0], span[1], 2)
for span in self.find_spans(r"\\+") for span in self.find_spans(r"\\+")
if (span[1] - span[0]) % 2 == 1 ]))
]
def get_unescaped_char_spans(self, chars: str): def get_unescaped_char_spans(self, chars: str):
return sorted(filter( return sorted(filter(
@ -215,7 +194,13 @@ class MTex(LabelledString):
result.append(((span_begin, span_end), repl_str)) result.append(((span_begin, span_end), repl_str))
return result return result
def get_ignored_spans(self) -> list[int]: def get_extra_entity_spans(self) -> list[Span]:
return [
self.match(r"\\([a-zA-Z]+|.)", pos=index).span()
for index in self.backslash_indices
]
def get_extra_ignored_spans(self) -> list[int]:
return self.script_char_spans.copy() return self.script_char_spans.copy()
def get_internal_specified_spans(self) -> list[Span]: def get_internal_specified_spans(self) -> list[Span]:
@ -256,35 +241,46 @@ class MTex(LabelledString):
result.append(shrinked_span) result.append(shrinked_span)
return result return result
def get_inserted_string_pairs( def get_content(self, use_plain_file: bool) -> str:
self, use_plain_file: bool
) -> list[tuple[Span, tuple[str, str]]]:
if use_plain_file: if use_plain_file:
return [] span_repl_dict = {}
else:
extended_label_span_list = [
span
if span in self.script_content_spans
else (span[0], self.rslide(span[1], self.script_spans))
for span in self.label_span_list
]
inserted_string_pairs = [
(span, (
"{{" + self.get_color_command_str(label + 1),
"}}"
))
for label, span in enumerate(extended_label_span_list)
]
span_repl_dict = self.generate_span_repl_dict(
inserted_string_pairs,
self.command_repl_items
)
result = self.get_replaced_substr(self.full_span, span_repl_dict)
extended_label_span_list = [ if self.tex_environment:
span result = "\n".join([
if span in self.script_content_spans f"\\begin{{{self.tex_environment}}}",
else (span[0], self.rslide(span[1], self.script_spans)) result,
for span in self.label_span_list f"\\end{{{self.tex_environment}}}"
] ])
return [ if self.alignment:
(span, ( result = "\n".join([self.alignment, result])
self.get_begin_color_command_str(label + 1),
self.get_end_color_command_str()
))
for label, span in enumerate(extended_label_span_list)
]
def get_other_repl_items(
self, use_plain_file: bool
) -> list[tuple[Span, str]]:
if use_plain_file: if use_plain_file:
return [] result = "\n".join([
return self.command_repl_items.copy() self.get_color_command_str(self.hex_to_int(self.base_color)),
result
])
return result
@property @property
def has_predefined_colors(self) -> bool: def has_predefined_local_colors(self) -> bool:
return bool(self.command_repl_items) return bool(self.command_repl_items)
# Post-parsing # Post-parsing

View file

@ -254,27 +254,6 @@ class MarkupText(LabelledString):
for key, val in attr_dict.items() for key, val in attr_dict.items()
]) ])
@staticmethod
def get_begin_tag_str(attr_dict: dict[str, str]) -> str:
return f"<span {MarkupText.get_attr_dict_str(attr_dict)}>"
@staticmethod
def get_end_tag_str() -> str:
return "</span>"
@staticmethod
def rgb_int_to_hex(rgb_int: int) -> str:
return "#{:06x}".format(rgb_int).upper()
@staticmethod
def get_begin_color_command_str(rgb_int: int):
color_hex = MarkupText.rgb_int_to_hex(rgb_int)
return MarkupText.get_begin_tag_str({"foreground": color_hex})
@staticmethod
def get_end_color_command_str() -> str:
return MarkupText.get_end_tag_str()
@staticmethod @staticmethod
def merge_attr_dicts( def merge_attr_dicts(
attr_dict_items: list[Span, str, typing.Any] attr_dict_items: list[Span, str, typing.Any]
@ -452,6 +431,14 @@ class MarkupText(LabelledString):
] ]
return result return result
def get_extra_entity_spans(self) -> list[Span]:
if not self.is_markup:
return []
return self.find_spans(r"&.*?;")
def get_extra_ignored_spans(self) -> list[int]:
return []
def get_internal_specified_spans(self) -> list[Span]: def get_internal_specified_spans(self) -> list[Span]:
return [span for span, _ in self.local_dicts_from_markup] return [span for span, _ in self.local_dicts_from_markup]
@ -464,13 +451,10 @@ class MarkupText(LabelledString):
self.find_spans(r"\b"), self.find_spans(r"\b"),
self.specified_spans self.specified_spans
)))) ))))
entity_spans = self.command_spans.copy()
if self.is_markup:
entity_spans += self.find_spans(r"&.*?;")
breakup_indices = sorted(filter( breakup_indices = sorted(filter(
lambda index: not any([ lambda index: not any([
span[0] < index < span[1] span[0] < index < span[1]
for span in entity_spans for span in self.entity_spans
]), ]),
breakup_indices breakup_indices
)) ))
@ -479,40 +463,45 @@ class MarkupText(LabelledString):
self.get_neighbouring_pairs(breakup_indices) self.get_neighbouring_pairs(breakup_indices)
)) ))
def get_inserted_string_pairs( def get_content(self, use_plain_file: bool) -> str:
self, use_plain_file: bool if use_plain_file:
) -> list[tuple[Span, tuple[str, str]]]:
if not use_plain_file:
attr_dict_items = [ attr_dict_items = [
(span, { (self.full_span, {"foreground": self.base_color}),
key: BLACK if key in COLOR_RELATED_KEYS else val *self.predefined_attr_dicts,
for key, val in attr_dict.items() *[
}) (span, {})
for span, attr_dict in self.predefined_attr_dicts for span in self.label_span_list
] + [ ]
(span, {"foreground": self.rgb_int_to_hex(label + 1)})
for label, span in enumerate(self.label_span_list)
] ]
else: else:
attr_dict_items = self.predefined_attr_dicts + [ attr_dict_items = [
(span, {}) (self.full_span, {"foreground": BLACK}),
for span in self.label_span_list *[
(span, {
key: BLACK if key in COLOR_RELATED_KEYS else val
for key, val in attr_dict.items()
})
for span, attr_dict in self.predefined_attr_dicts
],
*[
(span, {"foreground": self.int_to_hex(label + 1)})
for label, span in enumerate(self.label_span_list)
]
] ]
return [ inserted_string_pairs = [
(span, ( (span, (
self.get_begin_tag_str(attr_dict), f"<span {self.get_attr_dict_str(attr_dict)}>",
self.get_end_tag_str() "</span>"
)) ))
for span, attr_dict in self.merge_attr_dicts(attr_dict_items) for span, attr_dict in self.merge_attr_dicts(attr_dict_items)
] ]
span_repl_dict = self.generate_span_repl_dict(
def get_other_repl_items( inserted_string_pairs, self.command_repl_items
self, use_plain_file: bool )
) -> list[tuple[Span, str]]: return self.get_replaced_substr(self.full_span, span_repl_dict)
return self.command_repl_items.copy()
@property @property
def has_predefined_colors(self) -> bool: def has_predefined_local_colors(self) -> bool:
return any([ return any([
key in COLOR_RELATED_KEYS key in COLOR_RELATED_KEYS
for _, attr_dict in self.predefined_attr_dicts for _, attr_dict in self.predefined_attr_dicts