Improve TransformMatchingString to match longest common substrings by default

This commit is contained in:
Grant Sanderson 2023-02-03 17:28:27 -08:00
parent b8fe7b0172
commit fab917ccee

View file

@ -1,19 +1,15 @@
from __future__ import annotations
import itertools as it
import numpy as np
from difflib import SequenceMatcher
from manimlib.animation.composition import AnimationGroup
from manimlib.animation.fading import FadeInFromPoint
from manimlib.animation.fading import FadeOutToPoint
from manimlib.animation.fading import FadeTransformPieces
from manimlib.animation.transform import Transform
from manimlib.mobject.mobject import Mobject
from manimlib.mobject.mobject import Group
from manimlib.mobject.svg.string_mobject import StringMobject
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.mobject.svg.string_mobject import StringMobject
from typing import TYPE_CHECKING
@ -131,28 +127,64 @@ class TransformMatchingStrings(TransformMatchingParts):
target: StringMobject,
matched_keys: Iterable[str] = [],
key_map: dict[str, str] = dict(),
matched_pairs: Iterable[tuple[Mobject, Mobject]] = [],
matched_pairs: Iterable[tuple[VMobject, VMobject]] = [],
**kwargs,
):
matched_pairs = list(matched_pairs) + [
*[(source[key], target[key]) for key in matched_keys],
*[(source[key1], target[key2]) for key1, key2 in key_map.items()],
*[
(source[substr], target[substr])
for substr in [
*source.get_specified_substrings(),
*target.get_specified_substrings(),
*source.get_symbol_substrings(),
*target.get_symbol_substrings(),
]
]
matched_pairs = [
*matched_pairs,
*self.matching_blocks(source, target, matched_keys, key_map),
]
super().__init__(
source, target,
matched_pairs=matched_pairs,
**kwargs,
)
def matching_blocks(
self,
source: StringMobject,
target: StringMobject,
matched_keys: Iterable[str],
key_map: dict[str, str]
) -> list[tuple[VMobject, VMobject]]:
syms1 = source.get_symbol_substrings()
syms2 = target.get_symbol_substrings()
counts1 = list(map(source.substr_to_path_count, syms1))
counts2 = list(map(target.substr_to_path_count, syms2))
# Start with user specified matches
blocks = [(source[key], target[key]) for key in matched_keys]
blocks += [(source[key1], target[key2]) for key1, key2 in key_map.items()]
# Nullify any intersections with those matches in the two symbol lists
for sub_source, sub_target in blocks:
for i in range(len(syms1)):
if source[i] in sub_source.family_members_with_points():
syms1[i] = "Null1"
for j in range(len(syms2)):
if target[j] in sub_target.family_members_with_points():
syms2[j] = "Null2"
# Group together longest matching substrings
while True:
matcher = SequenceMatcher(None, syms1, syms2)
match = matcher.find_longest_match(0, len(syms1), 0, len(syms2))
if match.size == 0:
break
i1 = sum(counts1[:match.a])
i2 = sum(counts2[:match.b])
size = sum(counts1[match.a:match.a + match.size])
blocks.append((source[i1:i1 + size], target[i2:i2 + size]))
for i in range(match.size):
syms1[match.a + i] = "Null1"
syms2[match.b + i] = "Null2"
return blocks
class TransformMatchingTex(TransformMatchingStrings):
"""Alias for TransformMatchingStrings"""