mirror of
https://github.com/3b1b/manim.git
synced 2025-08-29 20:42:02 +00:00
Finished GeneralFormulas of nn/part3
This commit is contained in:
parent
cfa9aba0ce
commit
109bff8a27
1 changed files with 322 additions and 19 deletions
341
nn/part3.py
341
nn/part3.py
|
@ -3413,16 +3413,15 @@ class GeneralFormulas(SimplestNetworkExample):
|
|||
"neuron_to_neuron_buff" : LARGE_BUFF,
|
||||
"neuron_radius" : 0.3,
|
||||
},
|
||||
"stroke_width_exp" : 0.5,
|
||||
"random_seed" : 1,
|
||||
"edge_stroke_width" : 4,
|
||||
"stroke_width_exp" : 0.2,
|
||||
"random_seed" : 9,
|
||||
}
|
||||
def setup(self):
|
||||
self.seed_random_libraries()
|
||||
self.setup_bases()
|
||||
|
||||
def construct(self):
|
||||
self.force_skipping()
|
||||
|
||||
self.setup_network_mob()
|
||||
self.show_all_a_labels()
|
||||
self.only_show_abstract_a_labels()
|
||||
|
@ -3431,12 +3430,14 @@ class GeneralFormulas(SimplestNetworkExample):
|
|||
self.show_example_weight()
|
||||
self.show_values_between_weight_and_cost()
|
||||
self.show_weight_chain_rule()
|
||||
self.show_multiple_paths_from_prev_layer_neuron()
|
||||
self.show_derivative_wrt_prev_activation()
|
||||
self.show_multiple_paths_from_prev_layer_neuron()
|
||||
self.show_previous_layer()
|
||||
|
||||
def setup_network_mob(self):
|
||||
self.color_network_edges()
|
||||
self.network_mob.to_edge(LEFT)
|
||||
self.network_mob.shift(DOWN)
|
||||
in_vect = np.random.random(self.layer_sizes[0])
|
||||
self.network_mob.activate_layers(in_vect)
|
||||
self.remove(self.network_mob.layers[0])
|
||||
|
@ -3462,7 +3463,7 @@ class GeneralFormulas(SimplestNetworkExample):
|
|||
arrow.next_to(neuron, -vect)
|
||||
arrow.set_fill(WHITE)
|
||||
label = TexMobject("a^{(%s)}_%d"%(s, i))
|
||||
label.next_to(arrow, -vect)
|
||||
label.next_to(arrow, -vect, SMALL_BUFF)
|
||||
rect = SurroundingRectangle(label[-1], buff = 0.5*SMALL_BUFF)
|
||||
decimal = self.get_neuron_activation_decimal(neuron)
|
||||
neuron.arrow = arrow
|
||||
|
@ -3515,7 +3516,7 @@ class GeneralFormulas(SimplestNetworkExample):
|
|||
rects = VGroup()
|
||||
for x, layer in enumerate(self.network_mob.layers[-2:]):
|
||||
for y, neuron in enumerate(layer.neurons):
|
||||
if (x == 0 and y == 2) or (x == 1 and y == 0):
|
||||
if (x == 0 and y == 1) or (x == 1 and y == 0):
|
||||
tex = "k" if x == 0 else "j"
|
||||
neuron.label.generate_target()
|
||||
self.replace_subscript(neuron.label.target, tex)
|
||||
|
@ -3544,38 +3545,340 @@ class GeneralFormulas(SimplestNetworkExample):
|
|||
self.dither()
|
||||
|
||||
def add_desired_output(self):
|
||||
pass
|
||||
layer = self.network_mob.layers[-1]
|
||||
desired_output = layer.deepcopy()
|
||||
desired_output.shift(3*RIGHT)
|
||||
desired_output_decimals = VGroup()
|
||||
arrows = VGroup()
|
||||
labels = VGroup()
|
||||
for i, neuron in enumerate(desired_output.neurons):
|
||||
neuron.set_fill(opacity = i)
|
||||
decimal = self.get_neuron_activation_decimal(neuron)
|
||||
neuron.decimal = decimal
|
||||
neuron.arrow = Arrow(ORIGIN, LEFT, color = WHITE)
|
||||
neuron.arrow.next_to(neuron, RIGHT)
|
||||
neuron.label = TexMobject("y_%d"%i)
|
||||
neuron.label.next_to(neuron.arrow, RIGHT)
|
||||
neuron.label.highlight(self.desired_output_color)
|
||||
|
||||
desired_output_decimals.add(decimal)
|
||||
arrows.add(neuron.arrow)
|
||||
labels.add(neuron.label)
|
||||
rect = SurroundingRectangle(desired_output, buff = 0.5*SMALL_BUFF)
|
||||
words = TextMobject("Desired output")
|
||||
words.next_to(rect, DOWN)
|
||||
VGroup(words, rect).highlight(self.desired_output_color)
|
||||
|
||||
self.play(
|
||||
FadeIn(rect),
|
||||
FadeIn(words),
|
||||
ReplacementTransform(layer.copy(), desired_output),
|
||||
FadeIn(labels),
|
||||
*[
|
||||
ReplacementTransform(n1.decimal.copy(), n2.decimal)
|
||||
for n1, n2 in zip(layer.neurons, desired_output.neurons)
|
||||
] + map(GrowArrow, arrows)
|
||||
)
|
||||
self.dither()
|
||||
|
||||
self.set_variables_as_attrs(
|
||||
desired_output,
|
||||
desired_output_decimals,
|
||||
desired_output_rect = rect,
|
||||
desired_output_words = words,
|
||||
)
|
||||
|
||||
def show_cost(self):
|
||||
pass
|
||||
aj = self.chosen_neurons[1].label.copy()
|
||||
yj = self.desired_output.neurons[0].label.copy()
|
||||
|
||||
cost_equation = TexMobject(
|
||||
"C_0", "=", "\\sum_{j = 0}^{n_L - 1}",
|
||||
"(", "a^{(L)}_j", "-", "y_j", ")", "^2"
|
||||
)
|
||||
cost_equation.to_corner(UP+RIGHT)
|
||||
cost_equation[0].highlight(self.cost_color)
|
||||
aj.target = cost_equation.get_part_by_tex("a^{(L)}_j")
|
||||
yj.target = cost_equation.get_part_by_tex("y_j")
|
||||
yj.target.highlight(self.desired_output_color)
|
||||
to_fade_in = VGroup(*filter(
|
||||
lambda m : m not in [aj.target, yj.target],
|
||||
cost_equation
|
||||
))
|
||||
|
||||
self.play(*[
|
||||
ReplacementTransform(mob, mob.target)
|
||||
for mob in aj, yj
|
||||
])
|
||||
self.play(LaggedStart(FadeIn, to_fade_in))
|
||||
self.dither(2)
|
||||
|
||||
self.set_variables_as_attrs(cost_equation)
|
||||
|
||||
def show_example_weight(self):
|
||||
pass
|
||||
edges = self.network_mob.edge_groups[-1]
|
||||
edge = self.chosen_neurons[1].edges_in[1]
|
||||
faded_edges = VGroup(*filter(
|
||||
lambda e : e is not edge,
|
||||
edges
|
||||
))
|
||||
faded_edges.save_state()
|
||||
for faded_edge in faded_edges:
|
||||
faded_edge.save_state()
|
||||
|
||||
w_label = TexMobject("w^{(L)}_{jk}")
|
||||
w_label.scale(1.2)
|
||||
w_label.add_background_rectangle()
|
||||
w_label.next_to(ORIGIN, UP, SMALL_BUFF)
|
||||
w_label.rotate(edge.get_angle())
|
||||
w_label.shift(edge.get_center())
|
||||
w_label.highlight(BLUE)
|
||||
|
||||
self.play(faded_edges.fade, 0.9)
|
||||
self.play(Write(w_label))
|
||||
self.dither()
|
||||
|
||||
self.set_variables_as_attrs(faded_edges, w_label)
|
||||
|
||||
def show_values_between_weight_and_cost(self):
|
||||
pass
|
||||
z_formula = TexMobject(
|
||||
"z^{(L)}_j", "=", "\\cdots", "+"
|
||||
"w^{(L)}_{jk}", "a^{(L-1)}_k", "+", "\\cdots"
|
||||
)
|
||||
z_formula.to_corner(UP+RIGHT)
|
||||
z_formula.highlight_by_tex_to_color_map({
|
||||
"z^" : self.z_color,
|
||||
"w^" : self.w_label.get_color()
|
||||
})
|
||||
w_part = z_formula.get_part_by_tex("w^")
|
||||
aLm1_part = z_formula.get_part_by_tex("a^{(L-1)}")
|
||||
|
||||
a_formula = TexMobject(
|
||||
"a^{(L)}_j", "=", "\\sigma(", "z^{(L)}_j", ")"
|
||||
)
|
||||
a_formula.highlight_by_tex("z^", self.z_color)
|
||||
a_formula.next_to(z_formula, DOWN, MED_LARGE_BUFF)
|
||||
aL_part = a_formula[0]
|
||||
|
||||
to_fade = VGroup(
|
||||
self.desired_output,
|
||||
self.desired_output_decimals,
|
||||
self.desired_output_rect,
|
||||
self.desired_output_words,
|
||||
*[
|
||||
VGroup(n.arrow, n.label)
|
||||
for n in self.desired_output.neurons
|
||||
]
|
||||
)
|
||||
|
||||
self.play(
|
||||
FadeOut(to_fade),
|
||||
self.cost_equation.next_to, a_formula, DOWN, MED_LARGE_BUFF,
|
||||
self.cost_equation.to_edge, RIGHT,
|
||||
ReplacementTransform(self.w_label[1].copy(), w_part),
|
||||
ReplacementTransform(
|
||||
self.chosen_neurons[0].label.copy(),
|
||||
aLm1_part
|
||||
),
|
||||
)
|
||||
self.play(Write(VGroup(*filter(
|
||||
lambda m : m not in [w_part, aLm1_part],
|
||||
z_formula
|
||||
))))
|
||||
self.dither()
|
||||
self.play(ReplacementTransform(
|
||||
self.chosen_neurons[1].label.copy(),
|
||||
aL_part
|
||||
))
|
||||
self.play(Write(VGroup(*a_formula[1:])))
|
||||
self.dither()
|
||||
|
||||
self.set_variables_as_attrs(z_formula, a_formula)
|
||||
|
||||
def show_weight_chain_rule(self):
|
||||
pass
|
||||
chain_rule = self.get_chain_rule(
|
||||
"{\\partial C_0", "\\over", "\\partial w^{(L)}_{jk}}",
|
||||
"=",
|
||||
"{\\partial z^{(L)}_j", "\\over", "\\partial w^{(L)}_{jk}}",
|
||||
"{\\partial a^{(L)}_j", "\\over", "\\partial z^{(L)}_j}",
|
||||
"{\\partial C_0", "\\over", "\\partial a^{(L)}_j}",
|
||||
)
|
||||
terms = VGroup(*[
|
||||
VGroup(*chain_rule[i:i+3])
|
||||
for i in range(4,len(chain_rule), 3)
|
||||
])
|
||||
rects = VGroup(*[
|
||||
SurroundingRectangle(term, buff = 0.5*SMALL_BUFF)
|
||||
for term in terms
|
||||
])
|
||||
rects.gradient_highlight(GREEN, WHITE, RED)
|
||||
|
||||
def show_multiple_paths_from_prev_layer_neuron(self):
|
||||
pass
|
||||
self.play(Write(chain_rule))
|
||||
self.dither()
|
||||
self.play(LaggedStart(
|
||||
ShowCreationThenDestruction, rects,
|
||||
lag_ratio = 0.7,
|
||||
run_time = 3
|
||||
))
|
||||
self.dither()
|
||||
|
||||
self.set_variables_as_attrs(chain_rule)
|
||||
|
||||
def show_derivative_wrt_prev_activation(self):
|
||||
pass
|
||||
chain_rule = self.get_chain_rule(
|
||||
"{\\partial C_0", "\\over", "\\partial a^{(L-1)}_k}",
|
||||
"=",
|
||||
"\\sum_{j=0}^{n_L - 1}",
|
||||
"{\\partial z^{(L)}_j", "\\over", "\\partial a^{(L-1)}_k}",
|
||||
"{\\partial a^{(L)}_j", "\\over", "\\partial z^{(L)}_j}",
|
||||
"{\\partial C_0", "\\over", "\\partial a^{(L)}_j}",
|
||||
)
|
||||
formulas = VGroup(self.z_formula, self.a_formula, self.cost_equation)
|
||||
|
||||
n = chain_rule.index_of_part_by_tex("sum")
|
||||
self.play(ReplacementTransform(
|
||||
self.chain_rule, VGroup(*chain_rule[:n] + chain_rule[n+1:])
|
||||
))
|
||||
self.play(Write(chain_rule[n], run_time = 1))
|
||||
self.dither()
|
||||
|
||||
self.set_variables_as_attrs(chain_rule)
|
||||
|
||||
def show_multiple_paths_from_prev_layer_neuron(self):
|
||||
neurons = self.network_mob.layers[-1].neurons
|
||||
labels, arrows, decimals = [
|
||||
VGroup(*[getattr(n, attr) for n in neurons])
|
||||
for attr in "label", "arrow", "decimal"
|
||||
]
|
||||
edges = VGroup(*[n.edges_in[1] for n in neurons])
|
||||
labels[0].generate_target()
|
||||
self.replace_subscript(labels[0].target, "0")
|
||||
|
||||
paths = [
|
||||
VGroup(
|
||||
self.chosen_neurons[0].label,
|
||||
self.chosen_neurons[0].arrow,
|
||||
self.chosen_neurons[0],
|
||||
self.chosen_neurons[0].decimal,
|
||||
edges[i],
|
||||
neurons[i],
|
||||
decimals[i],
|
||||
arrows[i],
|
||||
labels[i],
|
||||
)
|
||||
for i in range(2)
|
||||
]
|
||||
path_lines = VGroup()
|
||||
for path in paths:
|
||||
points = [path[0].get_center()]
|
||||
for mob in path[1:]:
|
||||
if isinstance(mob, DecimalNumber):
|
||||
continue
|
||||
points.append(mob.get_center())
|
||||
path_line = VMobject()
|
||||
path_line.set_points_as_corners(points)
|
||||
path_lines.add(path_line)
|
||||
path_lines.highlight(YELLOW)
|
||||
|
||||
chain_rule = self.chain_rule
|
||||
n = chain_rule.index_of_part_by_tex("sum")
|
||||
brace = Brace(VGroup(*chain_rule[n:]), DOWN, buff = SMALL_BUFF)
|
||||
words = brace.get_text("Sum over layer L", buff = SMALL_BUFF)
|
||||
|
||||
cost_aL = self.cost_equation.get_part_by_tex("a^{(L)}")
|
||||
|
||||
self.play(
|
||||
MoveToTarget(labels[0]),
|
||||
FadeIn(labels[1]),
|
||||
GrowArrow(arrows[1]),
|
||||
edges[1].restore,
|
||||
FadeOut(self.w_label),
|
||||
)
|
||||
for x in range(5):
|
||||
anims = [
|
||||
ShowCreationThenDestruction(
|
||||
path_line,
|
||||
run_time = 1.5,
|
||||
time_width = 0.5,
|
||||
)
|
||||
for path_line in path_lines
|
||||
]
|
||||
if x == 2:
|
||||
anims += [
|
||||
FadeIn(words),
|
||||
GrowFromCenter(brace)
|
||||
]
|
||||
self.play(*anims)
|
||||
self.dither()
|
||||
for path, path_line in zip(paths, path_lines):
|
||||
label = path[-1]
|
||||
self.play(
|
||||
LaggedStart(
|
||||
Indicate, path,
|
||||
rate_func = wiggle,
|
||||
run_time = 1,
|
||||
),
|
||||
ShowCreation(path_line),
|
||||
Animation(label)
|
||||
)
|
||||
self.dither()
|
||||
group = VGroup(label, cost_aL)
|
||||
self.play(
|
||||
group.shift, MED_SMALL_BUFF*UP,
|
||||
rate_func = wiggle
|
||||
)
|
||||
self.play(FadeOut(path_line))
|
||||
self.dither()
|
||||
|
||||
def show_previous_layer(self):
|
||||
layer = self.network_mob.layers[0]
|
||||
edges = self.network_mob.edge_groups[0]
|
||||
faded_edges = self.faded_edges
|
||||
to_fade = VGroup(
|
||||
self.chosen_neurons[0].label,
|
||||
self.chosen_neurons[0].arrow,
|
||||
)
|
||||
|
||||
self.play(faded_edges.restore)
|
||||
self.play(
|
||||
LaggedStart(
|
||||
GrowFromCenter, layer.neurons,
|
||||
run_time = 1
|
||||
),
|
||||
FadeOut(to_fade)
|
||||
)
|
||||
self.play(LaggedStart(ShowCreation, edges))
|
||||
self.dither()
|
||||
|
||||
####
|
||||
|
||||
def replace_subscript(self, label, tex):
|
||||
subscript = label[-1]
|
||||
new_subscript = TexMobject(tex)
|
||||
new_subscript.replace(subscript)
|
||||
new_subscript = TexMobject(tex)[0]
|
||||
new_subscript.replace(subscript, dim_to_match = 1)
|
||||
label.remove(subscript)
|
||||
label.add(new_subscript)
|
||||
return label
|
||||
|
||||
|
||||
|
||||
def get_chain_rule(self, *tex):
|
||||
chain_rule = TexMobject(*tex)
|
||||
chain_rule.scale(0.8)
|
||||
chain_rule.to_corner(UP+LEFT)
|
||||
chain_rule.highlight_by_tex_to_color_map({
|
||||
"C_0" : self.cost_color,
|
||||
"z^" : self.z_color,
|
||||
"w^" : self.w_label.get_color()
|
||||
})
|
||||
return chain_rule
|
||||
|
||||
class PatYourselfOnTheBack(TeacherStudentsScene):
|
||||
def construct(self):
|
||||
self.teacher_says(
|
||||
"Pat yourself on \\\\ the back!",
|
||||
target_mode = "hooray"
|
||||
)
|
||||
self.change_student_modes(*["hooray"]*3)
|
||||
self.dither(3)
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue