Up through chain rule in SimplestNetworkExample of nn/part3

This commit is contained in:
Grant Sanderson 2017-10-27 15:12:29 -07:00
parent 76ea80230d
commit ac079f182a

View file

@ -2002,7 +2002,7 @@ class SimplestNetworkExample(PreviewLearning):
"z_color" : GREEN,
"cost_color" : RED,
"desired_output_color" : YELLOW,
"derivative_scale_vale" : 0.7,
"derivative_scale_val" : 0.85,
}
def construct(self):
self.force_skipping()
@ -2022,6 +2022,8 @@ class SimplestNetworkExample(PreviewLearning):
self.show_derivative_wrt_w()
self.show_chain_of_events()
self.show_chain_rule()
self.name_chain_rule()
self.indicate_everything_on_screen()
self.compute_derivatives()
self.fire_together_wire_together()
self.show_derivative_wrt_b()
@ -2301,6 +2303,7 @@ class SimplestNetworkExample(PreviewLearning):
terms = self.weighted_sum_terms
terms.generate_target()
terms.target.next_to(self.formula, UP, aligned_edge = RIGHT)
terms.target.shift(MED_LARGE_BUFF*RIGHT)
equals = TexMobject("=")
equals.next_to(terms.target[0][0], LEFT)
@ -2486,6 +2489,10 @@ class SimplestNetworkExample(PreviewLearning):
mob.number_line.next_to(mob, UP, MED_LARGE_BUFF, LEFT)
else:
mob.number_line.next_to(mob, RIGHT)
if mob is C0:
mob.number_line.x_max = 0.5
for tick_mark in mob.number_line.tick_marks[1::2]:
mob.number_line.tick_marks.remove(tick_mark)
mob.dot = Dot(color = mob.get_color())
mob.dot.move_to(
mob.number_line.number_to_point(mob.val)
@ -2494,7 +2501,7 @@ class SimplestNetworkExample(PreviewLearning):
path_arc = 0
dot_spot = mob.dot.get_bottom()
else:
path_arc = -0.8*np.pi
path_arc = -0.7*np.pi
dot_spot = mob.dot.get_top()
if mob is C0:
mob_spot = mob[0].get_corner(UP+RIGHT)
@ -2555,17 +2562,27 @@ class SimplestNetworkExample(PreviewLearning):
]
]
def shake_dot(run_time = 2, rate_func = there_and_back):
self.play(
ApplyMethod(
wL.dot.shift, LEFT,
rate_func = rate_func,
run_time = run_time
),
*dot_update_anims
)
wL_line = Line(wL.dot.get_center(), wL.dot.get_center()+LEFT)
del_wL = TexMobject("\\partial w^{(L)}")
del_wL.scale(self.derivative_scale_vale)
del_wL.brace = Brace(wL_line, UP)
del_wL.scale(self.derivative_scale_val)
del_wL.brace = Brace(wL_line, UP, buff = SMALL_BUFF)
del_wL.highlight(wL.get_color())
del_wL.next_to(del_wL.brace, UP, SMALL_BUFF)
C0_line = Line(C0.dot.get_center(), C0.dot.get_center()+MED_SMALL_BUFF*RIGHT)
del_C0 = TexMobject("\\partial C_0")
del_C0.scale(self.derivative_scale_vale)
del_C0.brace = Brace(C0_line, UP)
del_C0.scale(self.derivative_scale_val)
del_C0.brace = Brace(C0_line, UP, buff = SMALL_BUFF)
del_C0.highlight(C0.get_color())
del_C0.next_to(del_C0.brace, UP, SMALL_BUFF)
@ -2574,32 +2591,189 @@ class SimplestNetworkExample(PreviewLearning):
GrowFromCenter(sym.brace),
Write(sym, run_time = 1)
)
self.play(
ApplyMethod(
wL.dot.shift, LEFT,
run_time = 2,
rate_func = there_and_back
),
*dot_update_anims
)
shake_dot()
self.dither()
self.set_variables_as_attrs(
dot_update_anims, del_wL, del_C0,
shake_dot, del_wL, del_C0,
)
def show_derivative_wrt_w(self):
pass
del_wL = self.del_wL
del_C0 = self.del_C0
cost_word = self.cost_word
cost_arrow = self.cost_arrow
shake_dot = self.shake_dot
wL = self.comp_graph.wL
dC_dw = TexMobject(
"{\\partial C_0", "\\over", "\\partial w^{(L)} }"
)
dC_dw[0].highlight(del_C0.get_color())
dC_dw[2].highlight(del_wL.get_color())
dC_dw.scale(self.derivative_scale_val)
dC_dw.to_edge(UP, buff = MED_SMALL_BUFF)
dC_dw.shift(3.5*LEFT)
full_rect = SurroundingRectangle(dC_dw)
full_rect_copy = full_rect.copy()
words = TextMobject("What we want")
words.next_to(full_rect, RIGHT)
words.highlight(YELLOW)
denom_rect = SurroundingRectangle(dC_dw[2])
numer_rect = SurroundingRectangle(dC_dw[0])
self.play(
ReplacementTransform(del_C0.copy(), dC_dw[0]),
ReplacementTransform(del_wL.copy(), dC_dw[2]),
Write(dC_dw[1], run_time = 1)
)
self.play(
FadeOut(cost_word),
FadeOut(cost_arrow),
ShowCreation(full_rect),
Write(words, run_time = 1),
)
self.dither(2)
self.play(
FadeOut(words),
ReplacementTransform(full_rect, denom_rect)
)
self.play(Transform(dC_dw[2].copy(), del_wL, remover = True))
shake_dot()
self.play(ReplacementTransform(denom_rect, numer_rect))
self.play(Transform(dC_dw[0].copy(), del_C0, remover = True))
shake_dot()
self.dither()
self.play(ReplacementTransform(numer_rect, full_rect_copy))
self.play(FadeOut(full_rect_copy))
self.dither()
self.dC_dw = dC_dw
def show_chain_of_events(self):
pass
comp_graph = self.comp_graph
wL, zL, aL, C0 = [
getattr(comp_graph, attr)
for attr in ["wL", "zL", "aL", "C0"]
]
del_wL = self.del_wL
del_C0 = self.del_C0
zL_line = Line(ORIGIN, MED_LARGE_BUFF*LEFT)
zL_line.shift(zL.dot.get_center())
del_zL = TexMobject("\\partial z^{(L)}")
del_zL.highlight(zL.get_color())
del_zL.brace = Brace(zL_line, DOWN, buff = SMALL_BUFF)
aL_line = Line(ORIGIN, MED_SMALL_BUFF*LEFT)
aL_line.shift(aL.dot.get_center())
del_aL = TexMobject("\\partial a^{(L)}")
del_aL.highlight(aL.get_color())
del_aL.brace = Brace(aL_line, DOWN, buff = SMALL_BUFF)
for sym in del_zL, del_aL:
sym.scale(self.derivative_scale_val)
sym.brace.stretch_about_point(
0.5, 1, sym.brace.get_top(),
)
sym.shift(
sym.brace.get_bottom()+SMALL_BUFF*DOWN \
-sym[0].get_corner(UP+RIGHT)
)
syms = [del_wL, del_zL, del_aL, del_C0]
for s1, s2 in zip(syms, syms[1:]):
self.play(
ReplacementTransform(s1.copy(), s2),
ReplacementTransform(s1.brace.copy(), s2.brace),
)
self.shake_dot(run_time = 1.5)
self.dither(0.5)
self.dither()
self.set_variables_as_attrs(del_zL, del_aL)
def show_chain_rule(self):
pass
dC_dw = self.dC_dw
dz_dw = TexMobject(
"{\\partial z^{(L)}", "\\over", "\\partial w^{(L)}}"
)
da_dz = TexMobject(
"{\\partial a^{(L)}", "\\over", "\\partial z^{(L)}}"
)
dC_da = TexMobject(
"{\\partial C0}", "\\over", "\\partial a^{(L)}}"
)
dz_dw[2].highlight(self.del_wL.get_color())
VGroup(dz_dw[0], da_dz[2]).highlight(self.z_color)
dC_da[0].highlight(self.cost_color)
equals = TexMobject("=")
group = VGroup(equals, dz_dw, da_dz, dC_da)
group.arrange_submobjects(RIGHT, SMALL_BUFF)
group.scale(self.derivative_scale_val)
group.next_to(dC_dw, RIGHT)
for mob in group[1:]:
target_y = equals.get_center()[1]
y = mob[1].get_center()[1]
mob.shift((target_y - y)*UP)
last_sym = dC_dw[2]
self.play(Write(equals, run_time = 1))
for fraction in group[1:]:
self.play(LaggedStart(
FadeIn, VGroup(*fraction[:2]),
lag_ratio = 0.75,
run_time = 1
))
self.play(ReplacementTransform(
last_sym.copy(), fraction[2]
))
self.dither()
last_sym = fraction[0]
self.shake_dot()
self.dither()
self.chain_rule_equation = VGroup(dC_dw, *group)
def name_chain_rule(self):
graph_parts = self.get_all_comp_graph_parts()
equation = self.chain_rule_equation
rect = SurroundingRectangle(equation)
group = VGroup(equation, rect)
group.generate_target()
group.target.to_corner(UP+LEFT)
words = TextMobject("Chain rule")
words.highlight(YELLOW)
words.next_to(group.target, DOWN)
self.play(ShowCreation(rect))
self.play(
MoveToTarget(group),
Write(words, run_time = 1),
graph_parts.scale, 0.7, graph_parts.get_bottom()
)
self.dither(2)
self.play(*map(FadeOut, [rect, words]))
def indicate_everything_on_screen(self):
everything = VGroup(*self.get_top_level_mobjects())
everything = VGroup(*filter(
lambda m : not m.is_subpath,
everything.family_members_with_points()
))
self.play(LaggedStart(
Indicate, everything,
rate_func = wiggle,
lag_ratio = 0.2,
run_time = 5
))
self.dither()
def compute_derivatives(self):
pass
self.play(FadeOut(self.all_comp_graph_parts))
def fire_together_wire_together(self):
pass
@ -2623,6 +2797,40 @@ class SimplestNetworkExample(PreviewLearning):
decimal.set_fill(BLACK)
decimal.move_to(neuron)
return decimal
def get_all_comp_graph_parts(self):
comp_graph = self.comp_graph
result = VGroup(comp_graph)
for attr in "wL", "zL", "aL", "C0":
sym = getattr(comp_graph, attr)
comp_graph.add(
sym.arrow, sym.number_line, sym.dot
)
del_sym = getattr(self, "del_" + attr)
comp_graph.add(del_sym, del_sym.brace)
self.all_comp_graph_parts = result
return result