From 20318cb9d2d9f32b4483442cb477653379d874d5 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Mon, 23 Oct 2017 15:48:05 -0700 Subject: [PATCH] ShowAveragingCost of nn/part3 --- nn/part3.py | 192 ++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 180 insertions(+), 12 deletions(-) diff --git a/nn/part3.py b/nn/part3.py index ba4ca842..2ab5a008 100644 --- a/nn/part3.py +++ b/nn/part3.py @@ -114,8 +114,10 @@ class InterpretGradientComponents(GradientNudging): cost = self.cost for x in range(self.n_steps): - self.move_grad_terms_into_position(grad_terms.copy()) - self.play(*self.get_weight_adjustment_anims(edges, cost)) + self.move_grad_terms_into_position( + grad_terms.copy(), + *self.get_weight_adjustment_anims(edges, cost) + ) self.play(*self.get_decimal_change_anims(decimals)) def ask_about_high_dimensions(self): @@ -149,7 +151,7 @@ class InterpretGradientComponents(GradientNudging): def circle_magnitudes(self): rects = VGroup() for decimal in self.grad_vect.decimals: - rects.add(SurroundingRectangle(VGroup(*decimal[-4:]))) + rects.add(SurroundingRectangle(VGroup(*decimal[-5:]))) rects.highlight(WHITE) self.play(LaggedStart(ShowCreation, rects)) @@ -239,6 +241,9 @@ class InterpretGradientComponents(GradientNudging): term.dot.save_state() term.dot.move_to(term) term.dot.set_fill(opacity = 0) + term.words = TextMobject("Nudge this weight") + term.words.scale(0.7) + term.words.next_to(term.number_line, UP, MED_SMALL_BUFF) groups = [ VGroup(d, d.arrow, edge, w) @@ -265,6 +270,7 @@ class InterpretGradientComponents(GradientNudging): GrowFromCenter(w.brace), ShowCreation(w.number_line), w.dot.restore, + Write(w.words, run_time = 1), *added_anims ) for x in range(2): @@ -280,12 +286,14 @@ class InterpretGradientComponents(GradientNudging): run_time = 2, ) self.dither() - self.play(*map(FadeOut, [w.dot, w.brace, w.number_line])) + self.play(*map(FadeOut, [ + w.dot, w.brace, w.number_line, w.words + ])) ###### - def move_grad_terms_into_position(self, grad_terms): + def move_grad_terms_into_position(self, grad_terms, *added_anims): cost_expression = self.cost_expression w_terms = self.cost_expression[1] points = VGroup(*[ @@ -316,7 +324,8 @@ class InterpretGradientComponents(GradientNudging): submobject_mode = "lagged_start", run_time = 1 ), - FadeOut(words) + FadeOut(words), + *added_anims ) def get_weight_adjustment_anims(self, edges, cost): @@ -406,17 +415,176 @@ class GetLostInNotation(PiCreatureScene): ) self.dither() - class TODOInsertPreviewLearning(TODOStub): CONFIG = { "message" : "Insert PreviewLearning" } - - - - - +class ShowAveragingCost(PreviewLearning): + CONFIG = { + "network_scale_val" : 0.8, + "stroke_width_exp" : 1, + "start_examples_time" : 5, + "examples_per_adjustment_time" : 2, + "n_adjustments" : 5, + "time_per_example" : 1./15, + "image_height" : 1.2, + } + def construct(self): + self.setup_network() + self.setup_diff_words() + self.show_many_examples() + + def setup_network(self): + self.network_mob.scale(self.network_scale_val) + self.network_mob.to_edge(DOWN) + self.network_mob.shift(LEFT) + self.color_network_edges() + + def setup_diff_words(self): + last_layer_copy = self.network_mob.layers[-1].deepcopy() + last_layer_copy.add(self.network_mob.output_labels.copy()) + last_layer_copy.shift(1.5*RIGHT) + + double_arrow = DoubleArrow( + self.network_mob.output_labels, + last_layer_copy, + color = RED + ) + brace = Brace( + VGroup(self.network_mob.layers[-1], last_layer_copy), + UP + ) + cost_words = brace.get_text("Cost of \\\\ one example") + cost_words.highlight(RED) + + self.add(last_layer_copy, double_arrow, brace, cost_words) + self.set_variables_as_attrs( + last_layer_copy, double_arrow, brace, cost_words + ) + self.last_layer_copy = last_layer_copy + + def show_many_examples(self): + training_data, validation_data, test_data = load_data_wrapper() + training_data_iter = iter(training_data) + + average_words = TextMobject("Average over all training examples") + average_words.next_to(LEFT, RIGHT) + average_words.to_edge(UP) + self.add(average_words) + + for x in xrange(int(self.start_examples_time/self.time_per_example)): + train_in, train_out = training_data_iter.next() + self.show_one_example(train_in, train_out) + self.dither(self.time_per_example) + + #Wiggle all edges + edges = VGroup(*it.chain(*self.network_mob.edge_groups)) + reversed_edges = VGroup(*reversed(edges)) + self.play(LaggedStart( + ApplyFunction, edges, + lambda edge : ( + lambda m : m.rotate_in_place(np.pi/12).highlight(YELLOW), + edge, + ), + rate_func = lambda t : wiggle(t, 4), + run_time = 3, + )) + + #Show all, then adjust + words = TextMobject( + "Each step \\\\ uses every \\\\ example\\\\", + "$\\dots$theoretically", + alignment = "" + ) + words.highlight(YELLOW) + words.scale(0.8) + words.to_corner(UP+LEFT) + + for x in xrange(self.n_adjustments): + for y in xrange(int(self.examples_per_adjustment_time/self.time_per_example)): + train_in, train_out = training_data_iter.next() + self.show_one_example(train_in, train_out) + self.dither(self.time_per_example) + self.play(LaggedStart( + ApplyMethod, reversed_edges, + lambda m : (m.rotate_in_place, np.pi), + run_time = 1, + lag_ratio = 0.2, + )) + if x < 2: + self.play(FadeIn(words[x])) + else: + self.dither() + + #### + + def show_one_example(self, train_in, train_out): + if hasattr(self, "curr_image"): + self.remove(self.curr_image) + image = MNistMobject(train_in) + image.scale_to_fit_height(self.image_height) + image.next_to( + self.network_mob.layers[0].neurons, UP, + aligned_edge = LEFT + ) + self.add(image) + self.network_mob.activate_layers(train_in) + + index = np.argmax(train_out) + self.last_layer_copy.neurons.set_fill(opacity = 0) + self.last_layer_copy.neurons[index].set_fill(WHITE, opacity = 1) + self.add(self.last_layer_copy) + + self.curr_iamge = image + + +class WalkThroughTwoExample(ShowAveragingCost): + def construct(self): + self.force_skipping() + + self.setup_network() + self.setup_diff_words() + self.show_single_example() + self.expand_last_layer() + self.cannot_directly_affect_activations() + self.show_desired_activation_nudges() + self.focus_on_one_neuron() + self.show_activation_formula() + self.three_ways_to_increase() + self.note_connections_to_brightest_neurons() + + def show_single_example(self): + two_vect = get_organized_images()[2][0] + two_out = np.zeroes(10) + two_out[2] = 1.0 + self.show_one_example(two_vect, two_out) + for layer in self.network_mob.layers: + layer.neurons.set_fill(opacity = 0) + + self.revert_to_original_skipping_status() + self.feed_forward(two_vect) + + def expand_last_layer(self): + pass + + def cannot_directly_affect_activations(self): + pass + + def show_desired_activation_nudges(self): + pass + + def focus_on_one_neuron(self): + pass + + def show_activation_formula(self): + pass + + def three_ways_to_increase(self): + pass + + def note_connections_to_brightest_neurons(self): + pass