From 20b3a1b4aa2cd5a6af5d6f20155db63d2e86bae4 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Thu, 22 Feb 2024 11:46:00 -0800 Subject: [PATCH] Add Softmax scene --- _2024/transformers/embedding.py | 258 ++++++++++++++++++++++++++++++-- _2024/transformers/objects.py | 3 +- 2 files changed, 244 insertions(+), 17 deletions(-) diff --git a/_2024/transformers/embedding.py b/_2024/transformers/embedding.py index edf17bd..ec47dba 100644 --- a/_2024/transformers/embedding.py +++ b/_2024/transformers/embedding.py @@ -1965,10 +1965,11 @@ class SoftmaxBreakdown(InteractiveScene): def construct(self): # Show example probability distribution word_strs = ['Dumbledore', 'Flitwick', 'Mcgonagall', 'Quirrell', 'Snape', 'Sprout', 'Trelawney'] - words = VGroup(*map(Text, word_strs)) - values = np.array([0.3, -1, 0.5, 1.5, 3.4, -1, 2.5]) + words = VGroup(*(Text(word_str, font_size=30) for word_str in word_strs)) + values = np.array([-0.8, -5.0, 0.5, 1.5, 3.4, -2.3, 2.5]) prob_values = softmax(values) chart = BarChart(prob_values, width=10) + chart.bars.set_stroke(width=1) probs = VGroup(*(DecimalNumber(pv) for pv in prob_values)) probs.arrange(DOWN, buff=0.25) @@ -2002,9 +2003,10 @@ class SoftmaxBreakdown(InteractiveScene): self.wait() # Show constraint between 0 and 1 - bar = chart.bars[0] + index = 3 + bar = chart.bars[index] bar.save_state() - prob = probs[0] + prob = probs[index] prob.bar = bar max_height = chart.y_axis.get_y(UP) - chart.x_axis.get_y() prob.add_updater(lambda p: p.set_value(p.bar.get_height() / max_height)) @@ -2048,7 +2050,6 @@ class SoftmaxBreakdown(InteractiveScene): Write(equals), FadeOut(one_line), ) - globals().update(locals()) self.play( LaggedStart(*( FadeTransform(pc.copy(), rhs) @@ -2082,12 +2083,12 @@ class SoftmaxBreakdown(InteractiveScene): vector.to_edge(UP).set_x(2.5) matrix.next_to(vector, LEFT) - self.play( + self.play(LaggedStart( chart_group.animate.scale(0.35).to_corner(DL), - FadeOut(sum_group), - FadeIn(matrix, lag_ratio=0.01), - FadeIn(vector, lag_ratio=0.01), - ) + FadeOut(sum_group, UP), + FadeIn(matrix, UP), + FadeIn(vector, UP), + )) eq, rhs = show_matrix_vector_product(self, matrix, vector, x_max=9) self.wait() @@ -2156,10 +2157,7 @@ class SoftmaxBreakdown(InteractiveScene): rhs.target.to_edge(LEFT, buff=1.5) rhs.target.set_y(0) - softmax_box = Rectangle( - width=5, - height=rhs.get_height() + 1, - ) + softmax_box = Rectangle(width=5, height=6.5) softmax_box.set_stroke(BLUE, 2) softmax_box.set_fill(BLUE_E, 0.5) in_arrow, out_arrow = Vector(RIGHT).replicate(2) @@ -2188,22 +2186,250 @@ class SoftmaxBreakdown(InteractiveScene): FadeOut(vector, 3 * LEFT), FadeOut(eq, 3.5 * LEFT), FadeOut(chart_group, DL), - TransformFromCopy(chart.bars, bars), GrowArrow(in_arrow), FadeIn(softmax_box, RIGHT), FadeIn(softmax_label, RIGHT), MoveToTarget(rhs), GrowArrow(out_arrow), FadeIn(output, RIGHT), + TransformFromCopy(chart.bars, bars), ), lag_ratio=0.2, run_time=2) self.wait() # Highlight larger and smaller parts + rhs_entries = rhs.get_entries() + changer = VGroup(rhs_entries, output.get_entries(), bars) + changer.save_state() + for index in range(4, 0, -1): + changer.target = changer.saved_state.copy() + changer.target.set_fill(border_width=0) + for group in changer.target: + for j, elem in enumerate(group): + if j != index: + elem.fade(0.8) + self.play(MoveToTarget(changer)) + self.wait() + self.play(Restore(changer)) + self.remove(changer) + self.add(rhs, output, bars) + self.wait() + + # Swap out for variables + variables = VGroup(*( + Tex(f"x_{{{n}}}", font_size=48).move_to(elem) + for n, elem in enumerate(rhs_entries, start=1) + )) + + self.remove(rhs_entries) + self.play( + LaggedStart(*( + TransformFromCopy(entry, variable, path_arc=PI / 2) + for entry, variable in zip(rhs_entries, variables) + ), lag_ratio=0.1, run_time=1.0) + ) + self.wait() # Exponentiate each part + exp_parts = VGroup(*( + Tex(f"e^{{{var.get_tex()}}}", font_size=48).move_to(var) + for var in variables + )) + exp_parts.align_to(softmax_box, LEFT) + exp_parts.shift(0.75 * RIGHT) + exp_parts.space_out_submobjects(1.5) + + self.play( + softmax_label.animate.next_to(softmax_box, UP, buff=0.15), + LaggedStart(*( + TransformMatchingStrings(var.copy(), exp_part) + for var, exp_part in zip(variables, exp_parts) + ), run_time=1, lag_ratio=0.01) + ) + self.wait() # Compute the sum + exp_sum = Tex(R"\sum_{n=0}^{N-1} e^{x_{n}}", font_size=42) + exp_sum[R"e^{x_{n}}"].scale(1.5, about_edge=LEFT) + exp_sum.next_to(softmax_box.get_right(), LEFT, buff=0.75) + + globals().update(locals()) + lines = VGroup(*(Line(exp_part.get_right(), exp_sum.get_left(), buff=0.1) for exp_part in exp_parts)) + lines.set_stroke(TEAL, 2) + + self.play( + LaggedStart(*( + FadeTransform(exp_part.copy(), exp_sum) + for exp_part in exp_parts + ), lag_ratio=0.01), + LaggedStartMap(ShowCreation, lines, lag_ratio=0.01), + run_time=1 + ) + self.wait() + self.play(FadeOut(lines)) # Divide each part by the sum + lil_denoms = VGroup() + for exp_part in exp_parts: + slash = Tex("/").match_height(exp_sum) + slash.next_to(exp_sum, LEFT, buff=0) + denom = VGroup(slash, exp_sum).copy() + denom.set_height(exp_part.get_height() * 1.5) + denom.next_to(exp_part, RIGHT, buff=0) + lil_denoms.add(denom) + lil_denoms.align_to(softmax_box.get_center(), LEFT) + + lines = VGroup(*(Line(exp_sum.get_left(), denom.get_center()) for denom in lil_denoms)) + lines.set_stroke(TEAL, 1) + + self.remove(exp_sum) + self.play( + exp_parts.animate.next_to(lil_denoms, LEFT, buff=0), + LaggedStart(*( + FadeTransform(exp_sum.copy(), denom) + for denom in lil_denoms + ), lag_ratio=0.01), + ) + self.wait() + + # Resize box + sm_terms = VGroup(*( + VGroup(exp_part, denom) + for exp_part, denom in zip(exp_parts, lil_denoms) + )) + sm_terms.generate_target() + + target_height = 5.0 + full_output = Group(output, bars) + full_output.generate_target() + full_output.target.set_height(target_height, about_edge=RIGHT) + full_output.target.shift(1.5 * LEFT) + equals = Tex("=") + equals.next_to(full_output.target, LEFT) + + softmax_box.generate_target() + softmax_box.target.set_width(3.0, stretch=True) + VGroup(softmax_box.target, sm_terms.target).set_height(target_height + 0.5).next_to(equals, LEFT) + + rhs.generate_target() + rhs_entries.become(variables) + self.remove(variables) + rhs.target.set_height(target_height) + rhs.target.next_to(softmax_box.target, LEFT, buff=1.5) + + self.play( + softmax_label.animate.next_to(softmax_box.target, UP), + MoveToTarget(softmax_box), + MoveToTarget(sm_terms), + MoveToTarget(full_output), + MoveToTarget(rhs), + FadeTransform(out_arrow, equals), + in_arrow.animate.become( + Arrow(rhs.target, softmax_box.target).match_style(in_arrow) + ), + ) + self.wait() + + # Set up updaters + output_entries = output.get_entries() + bar_width_ratio = bars.get_width() / max(o.get_value() for o in output_entries) + temp_tracker = ValueTracker(1) + + def update_outs(output_entries): + inputs = [entry.get_value() for entry in rhs_entries] + outputs = softmax(inputs, temp_tracker.get_value()) + for entry, output in zip(output_entries, outputs): + entry.set_value(output) + + def update_bars(bars): + for bar, entry in zip(bars, output_entries): + width = max(bar_width_ratio * entry.get_value(), 1e-3) + bar.set_width(width, about_edge=LEFT, stretch=True) + + output_entries.add_updater(update_outs) + bars.add_updater(update_bars) + + self.add(bars, output_entries) + + # Tweak values + for index, value in [(6, 4.0), (4, 4.2), (1, 0.0), (0, 6.0), (4, 9.9)]: + entry = rhs_entries[index] + rect = SurroundingRectangle(entry) + rect.set_stroke(BLUE if value > entry.get_value() else RED, 3) + self.play( + ChangeDecimalToValue(entry, value), + FadeIn(rect, time_span=(0, 1)), + run_time=2 + ) + self.play(FadeOut(rect)) + + # Add temperature + frame = self.frame + temp_color = RED + new_title = Text("softmax with temperature") + new_title["temperature"].set_color(temp_color) + get_t = temp_tracker.get_value + t_line = NumberLine( + (0, 10, 0.2), + tick_size=0.025, + big_tick_spacing=1, + longer_tick_multiple=2.0, + width=4 + ) + t_line.set_stroke(width=1.5) + t_line.next_to(softmax_box, UP) + t_tri = ArrowTip(angle=-90 * DEGREES) + t_tri.set_color(temp_color) + t_tri.set_height(0.2) + t_label = Tex("T = 0.00", font_size=36) + t_label.rhs = t_label.make_number_changable("0.00") + t_label["T"].set_color(temp_color) + globals().update(locals()) + t_tri.add_updater(lambda m: m.move_to(t_line.n2p(get_t()), DOWN)) + t_label.add_updater(lambda m: m.rhs.set_value(get_t())) + t_label.add_updater(lambda m: m.next_to(t_tri, UP, buff=0.1, aligned_edge=LEFT)) + + new_title.next_to(t_label, UP, buff=0.5).match_x(softmax_box) + + self.play( + frame.animate.move_to(0.75 * UP), + TransformMatchingStrings(softmax_label, new_title), + FadeIn(t_line), + FadeIn(t_tri), + FadeIn(t_label), + run_time=1 + ) + + # Change formula + template = Tex(R"e^{x_{0} / T} / \sum_{n=0}^{N - 1} e^{x_n / T}") + template["T"].set_color(temp_color) + template["/"][1].scale(1.9, about_edge=LEFT) + template[R"\sum_{n=0}^{N - 1}"][0].scale(0.7, about_edge=RIGHT) + index_part = template.make_number_changable("0") + + new_sm_terms = VGroup() + all_Ts = VGroup() + for n, term in enumerate(sm_terms, start=1): + template.replace(term, dim_to_match=1) + index_part.set_value(n) + new_term = template.copy() + all_Ts.add(*new_term["T"]) + new_sm_terms.add(new_term) + + self.play( + LaggedStart(*( + FadeTransform(old_term, new_term) + for old_term, new_term in zip(sm_terms, new_sm_terms) + )), + LaggedStart(*( + TransformFromCopy(t_label[0], t_mob[0]) + for t_mob in all_Ts + )), + ) + self.wait() + + # Oscilate between values + self.play(temp_tracker.animate.set_value(3), run_time=3) + self.wait() + self.play(temp_tracker.animate.set_value(10), run_time=3) + self.wait() - # Comment on largest values diff --git a/_2024/transformers/objects.py b/_2024/transformers/objects.py index 3ac63ff..2c05131 100644 --- a/_2024/transformers/objects.py +++ b/_2024/transformers/objects.py @@ -84,9 +84,10 @@ def matrix_row_vector_product(scene, row, vector, entry, to_fade): ShowIncreasingSubsets(row_rects), ShowIncreasingSubsets(vect_rects), UpdateFromAlphaFunc(entry, lambda m, a: m.set_value( - partial_values[min(int(a * n_values), n_values - 1)] + partial_values[min(int(np.round(a * n_values)), n_values - 1)] )), FadeOut(to_fade), + rate_func=linear, ) return VGroup(row_rects, vect_rects)