mirror of
https://github.com/3b1b/manim.git
synced 2025-11-14 22:57:44 +00:00
Refactor StringMobject and relevant classes
This commit is contained in:
parent
f434eb93e2
commit
28e4240475
3 changed files with 153 additions and 124 deletions
|
|
@ -106,16 +106,16 @@ class MTex(StringMobject):
|
|||
|
||||
@staticmethod
|
||||
def get_internal_specified_items(
|
||||
cmd_match_pairs: list[tuple[re.Match, re.Match]]
|
||||
command_match_pairs: list[tuple[re.Match, re.Match]]
|
||||
) -> list[tuple[Span, dict[str, str]]]:
|
||||
cmd_content_spans = [
|
||||
command_content_spans = [
|
||||
(start_match.end(), end_match.start())
|
||||
for start_match, end_match in cmd_match_pairs
|
||||
for start_match, end_match in command_match_pairs
|
||||
]
|
||||
return [
|
||||
(span, {})
|
||||
for span, next_span
|
||||
in MTex.get_neighbouring_pairs(cmd_content_spans)
|
||||
in MTex.get_neighbouring_pairs(command_content_spans)
|
||||
if span[0] == next_span[0] + 1 and span[1] == next_span[1] - 1
|
||||
]
|
||||
|
||||
|
|
@ -129,7 +129,7 @@ class MTex(StringMobject):
|
|||
]
|
||||
|
||||
@staticmethod
|
||||
def get_color_cmd_str(rgb_hex: str) -> str:
|
||||
def get_color_command(rgb_hex: str) -> str:
|
||||
rgb = MTex.hex_to_int(rgb_hex)
|
||||
rg, b = divmod(rgb, 256)
|
||||
r, g = divmod(rg, 256)
|
||||
|
|
@ -143,7 +143,7 @@ class MTex(StringMobject):
|
|||
return ""
|
||||
if is_end:
|
||||
return "}}"
|
||||
return "{{" + MTex.get_color_cmd_str(label_hex)
|
||||
return "{{" + MTex.get_color_command(label_hex)
|
||||
|
||||
def get_content_prefix_and_suffix(
|
||||
self, is_labelled: bool
|
||||
|
|
@ -151,7 +151,9 @@ class MTex(StringMobject):
|
|||
prefix_lines = []
|
||||
suffix_lines = []
|
||||
if not is_labelled:
|
||||
prefix_lines.append(self.get_color_cmd_str(self.base_color_hex))
|
||||
prefix_lines.append(self.get_color_command(
|
||||
self.color_to_hex(self.base_color)
|
||||
))
|
||||
if self.alignment:
|
||||
prefix_lines.append(self.alignment)
|
||||
if self.tex_environment:
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from manimlib.utils.color import color_to_rgb
|
|||
from manimlib.utils.color import rgb_to_hex
|
||||
from manimlib.utils.config_ops import digest_config
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from colour import Color
|
||||
|
|
@ -71,9 +71,8 @@ class StringMobject(SVGMobject, ABC):
|
|||
digest_config(self, kwargs)
|
||||
if self.base_color is None:
|
||||
self.base_color = WHITE
|
||||
self.base_color_hex = self.color_to_hex(self.base_color)
|
||||
#self.base_color_hex = self.color_to_hex(self.base_color)
|
||||
|
||||
self.full_len = len(self.string)
|
||||
self.parse()
|
||||
super().__init__(**kwargs)
|
||||
self.labels = [submob.label for submob in self.submobjects]
|
||||
|
|
@ -169,7 +168,7 @@ class StringMobject(SVGMobject, ABC):
|
|||
isinstance(index, int) or index is None
|
||||
for index in sel
|
||||
):
|
||||
l = self.full_len
|
||||
l = len(self.string)
|
||||
span = tuple(
|
||||
default_index if index is None else
|
||||
min(index, l) if index >= 0 else max(index + l, 0)
|
||||
|
|
@ -188,35 +187,15 @@ class StringMobject(SVGMobject, ABC):
|
|||
result.extend(spans)
|
||||
return list(filter(lambda span: span[0] < span[1], result))
|
||||
|
||||
def get_substr(self, span: Span) -> str:
|
||||
return self.string[slice(*span)]
|
||||
|
||||
@staticmethod
|
||||
def get_neighbouring_pairs(vals: Iterable[T]) -> list[tuple[T, T]]:
|
||||
val_list = list(vals)
|
||||
return list(zip(val_list[:-1], val_list[1:]))
|
||||
|
||||
@staticmethod
|
||||
def join_strs(strs, inserted_strs):
|
||||
return "".join(it.chain(*zip(strs, (*inserted_strs, ""))))
|
||||
|
||||
@staticmethod
|
||||
def span_contains(span_0: Span, span_1: Span) -> bool:
|
||||
return span_0[0] <= span_1[0] and span_0[1] >= span_1[1]
|
||||
|
||||
@staticmethod
|
||||
def get_complement_spans(
|
||||
universal_span: Span, interval_spans: list[Span]
|
||||
) -> list[Span]:
|
||||
if not interval_spans:
|
||||
return [universal_span]
|
||||
|
||||
span_ends, span_starts = zip(*interval_spans)
|
||||
return list(zip(
|
||||
(universal_span[0], *span_starts),
|
||||
(*span_ends, universal_span[1])
|
||||
))
|
||||
|
||||
@staticmethod
|
||||
def color_to_hex(color: ManimColor) -> str:
|
||||
return rgb_to_hex(color_to_rgb(color))
|
||||
|
|
@ -232,6 +211,23 @@ class StringMobject(SVGMobject, ABC):
|
|||
# Parsing
|
||||
|
||||
def parse(self) -> None:
|
||||
def get_substr(span: Span) -> str:
|
||||
return self.string[slice(*span)]
|
||||
|
||||
def get_complement_spans(
|
||||
universal_span: Span, interval_spans: list[Span]
|
||||
) -> list[Span]:
|
||||
if not interval_spans:
|
||||
return [universal_span]
|
||||
|
||||
span_ends, span_starts = zip(*interval_spans)
|
||||
return list(zip(
|
||||
(universal_span[0], *span_starts),
|
||||
(*span_ends, universal_span[1])
|
||||
))
|
||||
|
||||
full_len = len(self.string)
|
||||
|
||||
command_matches = list(re.finditer(
|
||||
self.get_command_pattern(), self.string, re.X | re.S
|
||||
))
|
||||
|
|
@ -251,11 +247,11 @@ class StringMobject(SVGMobject, ABC):
|
|||
]
|
||||
]
|
||||
command_spans = [match_obj.span() for match_obj in command_matches]
|
||||
region_spans = self.get_complement_spans(
|
||||
(0, self.full_len), command_spans
|
||||
region_spans = get_complement_spans(
|
||||
(0, full_len), command_spans
|
||||
)
|
||||
|
||||
def get_region_index(index):
|
||||
def get_region_index(index: int) -> int:
|
||||
for region_index, (start, end) in enumerate(region_spans):
|
||||
if start <= index <= end:
|
||||
return region_index
|
||||
|
|
@ -273,7 +269,7 @@ class StringMobject(SVGMobject, ABC):
|
|||
]):
|
||||
log.warning(
|
||||
"Cannot handle substring '%s', ignored",
|
||||
self.get_substr(span)
|
||||
get_substr(span)
|
||||
)
|
||||
continue
|
||||
overlapped_spans = [
|
||||
|
|
@ -285,14 +281,14 @@ class StringMobject(SVGMobject, ABC):
|
|||
if overlapped_spans:
|
||||
log.warning(
|
||||
"Substring '%s' partly overlaps with '%s', ignored",
|
||||
self.get_substr(span),
|
||||
self.get_substr(overlapped_spans[0])
|
||||
get_substr(span),
|
||||
get_substr(overlapped_spans[0])
|
||||
)
|
||||
continue
|
||||
labelled_spans.append(span)
|
||||
attr_dicts.append(attr_dict)
|
||||
|
||||
insertion_items = [
|
||||
inserted_items = [
|
||||
label_flag_pair
|
||||
for _, label_flag_pair in sorted(it.chain(*(
|
||||
sorted([
|
||||
|
|
@ -302,36 +298,127 @@ class StringMobject(SVGMobject, ABC):
|
|||
for flag in (-1, 1)
|
||||
)), key=lambda t: t[0][0])
|
||||
]
|
||||
insertion_interval_items = [
|
||||
#inserted_indices = [0, *(
|
||||
# labelled_spans[label][flag < 0]
|
||||
# for label, flag in inserted_items
|
||||
#), full_len]
|
||||
inserted_interval_items = [
|
||||
tuple(zip(*pair))
|
||||
for pair in self.get_neighbouring_pairs([
|
||||
(index, get_region_index(index))
|
||||
for index in [0, *(
|
||||
labelled_spans[label][flag < 0]
|
||||
for label, flag in insertion_items
|
||||
), self.full_len]
|
||||
for label, flag in inserted_items
|
||||
), full_len]
|
||||
])
|
||||
]
|
||||
|
||||
def get_replaced_pieces(replace_func):
|
||||
def join_strs(strs: list[str], inserted_strs: list[str]) -> str:
|
||||
return "".join(it.chain(*zip(strs, (*inserted_strs, ""))))
|
||||
|
||||
def get_replaced_pieces(replace_func: Callable[[re.Match], str]) -> list[str]:
|
||||
return [
|
||||
self.join_strs([
|
||||
self.get_substr(s)
|
||||
for s in self.get_complement_spans(
|
||||
join_strs([
|
||||
get_substr(s)
|
||||
for s in get_complement_spans(
|
||||
span, command_spans[slice(*region_range)]
|
||||
)
|
||||
], [
|
||||
replace_func(command_match)
|
||||
for command_match in command_matches[slice(*region_range)]
|
||||
])
|
||||
for span, region_range in insertion_interval_items
|
||||
for span, region_range in inserted_interval_items
|
||||
]
|
||||
|
||||
content_pieces = get_replaced_pieces(self.replace_for_content)
|
||||
matching_pieces = get_replaced_pieces(self.replace_for_matching)
|
||||
|
||||
def get_content(is_labelled: bool) -> str:
|
||||
inserted_strings = [
|
||||
self.get_command_string(
|
||||
attr_dicts[label],
|
||||
is_end=flag < 0,
|
||||
label_hex=self.int_to_hex(label + 1) if is_labelled else None
|
||||
)
|
||||
for label, flag in inserted_items
|
||||
]
|
||||
prefix, suffix = self.get_content_prefix_and_suffix(
|
||||
is_labelled=is_labelled
|
||||
)
|
||||
return "".join([
|
||||
prefix,
|
||||
join_strs(content_pieces, inserted_strings),
|
||||
suffix
|
||||
])
|
||||
|
||||
def get_group_substrs(group_labels: list[int]) -> list[str]:
|
||||
if not group_labels:
|
||||
return []
|
||||
|
||||
def get_index(label, flag):
|
||||
if label == -1:
|
||||
return 0 if flag == 1 else len(inserted_items) + 1
|
||||
return inserted_items.index((label, flag)) + 1
|
||||
|
||||
def get_labelled_span(label):
|
||||
if label == -1:
|
||||
return (0, full_len)
|
||||
return labelled_spans[label]
|
||||
|
||||
def label_contains(label_0, label_1):
|
||||
return self.span_contains(
|
||||
get_labelled_span(label_0), get_labelled_span(label_1)
|
||||
)
|
||||
|
||||
#piece_starts = [
|
||||
# get_index(group_labels[0], 1),
|
||||
# *(
|
||||
# get_index(curr_label, 1)
|
||||
# if label_contains(prev_label, curr_label)
|
||||
# else get_index(prev_label, -1)
|
||||
# for prev_label, curr_label in self.get_neighbouring_pairs(
|
||||
# group_labels
|
||||
# )
|
||||
# )
|
||||
#]
|
||||
#piece_ends = [
|
||||
# *(
|
||||
# get_index(curr_label, -1)
|
||||
# if label_contains(next_label, curr_label)
|
||||
# else get_index(next_label, 1)
|
||||
# for curr_label, next_label in self.get_neighbouring_pairs(
|
||||
# group_labels
|
||||
# )
|
||||
# ),
|
||||
# get_index(group_labels[-1], -1)
|
||||
#]
|
||||
|
||||
piece_ranges = get_complement_spans(
|
||||
(get_index(group_labels[0], 1), get_index(group_labels[-1], -1)),
|
||||
[
|
||||
(
|
||||
get_index(next_label, 1)
|
||||
if label_contains(prev_label, next_label)
|
||||
else get_index(prev_label, -1),
|
||||
get_index(prev_label, -1)
|
||||
if label_contains(next_label, prev_label)
|
||||
else get_index(next_label, 1)
|
||||
)
|
||||
for prev_label, next_label in self.get_neighbouring_pairs(
|
||||
group_labels
|
||||
)
|
||||
]
|
||||
)
|
||||
return [
|
||||
re.sub(r"\s+", "", "".join(
|
||||
matching_pieces[slice(*piece_ranges)]
|
||||
))
|
||||
for piece_ranges in piece_ranges
|
||||
]
|
||||
|
||||
self.labelled_spans = labelled_spans
|
||||
self.attr_dicts = attr_dicts
|
||||
self.insertion_items = insertion_items
|
||||
self.content_pieces = get_replaced_pieces(self.replace_for_content)
|
||||
self.matching_pieces = get_replaced_pieces(self.replace_for_matching)
|
||||
self.get_content = get_content
|
||||
self.get_group_substrs = get_group_substrs
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
|
|
@ -361,7 +448,9 @@ class StringMobject(SVGMobject, ABC):
|
|||
return []
|
||||
|
||||
@abstractmethod
|
||||
def get_external_specified_items(self) -> list[tuple[Span, dict[str, str]]]:
|
||||
def get_external_specified_items(
|
||||
self
|
||||
) -> list[tuple[Span, dict[str, str]]]:
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -382,83 +471,21 @@ class StringMobject(SVGMobject, ABC):
|
|||
command_matches: list[re.Match], command_flags: list[int]
|
||||
) -> list[tuple[re.Match, re.Match]]:
|
||||
result = []
|
||||
start_cmd_matches_stack = []
|
||||
for cmd_match, command_flag in zip(command_matches, command_flags):
|
||||
if command_flag == 1:
|
||||
start_cmd_matches_stack.append(cmd_match)
|
||||
elif command_flag == -1:
|
||||
if not start_cmd_matches_stack:
|
||||
open_stack = []
|
||||
for command_match, flag in zip(command_matches, command_flags):
|
||||
if flag == 1:
|
||||
open_stack.append(command_match)
|
||||
elif flag == -1:
|
||||
if not open_stack:
|
||||
raise ValueError("Missing open command")
|
||||
start_cmd_match = start_cmd_matches_stack.pop()
|
||||
open_command_match = open_stack.pop()
|
||||
result.append(
|
||||
(start_cmd_match, cmd_match)
|
||||
(open_command_match, command_match)
|
||||
)
|
||||
if start_cmd_matches_stack:
|
||||
if open_stack:
|
||||
raise ValueError("Missing close command")
|
||||
return result
|
||||
|
||||
def get_content(self, is_labelled: bool) -> str:
|
||||
insertion_strings = [
|
||||
self.get_command_string(
|
||||
self.attr_dicts[label],
|
||||
is_end=flag < 0,
|
||||
label_hex=self.int_to_hex(label + 1) if is_labelled else None
|
||||
)
|
||||
for label, flag in self.insertion_items
|
||||
]
|
||||
prefix, suffix = self.get_content_prefix_and_suffix(
|
||||
is_labelled=is_labelled
|
||||
)
|
||||
return "".join([
|
||||
prefix,
|
||||
self.join_strs(self.content_pieces, insertion_strings),
|
||||
suffix
|
||||
])
|
||||
|
||||
def get_group_substrs(self, group_labels: list[int]) -> list[str]:
|
||||
if not group_labels:
|
||||
return []
|
||||
|
||||
insertion_items = self.insertion_items
|
||||
|
||||
def get_index(label, flag):
|
||||
if label == -1:
|
||||
return 0 if flag == 1 else len(insertion_items) + 1
|
||||
return insertion_items.index((label, flag)) + 1
|
||||
|
||||
def get_labelled_span(label):
|
||||
if label == -1:
|
||||
return (0, self.full_len)
|
||||
return self.labelled_spans[label]
|
||||
|
||||
def label_contains(label_0, label_1):
|
||||
return self.span_contains(
|
||||
get_labelled_span(label_0), get_labelled_span(label_1)
|
||||
)
|
||||
|
||||
piece_ranges = self.get_complement_spans(
|
||||
(get_index(group_labels[0], 1), get_index(group_labels[-1], -1)),
|
||||
[
|
||||
(
|
||||
get_index(next_label, 1)
|
||||
if label_contains(prev_label, next_label)
|
||||
else get_index(prev_label, -1),
|
||||
get_index(prev_label, -1)
|
||||
if label_contains(next_label, prev_label)
|
||||
else get_index(next_label, 1)
|
||||
)
|
||||
for prev_label, next_label in self.get_neighbouring_pairs(
|
||||
group_labels
|
||||
)
|
||||
]
|
||||
)
|
||||
return [
|
||||
re.sub(r"\s+", "", "".join(
|
||||
self.matching_pieces[slice(*piece_range)]
|
||||
))
|
||||
for piece_range in piece_ranges
|
||||
]
|
||||
|
||||
# Selector
|
||||
|
||||
def get_submob_indices_list_by_span(
|
||||
|
|
@ -475,7 +502,7 @@ class StringMobject(SVGMobject, ABC):
|
|||
def get_specified_part_items(self) -> list[tuple[str, list[int]]]:
|
||||
return [
|
||||
(
|
||||
self.get_substr(span),
|
||||
self.string[slice(*span)],
|
||||
self.get_submob_indices_list_by_span(span)
|
||||
)
|
||||
for span in self.labelled_spans
|
||||
|
|
|
|||
|
|
@ -363,7 +363,7 @@ class MarkupText(StringMobject):
|
|||
self, is_labelled: bool
|
||||
) -> tuple[str, str]:
|
||||
global_attr_dict = {
|
||||
"foreground": self.base_color_hex,
|
||||
"foreground": self.color_to_hex(self.base_color),
|
||||
"font_family": self.font,
|
||||
"font_style": self.slant,
|
||||
"font_weight": self.weight,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue