refactor: refactor StringMobject

This commit is contained in:
YishiMichael 2022-08-22 16:55:46 +08:00
parent 19c757ec90
commit c2a75e15cc
No known key found for this signature in database
GPG key ID: EC615C0C5A86BC80
3 changed files with 200 additions and 289 deletions

View file

@ -79,10 +79,9 @@ class MTex(StringMobject):
@staticmethod
def get_command_matches(string: str) -> list[re.Match]:
# Group together adjacent braces
# Lump together adjacent brace pairs
pattern = re.compile(r"""
(?P<command>\\(?:[a-zA-Z]+|.))
|(?P<script>[_^])
|(?P<open>{+)
|(?P<close>}+)
""", flags=re.X | re.S)
@ -131,9 +130,9 @@ class MTex(StringMobject):
@staticmethod
def replace_for_matching(match_obj: re.Match) -> str:
if match_obj.group("script"):
return ""
return match_obj.group()
if match_obj.group("command"):
return match_obj.group()
return ""
@staticmethod
def get_attr_dict_from_command_pair(

View file

@ -14,11 +14,11 @@ from manimlib.utils.color import color_to_rgb
from manimlib.utils.color import rgb_to_hex
from manimlib.utils.config_ops import digest_config
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from colour import Color
from typing import Iterable, TypeVar, Union
from typing import Callable, Iterable, Union
ManimColor = Union[str, Color]
Span = tuple[int, int]
@ -32,7 +32,6 @@ if TYPE_CHECKING:
tuple[Union[int, None], Union[int, None]]
]]
]
T = TypeVar("T")
class StringMobject(SVGMobject, ABC):
@ -72,7 +71,6 @@ class StringMobject(SVGMobject, ABC):
digest_config(self, kwargs)
if self.base_color is None:
self.base_color = WHITE
#self.base_color_hex = self.color_to_hex(self.base_color)
self.parse()
super().__init__(**kwargs)
@ -98,14 +96,13 @@ class StringMobject(SVGMobject, ABC):
labelled_content = self.get_content(is_labelled=True)
file_path = self.get_file_path_by_content(labelled_content)
labelled_svg = SVGMobject(file_path)
#print(len(self.submobjects), len(labelled_svg.submobjects)) # ????
if len(self.submobjects) != len(labelled_svg.submobjects):
log.warning(
"Cannot align submobjects of the labelled svg "
"to the original svg. Skip the labelling process."
)
for submob in self.submobjects:
submob.label = labels_count - 1
submob.label = 0
return
self.rearrange_submobjects_by_positions(labelled_svg)
@ -113,13 +110,13 @@ class StringMobject(SVGMobject, ABC):
for submob, labelled_svg_submob in zip(
self.submobjects, labelled_svg.submobjects
):
color_int = self.hex_to_int(self.color_to_hex(
label = self.hex_to_int(self.color_to_hex(
labelled_svg_submob.get_fill_color()
))
if color_int >= labels_count:
unrecognizable_colors.append(color_int)
color_int = labels_count
submob.label = color_int - 1
if label >= labels_count:
unrecognizable_colors.append(label)
label = 0
submob.label = label
if unrecognizable_colors:
log.warning(
"Unrecognizable color labels detected (%s). "
@ -214,304 +211,168 @@ class StringMobject(SVGMobject, ABC):
def get_substr(span: Span) -> str:
return self.string[slice(*span)]
def get_neighbouring_pairs(vals: Iterable[T]) -> list[tuple[T, T]]:
val_list = list(vals)
return list(zip(val_list[:-1], val_list[1:]))
#def get_complement_spans(
# universal_span: Span, interval_spans: list[Span]
#) -> list[Span]:
# if not interval_spans:
# return [universal_span]
# span_ends, span_starts = zip(*interval_spans)
# return list(zip(
# (universal_span[0], *span_starts),
# (*span_ends, universal_span[1])
# ))
def join_strs(strs: list[str], inserted_strs: list[str]) -> str:
return "".join(it.chain(*zip(strs, (*inserted_strs, ""))))
command_matches = self.get_command_matches(self.string)
#command_spans = [match_obj.span() for match_obj in command_matches]
configured_items = self.get_configured_items()
#configured_spans = [span for span, _ in configured_items]
#configured_attr_dicts = [d for _, d in configured_items]
categorized_spans = [
[(0, len(self.string))], # TODO
[span for span, _ in configured_items],
self.find_spans_by_selector(self.isolate),
self.find_spans_by_selector(self.protect),
[match_obj.span() for match_obj in command_matches] # TODO
]
isolated_spans = self.find_spans_by_selector(self.isolate)
protected_spans = self.find_spans_by_selector(self.protect)
command_matches = self.get_command_matches(self.string)
sorted_items = sorted([
(category, category_index, flag, *span[::flag])
for category, spans in enumerate(categorized_spans)
for category_index, span in enumerate(spans)
def get_key(category, i, flag):
def get_span_by_category(category, i):
if category == 0:
return configured_items[i][0]
if category == 1:
return isolated_spans[i]
if category == 2:
return protected_spans[i]
return command_matches[i].span()
index, paired_index = get_span_by_category(category, i)[::flag]
return (
index,
flag * (2 if index != paired_index else -1),
-paired_index,
flag * category,
flag * i
)
index_items = sorted([
(category, i, flag)
for category, item_length in enumerate((
len(configured_items),
len(isolated_spans),
len(protected_spans),
len(command_matches)
))
for i in range(item_length)
for flag in (1, -1)
], key=lambda t: (
t[3], t[2] * (2 if t[3] != t[4] else -1), -t[4],
t[2] * t[0], t[2] * t[1]
)) # TODO
], key=lambda t: get_key(*t))
labelled_spans = []
attr_dicts = []
inserted_items = []
#labelled_items = []
labelled_items = []
count = 0
region_index = 0
label = 1
protect_level = 0
region_levels = [0]
bracket_stack = [0]
bracket_count = 0
open_command_stack = []
open_stack = []
#protect_level_stack = []
#bracket_level_stack = []
#inserted_position_stack = []
#index_items_len = 0 # count * 2
for category, i, flag, _, _ in sorted_items:
if category >= 3:
if flag == 1:
protect_level += 1
for category, i, flag in index_items:
if category >= 2:
protect_level += flag
if flag == 1 or category == 2:
continue
protect_level -= 1
if category == 3:
continue
region_index += 1
inserted_items.append((i, 0))
command_match = command_matches[i]
command_flag = self.get_command_flag(command_match)
region_levels.append(region_levels[-1] + command_flag)
if command_flag == 1:
open_command_stack.append(
(command_match, region_index, count)
)
bracket_count += 1
bracket_stack.append(bracket_count)
open_command_stack.append((len(inserted_items), i))
continue
elif command_flag == 0:
if command_flag == 0:
continue
command_match_, region_index_, count_ = open_command_stack.pop()
pos, i_ = open_command_stack.pop()
bracket_stack.pop()
open_command_match = command_matches[i_]
attr_dict = self.get_attr_dict_from_command_pair(
command_match_, command_match
open_command_match, command_match
)
if attr_dict is None:
continue
span = (command_match_.end(), command_match.start())
region_span = (region_index_, region_index - 1)
else:
if flag == 1:
open_stack.append(
(category, i, protect_level, region_index, count)
)
continue
category_, i_, protect_level_, region_index_, count_ \
= open_stack.pop()
span = categorized_spans[category][i]
if (category_, i_) != (category, i):
log.warning(
"Partly overlapping substrings detected: '%s' and '%s'",
get_substr(categorized_spans[category_][i_]),
get_substr(span)
)
continue
if protect_level_ or protect_level:
continue
ls = region_levels[region_index_:region_index + 1]
if ls and (any(ls[0] > l for l in ls) or ls[0] < ls[-1]):
log.warning(
"Cannot handle substring '%s'", get_substr(span)
)
continue
attr_dict = configured_items[i][1] if category == 1 else {}
region_span = (region_index_, region_index)
#labelled_items.append(
# (span, region_span, (count_, count), attr_dict)
#)
pos = count_ * 2
labelled_spans.append(span)
attr_dicts.append(attr_dict)
inserted_items.insert(pos, (count, 1, span[0], region_span[0]))
inserted_items.append((count, -1, span[1], region_span[1]))
count += 1
span = (open_command_match.end(), command_match.start())
labelled_items.append((span, attr_dict))
inserted_items.insert(pos, (label, 1))
inserted_items.insert(-1, (label, -1))
label += 1
continue
if flag == 1:
open_stack.append((
len(inserted_items), category, i,
protect_level, bracket_stack.copy()
))
continue
span, attr_dict = configured_items[i] \
if category == 0 else (isolated_spans[i], {})
pos, category_, i_, protect_level_, bracket_stack_ \
= open_stack.pop()
if category_ != category or i_ != i:
span_ = configured_items[i_][0] \
if category_ == 0 else isolated_spans[i_]
log.warning(
"Partly overlapping substrings detected: '%s' and '%s'",
get_substr(span_),
get_substr(span)
)
continue
if protect_level_ or protect_level:
continue
if bracket_stack_ != bracket_stack:
log.warning(
"Cannot handle substring '%s'", get_substr(span)
)
continue
labelled_items.append((span, attr_dict))
inserted_items.insert(pos, (label, 1))
inserted_items.append((label, -1))
label += 1
labelled_items.insert(0, ((0, len(self.string)), {}))
inserted_items.insert(0, (0, 1))
inserted_items.append((0, -1))
#labelled_spans = []
#attr_dicts = []
#inserted_items = []
def reconstruct_string(
start_item: tuple[int, int],
end_item: tuple[int, int],
command_replace_func: Callable[[re.Match], str],
command_insert_func: Callable[[int, int, dict[str, str]], str]
) -> str:
def get_edge_item(i: int, flag: int) -> tuple[Span, str]:
if flag == 0:
match_obj = command_matches[i]
return (
match_obj.span(),
command_replace_func(match_obj)
)
span, attr_dict = labelled_items[i]
index = span[flag < 0]
return (
(index, index),
command_insert_func(i, flag, attr_dict)
)
#inserted_items.insert(0, (-1, 1, 0, 0))
#inserted_items.append((-1, -1, len(self.string), region_index))
inserted_label_items = [
(label, flag)
for label, flag, _, _ in inserted_items
]
#inserted_interval_spans = []
#command_matches_lists = []
#subpieces_lists = []
content_pieces = []
matching_pieces = []
for (_, _, prev_index, prev_region_index), (_, _, next_index, next_region_index) in get_neighbouring_pairs(inserted_items):
region_matches = command_matches[prev_region_index:next_region_index]
#command_matches_lists.append(region_matches)
subpieces = [
items = [
get_edge_item(i, flag)
for i, flag in inserted_items[slice(
inserted_items.index(start_item),
inserted_items.index(end_item) + 1
)]
]
pieces = [
get_substr((start, end))
for start, end in zip(
[prev_index, *(m.end() for m in region_matches)],
[*(m.start() for m in region_matches), next_index]
[interval_end for (_, interval_end), _ in items[:-1]],
[interval_start for (interval_start, _), _ in items[1:]]
)
]
content_pieces.append(join_strs(subpieces, [
self.replace_for_content(m) for m in region_matches
]))
matching_pieces.append(join_strs(subpieces, [
self.replace_for_matching(m) for m in region_matches
]))
#subpieces_lists.append([
# get_substr(s)
# for s in get_complement_spans(
# (prev_index, next_index),
# [m.span() for m in region_matches]
# )
#])
interval_pieces = [piece for _, piece in items[1:-1]]
return "".join(it.chain(*zip(pieces, (*interval_pieces, ""))))
self.labelled_spans = [span for span, _ in labelled_items]
self.reconstruct_string = reconstruct_string
#inserted_interval_spans = get_neighbouring_pairs([
# index
# for _, _, index, _ in inserted_items
#])
#command_matches_lists = [
# command_matches[slice(*region_range)]
# for region_range in get_neighbouring_pairs([
# region_index
# for _, _, _, region_index in inserted_items
# ])
#]
#subpieces_lists = [
# [
# get_substr(s)
# for s in get_complement_spans(
# span, [m.span() for m in match_list]
# )
# ]
# for span, match_list in zip(inserted_interval_spans, command_matches_lists)
#]
#def get_replaced_pieces(replace_func: Callable[[re.Match], str]) -> list[str]:
# return [
# join_strs(subpieces, [
# replace_func(command_match)
# for command_match in match_list
# ])
# for subpieces, match_list in zip(subpieces_lists, command_matches_lists)
# ]
#content_pieces = get_replaced_pieces(self.replace_for_content)
#matching_pieces = get_replaced_pieces(self.replace_for_matching)
def get_content(is_labelled: bool) -> str:
inserted_strings = [
self.get_command_string(
attr_dicts[label],
is_end=flag < 0,
label_hex=self.int_to_hex(label + 1) if is_labelled else None
)
for label, flag in inserted_label_items[1:-1]
]
prefix, suffix = self.get_content_prefix_and_suffix(
is_labelled=is_labelled
def get_content(self, is_labelled: bool) -> str:
content = self.reconstruct_string(
(0, 1), (0, -1),
self.replace_for_content,
lambda label, flag, attr_dict: self.get_command_string(
attr_dict,
is_end=flag < 0,
label_hex=self.int_to_hex(label) if is_labelled else None
)
return "".join([
prefix,
join_strs(content_pieces, inserted_strings),
suffix
])
def get_group_part_items_by_labels(labels: list[int]) -> list[tuple[str, list[int]]]:
if not labels:
return []
range_lens, group_labels = zip(*(
(len(list(grouper)), val)
for val, grouper in it.groupby(labels)
))
submob_indices_lists = [
list(range(*submob_range))
for submob_range in get_neighbouring_pairs(
[0, *it.accumulate(range_lens)]
)
]
def get_region_index(label, flag):
#if label == -1:
# if flag == 1:
# return 0
# return len(inserted_label_items) - 1
return inserted_label_items.index((label, flag))
def get_labelled_span(label):
#if label == -1:
# return (0, len(self.string))
return labelled_spans[label]
def label_contains(label_0, label_1):
return self.span_contains(
get_labelled_span(label_0), get_labelled_span(label_1)
)
piece_starts = [
get_region_index(group_labels[0], 1),
*(
get_region_index(curr_label, 1)
if label_contains(prev_label, curr_label)
else get_region_index(prev_label, -1)
for prev_label, curr_label in get_neighbouring_pairs(
group_labels
)
)
]
piece_ends = [
*(
get_region_index(curr_label, -1)
if label_contains(next_label, curr_label)
else get_region_index(next_label, 1)
for curr_label, next_label in get_neighbouring_pairs(
group_labels
)
),
get_region_index(group_labels[-1], -1)
]
#piece_ranges = get_complement_spans(
# (get_region_index(group_labels[0], 1), get_region_index(group_labels[-1], -1)),
# [
# (
# get_region_index(next_label, 1)
# if label_contains(prev_label, next_label)
# else get_region_index(prev_label, -1),
# get_region_index(prev_label, -1)
# if label_contains(next_label, prev_label)
# else get_region_index(next_label, 1)
# )
# for prev_label, next_label in get_neighbouring_pairs(
# group_labels
# )
# ]
#)
group_substrs = [
re.sub(r"\s+", "", "".join(
matching_pieces[start:end]
))
for start, end in zip(piece_starts, piece_ends)
]
return list(zip(group_substrs, submob_indices_lists))
#print(labelled_spans)
self.labelled_spans = labelled_spans
self.get_content = get_content
self.get_group_part_items_by_labels = get_group_part_items_by_labels
)
prefix, suffix = self.get_content_prefix_and_suffix(
is_labelled=is_labelled
)
return "".join([prefix, content, suffix])
@staticmethod
@abstractmethod
@ -574,11 +435,62 @@ class StringMobject(SVGMobject, ABC):
self.string[slice(*span)],
self.get_submob_indices_list_by_span(span)
)
for span in self.labelled_spans[:-1]
for span in self.labelled_spans[1:]
]
def get_group_part_items(self) -> list[tuple[str, list[int]]]:
return self.get_group_part_items_by_labels(self.labels)
if not self.labels:
return []
def get_neighbouring_pairs(vals):
return list(zip(vals[:-1], vals[1:]))
range_lens, group_labels = zip(*(
(len(list(grouper)), val)
for val, grouper in it.groupby(self.labels)
))
submob_indices_lists = [
list(range(*submob_range))
for submob_range in get_neighbouring_pairs(
[0, *it.accumulate(range_lens)]
)
]
labelled_spans = self.labelled_spans
start_items = [
(group_labels[0], 1),
*(
(curr_label, 1)
if self.span_contains(
labelled_spans[prev_label], labelled_spans[curr_label]
)
else (prev_label, -1)
for prev_label, curr_label in get_neighbouring_pairs(
group_labels
)
)
]
end_items = [
*(
(curr_label, -1)
if self.span_contains(
labelled_spans[next_label], labelled_spans[curr_label]
)
else (next_label, 1)
for curr_label, next_label in get_neighbouring_pairs(
group_labels
)
),
(group_labels[-1], -1)
]
group_substrs = [
re.sub(r"\s+", "", self.reconstruct_string(
start_item, end_item,
self.replace_for_matching,
lambda label, flag, attr_dict: ""
))
for start_item, end_item in zip(start_items, end_items)
]
return list(zip(group_substrs, submob_indices_lists))
def get_submob_indices_lists_by_selector(
self, selector: Selector

View file

@ -247,7 +247,7 @@ class MarkupText(StringMobject):
<
(?P<close_slash>/)?
(?P<tag_name>\w+)\s*
(?P<attr_list>(?:\w+\s*\=\s*(?P<quot>["']).*?(?P=quot)\s*)*) # TODO: test wsp
(?P<attr_list>(?:\w+\s*\=\s*(?P<quot>["']).*?(?P=quot)\s*)*)
(?P<elision_slash>/)?
>
)