ShowAveragingCost of nn/part3

This commit is contained in:
Grant Sanderson 2017-10-23 15:48:05 -07:00
parent 8a5d63917f
commit 20318cb9d2

View file

@ -114,8 +114,10 @@ class InterpretGradientComponents(GradientNudging):
cost = self.cost
for x in range(self.n_steps):
self.move_grad_terms_into_position(grad_terms.copy())
self.play(*self.get_weight_adjustment_anims(edges, cost))
self.move_grad_terms_into_position(
grad_terms.copy(),
*self.get_weight_adjustment_anims(edges, cost)
)
self.play(*self.get_decimal_change_anims(decimals))
def ask_about_high_dimensions(self):
@ -149,7 +151,7 @@ class InterpretGradientComponents(GradientNudging):
def circle_magnitudes(self):
rects = VGroup()
for decimal in self.grad_vect.decimals:
rects.add(SurroundingRectangle(VGroup(*decimal[-4:])))
rects.add(SurroundingRectangle(VGroup(*decimal[-5:])))
rects.highlight(WHITE)
self.play(LaggedStart(ShowCreation, rects))
@ -239,6 +241,9 @@ class InterpretGradientComponents(GradientNudging):
term.dot.save_state()
term.dot.move_to(term)
term.dot.set_fill(opacity = 0)
term.words = TextMobject("Nudge this weight")
term.words.scale(0.7)
term.words.next_to(term.number_line, UP, MED_SMALL_BUFF)
groups = [
VGroup(d, d.arrow, edge, w)
@ -265,6 +270,7 @@ class InterpretGradientComponents(GradientNudging):
GrowFromCenter(w.brace),
ShowCreation(w.number_line),
w.dot.restore,
Write(w.words, run_time = 1),
*added_anims
)
for x in range(2):
@ -280,12 +286,14 @@ class InterpretGradientComponents(GradientNudging):
run_time = 2,
)
self.dither()
self.play(*map(FadeOut, [w.dot, w.brace, w.number_line]))
self.play(*map(FadeOut, [
w.dot, w.brace, w.number_line, w.words
]))
######
def move_grad_terms_into_position(self, grad_terms):
def move_grad_terms_into_position(self, grad_terms, *added_anims):
cost_expression = self.cost_expression
w_terms = self.cost_expression[1]
points = VGroup(*[
@ -316,7 +324,8 @@ class InterpretGradientComponents(GradientNudging):
submobject_mode = "lagged_start",
run_time = 1
),
FadeOut(words)
FadeOut(words),
*added_anims
)
def get_weight_adjustment_anims(self, edges, cost):
@ -406,17 +415,176 @@ class GetLostInNotation(PiCreatureScene):
)
self.dither()
class TODOInsertPreviewLearning(TODOStub):
CONFIG = {
"message" : "Insert PreviewLearning"
}
class ShowAveragingCost(PreviewLearning):
CONFIG = {
"network_scale_val" : 0.8,
"stroke_width_exp" : 1,
"start_examples_time" : 5,
"examples_per_adjustment_time" : 2,
"n_adjustments" : 5,
"time_per_example" : 1./15,
"image_height" : 1.2,
}
def construct(self):
self.setup_network()
self.setup_diff_words()
self.show_many_examples()
def setup_network(self):
self.network_mob.scale(self.network_scale_val)
self.network_mob.to_edge(DOWN)
self.network_mob.shift(LEFT)
self.color_network_edges()
def setup_diff_words(self):
last_layer_copy = self.network_mob.layers[-1].deepcopy()
last_layer_copy.add(self.network_mob.output_labels.copy())
last_layer_copy.shift(1.5*RIGHT)
double_arrow = DoubleArrow(
self.network_mob.output_labels,
last_layer_copy,
color = RED
)
brace = Brace(
VGroup(self.network_mob.layers[-1], last_layer_copy),
UP
)
cost_words = brace.get_text("Cost of \\\\ one example")
cost_words.highlight(RED)
self.add(last_layer_copy, double_arrow, brace, cost_words)
self.set_variables_as_attrs(
last_layer_copy, double_arrow, brace, cost_words
)
self.last_layer_copy = last_layer_copy
def show_many_examples(self):
training_data, validation_data, test_data = load_data_wrapper()
training_data_iter = iter(training_data)
average_words = TextMobject("Average over all training examples")
average_words.next_to(LEFT, RIGHT)
average_words.to_edge(UP)
self.add(average_words)
for x in xrange(int(self.start_examples_time/self.time_per_example)):
train_in, train_out = training_data_iter.next()
self.show_one_example(train_in, train_out)
self.dither(self.time_per_example)
#Wiggle all edges
edges = VGroup(*it.chain(*self.network_mob.edge_groups))
reversed_edges = VGroup(*reversed(edges))
self.play(LaggedStart(
ApplyFunction, edges,
lambda edge : (
lambda m : m.rotate_in_place(np.pi/12).highlight(YELLOW),
edge,
),
rate_func = lambda t : wiggle(t, 4),
run_time = 3,
))
#Show all, then adjust
words = TextMobject(
"Each step \\\\ uses every \\\\ example\\\\",
"$\\dots$theoretically",
alignment = ""
)
words.highlight(YELLOW)
words.scale(0.8)
words.to_corner(UP+LEFT)
for x in xrange(self.n_adjustments):
for y in xrange(int(self.examples_per_adjustment_time/self.time_per_example)):
train_in, train_out = training_data_iter.next()
self.show_one_example(train_in, train_out)
self.dither(self.time_per_example)
self.play(LaggedStart(
ApplyMethod, reversed_edges,
lambda m : (m.rotate_in_place, np.pi),
run_time = 1,
lag_ratio = 0.2,
))
if x < 2:
self.play(FadeIn(words[x]))
else:
self.dither()
####
def show_one_example(self, train_in, train_out):
if hasattr(self, "curr_image"):
self.remove(self.curr_image)
image = MNistMobject(train_in)
image.scale_to_fit_height(self.image_height)
image.next_to(
self.network_mob.layers[0].neurons, UP,
aligned_edge = LEFT
)
self.add(image)
self.network_mob.activate_layers(train_in)
index = np.argmax(train_out)
self.last_layer_copy.neurons.set_fill(opacity = 0)
self.last_layer_copy.neurons[index].set_fill(WHITE, opacity = 1)
self.add(self.last_layer_copy)
self.curr_iamge = image
class WalkThroughTwoExample(ShowAveragingCost):
def construct(self):
self.force_skipping()
self.setup_network()
self.setup_diff_words()
self.show_single_example()
self.expand_last_layer()
self.cannot_directly_affect_activations()
self.show_desired_activation_nudges()
self.focus_on_one_neuron()
self.show_activation_formula()
self.three_ways_to_increase()
self.note_connections_to_brightest_neurons()
def show_single_example(self):
two_vect = get_organized_images()[2][0]
two_out = np.zeroes(10)
two_out[2] = 1.0
self.show_one_example(two_vect, two_out)
for layer in self.network_mob.layers:
layer.neurons.set_fill(opacity = 0)
self.revert_to_original_skipping_status()
self.feed_forward(two_vect)
def expand_last_layer(self):
pass
def cannot_directly_affect_activations(self):
pass
def show_desired_activation_nudges(self):
pass
def focus_on_one_neuron(self):
pass
def show_activation_formula(self):
pass
def three_ways_to_increase(self):
pass
def note_connections_to_brightest_neurons(self):
pass