diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index b832cf0e..7e30d2ad 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -1,19 +1,15 @@ from __future__ import annotations import itertools as it - -import numpy as np +from difflib import SequenceMatcher from manimlib.animation.composition import AnimationGroup from manimlib.animation.fading import FadeInFromPoint from manimlib.animation.fading import FadeOutToPoint -from manimlib.animation.fading import FadeTransformPieces from manimlib.animation.transform import Transform from manimlib.mobject.mobject import Mobject -from manimlib.mobject.mobject import Group -from manimlib.mobject.svg.string_mobject import StringMobject -from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VMobject +from manimlib.mobject.svg.string_mobject import StringMobject from typing import TYPE_CHECKING @@ -131,28 +127,64 @@ class TransformMatchingStrings(TransformMatchingParts): target: StringMobject, matched_keys: Iterable[str] = [], key_map: dict[str, str] = dict(), - matched_pairs: Iterable[tuple[Mobject, Mobject]] = [], + matched_pairs: Iterable[tuple[VMobject, VMobject]] = [], **kwargs, ): - matched_pairs = list(matched_pairs) + [ - *[(source[key], target[key]) for key in matched_keys], - *[(source[key1], target[key2]) for key1, key2 in key_map.items()], - *[ - (source[substr], target[substr]) - for substr in [ - *source.get_specified_substrings(), - *target.get_specified_substrings(), - *source.get_symbol_substrings(), - *target.get_symbol_substrings(), - ] - ] + matched_pairs = [ + *matched_pairs, + *self.matching_blocks(source, target, matched_keys, key_map), ] + super().__init__( source, target, matched_pairs=matched_pairs, **kwargs, ) + def matching_blocks( + self, + source: StringMobject, + target: StringMobject, + matched_keys: Iterable[str], + key_map: dict[str, str] + ) -> list[tuple[VMobject, VMobject]]: + syms1 = source.get_symbol_substrings() + syms2 = target.get_symbol_substrings() + counts1 = list(map(source.substr_to_path_count, syms1)) + counts2 = list(map(target.substr_to_path_count, syms2)) + + # Start with user specified matches + blocks = [(source[key], target[key]) for key in matched_keys] + blocks += [(source[key1], target[key2]) for key1, key2 in key_map.items()] + + # Nullify any intersections with those matches in the two symbol lists + for sub_source, sub_target in blocks: + for i in range(len(syms1)): + if source[i] in sub_source.family_members_with_points(): + syms1[i] = "Null1" + for j in range(len(syms2)): + if target[j] in sub_target.family_members_with_points(): + syms2[j] = "Null2" + + # Group together longest matching substrings + while True: + matcher = SequenceMatcher(None, syms1, syms2) + match = matcher.find_longest_match(0, len(syms1), 0, len(syms2)) + if match.size == 0: + break + + i1 = sum(counts1[:match.a]) + i2 = sum(counts2[:match.b]) + size = sum(counts1[match.a:match.a + match.size]) + + blocks.append((source[i1:i1 + size], target[i2:i2 + size])) + + for i in range(match.size): + syms1[match.a + i] = "Null1" + syms2[match.b + i] = "Null2" + + return blocks + class TransformMatchingTex(TransformMatchingStrings): """Alias for TransformMatchingStrings"""