More WalkThroughTwoExample progress in nn/part3

This commit is contained in:
Grant Sanderson 2017-10-24 20:16:25 -07:00
parent f15f5d153e
commit 24d6c7bbef

View file

@ -559,6 +559,10 @@ class WalkThroughTwoExample(ShowAveragingCost):
self.three_ways_to_increase()
self.note_connections_to_brightest_neurons()
self.fire_together_wire_together()
self.show_desired_increase_to_previous_neurons()
self.only_keeping_track_of_changes()
self.show_other_output_neurons()
self.show_recursion()
def show_single_example(self):
two_vect = get_organized_images()[2][0]
@ -867,7 +871,7 @@ class WalkThroughTwoExample(ShowAveragingCost):
increase_words = VGroup(
TextMobject("Increase", "$b$"),
TextMobject("Increase", "$w_i$"),
TextMobject("Increase", "$a_i$"),
TextMobject("Change", "$a_i$"),
)
for words in increase_words:
words.highlight_by_tex_to_color_map({
@ -962,10 +966,11 @@ class WalkThroughTwoExample(ShowAveragingCost):
two_decimal = self.two_decimal
two_activation = two_decimal.number
edge_animation = LaggedStart(
ShowCreationThenDestruction, bright_edges,
lag_ratio = 0.7
)
def get_edge_animation():
return LaggedStart(
ShowCreationThenDestruction, bright_edges,
lag_ratio = 0.7
)
neuron_arrows = VGroup(*[
Vector(MED_LARGE_BUFF*RIGHT).next_to(n, LEFT)
for n in bright_neurons
@ -979,9 +984,11 @@ class WalkThroughTwoExample(ShowAveragingCost):
))
two_neuron_rect = SurroundingRectangle(two_neuron)
seeing_words = TextMobject("Seeing a 2")
seeing_words.scale(0.8)
thinking_words = TextMobject("Thinking about a 2")
thinking_words.scale(0.8)
seeing_words.next_to(neuron_rects, UP)
thinking_words.next_to(two_neuron_rect, UP, aligned_edge = LEFT)
thinking_words.next_to(two_neuron_arrow, RIGHT)
morty = Mortimer()
morty.scale(0.8)
@ -993,7 +1000,6 @@ class WalkThroughTwoExample(ShowAveragingCost):
""")
words.to_edge(RIGHT)
self.revert_to_original_skipping_status()
self.play(FadeIn(morty))
self.play(
Write(words),
@ -1001,20 +1007,20 @@ class WalkThroughTwoExample(ShowAveragingCost):
)
self.play(Blink(morty))
self.play(
edge_animation,
get_edge_animation(),
morty.change, "pondering", bright_edges
)
self.play(edge_animation)
self.play(get_edge_animation())
self.play(
LaggedStart(GrowArrow, neuron_arrows),
edge_animation,
get_edge_animation(),
)
self.play(
GrowArrow(two_neuron_arrow),
morty.change, "raise_right_hand", two_neuron
)
self.play(
ApplyMethod(two_neuron.set_fill, WHITE, 1, **kwargs),
ApplyMethod(two_neuron.set_fill, WHITE, 1),
ChangingDecimal(
two_decimal,
lambda a : interpolate(two_activation, 1, a),
@ -1024,7 +1030,8 @@ class WalkThroughTwoExample(ShowAveragingCost):
two_decimal,
lambda m : m.highlight(WHITE if m.number < 0.8 else BLACK),
),
ShowCreation(bright_edges),
LaggedStart(ShowCreation, bright_edges),
run_time = 2,
)
self.dither()
self.play(
@ -1032,6 +1039,7 @@ class WalkThroughTwoExample(ShowAveragingCost):
Write(seeing_words, run_time = 2),
morty.change, "thinking", seeing_words
)
self.dither()
self.play(
ShowCreation(two_neuron_rect),
Write(thinking_words, run_time = 2),
@ -1041,8 +1049,133 @@ class WalkThroughTwoExample(ShowAveragingCost):
self.play(LaggedStart(FadeOut, VGroup(
neuron_rects, two_neuron_rect,
seeing_words, thinking_words,
words, morty
words, morty,
neuron_arrows, two_neuron_arrow,
bright_edges, bright_neurons,
)))
self.play(
ApplyMethod(two_neuron.set_fill, WHITE, two_activation),
ChangingDecimal(
two_decimal,
lambda a : interpolate(1, two_activation, a),
num_decimal_points = 1,
),
UpdateFromFunc(
two_decimal,
lambda m : m.highlight(WHITE if m.number < 0.8 else BLACK),
),
)
def show_desired_increase_to_previous_neurons(self):
increase_words = self.increase_words
two_neuron = self.two_neuron
edges = two_neuron.edges_in
prev_neurons = self.network_mob.layers[-2].neurons
positive_arrows = VGroup()
negative_arrows = VGroup()
positive_edges = VGroup()
negative_edges = VGroup()
positive_neurons = VGroup()
negative_neurons = VGroup()
for neuron, edge in zip(prev_neurons, edges):
value = edge.get_stroke_width()
if Color(edge.get_stroke_color()) == Color(self.negative_edge_color):
value *= -1
arrow = Vector(0.25*value*UP, color = edge.get_color())
arrow.stretch_to_fit_height(neuron.get_height())
arrow.move_to(neuron.get_left())
arrow.shift(SMALL_BUFF*LEFT)
if value > 0:
positive_arrows.add(arrow)
positive_edges.add(edge)
positive_neurons.add(neuron)
else:
negative_arrows.add(arrow)
negative_edges.add(edge)
negative_neurons.add(neuron)
added_words = TextMobject("in proportion to $w_i$")
added_words.highlight(self.w_terms.get_color())
added_words.next_to(
increase_words[-1], DOWN,
SMALL_BUFF, aligned_edge = LEFT
)
self.play(LaggedStart(
ApplyFunction, prev_neurons,
lambda neuron : (
lambda m : m.scale_in_place(0.5).highlight(YELLOW),
neuron
),
rate_func = wiggle
))
self.dither()
for positive in [True, False]:
if positive:
arrows = positive_arrows
edges = positive_edges
neurons = positive_neurons
color = self.positive_edge_color
else:
arrows = negative_arrows
edges = negative_edges
neurons = negative_neurons
color = self.negative_edge_color
self.play(
LaggedStart(
Transform, edges,
lambda mob : (
mob,
Dot(
mob.get_center(),
stroke_color = edges[0].get_color(),
stroke_width = 1,
radius = 0.25*SMALL_BUFF,
fill_opacity = 0
)
),
rate_func = there_and_back
),
neurons.set_stroke, color, 3,
)
self.play(
LaggedStart(GrowArrow, arrows),
ApplyMethod(
neurons.set_fill, color, 1,
rate_func = there_and_back,
)
)
self.dither()
self.play(Write(added_words, run_time = 1))
self.set_variables_as_attrs(
in_proportion_to_w = added_words,
prev_neuron_arrows = VGroup(positive_arrows, negative_arrows),
)
def only_keeping_track_of_changes(self):
arrows = self.prev_neuron_arrows
prev_neurons = self.network_mob.layers[-2].neurons
rect = SurroundingRectangle(VGroup(arrows, prev_neurons))
words = TextMobject("No direct influence")
words.next_to(rect, UP)
self.revert_to_original_skipping_status()
self.play(ShowCreation(rect))
self.play(Write(words))
self.dither()
self.play(FadeOut(VGroup(words, rect)))
def show_other_output_neurons(self):
two_neuron = self.two_neuron
two_decimal = self.two_decimal
two_edges = two_neuron.edges_in
def show_recursion(self):
pass