mirror of
https://github.com/3b1b/manim.git
synced 2025-08-31 23:58:32 +00:00
ShowAveragingCost of nn/part3
This commit is contained in:
parent
8a5d63917f
commit
20318cb9d2
1 changed files with 180 additions and 12 deletions
192
nn/part3.py
192
nn/part3.py
|
@ -114,8 +114,10 @@ class InterpretGradientComponents(GradientNudging):
|
|||
cost = self.cost
|
||||
|
||||
for x in range(self.n_steps):
|
||||
self.move_grad_terms_into_position(grad_terms.copy())
|
||||
self.play(*self.get_weight_adjustment_anims(edges, cost))
|
||||
self.move_grad_terms_into_position(
|
||||
grad_terms.copy(),
|
||||
*self.get_weight_adjustment_anims(edges, cost)
|
||||
)
|
||||
self.play(*self.get_decimal_change_anims(decimals))
|
||||
|
||||
def ask_about_high_dimensions(self):
|
||||
|
@ -149,7 +151,7 @@ class InterpretGradientComponents(GradientNudging):
|
|||
def circle_magnitudes(self):
|
||||
rects = VGroup()
|
||||
for decimal in self.grad_vect.decimals:
|
||||
rects.add(SurroundingRectangle(VGroup(*decimal[-4:])))
|
||||
rects.add(SurroundingRectangle(VGroup(*decimal[-5:])))
|
||||
rects.highlight(WHITE)
|
||||
|
||||
self.play(LaggedStart(ShowCreation, rects))
|
||||
|
@ -239,6 +241,9 @@ class InterpretGradientComponents(GradientNudging):
|
|||
term.dot.save_state()
|
||||
term.dot.move_to(term)
|
||||
term.dot.set_fill(opacity = 0)
|
||||
term.words = TextMobject("Nudge this weight")
|
||||
term.words.scale(0.7)
|
||||
term.words.next_to(term.number_line, UP, MED_SMALL_BUFF)
|
||||
|
||||
groups = [
|
||||
VGroup(d, d.arrow, edge, w)
|
||||
|
@ -265,6 +270,7 @@ class InterpretGradientComponents(GradientNudging):
|
|||
GrowFromCenter(w.brace),
|
||||
ShowCreation(w.number_line),
|
||||
w.dot.restore,
|
||||
Write(w.words, run_time = 1),
|
||||
*added_anims
|
||||
)
|
||||
for x in range(2):
|
||||
|
@ -280,12 +286,14 @@ class InterpretGradientComponents(GradientNudging):
|
|||
run_time = 2,
|
||||
)
|
||||
self.dither()
|
||||
self.play(*map(FadeOut, [w.dot, w.brace, w.number_line]))
|
||||
self.play(*map(FadeOut, [
|
||||
w.dot, w.brace, w.number_line, w.words
|
||||
]))
|
||||
|
||||
|
||||
######
|
||||
|
||||
def move_grad_terms_into_position(self, grad_terms):
|
||||
def move_grad_terms_into_position(self, grad_terms, *added_anims):
|
||||
cost_expression = self.cost_expression
|
||||
w_terms = self.cost_expression[1]
|
||||
points = VGroup(*[
|
||||
|
@ -316,7 +324,8 @@ class InterpretGradientComponents(GradientNudging):
|
|||
submobject_mode = "lagged_start",
|
||||
run_time = 1
|
||||
),
|
||||
FadeOut(words)
|
||||
FadeOut(words),
|
||||
*added_anims
|
||||
)
|
||||
|
||||
def get_weight_adjustment_anims(self, edges, cost):
|
||||
|
@ -406,17 +415,176 @@ class GetLostInNotation(PiCreatureScene):
|
|||
)
|
||||
self.dither()
|
||||
|
||||
|
||||
class TODOInsertPreviewLearning(TODOStub):
|
||||
CONFIG = {
|
||||
"message" : "Insert PreviewLearning"
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class ShowAveragingCost(PreviewLearning):
|
||||
CONFIG = {
|
||||
"network_scale_val" : 0.8,
|
||||
"stroke_width_exp" : 1,
|
||||
"start_examples_time" : 5,
|
||||
"examples_per_adjustment_time" : 2,
|
||||
"n_adjustments" : 5,
|
||||
"time_per_example" : 1./15,
|
||||
"image_height" : 1.2,
|
||||
}
|
||||
def construct(self):
|
||||
self.setup_network()
|
||||
self.setup_diff_words()
|
||||
self.show_many_examples()
|
||||
|
||||
def setup_network(self):
|
||||
self.network_mob.scale(self.network_scale_val)
|
||||
self.network_mob.to_edge(DOWN)
|
||||
self.network_mob.shift(LEFT)
|
||||
self.color_network_edges()
|
||||
|
||||
def setup_diff_words(self):
|
||||
last_layer_copy = self.network_mob.layers[-1].deepcopy()
|
||||
last_layer_copy.add(self.network_mob.output_labels.copy())
|
||||
last_layer_copy.shift(1.5*RIGHT)
|
||||
|
||||
double_arrow = DoubleArrow(
|
||||
self.network_mob.output_labels,
|
||||
last_layer_copy,
|
||||
color = RED
|
||||
)
|
||||
brace = Brace(
|
||||
VGroup(self.network_mob.layers[-1], last_layer_copy),
|
||||
UP
|
||||
)
|
||||
cost_words = brace.get_text("Cost of \\\\ one example")
|
||||
cost_words.highlight(RED)
|
||||
|
||||
self.add(last_layer_copy, double_arrow, brace, cost_words)
|
||||
self.set_variables_as_attrs(
|
||||
last_layer_copy, double_arrow, brace, cost_words
|
||||
)
|
||||
self.last_layer_copy = last_layer_copy
|
||||
|
||||
def show_many_examples(self):
|
||||
training_data, validation_data, test_data = load_data_wrapper()
|
||||
training_data_iter = iter(training_data)
|
||||
|
||||
average_words = TextMobject("Average over all training examples")
|
||||
average_words.next_to(LEFT, RIGHT)
|
||||
average_words.to_edge(UP)
|
||||
self.add(average_words)
|
||||
|
||||
for x in xrange(int(self.start_examples_time/self.time_per_example)):
|
||||
train_in, train_out = training_data_iter.next()
|
||||
self.show_one_example(train_in, train_out)
|
||||
self.dither(self.time_per_example)
|
||||
|
||||
#Wiggle all edges
|
||||
edges = VGroup(*it.chain(*self.network_mob.edge_groups))
|
||||
reversed_edges = VGroup(*reversed(edges))
|
||||
self.play(LaggedStart(
|
||||
ApplyFunction, edges,
|
||||
lambda edge : (
|
||||
lambda m : m.rotate_in_place(np.pi/12).highlight(YELLOW),
|
||||
edge,
|
||||
),
|
||||
rate_func = lambda t : wiggle(t, 4),
|
||||
run_time = 3,
|
||||
))
|
||||
|
||||
#Show all, then adjust
|
||||
words = TextMobject(
|
||||
"Each step \\\\ uses every \\\\ example\\\\",
|
||||
"$\\dots$theoretically",
|
||||
alignment = ""
|
||||
)
|
||||
words.highlight(YELLOW)
|
||||
words.scale(0.8)
|
||||
words.to_corner(UP+LEFT)
|
||||
|
||||
for x in xrange(self.n_adjustments):
|
||||
for y in xrange(int(self.examples_per_adjustment_time/self.time_per_example)):
|
||||
train_in, train_out = training_data_iter.next()
|
||||
self.show_one_example(train_in, train_out)
|
||||
self.dither(self.time_per_example)
|
||||
self.play(LaggedStart(
|
||||
ApplyMethod, reversed_edges,
|
||||
lambda m : (m.rotate_in_place, np.pi),
|
||||
run_time = 1,
|
||||
lag_ratio = 0.2,
|
||||
))
|
||||
if x < 2:
|
||||
self.play(FadeIn(words[x]))
|
||||
else:
|
||||
self.dither()
|
||||
|
||||
####
|
||||
|
||||
def show_one_example(self, train_in, train_out):
|
||||
if hasattr(self, "curr_image"):
|
||||
self.remove(self.curr_image)
|
||||
image = MNistMobject(train_in)
|
||||
image.scale_to_fit_height(self.image_height)
|
||||
image.next_to(
|
||||
self.network_mob.layers[0].neurons, UP,
|
||||
aligned_edge = LEFT
|
||||
)
|
||||
self.add(image)
|
||||
self.network_mob.activate_layers(train_in)
|
||||
|
||||
index = np.argmax(train_out)
|
||||
self.last_layer_copy.neurons.set_fill(opacity = 0)
|
||||
self.last_layer_copy.neurons[index].set_fill(WHITE, opacity = 1)
|
||||
self.add(self.last_layer_copy)
|
||||
|
||||
self.curr_iamge = image
|
||||
|
||||
|
||||
class WalkThroughTwoExample(ShowAveragingCost):
|
||||
def construct(self):
|
||||
self.force_skipping()
|
||||
|
||||
self.setup_network()
|
||||
self.setup_diff_words()
|
||||
self.show_single_example()
|
||||
self.expand_last_layer()
|
||||
self.cannot_directly_affect_activations()
|
||||
self.show_desired_activation_nudges()
|
||||
self.focus_on_one_neuron()
|
||||
self.show_activation_formula()
|
||||
self.three_ways_to_increase()
|
||||
self.note_connections_to_brightest_neurons()
|
||||
|
||||
def show_single_example(self):
|
||||
two_vect = get_organized_images()[2][0]
|
||||
two_out = np.zeroes(10)
|
||||
two_out[2] = 1.0
|
||||
self.show_one_example(two_vect, two_out)
|
||||
for layer in self.network_mob.layers:
|
||||
layer.neurons.set_fill(opacity = 0)
|
||||
|
||||
self.revert_to_original_skipping_status()
|
||||
self.feed_forward(two_vect)
|
||||
|
||||
def expand_last_layer(self):
|
||||
pass
|
||||
|
||||
def cannot_directly_affect_activations(self):
|
||||
pass
|
||||
|
||||
def show_desired_activation_nudges(self):
|
||||
pass
|
||||
|
||||
def focus_on_one_neuron(self):
|
||||
pass
|
||||
|
||||
def show_activation_formula(self):
|
||||
pass
|
||||
|
||||
def three_ways_to_increase(self):
|
||||
pass
|
||||
|
||||
def note_connections_to_brightest_neurons(self):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue