From d30a4f430dc6af1e6da5f1999bae601b1f0c6853 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Sun, 22 Aug 2021 14:57:08 -0700 Subject: [PATCH] Make RootCoefScene interactive --- _2021/quintic.py | 240 ++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 215 insertions(+), 25 deletions(-) diff --git a/_2021/quintic.py b/_2021/quintic.py index d1cbd1f..0ee9b9a 100644 --- a/_2021/quintic.py +++ b/_2021/quintic.py @@ -46,8 +46,8 @@ def coefficients_to_roots(coefs): class RootCoefScene(Scene): root_plane_config = { - "x_range": (-2, 2), - "y_range": (-2, 2), + "x_range": (-2.0, 2.0), + "y_range": (-2.0, 2.0), } coef_plane_config = { "x_range": (-4, 4), @@ -71,6 +71,10 @@ class RootCoefScene(Scene): "draw_stroke_behind_fill": True, } + def setup(self): + self.root_dots = VGroup() + self.coef_dots = VGroup() + def add_planes(self): # Planes planes = VGroup( @@ -126,17 +130,15 @@ class RootCoefScene(Scene): self.coef_plane = coef_plane self.root_plane_label = root_plane_label self.coef_plane_label = coef_plane_label + self.root_poly = root_poly + self.coef_poly = coef_poly self.poly_equal_sign = equals def get_r_symbols(self, root_poly): - return VGroup(*it.chain(*( - part[3:5] for part in root_poly - ))) + return VGroup(*(part[3:5] for part in root_poly)) def get_c_symbols(self, coef_poly): - return VGroup(*(it.chain(*( - part[1:3] for part in coef_poly[1:] - )))) + return VGroup(*(part[1:3] for part in coef_poly[:0:-1])) def get_random_root(self): return complex( @@ -150,36 +152,36 @@ class RootCoefScene(Scene): def get_roots_of_unity(self): return [np.exp(complex(0, TAU * n / self.degree)) for n in range(self.degree)] - def create_root_dots(self, roots): - return VGroup(*( + def set_roots(self, roots): + self.root_dots.set_submobjects( Dot( self.root_plane.n2p(root), color=self.root_color, **self.dot_style, ) for root in roots - )) + ) - def create_coef_dots(self, coefs): - return VGroup(*( + def set_coefs(self, coefs): + self.coef_dots.set_submobjects( Dot( self.coef_plane.n2p(coef), color=self.coef_color, **self.dot_style, ) for coef in coefs - )) + ) def add_root_dots(self, roots=None): if roots is None: roots = self.get_roots_of_unity() - self.root_dots = self.create_root_dots(roots) + self.set_roots(roots) self.add(self.root_dots) def add_coef_dots(self, coefs=None): if coefs is None: coefs = [0] * self.degree - self.coef_dots = self.create_coef_dots(coefs) + self.set_coefs(coefs) self.add(self.coef_dots) def get_roots(self): @@ -194,35 +196,223 @@ class RootCoefScene(Scene): for coef_dot in self.coef_dots ] - def tie_coefs_to_roots(self): + def tie_coefs_to_roots(self, clear_updaters=True): + if clear_updaters: + self.root_dots.clear_updaters() + self.coef_dots.clear_updaters() + def update_coef_dots(cdots): coefs = roots_to_coefficients(self.get_roots()) for dot, coef in zip(cdots, coefs): dot.move_to(self.coef_plane.n2p(coef)) self.coef_dots.add_updater(update_coef_dots) + self.add(self.coef_dots) + + def tie_roots_to_coefs(self, clear_updaters=True): + if clear_updaters: + self.root_dots.clear_updaters() + self.coef_dots.clear_updaters() - def tie_roots_to_coefs(self): def update_root_dots(rdots): - roots = coefficients_to_roots(self.get_coefs()) + old_roots = self.get_roots() + unordered_roots = coefficients_to_roots(self.get_coefs()) + # Sort them to match the old_roots + roots = [] + for old_root in old_roots: + root = unordered_roots[np.argmin([ + abs(old_root - ur) + for ur in unordered_roots + ])] + unordered_roots.remove(root) + roots.append(root) for dot, root in zip(rdots, roots): dot.move_to(self.root_plane.n2p(root)) self.root_dots.add_updater(update_root_dots) + self.add(self.root_dots) + + def add_tracers(self, time_traced=2.0, **kwargs): + self.tracers = VGroup(*( + TracingTail( + dot, + stroke_color=dot.get_fill_color(), + time_traced=time_traced, + **kwargs + ) + for dot in (*self.root_dots, *self.coef_dots) + )) + self.add(self.tracers) + + def get_tracking_lines(self, dots, syms, stroke_width=1, stroke_opacity=0.5): + lines = VGroup(*( + Line( + stroke_color=root.get_fill_color(), + stroke_width=stroke_width, + stroke_opacity=stroke_opacity, + ) + for root in dots + )) + + def update_lines(lines): + for sym, dot, line in zip(syms, dots, lines): + line.put_start_and_end_on( + sym.get_bottom(), + dot.get_center() + ) + + lines.add_updater(update_lines) + return lines + + def add_root_lines(self, **kwargs): + self.root_lines = self.get_tracking_lines( + self.root_dots, + self.get_r_symbols(self.root_poly), + **kwargs + ) + self.add(self.root_lines) + + def add_coef_lines(self, **kwargs): + self.coef_lines = self.get_tracking_lines( + self.coef_dots, + self.get_c_symbols(self.coef_poly), + **kwargs + ) + self.add(self.coef_lines) + + def add_dot_labels(self, sym, dots, buff=0.05): + labels = VGroup() + for i, dot in enumerate(dots): + label = Tex(f"{sym}_{i}", font_size=24) + label.set_fill(dot.get_fill_color()) + label.set_stroke(BLACK, 3, background=True) + label.dot = dot + label.add_updater(lambda m: m.next_to(m.dot, UR, buff=buff)) + labels.add(label) + self.add(labels) + return labels + + def add_r_labels(self): + self.r_dot_labels = self.add_dot_labels("r", self.root_dots) + + def add_c_labels(self): + self.c_dot_labels = self.add_dot_labels("c", self.coef_dots) + + def add_value_label(self): + pass # TODO + + # Animations + def play(self, *anims, **kwargs): + movers = list(it.chain(*(anim.mobject.get_family() for anim in anims))) + roots_move = any(rd in movers for rd in self.root_dots) + coefs_move = any(cd in movers for cd in self.coef_dots) + if roots_move and not coefs_move: + self.tie_coefs_to_roots() + elif coefs_move and not roots_move: + self.tie_roots_to_coefs() + super().play(*anims, **kwargs) + + def get_root_swap_arrows(self, i, j, + path_arc=90 * DEGREES, + stroke_width=5, + stroke_opacity=0.7, + buff=0.3, + **kwargs): + di = self.root_dots[i].get_center() + dj = self.root_dots[j].get_center() + kwargs["path_arc"] = path_arc + kwargs["stroke_width"] = stroke_width + kwargs["stroke_opacity"] = stroke_opacity + kwargs["buff"] = buff + return VGroup( + Arrow(di, dj, **kwargs), + Arrow(dj, di, **kwargs), + ) + + def swap_roots(self, i, j, run_time=2, wait_time=1): + self.play(Swap( + self.root_dots[i], + self.root_dots[j], + run_time=3 + )) + self.wait(wait_time) + + def rotate_coefs(self, indicies, center_z=0, run_time=5, wait_time=1): + self.play(*( + Rotate( + self.coef_dots[i], TAU, + about_point=self.coef_plane.n2p(center_z), + run_time=run_time + ) + for i in indicies + )) + self.wait(wait_time) + + def rotate_coef(self, i, **kwargs): + self.rotate_coefs([i], **kwargs) + + # Interaction + + def on_mouse_press(self, point, button, mods): + try: + super().on_mouse_press(point, button, mods) + mob = self.point_to_mobject( + point, + search_set=(*self.root_dots, *self.coef_dots), + buff=0.1 + ) + if mob is None: + return + if mob in self.root_dots: + self.tie_coefs_to_roots() + self.add(*self.root_dots) + elif mob in self.coef_dots: + self.tie_roots_to_coefs() + self.add(*self.coef_dots) + self.mouse_drag_point.move_to(point) + mob.add_updater(lambda m: m.move_to(self.mouse_drag_point)) + self.unlock_mobject_data() + self.lock_static_mobject_data() + except Exception as e: + print(e) + + def on_mouse_release(self, point, button, mods): + super().on_mouse_release(point, button, mods) + self.root_dots.clear_updaters() + self.coef_dots.clear_updaters() -class Test(RootCoefScene): +class TestRootCoefScene(RootCoefScene): def construct(self): self.add_planes() - coefs = list(range(-2, 3, 1)) + + # Add dots + coefs = [3, 2, 1, 0, -1] roots = coefficients_to_roots(coefs) self.add_root_dots(roots) self.add_coef_dots(coefs) + self.add_tracers() - self.tie_coefs_to_roots() + self.add_r_labels() + self.add_c_labels() - self.root_dots[0].set_color(RED) + # Animate + # self.swap_roots(0, 1) + # self.rotate_coef(0) - self.play(Swap(*self.root_dots[:2]), run_time=3) + # self.add_root_lines() + # self.add_coef_lines() - self.embed() + # self.tie_roots_to_coefs() + + # Sample animations + # self.rotate_coef(0) + + # self.tie_coefs_to_roots() + # for i, j in [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2)]: + # arrows = self.get_root_swap_arrows(i, j) + # self.play(*map(ShowCreation, arrows)) + # self.swap_roots(i, j) + # self.play(FadeOut(arrows)) + + # self.embed()