mirror of
https://github.com/3b1b/manim.git
synced 2025-09-01 00:48:45 +00:00
ShowAveragingCost of nn/part3
This commit is contained in:
parent
8a5d63917f
commit
20318cb9d2
1 changed files with 180 additions and 12 deletions
192
nn/part3.py
192
nn/part3.py
|
@ -114,8 +114,10 @@ class InterpretGradientComponents(GradientNudging):
|
||||||
cost = self.cost
|
cost = self.cost
|
||||||
|
|
||||||
for x in range(self.n_steps):
|
for x in range(self.n_steps):
|
||||||
self.move_grad_terms_into_position(grad_terms.copy())
|
self.move_grad_terms_into_position(
|
||||||
self.play(*self.get_weight_adjustment_anims(edges, cost))
|
grad_terms.copy(),
|
||||||
|
*self.get_weight_adjustment_anims(edges, cost)
|
||||||
|
)
|
||||||
self.play(*self.get_decimal_change_anims(decimals))
|
self.play(*self.get_decimal_change_anims(decimals))
|
||||||
|
|
||||||
def ask_about_high_dimensions(self):
|
def ask_about_high_dimensions(self):
|
||||||
|
@ -149,7 +151,7 @@ class InterpretGradientComponents(GradientNudging):
|
||||||
def circle_magnitudes(self):
|
def circle_magnitudes(self):
|
||||||
rects = VGroup()
|
rects = VGroup()
|
||||||
for decimal in self.grad_vect.decimals:
|
for decimal in self.grad_vect.decimals:
|
||||||
rects.add(SurroundingRectangle(VGroup(*decimal[-4:])))
|
rects.add(SurroundingRectangle(VGroup(*decimal[-5:])))
|
||||||
rects.highlight(WHITE)
|
rects.highlight(WHITE)
|
||||||
|
|
||||||
self.play(LaggedStart(ShowCreation, rects))
|
self.play(LaggedStart(ShowCreation, rects))
|
||||||
|
@ -239,6 +241,9 @@ class InterpretGradientComponents(GradientNudging):
|
||||||
term.dot.save_state()
|
term.dot.save_state()
|
||||||
term.dot.move_to(term)
|
term.dot.move_to(term)
|
||||||
term.dot.set_fill(opacity = 0)
|
term.dot.set_fill(opacity = 0)
|
||||||
|
term.words = TextMobject("Nudge this weight")
|
||||||
|
term.words.scale(0.7)
|
||||||
|
term.words.next_to(term.number_line, UP, MED_SMALL_BUFF)
|
||||||
|
|
||||||
groups = [
|
groups = [
|
||||||
VGroup(d, d.arrow, edge, w)
|
VGroup(d, d.arrow, edge, w)
|
||||||
|
@ -265,6 +270,7 @@ class InterpretGradientComponents(GradientNudging):
|
||||||
GrowFromCenter(w.brace),
|
GrowFromCenter(w.brace),
|
||||||
ShowCreation(w.number_line),
|
ShowCreation(w.number_line),
|
||||||
w.dot.restore,
|
w.dot.restore,
|
||||||
|
Write(w.words, run_time = 1),
|
||||||
*added_anims
|
*added_anims
|
||||||
)
|
)
|
||||||
for x in range(2):
|
for x in range(2):
|
||||||
|
@ -280,12 +286,14 @@ class InterpretGradientComponents(GradientNudging):
|
||||||
run_time = 2,
|
run_time = 2,
|
||||||
)
|
)
|
||||||
self.dither()
|
self.dither()
|
||||||
self.play(*map(FadeOut, [w.dot, w.brace, w.number_line]))
|
self.play(*map(FadeOut, [
|
||||||
|
w.dot, w.brace, w.number_line, w.words
|
||||||
|
]))
|
||||||
|
|
||||||
|
|
||||||
######
|
######
|
||||||
|
|
||||||
def move_grad_terms_into_position(self, grad_terms):
|
def move_grad_terms_into_position(self, grad_terms, *added_anims):
|
||||||
cost_expression = self.cost_expression
|
cost_expression = self.cost_expression
|
||||||
w_terms = self.cost_expression[1]
|
w_terms = self.cost_expression[1]
|
||||||
points = VGroup(*[
|
points = VGroup(*[
|
||||||
|
@ -316,7 +324,8 @@ class InterpretGradientComponents(GradientNudging):
|
||||||
submobject_mode = "lagged_start",
|
submobject_mode = "lagged_start",
|
||||||
run_time = 1
|
run_time = 1
|
||||||
),
|
),
|
||||||
FadeOut(words)
|
FadeOut(words),
|
||||||
|
*added_anims
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_weight_adjustment_anims(self, edges, cost):
|
def get_weight_adjustment_anims(self, edges, cost):
|
||||||
|
@ -406,17 +415,176 @@ class GetLostInNotation(PiCreatureScene):
|
||||||
)
|
)
|
||||||
self.dither()
|
self.dither()
|
||||||
|
|
||||||
|
|
||||||
class TODOInsertPreviewLearning(TODOStub):
|
class TODOInsertPreviewLearning(TODOStub):
|
||||||
CONFIG = {
|
CONFIG = {
|
||||||
"message" : "Insert PreviewLearning"
|
"message" : "Insert PreviewLearning"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class ShowAveragingCost(PreviewLearning):
|
||||||
|
CONFIG = {
|
||||||
|
"network_scale_val" : 0.8,
|
||||||
|
"stroke_width_exp" : 1,
|
||||||
|
"start_examples_time" : 5,
|
||||||
|
"examples_per_adjustment_time" : 2,
|
||||||
|
"n_adjustments" : 5,
|
||||||
|
"time_per_example" : 1./15,
|
||||||
|
"image_height" : 1.2,
|
||||||
|
}
|
||||||
|
def construct(self):
|
||||||
|
self.setup_network()
|
||||||
|
self.setup_diff_words()
|
||||||
|
self.show_many_examples()
|
||||||
|
|
||||||
|
def setup_network(self):
|
||||||
|
self.network_mob.scale(self.network_scale_val)
|
||||||
|
self.network_mob.to_edge(DOWN)
|
||||||
|
self.network_mob.shift(LEFT)
|
||||||
|
self.color_network_edges()
|
||||||
|
|
||||||
|
def setup_diff_words(self):
|
||||||
|
last_layer_copy = self.network_mob.layers[-1].deepcopy()
|
||||||
|
last_layer_copy.add(self.network_mob.output_labels.copy())
|
||||||
|
last_layer_copy.shift(1.5*RIGHT)
|
||||||
|
|
||||||
|
double_arrow = DoubleArrow(
|
||||||
|
self.network_mob.output_labels,
|
||||||
|
last_layer_copy,
|
||||||
|
color = RED
|
||||||
|
)
|
||||||
|
brace = Brace(
|
||||||
|
VGroup(self.network_mob.layers[-1], last_layer_copy),
|
||||||
|
UP
|
||||||
|
)
|
||||||
|
cost_words = brace.get_text("Cost of \\\\ one example")
|
||||||
|
cost_words.highlight(RED)
|
||||||
|
|
||||||
|
self.add(last_layer_copy, double_arrow, brace, cost_words)
|
||||||
|
self.set_variables_as_attrs(
|
||||||
|
last_layer_copy, double_arrow, brace, cost_words
|
||||||
|
)
|
||||||
|
self.last_layer_copy = last_layer_copy
|
||||||
|
|
||||||
|
def show_many_examples(self):
|
||||||
|
training_data, validation_data, test_data = load_data_wrapper()
|
||||||
|
training_data_iter = iter(training_data)
|
||||||
|
|
||||||
|
average_words = TextMobject("Average over all training examples")
|
||||||
|
average_words.next_to(LEFT, RIGHT)
|
||||||
|
average_words.to_edge(UP)
|
||||||
|
self.add(average_words)
|
||||||
|
|
||||||
|
for x in xrange(int(self.start_examples_time/self.time_per_example)):
|
||||||
|
train_in, train_out = training_data_iter.next()
|
||||||
|
self.show_one_example(train_in, train_out)
|
||||||
|
self.dither(self.time_per_example)
|
||||||
|
|
||||||
|
#Wiggle all edges
|
||||||
|
edges = VGroup(*it.chain(*self.network_mob.edge_groups))
|
||||||
|
reversed_edges = VGroup(*reversed(edges))
|
||||||
|
self.play(LaggedStart(
|
||||||
|
ApplyFunction, edges,
|
||||||
|
lambda edge : (
|
||||||
|
lambda m : m.rotate_in_place(np.pi/12).highlight(YELLOW),
|
||||||
|
edge,
|
||||||
|
),
|
||||||
|
rate_func = lambda t : wiggle(t, 4),
|
||||||
|
run_time = 3,
|
||||||
|
))
|
||||||
|
|
||||||
|
#Show all, then adjust
|
||||||
|
words = TextMobject(
|
||||||
|
"Each step \\\\ uses every \\\\ example\\\\",
|
||||||
|
"$\\dots$theoretically",
|
||||||
|
alignment = ""
|
||||||
|
)
|
||||||
|
words.highlight(YELLOW)
|
||||||
|
words.scale(0.8)
|
||||||
|
words.to_corner(UP+LEFT)
|
||||||
|
|
||||||
|
for x in xrange(self.n_adjustments):
|
||||||
|
for y in xrange(int(self.examples_per_adjustment_time/self.time_per_example)):
|
||||||
|
train_in, train_out = training_data_iter.next()
|
||||||
|
self.show_one_example(train_in, train_out)
|
||||||
|
self.dither(self.time_per_example)
|
||||||
|
self.play(LaggedStart(
|
||||||
|
ApplyMethod, reversed_edges,
|
||||||
|
lambda m : (m.rotate_in_place, np.pi),
|
||||||
|
run_time = 1,
|
||||||
|
lag_ratio = 0.2,
|
||||||
|
))
|
||||||
|
if x < 2:
|
||||||
|
self.play(FadeIn(words[x]))
|
||||||
|
else:
|
||||||
|
self.dither()
|
||||||
|
|
||||||
|
####
|
||||||
|
|
||||||
|
def show_one_example(self, train_in, train_out):
|
||||||
|
if hasattr(self, "curr_image"):
|
||||||
|
self.remove(self.curr_image)
|
||||||
|
image = MNistMobject(train_in)
|
||||||
|
image.scale_to_fit_height(self.image_height)
|
||||||
|
image.next_to(
|
||||||
|
self.network_mob.layers[0].neurons, UP,
|
||||||
|
aligned_edge = LEFT
|
||||||
|
)
|
||||||
|
self.add(image)
|
||||||
|
self.network_mob.activate_layers(train_in)
|
||||||
|
|
||||||
|
index = np.argmax(train_out)
|
||||||
|
self.last_layer_copy.neurons.set_fill(opacity = 0)
|
||||||
|
self.last_layer_copy.neurons[index].set_fill(WHITE, opacity = 1)
|
||||||
|
self.add(self.last_layer_copy)
|
||||||
|
|
||||||
|
self.curr_iamge = image
|
||||||
|
|
||||||
|
|
||||||
|
class WalkThroughTwoExample(ShowAveragingCost):
|
||||||
|
def construct(self):
|
||||||
|
self.force_skipping()
|
||||||
|
|
||||||
|
self.setup_network()
|
||||||
|
self.setup_diff_words()
|
||||||
|
self.show_single_example()
|
||||||
|
self.expand_last_layer()
|
||||||
|
self.cannot_directly_affect_activations()
|
||||||
|
self.show_desired_activation_nudges()
|
||||||
|
self.focus_on_one_neuron()
|
||||||
|
self.show_activation_formula()
|
||||||
|
self.three_ways_to_increase()
|
||||||
|
self.note_connections_to_brightest_neurons()
|
||||||
|
|
||||||
|
def show_single_example(self):
|
||||||
|
two_vect = get_organized_images()[2][0]
|
||||||
|
two_out = np.zeroes(10)
|
||||||
|
two_out[2] = 1.0
|
||||||
|
self.show_one_example(two_vect, two_out)
|
||||||
|
for layer in self.network_mob.layers:
|
||||||
|
layer.neurons.set_fill(opacity = 0)
|
||||||
|
|
||||||
|
self.revert_to_original_skipping_status()
|
||||||
|
self.feed_forward(two_vect)
|
||||||
|
|
||||||
|
def expand_last_layer(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def cannot_directly_affect_activations(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def show_desired_activation_nudges(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def focus_on_one_neuron(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def show_activation_formula(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def three_ways_to_increase(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def note_connections_to_brightest_neurons(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue