Add Softmax scene

This commit is contained in:
Grant Sanderson 2024-02-22 11:46:00 -08:00
parent ba0743e330
commit 20b3a1b4aa
2 changed files with 244 additions and 17 deletions

View file

@ -1965,10 +1965,11 @@ class SoftmaxBreakdown(InteractiveScene):
def construct(self):
# Show example probability distribution
word_strs = ['Dumbledore', 'Flitwick', 'Mcgonagall', 'Quirrell', 'Snape', 'Sprout', 'Trelawney']
words = VGroup(*map(Text, word_strs))
values = np.array([0.3, -1, 0.5, 1.5, 3.4, -1, 2.5])
words = VGroup(*(Text(word_str, font_size=30) for word_str in word_strs))
values = np.array([-0.8, -5.0, 0.5, 1.5, 3.4, -2.3, 2.5])
prob_values = softmax(values)
chart = BarChart(prob_values, width=10)
chart.bars.set_stroke(width=1)
probs = VGroup(*(DecimalNumber(pv) for pv in prob_values))
probs.arrange(DOWN, buff=0.25)
@ -2002,9 +2003,10 @@ class SoftmaxBreakdown(InteractiveScene):
self.wait()
# Show constraint between 0 and 1
bar = chart.bars[0]
index = 3
bar = chart.bars[index]
bar.save_state()
prob = probs[0]
prob = probs[index]
prob.bar = bar
max_height = chart.y_axis.get_y(UP) - chart.x_axis.get_y()
prob.add_updater(lambda p: p.set_value(p.bar.get_height() / max_height))
@ -2048,7 +2050,6 @@ class SoftmaxBreakdown(InteractiveScene):
Write(equals),
FadeOut(one_line),
)
globals().update(locals())
self.play(
LaggedStart(*(
FadeTransform(pc.copy(), rhs)
@ -2082,12 +2083,12 @@ class SoftmaxBreakdown(InteractiveScene):
vector.to_edge(UP).set_x(2.5)
matrix.next_to(vector, LEFT)
self.play(
self.play(LaggedStart(
chart_group.animate.scale(0.35).to_corner(DL),
FadeOut(sum_group),
FadeIn(matrix, lag_ratio=0.01),
FadeIn(vector, lag_ratio=0.01),
)
FadeOut(sum_group, UP),
FadeIn(matrix, UP),
FadeIn(vector, UP),
))
eq, rhs = show_matrix_vector_product(self, matrix, vector, x_max=9)
self.wait()
@ -2156,10 +2157,7 @@ class SoftmaxBreakdown(InteractiveScene):
rhs.target.to_edge(LEFT, buff=1.5)
rhs.target.set_y(0)
softmax_box = Rectangle(
width=5,
height=rhs.get_height() + 1,
)
softmax_box = Rectangle(width=5, height=6.5)
softmax_box.set_stroke(BLUE, 2)
softmax_box.set_fill(BLUE_E, 0.5)
in_arrow, out_arrow = Vector(RIGHT).replicate(2)
@ -2188,22 +2186,250 @@ class SoftmaxBreakdown(InteractiveScene):
FadeOut(vector, 3 * LEFT),
FadeOut(eq, 3.5 * LEFT),
FadeOut(chart_group, DL),
TransformFromCopy(chart.bars, bars),
GrowArrow(in_arrow),
FadeIn(softmax_box, RIGHT),
FadeIn(softmax_label, RIGHT),
MoveToTarget(rhs),
GrowArrow(out_arrow),
FadeIn(output, RIGHT),
TransformFromCopy(chart.bars, bars),
), lag_ratio=0.2, run_time=2)
self.wait()
# Highlight larger and smaller parts
rhs_entries = rhs.get_entries()
changer = VGroup(rhs_entries, output.get_entries(), bars)
changer.save_state()
for index in range(4, 0, -1):
changer.target = changer.saved_state.copy()
changer.target.set_fill(border_width=0)
for group in changer.target:
for j, elem in enumerate(group):
if j != index:
elem.fade(0.8)
self.play(MoveToTarget(changer))
self.wait()
self.play(Restore(changer))
self.remove(changer)
self.add(rhs, output, bars)
self.wait()
# Swap out for variables
variables = VGroup(*(
Tex(f"x_{{{n}}}", font_size=48).move_to(elem)
for n, elem in enumerate(rhs_entries, start=1)
))
self.remove(rhs_entries)
self.play(
LaggedStart(*(
TransformFromCopy(entry, variable, path_arc=PI / 2)
for entry, variable in zip(rhs_entries, variables)
), lag_ratio=0.1, run_time=1.0)
)
self.wait()
# Exponentiate each part
exp_parts = VGroup(*(
Tex(f"e^{{{var.get_tex()}}}", font_size=48).move_to(var)
for var in variables
))
exp_parts.align_to(softmax_box, LEFT)
exp_parts.shift(0.75 * RIGHT)
exp_parts.space_out_submobjects(1.5)
self.play(
softmax_label.animate.next_to(softmax_box, UP, buff=0.15),
LaggedStart(*(
TransformMatchingStrings(var.copy(), exp_part)
for var, exp_part in zip(variables, exp_parts)
), run_time=1, lag_ratio=0.01)
)
self.wait()
# Compute the sum
exp_sum = Tex(R"\sum_{n=0}^{N-1} e^{x_{n}}", font_size=42)
exp_sum[R"e^{x_{n}}"].scale(1.5, about_edge=LEFT)
exp_sum.next_to(softmax_box.get_right(), LEFT, buff=0.75)
globals().update(locals())
lines = VGroup(*(Line(exp_part.get_right(), exp_sum.get_left(), buff=0.1) for exp_part in exp_parts))
lines.set_stroke(TEAL, 2)
self.play(
LaggedStart(*(
FadeTransform(exp_part.copy(), exp_sum)
for exp_part in exp_parts
), lag_ratio=0.01),
LaggedStartMap(ShowCreation, lines, lag_ratio=0.01),
run_time=1
)
self.wait()
self.play(FadeOut(lines))
# Divide each part by the sum
lil_denoms = VGroup()
for exp_part in exp_parts:
slash = Tex("/").match_height(exp_sum)
slash.next_to(exp_sum, LEFT, buff=0)
denom = VGroup(slash, exp_sum).copy()
denom.set_height(exp_part.get_height() * 1.5)
denom.next_to(exp_part, RIGHT, buff=0)
lil_denoms.add(denom)
lil_denoms.align_to(softmax_box.get_center(), LEFT)
lines = VGroup(*(Line(exp_sum.get_left(), denom.get_center()) for denom in lil_denoms))
lines.set_stroke(TEAL, 1)
self.remove(exp_sum)
self.play(
exp_parts.animate.next_to(lil_denoms, LEFT, buff=0),
LaggedStart(*(
FadeTransform(exp_sum.copy(), denom)
for denom in lil_denoms
), lag_ratio=0.01),
)
self.wait()
# Resize box
sm_terms = VGroup(*(
VGroup(exp_part, denom)
for exp_part, denom in zip(exp_parts, lil_denoms)
))
sm_terms.generate_target()
target_height = 5.0
full_output = Group(output, bars)
full_output.generate_target()
full_output.target.set_height(target_height, about_edge=RIGHT)
full_output.target.shift(1.5 * LEFT)
equals = Tex("=")
equals.next_to(full_output.target, LEFT)
softmax_box.generate_target()
softmax_box.target.set_width(3.0, stretch=True)
VGroup(softmax_box.target, sm_terms.target).set_height(target_height + 0.5).next_to(equals, LEFT)
rhs.generate_target()
rhs_entries.become(variables)
self.remove(variables)
rhs.target.set_height(target_height)
rhs.target.next_to(softmax_box.target, LEFT, buff=1.5)
self.play(
softmax_label.animate.next_to(softmax_box.target, UP),
MoveToTarget(softmax_box),
MoveToTarget(sm_terms),
MoveToTarget(full_output),
MoveToTarget(rhs),
FadeTransform(out_arrow, equals),
in_arrow.animate.become(
Arrow(rhs.target, softmax_box.target).match_style(in_arrow)
),
)
self.wait()
# Set up updaters
output_entries = output.get_entries()
bar_width_ratio = bars.get_width() / max(o.get_value() for o in output_entries)
temp_tracker = ValueTracker(1)
def update_outs(output_entries):
inputs = [entry.get_value() for entry in rhs_entries]
outputs = softmax(inputs, temp_tracker.get_value())
for entry, output in zip(output_entries, outputs):
entry.set_value(output)
def update_bars(bars):
for bar, entry in zip(bars, output_entries):
width = max(bar_width_ratio * entry.get_value(), 1e-3)
bar.set_width(width, about_edge=LEFT, stretch=True)
output_entries.add_updater(update_outs)
bars.add_updater(update_bars)
self.add(bars, output_entries)
# Tweak values
for index, value in [(6, 4.0), (4, 4.2), (1, 0.0), (0, 6.0), (4, 9.9)]:
entry = rhs_entries[index]
rect = SurroundingRectangle(entry)
rect.set_stroke(BLUE if value > entry.get_value() else RED, 3)
self.play(
ChangeDecimalToValue(entry, value),
FadeIn(rect, time_span=(0, 1)),
run_time=2
)
self.play(FadeOut(rect))
# Add temperature
frame = self.frame
temp_color = RED
new_title = Text("softmax with temperature")
new_title["temperature"].set_color(temp_color)
get_t = temp_tracker.get_value
t_line = NumberLine(
(0, 10, 0.2),
tick_size=0.025,
big_tick_spacing=1,
longer_tick_multiple=2.0,
width=4
)
t_line.set_stroke(width=1.5)
t_line.next_to(softmax_box, UP)
t_tri = ArrowTip(angle=-90 * DEGREES)
t_tri.set_color(temp_color)
t_tri.set_height(0.2)
t_label = Tex("T = 0.00", font_size=36)
t_label.rhs = t_label.make_number_changable("0.00")
t_label["T"].set_color(temp_color)
globals().update(locals())
t_tri.add_updater(lambda m: m.move_to(t_line.n2p(get_t()), DOWN))
t_label.add_updater(lambda m: m.rhs.set_value(get_t()))
t_label.add_updater(lambda m: m.next_to(t_tri, UP, buff=0.1, aligned_edge=LEFT))
new_title.next_to(t_label, UP, buff=0.5).match_x(softmax_box)
self.play(
frame.animate.move_to(0.75 * UP),
TransformMatchingStrings(softmax_label, new_title),
FadeIn(t_line),
FadeIn(t_tri),
FadeIn(t_label),
run_time=1
)
# Change formula
template = Tex(R"e^{x_{0} / T} / \sum_{n=0}^{N - 1} e^{x_n / T}")
template["T"].set_color(temp_color)
template["/"][1].scale(1.9, about_edge=LEFT)
template[R"\sum_{n=0}^{N - 1}"][0].scale(0.7, about_edge=RIGHT)
index_part = template.make_number_changable("0")
new_sm_terms = VGroup()
all_Ts = VGroup()
for n, term in enumerate(sm_terms, start=1):
template.replace(term, dim_to_match=1)
index_part.set_value(n)
new_term = template.copy()
all_Ts.add(*new_term["T"])
new_sm_terms.add(new_term)
self.play(
LaggedStart(*(
FadeTransform(old_term, new_term)
for old_term, new_term in zip(sm_terms, new_sm_terms)
)),
LaggedStart(*(
TransformFromCopy(t_label[0], t_mob[0])
for t_mob in all_Ts
)),
)
self.wait()
# Oscilate between values
self.play(temp_tracker.animate.set_value(3), run_time=3)
self.wait()
self.play(temp_tracker.animate.set_value(10), run_time=3)
self.wait()
# Comment on largest values

View file

@ -84,9 +84,10 @@ def matrix_row_vector_product(scene, row, vector, entry, to_fade):
ShowIncreasingSubsets(row_rects),
ShowIncreasingSubsets(vect_rects),
UpdateFromAlphaFunc(entry, lambda m, a: m.set_value(
partial_values[min(int(a * n_values), n_values - 1)]
partial_values[min(int(np.round(a * n_values)), n_values - 1)]
)),
FadeOut(to_fade),
rate_func=linear,
)
return VGroup(row_rects, vect_rects)