mirror of
https://github.com/3b1b/manim.git
synced 2025-08-19 21:08:53 +00:00
Some refactors
This commit is contained in:
parent
69db53d612
commit
065900c6ac
4 changed files with 71 additions and 80 deletions
|
@ -2,9 +2,10 @@ from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import itertools as it
|
import itertools as it
|
||||||
import numpy as np
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from manimlib.constants import WHITE
|
from manimlib.constants import WHITE
|
||||||
from manimlib.mobject.svg.svg_mobject import SVGMobject
|
from manimlib.mobject.svg.svg_mobject import SVGMobject
|
||||||
from manimlib.mobject.types.vectorized_mobject import VGroup
|
from manimlib.mobject.types.vectorized_mobject import VGroup
|
||||||
|
@ -138,32 +139,16 @@ class LabelledString(SVGMobject, ABC):
|
||||||
def find_indices(self, pattern: str | re.Pattern, **kwargs) -> list[int]:
|
def find_indices(self, pattern: str | re.Pattern, **kwargs) -> list[int]:
|
||||||
return [index for index, _ in self.find_spans(pattern, **kwargs)]
|
return [index for index, _ in self.find_spans(pattern, **kwargs)]
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def is_single_selector(selector: Selector) -> bool:
|
|
||||||
if isinstance(selector, str):
|
|
||||||
return True
|
|
||||||
if isinstance(selector, re.Pattern):
|
|
||||||
return True
|
|
||||||
if isinstance(selector, tuple):
|
|
||||||
if len(selector) == 2 and all([
|
|
||||||
isinstance(index, int) or index is None
|
|
||||||
for index in selector
|
|
||||||
]):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def find_spans_by_selector(self, selector: Selector) -> list[Span]:
|
def find_spans_by_selector(self, selector: Selector) -> list[Span]:
|
||||||
if self.is_single_selector(selector):
|
def find_spans_by_single_selector(sel):
|
||||||
selector = (selector,)
|
|
||||||
result = []
|
|
||||||
for sel in selector:
|
|
||||||
if not self.is_single_selector(sel):
|
|
||||||
raise TypeError(f"Invalid selector: '{sel}'")
|
|
||||||
if isinstance(sel, str):
|
if isinstance(sel, str):
|
||||||
spans = self.find_spans(re.escape(sel))
|
return self.find_spans(re.escape(sel))
|
||||||
elif isinstance(sel, re.Pattern):
|
if isinstance(sel, re.Pattern):
|
||||||
spans = self.find_spans(sel)
|
return self.find_spans(sel)
|
||||||
else:
|
if isinstance(sel, tuple) and len(sel) == 2 and all([
|
||||||
|
isinstance(index, int) or index is None
|
||||||
|
for index in sel
|
||||||
|
]):
|
||||||
string_len = self.full_span[1]
|
string_len = self.full_span[1]
|
||||||
span = tuple([
|
span = tuple([
|
||||||
(
|
(
|
||||||
|
@ -174,7 +159,16 @@ class LabelledString(SVGMobject, ABC):
|
||||||
if index is not None else default_index
|
if index is not None else default_index
|
||||||
for index, default_index in zip(sel, self.full_span)
|
for index, default_index in zip(sel, self.full_span)
|
||||||
])
|
])
|
||||||
spans = [span]
|
return [span]
|
||||||
|
return None
|
||||||
|
|
||||||
|
result = find_spans_by_single_selector(selector)
|
||||||
|
if result is None:
|
||||||
|
result = []
|
||||||
|
for sel in selector:
|
||||||
|
spans = find_spans_by_single_selector(sel)
|
||||||
|
if spans is None:
|
||||||
|
raise TypeError(f"Invalid selector: '{sel}'")
|
||||||
result.extend(spans)
|
result.extend(spans)
|
||||||
return sorted(filter(
|
return sorted(filter(
|
||||||
lambda span: span[0] < span[1],
|
lambda span: span[0] < span[1],
|
||||||
|
@ -206,8 +200,8 @@ class LabelledString(SVGMobject, ABC):
|
||||||
unique_vals.append(val)
|
unique_vals.append(val)
|
||||||
indices.append(index)
|
indices.append(index)
|
||||||
indices.append(len(vals))
|
indices.append(len(vals))
|
||||||
spans = LabelledString.get_neighbouring_pairs(indices)
|
val_ranges = LabelledString.get_neighbouring_pairs(indices)
|
||||||
return list(zip(unique_vals, spans))
|
return list(zip(unique_vals, val_ranges))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def span_contains(span_0: Span, span_1: Span) -> bool:
|
def span_contains(span_0: Span, span_1: Span) -> bool:
|
||||||
|
@ -233,26 +227,23 @@ class LabelledString(SVGMobject, ABC):
|
||||||
if not inserted_string_pairs:
|
if not inserted_string_pairs:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
spans = [
|
indices, *_, inserted_strings = zip(*sorted([
|
||||||
span for span, _ in inserted_string_pairs
|
(
|
||||||
]
|
span[flag],
|
||||||
sorted_index_flag_pairs = sorted(
|
np.sign(span[1 - flag] - span[flag]),
|
||||||
it.product(range(len(spans)), range(2)),
|
-span[1 - flag],
|
||||||
key=lambda t: (
|
flag,
|
||||||
spans[t[0]][t[1]],
|
(1, -1)[flag] * item_index,
|
||||||
np.sign(spans[t[0]][1 - t[1]] - spans[t[0]][t[1]]),
|
str_pair[flag]
|
||||||
-spans[t[0]][1 - t[1]],
|
|
||||||
t[1],
|
|
||||||
(1, -1)[t[1]] * t[0]
|
|
||||||
)
|
)
|
||||||
|
for item_index, (span, str_pair) in enumerate(
|
||||||
|
inserted_string_pairs
|
||||||
)
|
)
|
||||||
indices, inserted_strings = zip(*[
|
for flag in range(2)
|
||||||
list(zip(*inserted_string_pairs[item_index]))[flag]
|
]))
|
||||||
for item_index, flag in sorted_index_flag_pairs
|
|
||||||
])
|
|
||||||
return [
|
return [
|
||||||
(index, "".join(inserted_strings[slice(*item_span)]))
|
(index, "".join(inserted_strings[slice(*index_range)]))
|
||||||
for index, item_span
|
for index, index_range
|
||||||
in LabelledString.compress_neighbours(indices)
|
in LabelledString.compress_neighbours(indices)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -262,8 +253,7 @@ class LabelledString(SVGMobject, ABC):
|
||||||
if not repl_items:
|
if not repl_items:
|
||||||
return self.get_substr(span)
|
return self.get_substr(span)
|
||||||
|
|
||||||
sorted_repl_items = sorted(repl_items, key=lambda t: t[0])
|
repl_spans, repl_strs = zip(*sorted(repl_items))
|
||||||
repl_spans, repl_strs = zip(*sorted_repl_items)
|
|
||||||
pieces = [
|
pieces = [
|
||||||
self.get_substr(piece_span)
|
self.get_substr(piece_span)
|
||||||
for piece_span in self.get_complement_spans(repl_spans, span)
|
for piece_span in self.get_complement_spans(repl_spans, span)
|
||||||
|
@ -335,7 +325,7 @@ class LabelledString(SVGMobject, ABC):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
labels, labelled_submobjects = zip(*self.labelled_submobject_items)
|
labels, labelled_submobjects = zip(*self.labelled_submobject_items)
|
||||||
group_labels, labelled_submob_spans = zip(
|
group_labels, labelled_submob_ranges = zip(
|
||||||
*self.compress_neighbours(labels)
|
*self.compress_neighbours(labels)
|
||||||
)
|
)
|
||||||
ordered_spans = [
|
ordered_spans = [
|
||||||
|
@ -362,8 +352,8 @@ class LabelledString(SVGMobject, ABC):
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
submob_groups = VGroup(*[
|
submob_groups = VGroup(*[
|
||||||
VGroup(*labelled_submobjects[slice(*submob_span)])
|
VGroup(*labelled_submobjects[slice(*submob_range)])
|
||||||
for submob_span in labelled_submob_spans
|
for submob_range in labelled_submob_ranges
|
||||||
])
|
])
|
||||||
return list(zip(group_substrs, submob_groups))
|
return list(zip(group_substrs, submob_groups))
|
||||||
|
|
||||||
|
@ -377,13 +367,9 @@ class LabelledString(SVGMobject, ABC):
|
||||||
]
|
]
|
||||||
|
|
||||||
def select_part_by_span(self, custom_span: Span) -> VGroup:
|
def select_part_by_span(self, custom_span: Span) -> VGroup:
|
||||||
labels = [
|
|
||||||
label for label, span in enumerate(self.label_span_list)
|
|
||||||
if self.span_contains(custom_span, span)
|
|
||||||
]
|
|
||||||
return VGroup(*[
|
return VGroup(*[
|
||||||
submob for label, submob in self.labelled_submobject_items
|
submob for label, submob in self.labelled_submobject_items
|
||||||
if label in labels
|
if self.span_contains(custom_span, self.label_span_list[label])
|
||||||
])
|
])
|
||||||
|
|
||||||
def select_parts(self, selector: Selector) -> VGroup:
|
def select_parts(self, selector: Selector) -> VGroup:
|
||||||
|
|
|
@ -209,17 +209,19 @@ class MTex(LabelledString):
|
||||||
# Match paired double braces (`{{...}}`).
|
# Match paired double braces (`{{...}}`).
|
||||||
sorted_brace_spans = sorted(self.brace_spans, key=lambda t: t[1])
|
sorted_brace_spans = sorted(self.brace_spans, key=lambda t: t[1])
|
||||||
inner_brace_spans = [
|
inner_brace_spans = [
|
||||||
sorted_brace_spans[span_span[0]]
|
sorted_brace_spans[range_begin]
|
||||||
for _, span_span in self.compress_neighbours([
|
for _, (range_begin, range_end) in self.compress_neighbours([
|
||||||
(brace_span[0] + index, brace_span[1] - index)
|
(span_begin + index, span_end - index)
|
||||||
for index, brace_span in enumerate(sorted_brace_spans)
|
for index, (span_begin, span_end) in enumerate(
|
||||||
|
sorted_brace_spans
|
||||||
|
)
|
||||||
])
|
])
|
||||||
if span_span[1] - span_span[0] >= 2
|
if range_end - range_begin >= 2
|
||||||
]
|
]
|
||||||
inner_brace_content_spans = [
|
inner_brace_content_spans = [
|
||||||
(span[0] + 1, span[1] - 1)
|
(span_begin + 1, span_end - 1)
|
||||||
for span in inner_brace_spans
|
for span_begin, span_end in inner_brace_spans
|
||||||
if span[1] - span[0] > 2
|
if span_end - span_begin > 2
|
||||||
]
|
]
|
||||||
|
|
||||||
result = self.remove_redundancies(self.chain(
|
result = self.remove_redundancies(self.chain(
|
||||||
|
@ -303,12 +305,14 @@ class MTex(LabelledString):
|
||||||
# Selector
|
# Selector
|
||||||
|
|
||||||
def get_cleaned_substr(self, span: Span) -> str:
|
def get_cleaned_substr(self, span: Span) -> str:
|
||||||
if not self.brace_spans:
|
left_brace_indices = [
|
||||||
brace_begins, brace_ends = [], []
|
span_begin
|
||||||
else:
|
for span_begin, _ in self.brace_spans
|
||||||
brace_begins, brace_ends = zip(*self.brace_spans)
|
]
|
||||||
left_brace_indices = list(brace_begins)
|
right_brace_indices = [
|
||||||
right_brace_indices = [index - 1 for index in brace_ends]
|
span_end - 1
|
||||||
|
for _, span_end in self.brace_spans
|
||||||
|
]
|
||||||
skippable_indices = self.chain(
|
skippable_indices = self.chain(
|
||||||
self.find_indices(r"\s"),
|
self.find_indices(r"\s"),
|
||||||
self.script_char_indices,
|
self.script_char_indices,
|
||||||
|
|
|
@ -300,18 +300,17 @@ class MarkupText(LabelledString):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_tag_spans(self) -> list[Span]:
|
def get_tag_spans(self) -> list[Span]:
|
||||||
return [
|
return self.chain(
|
||||||
tag_span
|
(begin_tag_span, end_tag_span)
|
||||||
for begin_tag, end_tag, _ in self.tag_pairs_from_markup
|
for begin_tag_span, end_tag_span, _ in self.tag_pairs_from_markup
|
||||||
for tag_span in (begin_tag, end_tag)
|
)
|
||||||
]
|
|
||||||
|
|
||||||
def get_items_from_markup(self) -> list[Span]:
|
def get_items_from_markup(self) -> list[tuple[Span, dict[str, str]]]:
|
||||||
return [
|
return [
|
||||||
((begin_tag_span[1], end_tag_span[0]), attr_dict)
|
((span_begin, span_end), attr_dict)
|
||||||
for begin_tag_span, end_tag_span, attr_dict
|
for (_, span_begin), (span_end, _), attr_dict
|
||||||
in self.tag_pairs_from_markup
|
in self.tag_pairs_from_markup
|
||||||
if begin_tag_span[1] < end_tag_span[0]
|
if span_begin < span_end
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_specified_items(self) -> list[tuple[Span, dict[str, str]]]:
|
def get_specified_items(self) -> list[tuple[Span, dict[str, str]]]:
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import itertools as it
|
import itertools as it
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pyperclip
|
import pyperclip
|
||||||
|
|
Loading…
Add table
Reference in a new issue