mirror of
https://github.com/3b1b/manim.git
synced 2025-09-19 04:41:56 +00:00
Finished SimplestNetworkExample in nn/part3
This commit is contained in:
parent
ac079f182a
commit
065de1af0c
1 changed files with 589 additions and 14 deletions
603
nn/part3.py
603
nn/part3.py
|
@ -2005,8 +2005,6 @@ class SimplestNetworkExample(PreviewLearning):
|
|||
"derivative_scale_val" : 0.85,
|
||||
}
|
||||
def construct(self):
|
||||
self.force_skipping()
|
||||
|
||||
self.seed_random_libraries()
|
||||
self.collapse_ordinary_network()
|
||||
self.focus_just_on_last_two_layers()
|
||||
|
@ -2024,11 +2022,18 @@ class SimplestNetworkExample(PreviewLearning):
|
|||
self.show_chain_rule()
|
||||
self.name_chain_rule()
|
||||
self.indicate_everything_on_screen()
|
||||
self.prepare_for_derivatives()
|
||||
self.compute_derivatives()
|
||||
self.get_lost_in_formulas()
|
||||
self.fire_together_wire_together()
|
||||
self.organize_chain_rule_rhs()
|
||||
self.show_average_derivative()
|
||||
self.show_gradient()
|
||||
self.transition_to_derivative_wrt_b()
|
||||
self.show_derivative_wrt_b()
|
||||
self.show_derivative_wrt_a()
|
||||
self.show_previous_weight_and_bias()
|
||||
self.animate_long_path()
|
||||
|
||||
def seed_random_libraries(self):
|
||||
np.random.seed(self.random_seed)
|
||||
|
@ -2103,6 +2108,8 @@ class SimplestNetworkExample(PreviewLearning):
|
|||
))
|
||||
self.dither()
|
||||
|
||||
self.prev_layers = to_fade
|
||||
|
||||
def label_neurons(self):
|
||||
neurons = VGroup(*[
|
||||
self.network_mob.layers[i].neurons[0]
|
||||
|
@ -2302,7 +2309,11 @@ class SimplestNetworkExample(PreviewLearning):
|
|||
def introduce_z(self):
|
||||
terms = self.weighted_sum_terms
|
||||
terms.generate_target()
|
||||
terms.target.next_to(self.formula, UP, aligned_edge = RIGHT)
|
||||
terms.target.next_to(
|
||||
self.formula, UP,
|
||||
buff = MED_LARGE_BUFF,
|
||||
aligned_edge = RIGHT
|
||||
)
|
||||
terms.target.shift(MED_LARGE_BUFF*RIGHT)
|
||||
equals = TexMobject("=")
|
||||
equals.next_to(terms.target[0][0], LEFT)
|
||||
|
@ -2332,8 +2343,8 @@ class SimplestNetworkExample(PreviewLearning):
|
|||
)
|
||||
self.dither()
|
||||
|
||||
zL_formula = VGroup(z_label, equals, terms)
|
||||
aL_formula = VGroup(aL_start, z_label2, rp)
|
||||
zL_formula = VGroup(z_label, equals, *terms)
|
||||
aL_formula = VGroup(*list(aL_start) + [z_label2, rp])
|
||||
self.set_variables_as_attrs(z_label, zL_formula, aL_formula)
|
||||
|
||||
def break_into_computational_graph(self):
|
||||
|
@ -2392,7 +2403,7 @@ class SimplestNetworkExample(PreviewLearning):
|
|||
self.dither()
|
||||
self.play(MoveToTarget(aL))
|
||||
self.play(
|
||||
FadeOut(network_early_layers),
|
||||
network_early_layers.fade, 1,
|
||||
ShowCreation(z_to_a_line),
|
||||
z_to_a_line.flash
|
||||
)
|
||||
|
@ -2444,6 +2455,8 @@ class SimplestNetworkExample(PreviewLearning):
|
|||
comp_graph.z_to_a_line.copy(),
|
||||
)
|
||||
new_subgraph = VGroup(new_terms, new_edges)
|
||||
self.wLm1 = new_terms[0]
|
||||
self.zLm1 = new_terms[-1]
|
||||
|
||||
self.play(ShowCreation(rect))
|
||||
self.play(
|
||||
|
@ -2465,6 +2478,8 @@ class SimplestNetworkExample(PreviewLearning):
|
|||
self.remove(rect)
|
||||
self.dither()
|
||||
|
||||
self.prev_comp_subgraph = new_subgraph
|
||||
|
||||
def show_number_lines(self):
|
||||
comp_graph = self.comp_graph
|
||||
wL, aLm1, bL, zL, aL, C0 = [
|
||||
|
@ -2771,21 +2786,581 @@ class SimplestNetworkExample(PreviewLearning):
|
|||
))
|
||||
self.dither()
|
||||
|
||||
def compute_derivatives(self):
|
||||
def prepare_for_derivatives(self):
|
||||
zL_formula = self.zL_formula
|
||||
aL_formula = self.aL_formula
|
||||
az_formulas = VGroup(zL_formula, aL_formula)
|
||||
cost_equation = self.cost_equation
|
||||
desired_output_words = self.desired_output_words
|
||||
|
||||
self.play(FadeOut(self.all_comp_graph_parts))
|
||||
az_formulas.generate_target()
|
||||
az_formulas.target.to_edge(RIGHT)
|
||||
|
||||
index = 4
|
||||
cost_eq = cost_equation[index]
|
||||
z_eq = az_formulas.target[0][1]
|
||||
x_shift = (z_eq.get_center() - cost_eq.get_center())[0]*RIGHT
|
||||
cost_equation.generate_target()
|
||||
Transform(
|
||||
VGroup(*cost_equation.target[1:index]),
|
||||
VectorizedPoint(cost_eq.get_left())
|
||||
).update(1)
|
||||
cost_equation.target[0].next_to(cost_eq, LEFT, SMALL_BUFF)
|
||||
cost_equation.target.shift(x_shift)
|
||||
cost_equation.shift(MED_SMALL_BUFF*DOWN)
|
||||
|
||||
self.play(
|
||||
FadeOut(self.all_comp_graph_parts),
|
||||
FadeOut(self.desired_output_words),
|
||||
MoveToTarget(az_formulas),
|
||||
MoveToTarget(cost_equation)
|
||||
)
|
||||
|
||||
def compute_derivatives(self):
|
||||
cost_equation = self.cost_equation
|
||||
zL_formula = self.zL_formula
|
||||
aL_formula = self.aL_formula
|
||||
chain_rule_equation = self.chain_rule_equation.copy()
|
||||
dC_dw, equals, dz_dw, da_dz, dC_da = chain_rule_equation
|
||||
|
||||
derivs = VGroup(dC_da, da_dz, dz_dw)
|
||||
deriv_targets = VGroup()
|
||||
for deriv in derivs:
|
||||
deriv.generate_target()
|
||||
deriv_targets.add(deriv.target)
|
||||
deriv_targets.arrange_submobjects(DOWN, buff = MED_LARGE_BUFF)
|
||||
deriv_targets.next_to(dC_dw, DOWN, LARGE_BUFF)
|
||||
for deriv in derivs:
|
||||
deriv.equals = TexMobject("=")
|
||||
deriv.equals.next_to(deriv.target, RIGHT)
|
||||
|
||||
#dC_da
|
||||
self.play(
|
||||
MoveToTarget(dC_da),
|
||||
Write(dC_da.equals,)
|
||||
)
|
||||
index = 4
|
||||
cost_rhs = VGroup(*cost_equation[index+1:])
|
||||
dC_da.rhs = cost_rhs.copy()
|
||||
two = dC_da.rhs[-1]
|
||||
two.scale(1.5)
|
||||
two.next_to(dC_da.rhs[0], LEFT, SMALL_BUFF)
|
||||
dC_da.rhs.next_to(dC_da.equals, RIGHT)
|
||||
dC_da.rhs.shift(0.7*SMALL_BUFF*UP)
|
||||
cost_equation.save_state()
|
||||
self.play(
|
||||
cost_equation.next_to, dC_da.rhs,
|
||||
DOWN, MED_LARGE_BUFF, LEFT
|
||||
)
|
||||
self.dither()
|
||||
self.play(ReplacementTransform(
|
||||
cost_rhs.copy(), dC_da.rhs,
|
||||
path_arc = np.pi/2,
|
||||
))
|
||||
self.dither()
|
||||
self.play(cost_equation.restore)
|
||||
self.dither()
|
||||
|
||||
#show_difference
|
||||
neuron = self.last_neurons[0]
|
||||
decimal = self.decimals[0]
|
||||
double_arrow = DoubleArrow(
|
||||
neuron.get_right(),
|
||||
self.desired_output_neuron.get_left(),
|
||||
buff = SMALL_BUFF,
|
||||
color = RED
|
||||
)
|
||||
self.play(ReplacementTransform(
|
||||
dC_da.rhs.copy(), double_arrow
|
||||
))
|
||||
opacity = neuron.get_fill_opacity()
|
||||
for target_o in 0, opacity:
|
||||
self.dither(2)
|
||||
self.play(
|
||||
neuron.set_fill, None, target_o,
|
||||
ChangingDecimal(
|
||||
decimal, lambda a : neuron.get_fill_opacity()
|
||||
)
|
||||
)
|
||||
self.play(FadeOut(double_arrow))
|
||||
|
||||
#da_dz
|
||||
self.play(
|
||||
MoveToTarget(da_dz),
|
||||
Write(da_dz.equals)
|
||||
)
|
||||
a_rhs = VGroup(*aL_formula[2:])
|
||||
da_dz.rhs = a_rhs.copy()
|
||||
prime = TexMobject("'")
|
||||
prime.move_to(da_dz.rhs[0].get_corner(UP+RIGHT))
|
||||
da_dz.rhs[0].shift(0.5*SMALL_BUFF*LEFT)
|
||||
da_dz.rhs.add_to_back(prime)
|
||||
da_dz.rhs.next_to(da_dz.equals, RIGHT)
|
||||
da_dz.rhs.shift(0.5*SMALL_BUFF*UP)
|
||||
aL_formula.save_state()
|
||||
self.play(
|
||||
aL_formula.next_to, da_dz.rhs,
|
||||
DOWN, MED_LARGE_BUFF, LEFT
|
||||
)
|
||||
self.dither()
|
||||
self.play(ReplacementTransform(
|
||||
a_rhs.copy(), da_dz.rhs,
|
||||
))
|
||||
self.dither()
|
||||
self.play(aL_formula.restore)
|
||||
self.dither()
|
||||
|
||||
#dz_dw
|
||||
self.play(
|
||||
MoveToTarget(dz_dw),
|
||||
Write(dz_dw.equals)
|
||||
)
|
||||
z_rhs = VGroup(*zL_formula[2:])
|
||||
dz_dw.rhs = z_rhs[1].copy()
|
||||
dz_dw.rhs.next_to(dz_dw.equals, RIGHT)
|
||||
dz_dw.rhs.shift(SMALL_BUFF*UP)
|
||||
zL_formula.save_state()
|
||||
self.play(
|
||||
zL_formula.next_to, dz_dw.rhs,
|
||||
DOWN, MED_LARGE_BUFF, LEFT,
|
||||
)
|
||||
self.dither()
|
||||
self.play(ReplacementTransform(
|
||||
z_rhs[1].copy(), dz_dw.rhs,
|
||||
))
|
||||
self.dither()
|
||||
self.play(zL_formula.restore)
|
||||
self.dither()
|
||||
|
||||
self.derivative_equations = VGroup(dC_da, da_dz, dz_dw)
|
||||
|
||||
def get_lost_in_formulas(self):
|
||||
randy = Randolph()
|
||||
randy.flip()
|
||||
randy.scale(0.7)
|
||||
randy.to_edge(DOWN)
|
||||
randy.shift(LEFT)
|
||||
|
||||
self.play(FadeIn(randy))
|
||||
self.play(randy.change, "pleading", self.chain_rule_equation)
|
||||
self.play(Blink(randy))
|
||||
self.play(randy.change, "maybe")
|
||||
self.play(Blink(randy))
|
||||
self.play(FadeOut(randy))
|
||||
|
||||
def fire_together_wire_together(self):
|
||||
pass
|
||||
dz_dw = self.derivative_equations[2]
|
||||
rhs = dz_dw.rhs
|
||||
rhs_copy = rhs.copy()
|
||||
del_wL = dz_dw[2].copy()
|
||||
rect = SurroundingRectangle(VGroup(dz_dw, dz_dw.rhs))
|
||||
edge = self.network_mob.edge_groups[-1][0]
|
||||
edge.save_state()
|
||||
neuron = self.last_neurons[1]
|
||||
decimal = self.decimals[1]
|
||||
|
||||
def get_decimal_anims():
|
||||
return [
|
||||
ChangingDecimal(decimal, lambda a : neuron.get_fill_opacity()),
|
||||
UpdateFromFunc(
|
||||
decimal, lambda m : m.highlight(
|
||||
WHITE if neuron.get_fill_opacity() < 0.8 \
|
||||
else BLACK
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
self.play(ShowCreation(rect))
|
||||
self.play(FadeOut(rect))
|
||||
self.play(
|
||||
del_wL.next_to, edge, UP, SMALL_BUFF
|
||||
)
|
||||
self.play(
|
||||
edge.set_stroke, None, 10,
|
||||
rate_func = wiggle,
|
||||
run_time = 3,
|
||||
)
|
||||
self.dither()
|
||||
self.play(rhs.shift, MED_LARGE_BUFF*UP, rate_func = wiggle)
|
||||
self.play(
|
||||
rhs_copy.move_to, neuron,
|
||||
rhs_copy.set_fill, None, 0
|
||||
)
|
||||
self.remove(rhs_copy)
|
||||
self.play(
|
||||
neuron.set_fill, None, 0,
|
||||
*get_decimal_anims(),
|
||||
run_time = 3,
|
||||
rate_func = there_and_back
|
||||
)
|
||||
self.dither()
|
||||
|
||||
#Fire together wire together
|
||||
opacity = neuron.get_fill_opacity()
|
||||
self.play(
|
||||
neuron.set_fill, None, 0.99,
|
||||
*get_decimal_anims()
|
||||
)
|
||||
self.play(edge.set_stroke, None, 8)
|
||||
self.play(
|
||||
neuron.set_fill, None, opacity,
|
||||
*get_decimal_anims()
|
||||
)
|
||||
self.play(edge.restore, FadeOut(del_wL))
|
||||
self.dither(3)
|
||||
|
||||
def organize_chain_rule_rhs(self):
|
||||
fracs = self.derivative_equations
|
||||
equals_group = VGroup(*[frac.equals for frac in fracs])
|
||||
rhs_group = VGroup(*[frac.rhs for frac in reversed(fracs)])
|
||||
|
||||
chain_rule_equation = self.chain_rule_equation
|
||||
equals = TexMobject("=")
|
||||
equals.next_to(chain_rule_equation, RIGHT)
|
||||
|
||||
rhs_group.generate_target()
|
||||
rhs_group.target.arrange_submobjects(RIGHT, buff = SMALL_BUFF)
|
||||
rhs_group.target.next_to(equals, RIGHT)
|
||||
rhs_group.target.shift(SMALL_BUFF*UP)
|
||||
|
||||
right_group = VGroup(
|
||||
self.cost_equation, self.zL_formula, self.aL_formula,
|
||||
self.network_mob, self.decimals,
|
||||
self.a_labels, self.a_label_arrows,
|
||||
self.y_label, self.y_label_arrow,
|
||||
self.desired_output_neuron,
|
||||
self.desired_output_rect,
|
||||
self.desired_output_decimal,
|
||||
)
|
||||
|
||||
self.play(
|
||||
MoveToTarget(rhs_group, path_arc = np.pi/2),
|
||||
Write(equals),
|
||||
FadeOut(fracs),
|
||||
FadeOut(equals_group),
|
||||
right_group.to_corner, DOWN+RIGHT
|
||||
)
|
||||
self.dither()
|
||||
|
||||
rhs_group.add(equals)
|
||||
self.chain_rule_rhs = rhs_group
|
||||
|
||||
def show_average_derivative(self):
|
||||
dC0_dw = self.chain_rule_equation[0]
|
||||
full_derivative = TexMobject(
|
||||
"{\\partial C", "\\over", "\\partial w^{(L)}}",
|
||||
"=", "\\frac{1}{n}", "\\sum_{k=0}^{n-1}",
|
||||
"{\\partial C_k", "\\over", "\\partial w^{(L)}}"
|
||||
)
|
||||
full_derivative.highlight_by_tex_to_color_map({
|
||||
"partial C" : self.cost_color,
|
||||
"partial w" : self.del_wL.get_color()
|
||||
})
|
||||
full_derivative.to_edge(LEFT)
|
||||
|
||||
dCk_dw = VGroup(*full_derivative[-3:])
|
||||
lhs = VGroup(*full_derivative[:3])
|
||||
rhs = VGroup(*full_derivative[4:])
|
||||
lhs_brace = Brace(lhs, DOWN)
|
||||
lhs_text = lhs_brace.get_text("Derivative of \\\\ full cost function")
|
||||
rhs_brace = Brace(rhs, UP)
|
||||
rhs_text = rhs_brace.get_text("Average of all \\\\ training examples")
|
||||
VGroup(
|
||||
full_derivative, lhs_brace, lhs_text, rhs_brace, rhs_text
|
||||
).to_corner(DOWN+LEFT)
|
||||
|
||||
mover = dC0_dw.copy()
|
||||
self.play(Transform(mover, dCk_dw))
|
||||
self.play(Write(full_derivative, run_time = 2))
|
||||
self.remove(mover)
|
||||
self.play(
|
||||
GrowFromCenter(lhs_brace),
|
||||
GrowFromCenter(rhs_brace),
|
||||
Write(lhs_text, run_time = 2),
|
||||
Write(rhs_text, run_time = 2),
|
||||
)
|
||||
self.cycle_through_altnernate_training_examples()
|
||||
self.play(*map(FadeOut, [
|
||||
VGroup(*full_derivative[3:]),
|
||||
lhs_brace, lhs_text,
|
||||
rhs_brace, rhs_text,
|
||||
]))
|
||||
|
||||
self.dC_dw = lhs
|
||||
|
||||
def cycle_through_altnernate_training_examples(self):
|
||||
neurons = VGroup(
|
||||
self.desired_output_neuron, *self.last_neurons
|
||||
)
|
||||
decimals = VGroup(
|
||||
self.desired_output_decimal, *self.decimals
|
||||
)
|
||||
group = VGroup(neurons, decimals)
|
||||
group.save_state()
|
||||
|
||||
for x in range(20):
|
||||
for n, d in zip(neurons, decimals):
|
||||
o = np.random.random()
|
||||
if n is self.desired_output_neuron:
|
||||
o = np.round(o)
|
||||
n.set_fill(opacity = o)
|
||||
Transform(
|
||||
d, self.get_neuron_activation_decimal(n)
|
||||
).update(1)
|
||||
self.dither(0.2)
|
||||
self.play(group.restore, run_time = 0.2)
|
||||
|
||||
def show_gradient(self):
|
||||
dC_dw = self.dC_dw
|
||||
dC_dw.generate_target()
|
||||
terms = VGroup(
|
||||
TexMobject("{\\partial C", "\\over", "\\partial w^{(1)}"),
|
||||
TexMobject("{\\partial C", "\\over", "\\partial b^{(1)}"),
|
||||
TexMobject("\\vdots"),
|
||||
dC_dw.target,
|
||||
TexMobject("{\\partial C", "\\over", "\\partial b^{(L)}"),
|
||||
)
|
||||
for term in terms:
|
||||
if isinstance(term, TexMobject):
|
||||
term.highlight_by_tex_to_color_map({
|
||||
"partial C" : RED,
|
||||
"partial w" : BLUE,
|
||||
"partial b" : MAROON_B,
|
||||
})
|
||||
terms.arrange_submobjects(DOWN, buff = MED_LARGE_BUFF)
|
||||
lb, rb = brackets = TexMobject("[]")
|
||||
brackets.scale(3)
|
||||
brackets.stretch_to_fit_height(1.1*terms.get_height())
|
||||
lb.next_to(terms, LEFT, buff = SMALL_BUFF)
|
||||
rb.next_to(terms, RIGHT, buff = SMALL_BUFF)
|
||||
vect = VGroup(lb, terms, rb)
|
||||
vect.scale_to_fit_height(5)
|
||||
lhs = TexMobject("\\nabla C", "=")
|
||||
lhs[0].highlight(RED)
|
||||
lhs.next_to(vect, LEFT)
|
||||
VGroup(lhs, vect).to_corner(DOWN+LEFT, buff = LARGE_BUFF)
|
||||
terms.remove(dC_dw.target)
|
||||
|
||||
self.play(
|
||||
MoveToTarget(dC_dw),
|
||||
Write(vect, run_time = 1)
|
||||
)
|
||||
terms.add(dC_dw)
|
||||
self.play(Write(lhs))
|
||||
self.dither(2)
|
||||
self.play(FadeOut(VGroup(lhs, vect)))
|
||||
|
||||
def transition_to_derivative_wrt_b(self):
|
||||
all_comp_graph_parts = self.all_comp_graph_parts
|
||||
all_comp_graph_parts.scale(
|
||||
1.3, about_point = all_comp_graph_parts.get_bottom()
|
||||
)
|
||||
comp_graph = self.comp_graph
|
||||
wL, bL, zL, aL, C0 = [
|
||||
getattr(comp_graph, attr)
|
||||
for attr in ["wL", "bL", "zL", "aL", "C0"]
|
||||
]
|
||||
path_to_C = VGroup(wL, zL, aL, C0)
|
||||
|
||||
top_expression = VGroup(
|
||||
self.chain_rule_equation,
|
||||
self.chain_rule_rhs
|
||||
)
|
||||
rect = SurroundingRectangle(top_expression)
|
||||
|
||||
self.play(ShowCreation(rect))
|
||||
self.play(FadeIn(comp_graph), FadeOut(rect))
|
||||
for x in range(2):
|
||||
self.play(LaggedStart(
|
||||
Indicate, path_to_C,
|
||||
rate_func = there_and_back,
|
||||
run_time = 1.5,
|
||||
lag_ratio = 0.7,
|
||||
))
|
||||
self.dither()
|
||||
|
||||
def show_derivative_wrt_b(self):
|
||||
pass
|
||||
comp_graph = self.comp_graph
|
||||
dC0_dw = self.chain_rule_equation[0]
|
||||
dz_dw = self.chain_rule_equation[2]
|
||||
aLm1 = self.chain_rule_rhs[0]
|
||||
left_term_group = VGroup(dz_dw, aLm1)
|
||||
|
||||
del_w = dC0_dw[2]
|
||||
del_b = TexMobject("\\partial b^{(L)}")
|
||||
del_b.highlight(MAROON_B)
|
||||
del_b.replace(del_w)
|
||||
|
||||
dz_db = TexMobject(
|
||||
"{\\partial z^{(L)}", "\\over", "\\partial b^{(L)}}"
|
||||
)
|
||||
dz_db.highlight_by_tex_to_color_map({
|
||||
"partial z" : self.z_color,
|
||||
"partial b" : MAROON_B
|
||||
})
|
||||
dz_db.replace(dz_dw)
|
||||
|
||||
one = TexMobject("1")
|
||||
one.move_to(aLm1, RIGHT)
|
||||
arrow = Arrow(
|
||||
dz_db.get_bottom(),
|
||||
one.get_bottom(),
|
||||
use_rectangular_stem = False,
|
||||
path_arc = np.pi/2,
|
||||
color = WHITE,
|
||||
)
|
||||
arrow.set_stroke(width = 2)
|
||||
|
||||
wL, bL, zL, aL, C0 = [
|
||||
getattr(comp_graph, attr)
|
||||
for attr in ["wL", "bL", "zL", "aL", "C0"]
|
||||
]
|
||||
path_to_C = VGroup(bL, zL, aL, C0)
|
||||
def get_path_animation():
|
||||
return LaggedStart(
|
||||
Indicate, path_to_C,
|
||||
rate_func = there_and_back,
|
||||
run_time = 1.5,
|
||||
lag_ratio = 0.7,
|
||||
)
|
||||
|
||||
self.play(get_path_animation())
|
||||
self.play(
|
||||
left_term_group.shift, DOWN,
|
||||
left_term_group.fade, 1,
|
||||
)
|
||||
self.remove(left_term_group)
|
||||
self.chain_rule_equation.remove(dz_dw)
|
||||
self.chain_rule_rhs.remove(aLm1)
|
||||
self.play(
|
||||
Transform(del_w, del_b),
|
||||
FadeIn(dz_db)
|
||||
)
|
||||
self.play(get_path_animation())
|
||||
self.play(
|
||||
ShowCreation(arrow),
|
||||
ReplacementTransform(
|
||||
dz_db.copy(), one,
|
||||
path_arc = arrow.path_arc
|
||||
)
|
||||
)
|
||||
self.dither(2)
|
||||
self.play(*map(FadeOut, [dz_db, arrow, one]))
|
||||
|
||||
self.dz_db = dz_db
|
||||
|
||||
def show_derivative_wrt_a(self):
|
||||
pass
|
||||
denom = self.chain_rule_equation[0][2]
|
||||
numer = VGroup(*self.chain_rule_equation[0][:2])
|
||||
del_aLm1 = TexMobject("\\partial a^{(L-1)}")
|
||||
del_aLm1.scale(0.8)
|
||||
del_aLm1.move_to(denom)
|
||||
dz_daLm1 = TexMobject(
|
||||
"{\\partial z^{(L)}", "\\over", "\\partial a^{(L-1)}}"
|
||||
)
|
||||
dz_daLm1.scale(0.8)
|
||||
dz_daLm1.next_to(self.chain_rule_equation[1], RIGHT, SMALL_BUFF)
|
||||
dz_daLm1.shift(0.7*SMALL_BUFF*UP)
|
||||
dz_daLm1[0].highlight(self.z_color)
|
||||
wL = self.zL_formula[2].copy()
|
||||
wL.next_to(self.chain_rule_rhs, LEFT, SMALL_BUFF)
|
||||
|
||||
arrow = Arrow(
|
||||
dz_daLm1.get_bottom(), wL.get_bottom(),
|
||||
path_arc = np.pi/2,
|
||||
use_rectangular_stem = False,
|
||||
color = WHITE,
|
||||
)
|
||||
|
||||
comp_graph = self.comp_graph
|
||||
path_to_C = VGroup(*[
|
||||
getattr(comp_graph, attr)
|
||||
for attr in ["aLm1", "zL", "aL", "C0"]
|
||||
])
|
||||
def get_path_animation():
|
||||
return LaggedStart(
|
||||
Indicate, path_to_C,
|
||||
rate_func = there_and_back,
|
||||
run_time = 1.5,
|
||||
lag_ratio = 0.7,
|
||||
)
|
||||
|
||||
self.play(get_path_animation())
|
||||
self.play(
|
||||
numer.shift, SMALL_BUFF*UP,
|
||||
Transform(denom, del_aLm1),
|
||||
FadeIn(dz_daLm1),
|
||||
VGroup(*self.chain_rule_equation[-2:]).shift, SMALL_BUFF*RIGHT,
|
||||
)
|
||||
self.dither()
|
||||
self.play(
|
||||
ShowCreation(arrow),
|
||||
ReplacementTransform(
|
||||
dz_daLm1.copy(), wL,
|
||||
path_arc = arrow.path_arc
|
||||
)
|
||||
)
|
||||
self.dither(2)
|
||||
|
||||
self.chain_rule_rhs.add(wL, arrow)
|
||||
self.chain_rule_equation.add(dz_daLm1)
|
||||
|
||||
def show_previous_weight_and_bias(self):
|
||||
pass
|
||||
to_fade = self.chain_rule_rhs
|
||||
comp_graph = self.comp_graph
|
||||
prev_comp_subgraph = self.prev_comp_subgraph
|
||||
prev_comp_subgraph.scale(0.8)
|
||||
prev_comp_subgraph.next_to(comp_graph, UP, SMALL_BUFF)
|
||||
|
||||
prev_layer = VGroup(
|
||||
self.network_mob.layers[1],
|
||||
self.network_mob.edge_groups[1],
|
||||
)
|
||||
for mob in prev_layer:
|
||||
mob.restore()
|
||||
prev_layer.next_to(self.last_neurons, LEFT, buff = 0)
|
||||
self.remove(prev_layer)
|
||||
|
||||
self.play(LaggedStart(FadeOut, to_fade, run_time = 1))
|
||||
self.play(
|
||||
ShowCreation(prev_comp_subgraph, run_time = 1),
|
||||
self.chain_rule_equation.to_edge, RIGHT
|
||||
)
|
||||
self.play(FadeIn(prev_layer))
|
||||
|
||||
###
|
||||
neuron = self.network_mob.layers[1].neurons[0]
|
||||
decimal = self.get_neuron_activation_decimal(neuron)
|
||||
a_label = TexMobject("a^{(L-2)}")
|
||||
a_label.replace(self.a_labels[1])
|
||||
arrow = self.a_label_arrows[1].copy()
|
||||
VGroup(a_label, arrow).shift(
|
||||
neuron.get_center() - self.last_neurons[1].get_center()
|
||||
)
|
||||
|
||||
self.play(
|
||||
Write(a_label, run_time = 1),
|
||||
Write(decimal, run_time = 1),
|
||||
GrowArrow(arrow),
|
||||
)
|
||||
|
||||
def animate_long_path(self):
|
||||
comp_graph = self.comp_graph
|
||||
path_to_C = VGroup(
|
||||
self.wLm1, self.zLm1,
|
||||
*[
|
||||
getattr(comp_graph, attr)
|
||||
for attr in ["aLm1", "zL", "aL", "C0"]
|
||||
]
|
||||
)
|
||||
for x in range(2):
|
||||
self.play(LaggedStart(
|
||||
Indicate, path_to_C,
|
||||
rate_func = there_and_back,
|
||||
run_time = 1.5,
|
||||
lag_ratio = 0.4,
|
||||
))
|
||||
self.dither(2)
|
||||
|
||||
###
|
||||
|
||||
|
@ -2803,11 +3378,11 @@ class SimplestNetworkExample(PreviewLearning):
|
|||
result = VGroup(comp_graph)
|
||||
for attr in "wL", "zL", "aL", "C0":
|
||||
sym = getattr(comp_graph, attr)
|
||||
comp_graph.add(
|
||||
result.add(
|
||||
sym.arrow, sym.number_line, sym.dot
|
||||
)
|
||||
del_sym = getattr(self, "del_" + attr)
|
||||
comp_graph.add(del_sym, del_sym.brace)
|
||||
result.add(del_sym, del_sym.brace)
|
||||
|
||||
self.all_comp_graph_parts = result
|
||||
return result
|
||||
|
|
Loading…
Add table
Reference in a new issue