mirror of
https://github.com/3b1b/videos.git
synced 2025-08-31 21:58:59 +00:00
Add Softmax scene
This commit is contained in:
parent
ba0743e330
commit
20b3a1b4aa
2 changed files with 244 additions and 17 deletions
|
@ -1965,10 +1965,11 @@ class SoftmaxBreakdown(InteractiveScene):
|
|||
def construct(self):
|
||||
# Show example probability distribution
|
||||
word_strs = ['Dumbledore', 'Flitwick', 'Mcgonagall', 'Quirrell', 'Snape', 'Sprout', 'Trelawney']
|
||||
words = VGroup(*map(Text, word_strs))
|
||||
values = np.array([0.3, -1, 0.5, 1.5, 3.4, -1, 2.5])
|
||||
words = VGroup(*(Text(word_str, font_size=30) for word_str in word_strs))
|
||||
values = np.array([-0.8, -5.0, 0.5, 1.5, 3.4, -2.3, 2.5])
|
||||
prob_values = softmax(values)
|
||||
chart = BarChart(prob_values, width=10)
|
||||
chart.bars.set_stroke(width=1)
|
||||
|
||||
probs = VGroup(*(DecimalNumber(pv) for pv in prob_values))
|
||||
probs.arrange(DOWN, buff=0.25)
|
||||
|
@ -2002,9 +2003,10 @@ class SoftmaxBreakdown(InteractiveScene):
|
|||
self.wait()
|
||||
|
||||
# Show constraint between 0 and 1
|
||||
bar = chart.bars[0]
|
||||
index = 3
|
||||
bar = chart.bars[index]
|
||||
bar.save_state()
|
||||
prob = probs[0]
|
||||
prob = probs[index]
|
||||
prob.bar = bar
|
||||
max_height = chart.y_axis.get_y(UP) - chart.x_axis.get_y()
|
||||
prob.add_updater(lambda p: p.set_value(p.bar.get_height() / max_height))
|
||||
|
@ -2048,7 +2050,6 @@ class SoftmaxBreakdown(InteractiveScene):
|
|||
Write(equals),
|
||||
FadeOut(one_line),
|
||||
)
|
||||
globals().update(locals())
|
||||
self.play(
|
||||
LaggedStart(*(
|
||||
FadeTransform(pc.copy(), rhs)
|
||||
|
@ -2082,12 +2083,12 @@ class SoftmaxBreakdown(InteractiveScene):
|
|||
vector.to_edge(UP).set_x(2.5)
|
||||
matrix.next_to(vector, LEFT)
|
||||
|
||||
self.play(
|
||||
self.play(LaggedStart(
|
||||
chart_group.animate.scale(0.35).to_corner(DL),
|
||||
FadeOut(sum_group),
|
||||
FadeIn(matrix, lag_ratio=0.01),
|
||||
FadeIn(vector, lag_ratio=0.01),
|
||||
)
|
||||
FadeOut(sum_group, UP),
|
||||
FadeIn(matrix, UP),
|
||||
FadeIn(vector, UP),
|
||||
))
|
||||
eq, rhs = show_matrix_vector_product(self, matrix, vector, x_max=9)
|
||||
self.wait()
|
||||
|
||||
|
@ -2156,10 +2157,7 @@ class SoftmaxBreakdown(InteractiveScene):
|
|||
rhs.target.to_edge(LEFT, buff=1.5)
|
||||
rhs.target.set_y(0)
|
||||
|
||||
softmax_box = Rectangle(
|
||||
width=5,
|
||||
height=rhs.get_height() + 1,
|
||||
)
|
||||
softmax_box = Rectangle(width=5, height=6.5)
|
||||
softmax_box.set_stroke(BLUE, 2)
|
||||
softmax_box.set_fill(BLUE_E, 0.5)
|
||||
in_arrow, out_arrow = Vector(RIGHT).replicate(2)
|
||||
|
@ -2188,22 +2186,250 @@ class SoftmaxBreakdown(InteractiveScene):
|
|||
FadeOut(vector, 3 * LEFT),
|
||||
FadeOut(eq, 3.5 * LEFT),
|
||||
FadeOut(chart_group, DL),
|
||||
TransformFromCopy(chart.bars, bars),
|
||||
GrowArrow(in_arrow),
|
||||
FadeIn(softmax_box, RIGHT),
|
||||
FadeIn(softmax_label, RIGHT),
|
||||
MoveToTarget(rhs),
|
||||
GrowArrow(out_arrow),
|
||||
FadeIn(output, RIGHT),
|
||||
TransformFromCopy(chart.bars, bars),
|
||||
), lag_ratio=0.2, run_time=2)
|
||||
self.wait()
|
||||
|
||||
# Highlight larger and smaller parts
|
||||
rhs_entries = rhs.get_entries()
|
||||
changer = VGroup(rhs_entries, output.get_entries(), bars)
|
||||
changer.save_state()
|
||||
for index in range(4, 0, -1):
|
||||
changer.target = changer.saved_state.copy()
|
||||
changer.target.set_fill(border_width=0)
|
||||
for group in changer.target:
|
||||
for j, elem in enumerate(group):
|
||||
if j != index:
|
||||
elem.fade(0.8)
|
||||
self.play(MoveToTarget(changer))
|
||||
self.wait()
|
||||
self.play(Restore(changer))
|
||||
self.remove(changer)
|
||||
self.add(rhs, output, bars)
|
||||
self.wait()
|
||||
|
||||
# Swap out for variables
|
||||
variables = VGroup(*(
|
||||
Tex(f"x_{{{n}}}", font_size=48).move_to(elem)
|
||||
for n, elem in enumerate(rhs_entries, start=1)
|
||||
))
|
||||
|
||||
self.remove(rhs_entries)
|
||||
self.play(
|
||||
LaggedStart(*(
|
||||
TransformFromCopy(entry, variable, path_arc=PI / 2)
|
||||
for entry, variable in zip(rhs_entries, variables)
|
||||
), lag_ratio=0.1, run_time=1.0)
|
||||
)
|
||||
self.wait()
|
||||
|
||||
# Exponentiate each part
|
||||
exp_parts = VGroup(*(
|
||||
Tex(f"e^{{{var.get_tex()}}}", font_size=48).move_to(var)
|
||||
for var in variables
|
||||
))
|
||||
exp_parts.align_to(softmax_box, LEFT)
|
||||
exp_parts.shift(0.75 * RIGHT)
|
||||
exp_parts.space_out_submobjects(1.5)
|
||||
|
||||
self.play(
|
||||
softmax_label.animate.next_to(softmax_box, UP, buff=0.15),
|
||||
LaggedStart(*(
|
||||
TransformMatchingStrings(var.copy(), exp_part)
|
||||
for var, exp_part in zip(variables, exp_parts)
|
||||
), run_time=1, lag_ratio=0.01)
|
||||
)
|
||||
self.wait()
|
||||
|
||||
# Compute the sum
|
||||
exp_sum = Tex(R"\sum_{n=0}^{N-1} e^{x_{n}}", font_size=42)
|
||||
exp_sum[R"e^{x_{n}}"].scale(1.5, about_edge=LEFT)
|
||||
exp_sum.next_to(softmax_box.get_right(), LEFT, buff=0.75)
|
||||
|
||||
globals().update(locals())
|
||||
lines = VGroup(*(Line(exp_part.get_right(), exp_sum.get_left(), buff=0.1) for exp_part in exp_parts))
|
||||
lines.set_stroke(TEAL, 2)
|
||||
|
||||
self.play(
|
||||
LaggedStart(*(
|
||||
FadeTransform(exp_part.copy(), exp_sum)
|
||||
for exp_part in exp_parts
|
||||
), lag_ratio=0.01),
|
||||
LaggedStartMap(ShowCreation, lines, lag_ratio=0.01),
|
||||
run_time=1
|
||||
)
|
||||
self.wait()
|
||||
self.play(FadeOut(lines))
|
||||
|
||||
# Divide each part by the sum
|
||||
lil_denoms = VGroup()
|
||||
for exp_part in exp_parts:
|
||||
slash = Tex("/").match_height(exp_sum)
|
||||
slash.next_to(exp_sum, LEFT, buff=0)
|
||||
denom = VGroup(slash, exp_sum).copy()
|
||||
denom.set_height(exp_part.get_height() * 1.5)
|
||||
denom.next_to(exp_part, RIGHT, buff=0)
|
||||
lil_denoms.add(denom)
|
||||
lil_denoms.align_to(softmax_box.get_center(), LEFT)
|
||||
|
||||
lines = VGroup(*(Line(exp_sum.get_left(), denom.get_center()) for denom in lil_denoms))
|
||||
lines.set_stroke(TEAL, 1)
|
||||
|
||||
self.remove(exp_sum)
|
||||
self.play(
|
||||
exp_parts.animate.next_to(lil_denoms, LEFT, buff=0),
|
||||
LaggedStart(*(
|
||||
FadeTransform(exp_sum.copy(), denom)
|
||||
for denom in lil_denoms
|
||||
), lag_ratio=0.01),
|
||||
)
|
||||
self.wait()
|
||||
|
||||
# Resize box
|
||||
sm_terms = VGroup(*(
|
||||
VGroup(exp_part, denom)
|
||||
for exp_part, denom in zip(exp_parts, lil_denoms)
|
||||
))
|
||||
sm_terms.generate_target()
|
||||
|
||||
target_height = 5.0
|
||||
full_output = Group(output, bars)
|
||||
full_output.generate_target()
|
||||
full_output.target.set_height(target_height, about_edge=RIGHT)
|
||||
full_output.target.shift(1.5 * LEFT)
|
||||
equals = Tex("=")
|
||||
equals.next_to(full_output.target, LEFT)
|
||||
|
||||
softmax_box.generate_target()
|
||||
softmax_box.target.set_width(3.0, stretch=True)
|
||||
VGroup(softmax_box.target, sm_terms.target).set_height(target_height + 0.5).next_to(equals, LEFT)
|
||||
|
||||
rhs.generate_target()
|
||||
rhs_entries.become(variables)
|
||||
self.remove(variables)
|
||||
rhs.target.set_height(target_height)
|
||||
rhs.target.next_to(softmax_box.target, LEFT, buff=1.5)
|
||||
|
||||
self.play(
|
||||
softmax_label.animate.next_to(softmax_box.target, UP),
|
||||
MoveToTarget(softmax_box),
|
||||
MoveToTarget(sm_terms),
|
||||
MoveToTarget(full_output),
|
||||
MoveToTarget(rhs),
|
||||
FadeTransform(out_arrow, equals),
|
||||
in_arrow.animate.become(
|
||||
Arrow(rhs.target, softmax_box.target).match_style(in_arrow)
|
||||
),
|
||||
)
|
||||
self.wait()
|
||||
|
||||
# Set up updaters
|
||||
output_entries = output.get_entries()
|
||||
bar_width_ratio = bars.get_width() / max(o.get_value() for o in output_entries)
|
||||
temp_tracker = ValueTracker(1)
|
||||
|
||||
def update_outs(output_entries):
|
||||
inputs = [entry.get_value() for entry in rhs_entries]
|
||||
outputs = softmax(inputs, temp_tracker.get_value())
|
||||
for entry, output in zip(output_entries, outputs):
|
||||
entry.set_value(output)
|
||||
|
||||
def update_bars(bars):
|
||||
for bar, entry in zip(bars, output_entries):
|
||||
width = max(bar_width_ratio * entry.get_value(), 1e-3)
|
||||
bar.set_width(width, about_edge=LEFT, stretch=True)
|
||||
|
||||
output_entries.add_updater(update_outs)
|
||||
bars.add_updater(update_bars)
|
||||
|
||||
self.add(bars, output_entries)
|
||||
|
||||
# Tweak values
|
||||
for index, value in [(6, 4.0), (4, 4.2), (1, 0.0), (0, 6.0), (4, 9.9)]:
|
||||
entry = rhs_entries[index]
|
||||
rect = SurroundingRectangle(entry)
|
||||
rect.set_stroke(BLUE if value > entry.get_value() else RED, 3)
|
||||
self.play(
|
||||
ChangeDecimalToValue(entry, value),
|
||||
FadeIn(rect, time_span=(0, 1)),
|
||||
run_time=2
|
||||
)
|
||||
self.play(FadeOut(rect))
|
||||
|
||||
# Add temperature
|
||||
frame = self.frame
|
||||
temp_color = RED
|
||||
new_title = Text("softmax with temperature")
|
||||
new_title["temperature"].set_color(temp_color)
|
||||
get_t = temp_tracker.get_value
|
||||
t_line = NumberLine(
|
||||
(0, 10, 0.2),
|
||||
tick_size=0.025,
|
||||
big_tick_spacing=1,
|
||||
longer_tick_multiple=2.0,
|
||||
width=4
|
||||
)
|
||||
t_line.set_stroke(width=1.5)
|
||||
t_line.next_to(softmax_box, UP)
|
||||
t_tri = ArrowTip(angle=-90 * DEGREES)
|
||||
t_tri.set_color(temp_color)
|
||||
t_tri.set_height(0.2)
|
||||
t_label = Tex("T = 0.00", font_size=36)
|
||||
t_label.rhs = t_label.make_number_changable("0.00")
|
||||
t_label["T"].set_color(temp_color)
|
||||
globals().update(locals())
|
||||
t_tri.add_updater(lambda m: m.move_to(t_line.n2p(get_t()), DOWN))
|
||||
t_label.add_updater(lambda m: m.rhs.set_value(get_t()))
|
||||
t_label.add_updater(lambda m: m.next_to(t_tri, UP, buff=0.1, aligned_edge=LEFT))
|
||||
|
||||
new_title.next_to(t_label, UP, buff=0.5).match_x(softmax_box)
|
||||
|
||||
self.play(
|
||||
frame.animate.move_to(0.75 * UP),
|
||||
TransformMatchingStrings(softmax_label, new_title),
|
||||
FadeIn(t_line),
|
||||
FadeIn(t_tri),
|
||||
FadeIn(t_label),
|
||||
run_time=1
|
||||
)
|
||||
|
||||
# Change formula
|
||||
template = Tex(R"e^{x_{0} / T} / \sum_{n=0}^{N - 1} e^{x_n / T}")
|
||||
template["T"].set_color(temp_color)
|
||||
template["/"][1].scale(1.9, about_edge=LEFT)
|
||||
template[R"\sum_{n=0}^{N - 1}"][0].scale(0.7, about_edge=RIGHT)
|
||||
index_part = template.make_number_changable("0")
|
||||
|
||||
new_sm_terms = VGroup()
|
||||
all_Ts = VGroup()
|
||||
for n, term in enumerate(sm_terms, start=1):
|
||||
template.replace(term, dim_to_match=1)
|
||||
index_part.set_value(n)
|
||||
new_term = template.copy()
|
||||
all_Ts.add(*new_term["T"])
|
||||
new_sm_terms.add(new_term)
|
||||
|
||||
self.play(
|
||||
LaggedStart(*(
|
||||
FadeTransform(old_term, new_term)
|
||||
for old_term, new_term in zip(sm_terms, new_sm_terms)
|
||||
)),
|
||||
LaggedStart(*(
|
||||
TransformFromCopy(t_label[0], t_mob[0])
|
||||
for t_mob in all_Ts
|
||||
)),
|
||||
)
|
||||
self.wait()
|
||||
|
||||
# Oscilate between values
|
||||
self.play(temp_tracker.animate.set_value(3), run_time=3)
|
||||
self.wait()
|
||||
self.play(temp_tracker.animate.set_value(10), run_time=3)
|
||||
self.wait()
|
||||
|
||||
# Comment on largest values
|
||||
|
|
|
@ -84,9 +84,10 @@ def matrix_row_vector_product(scene, row, vector, entry, to_fade):
|
|||
ShowIncreasingSubsets(row_rects),
|
||||
ShowIncreasingSubsets(vect_rects),
|
||||
UpdateFromAlphaFunc(entry, lambda m, a: m.set_value(
|
||||
partial_values[min(int(a * n_values), n_values - 1)]
|
||||
partial_values[min(int(np.round(a * n_values)), n_values - 1)]
|
||||
)),
|
||||
FadeOut(to_fade),
|
||||
rate_func=linear,
|
||||
)
|
||||
|
||||
return VGroup(row_rects, vect_rects)
|
||||
|
|
Loading…
Add table
Reference in a new issue