Some refactors

This commit is contained in:
YishiMichael 2022-04-27 23:04:24 +08:00
parent 69db53d612
commit 065900c6ac
No known key found for this signature in database
GPG key ID: EC615C0C5A86BC80
4 changed files with 71 additions and 80 deletions

View file

@ -2,9 +2,10 @@ from __future__ import annotations
from abc import ABC, abstractmethod
import itertools as it
import numpy as np
import re
import numpy as np
from manimlib.constants import WHITE
from manimlib.mobject.svg.svg_mobject import SVGMobject
from manimlib.mobject.types.vectorized_mobject import VGroup
@ -138,32 +139,16 @@ class LabelledString(SVGMobject, ABC):
def find_indices(self, pattern: str | re.Pattern, **kwargs) -> list[int]:
return [index for index, _ in self.find_spans(pattern, **kwargs)]
@staticmethod
def is_single_selector(selector: Selector) -> bool:
if isinstance(selector, str):
return True
if isinstance(selector, re.Pattern):
return True
if isinstance(selector, tuple):
if len(selector) == 2 and all([
isinstance(index, int) or index is None
for index in selector
]):
return True
return False
def find_spans_by_selector(self, selector: Selector) -> list[Span]:
if self.is_single_selector(selector):
selector = (selector,)
result = []
for sel in selector:
if not self.is_single_selector(sel):
raise TypeError(f"Invalid selector: '{sel}'")
def find_spans_by_single_selector(sel):
if isinstance(sel, str):
spans = self.find_spans(re.escape(sel))
elif isinstance(sel, re.Pattern):
spans = self.find_spans(sel)
else:
return self.find_spans(re.escape(sel))
if isinstance(sel, re.Pattern):
return self.find_spans(sel)
if isinstance(sel, tuple) and len(sel) == 2 and all([
isinstance(index, int) or index is None
for index in sel
]):
string_len = self.full_span[1]
span = tuple([
(
@ -174,8 +159,17 @@ class LabelledString(SVGMobject, ABC):
if index is not None else default_index
for index, default_index in zip(sel, self.full_span)
])
spans = [span]
result.extend(spans)
return [span]
return None
result = find_spans_by_single_selector(selector)
if result is None:
result = []
for sel in selector:
spans = find_spans_by_single_selector(sel)
if spans is None:
raise TypeError(f"Invalid selector: '{sel}'")
result.extend(spans)
return sorted(filter(
lambda span: span[0] < span[1],
self.remove_redundancies(result)
@ -206,8 +200,8 @@ class LabelledString(SVGMobject, ABC):
unique_vals.append(val)
indices.append(index)
indices.append(len(vals))
spans = LabelledString.get_neighbouring_pairs(indices)
return list(zip(unique_vals, spans))
val_ranges = LabelledString.get_neighbouring_pairs(indices)
return list(zip(unique_vals, val_ranges))
@staticmethod
def span_contains(span_0: Span, span_1: Span) -> bool:
@ -233,26 +227,23 @@ class LabelledString(SVGMobject, ABC):
if not inserted_string_pairs:
return []
spans = [
span for span, _ in inserted_string_pairs
]
sorted_index_flag_pairs = sorted(
it.product(range(len(spans)), range(2)),
key=lambda t: (
spans[t[0]][t[1]],
np.sign(spans[t[0]][1 - t[1]] - spans[t[0]][t[1]]),
-spans[t[0]][1 - t[1]],
t[1],
(1, -1)[t[1]] * t[0]
indices, *_, inserted_strings = zip(*sorted([
(
span[flag],
np.sign(span[1 - flag] - span[flag]),
-span[1 - flag],
flag,
(1, -1)[flag] * item_index,
str_pair[flag]
)
)
indices, inserted_strings = zip(*[
list(zip(*inserted_string_pairs[item_index]))[flag]
for item_index, flag in sorted_index_flag_pairs
])
for item_index, (span, str_pair) in enumerate(
inserted_string_pairs
)
for flag in range(2)
]))
return [
(index, "".join(inserted_strings[slice(*item_span)]))
for index, item_span
(index, "".join(inserted_strings[slice(*index_range)]))
for index, index_range
in LabelledString.compress_neighbours(indices)
]
@ -262,8 +253,7 @@ class LabelledString(SVGMobject, ABC):
if not repl_items:
return self.get_substr(span)
sorted_repl_items = sorted(repl_items, key=lambda t: t[0])
repl_spans, repl_strs = zip(*sorted_repl_items)
repl_spans, repl_strs = zip(*sorted(repl_items))
pieces = [
self.get_substr(piece_span)
for piece_span in self.get_complement_spans(repl_spans, span)
@ -335,7 +325,7 @@ class LabelledString(SVGMobject, ABC):
return []
labels, labelled_submobjects = zip(*self.labelled_submobject_items)
group_labels, labelled_submob_spans = zip(
group_labels, labelled_submob_ranges = zip(
*self.compress_neighbours(labels)
)
ordered_spans = [
@ -362,8 +352,8 @@ class LabelledString(SVGMobject, ABC):
)
]
submob_groups = VGroup(*[
VGroup(*labelled_submobjects[slice(*submob_span)])
for submob_span in labelled_submob_spans
VGroup(*labelled_submobjects[slice(*submob_range)])
for submob_range in labelled_submob_ranges
])
return list(zip(group_substrs, submob_groups))
@ -377,13 +367,9 @@ class LabelledString(SVGMobject, ABC):
]
def select_part_by_span(self, custom_span: Span) -> VGroup:
labels = [
label for label, span in enumerate(self.label_span_list)
if self.span_contains(custom_span, span)
]
return VGroup(*[
submob for label, submob in self.labelled_submobject_items
if label in labels
if self.span_contains(custom_span, self.label_span_list[label])
])
def select_parts(self, selector: Selector) -> VGroup:

View file

@ -209,17 +209,19 @@ class MTex(LabelledString):
# Match paired double braces (`{{...}}`).
sorted_brace_spans = sorted(self.brace_spans, key=lambda t: t[1])
inner_brace_spans = [
sorted_brace_spans[span_span[0]]
for _, span_span in self.compress_neighbours([
(brace_span[0] + index, brace_span[1] - index)
for index, brace_span in enumerate(sorted_brace_spans)
sorted_brace_spans[range_begin]
for _, (range_begin, range_end) in self.compress_neighbours([
(span_begin + index, span_end - index)
for index, (span_begin, span_end) in enumerate(
sorted_brace_spans
)
])
if span_span[1] - span_span[0] >= 2
if range_end - range_begin >= 2
]
inner_brace_content_spans = [
(span[0] + 1, span[1] - 1)
for span in inner_brace_spans
if span[1] - span[0] > 2
(span_begin + 1, span_end - 1)
for span_begin, span_end in inner_brace_spans
if span_end - span_begin > 2
]
result = self.remove_redundancies(self.chain(
@ -303,12 +305,14 @@ class MTex(LabelledString):
# Selector
def get_cleaned_substr(self, span: Span) -> str:
if not self.brace_spans:
brace_begins, brace_ends = [], []
else:
brace_begins, brace_ends = zip(*self.brace_spans)
left_brace_indices = list(brace_begins)
right_brace_indices = [index - 1 for index in brace_ends]
left_brace_indices = [
span_begin
for span_begin, _ in self.brace_spans
]
right_brace_indices = [
span_end - 1
for _, span_end in self.brace_spans
]
skippable_indices = self.chain(
self.find_indices(r"\s"),
self.script_char_indices,

View file

@ -300,18 +300,17 @@ class MarkupText(LabelledString):
return result
def get_tag_spans(self) -> list[Span]:
return [
tag_span
for begin_tag, end_tag, _ in self.tag_pairs_from_markup
for tag_span in (begin_tag, end_tag)
]
return self.chain(
(begin_tag_span, end_tag_span)
for begin_tag_span, end_tag_span, _ in self.tag_pairs_from_markup
)
def get_items_from_markup(self) -> list[Span]:
def get_items_from_markup(self) -> list[tuple[Span, dict[str, str]]]:
return [
((begin_tag_span[1], end_tag_span[0]), attr_dict)
for begin_tag_span, end_tag_span, attr_dict
((span_begin, span_end), attr_dict)
for (_, span_begin), (span_end, _), attr_dict
in self.tag_pairs_from_markup
if begin_tag_span[1] < end_tag_span[0]
if span_begin < span_end
]
def get_specified_items(self) -> list[tuple[Span, dict[str, str]]]:

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import itertools as it
import numpy as np
import pyperclip