mirror of
https://github.com/3b1b/manim.git
synced 2025-09-01 00:48:45 +00:00
Up to derivatives in self.revert_to_original_skipping_status()
This commit is contained in:
parent
589fd5e028
commit
34b4ea089c
1 changed files with 283 additions and 21 deletions
304
nn/part3.py
304
nn/part3.py
|
@ -2002,6 +2002,7 @@ class SimplestNetworkExample(PreviewLearning):
|
|||
"z_color" : GREEN,
|
||||
"cost_color" : RED,
|
||||
"desired_output_color" : YELLOW,
|
||||
"derivative_scale_vale" : 0.7,
|
||||
}
|
||||
def construct(self):
|
||||
self.force_skipping()
|
||||
|
@ -2101,10 +2102,10 @@ class SimplestNetworkExample(PreviewLearning):
|
|||
self.dither()
|
||||
|
||||
def label_neurons(self):
|
||||
neurons = [
|
||||
neurons = VGroup(*[
|
||||
self.network_mob.layers[i].neurons[0]
|
||||
for i in -1, -2
|
||||
]
|
||||
])
|
||||
decimals = VGroup()
|
||||
a_labels = VGroup()
|
||||
a_label_arrows = VGroup()
|
||||
|
@ -2167,7 +2168,8 @@ class SimplestNetworkExample(PreviewLearning):
|
|||
self.play(*map(FadeOut, [not_exponents, superscript_rects]))
|
||||
|
||||
self.set_variables_as_attrs(
|
||||
a_labels, a_label_arrows, decimals
|
||||
a_labels, a_label_arrows, decimals,
|
||||
last_neurons = neurons
|
||||
)
|
||||
|
||||
def show_desired_output(self):
|
||||
|
@ -2290,41 +2292,301 @@ class SimplestNetworkExample(PreviewLearning):
|
|||
self.play(sigma_group.restore)
|
||||
self.dither()
|
||||
|
||||
weighted_sum_terms = VGroup(wL, aLm1, bL)
|
||||
weighted_sum_terms = VGroup(wL, aLm1, plus, bL)
|
||||
self.set_variables_as_attrs(
|
||||
formula, weighted_sum_terms
|
||||
)
|
||||
|
||||
def introduce_z(self):
|
||||
terms = self.weighted_sum_terms
|
||||
brace = Brace(terms, UP, buff = SMALL_BUFF)
|
||||
terms.generate_target()
|
||||
terms.target.next_to(self.formula, UP, aligned_edge = RIGHT)
|
||||
equals = TexMobject("=")
|
||||
equals.next_to(terms.target[0][0], LEFT)
|
||||
|
||||
z_label = TexMobject("z^{(L)}")
|
||||
z_label.next_to(brace, UP, buff = SMALL_BUFF)
|
||||
z_label.next_to(equals, LEFT)
|
||||
z_label.align_to(terms.target, DOWN)
|
||||
z_label.highlight(self.z_color)
|
||||
rect = SurroundingRectangle(terms)
|
||||
rect.highlight(GREEN)
|
||||
z_label2 = z_label.copy()
|
||||
|
||||
aL_start = VGroup(*self.formula[:4])
|
||||
aL_start.generate_target()
|
||||
aL_start.target.align_to(z_label, LEFT)
|
||||
z_label2.next_to(aL_start.target, RIGHT, SMALL_BUFF)
|
||||
z_label2.align_to(aL_start.target[0], DOWN)
|
||||
rp = self.formula[-1]
|
||||
rp.generate_target()
|
||||
rp.target.next_to(z_label2, RIGHT, SMALL_BUFF)
|
||||
rp.target.align_to(aL_start.target, DOWN)
|
||||
|
||||
self.play(MoveToTarget(terms))
|
||||
self.play(Write(z_label), Write(equals))
|
||||
self.play(
|
||||
ReplacementTransform(z_label.copy(), z_label2),
|
||||
MoveToTarget(aL_start),
|
||||
MoveToTarget(rp),
|
||||
)
|
||||
self.dither()
|
||||
|
||||
zL_formula = VGroup(z_label, equals, terms)
|
||||
aL_formula = VGroup(aL_start, z_label2, rp)
|
||||
self.set_variables_as_attrs(z_label, zL_formula, aL_formula)
|
||||
|
||||
def break_into_computational_graph(self):
|
||||
network_early_layers = VGroup(*it.chain(
|
||||
self.network_mob.layers[:2],
|
||||
self.network_mob.edge_groups[:2]
|
||||
))
|
||||
|
||||
wL, aL, plus, bL = self.weighted_sum_terms
|
||||
top_terms = VGroup(wL, aL, bL).copy()
|
||||
zL = self.z_label.copy()
|
||||
aL = self.formula[0].copy()
|
||||
y = self.y_label.copy()
|
||||
C0 = self.cost_equation[0].copy()
|
||||
targets = VGroup()
|
||||
for mob in top_terms, zL, aL, C0:
|
||||
mob.generate_target()
|
||||
targets.add(mob.target)
|
||||
y.generate_target()
|
||||
top_terms.target.arrange_submobjects(RIGHT, buff = MED_LARGE_BUFF)
|
||||
targets.arrange_submobjects(DOWN, buff = LARGE_BUFF)
|
||||
targets.center().to_corner(DOWN+LEFT)
|
||||
y.target.next_to(aL.target, LEFT, LARGE_BUFF, DOWN)
|
||||
|
||||
top_lines = VGroup(*[
|
||||
Line(
|
||||
term.get_bottom(),
|
||||
zL.target.get_top(),
|
||||
buff = SMALL_BUFF
|
||||
)
|
||||
for term in top_terms.target
|
||||
])
|
||||
z_to_a_line, a_to_c_line, y_to_c_line = all_lines = [
|
||||
Line(
|
||||
m1.target.get_bottom(),
|
||||
m2.target.get_top(),
|
||||
buff = SMALL_BUFF
|
||||
)
|
||||
for m1, m2 in [
|
||||
(zL, aL),
|
||||
(aL, C0),
|
||||
(y, C0)
|
||||
]
|
||||
]
|
||||
for mob in [top_lines] + all_lines:
|
||||
yellow_copy = mob.copy().highlight(YELLOW)
|
||||
mob.flash = ShowCreationThenDestruction(yellow_copy)
|
||||
|
||||
self.play(MoveToTarget(top_terms))
|
||||
self.dither()
|
||||
self.play(MoveToTarget(zL))
|
||||
self.play(
|
||||
ShowCreation(top_lines, submobject_mode = "all_at_once"),
|
||||
top_lines.flash
|
||||
)
|
||||
self.dither()
|
||||
self.play(MoveToTarget(aL))
|
||||
self.play(
|
||||
FadeOut(network_early_layers),
|
||||
ShowCreation(z_to_a_line),
|
||||
z_to_a_line.flash
|
||||
)
|
||||
self.dither()
|
||||
self.play(MoveToTarget(y))
|
||||
self.play(MoveToTarget(C0))
|
||||
self.play(*it.chain(*[
|
||||
[ShowCreation(line), line.flash]
|
||||
for line in a_to_c_line, y_to_c_line
|
||||
]))
|
||||
self.dither(2)
|
||||
|
||||
comp_graph = VGroup()
|
||||
comp_graph.wL, comp_graph.aLm1, comp_graph.bL = top_terms
|
||||
comp_graph.top_lines = top_lines
|
||||
comp_graph.zL = zL
|
||||
comp_graph.z_to_a_line = z_to_a_line
|
||||
comp_graph.aL = aL
|
||||
comp_graph.y = y
|
||||
comp_graph.a_to_c_line = a_to_c_line
|
||||
comp_graph.y_to_c_line = y_to_c_line
|
||||
comp_graph.C0 = C0
|
||||
comp_graph.digest_mobject_attrs()
|
||||
self.comp_graph = comp_graph
|
||||
|
||||
def show_preceding_layer_in_computational_graph(self):
|
||||
shift_vect = DOWN
|
||||
comp_graph = self.comp_graph
|
||||
comp_graph.save_state()
|
||||
comp_graph.generate_target()
|
||||
comp_graph.target.shift(shift_vect)
|
||||
rect = SurroundingRectangle(comp_graph.aLm1)
|
||||
|
||||
attrs = ["wL", "aLm1", "bL", "zL"]
|
||||
new_terms = VGroup()
|
||||
for attr in attrs:
|
||||
term = getattr(comp_graph, attr)
|
||||
tex = term.get_tex_string()
|
||||
if "L-1" in tex:
|
||||
tex = tex.replace("L-1", "L-2")
|
||||
else:
|
||||
tex = tex.replace("L", "L-1")
|
||||
new_term = TexMobject(tex)
|
||||
new_term.highlight(term.get_color())
|
||||
new_term.move_to(term)
|
||||
new_terms.add(new_term)
|
||||
new_edges = VGroup(
|
||||
comp_graph.top_lines.copy(),
|
||||
comp_graph.z_to_a_line.copy(),
|
||||
)
|
||||
new_subgraph = VGroup(new_terms, new_edges)
|
||||
|
||||
self.play(ShowCreation(rect))
|
||||
self.play(
|
||||
GrowFromCenter(brace),
|
||||
Write(z_label),
|
||||
new_subgraph.next_to, comp_graph.target, UP, SMALL_BUFF,
|
||||
UpdateFromAlphaFunc(
|
||||
new_terms,
|
||||
lambda m, a : m.set_fill(opacity = a)
|
||||
),
|
||||
MoveToTarget(comp_graph),
|
||||
rect.shift, shift_vect
|
||||
)
|
||||
self.play(FadeOut(rect))
|
||||
self.dither(2)
|
||||
self.play(
|
||||
FadeOut(new_subgraph),
|
||||
comp_graph.restore,
|
||||
rect.shift, -shift_vect,
|
||||
rect.set_stroke, BLACK, 0
|
||||
)
|
||||
self.remove(rect)
|
||||
self.dither()
|
||||
|
||||
self.set_variables_as_attrs(z_label, z_brace = brace)
|
||||
|
||||
def break_into_computational_graph(self):
|
||||
pass
|
||||
|
||||
def show_preceding_layer_in_computational_graph(self):
|
||||
pass
|
||||
|
||||
def show_number_lines(self):
|
||||
pass
|
||||
comp_graph = self.comp_graph
|
||||
wL, aLm1, bL, zL, aL, C0 = [
|
||||
getattr(comp_graph, attr)
|
||||
for attr in ["wL", "aLm1", "bL", "zL", "aL", "C0"]
|
||||
]
|
||||
wL.val = self.network.weights[-1][0][0]
|
||||
aL.val = self.decimals[0].number
|
||||
zL.val = sigmoid_inverse(aL.val)
|
||||
C0.val = (aL.val - 1)**2
|
||||
|
||||
number_line = UnitInterval(
|
||||
unit_size = 2,
|
||||
stroke_width = 2,
|
||||
tick_size = 0.075,
|
||||
color = LIGHT_GREY,
|
||||
)
|
||||
|
||||
for mob in wL, zL, aL, C0:
|
||||
mob.number_line = number_line.deepcopy()
|
||||
if mob is wL:
|
||||
mob.number_line.next_to(mob, UP, MED_LARGE_BUFF, LEFT)
|
||||
else:
|
||||
mob.number_line.next_to(mob, RIGHT)
|
||||
mob.dot = Dot(color = mob.get_color())
|
||||
mob.dot.move_to(
|
||||
mob.number_line.number_to_point(mob.val)
|
||||
)
|
||||
if mob is wL:
|
||||
path_arc = 0
|
||||
dot_spot = mob.dot.get_bottom()
|
||||
else:
|
||||
path_arc = -0.8*np.pi
|
||||
dot_spot = mob.dot.get_top()
|
||||
if mob is C0:
|
||||
mob_spot = mob[0].get_corner(UP+RIGHT)
|
||||
tip_length = 0.15
|
||||
else:
|
||||
mob_spot = mob.get_corner(UP+RIGHT)
|
||||
tip_length = 0.2
|
||||
mob.arrow = Arrow(
|
||||
mob_spot, dot_spot,
|
||||
use_rectangular_stem = False,
|
||||
path_arc = path_arc,
|
||||
tip_length = tip_length,
|
||||
buff = SMALL_BUFF,
|
||||
)
|
||||
mob.arrow.highlight(mob.get_color())
|
||||
mob.arrow.set_stroke(width = 5)
|
||||
|
||||
self.play(ShowCreation(
|
||||
mob.number_line,
|
||||
submobject_mode = "lagged_start"
|
||||
))
|
||||
self.play(
|
||||
ShowCreation(mob.arrow),
|
||||
ReplacementTransform(
|
||||
mob.copy(), mob.dot,
|
||||
path_arc = path_arc
|
||||
)
|
||||
)
|
||||
self.dither()
|
||||
|
||||
def ask_about_w_sensitivity(self):
|
||||
pass
|
||||
wL, aLm1, bL, zL, aL, C0 = [
|
||||
getattr(self.comp_graph, attr)
|
||||
for attr in ["wL", "aLm1", "bL", "zL", "aL", "C0"]
|
||||
]
|
||||
aLm1_val = self.last_neurons[1].get_fill_opacity()
|
||||
bL_val = self.network.biases[-1][0]
|
||||
|
||||
get_wL_val = lambda : wL.number_line.point_to_number(
|
||||
wL.dot.get_center()
|
||||
)
|
||||
get_zL_val = lambda : get_wL_val()*aLm1_val+bL_val
|
||||
get_aL_val = lambda : sigmoid(get_zL_val())
|
||||
get_C0_val = lambda : (get_aL_val() - 1)**2
|
||||
|
||||
def generate_dot_update(term, val_func):
|
||||
def update_dot(dot):
|
||||
dot.move_to(term.number_line.number_to_point(val_func()))
|
||||
return dot
|
||||
return update_dot
|
||||
|
||||
dot_update_anims = [
|
||||
UpdateFromFunc(term.dot, generate_dot_update(term, val_func))
|
||||
for term, val_func in [
|
||||
(zL, get_zL_val),
|
||||
(aL, get_aL_val),
|
||||
(C0, get_C0_val),
|
||||
]
|
||||
]
|
||||
|
||||
wL_line = Line(wL.dot.get_center(), wL.dot.get_center()+LEFT)
|
||||
del_wL = TexMobject("\\partial w^{(L)}")
|
||||
del_wL.scale(self.derivative_scale_vale)
|
||||
del_wL.brace = Brace(wL_line, UP)
|
||||
del_wL.highlight(wL.get_color())
|
||||
del_wL.next_to(del_wL.brace, UP, SMALL_BUFF)
|
||||
|
||||
C0_line = Line(C0.dot.get_center(), C0.dot.get_center()+MED_SMALL_BUFF*RIGHT)
|
||||
del_C0 = TexMobject("\\partial C_0")
|
||||
del_C0.scale(self.derivative_scale_vale)
|
||||
del_C0.brace = Brace(C0_line, UP)
|
||||
del_C0.highlight(C0.get_color())
|
||||
del_C0.next_to(del_C0.brace, UP, SMALL_BUFF)
|
||||
|
||||
for sym in del_wL, del_C0:
|
||||
self.play(
|
||||
GrowFromCenter(sym.brace),
|
||||
Write(sym, run_time = 1)
|
||||
)
|
||||
self.play(
|
||||
ApplyMethod(
|
||||
wL.dot.shift, LEFT,
|
||||
run_time = 2,
|
||||
rate_func = there_and_back
|
||||
),
|
||||
*dot_update_anims
|
||||
)
|
||||
self.dither()
|
||||
|
||||
self.set_variables_as_attrs(
|
||||
dot_update_anims, del_wL, del_C0,
|
||||
)
|
||||
|
||||
def show_derivative_wrt_w(self):
|
||||
pass
|
||||
|
|
Loading…
Add table
Reference in a new issue