diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 1c3f0afd..cb62df9b 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -2,9 +2,10 @@ from __future__ import annotations from abc import ABC, abstractmethod import itertools as it -import numpy as np import re +import numpy as np + from manimlib.constants import WHITE from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.types.vectorized_mobject import VGroup @@ -138,32 +139,16 @@ class LabelledString(SVGMobject, ABC): def find_indices(self, pattern: str | re.Pattern, **kwargs) -> list[int]: return [index for index, _ in self.find_spans(pattern, **kwargs)] - @staticmethod - def is_single_selector(selector: Selector) -> bool: - if isinstance(selector, str): - return True - if isinstance(selector, re.Pattern): - return True - if isinstance(selector, tuple): - if len(selector) == 2 and all([ - isinstance(index, int) or index is None - for index in selector - ]): - return True - return False - def find_spans_by_selector(self, selector: Selector) -> list[Span]: - if self.is_single_selector(selector): - selector = (selector,) - result = [] - for sel in selector: - if not self.is_single_selector(sel): - raise TypeError(f"Invalid selector: '{sel}'") + def find_spans_by_single_selector(sel): if isinstance(sel, str): - spans = self.find_spans(re.escape(sel)) - elif isinstance(sel, re.Pattern): - spans = self.find_spans(sel) - else: + return self.find_spans(re.escape(sel)) + if isinstance(sel, re.Pattern): + return self.find_spans(sel) + if isinstance(sel, tuple) and len(sel) == 2 and all([ + isinstance(index, int) or index is None + for index in sel + ]): string_len = self.full_span[1] span = tuple([ ( @@ -174,8 +159,17 @@ class LabelledString(SVGMobject, ABC): if index is not None else default_index for index, default_index in zip(sel, self.full_span) ]) - spans = [span] - result.extend(spans) + return [span] + return None + + result = find_spans_by_single_selector(selector) + if result is None: + result = [] + for sel in selector: + spans = find_spans_by_single_selector(sel) + if spans is None: + raise TypeError(f"Invalid selector: '{sel}'") + result.extend(spans) return sorted(filter( lambda span: span[0] < span[1], self.remove_redundancies(result) @@ -206,8 +200,8 @@ class LabelledString(SVGMobject, ABC): unique_vals.append(val) indices.append(index) indices.append(len(vals)) - spans = LabelledString.get_neighbouring_pairs(indices) - return list(zip(unique_vals, spans)) + val_ranges = LabelledString.get_neighbouring_pairs(indices) + return list(zip(unique_vals, val_ranges)) @staticmethod def span_contains(span_0: Span, span_1: Span) -> bool: @@ -233,26 +227,23 @@ class LabelledString(SVGMobject, ABC): if not inserted_string_pairs: return [] - spans = [ - span for span, _ in inserted_string_pairs - ] - sorted_index_flag_pairs = sorted( - it.product(range(len(spans)), range(2)), - key=lambda t: ( - spans[t[0]][t[1]], - np.sign(spans[t[0]][1 - t[1]] - spans[t[0]][t[1]]), - -spans[t[0]][1 - t[1]], - t[1], - (1, -1)[t[1]] * t[0] + indices, *_, inserted_strings = zip(*sorted([ + ( + span[flag], + np.sign(span[1 - flag] - span[flag]), + -span[1 - flag], + flag, + (1, -1)[flag] * item_index, + str_pair[flag] ) - ) - indices, inserted_strings = zip(*[ - list(zip(*inserted_string_pairs[item_index]))[flag] - for item_index, flag in sorted_index_flag_pairs - ]) + for item_index, (span, str_pair) in enumerate( + inserted_string_pairs + ) + for flag in range(2) + ])) return [ - (index, "".join(inserted_strings[slice(*item_span)])) - for index, item_span + (index, "".join(inserted_strings[slice(*index_range)])) + for index, index_range in LabelledString.compress_neighbours(indices) ] @@ -262,8 +253,7 @@ class LabelledString(SVGMobject, ABC): if not repl_items: return self.get_substr(span) - sorted_repl_items = sorted(repl_items, key=lambda t: t[0]) - repl_spans, repl_strs = zip(*sorted_repl_items) + repl_spans, repl_strs = zip(*sorted(repl_items)) pieces = [ self.get_substr(piece_span) for piece_span in self.get_complement_spans(repl_spans, span) @@ -335,7 +325,7 @@ class LabelledString(SVGMobject, ABC): return [] labels, labelled_submobjects = zip(*self.labelled_submobject_items) - group_labels, labelled_submob_spans = zip( + group_labels, labelled_submob_ranges = zip( *self.compress_neighbours(labels) ) ordered_spans = [ @@ -362,8 +352,8 @@ class LabelledString(SVGMobject, ABC): ) ] submob_groups = VGroup(*[ - VGroup(*labelled_submobjects[slice(*submob_span)]) - for submob_span in labelled_submob_spans + VGroup(*labelled_submobjects[slice(*submob_range)]) + for submob_range in labelled_submob_ranges ]) return list(zip(group_substrs, submob_groups)) @@ -377,13 +367,9 @@ class LabelledString(SVGMobject, ABC): ] def select_part_by_span(self, custom_span: Span) -> VGroup: - labels = [ - label for label, span in enumerate(self.label_span_list) - if self.span_contains(custom_span, span) - ] return VGroup(*[ submob for label, submob in self.labelled_submobject_items - if label in labels + if self.span_contains(custom_span, self.label_span_list[label]) ]) def select_parts(self, selector: Selector) -> VGroup: diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index ed7273ee..3edb10b1 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -209,17 +209,19 @@ class MTex(LabelledString): # Match paired double braces (`{{...}}`). sorted_brace_spans = sorted(self.brace_spans, key=lambda t: t[1]) inner_brace_spans = [ - sorted_brace_spans[span_span[0]] - for _, span_span in self.compress_neighbours([ - (brace_span[0] + index, brace_span[1] - index) - for index, brace_span in enumerate(sorted_brace_spans) + sorted_brace_spans[range_begin] + for _, (range_begin, range_end) in self.compress_neighbours([ + (span_begin + index, span_end - index) + for index, (span_begin, span_end) in enumerate( + sorted_brace_spans + ) ]) - if span_span[1] - span_span[0] >= 2 + if range_end - range_begin >= 2 ] inner_brace_content_spans = [ - (span[0] + 1, span[1] - 1) - for span in inner_brace_spans - if span[1] - span[0] > 2 + (span_begin + 1, span_end - 1) + for span_begin, span_end in inner_brace_spans + if span_end - span_begin > 2 ] result = self.remove_redundancies(self.chain( @@ -303,12 +305,14 @@ class MTex(LabelledString): # Selector def get_cleaned_substr(self, span: Span) -> str: - if not self.brace_spans: - brace_begins, brace_ends = [], [] - else: - brace_begins, brace_ends = zip(*self.brace_spans) - left_brace_indices = list(brace_begins) - right_brace_indices = [index - 1 for index in brace_ends] + left_brace_indices = [ + span_begin + for span_begin, _ in self.brace_spans + ] + right_brace_indices = [ + span_end - 1 + for _, span_end in self.brace_spans + ] skippable_indices = self.chain( self.find_indices(r"\s"), self.script_char_indices, diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 740eeee6..0fe113ff 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -300,18 +300,17 @@ class MarkupText(LabelledString): return result def get_tag_spans(self) -> list[Span]: - return [ - tag_span - for begin_tag, end_tag, _ in self.tag_pairs_from_markup - for tag_span in (begin_tag, end_tag) - ] + return self.chain( + (begin_tag_span, end_tag_span) + for begin_tag_span, end_tag_span, _ in self.tag_pairs_from_markup + ) - def get_items_from_markup(self) -> list[Span]: + def get_items_from_markup(self) -> list[tuple[Span, dict[str, str]]]: return [ - ((begin_tag_span[1], end_tag_span[0]), attr_dict) - for begin_tag_span, end_tag_span, attr_dict + ((span_begin, span_end), attr_dict) + for (_, span_begin), (span_end, _), attr_dict in self.tag_pairs_from_markup - if begin_tag_span[1] < end_tag_span[0] + if span_begin < span_end ] def get_specified_items(self) -> list[tuple[Span, dict[str, str]]]: diff --git a/manimlib/scene/interactive_scene.py b/manimlib/scene/interactive_scene.py index 40ec18c5..3b7397b8 100644 --- a/manimlib/scene/interactive_scene.py +++ b/manimlib/scene/interactive_scene.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import itertools as it import numpy as np import pyperclip