Finished SimplestNetworkExample in nn/part3

This commit is contained in:
Grant Sanderson 2017-10-30 15:28:48 -07:00
parent ac079f182a
commit 065de1af0c

View file

@ -2005,8 +2005,6 @@ class SimplestNetworkExample(PreviewLearning):
"derivative_scale_val" : 0.85,
}
def construct(self):
self.force_skipping()
self.seed_random_libraries()
self.collapse_ordinary_network()
self.focus_just_on_last_two_layers()
@ -2024,11 +2022,18 @@ class SimplestNetworkExample(PreviewLearning):
self.show_chain_rule()
self.name_chain_rule()
self.indicate_everything_on_screen()
self.prepare_for_derivatives()
self.compute_derivatives()
self.get_lost_in_formulas()
self.fire_together_wire_together()
self.organize_chain_rule_rhs()
self.show_average_derivative()
self.show_gradient()
self.transition_to_derivative_wrt_b()
self.show_derivative_wrt_b()
self.show_derivative_wrt_a()
self.show_previous_weight_and_bias()
self.animate_long_path()
def seed_random_libraries(self):
np.random.seed(self.random_seed)
@ -2103,6 +2108,8 @@ class SimplestNetworkExample(PreviewLearning):
))
self.dither()
self.prev_layers = to_fade
def label_neurons(self):
neurons = VGroup(*[
self.network_mob.layers[i].neurons[0]
@ -2302,7 +2309,11 @@ class SimplestNetworkExample(PreviewLearning):
def introduce_z(self):
terms = self.weighted_sum_terms
terms.generate_target()
terms.target.next_to(self.formula, UP, aligned_edge = RIGHT)
terms.target.next_to(
self.formula, UP,
buff = MED_LARGE_BUFF,
aligned_edge = RIGHT
)
terms.target.shift(MED_LARGE_BUFF*RIGHT)
equals = TexMobject("=")
equals.next_to(terms.target[0][0], LEFT)
@ -2332,8 +2343,8 @@ class SimplestNetworkExample(PreviewLearning):
)
self.dither()
zL_formula = VGroup(z_label, equals, terms)
aL_formula = VGroup(aL_start, z_label2, rp)
zL_formula = VGroup(z_label, equals, *terms)
aL_formula = VGroup(*list(aL_start) + [z_label2, rp])
self.set_variables_as_attrs(z_label, zL_formula, aL_formula)
def break_into_computational_graph(self):
@ -2392,7 +2403,7 @@ class SimplestNetworkExample(PreviewLearning):
self.dither()
self.play(MoveToTarget(aL))
self.play(
FadeOut(network_early_layers),
network_early_layers.fade, 1,
ShowCreation(z_to_a_line),
z_to_a_line.flash
)
@ -2444,6 +2455,8 @@ class SimplestNetworkExample(PreviewLearning):
comp_graph.z_to_a_line.copy(),
)
new_subgraph = VGroup(new_terms, new_edges)
self.wLm1 = new_terms[0]
self.zLm1 = new_terms[-1]
self.play(ShowCreation(rect))
self.play(
@ -2465,6 +2478,8 @@ class SimplestNetworkExample(PreviewLearning):
self.remove(rect)
self.dither()
self.prev_comp_subgraph = new_subgraph
def show_number_lines(self):
comp_graph = self.comp_graph
wL, aLm1, bL, zL, aL, C0 = [
@ -2771,21 +2786,581 @@ class SimplestNetworkExample(PreviewLearning):
))
self.dither()
def compute_derivatives(self):
def prepare_for_derivatives(self):
zL_formula = self.zL_formula
aL_formula = self.aL_formula
az_formulas = VGroup(zL_formula, aL_formula)
cost_equation = self.cost_equation
desired_output_words = self.desired_output_words
self.play(FadeOut(self.all_comp_graph_parts))
az_formulas.generate_target()
az_formulas.target.to_edge(RIGHT)
index = 4
cost_eq = cost_equation[index]
z_eq = az_formulas.target[0][1]
x_shift = (z_eq.get_center() - cost_eq.get_center())[0]*RIGHT
cost_equation.generate_target()
Transform(
VGroup(*cost_equation.target[1:index]),
VectorizedPoint(cost_eq.get_left())
).update(1)
cost_equation.target[0].next_to(cost_eq, LEFT, SMALL_BUFF)
cost_equation.target.shift(x_shift)
cost_equation.shift(MED_SMALL_BUFF*DOWN)
self.play(
FadeOut(self.all_comp_graph_parts),
FadeOut(self.desired_output_words),
MoveToTarget(az_formulas),
MoveToTarget(cost_equation)
)
def compute_derivatives(self):
cost_equation = self.cost_equation
zL_formula = self.zL_formula
aL_formula = self.aL_formula
chain_rule_equation = self.chain_rule_equation.copy()
dC_dw, equals, dz_dw, da_dz, dC_da = chain_rule_equation
derivs = VGroup(dC_da, da_dz, dz_dw)
deriv_targets = VGroup()
for deriv in derivs:
deriv.generate_target()
deriv_targets.add(deriv.target)
deriv_targets.arrange_submobjects(DOWN, buff = MED_LARGE_BUFF)
deriv_targets.next_to(dC_dw, DOWN, LARGE_BUFF)
for deriv in derivs:
deriv.equals = TexMobject("=")
deriv.equals.next_to(deriv.target, RIGHT)
#dC_da
self.play(
MoveToTarget(dC_da),
Write(dC_da.equals,)
)
index = 4
cost_rhs = VGroup(*cost_equation[index+1:])
dC_da.rhs = cost_rhs.copy()
two = dC_da.rhs[-1]
two.scale(1.5)
two.next_to(dC_da.rhs[0], LEFT, SMALL_BUFF)
dC_da.rhs.next_to(dC_da.equals, RIGHT)
dC_da.rhs.shift(0.7*SMALL_BUFF*UP)
cost_equation.save_state()
self.play(
cost_equation.next_to, dC_da.rhs,
DOWN, MED_LARGE_BUFF, LEFT
)
self.dither()
self.play(ReplacementTransform(
cost_rhs.copy(), dC_da.rhs,
path_arc = np.pi/2,
))
self.dither()
self.play(cost_equation.restore)
self.dither()
#show_difference
neuron = self.last_neurons[0]
decimal = self.decimals[0]
double_arrow = DoubleArrow(
neuron.get_right(),
self.desired_output_neuron.get_left(),
buff = SMALL_BUFF,
color = RED
)
self.play(ReplacementTransform(
dC_da.rhs.copy(), double_arrow
))
opacity = neuron.get_fill_opacity()
for target_o in 0, opacity:
self.dither(2)
self.play(
neuron.set_fill, None, target_o,
ChangingDecimal(
decimal, lambda a : neuron.get_fill_opacity()
)
)
self.play(FadeOut(double_arrow))
#da_dz
self.play(
MoveToTarget(da_dz),
Write(da_dz.equals)
)
a_rhs = VGroup(*aL_formula[2:])
da_dz.rhs = a_rhs.copy()
prime = TexMobject("'")
prime.move_to(da_dz.rhs[0].get_corner(UP+RIGHT))
da_dz.rhs[0].shift(0.5*SMALL_BUFF*LEFT)
da_dz.rhs.add_to_back(prime)
da_dz.rhs.next_to(da_dz.equals, RIGHT)
da_dz.rhs.shift(0.5*SMALL_BUFF*UP)
aL_formula.save_state()
self.play(
aL_formula.next_to, da_dz.rhs,
DOWN, MED_LARGE_BUFF, LEFT
)
self.dither()
self.play(ReplacementTransform(
a_rhs.copy(), da_dz.rhs,
))
self.dither()
self.play(aL_formula.restore)
self.dither()
#dz_dw
self.play(
MoveToTarget(dz_dw),
Write(dz_dw.equals)
)
z_rhs = VGroup(*zL_formula[2:])
dz_dw.rhs = z_rhs[1].copy()
dz_dw.rhs.next_to(dz_dw.equals, RIGHT)
dz_dw.rhs.shift(SMALL_BUFF*UP)
zL_formula.save_state()
self.play(
zL_formula.next_to, dz_dw.rhs,
DOWN, MED_LARGE_BUFF, LEFT,
)
self.dither()
self.play(ReplacementTransform(
z_rhs[1].copy(), dz_dw.rhs,
))
self.dither()
self.play(zL_formula.restore)
self.dither()
self.derivative_equations = VGroup(dC_da, da_dz, dz_dw)
def get_lost_in_formulas(self):
randy = Randolph()
randy.flip()
randy.scale(0.7)
randy.to_edge(DOWN)
randy.shift(LEFT)
self.play(FadeIn(randy))
self.play(randy.change, "pleading", self.chain_rule_equation)
self.play(Blink(randy))
self.play(randy.change, "maybe")
self.play(Blink(randy))
self.play(FadeOut(randy))
def fire_together_wire_together(self):
pass
dz_dw = self.derivative_equations[2]
rhs = dz_dw.rhs
rhs_copy = rhs.copy()
del_wL = dz_dw[2].copy()
rect = SurroundingRectangle(VGroup(dz_dw, dz_dw.rhs))
edge = self.network_mob.edge_groups[-1][0]
edge.save_state()
neuron = self.last_neurons[1]
decimal = self.decimals[1]
def get_decimal_anims():
return [
ChangingDecimal(decimal, lambda a : neuron.get_fill_opacity()),
UpdateFromFunc(
decimal, lambda m : m.highlight(
WHITE if neuron.get_fill_opacity() < 0.8 \
else BLACK
)
)
]
self.play(ShowCreation(rect))
self.play(FadeOut(rect))
self.play(
del_wL.next_to, edge, UP, SMALL_BUFF
)
self.play(
edge.set_stroke, None, 10,
rate_func = wiggle,
run_time = 3,
)
self.dither()
self.play(rhs.shift, MED_LARGE_BUFF*UP, rate_func = wiggle)
self.play(
rhs_copy.move_to, neuron,
rhs_copy.set_fill, None, 0
)
self.remove(rhs_copy)
self.play(
neuron.set_fill, None, 0,
*get_decimal_anims(),
run_time = 3,
rate_func = there_and_back
)
self.dither()
#Fire together wire together
opacity = neuron.get_fill_opacity()
self.play(
neuron.set_fill, None, 0.99,
*get_decimal_anims()
)
self.play(edge.set_stroke, None, 8)
self.play(
neuron.set_fill, None, opacity,
*get_decimal_anims()
)
self.play(edge.restore, FadeOut(del_wL))
self.dither(3)
def organize_chain_rule_rhs(self):
fracs = self.derivative_equations
equals_group = VGroup(*[frac.equals for frac in fracs])
rhs_group = VGroup(*[frac.rhs for frac in reversed(fracs)])
chain_rule_equation = self.chain_rule_equation
equals = TexMobject("=")
equals.next_to(chain_rule_equation, RIGHT)
rhs_group.generate_target()
rhs_group.target.arrange_submobjects(RIGHT, buff = SMALL_BUFF)
rhs_group.target.next_to(equals, RIGHT)
rhs_group.target.shift(SMALL_BUFF*UP)
right_group = VGroup(
self.cost_equation, self.zL_formula, self.aL_formula,
self.network_mob, self.decimals,
self.a_labels, self.a_label_arrows,
self.y_label, self.y_label_arrow,
self.desired_output_neuron,
self.desired_output_rect,
self.desired_output_decimal,
)
self.play(
MoveToTarget(rhs_group, path_arc = np.pi/2),
Write(equals),
FadeOut(fracs),
FadeOut(equals_group),
right_group.to_corner, DOWN+RIGHT
)
self.dither()
rhs_group.add(equals)
self.chain_rule_rhs = rhs_group
def show_average_derivative(self):
dC0_dw = self.chain_rule_equation[0]
full_derivative = TexMobject(
"{\\partial C", "\\over", "\\partial w^{(L)}}",
"=", "\\frac{1}{n}", "\\sum_{k=0}^{n-1}",
"{\\partial C_k", "\\over", "\\partial w^{(L)}}"
)
full_derivative.highlight_by_tex_to_color_map({
"partial C" : self.cost_color,
"partial w" : self.del_wL.get_color()
})
full_derivative.to_edge(LEFT)
dCk_dw = VGroup(*full_derivative[-3:])
lhs = VGroup(*full_derivative[:3])
rhs = VGroup(*full_derivative[4:])
lhs_brace = Brace(lhs, DOWN)
lhs_text = lhs_brace.get_text("Derivative of \\\\ full cost function")
rhs_brace = Brace(rhs, UP)
rhs_text = rhs_brace.get_text("Average of all \\\\ training examples")
VGroup(
full_derivative, lhs_brace, lhs_text, rhs_brace, rhs_text
).to_corner(DOWN+LEFT)
mover = dC0_dw.copy()
self.play(Transform(mover, dCk_dw))
self.play(Write(full_derivative, run_time = 2))
self.remove(mover)
self.play(
GrowFromCenter(lhs_brace),
GrowFromCenter(rhs_brace),
Write(lhs_text, run_time = 2),
Write(rhs_text, run_time = 2),
)
self.cycle_through_altnernate_training_examples()
self.play(*map(FadeOut, [
VGroup(*full_derivative[3:]),
lhs_brace, lhs_text,
rhs_brace, rhs_text,
]))
self.dC_dw = lhs
def cycle_through_altnernate_training_examples(self):
neurons = VGroup(
self.desired_output_neuron, *self.last_neurons
)
decimals = VGroup(
self.desired_output_decimal, *self.decimals
)
group = VGroup(neurons, decimals)
group.save_state()
for x in range(20):
for n, d in zip(neurons, decimals):
o = np.random.random()
if n is self.desired_output_neuron:
o = np.round(o)
n.set_fill(opacity = o)
Transform(
d, self.get_neuron_activation_decimal(n)
).update(1)
self.dither(0.2)
self.play(group.restore, run_time = 0.2)
def show_gradient(self):
dC_dw = self.dC_dw
dC_dw.generate_target()
terms = VGroup(
TexMobject("{\\partial C", "\\over", "\\partial w^{(1)}"),
TexMobject("{\\partial C", "\\over", "\\partial b^{(1)}"),
TexMobject("\\vdots"),
dC_dw.target,
TexMobject("{\\partial C", "\\over", "\\partial b^{(L)}"),
)
for term in terms:
if isinstance(term, TexMobject):
term.highlight_by_tex_to_color_map({
"partial C" : RED,
"partial w" : BLUE,
"partial b" : MAROON_B,
})
terms.arrange_submobjects(DOWN, buff = MED_LARGE_BUFF)
lb, rb = brackets = TexMobject("[]")
brackets.scale(3)
brackets.stretch_to_fit_height(1.1*terms.get_height())
lb.next_to(terms, LEFT, buff = SMALL_BUFF)
rb.next_to(terms, RIGHT, buff = SMALL_BUFF)
vect = VGroup(lb, terms, rb)
vect.scale_to_fit_height(5)
lhs = TexMobject("\\nabla C", "=")
lhs[0].highlight(RED)
lhs.next_to(vect, LEFT)
VGroup(lhs, vect).to_corner(DOWN+LEFT, buff = LARGE_BUFF)
terms.remove(dC_dw.target)
self.play(
MoveToTarget(dC_dw),
Write(vect, run_time = 1)
)
terms.add(dC_dw)
self.play(Write(lhs))
self.dither(2)
self.play(FadeOut(VGroup(lhs, vect)))
def transition_to_derivative_wrt_b(self):
all_comp_graph_parts = self.all_comp_graph_parts
all_comp_graph_parts.scale(
1.3, about_point = all_comp_graph_parts.get_bottom()
)
comp_graph = self.comp_graph
wL, bL, zL, aL, C0 = [
getattr(comp_graph, attr)
for attr in ["wL", "bL", "zL", "aL", "C0"]
]
path_to_C = VGroup(wL, zL, aL, C0)
top_expression = VGroup(
self.chain_rule_equation,
self.chain_rule_rhs
)
rect = SurroundingRectangle(top_expression)
self.play(ShowCreation(rect))
self.play(FadeIn(comp_graph), FadeOut(rect))
for x in range(2):
self.play(LaggedStart(
Indicate, path_to_C,
rate_func = there_and_back,
run_time = 1.5,
lag_ratio = 0.7,
))
self.dither()
def show_derivative_wrt_b(self):
pass
comp_graph = self.comp_graph
dC0_dw = self.chain_rule_equation[0]
dz_dw = self.chain_rule_equation[2]
aLm1 = self.chain_rule_rhs[0]
left_term_group = VGroup(dz_dw, aLm1)
del_w = dC0_dw[2]
del_b = TexMobject("\\partial b^{(L)}")
del_b.highlight(MAROON_B)
del_b.replace(del_w)
dz_db = TexMobject(
"{\\partial z^{(L)}", "\\over", "\\partial b^{(L)}}"
)
dz_db.highlight_by_tex_to_color_map({
"partial z" : self.z_color,
"partial b" : MAROON_B
})
dz_db.replace(dz_dw)
one = TexMobject("1")
one.move_to(aLm1, RIGHT)
arrow = Arrow(
dz_db.get_bottom(),
one.get_bottom(),
use_rectangular_stem = False,
path_arc = np.pi/2,
color = WHITE,
)
arrow.set_stroke(width = 2)
wL, bL, zL, aL, C0 = [
getattr(comp_graph, attr)
for attr in ["wL", "bL", "zL", "aL", "C0"]
]
path_to_C = VGroup(bL, zL, aL, C0)
def get_path_animation():
return LaggedStart(
Indicate, path_to_C,
rate_func = there_and_back,
run_time = 1.5,
lag_ratio = 0.7,
)
self.play(get_path_animation())
self.play(
left_term_group.shift, DOWN,
left_term_group.fade, 1,
)
self.remove(left_term_group)
self.chain_rule_equation.remove(dz_dw)
self.chain_rule_rhs.remove(aLm1)
self.play(
Transform(del_w, del_b),
FadeIn(dz_db)
)
self.play(get_path_animation())
self.play(
ShowCreation(arrow),
ReplacementTransform(
dz_db.copy(), one,
path_arc = arrow.path_arc
)
)
self.dither(2)
self.play(*map(FadeOut, [dz_db, arrow, one]))
self.dz_db = dz_db
def show_derivative_wrt_a(self):
pass
denom = self.chain_rule_equation[0][2]
numer = VGroup(*self.chain_rule_equation[0][:2])
del_aLm1 = TexMobject("\\partial a^{(L-1)}")
del_aLm1.scale(0.8)
del_aLm1.move_to(denom)
dz_daLm1 = TexMobject(
"{\\partial z^{(L)}", "\\over", "\\partial a^{(L-1)}}"
)
dz_daLm1.scale(0.8)
dz_daLm1.next_to(self.chain_rule_equation[1], RIGHT, SMALL_BUFF)
dz_daLm1.shift(0.7*SMALL_BUFF*UP)
dz_daLm1[0].highlight(self.z_color)
wL = self.zL_formula[2].copy()
wL.next_to(self.chain_rule_rhs, LEFT, SMALL_BUFF)
arrow = Arrow(
dz_daLm1.get_bottom(), wL.get_bottom(),
path_arc = np.pi/2,
use_rectangular_stem = False,
color = WHITE,
)
comp_graph = self.comp_graph
path_to_C = VGroup(*[
getattr(comp_graph, attr)
for attr in ["aLm1", "zL", "aL", "C0"]
])
def get_path_animation():
return LaggedStart(
Indicate, path_to_C,
rate_func = there_and_back,
run_time = 1.5,
lag_ratio = 0.7,
)
self.play(get_path_animation())
self.play(
numer.shift, SMALL_BUFF*UP,
Transform(denom, del_aLm1),
FadeIn(dz_daLm1),
VGroup(*self.chain_rule_equation[-2:]).shift, SMALL_BUFF*RIGHT,
)
self.dither()
self.play(
ShowCreation(arrow),
ReplacementTransform(
dz_daLm1.copy(), wL,
path_arc = arrow.path_arc
)
)
self.dither(2)
self.chain_rule_rhs.add(wL, arrow)
self.chain_rule_equation.add(dz_daLm1)
def show_previous_weight_and_bias(self):
pass
to_fade = self.chain_rule_rhs
comp_graph = self.comp_graph
prev_comp_subgraph = self.prev_comp_subgraph
prev_comp_subgraph.scale(0.8)
prev_comp_subgraph.next_to(comp_graph, UP, SMALL_BUFF)
prev_layer = VGroup(
self.network_mob.layers[1],
self.network_mob.edge_groups[1],
)
for mob in prev_layer:
mob.restore()
prev_layer.next_to(self.last_neurons, LEFT, buff = 0)
self.remove(prev_layer)
self.play(LaggedStart(FadeOut, to_fade, run_time = 1))
self.play(
ShowCreation(prev_comp_subgraph, run_time = 1),
self.chain_rule_equation.to_edge, RIGHT
)
self.play(FadeIn(prev_layer))
###
neuron = self.network_mob.layers[1].neurons[0]
decimal = self.get_neuron_activation_decimal(neuron)
a_label = TexMobject("a^{(L-2)}")
a_label.replace(self.a_labels[1])
arrow = self.a_label_arrows[1].copy()
VGroup(a_label, arrow).shift(
neuron.get_center() - self.last_neurons[1].get_center()
)
self.play(
Write(a_label, run_time = 1),
Write(decimal, run_time = 1),
GrowArrow(arrow),
)
def animate_long_path(self):
comp_graph = self.comp_graph
path_to_C = VGroup(
self.wLm1, self.zLm1,
*[
getattr(comp_graph, attr)
for attr in ["aLm1", "zL", "aL", "C0"]
]
)
for x in range(2):
self.play(LaggedStart(
Indicate, path_to_C,
rate_func = there_and_back,
run_time = 1.5,
lag_ratio = 0.4,
))
self.dither(2)
###
@ -2803,11 +3378,11 @@ class SimplestNetworkExample(PreviewLearning):
result = VGroup(comp_graph)
for attr in "wL", "zL", "aL", "C0":
sym = getattr(comp_graph, attr)
comp_graph.add(
result.add(
sym.arrow, sym.number_line, sym.dot
)
del_sym = getattr(self, "del_" + attr)
comp_graph.add(del_sym, del_sym.brace)
result.add(del_sym, del_sym.brace)
self.all_comp_graph_parts = result
return result