mirror of
https://github.com/3b1b/videos.git
synced 2025-09-18 21:38:53 +00:00
Add Softmax scene
This commit is contained in:
parent
ba0743e330
commit
20b3a1b4aa
2 changed files with 244 additions and 17 deletions
|
@ -1965,10 +1965,11 @@ class SoftmaxBreakdown(InteractiveScene):
|
||||||
def construct(self):
|
def construct(self):
|
||||||
# Show example probability distribution
|
# Show example probability distribution
|
||||||
word_strs = ['Dumbledore', 'Flitwick', 'Mcgonagall', 'Quirrell', 'Snape', 'Sprout', 'Trelawney']
|
word_strs = ['Dumbledore', 'Flitwick', 'Mcgonagall', 'Quirrell', 'Snape', 'Sprout', 'Trelawney']
|
||||||
words = VGroup(*map(Text, word_strs))
|
words = VGroup(*(Text(word_str, font_size=30) for word_str in word_strs))
|
||||||
values = np.array([0.3, -1, 0.5, 1.5, 3.4, -1, 2.5])
|
values = np.array([-0.8, -5.0, 0.5, 1.5, 3.4, -2.3, 2.5])
|
||||||
prob_values = softmax(values)
|
prob_values = softmax(values)
|
||||||
chart = BarChart(prob_values, width=10)
|
chart = BarChart(prob_values, width=10)
|
||||||
|
chart.bars.set_stroke(width=1)
|
||||||
|
|
||||||
probs = VGroup(*(DecimalNumber(pv) for pv in prob_values))
|
probs = VGroup(*(DecimalNumber(pv) for pv in prob_values))
|
||||||
probs.arrange(DOWN, buff=0.25)
|
probs.arrange(DOWN, buff=0.25)
|
||||||
|
@ -2002,9 +2003,10 @@ class SoftmaxBreakdown(InteractiveScene):
|
||||||
self.wait()
|
self.wait()
|
||||||
|
|
||||||
# Show constraint between 0 and 1
|
# Show constraint between 0 and 1
|
||||||
bar = chart.bars[0]
|
index = 3
|
||||||
|
bar = chart.bars[index]
|
||||||
bar.save_state()
|
bar.save_state()
|
||||||
prob = probs[0]
|
prob = probs[index]
|
||||||
prob.bar = bar
|
prob.bar = bar
|
||||||
max_height = chart.y_axis.get_y(UP) - chart.x_axis.get_y()
|
max_height = chart.y_axis.get_y(UP) - chart.x_axis.get_y()
|
||||||
prob.add_updater(lambda p: p.set_value(p.bar.get_height() / max_height))
|
prob.add_updater(lambda p: p.set_value(p.bar.get_height() / max_height))
|
||||||
|
@ -2048,7 +2050,6 @@ class SoftmaxBreakdown(InteractiveScene):
|
||||||
Write(equals),
|
Write(equals),
|
||||||
FadeOut(one_line),
|
FadeOut(one_line),
|
||||||
)
|
)
|
||||||
globals().update(locals())
|
|
||||||
self.play(
|
self.play(
|
||||||
LaggedStart(*(
|
LaggedStart(*(
|
||||||
FadeTransform(pc.copy(), rhs)
|
FadeTransform(pc.copy(), rhs)
|
||||||
|
@ -2082,12 +2083,12 @@ class SoftmaxBreakdown(InteractiveScene):
|
||||||
vector.to_edge(UP).set_x(2.5)
|
vector.to_edge(UP).set_x(2.5)
|
||||||
matrix.next_to(vector, LEFT)
|
matrix.next_to(vector, LEFT)
|
||||||
|
|
||||||
self.play(
|
self.play(LaggedStart(
|
||||||
chart_group.animate.scale(0.35).to_corner(DL),
|
chart_group.animate.scale(0.35).to_corner(DL),
|
||||||
FadeOut(sum_group),
|
FadeOut(sum_group, UP),
|
||||||
FadeIn(matrix, lag_ratio=0.01),
|
FadeIn(matrix, UP),
|
||||||
FadeIn(vector, lag_ratio=0.01),
|
FadeIn(vector, UP),
|
||||||
)
|
))
|
||||||
eq, rhs = show_matrix_vector_product(self, matrix, vector, x_max=9)
|
eq, rhs = show_matrix_vector_product(self, matrix, vector, x_max=9)
|
||||||
self.wait()
|
self.wait()
|
||||||
|
|
||||||
|
@ -2156,10 +2157,7 @@ class SoftmaxBreakdown(InteractiveScene):
|
||||||
rhs.target.to_edge(LEFT, buff=1.5)
|
rhs.target.to_edge(LEFT, buff=1.5)
|
||||||
rhs.target.set_y(0)
|
rhs.target.set_y(0)
|
||||||
|
|
||||||
softmax_box = Rectangle(
|
softmax_box = Rectangle(width=5, height=6.5)
|
||||||
width=5,
|
|
||||||
height=rhs.get_height() + 1,
|
|
||||||
)
|
|
||||||
softmax_box.set_stroke(BLUE, 2)
|
softmax_box.set_stroke(BLUE, 2)
|
||||||
softmax_box.set_fill(BLUE_E, 0.5)
|
softmax_box.set_fill(BLUE_E, 0.5)
|
||||||
in_arrow, out_arrow = Vector(RIGHT).replicate(2)
|
in_arrow, out_arrow = Vector(RIGHT).replicate(2)
|
||||||
|
@ -2188,22 +2186,250 @@ class SoftmaxBreakdown(InteractiveScene):
|
||||||
FadeOut(vector, 3 * LEFT),
|
FadeOut(vector, 3 * LEFT),
|
||||||
FadeOut(eq, 3.5 * LEFT),
|
FadeOut(eq, 3.5 * LEFT),
|
||||||
FadeOut(chart_group, DL),
|
FadeOut(chart_group, DL),
|
||||||
TransformFromCopy(chart.bars, bars),
|
|
||||||
GrowArrow(in_arrow),
|
GrowArrow(in_arrow),
|
||||||
FadeIn(softmax_box, RIGHT),
|
FadeIn(softmax_box, RIGHT),
|
||||||
FadeIn(softmax_label, RIGHT),
|
FadeIn(softmax_label, RIGHT),
|
||||||
MoveToTarget(rhs),
|
MoveToTarget(rhs),
|
||||||
GrowArrow(out_arrow),
|
GrowArrow(out_arrow),
|
||||||
FadeIn(output, RIGHT),
|
FadeIn(output, RIGHT),
|
||||||
|
TransformFromCopy(chart.bars, bars),
|
||||||
), lag_ratio=0.2, run_time=2)
|
), lag_ratio=0.2, run_time=2)
|
||||||
self.wait()
|
self.wait()
|
||||||
|
|
||||||
# Highlight larger and smaller parts
|
# Highlight larger and smaller parts
|
||||||
|
rhs_entries = rhs.get_entries()
|
||||||
|
changer = VGroup(rhs_entries, output.get_entries(), bars)
|
||||||
|
changer.save_state()
|
||||||
|
for index in range(4, 0, -1):
|
||||||
|
changer.target = changer.saved_state.copy()
|
||||||
|
changer.target.set_fill(border_width=0)
|
||||||
|
for group in changer.target:
|
||||||
|
for j, elem in enumerate(group):
|
||||||
|
if j != index:
|
||||||
|
elem.fade(0.8)
|
||||||
|
self.play(MoveToTarget(changer))
|
||||||
|
self.wait()
|
||||||
|
self.play(Restore(changer))
|
||||||
|
self.remove(changer)
|
||||||
|
self.add(rhs, output, bars)
|
||||||
|
self.wait()
|
||||||
|
|
||||||
|
# Swap out for variables
|
||||||
|
variables = VGroup(*(
|
||||||
|
Tex(f"x_{{{n}}}", font_size=48).move_to(elem)
|
||||||
|
for n, elem in enumerate(rhs_entries, start=1)
|
||||||
|
))
|
||||||
|
|
||||||
|
self.remove(rhs_entries)
|
||||||
|
self.play(
|
||||||
|
LaggedStart(*(
|
||||||
|
TransformFromCopy(entry, variable, path_arc=PI / 2)
|
||||||
|
for entry, variable in zip(rhs_entries, variables)
|
||||||
|
), lag_ratio=0.1, run_time=1.0)
|
||||||
|
)
|
||||||
|
self.wait()
|
||||||
|
|
||||||
# Exponentiate each part
|
# Exponentiate each part
|
||||||
|
exp_parts = VGroup(*(
|
||||||
|
Tex(f"e^{{{var.get_tex()}}}", font_size=48).move_to(var)
|
||||||
|
for var in variables
|
||||||
|
))
|
||||||
|
exp_parts.align_to(softmax_box, LEFT)
|
||||||
|
exp_parts.shift(0.75 * RIGHT)
|
||||||
|
exp_parts.space_out_submobjects(1.5)
|
||||||
|
|
||||||
|
self.play(
|
||||||
|
softmax_label.animate.next_to(softmax_box, UP, buff=0.15),
|
||||||
|
LaggedStart(*(
|
||||||
|
TransformMatchingStrings(var.copy(), exp_part)
|
||||||
|
for var, exp_part in zip(variables, exp_parts)
|
||||||
|
), run_time=1, lag_ratio=0.01)
|
||||||
|
)
|
||||||
|
self.wait()
|
||||||
|
|
||||||
# Compute the sum
|
# Compute the sum
|
||||||
|
exp_sum = Tex(R"\sum_{n=0}^{N-1} e^{x_{n}}", font_size=42)
|
||||||
|
exp_sum[R"e^{x_{n}}"].scale(1.5, about_edge=LEFT)
|
||||||
|
exp_sum.next_to(softmax_box.get_right(), LEFT, buff=0.75)
|
||||||
|
|
||||||
|
globals().update(locals())
|
||||||
|
lines = VGroup(*(Line(exp_part.get_right(), exp_sum.get_left(), buff=0.1) for exp_part in exp_parts))
|
||||||
|
lines.set_stroke(TEAL, 2)
|
||||||
|
|
||||||
|
self.play(
|
||||||
|
LaggedStart(*(
|
||||||
|
FadeTransform(exp_part.copy(), exp_sum)
|
||||||
|
for exp_part in exp_parts
|
||||||
|
), lag_ratio=0.01),
|
||||||
|
LaggedStartMap(ShowCreation, lines, lag_ratio=0.01),
|
||||||
|
run_time=1
|
||||||
|
)
|
||||||
|
self.wait()
|
||||||
|
self.play(FadeOut(lines))
|
||||||
|
|
||||||
# Divide each part by the sum
|
# Divide each part by the sum
|
||||||
|
lil_denoms = VGroup()
|
||||||
|
for exp_part in exp_parts:
|
||||||
|
slash = Tex("/").match_height(exp_sum)
|
||||||
|
slash.next_to(exp_sum, LEFT, buff=0)
|
||||||
|
denom = VGroup(slash, exp_sum).copy()
|
||||||
|
denom.set_height(exp_part.get_height() * 1.5)
|
||||||
|
denom.next_to(exp_part, RIGHT, buff=0)
|
||||||
|
lil_denoms.add(denom)
|
||||||
|
lil_denoms.align_to(softmax_box.get_center(), LEFT)
|
||||||
|
|
||||||
|
lines = VGroup(*(Line(exp_sum.get_left(), denom.get_center()) for denom in lil_denoms))
|
||||||
|
lines.set_stroke(TEAL, 1)
|
||||||
|
|
||||||
|
self.remove(exp_sum)
|
||||||
|
self.play(
|
||||||
|
exp_parts.animate.next_to(lil_denoms, LEFT, buff=0),
|
||||||
|
LaggedStart(*(
|
||||||
|
FadeTransform(exp_sum.copy(), denom)
|
||||||
|
for denom in lil_denoms
|
||||||
|
), lag_ratio=0.01),
|
||||||
|
)
|
||||||
|
self.wait()
|
||||||
|
|
||||||
|
# Resize box
|
||||||
|
sm_terms = VGroup(*(
|
||||||
|
VGroup(exp_part, denom)
|
||||||
|
for exp_part, denom in zip(exp_parts, lil_denoms)
|
||||||
|
))
|
||||||
|
sm_terms.generate_target()
|
||||||
|
|
||||||
|
target_height = 5.0
|
||||||
|
full_output = Group(output, bars)
|
||||||
|
full_output.generate_target()
|
||||||
|
full_output.target.set_height(target_height, about_edge=RIGHT)
|
||||||
|
full_output.target.shift(1.5 * LEFT)
|
||||||
|
equals = Tex("=")
|
||||||
|
equals.next_to(full_output.target, LEFT)
|
||||||
|
|
||||||
|
softmax_box.generate_target()
|
||||||
|
softmax_box.target.set_width(3.0, stretch=True)
|
||||||
|
VGroup(softmax_box.target, sm_terms.target).set_height(target_height + 0.5).next_to(equals, LEFT)
|
||||||
|
|
||||||
|
rhs.generate_target()
|
||||||
|
rhs_entries.become(variables)
|
||||||
|
self.remove(variables)
|
||||||
|
rhs.target.set_height(target_height)
|
||||||
|
rhs.target.next_to(softmax_box.target, LEFT, buff=1.5)
|
||||||
|
|
||||||
|
self.play(
|
||||||
|
softmax_label.animate.next_to(softmax_box.target, UP),
|
||||||
|
MoveToTarget(softmax_box),
|
||||||
|
MoveToTarget(sm_terms),
|
||||||
|
MoveToTarget(full_output),
|
||||||
|
MoveToTarget(rhs),
|
||||||
|
FadeTransform(out_arrow, equals),
|
||||||
|
in_arrow.animate.become(
|
||||||
|
Arrow(rhs.target, softmax_box.target).match_style(in_arrow)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.wait()
|
||||||
|
|
||||||
|
# Set up updaters
|
||||||
|
output_entries = output.get_entries()
|
||||||
|
bar_width_ratio = bars.get_width() / max(o.get_value() for o in output_entries)
|
||||||
|
temp_tracker = ValueTracker(1)
|
||||||
|
|
||||||
|
def update_outs(output_entries):
|
||||||
|
inputs = [entry.get_value() for entry in rhs_entries]
|
||||||
|
outputs = softmax(inputs, temp_tracker.get_value())
|
||||||
|
for entry, output in zip(output_entries, outputs):
|
||||||
|
entry.set_value(output)
|
||||||
|
|
||||||
|
def update_bars(bars):
|
||||||
|
for bar, entry in zip(bars, output_entries):
|
||||||
|
width = max(bar_width_ratio * entry.get_value(), 1e-3)
|
||||||
|
bar.set_width(width, about_edge=LEFT, stretch=True)
|
||||||
|
|
||||||
|
output_entries.add_updater(update_outs)
|
||||||
|
bars.add_updater(update_bars)
|
||||||
|
|
||||||
|
self.add(bars, output_entries)
|
||||||
|
|
||||||
|
# Tweak values
|
||||||
|
for index, value in [(6, 4.0), (4, 4.2), (1, 0.0), (0, 6.0), (4, 9.9)]:
|
||||||
|
entry = rhs_entries[index]
|
||||||
|
rect = SurroundingRectangle(entry)
|
||||||
|
rect.set_stroke(BLUE if value > entry.get_value() else RED, 3)
|
||||||
|
self.play(
|
||||||
|
ChangeDecimalToValue(entry, value),
|
||||||
|
FadeIn(rect, time_span=(0, 1)),
|
||||||
|
run_time=2
|
||||||
|
)
|
||||||
|
self.play(FadeOut(rect))
|
||||||
|
|
||||||
|
# Add temperature
|
||||||
|
frame = self.frame
|
||||||
|
temp_color = RED
|
||||||
|
new_title = Text("softmax with temperature")
|
||||||
|
new_title["temperature"].set_color(temp_color)
|
||||||
|
get_t = temp_tracker.get_value
|
||||||
|
t_line = NumberLine(
|
||||||
|
(0, 10, 0.2),
|
||||||
|
tick_size=0.025,
|
||||||
|
big_tick_spacing=1,
|
||||||
|
longer_tick_multiple=2.0,
|
||||||
|
width=4
|
||||||
|
)
|
||||||
|
t_line.set_stroke(width=1.5)
|
||||||
|
t_line.next_to(softmax_box, UP)
|
||||||
|
t_tri = ArrowTip(angle=-90 * DEGREES)
|
||||||
|
t_tri.set_color(temp_color)
|
||||||
|
t_tri.set_height(0.2)
|
||||||
|
t_label = Tex("T = 0.00", font_size=36)
|
||||||
|
t_label.rhs = t_label.make_number_changable("0.00")
|
||||||
|
t_label["T"].set_color(temp_color)
|
||||||
|
globals().update(locals())
|
||||||
|
t_tri.add_updater(lambda m: m.move_to(t_line.n2p(get_t()), DOWN))
|
||||||
|
t_label.add_updater(lambda m: m.rhs.set_value(get_t()))
|
||||||
|
t_label.add_updater(lambda m: m.next_to(t_tri, UP, buff=0.1, aligned_edge=LEFT))
|
||||||
|
|
||||||
|
new_title.next_to(t_label, UP, buff=0.5).match_x(softmax_box)
|
||||||
|
|
||||||
|
self.play(
|
||||||
|
frame.animate.move_to(0.75 * UP),
|
||||||
|
TransformMatchingStrings(softmax_label, new_title),
|
||||||
|
FadeIn(t_line),
|
||||||
|
FadeIn(t_tri),
|
||||||
|
FadeIn(t_label),
|
||||||
|
run_time=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Change formula
|
||||||
|
template = Tex(R"e^{x_{0} / T} / \sum_{n=0}^{N - 1} e^{x_n / T}")
|
||||||
|
template["T"].set_color(temp_color)
|
||||||
|
template["/"][1].scale(1.9, about_edge=LEFT)
|
||||||
|
template[R"\sum_{n=0}^{N - 1}"][0].scale(0.7, about_edge=RIGHT)
|
||||||
|
index_part = template.make_number_changable("0")
|
||||||
|
|
||||||
|
new_sm_terms = VGroup()
|
||||||
|
all_Ts = VGroup()
|
||||||
|
for n, term in enumerate(sm_terms, start=1):
|
||||||
|
template.replace(term, dim_to_match=1)
|
||||||
|
index_part.set_value(n)
|
||||||
|
new_term = template.copy()
|
||||||
|
all_Ts.add(*new_term["T"])
|
||||||
|
new_sm_terms.add(new_term)
|
||||||
|
|
||||||
|
self.play(
|
||||||
|
LaggedStart(*(
|
||||||
|
FadeTransform(old_term, new_term)
|
||||||
|
for old_term, new_term in zip(sm_terms, new_sm_terms)
|
||||||
|
)),
|
||||||
|
LaggedStart(*(
|
||||||
|
TransformFromCopy(t_label[0], t_mob[0])
|
||||||
|
for t_mob in all_Ts
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
self.wait()
|
||||||
|
|
||||||
|
# Oscilate between values
|
||||||
|
self.play(temp_tracker.animate.set_value(3), run_time=3)
|
||||||
|
self.wait()
|
||||||
|
self.play(temp_tracker.animate.set_value(10), run_time=3)
|
||||||
|
self.wait()
|
||||||
|
|
||||||
# Comment on largest values
|
|
||||||
|
|
|
@ -84,9 +84,10 @@ def matrix_row_vector_product(scene, row, vector, entry, to_fade):
|
||||||
ShowIncreasingSubsets(row_rects),
|
ShowIncreasingSubsets(row_rects),
|
||||||
ShowIncreasingSubsets(vect_rects),
|
ShowIncreasingSubsets(vect_rects),
|
||||||
UpdateFromAlphaFunc(entry, lambda m, a: m.set_value(
|
UpdateFromAlphaFunc(entry, lambda m, a: m.set_value(
|
||||||
partial_values[min(int(a * n_values), n_values - 1)]
|
partial_values[min(int(np.round(a * n_values)), n_values - 1)]
|
||||||
)),
|
)),
|
||||||
FadeOut(to_fade),
|
FadeOut(to_fade),
|
||||||
|
rate_func=linear,
|
||||||
)
|
)
|
||||||
|
|
||||||
return VGroup(row_rects, vect_rects)
|
return VGroup(row_rects, vect_rects)
|
||||||
|
|
Loading…
Add table
Reference in a new issue