mirror of
https://github.com/3b1b/manim.git
synced 2025-09-01 00:48:45 +00:00
Finished nn/part3
This commit is contained in:
parent
b07b9507d0
commit
6093780215
1 changed files with 121 additions and 19 deletions
140
nn/part3.py
140
nn/part3.py
|
@ -3669,8 +3669,6 @@ class GeneralFormulas(SimplestNetworkExample):
|
||||||
self.setup_bases()
|
self.setup_bases()
|
||||||
|
|
||||||
def construct(self):
|
def construct(self):
|
||||||
self.force_skipping()
|
|
||||||
|
|
||||||
self.setup_network_mob()
|
self.setup_network_mob()
|
||||||
self.show_all_a_labels()
|
self.show_all_a_labels()
|
||||||
self.only_show_abstract_a_labels()
|
self.only_show_abstract_a_labels()
|
||||||
|
@ -3925,22 +3923,34 @@ class GeneralFormulas(SimplestNetworkExample):
|
||||||
|
|
||||||
def show_values_between_weight_and_cost(self):
|
def show_values_between_weight_and_cost(self):
|
||||||
z_formula = TexMobject(
|
z_formula = TexMobject(
|
||||||
"z^{(L)}_j", "=", "\\cdots", "+"
|
"z^{(L)}_j", "=",
|
||||||
"w^{(L)}_{jk}", "a^{(L-1)}_k", "+", "\\cdots"
|
"w^{(L)}_{j0}", "a^{(L-1)}_0", "+",
|
||||||
|
"w^{(L)}_{j1}", "a^{(L-1)}_1", "+",
|
||||||
|
"w^{(L)}_{j2}", "a^{(L-1)}_2", "+",
|
||||||
|
"b^{(L)}_j"
|
||||||
)
|
)
|
||||||
z_formula.to_corner(UP+RIGHT)
|
compact_z_formula = TexMobject(
|
||||||
z_formula.highlight_by_tex_to_color_map({
|
"z^{(L)}_j", "=",
|
||||||
"z^" : self.z_color,
|
"\\cdots", "", "+"
|
||||||
"w^" : self.w_label.get_color()
|
"w^{(L)}_{jk}", "a^{(L-1)}_k", "+",
|
||||||
})
|
"\\cdots", "", "", "",
|
||||||
w_part = z_formula.get_part_by_tex("w^")
|
)
|
||||||
aLm1_part = z_formula.get_part_by_tex("a^{(L-1)}")
|
for expression in z_formula, compact_z_formula:
|
||||||
|
expression.to_corner(UP+RIGHT)
|
||||||
|
expression.highlight_by_tex_to_color_map({
|
||||||
|
"z^" : self.z_color,
|
||||||
|
"w^" : self.w_label.get_color(),
|
||||||
|
"b^" : MAROON_B,
|
||||||
|
})
|
||||||
|
w_part = z_formula.get_parts_by_tex("w^")[1]
|
||||||
|
aLm1_part = z_formula.get_parts_by_tex("a^{(L-1)}")[1]
|
||||||
|
|
||||||
a_formula = TexMobject(
|
a_formula = TexMobject(
|
||||||
"a^{(L)}_j", "=", "\\sigma(", "z^{(L)}_j", ")"
|
"a^{(L)}_j", "=", "\\sigma(", "z^{(L)}_j", ")"
|
||||||
)
|
)
|
||||||
a_formula.highlight_by_tex("z^", self.z_color)
|
a_formula.highlight_by_tex("z^", self.z_color)
|
||||||
a_formula.next_to(z_formula, DOWN, MED_LARGE_BUFF)
|
a_formula.next_to(z_formula, DOWN, MED_LARGE_BUFF)
|
||||||
|
a_formula.align_to(self.cost_equation, LEFT)
|
||||||
aL_part = a_formula[0]
|
aL_part = a_formula[0]
|
||||||
|
|
||||||
to_fade = VGroup(
|
to_fade = VGroup(
|
||||||
|
@ -3973,10 +3983,16 @@ class GeneralFormulas(SimplestNetworkExample):
|
||||||
self.chosen_neurons[1].label.copy(),
|
self.chosen_neurons[1].label.copy(),
|
||||||
aL_part
|
aL_part
|
||||||
))
|
))
|
||||||
self.play(Write(VGroup(*a_formula[1:])))
|
self.play(
|
||||||
|
Write(VGroup(*a_formula[1:3] + [a_formula[-1]])),
|
||||||
|
ReplacementTransform(
|
||||||
|
z_formula[0].copy(),
|
||||||
|
a_formula.get_part_by_tex("z^")
|
||||||
|
)
|
||||||
|
)
|
||||||
self.dither()
|
self.dither()
|
||||||
|
|
||||||
self.set_variables_as_attrs(z_formula, a_formula)
|
self.set_variables_as_attrs(z_formula, compact_z_formula, a_formula)
|
||||||
|
|
||||||
def show_weight_chain_rule(self):
|
def show_weight_chain_rule(self):
|
||||||
chain_rule = self.get_chain_rule(
|
chain_rule = self.get_chain_rule(
|
||||||
|
@ -3996,6 +4012,9 @@ class GeneralFormulas(SimplestNetworkExample):
|
||||||
])
|
])
|
||||||
rects.gradient_highlight(GREEN, WHITE, RED)
|
rects.gradient_highlight(GREEN, WHITE, RED)
|
||||||
|
|
||||||
|
self.play(Transform(
|
||||||
|
self.z_formula, self.compact_z_formula
|
||||||
|
))
|
||||||
self.play(Write(chain_rule))
|
self.play(Write(chain_rule))
|
||||||
self.dither()
|
self.dither()
|
||||||
self.play(LaggedStart(
|
self.play(LaggedStart(
|
||||||
|
@ -4151,7 +4170,6 @@ class GeneralFormulas(SimplestNetworkExample):
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
self.revert_to_original_skipping_status()
|
|
||||||
self.play(ShowCreation(deriv_rect))
|
self.play(ShowCreation(deriv_rect))
|
||||||
self.play(LaggedStart(
|
self.play(LaggedStart(
|
||||||
ShowCreationThenDestruction,
|
ShowCreationThenDestruction,
|
||||||
|
@ -4225,6 +4243,70 @@ class ThatsALotToThinkAbout(TeacherStudentsScene):
|
||||||
self.change_student_modes(*["thinking"]*3)
|
self.change_student_modes(*["thinking"]*3)
|
||||||
self.dither(4)
|
self.dither(4)
|
||||||
|
|
||||||
|
class LayersOfComplexity(Scene):
|
||||||
|
def construct(self):
|
||||||
|
chain_rule_equations = self.get_chain_rule_equations()
|
||||||
|
chain_rule_equations.to_corner(UP+RIGHT)
|
||||||
|
|
||||||
|
brace = Brace(chain_rule_equations, LEFT)
|
||||||
|
arrow = Vector(LEFT, color = RED)
|
||||||
|
arrow.next_to(brace, LEFT)
|
||||||
|
gradient = TexMobject("\\nabla C")
|
||||||
|
gradient.scale(2)
|
||||||
|
gradient.highlight(RED)
|
||||||
|
gradient.next_to(arrow, LEFT)
|
||||||
|
|
||||||
|
self.play(LaggedStart(FadeIn, chain_rule_equations))
|
||||||
|
self.play(GrowFromCenter(brace))
|
||||||
|
self.play(GrowArrow(arrow))
|
||||||
|
self.play(Write(gradient))
|
||||||
|
self.dither()
|
||||||
|
|
||||||
|
|
||||||
|
def get_chain_rule_equations(self):
|
||||||
|
w_deriv = TexMobject(
|
||||||
|
"{\\partial C", "\\over", "\\partial w^{(l)}_{jk}}",
|
||||||
|
"=",
|
||||||
|
"a^{(l-1)}_k",
|
||||||
|
"\\sigma'(z^{(l)}_j)",
|
||||||
|
"{\\partial C", "\\over", "\\partial a^{(l)}_j}",
|
||||||
|
)
|
||||||
|
lil_rect = SurroundingRectangle(
|
||||||
|
VGroup(*w_deriv[-3:]),
|
||||||
|
buff = 0.5*SMALL_BUFF
|
||||||
|
)
|
||||||
|
a_deriv = TexMobject(
|
||||||
|
"\\sum_{j = 0}^{n_{l+1} - 1}",
|
||||||
|
"w^{(l+1)}_{jk}",
|
||||||
|
"\\sigma'(z^{(l+1)}_j)",
|
||||||
|
"{\\partial C", "\\over", "\\partial a^{(l+1)}_j}",
|
||||||
|
)
|
||||||
|
or_word = TextMobject("or")
|
||||||
|
last_a_deriv = TexMobject("2(a^{(L)}_j - y_j)")
|
||||||
|
|
||||||
|
a_deriv.next_to(w_deriv, DOWN, LARGE_BUFF)
|
||||||
|
or_word.next_to(a_deriv, DOWN)
|
||||||
|
last_a_deriv.next_to(or_word, DOWN, MED_LARGE_BUFF)
|
||||||
|
|
||||||
|
big_rect = SurroundingRectangle(VGroup(a_deriv, last_a_deriv))
|
||||||
|
arrow = Arrow(
|
||||||
|
lil_rect.get_corner(DOWN+LEFT),
|
||||||
|
big_rect.get_top(),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = VGroup(
|
||||||
|
w_deriv, lil_rect, arrow,
|
||||||
|
big_rect, a_deriv, or_word, last_a_deriv
|
||||||
|
)
|
||||||
|
for expression in w_deriv, a_deriv, last_a_deriv:
|
||||||
|
expression.highlight_by_tex_to_color_map({
|
||||||
|
"C" : RED,
|
||||||
|
"z^" : GREEN,
|
||||||
|
"w^" : BLUE,
|
||||||
|
"b^" : MAROON_B,
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
|
||||||
class SponsorFrame(PiCreatureScene):
|
class SponsorFrame(PiCreatureScene):
|
||||||
def construct(self):
|
def construct(self):
|
||||||
morty = self.pi_creature
|
morty = self.pi_creature
|
||||||
|
@ -4363,11 +4445,16 @@ class Thumbnail(PreviewLearning):
|
||||||
},
|
},
|
||||||
"stroke_width_exp" : 1,
|
"stroke_width_exp" : 1,
|
||||||
"max_stroke_width" : 5,
|
"max_stroke_width" : 5,
|
||||||
|
"title" : "Backpropagation",
|
||||||
|
"network_scale_val" : 0.8,
|
||||||
}
|
}
|
||||||
def construct(self):
|
def construct(self):
|
||||||
self.color_network_edges()
|
self.color_network_edges()
|
||||||
network_mob = self.network_mob
|
network_mob = self.network_mob
|
||||||
network_mob.scale(0.8, about_point = network_mob.get_bottom())
|
network_mob.scale(
|
||||||
|
self.network_scale_val,
|
||||||
|
about_point = network_mob.get_bottom()
|
||||||
|
)
|
||||||
network_mob.activate_layers(np.random.random(self.layer_sizes[0]))
|
network_mob.activate_layers(np.random.random(self.layer_sizes[0]))
|
||||||
|
|
||||||
for edge in it.chain(*network_mob.edge_groups):
|
for edge in it.chain(*network_mob.edge_groups):
|
||||||
|
@ -4377,7 +4464,7 @@ class Thumbnail(PreviewLearning):
|
||||||
tip_length = 0.1,
|
tip_length = 0.1,
|
||||||
color = edge.get_color()
|
color = edge.get_color()
|
||||||
)
|
)
|
||||||
self.add(arrow.tip)
|
network_mob.add(arrow.tip)
|
||||||
|
|
||||||
arrow = Vector(
|
arrow = Vector(
|
||||||
3*LEFT,
|
3*LEFT,
|
||||||
|
@ -4387,14 +4474,29 @@ class Thumbnail(PreviewLearning):
|
||||||
)
|
)
|
||||||
arrow.next_to(network_mob.edge_groups[1], UP, MED_LARGE_BUFF)
|
arrow.next_to(network_mob.edge_groups[1], UP, MED_LARGE_BUFF)
|
||||||
|
|
||||||
self.add(arrow)
|
network_mob.add(arrow)
|
||||||
|
self.add(network_mob)
|
||||||
|
|
||||||
title = TextMobject("Backpropagation")
|
title = TextMobject(self.title)
|
||||||
title.scale(2)
|
title.scale(2)
|
||||||
title.to_edge(UP)
|
title.to_edge(UP)
|
||||||
self.add(title)
|
self.add(title)
|
||||||
|
|
||||||
|
class SupplementThumbnail(Thumbnail):
|
||||||
|
CONFIG = {
|
||||||
|
"title" : "Backpropagation \\\\ calculus",
|
||||||
|
"network_scale_val" : 0.7,
|
||||||
|
}
|
||||||
|
def construct(self):
|
||||||
|
Thumbnail.construct(self)
|
||||||
|
self.network_mob.to_edge(DOWN, buff = MED_SMALL_BUFF)
|
||||||
|
|
||||||
|
for layer in self.network_mob.layers:
|
||||||
|
for neuron in layer.neurons:
|
||||||
|
partial = TexMobject("\\partial")
|
||||||
|
partial.move_to(neuron)
|
||||||
|
self.remove(neuron)
|
||||||
|
self.add(partial)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue