diff --git a/nn/part3.py b/nn/part3.py index ae4d5c35..cea4dff3 100644 --- a/nn/part3.py +++ b/nn/part3.py @@ -2002,6 +2002,7 @@ class SimplestNetworkExample(PreviewLearning): "z_color" : GREEN, "cost_color" : RED, "desired_output_color" : YELLOW, + "derivative_scale_vale" : 0.7, } def construct(self): self.force_skipping() @@ -2101,10 +2102,10 @@ class SimplestNetworkExample(PreviewLearning): self.dither() def label_neurons(self): - neurons = [ + neurons = VGroup(*[ self.network_mob.layers[i].neurons[0] for i in -1, -2 - ] + ]) decimals = VGroup() a_labels = VGroup() a_label_arrows = VGroup() @@ -2167,7 +2168,8 @@ class SimplestNetworkExample(PreviewLearning): self.play(*map(FadeOut, [not_exponents, superscript_rects])) self.set_variables_as_attrs( - a_labels, a_label_arrows, decimals + a_labels, a_label_arrows, decimals, + last_neurons = neurons ) def show_desired_output(self): @@ -2290,41 +2292,301 @@ class SimplestNetworkExample(PreviewLearning): self.play(sigma_group.restore) self.dither() - weighted_sum_terms = VGroup(wL, aLm1, bL) + weighted_sum_terms = VGroup(wL, aLm1, plus, bL) self.set_variables_as_attrs( formula, weighted_sum_terms ) def introduce_z(self): terms = self.weighted_sum_terms - brace = Brace(terms, UP, buff = SMALL_BUFF) + terms.generate_target() + terms.target.next_to(self.formula, UP, aligned_edge = RIGHT) + equals = TexMobject("=") + equals.next_to(terms.target[0][0], LEFT) + z_label = TexMobject("z^{(L)}") - z_label.next_to(brace, UP, buff = SMALL_BUFF) + z_label.next_to(equals, LEFT) + z_label.align_to(terms.target, DOWN) z_label.highlight(self.z_color) - rect = SurroundingRectangle(terms) - rect.highlight(GREEN) + z_label2 = z_label.copy() + + aL_start = VGroup(*self.formula[:4]) + aL_start.generate_target() + aL_start.target.align_to(z_label, LEFT) + z_label2.next_to(aL_start.target, RIGHT, SMALL_BUFF) + z_label2.align_to(aL_start.target[0], DOWN) + rp = self.formula[-1] + rp.generate_target() + rp.target.next_to(z_label2, RIGHT, SMALL_BUFF) + rp.target.align_to(aL_start.target, DOWN) + + self.play(MoveToTarget(terms)) + self.play(Write(z_label), Write(equals)) + self.play( + ReplacementTransform(z_label.copy(), z_label2), + MoveToTarget(aL_start), + MoveToTarget(rp), + ) + self.dither() + + zL_formula = VGroup(z_label, equals, terms) + aL_formula = VGroup(aL_start, z_label2, rp) + self.set_variables_as_attrs(z_label, zL_formula, aL_formula) + + def break_into_computational_graph(self): + network_early_layers = VGroup(*it.chain( + self.network_mob.layers[:2], + self.network_mob.edge_groups[:2] + )) + + wL, aL, plus, bL = self.weighted_sum_terms + top_terms = VGroup(wL, aL, bL).copy() + zL = self.z_label.copy() + aL = self.formula[0].copy() + y = self.y_label.copy() + C0 = self.cost_equation[0].copy() + targets = VGroup() + for mob in top_terms, zL, aL, C0: + mob.generate_target() + targets.add(mob.target) + y.generate_target() + top_terms.target.arrange_submobjects(RIGHT, buff = MED_LARGE_BUFF) + targets.arrange_submobjects(DOWN, buff = LARGE_BUFF) + targets.center().to_corner(DOWN+LEFT) + y.target.next_to(aL.target, LEFT, LARGE_BUFF, DOWN) + + top_lines = VGroup(*[ + Line( + term.get_bottom(), + zL.target.get_top(), + buff = SMALL_BUFF + ) + for term in top_terms.target + ]) + z_to_a_line, a_to_c_line, y_to_c_line = all_lines = [ + Line( + m1.target.get_bottom(), + m2.target.get_top(), + buff = SMALL_BUFF + ) + for m1, m2 in [ + (zL, aL), + (aL, C0), + (y, C0) + ] + ] + for mob in [top_lines] + all_lines: + yellow_copy = mob.copy().highlight(YELLOW) + mob.flash = ShowCreationThenDestruction(yellow_copy) + + self.play(MoveToTarget(top_terms)) + self.dither() + self.play(MoveToTarget(zL)) + self.play( + ShowCreation(top_lines, submobject_mode = "all_at_once"), + top_lines.flash + ) + self.dither() + self.play(MoveToTarget(aL)) + self.play( + FadeOut(network_early_layers), + ShowCreation(z_to_a_line), + z_to_a_line.flash + ) + self.dither() + self.play(MoveToTarget(y)) + self.play(MoveToTarget(C0)) + self.play(*it.chain(*[ + [ShowCreation(line), line.flash] + for line in a_to_c_line, y_to_c_line + ])) + self.dither(2) + + comp_graph = VGroup() + comp_graph.wL, comp_graph.aLm1, comp_graph.bL = top_terms + comp_graph.top_lines = top_lines + comp_graph.zL = zL + comp_graph.z_to_a_line = z_to_a_line + comp_graph.aL = aL + comp_graph.y = y + comp_graph.a_to_c_line = a_to_c_line + comp_graph.y_to_c_line = y_to_c_line + comp_graph.C0 = C0 + comp_graph.digest_mobject_attrs() + self.comp_graph = comp_graph + + def show_preceding_layer_in_computational_graph(self): + shift_vect = DOWN + comp_graph = self.comp_graph + comp_graph.save_state() + comp_graph.generate_target() + comp_graph.target.shift(shift_vect) + rect = SurroundingRectangle(comp_graph.aLm1) + + attrs = ["wL", "aLm1", "bL", "zL"] + new_terms = VGroup() + for attr in attrs: + term = getattr(comp_graph, attr) + tex = term.get_tex_string() + if "L-1" in tex: + tex = tex.replace("L-1", "L-2") + else: + tex = tex.replace("L", "L-1") + new_term = TexMobject(tex) + new_term.highlight(term.get_color()) + new_term.move_to(term) + new_terms.add(new_term) + new_edges = VGroup( + comp_graph.top_lines.copy(), + comp_graph.z_to_a_line.copy(), + ) + new_subgraph = VGroup(new_terms, new_edges) self.play(ShowCreation(rect)) self.play( - GrowFromCenter(brace), - Write(z_label), + new_subgraph.next_to, comp_graph.target, UP, SMALL_BUFF, + UpdateFromAlphaFunc( + new_terms, + lambda m, a : m.set_fill(opacity = a) + ), + MoveToTarget(comp_graph), + rect.shift, shift_vect ) - self.play(FadeOut(rect)) + self.dither(2) + self.play( + FadeOut(new_subgraph), + comp_graph.restore, + rect.shift, -shift_vect, + rect.set_stroke, BLACK, 0 + ) + self.remove(rect) self.dither() - self.set_variables_as_attrs(z_label, z_brace = brace) - - def break_into_computational_graph(self): - pass - - def show_preceding_layer_in_computational_graph(self): - pass - def show_number_lines(self): - pass + comp_graph = self.comp_graph + wL, aLm1, bL, zL, aL, C0 = [ + getattr(comp_graph, attr) + for attr in ["wL", "aLm1", "bL", "zL", "aL", "C0"] + ] + wL.val = self.network.weights[-1][0][0] + aL.val = self.decimals[0].number + zL.val = sigmoid_inverse(aL.val) + C0.val = (aL.val - 1)**2 + + number_line = UnitInterval( + unit_size = 2, + stroke_width = 2, + tick_size = 0.075, + color = LIGHT_GREY, + ) + + for mob in wL, zL, aL, C0: + mob.number_line = number_line.deepcopy() + if mob is wL: + mob.number_line.next_to(mob, UP, MED_LARGE_BUFF, LEFT) + else: + mob.number_line.next_to(mob, RIGHT) + mob.dot = Dot(color = mob.get_color()) + mob.dot.move_to( + mob.number_line.number_to_point(mob.val) + ) + if mob is wL: + path_arc = 0 + dot_spot = mob.dot.get_bottom() + else: + path_arc = -0.8*np.pi + dot_spot = mob.dot.get_top() + if mob is C0: + mob_spot = mob[0].get_corner(UP+RIGHT) + tip_length = 0.15 + else: + mob_spot = mob.get_corner(UP+RIGHT) + tip_length = 0.2 + mob.arrow = Arrow( + mob_spot, dot_spot, + use_rectangular_stem = False, + path_arc = path_arc, + tip_length = tip_length, + buff = SMALL_BUFF, + ) + mob.arrow.highlight(mob.get_color()) + mob.arrow.set_stroke(width = 5) + + self.play(ShowCreation( + mob.number_line, + submobject_mode = "lagged_start" + )) + self.play( + ShowCreation(mob.arrow), + ReplacementTransform( + mob.copy(), mob.dot, + path_arc = path_arc + ) + ) + self.dither() def ask_about_w_sensitivity(self): - pass + wL, aLm1, bL, zL, aL, C0 = [ + getattr(self.comp_graph, attr) + for attr in ["wL", "aLm1", "bL", "zL", "aL", "C0"] + ] + aLm1_val = self.last_neurons[1].get_fill_opacity() + bL_val = self.network.biases[-1][0] + + get_wL_val = lambda : wL.number_line.point_to_number( + wL.dot.get_center() + ) + get_zL_val = lambda : get_wL_val()*aLm1_val+bL_val + get_aL_val = lambda : sigmoid(get_zL_val()) + get_C0_val = lambda : (get_aL_val() - 1)**2 + + def generate_dot_update(term, val_func): + def update_dot(dot): + dot.move_to(term.number_line.number_to_point(val_func())) + return dot + return update_dot + + dot_update_anims = [ + UpdateFromFunc(term.dot, generate_dot_update(term, val_func)) + for term, val_func in [ + (zL, get_zL_val), + (aL, get_aL_val), + (C0, get_C0_val), + ] + ] + + wL_line = Line(wL.dot.get_center(), wL.dot.get_center()+LEFT) + del_wL = TexMobject("\\partial w^{(L)}") + del_wL.scale(self.derivative_scale_vale) + del_wL.brace = Brace(wL_line, UP) + del_wL.highlight(wL.get_color()) + del_wL.next_to(del_wL.brace, UP, SMALL_BUFF) + + C0_line = Line(C0.dot.get_center(), C0.dot.get_center()+MED_SMALL_BUFF*RIGHT) + del_C0 = TexMobject("\\partial C_0") + del_C0.scale(self.derivative_scale_vale) + del_C0.brace = Brace(C0_line, UP) + del_C0.highlight(C0.get_color()) + del_C0.next_to(del_C0.brace, UP, SMALL_BUFF) + + for sym in del_wL, del_C0: + self.play( + GrowFromCenter(sym.brace), + Write(sym, run_time = 1) + ) + self.play( + ApplyMethod( + wL.dot.shift, LEFT, + run_time = 2, + rate_func = there_and_back + ), + *dot_update_anims + ) + self.dither() + + self.set_variables_as_attrs( + dot_update_anims, del_wL, del_C0, + ) def show_derivative_wrt_w(self): pass