diff --git a/nn/part3.py b/nn/part3.py index 67188de8..c35c4e91 100644 --- a/nn/part3.py +++ b/nn/part3.py @@ -2002,7 +2002,7 @@ class SimplestNetworkExample(PreviewLearning): "z_color" : GREEN, "cost_color" : RED, "desired_output_color" : YELLOW, - "derivative_scale_vale" : 0.7, + "derivative_scale_val" : 0.85, } def construct(self): self.force_skipping() @@ -2022,6 +2022,8 @@ class SimplestNetworkExample(PreviewLearning): self.show_derivative_wrt_w() self.show_chain_of_events() self.show_chain_rule() + self.name_chain_rule() + self.indicate_everything_on_screen() self.compute_derivatives() self.fire_together_wire_together() self.show_derivative_wrt_b() @@ -2301,6 +2303,7 @@ class SimplestNetworkExample(PreviewLearning): terms = self.weighted_sum_terms terms.generate_target() terms.target.next_to(self.formula, UP, aligned_edge = RIGHT) + terms.target.shift(MED_LARGE_BUFF*RIGHT) equals = TexMobject("=") equals.next_to(terms.target[0][0], LEFT) @@ -2486,6 +2489,10 @@ class SimplestNetworkExample(PreviewLearning): mob.number_line.next_to(mob, UP, MED_LARGE_BUFF, LEFT) else: mob.number_line.next_to(mob, RIGHT) + if mob is C0: + mob.number_line.x_max = 0.5 + for tick_mark in mob.number_line.tick_marks[1::2]: + mob.number_line.tick_marks.remove(tick_mark) mob.dot = Dot(color = mob.get_color()) mob.dot.move_to( mob.number_line.number_to_point(mob.val) @@ -2494,7 +2501,7 @@ class SimplestNetworkExample(PreviewLearning): path_arc = 0 dot_spot = mob.dot.get_bottom() else: - path_arc = -0.8*np.pi + path_arc = -0.7*np.pi dot_spot = mob.dot.get_top() if mob is C0: mob_spot = mob[0].get_corner(UP+RIGHT) @@ -2555,17 +2562,27 @@ class SimplestNetworkExample(PreviewLearning): ] ] + def shake_dot(run_time = 2, rate_func = there_and_back): + self.play( + ApplyMethod( + wL.dot.shift, LEFT, + rate_func = rate_func, + run_time = run_time + ), + *dot_update_anims + ) + wL_line = Line(wL.dot.get_center(), wL.dot.get_center()+LEFT) del_wL = TexMobject("\\partial w^{(L)}") - del_wL.scale(self.derivative_scale_vale) - del_wL.brace = Brace(wL_line, UP) + del_wL.scale(self.derivative_scale_val) + del_wL.brace = Brace(wL_line, UP, buff = SMALL_BUFF) del_wL.highlight(wL.get_color()) del_wL.next_to(del_wL.brace, UP, SMALL_BUFF) C0_line = Line(C0.dot.get_center(), C0.dot.get_center()+MED_SMALL_BUFF*RIGHT) del_C0 = TexMobject("\\partial C_0") - del_C0.scale(self.derivative_scale_vale) - del_C0.brace = Brace(C0_line, UP) + del_C0.scale(self.derivative_scale_val) + del_C0.brace = Brace(C0_line, UP, buff = SMALL_BUFF) del_C0.highlight(C0.get_color()) del_C0.next_to(del_C0.brace, UP, SMALL_BUFF) @@ -2574,32 +2591,189 @@ class SimplestNetworkExample(PreviewLearning): GrowFromCenter(sym.brace), Write(sym, run_time = 1) ) - self.play( - ApplyMethod( - wL.dot.shift, LEFT, - run_time = 2, - rate_func = there_and_back - ), - *dot_update_anims - ) + shake_dot() self.dither() self.set_variables_as_attrs( - dot_update_anims, del_wL, del_C0, + shake_dot, del_wL, del_C0, ) def show_derivative_wrt_w(self): - pass - + del_wL = self.del_wL + del_C0 = self.del_C0 + cost_word = self.cost_word + cost_arrow = self.cost_arrow + shake_dot = self.shake_dot + wL = self.comp_graph.wL + + dC_dw = TexMobject( + "{\\partial C_0", "\\over", "\\partial w^{(L)} }" + ) + dC_dw[0].highlight(del_C0.get_color()) + dC_dw[2].highlight(del_wL.get_color()) + dC_dw.scale(self.derivative_scale_val) + dC_dw.to_edge(UP, buff = MED_SMALL_BUFF) + dC_dw.shift(3.5*LEFT) + + full_rect = SurroundingRectangle(dC_dw) + full_rect_copy = full_rect.copy() + words = TextMobject("What we want") + words.next_to(full_rect, RIGHT) + words.highlight(YELLOW) + + denom_rect = SurroundingRectangle(dC_dw[2]) + numer_rect = SurroundingRectangle(dC_dw[0]) + + self.play( + ReplacementTransform(del_C0.copy(), dC_dw[0]), + ReplacementTransform(del_wL.copy(), dC_dw[2]), + Write(dC_dw[1], run_time = 1) + ) + self.play( + FadeOut(cost_word), + FadeOut(cost_arrow), + ShowCreation(full_rect), + Write(words, run_time = 1), + ) + self.dither(2) + self.play( + FadeOut(words), + ReplacementTransform(full_rect, denom_rect) + ) + self.play(Transform(dC_dw[2].copy(), del_wL, remover = True)) + shake_dot() + self.play(ReplacementTransform(denom_rect, numer_rect)) + self.play(Transform(dC_dw[0].copy(), del_C0, remover = True)) + shake_dot() + self.dither() + self.play(ReplacementTransform(numer_rect, full_rect_copy)) + self.play(FadeOut(full_rect_copy)) + self.dither() + + self.dC_dw = dC_dw def show_chain_of_events(self): - pass + comp_graph = self.comp_graph + wL, zL, aL, C0 = [ + getattr(comp_graph, attr) + for attr in ["wL", "zL", "aL", "C0"] + ] + del_wL = self.del_wL + del_C0 = self.del_C0 + + zL_line = Line(ORIGIN, MED_LARGE_BUFF*LEFT) + zL_line.shift(zL.dot.get_center()) + del_zL = TexMobject("\\partial z^{(L)}") + del_zL.highlight(zL.get_color()) + del_zL.brace = Brace(zL_line, DOWN, buff = SMALL_BUFF) + + aL_line = Line(ORIGIN, MED_SMALL_BUFF*LEFT) + aL_line.shift(aL.dot.get_center()) + del_aL = TexMobject("\\partial a^{(L)}") + del_aL.highlight(aL.get_color()) + del_aL.brace = Brace(aL_line, DOWN, buff = SMALL_BUFF) + + for sym in del_zL, del_aL: + sym.scale(self.derivative_scale_val) + sym.brace.stretch_about_point( + 0.5, 1, sym.brace.get_top(), + ) + sym.shift( + sym.brace.get_bottom()+SMALL_BUFF*DOWN \ + -sym[0].get_corner(UP+RIGHT) + ) + + syms = [del_wL, del_zL, del_aL, del_C0] + for s1, s2 in zip(syms, syms[1:]): + self.play( + ReplacementTransform(s1.copy(), s2), + ReplacementTransform(s1.brace.copy(), s2.brace), + ) + self.shake_dot(run_time = 1.5) + self.dither(0.5) + self.dither() + + self.set_variables_as_attrs(del_zL, del_aL) def show_chain_rule(self): - pass + dC_dw = self.dC_dw + dz_dw = TexMobject( + "{\\partial z^{(L)}", "\\over", "\\partial w^{(L)}}" + ) + da_dz = TexMobject( + "{\\partial a^{(L)}", "\\over", "\\partial z^{(L)}}" + ) + dC_da = TexMobject( + "{\\partial C0}", "\\over", "\\partial a^{(L)}}" + ) + dz_dw[2].highlight(self.del_wL.get_color()) + VGroup(dz_dw[0], da_dz[2]).highlight(self.z_color) + dC_da[0].highlight(self.cost_color) + equals = TexMobject("=") + group = VGroup(equals, dz_dw, da_dz, dC_da) + group.arrange_submobjects(RIGHT, SMALL_BUFF) + group.scale(self.derivative_scale_val) + group.next_to(dC_dw, RIGHT) + for mob in group[1:]: + target_y = equals.get_center()[1] + y = mob[1].get_center()[1] + mob.shift((target_y - y)*UP) + + last_sym = dC_dw[2] + self.play(Write(equals, run_time = 1)) + for fraction in group[1:]: + self.play(LaggedStart( + FadeIn, VGroup(*fraction[:2]), + lag_ratio = 0.75, + run_time = 1 + )) + self.play(ReplacementTransform( + last_sym.copy(), fraction[2] + )) + self.dither() + last_sym = fraction[0] + self.shake_dot() + self.dither() + + self.chain_rule_equation = VGroup(dC_dw, *group) + + def name_chain_rule(self): + graph_parts = self.get_all_comp_graph_parts() + equation = self.chain_rule_equation + rect = SurroundingRectangle(equation) + group = VGroup(equation, rect) + group.generate_target() + group.target.to_corner(UP+LEFT) + words = TextMobject("Chain rule") + words.highlight(YELLOW) + words.next_to(group.target, DOWN) + + self.play(ShowCreation(rect)) + self.play( + MoveToTarget(group), + Write(words, run_time = 1), + graph_parts.scale, 0.7, graph_parts.get_bottom() + ) + self.dither(2) + self.play(*map(FadeOut, [rect, words])) + + def indicate_everything_on_screen(self): + everything = VGroup(*self.get_top_level_mobjects()) + everything = VGroup(*filter( + lambda m : not m.is_subpath, + everything.family_members_with_points() + )) + self.play(LaggedStart( + Indicate, everything, + rate_func = wiggle, + lag_ratio = 0.2, + run_time = 5 + )) + self.dither() def compute_derivatives(self): - pass + + self.play(FadeOut(self.all_comp_graph_parts)) def fire_together_wire_together(self): pass @@ -2623,6 +2797,40 @@ class SimplestNetworkExample(PreviewLearning): decimal.set_fill(BLACK) decimal.move_to(neuron) return decimal + + def get_all_comp_graph_parts(self): + comp_graph = self.comp_graph + result = VGroup(comp_graph) + for attr in "wL", "zL", "aL", "C0": + sym = getattr(comp_graph, attr) + comp_graph.add( + sym.arrow, sym.number_line, sym.dot + ) + del_sym = getattr(self, "del_" + attr) + comp_graph.add(del_sym, del_sym.brace) + + self.all_comp_graph_parts = result + return result + + + + + + + + + + + + + + + + + + + +