mirror of
https://github.com/3b1b/manim.git
synced 2025-08-19 13:01:00 +00:00
Some refactors
This commit is contained in:
parent
69db53d612
commit
065900c6ac
4 changed files with 71 additions and 80 deletions
|
@ -2,9 +2,10 @@ from __future__ import annotations
|
|||
|
||||
from abc import ABC, abstractmethod
|
||||
import itertools as it
|
||||
import numpy as np
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
|
||||
from manimlib.constants import WHITE
|
||||
from manimlib.mobject.svg.svg_mobject import SVGMobject
|
||||
from manimlib.mobject.types.vectorized_mobject import VGroup
|
||||
|
@ -138,32 +139,16 @@ class LabelledString(SVGMobject, ABC):
|
|||
def find_indices(self, pattern: str | re.Pattern, **kwargs) -> list[int]:
|
||||
return [index for index, _ in self.find_spans(pattern, **kwargs)]
|
||||
|
||||
@staticmethod
|
||||
def is_single_selector(selector: Selector) -> bool:
|
||||
if isinstance(selector, str):
|
||||
return True
|
||||
if isinstance(selector, re.Pattern):
|
||||
return True
|
||||
if isinstance(selector, tuple):
|
||||
if len(selector) == 2 and all([
|
||||
isinstance(index, int) or index is None
|
||||
for index in selector
|
||||
]):
|
||||
return True
|
||||
return False
|
||||
|
||||
def find_spans_by_selector(self, selector: Selector) -> list[Span]:
|
||||
if self.is_single_selector(selector):
|
||||
selector = (selector,)
|
||||
result = []
|
||||
for sel in selector:
|
||||
if not self.is_single_selector(sel):
|
||||
raise TypeError(f"Invalid selector: '{sel}'")
|
||||
def find_spans_by_single_selector(sel):
|
||||
if isinstance(sel, str):
|
||||
spans = self.find_spans(re.escape(sel))
|
||||
elif isinstance(sel, re.Pattern):
|
||||
spans = self.find_spans(sel)
|
||||
else:
|
||||
return self.find_spans(re.escape(sel))
|
||||
if isinstance(sel, re.Pattern):
|
||||
return self.find_spans(sel)
|
||||
if isinstance(sel, tuple) and len(sel) == 2 and all([
|
||||
isinstance(index, int) or index is None
|
||||
for index in sel
|
||||
]):
|
||||
string_len = self.full_span[1]
|
||||
span = tuple([
|
||||
(
|
||||
|
@ -174,8 +159,17 @@ class LabelledString(SVGMobject, ABC):
|
|||
if index is not None else default_index
|
||||
for index, default_index in zip(sel, self.full_span)
|
||||
])
|
||||
spans = [span]
|
||||
result.extend(spans)
|
||||
return [span]
|
||||
return None
|
||||
|
||||
result = find_spans_by_single_selector(selector)
|
||||
if result is None:
|
||||
result = []
|
||||
for sel in selector:
|
||||
spans = find_spans_by_single_selector(sel)
|
||||
if spans is None:
|
||||
raise TypeError(f"Invalid selector: '{sel}'")
|
||||
result.extend(spans)
|
||||
return sorted(filter(
|
||||
lambda span: span[0] < span[1],
|
||||
self.remove_redundancies(result)
|
||||
|
@ -206,8 +200,8 @@ class LabelledString(SVGMobject, ABC):
|
|||
unique_vals.append(val)
|
||||
indices.append(index)
|
||||
indices.append(len(vals))
|
||||
spans = LabelledString.get_neighbouring_pairs(indices)
|
||||
return list(zip(unique_vals, spans))
|
||||
val_ranges = LabelledString.get_neighbouring_pairs(indices)
|
||||
return list(zip(unique_vals, val_ranges))
|
||||
|
||||
@staticmethod
|
||||
def span_contains(span_0: Span, span_1: Span) -> bool:
|
||||
|
@ -233,26 +227,23 @@ class LabelledString(SVGMobject, ABC):
|
|||
if not inserted_string_pairs:
|
||||
return []
|
||||
|
||||
spans = [
|
||||
span for span, _ in inserted_string_pairs
|
||||
]
|
||||
sorted_index_flag_pairs = sorted(
|
||||
it.product(range(len(spans)), range(2)),
|
||||
key=lambda t: (
|
||||
spans[t[0]][t[1]],
|
||||
np.sign(spans[t[0]][1 - t[1]] - spans[t[0]][t[1]]),
|
||||
-spans[t[0]][1 - t[1]],
|
||||
t[1],
|
||||
(1, -1)[t[1]] * t[0]
|
||||
indices, *_, inserted_strings = zip(*sorted([
|
||||
(
|
||||
span[flag],
|
||||
np.sign(span[1 - flag] - span[flag]),
|
||||
-span[1 - flag],
|
||||
flag,
|
||||
(1, -1)[flag] * item_index,
|
||||
str_pair[flag]
|
||||
)
|
||||
)
|
||||
indices, inserted_strings = zip(*[
|
||||
list(zip(*inserted_string_pairs[item_index]))[flag]
|
||||
for item_index, flag in sorted_index_flag_pairs
|
||||
])
|
||||
for item_index, (span, str_pair) in enumerate(
|
||||
inserted_string_pairs
|
||||
)
|
||||
for flag in range(2)
|
||||
]))
|
||||
return [
|
||||
(index, "".join(inserted_strings[slice(*item_span)]))
|
||||
for index, item_span
|
||||
(index, "".join(inserted_strings[slice(*index_range)]))
|
||||
for index, index_range
|
||||
in LabelledString.compress_neighbours(indices)
|
||||
]
|
||||
|
||||
|
@ -262,8 +253,7 @@ class LabelledString(SVGMobject, ABC):
|
|||
if not repl_items:
|
||||
return self.get_substr(span)
|
||||
|
||||
sorted_repl_items = sorted(repl_items, key=lambda t: t[0])
|
||||
repl_spans, repl_strs = zip(*sorted_repl_items)
|
||||
repl_spans, repl_strs = zip(*sorted(repl_items))
|
||||
pieces = [
|
||||
self.get_substr(piece_span)
|
||||
for piece_span in self.get_complement_spans(repl_spans, span)
|
||||
|
@ -335,7 +325,7 @@ class LabelledString(SVGMobject, ABC):
|
|||
return []
|
||||
|
||||
labels, labelled_submobjects = zip(*self.labelled_submobject_items)
|
||||
group_labels, labelled_submob_spans = zip(
|
||||
group_labels, labelled_submob_ranges = zip(
|
||||
*self.compress_neighbours(labels)
|
||||
)
|
||||
ordered_spans = [
|
||||
|
@ -362,8 +352,8 @@ class LabelledString(SVGMobject, ABC):
|
|||
)
|
||||
]
|
||||
submob_groups = VGroup(*[
|
||||
VGroup(*labelled_submobjects[slice(*submob_span)])
|
||||
for submob_span in labelled_submob_spans
|
||||
VGroup(*labelled_submobjects[slice(*submob_range)])
|
||||
for submob_range in labelled_submob_ranges
|
||||
])
|
||||
return list(zip(group_substrs, submob_groups))
|
||||
|
||||
|
@ -377,13 +367,9 @@ class LabelledString(SVGMobject, ABC):
|
|||
]
|
||||
|
||||
def select_part_by_span(self, custom_span: Span) -> VGroup:
|
||||
labels = [
|
||||
label for label, span in enumerate(self.label_span_list)
|
||||
if self.span_contains(custom_span, span)
|
||||
]
|
||||
return VGroup(*[
|
||||
submob for label, submob in self.labelled_submobject_items
|
||||
if label in labels
|
||||
if self.span_contains(custom_span, self.label_span_list[label])
|
||||
])
|
||||
|
||||
def select_parts(self, selector: Selector) -> VGroup:
|
||||
|
|
|
@ -209,17 +209,19 @@ class MTex(LabelledString):
|
|||
# Match paired double braces (`{{...}}`).
|
||||
sorted_brace_spans = sorted(self.brace_spans, key=lambda t: t[1])
|
||||
inner_brace_spans = [
|
||||
sorted_brace_spans[span_span[0]]
|
||||
for _, span_span in self.compress_neighbours([
|
||||
(brace_span[0] + index, brace_span[1] - index)
|
||||
for index, brace_span in enumerate(sorted_brace_spans)
|
||||
sorted_brace_spans[range_begin]
|
||||
for _, (range_begin, range_end) in self.compress_neighbours([
|
||||
(span_begin + index, span_end - index)
|
||||
for index, (span_begin, span_end) in enumerate(
|
||||
sorted_brace_spans
|
||||
)
|
||||
])
|
||||
if span_span[1] - span_span[0] >= 2
|
||||
if range_end - range_begin >= 2
|
||||
]
|
||||
inner_brace_content_spans = [
|
||||
(span[0] + 1, span[1] - 1)
|
||||
for span in inner_brace_spans
|
||||
if span[1] - span[0] > 2
|
||||
(span_begin + 1, span_end - 1)
|
||||
for span_begin, span_end in inner_brace_spans
|
||||
if span_end - span_begin > 2
|
||||
]
|
||||
|
||||
result = self.remove_redundancies(self.chain(
|
||||
|
@ -303,12 +305,14 @@ class MTex(LabelledString):
|
|||
# Selector
|
||||
|
||||
def get_cleaned_substr(self, span: Span) -> str:
|
||||
if not self.brace_spans:
|
||||
brace_begins, brace_ends = [], []
|
||||
else:
|
||||
brace_begins, brace_ends = zip(*self.brace_spans)
|
||||
left_brace_indices = list(brace_begins)
|
||||
right_brace_indices = [index - 1 for index in brace_ends]
|
||||
left_brace_indices = [
|
||||
span_begin
|
||||
for span_begin, _ in self.brace_spans
|
||||
]
|
||||
right_brace_indices = [
|
||||
span_end - 1
|
||||
for _, span_end in self.brace_spans
|
||||
]
|
||||
skippable_indices = self.chain(
|
||||
self.find_indices(r"\s"),
|
||||
self.script_char_indices,
|
||||
|
|
|
@ -300,18 +300,17 @@ class MarkupText(LabelledString):
|
|||
return result
|
||||
|
||||
def get_tag_spans(self) -> list[Span]:
|
||||
return [
|
||||
tag_span
|
||||
for begin_tag, end_tag, _ in self.tag_pairs_from_markup
|
||||
for tag_span in (begin_tag, end_tag)
|
||||
]
|
||||
return self.chain(
|
||||
(begin_tag_span, end_tag_span)
|
||||
for begin_tag_span, end_tag_span, _ in self.tag_pairs_from_markup
|
||||
)
|
||||
|
||||
def get_items_from_markup(self) -> list[Span]:
|
||||
def get_items_from_markup(self) -> list[tuple[Span, dict[str, str]]]:
|
||||
return [
|
||||
((begin_tag_span[1], end_tag_span[0]), attr_dict)
|
||||
for begin_tag_span, end_tag_span, attr_dict
|
||||
((span_begin, span_end), attr_dict)
|
||||
for (_, span_begin), (span_end, _), attr_dict
|
||||
in self.tag_pairs_from_markup
|
||||
if begin_tag_span[1] < end_tag_span[0]
|
||||
if span_begin < span_end
|
||||
]
|
||||
|
||||
def get_specified_items(self) -> list[tuple[Span, dict[str, str]]]:
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import itertools as it
|
||||
import numpy as np
|
||||
import pyperclip
|
||||
|
|
Loading…
Add table
Reference in a new issue