mirror of
https://github.com/3b1b/manim.git
synced 2025-08-05 16:49:03 +00:00
Up through chain rule in SimplestNetworkExample of nn/part3
This commit is contained in:
parent
76ea80230d
commit
ac079f182a
1 changed files with 228 additions and 20 deletions
248
nn/part3.py
248
nn/part3.py
|
@ -2002,7 +2002,7 @@ class SimplestNetworkExample(PreviewLearning):
|
|||
"z_color" : GREEN,
|
||||
"cost_color" : RED,
|
||||
"desired_output_color" : YELLOW,
|
||||
"derivative_scale_vale" : 0.7,
|
||||
"derivative_scale_val" : 0.85,
|
||||
}
|
||||
def construct(self):
|
||||
self.force_skipping()
|
||||
|
@ -2022,6 +2022,8 @@ class SimplestNetworkExample(PreviewLearning):
|
|||
self.show_derivative_wrt_w()
|
||||
self.show_chain_of_events()
|
||||
self.show_chain_rule()
|
||||
self.name_chain_rule()
|
||||
self.indicate_everything_on_screen()
|
||||
self.compute_derivatives()
|
||||
self.fire_together_wire_together()
|
||||
self.show_derivative_wrt_b()
|
||||
|
@ -2301,6 +2303,7 @@ class SimplestNetworkExample(PreviewLearning):
|
|||
terms = self.weighted_sum_terms
|
||||
terms.generate_target()
|
||||
terms.target.next_to(self.formula, UP, aligned_edge = RIGHT)
|
||||
terms.target.shift(MED_LARGE_BUFF*RIGHT)
|
||||
equals = TexMobject("=")
|
||||
equals.next_to(terms.target[0][0], LEFT)
|
||||
|
||||
|
@ -2486,6 +2489,10 @@ class SimplestNetworkExample(PreviewLearning):
|
|||
mob.number_line.next_to(mob, UP, MED_LARGE_BUFF, LEFT)
|
||||
else:
|
||||
mob.number_line.next_to(mob, RIGHT)
|
||||
if mob is C0:
|
||||
mob.number_line.x_max = 0.5
|
||||
for tick_mark in mob.number_line.tick_marks[1::2]:
|
||||
mob.number_line.tick_marks.remove(tick_mark)
|
||||
mob.dot = Dot(color = mob.get_color())
|
||||
mob.dot.move_to(
|
||||
mob.number_line.number_to_point(mob.val)
|
||||
|
@ -2494,7 +2501,7 @@ class SimplestNetworkExample(PreviewLearning):
|
|||
path_arc = 0
|
||||
dot_spot = mob.dot.get_bottom()
|
||||
else:
|
||||
path_arc = -0.8*np.pi
|
||||
path_arc = -0.7*np.pi
|
||||
dot_spot = mob.dot.get_top()
|
||||
if mob is C0:
|
||||
mob_spot = mob[0].get_corner(UP+RIGHT)
|
||||
|
@ -2555,17 +2562,27 @@ class SimplestNetworkExample(PreviewLearning):
|
|||
]
|
||||
]
|
||||
|
||||
def shake_dot(run_time = 2, rate_func = there_and_back):
|
||||
self.play(
|
||||
ApplyMethod(
|
||||
wL.dot.shift, LEFT,
|
||||
rate_func = rate_func,
|
||||
run_time = run_time
|
||||
),
|
||||
*dot_update_anims
|
||||
)
|
||||
|
||||
wL_line = Line(wL.dot.get_center(), wL.dot.get_center()+LEFT)
|
||||
del_wL = TexMobject("\\partial w^{(L)}")
|
||||
del_wL.scale(self.derivative_scale_vale)
|
||||
del_wL.brace = Brace(wL_line, UP)
|
||||
del_wL.scale(self.derivative_scale_val)
|
||||
del_wL.brace = Brace(wL_line, UP, buff = SMALL_BUFF)
|
||||
del_wL.highlight(wL.get_color())
|
||||
del_wL.next_to(del_wL.brace, UP, SMALL_BUFF)
|
||||
|
||||
C0_line = Line(C0.dot.get_center(), C0.dot.get_center()+MED_SMALL_BUFF*RIGHT)
|
||||
del_C0 = TexMobject("\\partial C_0")
|
||||
del_C0.scale(self.derivative_scale_vale)
|
||||
del_C0.brace = Brace(C0_line, UP)
|
||||
del_C0.scale(self.derivative_scale_val)
|
||||
del_C0.brace = Brace(C0_line, UP, buff = SMALL_BUFF)
|
||||
del_C0.highlight(C0.get_color())
|
||||
del_C0.next_to(del_C0.brace, UP, SMALL_BUFF)
|
||||
|
||||
|
@ -2574,32 +2591,189 @@ class SimplestNetworkExample(PreviewLearning):
|
|||
GrowFromCenter(sym.brace),
|
||||
Write(sym, run_time = 1)
|
||||
)
|
||||
self.play(
|
||||
ApplyMethod(
|
||||
wL.dot.shift, LEFT,
|
||||
run_time = 2,
|
||||
rate_func = there_and_back
|
||||
),
|
||||
*dot_update_anims
|
||||
)
|
||||
shake_dot()
|
||||
self.dither()
|
||||
|
||||
self.set_variables_as_attrs(
|
||||
dot_update_anims, del_wL, del_C0,
|
||||
shake_dot, del_wL, del_C0,
|
||||
)
|
||||
|
||||
def show_derivative_wrt_w(self):
|
||||
pass
|
||||
|
||||
del_wL = self.del_wL
|
||||
del_C0 = self.del_C0
|
||||
cost_word = self.cost_word
|
||||
cost_arrow = self.cost_arrow
|
||||
shake_dot = self.shake_dot
|
||||
wL = self.comp_graph.wL
|
||||
|
||||
dC_dw = TexMobject(
|
||||
"{\\partial C_0", "\\over", "\\partial w^{(L)} }"
|
||||
)
|
||||
dC_dw[0].highlight(del_C0.get_color())
|
||||
dC_dw[2].highlight(del_wL.get_color())
|
||||
dC_dw.scale(self.derivative_scale_val)
|
||||
dC_dw.to_edge(UP, buff = MED_SMALL_BUFF)
|
||||
dC_dw.shift(3.5*LEFT)
|
||||
|
||||
full_rect = SurroundingRectangle(dC_dw)
|
||||
full_rect_copy = full_rect.copy()
|
||||
words = TextMobject("What we want")
|
||||
words.next_to(full_rect, RIGHT)
|
||||
words.highlight(YELLOW)
|
||||
|
||||
denom_rect = SurroundingRectangle(dC_dw[2])
|
||||
numer_rect = SurroundingRectangle(dC_dw[0])
|
||||
|
||||
self.play(
|
||||
ReplacementTransform(del_C0.copy(), dC_dw[0]),
|
||||
ReplacementTransform(del_wL.copy(), dC_dw[2]),
|
||||
Write(dC_dw[1], run_time = 1)
|
||||
)
|
||||
self.play(
|
||||
FadeOut(cost_word),
|
||||
FadeOut(cost_arrow),
|
||||
ShowCreation(full_rect),
|
||||
Write(words, run_time = 1),
|
||||
)
|
||||
self.dither(2)
|
||||
self.play(
|
||||
FadeOut(words),
|
||||
ReplacementTransform(full_rect, denom_rect)
|
||||
)
|
||||
self.play(Transform(dC_dw[2].copy(), del_wL, remover = True))
|
||||
shake_dot()
|
||||
self.play(ReplacementTransform(denom_rect, numer_rect))
|
||||
self.play(Transform(dC_dw[0].copy(), del_C0, remover = True))
|
||||
shake_dot()
|
||||
self.dither()
|
||||
self.play(ReplacementTransform(numer_rect, full_rect_copy))
|
||||
self.play(FadeOut(full_rect_copy))
|
||||
self.dither()
|
||||
|
||||
self.dC_dw = dC_dw
|
||||
|
||||
def show_chain_of_events(self):
|
||||
pass
|
||||
comp_graph = self.comp_graph
|
||||
wL, zL, aL, C0 = [
|
||||
getattr(comp_graph, attr)
|
||||
for attr in ["wL", "zL", "aL", "C0"]
|
||||
]
|
||||
del_wL = self.del_wL
|
||||
del_C0 = self.del_C0
|
||||
|
||||
zL_line = Line(ORIGIN, MED_LARGE_BUFF*LEFT)
|
||||
zL_line.shift(zL.dot.get_center())
|
||||
del_zL = TexMobject("\\partial z^{(L)}")
|
||||
del_zL.highlight(zL.get_color())
|
||||
del_zL.brace = Brace(zL_line, DOWN, buff = SMALL_BUFF)
|
||||
|
||||
aL_line = Line(ORIGIN, MED_SMALL_BUFF*LEFT)
|
||||
aL_line.shift(aL.dot.get_center())
|
||||
del_aL = TexMobject("\\partial a^{(L)}")
|
||||
del_aL.highlight(aL.get_color())
|
||||
del_aL.brace = Brace(aL_line, DOWN, buff = SMALL_BUFF)
|
||||
|
||||
for sym in del_zL, del_aL:
|
||||
sym.scale(self.derivative_scale_val)
|
||||
sym.brace.stretch_about_point(
|
||||
0.5, 1, sym.brace.get_top(),
|
||||
)
|
||||
sym.shift(
|
||||
sym.brace.get_bottom()+SMALL_BUFF*DOWN \
|
||||
-sym[0].get_corner(UP+RIGHT)
|
||||
)
|
||||
|
||||
syms = [del_wL, del_zL, del_aL, del_C0]
|
||||
for s1, s2 in zip(syms, syms[1:]):
|
||||
self.play(
|
||||
ReplacementTransform(s1.copy(), s2),
|
||||
ReplacementTransform(s1.brace.copy(), s2.brace),
|
||||
)
|
||||
self.shake_dot(run_time = 1.5)
|
||||
self.dither(0.5)
|
||||
self.dither()
|
||||
|
||||
self.set_variables_as_attrs(del_zL, del_aL)
|
||||
|
||||
def show_chain_rule(self):
|
||||
pass
|
||||
dC_dw = self.dC_dw
|
||||
dz_dw = TexMobject(
|
||||
"{\\partial z^{(L)}", "\\over", "\\partial w^{(L)}}"
|
||||
)
|
||||
da_dz = TexMobject(
|
||||
"{\\partial a^{(L)}", "\\over", "\\partial z^{(L)}}"
|
||||
)
|
||||
dC_da = TexMobject(
|
||||
"{\\partial C0}", "\\over", "\\partial a^{(L)}}"
|
||||
)
|
||||
dz_dw[2].highlight(self.del_wL.get_color())
|
||||
VGroup(dz_dw[0], da_dz[2]).highlight(self.z_color)
|
||||
dC_da[0].highlight(self.cost_color)
|
||||
equals = TexMobject("=")
|
||||
group = VGroup(equals, dz_dw, da_dz, dC_da)
|
||||
group.arrange_submobjects(RIGHT, SMALL_BUFF)
|
||||
group.scale(self.derivative_scale_val)
|
||||
group.next_to(dC_dw, RIGHT)
|
||||
for mob in group[1:]:
|
||||
target_y = equals.get_center()[1]
|
||||
y = mob[1].get_center()[1]
|
||||
mob.shift((target_y - y)*UP)
|
||||
|
||||
last_sym = dC_dw[2]
|
||||
self.play(Write(equals, run_time = 1))
|
||||
for fraction in group[1:]:
|
||||
self.play(LaggedStart(
|
||||
FadeIn, VGroup(*fraction[:2]),
|
||||
lag_ratio = 0.75,
|
||||
run_time = 1
|
||||
))
|
||||
self.play(ReplacementTransform(
|
||||
last_sym.copy(), fraction[2]
|
||||
))
|
||||
self.dither()
|
||||
last_sym = fraction[0]
|
||||
self.shake_dot()
|
||||
self.dither()
|
||||
|
||||
self.chain_rule_equation = VGroup(dC_dw, *group)
|
||||
|
||||
def name_chain_rule(self):
|
||||
graph_parts = self.get_all_comp_graph_parts()
|
||||
equation = self.chain_rule_equation
|
||||
rect = SurroundingRectangle(equation)
|
||||
group = VGroup(equation, rect)
|
||||
group.generate_target()
|
||||
group.target.to_corner(UP+LEFT)
|
||||
words = TextMobject("Chain rule")
|
||||
words.highlight(YELLOW)
|
||||
words.next_to(group.target, DOWN)
|
||||
|
||||
self.play(ShowCreation(rect))
|
||||
self.play(
|
||||
MoveToTarget(group),
|
||||
Write(words, run_time = 1),
|
||||
graph_parts.scale, 0.7, graph_parts.get_bottom()
|
||||
)
|
||||
self.dither(2)
|
||||
self.play(*map(FadeOut, [rect, words]))
|
||||
|
||||
def indicate_everything_on_screen(self):
|
||||
everything = VGroup(*self.get_top_level_mobjects())
|
||||
everything = VGroup(*filter(
|
||||
lambda m : not m.is_subpath,
|
||||
everything.family_members_with_points()
|
||||
))
|
||||
self.play(LaggedStart(
|
||||
Indicate, everything,
|
||||
rate_func = wiggle,
|
||||
lag_ratio = 0.2,
|
||||
run_time = 5
|
||||
))
|
||||
self.dither()
|
||||
|
||||
def compute_derivatives(self):
|
||||
pass
|
||||
|
||||
self.play(FadeOut(self.all_comp_graph_parts))
|
||||
|
||||
def fire_together_wire_together(self):
|
||||
pass
|
||||
|
@ -2623,6 +2797,40 @@ class SimplestNetworkExample(PreviewLearning):
|
|||
decimal.set_fill(BLACK)
|
||||
decimal.move_to(neuron)
|
||||
return decimal
|
||||
|
||||
def get_all_comp_graph_parts(self):
|
||||
comp_graph = self.comp_graph
|
||||
result = VGroup(comp_graph)
|
||||
for attr in "wL", "zL", "aL", "C0":
|
||||
sym = getattr(comp_graph, attr)
|
||||
comp_graph.add(
|
||||
sym.arrow, sym.number_line, sym.dot
|
||||
)
|
||||
del_sym = getattr(self, "del_" + attr)
|
||||
comp_graph.add(del_sym, del_sym.brace)
|
||||
|
||||
self.all_comp_graph_parts = result
|
||||
return result
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue