diff --git a/_2022/convolutions/continuous.py b/_2022/convolutions/continuous.py index 3ecdb5b..ed8933b 100644 --- a/_2022/convolutions/continuous.py +++ b/_2022/convolutions/continuous.py @@ -1,8 +1,10 @@ from manim_imports_ext import * -from _2022.convolutions.main import * +from _2022.convolutions.discrete import * class ConvolveDiscreteDistributions(InteractiveScene): + long_form = True + def construct(self): # Set up two distributions dist1 = np.array([np.exp(-0.25 * (x - 3)**2) for x in range(6)]) @@ -38,7 +40,6 @@ class ConvolveDiscreteDistributions(InteractiveScene): v_line.set_y(0) v_lines.add(v_line) v_lines.add(v_lines[-1].copy().set_x(top_bars.get_right()[0])) - # v_lines.set_stroke(opacity=0) # Set up new distribution conv_dist = np.convolve(dist1, dist2) @@ -74,30 +75,41 @@ class ConvolveDiscreteDistributions(InteractiveScene): PXY.next_to(conv_bars, UP, LARGE_BUFF) # Add distributions + frame = self.camera.frame + frame.set_height(6).move_to(top_bars) + self.play( - FadeIn(top_bars, lag_ratio=0.1), - FadeIn(v_lines, lag_ratio=0.2), + self.show_bars_creation(top_bars, lag_ratio=0.05), Write(PX), ) self.wait() self.play( - FadeIn(low_bars, lag_ratio=0.1), + self.show_bars_creation(low_bars, lag_ratio=0.1), Write(PY), + FadeIn(v_lines, lag_ratio=0.2), + frame.animate.set_height(FRAME_HEIGHT).set_y(0).set_anim_args(run_time=1), ) self.wait() self.play( - FadeIn(conv_bars), + self.show_bars_creation(conv_bars), FadeTransform(PX.copy(), PXY), FadeTransform(PY.copy(), PXY), + frame.animate.center().set_anim_args(run_time=1) + ) + self.wait() + + # Flip + self.play( + low_bars.animate.arrange(LEFT, aligned_edge=DOWN, buff=0).move_to(low_bars), + path_arc=PI / 3, + lag_ratio=0.005 ) self.wait() # March! - self.play(low_bars.animate.arrange(LEFT, aligned_edge=DOWN, buff=0).move_to(low_bars)) - - last_rects = VGroup() - for n in range(2, 13): + last_pair_rects = VGroup() + for n in [7, *range(2, 13)]: conv_bars.generate_target() conv_bars.target.set_opacity(0.35) conv_bars.target[n - 2].set_opacity(1.0) @@ -105,36 +117,78 @@ class ConvolveDiscreteDistributions(InteractiveScene): self.play( get_row_shift(top_bars, low_bars, n), MaintainPositionRelativeTo(PY, low_bars), - FadeOut(last_rects), + FadeOut(last_pair_rects), MoveToTarget(conv_bars), ) pairs = get_aligned_pairs(top_bars, low_bars, n) label_pairs = VGroup(*(VGroup(m1.value_label, m2.value_label) for m1, m2 in pairs)) - rects = VGroup(*( - SurroundingRectangle(lp, buff=0.05).set_stroke(YELLOW, 2).round_corners() - for lp in label_pairs + die_pairs = VGroup(*(VGroup(m1.die, m2.die) for m1, m2 in pairs)) + pair_rects = VGroup(*( + SurroundingRectangle(pair, buff=0.05).set_stroke(YELLOW, 2).round_corners() + for pair in pairs )) - rects.set_stroke(YELLOW, 2) + pair_rects.set_stroke(YELLOW, 2) + for rect in pair_rects: + rect.set_width(label_pairs[0].get_width() + 0.125, stretch=True) - self.play( - FadeIn(rects, lag_ratio=0.5), - # Restore(bar[0], time_span=(0.5, 1.0)), - # Write(bar[2], time_span=(0.5, 1.0)), - ) + self.play(FadeIn(pair_rects, lag_ratio=0.5)) - self.play(*( - FadeTransform(label.copy(), conv_bars[n - 2].value_label) - for lp in label_pairs - for label in lp - )) - self.wait(0.5) + fade_anims = [] + if self.long_form: + # Spell out the full dot product + products = MTex(R"P(O) \cdot P(O)", isolate="O", font_size=36).replicate(len(pairs)) + products.arrange(DOWN, buff=MED_LARGE_BUFF) + products.next_to(conv_bars, LEFT, MED_LARGE_BUFF) + products.to_edge(UP, buff=LARGE_BUFF) + plusses = Tex("+", font_size=36).replicate(len(pairs) - 1) + for plus, lp1, lp2 in zip(plusses, products, products[1:]): + plus.move_to(VGroup(lp1, lp2)) - last_rects = rects + die_targets = die_pairs.copy() + for dp, dt_pair, product in zip(die_pairs, die_targets, products): + for die, O in zip(dt_pair, product.select_parts("O")): + die.match_width(O) + die.move_to(O) + O.set_opacity(0) + product.save_state() + product[:len(product) // 2].replace(dp[0]) + product[len(product) // 2:].replace(dp[1]) + product.set_opacity(0) + + self.play( + LaggedStart(*( + TransformFromCopy(dp, dt) + for dp, dt in zip(die_pairs, die_targets) + ), lag_ratio=0.5), + LaggedStartMap(Restore, products, lag_ratio=0.5), + LaggedStartMap(Write, plusses, lag_ratio=0.5), + run_time=3 + ) + self.wait() + prod_group = VGroup(*products, *die_targets, *plusses) + prod_group.generate_target() + prod_group.target.set_opacity(0) + for mob in prod_group.target: + mob.replace(conv_bars[n - 2].value_label, stretch=True) + self.play(MoveToTarget(prod_group, remover=True)) + self.wait() + else: + self.play( + *( + FadeTransform(label.copy(), conv_bars[n - 2].value_label) + for lp in label_pairs + for label in lp + ), + *fade_anims, + ) + self.wait(0.5) + + last_pair_rects = pair_rects conv_bars.target.set_opacity(1.0) self.play( - FadeOut(last_rects), + FadeOut(last_pair_rects), get_row_shift(top_bars, low_bars, 7), MaintainPositionRelativeTo(PY, low_bars), MoveToTarget(conv_bars), @@ -222,6 +276,24 @@ class ConvolveDiscreteDistributions(InteractiveScene): self.play(Write(conv_def[10:])) self.wait() + def show_bars_creation(self, bars, lag_ratio=0.05, run_time=3): + anims = [] + for bar in bars: + rect, num, face = bar + num.rect = rect + rect.save_state() + rect.stretch(0, 1, about_edge=DOWN) + rect.set_opacity(0) + + anims.extend([ + FadeIn(face), + rect.animate.restore(), + CountInFrom(num, 0), + UpdateFromAlphaFunc(num, lambda m, a: m.next_to(m.rect, UP, SMALL_BUFF).set_opacity(a)), + ]) + + return LaggedStart(*anims, lag_ratio=lag_ratio, run_time=run_time) + # Continuous case @@ -231,6 +303,7 @@ class TransitionToContinuousProbability(InteractiveScene): # Setup axes and initial graph axes = Axes((0, 12), (0, 1, 0.2), width=14, height=5) axes.to_edge(LEFT, LARGE_BUFF) + axes.to_edge(DOWN, buff=1.25) def pd(x): return (x**4) * np.exp(-x) / 8.0 @@ -240,8 +313,9 @@ class TransitionToContinuousProbability(InteractiveScene): bars = axes.get_riemann_rectangles(graph, dx=1, x_range=(0, 6), input_sample_type="right") bars.set_stroke(WHITE, 3) - y_label = Text("Probability", font_size=24) + y_label = Text("Probability", font_size=48) y_label.next_to(axes.y_axis, UP, SMALL_BUFF) + y_label.shift_onto_screen() self.add(axes) self.add(y_label) @@ -297,12 +371,13 @@ class TransitionToContinuousProbability(InteractiveScene): density.match_height(y_label) density.move_to(y_label, LEFT) cross = Cross(y_label) + cross.set_stroke(RED, width=(0, 8, 8, 8, 0)) self.play(Write(x_label)) self.wait() self.play(ShowCreation(cross)) self.play( - VGroup(y_label, cross).animate.shift(0.5 * UP), + VGroup(y_label, cross).animate.shift(0.75 * UP), FadeIn(density) ) self.wait() @@ -357,8 +432,10 @@ class Convolutions(InteractiveScene): f_graph_x_step = 0.1 g_graph_x_step = 0.1 f_label_tex = "f(x)" - g_label_tex = "g(t - x)" - fg_label_tex = R"f(x) \cdot g(t - x)" + g_label_tex = "g(s - x)" + fg_label_tex = R"f(x) \cdot g(s - x)" + conv_label_tex = R"(f * g)(s) = \int_{-\infty}^\infty f(x) \cdot g(s - x) dx" + label_config = dict(font_size=36) t_color = TEAL area_line_dx = 0.05 jagged_product = True @@ -367,7 +444,8 @@ class Convolutions(InteractiveScene): def setup(self): super().setup() if self.g_is_rect: - k_tracker = self.k_tracker = ValueTracker(1) + k1_tracker = self.k1_tracker = ValueTracker(1) + k2_tracker = self.k2_tracker = ValueTracker(1) # Add axes all_axes = self.all_axes = self.get_all_axes() @@ -380,24 +458,25 @@ class Convolutions(InteractiveScene): # Add f(x) f_graph = self.f_graph = f_axes.get_graph(self.f, x_range=(x_min, x_max, self.f_graph_x_step)) f_graph.set_style(**self.f_graph_style) - f_label = self.get_label(self.f_label_tex, f_axes) + f_label = self.f_label = self.get_label(self.f_label_tex, f_axes) if self.jagged_product: f_graph.make_jagged() self.add(f_graph) self.add(f_label) - # Add g(t - x) + # Add g(s - x) self.toggle_selection_mode() # So triangle is highlighted - t_indicator = self.t_indicator = ArrowTip().rotate(90 * DEGREES) - t_indicator.set_height(0.15) - t_indicator.set_fill(self.t_color, 0.8) - t_indicator.move_to(g_axes.get_origin(), UP) - t_indicator.add_updater(lambda m: m.align_to(g_axes.get_origin(), UP)) + s_indicator = self.s_indicator = ArrowTip().rotate(90 * DEGREES) + s_indicator.set_height(0.15) + s_indicator.set_fill(self.t_color, 0.8) + s_indicator.move_to(g_axes.get_origin(), UP) + s_indicator.add_updater(lambda m: m.align_to(g_axes.get_origin(), UP)) - def get_t(): - return g_axes.x_axis.p2n(t_indicator.get_center()) + def get_s(): + return g_axes.x_axis.p2n(s_indicator.get_center()) + self.get_s = get_s g_graph = self.g_graph = g_axes.get_graph(lambda x: 0, x_range=(x_min, x_max, self.g_graph_x_step)) g_graph.set_style(**self.g_graph_style) if self.g_is_rect: @@ -405,34 +484,36 @@ class Convolutions(InteractiveScene): x_max = g_axes.x_axis.x_max g_graph.add_updater(lambda m: m.set_points_as_corners([ g_axes.c2p(x, y) - for t in [get_t()] - for k in [k_tracker.get_value()] + for s in [get_s()] + for k1 in [k1_tracker.get_value()] + for k2 in [k2_tracker.get_value()] for x, y in [ - (x_min, 0), (-0.5 / k + t, 0), (-0.5 / k + t, k), (0.5 / k + t, k), (0.5 / k + t, 0), (x_max, 0) + (x_min, 0), (-0.5 / k1 + s, 0), (-0.5 / k1 + s, k2), (0.5 / k1 + s, k2), (0.5 / k1 + s, 0), (x_max, 0) ] ])) else: - g_axes.bind_graph_to_func(g_graph, lambda x: self.g(get_t() - x), jagged=self.jagged_product) + g_axes.bind_graph_to_func(g_graph, lambda x: self.g(get_s() - x), jagged=self.jagged_product) g_label = self.g_label = self.get_label(self.g_label_tex, g_axes) - t_label = VGroup(*Tex("t = ")[0], DecimalNumber()) - t_label.arrange(RIGHT, buff=SMALL_BUFF) - t_label.scale(0.5) - t_label.set_backstroke(width=8) - t_label.add_updater(lambda m: m.next_to(t_indicator, DOWN, submobject_to_align=m[0], buff=0.15)) - t_label.add_updater(lambda m: m.shift(m.get_width() * LEFT / 2)) - t_label.add_updater(lambda m: m[-1].set_value(get_t())) + s_label = self.s_label = VGroup(*Tex("s = ")[0], DecimalNumber()) + s_label.arrange(RIGHT, buff=SMALL_BUFF) + s_label.scale(0.5) + s_label.set_backstroke(width=8) + s_label.add_updater(lambda m: m.next_to(s_indicator, DOWN, submobject_to_align=m[0], buff=0.15)) + s_label.add_updater(lambda m: m.shift(m.get_width() * LEFT / 2)) + s_label.add_updater(lambda m: m[-1].set_value(get_s())) self.add(g_graph) self.add(g_label) - self.add(t_indicator) - self.add(t_label) + self.add(s_indicator) + self.add(s_label) - # Show integral of f(x) * g(t - x) + # Show integral of f(x) * g(s - x) def prod_func(x): - k = self.k_tracker.get_value() if self.g_is_rect else 1 - return self.f(x) * self.g((get_t() - x) * k) * k + k1 = self.k1_tracker.get_value() if self.g_is_rect else 1 + k2 = self.k2_tracker.get_value() if self.g_is_rect else 1 + return self.f(x) * self.g((get_s() - x) * k1) * k2 fg_graph, pos_graph, neg_graph = ( fg_axes.get_graph(lambda x: 0, x_range=(x_min, x_max, self.g_graph_x_step)) @@ -446,8 +527,8 @@ class Convolutions(InteractiveScene): get_discontinuities = None if self.g_is_rect: def get_discontinuities(): - k = self.k_tracker.get_value() - return [get_t() - 0.5 / k, get_t() + 0.5 / k] + k1 = self.k1_tracker.get_value() + return [get_s() - 0.5 / k1, get_s() + 0.5 / k1] kw = dict( jagged=self.jagged_product, @@ -467,11 +548,11 @@ class Convolutions(InteractiveScene): # Show convolution conv_graph = self.conv_graph = self.get_conv_graph(conv_axes, self.f, self.g) - graph_dot = GlowDot(color=WHITE) + graph_dot = self.graph_dot = GlowDot(color=WHITE) graph_dot.add_updater(lambda d: d.move_to(conv_graph.quick_point_from_proportion( - inverse_interpolate(x_min, x_max, get_t()) + inverse_interpolate(x_min, x_max, get_s()) ))) - graph_line = Line(stroke_color=WHITE, stroke_width=1) + graph_line = self.graph_line = Line(stroke_color=WHITE, stroke_width=1) graph_line.add_updater(lambda l: l.put_start_and_end_on( graph_dot.get_center(), [graph_dot.get_x(), conv_axes.get_y(), 0], @@ -479,10 +560,7 @@ class Convolutions(InteractiveScene): self.conv_graph_dot = graph_dot self.conv_graph_line = graph_line - conv_label = Tex( - R"(f * g)(t) := \int_{-\infty}^\infty f(x) \cdot g(t - x) dx", - font_size=36 - ) + conv_label = self.conv_label = MTex(self.conv_label_tex, **self.label_config) conv_label.next_to(conv_axes, UP) self.add(conv_graph) @@ -490,8 +568,6 @@ class Convolutions(InteractiveScene): self.add(graph_line) self.add(conv_label) - # Now play! - def get_all_axes(self): all_axes = VGroup(*(Axes(**self.axes_config) for x in range(4))) all_axes[:3].arrange(DOWN, buff=0.75) @@ -501,7 +577,7 @@ class Convolutions(InteractiveScene): all_axes.to_edge(DOWN, buff=0.1) for i, axes in enumerate(all_axes): - x_label = Tex("x" if i < 3 else "t", font_size=24) + x_label = Tex("x" if i < 3 else "s", font_size=24) x_label.next_to(axes.x_axis.get_right(), UP, MED_SMALL_BUFF) axes.x_label = x_label axes.x_axis.add(x_label) @@ -511,7 +587,7 @@ class Convolutions(InteractiveScene): return all_axes def get_label(self, tex, axes): - label = Tex(tex, font_size=36) + label = MTex(tex, **self.label_config) label.move_to(midpoint(axes.get_origin(), axes.get_right())) label.match_y(axes.get_top()) return label @@ -536,6 +612,197 @@ class Convolutions(InteractiveScene): class ProbConvolutions(Convolutions): jagged_product = True + f_label_tex = "p_X(x)" + g_label_tex = "p_Y(s - x)" + fg_label_tex = R"p_X(x) \cdot p_Y(s - x)" + conv_label_tex = R"(p_X * p_Y)(s) := \int_{-\infty}^\infty p_X(x) \cdot p_Y(s - x) dx" + label_config = dict( + font_size=36, + tex_to_color_map={"X": BLUE, "Y": YELLOW} + ) + + def construct(self): + # Hit most of previous setup + f_axes, g_axes, fg_axes, conv_axes = self.all_axes + f_graph, g_graph, prod_graphs, conv_graph = self.f_graph, self.g_graph, self.prod_graphs, self.conv_graph + f_label, g_label, fg_label, conv_label = self.f_label, self.g_label, self.fg_label, self.conv_label + s_indicator = self.s_indicator + s_label = self.s_label + self.remove(s_indicator, s_label) + + f_axes.x_axis.add_numbers(font_size=16, buff=0.1) + self.add(f_axes) + + y_label = MTex("y").replace(g_axes.x_label) + g_label.shift(0.2 * UP) + gy_label = MTex("p_Y(y)", **self.label_config).replace(g_label, dim_to_match=1) + gmx_label = MTex("p_Y(-x)", **self.label_config).replace(g_label, dim_to_match=1) + g_axes.x_label.set_opacity(0) + self.remove(g_label) + self.add(y_label) + self.add(gy_label) + + alt_fg_label = MTex(R"p_X(x) \cdot p_Y(-x)", **self.label_config) + alt_fg_label.move_to(fg_label) + + conv_label.shift_onto_screen() + sum_label = MTex("p_{X + Y}(s)", **self.label_config) + sum_label.move_to(conv_label) + self.remove(fg_axes, prod_graphs, fg_label) + self.remove(conv_label) + self.remove(conv_axes, conv_graph, self.graph_dot, self.graph_line) + + # Show f and g + true_g_graph = g_graph.copy() + true_g_graph.clear_updaters() + true_g_graph.flip() + true_g_graph.reverse_points() + + self.remove(g_graph, f_graph) + self.play(LaggedStart(*( + AnimationGroup( + ShowCreation(graph), + VShowPassingFlash(graph.copy().set_stroke(width=5)), + run_time=2 + ) + for graph in (f_graph, true_g_graph) + ), lag_ratio=0.25)) + self.wait() + self.play( + Transform(f_graph.copy(), self.conv_graph.deepcopy(), remover=True), + Transform(true_g_graph.copy(), self.conv_graph.deepcopy(), remover=True), + FadeIn(conv_axes), + TransformMatchingShapes(VGroup(*f_label, *gy_label).copy(), sum_label), + ) + self.add(self.conv_graph) + + # Flip g + right_rect = FullScreenFadeRectangle() + right_rect.stretch(0.5, 0, about_edge=RIGHT) + g_axes_copy = g_axes.copy() + g_axes_copy.add(y_label) + true_group = VGroup(g_axes_copy, gy_label, true_g_graph) + + self.play( + true_group.animate.to_edge(DOWN, buff=MED_SMALL_BUFF), + FadeIn(right_rect) + ) + self.add(*true_group) + g_axes.generate_target() + g_axes.target.x_label.set_opacity(1), + self.play( + TransformMatchingShapes(gy_label.copy(), gmx_label), + true_g_graph.copy().animate.flip().move_to(g_graph).set_anim_args(remover=True), + MoveToTarget(g_axes), + ) + self.add(g_graph) + self.wait() + self.play(FadeOut(true_group)) + + # Show product + self.play( + FadeTransform(f_axes.copy(), fg_axes), + FadeTransform(g_axes.copy(), fg_axes), + Transform(f_graph.copy(), prod_graphs[0].copy(), remover=True), + Transform(g_graph.copy(), prod_graphs[0].copy(), remover=True), + TransformFromCopy( + VGroup(*f_label, *gmx_label), + alt_fg_label + ), + ) + self.add(fg_axes, prod_graphs[0]) + self.wait() + self.add(*prod_graphs) + self.play(DrawBorderThenFill(prod_graphs[1])) + self.wait() + + # Show constant sums + self.highlight_several_regions(reference=alt_fg_label) + self.play( + FadeIn(s_indicator), FadeIn(s_label), + FadeOut(gmx_label, 0.5 * UP), + FadeIn(g_label, 0.5 * UP), + FadeTransform(alt_fg_label, fg_label), + ) + self.wait() + self.play(s_indicator.animate.match_x(g_axes.c2p(-1, 0)), run_time=2) + self.highlight_several_regions(s=self.get_s(), reference=fg_label) + self.wait() + + # Show convolution + lhs = conv_label[:len("(px*py)(s)")] + rhs = conv_label[len("(px*py)(s)"):] + + self.play( + FadeOut(right_rect), + FadeIn(self.graph_dot), + FadeIn(self.graph_line), + ) + self.play(s_indicator.animate.match_x(g_axes.c2p(1.0, 0)), run_time=3) + self.wait() + + self.play( + sum_label.animate.move_to(lhs, RIGHT), + Write(rhs) + ) + self.wait() + self.play( + FlashAround(rhs[6:]), + FlashAround(fg_label), + time_width=2.0, + run_time=3, + ) + self.wait() + + # Move p_{X + Y} + equals = Tex("=").rotate(PI / 2) + equals.next_to(lhs, UP) + + self.play( + sum_label.animate.next_to(equals, UP, MED_SMALL_BUFF), + Write(lhs), + Write(equals), + ) + self.wait() + + # Slow panning + for s in [-2, 2]: + self.play(s_indicator.animate.match_x(g_axes.c2p(s, 0)), run_time=8) + self.wait() + + def highlight_several_regions(self, highlighted_xs=None, s=0, reference=None): + # Highlight a few regions + if highlighted_xs is None: + highlighted_xs = np.arange(-1, 1.1, 0.1) + + g_axes = self.all_axes[1] + highlight_rect = Rectangle(width=0.1, height=FRAME_HEIGHT / 2) + highlight_rect.set_stroke(width=0) + highlight_rect.set_fill(TEAL, 0.5) + highlight_rect.move_to(g_axes.get_origin(), DOWN) + highlight_rect.set_opacity(0.5) + self.add(highlight_rect) + + last_label = VMobject() + for x in highlighted_xs: + x_tex = f"{{{np.round(x, 1)}}}" + diff_tex = f"{{{np.round(s - x, 1)}}}" + label = MTex( + fR"p_X({x_tex}) \cdot p_Y({diff_tex})", + tex_to_color_map={diff_tex: YELLOW, x_tex: BLUE}, + font_size=36 + ) + if reference: + label.next_to(reference, UP, MED_LARGE_BUFF) + else: + label.next_to(ORIGIN, DOWN, LARGE_BUFF) + + highlight_rect.set_x(g_axes.c2p(x, 0)[0]), + self.add(label) + self.remove(last_label) + self.wait(0.25) + last_label = label + self.play(FadeOut(last_label), FadeOut(highlight_rect)) def f(self, x): return max(-abs(x) + 1, 0) @@ -549,13 +816,13 @@ class ProbConvolutionControlled(ProbConvolutions): initial_t = 0 def construct(self): - t_indicator = self.t_indicator + s_indicator = self.s_indicator g_axes = self.all_axes[1] def set_t(t): - return t_indicator.animate.set_x(g_axes.c2p(t, 0)[0]) + return s_indicator.animate.set_x(g_axes.c2p(t, 0)[0]) - t_indicator.set_x(g_axes.c2p(self.initial_t, 0)[0]) + s_indicator.set_x(g_axes.c2p(self.initial_t, 0)[0]) for t, time in self.t_time_pairs: self.play(set_t(t), run_time=time) self.wait() @@ -570,12 +837,12 @@ class AltConvolutions(Convolutions): jagged_product = True def construct(self): - t_indicator = self.t_indicator + s_indicator = self.s_indicator g_axes = self.all_axes[1] # Sample values for t in [3, -3, -1.0]: - self.play(t_indicator.animate.set_x(g_axes.c2p(t, 0)[0]), run_time=3) + self.play(s_indicator.animate.set_x(g_axes.c2p(t, 0)[0]), run_time=3) self.wait() def f(self, x): @@ -600,9 +867,10 @@ class MovingAverageAsConvolution(Convolutions): def construct(self): # Setup super().construct() - t_indicator = self.t_indicator - g_axes = self.all_axes[1] + s_indicator = self.s_indicator + f_axes, g_axes, fg_axes, conv_axes = self.all_axes self.g_label.shift(0.25 * UP) + self.fg_label.shift(0.25 * UP) y_axes = VGroup(*(axes.y_axis for axes in self.all_axes[1:3])) fake_ys = y_axes.copy() @@ -610,9 +878,14 @@ class MovingAverageAsConvolution(Convolutions): fake_y.stretch(1.2, 1) self.add(*fake_ys, *self.mobjects) + conv_axes.y_axis.match_height(f_axes.y_axis) + VGroup(conv_axes).match_y(f_axes) + self.conv_graph.become(self.get_conv_graph(conv_axes, self.f, self.g)) + self.conv_label.next_to(conv_axes, DOWN, MED_LARGE_BUFF) + # Sample values def set_t(t): - return t_indicator.animate.set_x(g_axes.c2p(t, 0)[0]) + return s_indicator.animate.set_x(g_axes.c2p(t, 0)[0]) self.play(set_t(-2.5), run_time=2) self.play(set_t(2.5), run_time=8) @@ -639,20 +912,22 @@ class MovingAverageAsConvolution(Convolutions): self.play(FadeOut(fade_rects)) # Show rect dimensions - get_k = self.k_tracker.get_value + get_k1 = self.k1_tracker.get_value + get_k2 = self.k2_tracker.get_value top_label = DecimalNumber(1, font_size=24) - top_label.add_updater(lambda m: m.set_value(1 / get_k())) + top_label.add_updater(lambda m: m.set_value(1.0 / get_k1())) top_label.add_updater(lambda m: m.next_to(top_line, UP, SMALL_BUFF)) side_label = DecimalNumber(1, font_size=24) - side_label.add_updater(lambda m: m.set_value(get_k())) + side_label.add_updater(lambda m: m.set_value(get_k2())) side_label.add_updater(lambda m: m.next_to(side_line, LEFT, SMALL_BUFF)) - def change_k(k, run_time=3): + def change_ks(k1, k2, run_time=3): new_conv_graph = self.get_conv_graph( - self.all_axes[3], self.f, lambda x: self.g(k * x) * k, + self.all_axes[3], self.f, lambda x: self.g(k1 * x) * k2, ) self.play( - self.k_tracker.animate.set_value(k), + self.k1_tracker.animate.set_value(k1), + self.k2_tracker.animate.set_value(k2), Transform(self.conv_graph, new_conv_graph), run_time=run_time ) @@ -669,13 +944,20 @@ class MovingAverageAsConvolution(Convolutions): VFadeIn(top_label), ) self.wait() - change_k(0.5) + + # Change dimensions + change_ks(0.5, 1) self.wait() - self.play(set_t(-1.5), run_time=3) + change_ks(0.5, 0.5) + self.play(set_t(-1.5), run_time=2) + self.play(set_t(-0.25), run_time=2) self.wait() - change_k(2) + change_ks(2, 0.5) self.wait() - change_k(1) + change_ks(2, 2) + self.wait() + change_ks(4, 4) + change_ks(1, 1) self.play(*map(FadeOut, [top_label, top_line, side_label, side_line])) # Show area @@ -683,9 +965,9 @@ class MovingAverageAsConvolution(Convolutions): rect.set_fill(YELLOW, 0.5) rect.set_stroke(width=0) rect.set_gloss(1) - rect.add_updater(lambda m: m.set_width(g_axes.x_axis.unit_size / get_k(), stretch=True)) - rect.add_updater(lambda m: m.set_height(g_axes.y_axis.unit_size * get_k(), stretch=True)) - rect.add_updater(lambda m: m.set_x(t_indicator.get_x())) + rect.add_updater(lambda m: m.set_width(g_axes.x_axis.unit_size / get_k1(), stretch=True)) + rect.add_updater(lambda m: m.set_height(g_axes.y_axis.unit_size * get_k2(), stretch=True)) + rect.add_updater(lambda m: m.set_x(s_indicator.get_x())) rect.add_updater(lambda m: m.set_y(g_axes.get_origin()[1], DOWN)) area_label = Tex(R"\text{Area } = 1", font_size=36) @@ -710,13 +992,20 @@ class MovingAverageAsConvolution(Convolutions): ShowCreation(arrow2) ) self.wait() - for k in [1.4, 0.8, 1.0]: - change_k(k) + for k in [1.4, 0.8, 1.0, 4.0, 10.0, 1.0]: + change_ks(k, k) self.play(*map(FadeOut, [area_label, arrow, avg_label, arrow2])) - # Slide once more + # More ambient variation self.play(set_t(-2.5), run_time=3) self.play(set_t(2.5), run_time=8) + self.play(set_t(0), run_time=4) + change_ks(20, 20) + self.wait() + change_ks(10, 10) + self.wait() + change_ks(0.2, 0.2, run_time=12) + self.wait() def f(self, x): return kinked_function(x) @@ -770,8 +1059,12 @@ class DiagonalSlices(ProbConvolutions): surface_mesh = SurfaceMesh(surface, resolution=(21, 21)) surface_mesh.set_stroke(WHITE, 0.5, 0.5) - func_name = Tex(R"f(x) \cdot g(y)") - func_name.to_corner(UL) + func_name = MTex( + R"p_X(x) \cdot p_Y(y)", + tex_to_color_map={"X": BLUE, "Y": YELLOW}, + font_size=42, + ) + func_name.to_corner(UL, buff=0.25) func_name.fix_in_frame() self.add(surface) @@ -789,9 +1082,9 @@ class DiagonalSlices(ProbConvolutions): equation.fix_in_frame() equation[1].add_updater(lambda m: m.set_value(t_tracker.get_value())) - set_label = MTex(R"\{(x, t - x): x \in \mathds{R}\}", tex_to_color_map={"t": YELLOW}, font_size=30) - set_label.next_to(equation, DOWN, MED_LARGE_BUFF, aligned_edge=RIGHT) - set_label.fix_in_frame() + ses_label = MTex(R"\{(x, s - x): x \in \mathds{R}\}", tex_to_color_map={"s": YELLOW}, font_size=30) + ses_label.next_to(equation, DOWN, MED_LARGE_BUFF, aligned_edge=RIGHT) + ses_label.fix_in_frame() self.play(frame.animate.reorient(20, 70), run_time=5) self.wait() @@ -808,7 +1101,7 @@ class DiagonalSlices(ProbConvolutions): ) self.wait() self.play( - FadeIn(set_label, 0.5 * DOWN), + FadeIn(ses_label, 0.5 * DOWN), MoveAlongPath(GlowDot(), slice_graph, run_time=5, remover=True) ) self.wait() @@ -879,7 +1172,12 @@ class DiagonalSlices(ProbConvolutions): class RepeatedConvolution(MovingAverageAsConvolution): resolution = 0.01 - n_iterations = 12 + n_iterations = 20 + when_to_renormalize = 6 + f_label_tex = "f_1(x)" + g_label_tex = "f_1(s - x)" + fg_label_tex = R"f_1(x) \cdot f_1(s - x)" + conv_label_tex = R"f_2(s) = [f_1 * f_1](s)" def construct(self): # Clean the board @@ -904,25 +1202,17 @@ class RepeatedConvolution(MovingAverageAsConvolution): # New f graph f_graph = g_graph.deepcopy() f_graph.clear_updaters() - f_graph.set_stroke(BLUE) + f_graph.set_stroke(BLUE, 3) f_graph.shift(axes1.get_origin() - axes2.get_origin()) self.add(f_graph) # New prod graph - t_indicator = self.t_indicator - - def get_t(): - return axes2.x_axis.p2n(t_indicator.get_center()) - - def set_t(t): - return t_indicator.animate.set_x(axes2.c2p(t)[0]) - def update_prod_graph(prod_graph): prod_samples = f_samples.copy() - t = get_t() - prod_samples[x_samples < t - 0.5] = 0 - prod_samples[x_samples > t + 0.5] = 0 + s = self.get_s() + prod_samples[x_samples < s - 0.5] = 0 + prod_samples[x_samples > s + 0.5] = 0 prod_graph.set_points_as_corners( axes3.c2p(x_samples, prod_samples) ) @@ -932,49 +1222,110 @@ class RepeatedConvolution(MovingAverageAsConvolution): prod_graph.set_fill(BLUE_E, 1) prod_graph.add_updater(update_prod_graph) + self.fg_label.shift(0.35 * UP) + self.g_label.shift(0.35 * UP) + self.add(prod_graph) self.add(self.fg_label) - # Convolution + # Move convolution axes + conv_axes.y_axis.match_height(axes1.y_axis) + conv_axes.match_y(axes1) + self.remove(self.conv_label) + conv_label = self.conv_label = self.get_conv_label(2) + self.add(conv_label) + + # Show repeated convolutions + self.n = 2 + for x in range(self.n_iterations): + conv_samples, conv_graph = self.create_convolution( + x_samples, f_samples, g_samples, conv_axes, + # TODO, account for renormalized version + ) + if self.n == self.when_to_renormalize: + self.renormalize() + self.swap_graphs(f_graph, conv_graph, axes1, conv_axes) + self.swap_labels() + f_samples[:] = conv_samples + self.n += 1 + + def create_convolution(self, x_samples, f_samples, g_samples, conv_axes): + # Test + self.set_s(-3, animate=False) + conv_samples, conv_graph = self.get_conv( x_samples, f_samples, g_samples, conv_axes ) endpoint_dot = GlowDot(color=WHITE) endpoint_dot.add_updater(lambda m: m.move_to(conv_graph.get_points()[-1])) - self.add(conv_graph) + self.play( + self.set_s(3), + ShowCreation(conv_graph), + UpdateFromAlphaFunc( + endpoint_dot, lambda m, a: m.move_to(conv_graph.get_end()).set_opacity(min(6 * a, 1)), + ), + run_time=5, + rate_func=bezier([0, 0, 1, 1]) + ) + self.play(FadeOut(endpoint_dot)) - # Show new convolutions - for n in range(self.n_iterations): - t_indicator.set_x(axes2.c2p(-3, 0)[0]) - self.play( - set_t(3), - ShowCreation(conv_graph), - UpdateFromAlphaFunc( - endpoint_dot, lambda m, a: m.set_opacity(a), - time_span=(0, 0.5), - ), - run_time=5, - rate_func=bezier([0, 0, 1, 1]) - ) - self.play(FadeOut(endpoint_dot)) - shift_value = axes1.get_origin() - conv_axes.get_origin() - cg_anim = conv_graph.animate.stretch(1 / 1.5, 1, about_point=conv_axes.get_origin()) - cg_anim.shift(shift_value) - cg_anim.match_style(f_graph) - self.play( - cg_anim, - FadeOut(f_graph, shift_value), - FadeOut(axes1, shift_value), - Transform(conv_axes.deepcopy(), axes1, remover=True) - ) - self.add(axes1, conv_graph) + return conv_samples, conv_graph - f_samples[:] = conv_samples - f_graph = conv_graph - conv_samples, conv_graph = self.get_conv( - x_samples, f_samples, g_samples, conv_axes - ) + def swap_graphs(self, f_graph, conv_graph, f_axes, conv_axes): + shift_value = f_axes.get_origin() - conv_axes.get_origin() + conv_axes_copy = conv_axes.deepcopy() + + f_label = self.f_label + new_f_label = MTex(f"f_{{{self.n}}}(x)", **self.label_config) + new_f_label.replace(self.conv_label[0]) + new_f_label[-2].set_opacity(0) + + f_group = VGroup(f_axes, f_graph, f_label) + self.add(conv_axes_copy, conv_graph) + self.play(LaggedStart( + Transform(conv_axes_copy, f_axes, remover=True), + conv_graph.animate.shift(shift_value).match_style(f_graph), + FadeOut(f_group, shift_value), + new_f_label.animate.replace(f_label, dim_to_match=1).set_opacity(1), + )) + self.remove(conv_graph, new_f_label) + f_graph.become(conv_graph) + f_label.become(new_f_label) + self.add(f_axes, f_graph, f_label) + + def swap_labels(self): + # Test + new_conv_label = self.get_conv_label(self.n + 1) + new_conv_label.replace(self.conv_label) + prod_rhs = self.fg_label[6:] + new_prod_rhs = MTex(f"f_{{{self.n}}}(s - x)") + new_prod_rhs.replace(prod_rhs, dim_to_match=1) + + to_remove = VGroup( + self.conv_label[0], + self.conv_label[3], + prod_rhs, + ) + to_add = VGroup( + new_conv_label[0], + new_conv_label[3], + new_prod_rhs, + ) + self.play( + LaggedStartMap(FadeOut, to_remove, shift=0.5 * UP), + LaggedStartMap(FadeIn, to_add, shift=0.5 * UP), + ) + + self.remove(self.conv_label) + self.remove(new_prod_rhs) + self.conv_label = new_conv_label + prod_rhs.become(new_prod_rhs) + self.add(self.conv_label) + self.add(prod_rhs) + + def renormalize(self): + pass def get_conv(self, x_samples, f_samples, g_samples, axes): """ @@ -983,16 +1334,45 @@ class RepeatedConvolution(MovingAverageAsConvolution): conv_samples = self.resolution * scipy.signal.fftconvolve( f_samples, g_samples, mode='same' ) - conv_graph = VMobject().set_points_as_corners( - axes.c2p(x_samples, conv_samples) - ) - conv_graph.set_stroke(TEAL, 2) + conv_graph = VMobject() + conv_graph.set_points_as_corners(axes.c2p(x_samples, conv_samples)) + conv_graph.set_stroke(TEAL, 3) return conv_samples, conv_graph + def get_s(self): + return self.all_axes[1].x_axis.p2n(self.s_indicator.get_center()) + + def set_s(self, s, animate=True): + if animate: + mob = self.s_indicator.animate + else: + mob = self.s_indicator + return mob.set_x(self.all_axes[1].c2p(s)[0]) + + def get_conv_label(self, n): + lhs = f"f_{{{n}}}(s)" + last = f"f_{{{n - 1}}}" + result = Tex(lhs, "=", R"\big[", last, "*", "f_1", R"\big]", "(s)") + result.set_height(0.5) + result.next_to(self.all_axes[3], DOWN, MED_LARGE_BUFF) + return result + def f(self, x): return rect_func(x) +# Supplements + +class AsideOnVariance(InteractiveScene): + def construct(self): + pass + + +class RotateXplusYLine(InteractiveScene): + def construct(self): + pass + + # Final class FunctionAverage(InteractiveScene): def construct(self): @@ -1011,13 +1391,13 @@ class MovingAverageOfRectFuncs(Convolutions): def construct(self): super().construct() - t_indicator = self.t_indicator + s_indicator = self.s_indicator g_axes = self.all_axes[1] self.all_axes[3].y_axis.match_height(g_axes.y_axis) self.conv_graph.set_height(0.5 * g_axes.y_axis.get_height(), about_edge=DOWN, stretch=True) for t in [3, -3, 0]: - self.play(t_indicator.animate.set_x(g_axes.c2p(t, 0)[0]), run_time=5) + self.play(s_indicator.animate.set_x(g_axes.c2p(t, 0)[0]), run_time=5) self.wait() def f(self, x): @@ -1102,7 +1482,7 @@ class RectConvolutionsNewNotation(MovingAverages): # Show the rest for n in range(2): left_graph = rect_graphs[n] if n == 0 else conv_graphs[n - 1] - left_label = rect_defs[n] if n == 0 else conv_labels[n - 1] + lefs_label = rect_defs[n] if n == 0 else conv_labels[n - 1] k = 2 * n + 5 new_rect = Rectangle(axes2.x_axis.unit_size / k, axes2.y_axis.unit_size * k) new_rect.set_stroke(width=0) @@ -1110,7 +1490,7 @@ class RectConvolutionsNewNotation(MovingAverages): new_rect.move_to(axes2.get_origin(), DOWN) self.play( FadeOut(left_graph, 1.5 * LEFT), - FadeOut(left_label, 1.5 * LEFT), + FadeOut(lefs_label, 1.5 * LEFT), FadeOut(rect_defs[n + 1]), FadeOut(rect_graphs[n + 1]), conv_labels[n].animate.match_x(axes1),