diff --git a/nn/part3.py b/nn/part3.py index 8704625f..aec51355 100644 --- a/nn/part3.py +++ b/nn/part3.py @@ -559,6 +559,10 @@ class WalkThroughTwoExample(ShowAveragingCost): self.three_ways_to_increase() self.note_connections_to_brightest_neurons() self.fire_together_wire_together() + self.show_desired_increase_to_previous_neurons() + self.only_keeping_track_of_changes() + self.show_other_output_neurons() + self.show_recursion() def show_single_example(self): two_vect = get_organized_images()[2][0] @@ -867,7 +871,7 @@ class WalkThroughTwoExample(ShowAveragingCost): increase_words = VGroup( TextMobject("Increase", "$b$"), TextMobject("Increase", "$w_i$"), - TextMobject("Increase", "$a_i$"), + TextMobject("Change", "$a_i$"), ) for words in increase_words: words.highlight_by_tex_to_color_map({ @@ -962,10 +966,11 @@ class WalkThroughTwoExample(ShowAveragingCost): two_decimal = self.two_decimal two_activation = two_decimal.number - edge_animation = LaggedStart( - ShowCreationThenDestruction, bright_edges, - lag_ratio = 0.7 - ) + def get_edge_animation(): + return LaggedStart( + ShowCreationThenDestruction, bright_edges, + lag_ratio = 0.7 + ) neuron_arrows = VGroup(*[ Vector(MED_LARGE_BUFF*RIGHT).next_to(n, LEFT) for n in bright_neurons @@ -979,9 +984,11 @@ class WalkThroughTwoExample(ShowAveragingCost): )) two_neuron_rect = SurroundingRectangle(two_neuron) seeing_words = TextMobject("Seeing a 2") + seeing_words.scale(0.8) thinking_words = TextMobject("Thinking about a 2") + thinking_words.scale(0.8) seeing_words.next_to(neuron_rects, UP) - thinking_words.next_to(two_neuron_rect, UP, aligned_edge = LEFT) + thinking_words.next_to(two_neuron_arrow, RIGHT) morty = Mortimer() morty.scale(0.8) @@ -993,7 +1000,6 @@ class WalkThroughTwoExample(ShowAveragingCost): """) words.to_edge(RIGHT) - self.revert_to_original_skipping_status() self.play(FadeIn(morty)) self.play( Write(words), @@ -1001,20 +1007,20 @@ class WalkThroughTwoExample(ShowAveragingCost): ) self.play(Blink(morty)) self.play( - edge_animation, + get_edge_animation(), morty.change, "pondering", bright_edges ) - self.play(edge_animation) + self.play(get_edge_animation()) self.play( LaggedStart(GrowArrow, neuron_arrows), - edge_animation, + get_edge_animation(), ) self.play( GrowArrow(two_neuron_arrow), morty.change, "raise_right_hand", two_neuron ) self.play( - ApplyMethod(two_neuron.set_fill, WHITE, 1, **kwargs), + ApplyMethod(two_neuron.set_fill, WHITE, 1), ChangingDecimal( two_decimal, lambda a : interpolate(two_activation, 1, a), @@ -1024,7 +1030,8 @@ class WalkThroughTwoExample(ShowAveragingCost): two_decimal, lambda m : m.highlight(WHITE if m.number < 0.8 else BLACK), ), - ShowCreation(bright_edges), + LaggedStart(ShowCreation, bright_edges), + run_time = 2, ) self.dither() self.play( @@ -1032,6 +1039,7 @@ class WalkThroughTwoExample(ShowAveragingCost): Write(seeing_words, run_time = 2), morty.change, "thinking", seeing_words ) + self.dither() self.play( ShowCreation(two_neuron_rect), Write(thinking_words, run_time = 2), @@ -1041,8 +1049,133 @@ class WalkThroughTwoExample(ShowAveragingCost): self.play(LaggedStart(FadeOut, VGroup( neuron_rects, two_neuron_rect, seeing_words, thinking_words, - words, morty + words, morty, + neuron_arrows, two_neuron_arrow, + bright_edges, bright_neurons, ))) + self.play( + ApplyMethod(two_neuron.set_fill, WHITE, two_activation), + ChangingDecimal( + two_decimal, + lambda a : interpolate(1, two_activation, a), + num_decimal_points = 1, + ), + UpdateFromFunc( + two_decimal, + lambda m : m.highlight(WHITE if m.number < 0.8 else BLACK), + ), + ) + + def show_desired_increase_to_previous_neurons(self): + increase_words = self.increase_words + two_neuron = self.two_neuron + edges = two_neuron.edges_in + prev_neurons = self.network_mob.layers[-2].neurons + + positive_arrows = VGroup() + negative_arrows = VGroup() + positive_edges = VGroup() + negative_edges = VGroup() + positive_neurons = VGroup() + negative_neurons = VGroup() + for neuron, edge in zip(prev_neurons, edges): + value = edge.get_stroke_width() + if Color(edge.get_stroke_color()) == Color(self.negative_edge_color): + value *= -1 + arrow = Vector(0.25*value*UP, color = edge.get_color()) + arrow.stretch_to_fit_height(neuron.get_height()) + arrow.move_to(neuron.get_left()) + arrow.shift(SMALL_BUFF*LEFT) + if value > 0: + positive_arrows.add(arrow) + positive_edges.add(edge) + positive_neurons.add(neuron) + else: + negative_arrows.add(arrow) + negative_edges.add(edge) + negative_neurons.add(neuron) + + added_words = TextMobject("in proportion to $w_i$") + added_words.highlight(self.w_terms.get_color()) + added_words.next_to( + increase_words[-1], DOWN, + SMALL_BUFF, aligned_edge = LEFT + ) + + self.play(LaggedStart( + ApplyFunction, prev_neurons, + lambda neuron : ( + lambda m : m.scale_in_place(0.5).highlight(YELLOW), + neuron + ), + rate_func = wiggle + )) + self.dither() + for positive in [True, False]: + if positive: + arrows = positive_arrows + edges = positive_edges + neurons = positive_neurons + color = self.positive_edge_color + else: + arrows = negative_arrows + edges = negative_edges + neurons = negative_neurons + color = self.negative_edge_color + self.play( + LaggedStart( + Transform, edges, + lambda mob : ( + mob, + Dot( + mob.get_center(), + stroke_color = edges[0].get_color(), + stroke_width = 1, + radius = 0.25*SMALL_BUFF, + fill_opacity = 0 + ) + ), + rate_func = there_and_back + ), + neurons.set_stroke, color, 3, + ) + self.play( + LaggedStart(GrowArrow, arrows), + ApplyMethod( + neurons.set_fill, color, 1, + rate_func = there_and_back, + ) + ) + self.dither() + self.play(Write(added_words, run_time = 1)) + + self.set_variables_as_attrs( + in_proportion_to_w = added_words, + prev_neuron_arrows = VGroup(positive_arrows, negative_arrows), + ) + + def only_keeping_track_of_changes(self): + arrows = self.prev_neuron_arrows + prev_neurons = self.network_mob.layers[-2].neurons + rect = SurroundingRectangle(VGroup(arrows, prev_neurons)) + + words = TextMobject("No direct influence") + words.next_to(rect, UP) + + self.revert_to_original_skipping_status() + self.play(ShowCreation(rect)) + self.play(Write(words)) + self.dither() + self.play(FadeOut(VGroup(words, rect))) + + def show_other_output_neurons(self): + two_neuron = self.two_neuron + two_decimal = self.two_decimal + two_edges = two_neuron.edges_in + + + def show_recursion(self): + pass