Some refactors

This commit is contained in:
YishiMichael 2022-04-27 23:04:24 +08:00
parent 69db53d612
commit 065900c6ac
No known key found for this signature in database
GPG key ID: EC615C0C5A86BC80
4 changed files with 71 additions and 80 deletions

View file

@ -2,9 +2,10 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import itertools as it import itertools as it
import numpy as np
import re import re
import numpy as np
from manimlib.constants import WHITE from manimlib.constants import WHITE
from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.svg.svg_mobject import SVGMobject
from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VGroup
@ -138,32 +139,16 @@ class LabelledString(SVGMobject, ABC):
def find_indices(self, pattern: str | re.Pattern, **kwargs) -> list[int]: def find_indices(self, pattern: str | re.Pattern, **kwargs) -> list[int]:
return [index for index, _ in self.find_spans(pattern, **kwargs)] return [index for index, _ in self.find_spans(pattern, **kwargs)]
@staticmethod
def is_single_selector(selector: Selector) -> bool:
if isinstance(selector, str):
return True
if isinstance(selector, re.Pattern):
return True
if isinstance(selector, tuple):
if len(selector) == 2 and all([
isinstance(index, int) or index is None
for index in selector
]):
return True
return False
def find_spans_by_selector(self, selector: Selector) -> list[Span]: def find_spans_by_selector(self, selector: Selector) -> list[Span]:
if self.is_single_selector(selector): def find_spans_by_single_selector(sel):
selector = (selector,)
result = []
for sel in selector:
if not self.is_single_selector(sel):
raise TypeError(f"Invalid selector: '{sel}'")
if isinstance(sel, str): if isinstance(sel, str):
spans = self.find_spans(re.escape(sel)) return self.find_spans(re.escape(sel))
elif isinstance(sel, re.Pattern): if isinstance(sel, re.Pattern):
spans = self.find_spans(sel) return self.find_spans(sel)
else: if isinstance(sel, tuple) and len(sel) == 2 and all([
isinstance(index, int) or index is None
for index in sel
]):
string_len = self.full_span[1] string_len = self.full_span[1]
span = tuple([ span = tuple([
( (
@ -174,7 +159,16 @@ class LabelledString(SVGMobject, ABC):
if index is not None else default_index if index is not None else default_index
for index, default_index in zip(sel, self.full_span) for index, default_index in zip(sel, self.full_span)
]) ])
spans = [span] return [span]
return None
result = find_spans_by_single_selector(selector)
if result is None:
result = []
for sel in selector:
spans = find_spans_by_single_selector(sel)
if spans is None:
raise TypeError(f"Invalid selector: '{sel}'")
result.extend(spans) result.extend(spans)
return sorted(filter( return sorted(filter(
lambda span: span[0] < span[1], lambda span: span[0] < span[1],
@ -206,8 +200,8 @@ class LabelledString(SVGMobject, ABC):
unique_vals.append(val) unique_vals.append(val)
indices.append(index) indices.append(index)
indices.append(len(vals)) indices.append(len(vals))
spans = LabelledString.get_neighbouring_pairs(indices) val_ranges = LabelledString.get_neighbouring_pairs(indices)
return list(zip(unique_vals, spans)) return list(zip(unique_vals, val_ranges))
@staticmethod @staticmethod
def span_contains(span_0: Span, span_1: Span) -> bool: def span_contains(span_0: Span, span_1: Span) -> bool:
@ -233,26 +227,23 @@ class LabelledString(SVGMobject, ABC):
if not inserted_string_pairs: if not inserted_string_pairs:
return [] return []
spans = [ indices, *_, inserted_strings = zip(*sorted([
span for span, _ in inserted_string_pairs (
] span[flag],
sorted_index_flag_pairs = sorted( np.sign(span[1 - flag] - span[flag]),
it.product(range(len(spans)), range(2)), -span[1 - flag],
key=lambda t: ( flag,
spans[t[0]][t[1]], (1, -1)[flag] * item_index,
np.sign(spans[t[0]][1 - t[1]] - spans[t[0]][t[1]]), str_pair[flag]
-spans[t[0]][1 - t[1]],
t[1],
(1, -1)[t[1]] * t[0]
) )
for item_index, (span, str_pair) in enumerate(
inserted_string_pairs
) )
indices, inserted_strings = zip(*[ for flag in range(2)
list(zip(*inserted_string_pairs[item_index]))[flag] ]))
for item_index, flag in sorted_index_flag_pairs
])
return [ return [
(index, "".join(inserted_strings[slice(*item_span)])) (index, "".join(inserted_strings[slice(*index_range)]))
for index, item_span for index, index_range
in LabelledString.compress_neighbours(indices) in LabelledString.compress_neighbours(indices)
] ]
@ -262,8 +253,7 @@ class LabelledString(SVGMobject, ABC):
if not repl_items: if not repl_items:
return self.get_substr(span) return self.get_substr(span)
sorted_repl_items = sorted(repl_items, key=lambda t: t[0]) repl_spans, repl_strs = zip(*sorted(repl_items))
repl_spans, repl_strs = zip(*sorted_repl_items)
pieces = [ pieces = [
self.get_substr(piece_span) self.get_substr(piece_span)
for piece_span in self.get_complement_spans(repl_spans, span) for piece_span in self.get_complement_spans(repl_spans, span)
@ -335,7 +325,7 @@ class LabelledString(SVGMobject, ABC):
return [] return []
labels, labelled_submobjects = zip(*self.labelled_submobject_items) labels, labelled_submobjects = zip(*self.labelled_submobject_items)
group_labels, labelled_submob_spans = zip( group_labels, labelled_submob_ranges = zip(
*self.compress_neighbours(labels) *self.compress_neighbours(labels)
) )
ordered_spans = [ ordered_spans = [
@ -362,8 +352,8 @@ class LabelledString(SVGMobject, ABC):
) )
] ]
submob_groups = VGroup(*[ submob_groups = VGroup(*[
VGroup(*labelled_submobjects[slice(*submob_span)]) VGroup(*labelled_submobjects[slice(*submob_range)])
for submob_span in labelled_submob_spans for submob_range in labelled_submob_ranges
]) ])
return list(zip(group_substrs, submob_groups)) return list(zip(group_substrs, submob_groups))
@ -377,13 +367,9 @@ class LabelledString(SVGMobject, ABC):
] ]
def select_part_by_span(self, custom_span: Span) -> VGroup: def select_part_by_span(self, custom_span: Span) -> VGroup:
labels = [
label for label, span in enumerate(self.label_span_list)
if self.span_contains(custom_span, span)
]
return VGroup(*[ return VGroup(*[
submob for label, submob in self.labelled_submobject_items submob for label, submob in self.labelled_submobject_items
if label in labels if self.span_contains(custom_span, self.label_span_list[label])
]) ])
def select_parts(self, selector: Selector) -> VGroup: def select_parts(self, selector: Selector) -> VGroup:

View file

@ -209,17 +209,19 @@ class MTex(LabelledString):
# Match paired double braces (`{{...}}`). # Match paired double braces (`{{...}}`).
sorted_brace_spans = sorted(self.brace_spans, key=lambda t: t[1]) sorted_brace_spans = sorted(self.brace_spans, key=lambda t: t[1])
inner_brace_spans = [ inner_brace_spans = [
sorted_brace_spans[span_span[0]] sorted_brace_spans[range_begin]
for _, span_span in self.compress_neighbours([ for _, (range_begin, range_end) in self.compress_neighbours([
(brace_span[0] + index, brace_span[1] - index) (span_begin + index, span_end - index)
for index, brace_span in enumerate(sorted_brace_spans) for index, (span_begin, span_end) in enumerate(
sorted_brace_spans
)
]) ])
if span_span[1] - span_span[0] >= 2 if range_end - range_begin >= 2
] ]
inner_brace_content_spans = [ inner_brace_content_spans = [
(span[0] + 1, span[1] - 1) (span_begin + 1, span_end - 1)
for span in inner_brace_spans for span_begin, span_end in inner_brace_spans
if span[1] - span[0] > 2 if span_end - span_begin > 2
] ]
result = self.remove_redundancies(self.chain( result = self.remove_redundancies(self.chain(
@ -303,12 +305,14 @@ class MTex(LabelledString):
# Selector # Selector
def get_cleaned_substr(self, span: Span) -> str: def get_cleaned_substr(self, span: Span) -> str:
if not self.brace_spans: left_brace_indices = [
brace_begins, brace_ends = [], [] span_begin
else: for span_begin, _ in self.brace_spans
brace_begins, brace_ends = zip(*self.brace_spans) ]
left_brace_indices = list(brace_begins) right_brace_indices = [
right_brace_indices = [index - 1 for index in brace_ends] span_end - 1
for _, span_end in self.brace_spans
]
skippable_indices = self.chain( skippable_indices = self.chain(
self.find_indices(r"\s"), self.find_indices(r"\s"),
self.script_char_indices, self.script_char_indices,

View file

@ -300,18 +300,17 @@ class MarkupText(LabelledString):
return result return result
def get_tag_spans(self) -> list[Span]: def get_tag_spans(self) -> list[Span]:
return [ return self.chain(
tag_span (begin_tag_span, end_tag_span)
for begin_tag, end_tag, _ in self.tag_pairs_from_markup for begin_tag_span, end_tag_span, _ in self.tag_pairs_from_markup
for tag_span in (begin_tag, end_tag) )
]
def get_items_from_markup(self) -> list[Span]: def get_items_from_markup(self) -> list[tuple[Span, dict[str, str]]]:
return [ return [
((begin_tag_span[1], end_tag_span[0]), attr_dict) ((span_begin, span_end), attr_dict)
for begin_tag_span, end_tag_span, attr_dict for (_, span_begin), (span_end, _), attr_dict
in self.tag_pairs_from_markup in self.tag_pairs_from_markup
if begin_tag_span[1] < end_tag_span[0] if span_begin < span_end
] ]
def get_specified_items(self) -> list[tuple[Span, dict[str, str]]]: def get_specified_items(self) -> list[tuple[Span, dict[str, str]]]:

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import itertools as it import itertools as it
import numpy as np import numpy as np
import pyperclip import pyperclip