mirror of
https://github.com/3b1b/manim.git
synced 2025-08-31 10:48:55 +00:00
Finished GeneralFormulas of nn/part3
This commit is contained in:
parent
cfa9aba0ce
commit
109bff8a27
1 changed files with 322 additions and 19 deletions
341
nn/part3.py
341
nn/part3.py
|
@ -3413,16 +3413,15 @@ class GeneralFormulas(SimplestNetworkExample):
|
||||||
"neuron_to_neuron_buff" : LARGE_BUFF,
|
"neuron_to_neuron_buff" : LARGE_BUFF,
|
||||||
"neuron_radius" : 0.3,
|
"neuron_radius" : 0.3,
|
||||||
},
|
},
|
||||||
"stroke_width_exp" : 0.5,
|
"edge_stroke_width" : 4,
|
||||||
"random_seed" : 1,
|
"stroke_width_exp" : 0.2,
|
||||||
|
"random_seed" : 9,
|
||||||
}
|
}
|
||||||
def setup(self):
|
def setup(self):
|
||||||
self.seed_random_libraries()
|
self.seed_random_libraries()
|
||||||
self.setup_bases()
|
self.setup_bases()
|
||||||
|
|
||||||
def construct(self):
|
def construct(self):
|
||||||
self.force_skipping()
|
|
||||||
|
|
||||||
self.setup_network_mob()
|
self.setup_network_mob()
|
||||||
self.show_all_a_labels()
|
self.show_all_a_labels()
|
||||||
self.only_show_abstract_a_labels()
|
self.only_show_abstract_a_labels()
|
||||||
|
@ -3431,12 +3430,14 @@ class GeneralFormulas(SimplestNetworkExample):
|
||||||
self.show_example_weight()
|
self.show_example_weight()
|
||||||
self.show_values_between_weight_and_cost()
|
self.show_values_between_weight_and_cost()
|
||||||
self.show_weight_chain_rule()
|
self.show_weight_chain_rule()
|
||||||
self.show_multiple_paths_from_prev_layer_neuron()
|
|
||||||
self.show_derivative_wrt_prev_activation()
|
self.show_derivative_wrt_prev_activation()
|
||||||
|
self.show_multiple_paths_from_prev_layer_neuron()
|
||||||
|
self.show_previous_layer()
|
||||||
|
|
||||||
def setup_network_mob(self):
|
def setup_network_mob(self):
|
||||||
self.color_network_edges()
|
self.color_network_edges()
|
||||||
self.network_mob.to_edge(LEFT)
|
self.network_mob.to_edge(LEFT)
|
||||||
|
self.network_mob.shift(DOWN)
|
||||||
in_vect = np.random.random(self.layer_sizes[0])
|
in_vect = np.random.random(self.layer_sizes[0])
|
||||||
self.network_mob.activate_layers(in_vect)
|
self.network_mob.activate_layers(in_vect)
|
||||||
self.remove(self.network_mob.layers[0])
|
self.remove(self.network_mob.layers[0])
|
||||||
|
@ -3462,7 +3463,7 @@ class GeneralFormulas(SimplestNetworkExample):
|
||||||
arrow.next_to(neuron, -vect)
|
arrow.next_to(neuron, -vect)
|
||||||
arrow.set_fill(WHITE)
|
arrow.set_fill(WHITE)
|
||||||
label = TexMobject("a^{(%s)}_%d"%(s, i))
|
label = TexMobject("a^{(%s)}_%d"%(s, i))
|
||||||
label.next_to(arrow, -vect)
|
label.next_to(arrow, -vect, SMALL_BUFF)
|
||||||
rect = SurroundingRectangle(label[-1], buff = 0.5*SMALL_BUFF)
|
rect = SurroundingRectangle(label[-1], buff = 0.5*SMALL_BUFF)
|
||||||
decimal = self.get_neuron_activation_decimal(neuron)
|
decimal = self.get_neuron_activation_decimal(neuron)
|
||||||
neuron.arrow = arrow
|
neuron.arrow = arrow
|
||||||
|
@ -3515,7 +3516,7 @@ class GeneralFormulas(SimplestNetworkExample):
|
||||||
rects = VGroup()
|
rects = VGroup()
|
||||||
for x, layer in enumerate(self.network_mob.layers[-2:]):
|
for x, layer in enumerate(self.network_mob.layers[-2:]):
|
||||||
for y, neuron in enumerate(layer.neurons):
|
for y, neuron in enumerate(layer.neurons):
|
||||||
if (x == 0 and y == 2) or (x == 1 and y == 0):
|
if (x == 0 and y == 1) or (x == 1 and y == 0):
|
||||||
tex = "k" if x == 0 else "j"
|
tex = "k" if x == 0 else "j"
|
||||||
neuron.label.generate_target()
|
neuron.label.generate_target()
|
||||||
self.replace_subscript(neuron.label.target, tex)
|
self.replace_subscript(neuron.label.target, tex)
|
||||||
|
@ -3544,38 +3545,340 @@ class GeneralFormulas(SimplestNetworkExample):
|
||||||
self.dither()
|
self.dither()
|
||||||
|
|
||||||
def add_desired_output(self):
|
def add_desired_output(self):
|
||||||
pass
|
layer = self.network_mob.layers[-1]
|
||||||
|
desired_output = layer.deepcopy()
|
||||||
|
desired_output.shift(3*RIGHT)
|
||||||
|
desired_output_decimals = VGroup()
|
||||||
|
arrows = VGroup()
|
||||||
|
labels = VGroup()
|
||||||
|
for i, neuron in enumerate(desired_output.neurons):
|
||||||
|
neuron.set_fill(opacity = i)
|
||||||
|
decimal = self.get_neuron_activation_decimal(neuron)
|
||||||
|
neuron.decimal = decimal
|
||||||
|
neuron.arrow = Arrow(ORIGIN, LEFT, color = WHITE)
|
||||||
|
neuron.arrow.next_to(neuron, RIGHT)
|
||||||
|
neuron.label = TexMobject("y_%d"%i)
|
||||||
|
neuron.label.next_to(neuron.arrow, RIGHT)
|
||||||
|
neuron.label.highlight(self.desired_output_color)
|
||||||
|
|
||||||
|
desired_output_decimals.add(decimal)
|
||||||
|
arrows.add(neuron.arrow)
|
||||||
|
labels.add(neuron.label)
|
||||||
|
rect = SurroundingRectangle(desired_output, buff = 0.5*SMALL_BUFF)
|
||||||
|
words = TextMobject("Desired output")
|
||||||
|
words.next_to(rect, DOWN)
|
||||||
|
VGroup(words, rect).highlight(self.desired_output_color)
|
||||||
|
|
||||||
|
self.play(
|
||||||
|
FadeIn(rect),
|
||||||
|
FadeIn(words),
|
||||||
|
ReplacementTransform(layer.copy(), desired_output),
|
||||||
|
FadeIn(labels),
|
||||||
|
*[
|
||||||
|
ReplacementTransform(n1.decimal.copy(), n2.decimal)
|
||||||
|
for n1, n2 in zip(layer.neurons, desired_output.neurons)
|
||||||
|
] + map(GrowArrow, arrows)
|
||||||
|
)
|
||||||
|
self.dither()
|
||||||
|
|
||||||
|
self.set_variables_as_attrs(
|
||||||
|
desired_output,
|
||||||
|
desired_output_decimals,
|
||||||
|
desired_output_rect = rect,
|
||||||
|
desired_output_words = words,
|
||||||
|
)
|
||||||
|
|
||||||
def show_cost(self):
|
def show_cost(self):
|
||||||
pass
|
aj = self.chosen_neurons[1].label.copy()
|
||||||
|
yj = self.desired_output.neurons[0].label.copy()
|
||||||
|
|
||||||
|
cost_equation = TexMobject(
|
||||||
|
"C_0", "=", "\\sum_{j = 0}^{n_L - 1}",
|
||||||
|
"(", "a^{(L)}_j", "-", "y_j", ")", "^2"
|
||||||
|
)
|
||||||
|
cost_equation.to_corner(UP+RIGHT)
|
||||||
|
cost_equation[0].highlight(self.cost_color)
|
||||||
|
aj.target = cost_equation.get_part_by_tex("a^{(L)}_j")
|
||||||
|
yj.target = cost_equation.get_part_by_tex("y_j")
|
||||||
|
yj.target.highlight(self.desired_output_color)
|
||||||
|
to_fade_in = VGroup(*filter(
|
||||||
|
lambda m : m not in [aj.target, yj.target],
|
||||||
|
cost_equation
|
||||||
|
))
|
||||||
|
|
||||||
|
self.play(*[
|
||||||
|
ReplacementTransform(mob, mob.target)
|
||||||
|
for mob in aj, yj
|
||||||
|
])
|
||||||
|
self.play(LaggedStart(FadeIn, to_fade_in))
|
||||||
|
self.dither(2)
|
||||||
|
|
||||||
|
self.set_variables_as_attrs(cost_equation)
|
||||||
|
|
||||||
def show_example_weight(self):
|
def show_example_weight(self):
|
||||||
pass
|
edges = self.network_mob.edge_groups[-1]
|
||||||
|
edge = self.chosen_neurons[1].edges_in[1]
|
||||||
|
faded_edges = VGroup(*filter(
|
||||||
|
lambda e : e is not edge,
|
||||||
|
edges
|
||||||
|
))
|
||||||
|
faded_edges.save_state()
|
||||||
|
for faded_edge in faded_edges:
|
||||||
|
faded_edge.save_state()
|
||||||
|
|
||||||
|
w_label = TexMobject("w^{(L)}_{jk}")
|
||||||
|
w_label.scale(1.2)
|
||||||
|
w_label.add_background_rectangle()
|
||||||
|
w_label.next_to(ORIGIN, UP, SMALL_BUFF)
|
||||||
|
w_label.rotate(edge.get_angle())
|
||||||
|
w_label.shift(edge.get_center())
|
||||||
|
w_label.highlight(BLUE)
|
||||||
|
|
||||||
|
self.play(faded_edges.fade, 0.9)
|
||||||
|
self.play(Write(w_label))
|
||||||
|
self.dither()
|
||||||
|
|
||||||
|
self.set_variables_as_attrs(faded_edges, w_label)
|
||||||
|
|
||||||
def show_values_between_weight_and_cost(self):
|
def show_values_between_weight_and_cost(self):
|
||||||
pass
|
z_formula = TexMobject(
|
||||||
|
"z^{(L)}_j", "=", "\\cdots", "+"
|
||||||
|
"w^{(L)}_{jk}", "a^{(L-1)}_k", "+", "\\cdots"
|
||||||
|
)
|
||||||
|
z_formula.to_corner(UP+RIGHT)
|
||||||
|
z_formula.highlight_by_tex_to_color_map({
|
||||||
|
"z^" : self.z_color,
|
||||||
|
"w^" : self.w_label.get_color()
|
||||||
|
})
|
||||||
|
w_part = z_formula.get_part_by_tex("w^")
|
||||||
|
aLm1_part = z_formula.get_part_by_tex("a^{(L-1)}")
|
||||||
|
|
||||||
|
a_formula = TexMobject(
|
||||||
|
"a^{(L)}_j", "=", "\\sigma(", "z^{(L)}_j", ")"
|
||||||
|
)
|
||||||
|
a_formula.highlight_by_tex("z^", self.z_color)
|
||||||
|
a_formula.next_to(z_formula, DOWN, MED_LARGE_BUFF)
|
||||||
|
aL_part = a_formula[0]
|
||||||
|
|
||||||
|
to_fade = VGroup(
|
||||||
|
self.desired_output,
|
||||||
|
self.desired_output_decimals,
|
||||||
|
self.desired_output_rect,
|
||||||
|
self.desired_output_words,
|
||||||
|
*[
|
||||||
|
VGroup(n.arrow, n.label)
|
||||||
|
for n in self.desired_output.neurons
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.play(
|
||||||
|
FadeOut(to_fade),
|
||||||
|
self.cost_equation.next_to, a_formula, DOWN, MED_LARGE_BUFF,
|
||||||
|
self.cost_equation.to_edge, RIGHT,
|
||||||
|
ReplacementTransform(self.w_label[1].copy(), w_part),
|
||||||
|
ReplacementTransform(
|
||||||
|
self.chosen_neurons[0].label.copy(),
|
||||||
|
aLm1_part
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.play(Write(VGroup(*filter(
|
||||||
|
lambda m : m not in [w_part, aLm1_part],
|
||||||
|
z_formula
|
||||||
|
))))
|
||||||
|
self.dither()
|
||||||
|
self.play(ReplacementTransform(
|
||||||
|
self.chosen_neurons[1].label.copy(),
|
||||||
|
aL_part
|
||||||
|
))
|
||||||
|
self.play(Write(VGroup(*a_formula[1:])))
|
||||||
|
self.dither()
|
||||||
|
|
||||||
|
self.set_variables_as_attrs(z_formula, a_formula)
|
||||||
|
|
||||||
def show_weight_chain_rule(self):
|
def show_weight_chain_rule(self):
|
||||||
pass
|
chain_rule = self.get_chain_rule(
|
||||||
|
"{\\partial C_0", "\\over", "\\partial w^{(L)}_{jk}}",
|
||||||
|
"=",
|
||||||
|
"{\\partial z^{(L)}_j", "\\over", "\\partial w^{(L)}_{jk}}",
|
||||||
|
"{\\partial a^{(L)}_j", "\\over", "\\partial z^{(L)}_j}",
|
||||||
|
"{\\partial C_0", "\\over", "\\partial a^{(L)}_j}",
|
||||||
|
)
|
||||||
|
terms = VGroup(*[
|
||||||
|
VGroup(*chain_rule[i:i+3])
|
||||||
|
for i in range(4,len(chain_rule), 3)
|
||||||
|
])
|
||||||
|
rects = VGroup(*[
|
||||||
|
SurroundingRectangle(term, buff = 0.5*SMALL_BUFF)
|
||||||
|
for term in terms
|
||||||
|
])
|
||||||
|
rects.gradient_highlight(GREEN, WHITE, RED)
|
||||||
|
|
||||||
def show_multiple_paths_from_prev_layer_neuron(self):
|
self.play(Write(chain_rule))
|
||||||
pass
|
self.dither()
|
||||||
|
self.play(LaggedStart(
|
||||||
|
ShowCreationThenDestruction, rects,
|
||||||
|
lag_ratio = 0.7,
|
||||||
|
run_time = 3
|
||||||
|
))
|
||||||
|
self.dither()
|
||||||
|
|
||||||
|
self.set_variables_as_attrs(chain_rule)
|
||||||
|
|
||||||
def show_derivative_wrt_prev_activation(self):
|
def show_derivative_wrt_prev_activation(self):
|
||||||
pass
|
chain_rule = self.get_chain_rule(
|
||||||
|
"{\\partial C_0", "\\over", "\\partial a^{(L-1)}_k}",
|
||||||
|
"=",
|
||||||
|
"\\sum_{j=0}^{n_L - 1}",
|
||||||
|
"{\\partial z^{(L)}_j", "\\over", "\\partial a^{(L-1)}_k}",
|
||||||
|
"{\\partial a^{(L)}_j", "\\over", "\\partial z^{(L)}_j}",
|
||||||
|
"{\\partial C_0", "\\over", "\\partial a^{(L)}_j}",
|
||||||
|
)
|
||||||
|
formulas = VGroup(self.z_formula, self.a_formula, self.cost_equation)
|
||||||
|
|
||||||
|
n = chain_rule.index_of_part_by_tex("sum")
|
||||||
|
self.play(ReplacementTransform(
|
||||||
|
self.chain_rule, VGroup(*chain_rule[:n] + chain_rule[n+1:])
|
||||||
|
))
|
||||||
|
self.play(Write(chain_rule[n], run_time = 1))
|
||||||
|
self.dither()
|
||||||
|
|
||||||
|
self.set_variables_as_attrs(chain_rule)
|
||||||
|
|
||||||
|
def show_multiple_paths_from_prev_layer_neuron(self):
|
||||||
|
neurons = self.network_mob.layers[-1].neurons
|
||||||
|
labels, arrows, decimals = [
|
||||||
|
VGroup(*[getattr(n, attr) for n in neurons])
|
||||||
|
for attr in "label", "arrow", "decimal"
|
||||||
|
]
|
||||||
|
edges = VGroup(*[n.edges_in[1] for n in neurons])
|
||||||
|
labels[0].generate_target()
|
||||||
|
self.replace_subscript(labels[0].target, "0")
|
||||||
|
|
||||||
|
paths = [
|
||||||
|
VGroup(
|
||||||
|
self.chosen_neurons[0].label,
|
||||||
|
self.chosen_neurons[0].arrow,
|
||||||
|
self.chosen_neurons[0],
|
||||||
|
self.chosen_neurons[0].decimal,
|
||||||
|
edges[i],
|
||||||
|
neurons[i],
|
||||||
|
decimals[i],
|
||||||
|
arrows[i],
|
||||||
|
labels[i],
|
||||||
|
)
|
||||||
|
for i in range(2)
|
||||||
|
]
|
||||||
|
path_lines = VGroup()
|
||||||
|
for path in paths:
|
||||||
|
points = [path[0].get_center()]
|
||||||
|
for mob in path[1:]:
|
||||||
|
if isinstance(mob, DecimalNumber):
|
||||||
|
continue
|
||||||
|
points.append(mob.get_center())
|
||||||
|
path_line = VMobject()
|
||||||
|
path_line.set_points_as_corners(points)
|
||||||
|
path_lines.add(path_line)
|
||||||
|
path_lines.highlight(YELLOW)
|
||||||
|
|
||||||
|
chain_rule = self.chain_rule
|
||||||
|
n = chain_rule.index_of_part_by_tex("sum")
|
||||||
|
brace = Brace(VGroup(*chain_rule[n:]), DOWN, buff = SMALL_BUFF)
|
||||||
|
words = brace.get_text("Sum over layer L", buff = SMALL_BUFF)
|
||||||
|
|
||||||
|
cost_aL = self.cost_equation.get_part_by_tex("a^{(L)}")
|
||||||
|
|
||||||
|
self.play(
|
||||||
|
MoveToTarget(labels[0]),
|
||||||
|
FadeIn(labels[1]),
|
||||||
|
GrowArrow(arrows[1]),
|
||||||
|
edges[1].restore,
|
||||||
|
FadeOut(self.w_label),
|
||||||
|
)
|
||||||
|
for x in range(5):
|
||||||
|
anims = [
|
||||||
|
ShowCreationThenDestruction(
|
||||||
|
path_line,
|
||||||
|
run_time = 1.5,
|
||||||
|
time_width = 0.5,
|
||||||
|
)
|
||||||
|
for path_line in path_lines
|
||||||
|
]
|
||||||
|
if x == 2:
|
||||||
|
anims += [
|
||||||
|
FadeIn(words),
|
||||||
|
GrowFromCenter(brace)
|
||||||
|
]
|
||||||
|
self.play(*anims)
|
||||||
|
self.dither()
|
||||||
|
for path, path_line in zip(paths, path_lines):
|
||||||
|
label = path[-1]
|
||||||
|
self.play(
|
||||||
|
LaggedStart(
|
||||||
|
Indicate, path,
|
||||||
|
rate_func = wiggle,
|
||||||
|
run_time = 1,
|
||||||
|
),
|
||||||
|
ShowCreation(path_line),
|
||||||
|
Animation(label)
|
||||||
|
)
|
||||||
|
self.dither()
|
||||||
|
group = VGroup(label, cost_aL)
|
||||||
|
self.play(
|
||||||
|
group.shift, MED_SMALL_BUFF*UP,
|
||||||
|
rate_func = wiggle
|
||||||
|
)
|
||||||
|
self.play(FadeOut(path_line))
|
||||||
|
self.dither()
|
||||||
|
|
||||||
|
def show_previous_layer(self):
|
||||||
|
layer = self.network_mob.layers[0]
|
||||||
|
edges = self.network_mob.edge_groups[0]
|
||||||
|
faded_edges = self.faded_edges
|
||||||
|
to_fade = VGroup(
|
||||||
|
self.chosen_neurons[0].label,
|
||||||
|
self.chosen_neurons[0].arrow,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.play(faded_edges.restore)
|
||||||
|
self.play(
|
||||||
|
LaggedStart(
|
||||||
|
GrowFromCenter, layer.neurons,
|
||||||
|
run_time = 1
|
||||||
|
),
|
||||||
|
FadeOut(to_fade)
|
||||||
|
)
|
||||||
|
self.play(LaggedStart(ShowCreation, edges))
|
||||||
|
self.dither()
|
||||||
|
|
||||||
####
|
####
|
||||||
|
|
||||||
def replace_subscript(self, label, tex):
|
def replace_subscript(self, label, tex):
|
||||||
subscript = label[-1]
|
subscript = label[-1]
|
||||||
new_subscript = TexMobject(tex)
|
new_subscript = TexMobject(tex)[0]
|
||||||
new_subscript.replace(subscript)
|
new_subscript.replace(subscript, dim_to_match = 1)
|
||||||
label.remove(subscript)
|
label.remove(subscript)
|
||||||
label.add(new_subscript)
|
label.add(new_subscript)
|
||||||
return label
|
return label
|
||||||
|
|
||||||
|
def get_chain_rule(self, *tex):
|
||||||
|
chain_rule = TexMobject(*tex)
|
||||||
|
chain_rule.scale(0.8)
|
||||||
|
chain_rule.to_corner(UP+LEFT)
|
||||||
|
chain_rule.highlight_by_tex_to_color_map({
|
||||||
|
"C_0" : self.cost_color,
|
||||||
|
"z^" : self.z_color,
|
||||||
|
"w^" : self.w_label.get_color()
|
||||||
|
})
|
||||||
|
return chain_rule
|
||||||
|
|
||||||
|
class PatYourselfOnTheBack(TeacherStudentsScene):
|
||||||
|
def construct(self):
|
||||||
|
self.teacher_says(
|
||||||
|
"Pat yourself on \\\\ the back!",
|
||||||
|
target_mode = "hooray"
|
||||||
|
)
|
||||||
|
self.change_student_modes(*["hooray"]*3)
|
||||||
|
self.dither(3)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue