Up to derivatives in self.revert_to_original_skipping_status()

This commit is contained in:
Grant Sanderson 2017-10-27 12:59:54 -07:00
parent 589fd5e028
commit 34b4ea089c

View file

@ -2002,6 +2002,7 @@ class SimplestNetworkExample(PreviewLearning):
"z_color" : GREEN,
"cost_color" : RED,
"desired_output_color" : YELLOW,
"derivative_scale_vale" : 0.7,
}
def construct(self):
self.force_skipping()
@ -2101,10 +2102,10 @@ class SimplestNetworkExample(PreviewLearning):
self.dither()
def label_neurons(self):
neurons = [
neurons = VGroup(*[
self.network_mob.layers[i].neurons[0]
for i in -1, -2
]
])
decimals = VGroup()
a_labels = VGroup()
a_label_arrows = VGroup()
@ -2167,7 +2168,8 @@ class SimplestNetworkExample(PreviewLearning):
self.play(*map(FadeOut, [not_exponents, superscript_rects]))
self.set_variables_as_attrs(
a_labels, a_label_arrows, decimals
a_labels, a_label_arrows, decimals,
last_neurons = neurons
)
def show_desired_output(self):
@ -2290,41 +2292,301 @@ class SimplestNetworkExample(PreviewLearning):
self.play(sigma_group.restore)
self.dither()
weighted_sum_terms = VGroup(wL, aLm1, bL)
weighted_sum_terms = VGroup(wL, aLm1, plus, bL)
self.set_variables_as_attrs(
formula, weighted_sum_terms
)
def introduce_z(self):
terms = self.weighted_sum_terms
brace = Brace(terms, UP, buff = SMALL_BUFF)
terms.generate_target()
terms.target.next_to(self.formula, UP, aligned_edge = RIGHT)
equals = TexMobject("=")
equals.next_to(terms.target[0][0], LEFT)
z_label = TexMobject("z^{(L)}")
z_label.next_to(brace, UP, buff = SMALL_BUFF)
z_label.next_to(equals, LEFT)
z_label.align_to(terms.target, DOWN)
z_label.highlight(self.z_color)
rect = SurroundingRectangle(terms)
rect.highlight(GREEN)
z_label2 = z_label.copy()
aL_start = VGroup(*self.formula[:4])
aL_start.generate_target()
aL_start.target.align_to(z_label, LEFT)
z_label2.next_to(aL_start.target, RIGHT, SMALL_BUFF)
z_label2.align_to(aL_start.target[0], DOWN)
rp = self.formula[-1]
rp.generate_target()
rp.target.next_to(z_label2, RIGHT, SMALL_BUFF)
rp.target.align_to(aL_start.target, DOWN)
self.play(MoveToTarget(terms))
self.play(Write(z_label), Write(equals))
self.play(
ReplacementTransform(z_label.copy(), z_label2),
MoveToTarget(aL_start),
MoveToTarget(rp),
)
self.dither()
zL_formula = VGroup(z_label, equals, terms)
aL_formula = VGroup(aL_start, z_label2, rp)
self.set_variables_as_attrs(z_label, zL_formula, aL_formula)
def break_into_computational_graph(self):
network_early_layers = VGroup(*it.chain(
self.network_mob.layers[:2],
self.network_mob.edge_groups[:2]
))
wL, aL, plus, bL = self.weighted_sum_terms
top_terms = VGroup(wL, aL, bL).copy()
zL = self.z_label.copy()
aL = self.formula[0].copy()
y = self.y_label.copy()
C0 = self.cost_equation[0].copy()
targets = VGroup()
for mob in top_terms, zL, aL, C0:
mob.generate_target()
targets.add(mob.target)
y.generate_target()
top_terms.target.arrange_submobjects(RIGHT, buff = MED_LARGE_BUFF)
targets.arrange_submobjects(DOWN, buff = LARGE_BUFF)
targets.center().to_corner(DOWN+LEFT)
y.target.next_to(aL.target, LEFT, LARGE_BUFF, DOWN)
top_lines = VGroup(*[
Line(
term.get_bottom(),
zL.target.get_top(),
buff = SMALL_BUFF
)
for term in top_terms.target
])
z_to_a_line, a_to_c_line, y_to_c_line = all_lines = [
Line(
m1.target.get_bottom(),
m2.target.get_top(),
buff = SMALL_BUFF
)
for m1, m2 in [
(zL, aL),
(aL, C0),
(y, C0)
]
]
for mob in [top_lines] + all_lines:
yellow_copy = mob.copy().highlight(YELLOW)
mob.flash = ShowCreationThenDestruction(yellow_copy)
self.play(MoveToTarget(top_terms))
self.dither()
self.play(MoveToTarget(zL))
self.play(
ShowCreation(top_lines, submobject_mode = "all_at_once"),
top_lines.flash
)
self.dither()
self.play(MoveToTarget(aL))
self.play(
FadeOut(network_early_layers),
ShowCreation(z_to_a_line),
z_to_a_line.flash
)
self.dither()
self.play(MoveToTarget(y))
self.play(MoveToTarget(C0))
self.play(*it.chain(*[
[ShowCreation(line), line.flash]
for line in a_to_c_line, y_to_c_line
]))
self.dither(2)
comp_graph = VGroup()
comp_graph.wL, comp_graph.aLm1, comp_graph.bL = top_terms
comp_graph.top_lines = top_lines
comp_graph.zL = zL
comp_graph.z_to_a_line = z_to_a_line
comp_graph.aL = aL
comp_graph.y = y
comp_graph.a_to_c_line = a_to_c_line
comp_graph.y_to_c_line = y_to_c_line
comp_graph.C0 = C0
comp_graph.digest_mobject_attrs()
self.comp_graph = comp_graph
def show_preceding_layer_in_computational_graph(self):
shift_vect = DOWN
comp_graph = self.comp_graph
comp_graph.save_state()
comp_graph.generate_target()
comp_graph.target.shift(shift_vect)
rect = SurroundingRectangle(comp_graph.aLm1)
attrs = ["wL", "aLm1", "bL", "zL"]
new_terms = VGroup()
for attr in attrs:
term = getattr(comp_graph, attr)
tex = term.get_tex_string()
if "L-1" in tex:
tex = tex.replace("L-1", "L-2")
else:
tex = tex.replace("L", "L-1")
new_term = TexMobject(tex)
new_term.highlight(term.get_color())
new_term.move_to(term)
new_terms.add(new_term)
new_edges = VGroup(
comp_graph.top_lines.copy(),
comp_graph.z_to_a_line.copy(),
)
new_subgraph = VGroup(new_terms, new_edges)
self.play(ShowCreation(rect))
self.play(
GrowFromCenter(brace),
Write(z_label),
new_subgraph.next_to, comp_graph.target, UP, SMALL_BUFF,
UpdateFromAlphaFunc(
new_terms,
lambda m, a : m.set_fill(opacity = a)
),
MoveToTarget(comp_graph),
rect.shift, shift_vect
)
self.play(FadeOut(rect))
self.dither(2)
self.play(
FadeOut(new_subgraph),
comp_graph.restore,
rect.shift, -shift_vect,
rect.set_stroke, BLACK, 0
)
self.remove(rect)
self.dither()
self.set_variables_as_attrs(z_label, z_brace = brace)
def break_into_computational_graph(self):
pass
def show_preceding_layer_in_computational_graph(self):
pass
def show_number_lines(self):
pass
comp_graph = self.comp_graph
wL, aLm1, bL, zL, aL, C0 = [
getattr(comp_graph, attr)
for attr in ["wL", "aLm1", "bL", "zL", "aL", "C0"]
]
wL.val = self.network.weights[-1][0][0]
aL.val = self.decimals[0].number
zL.val = sigmoid_inverse(aL.val)
C0.val = (aL.val - 1)**2
number_line = UnitInterval(
unit_size = 2,
stroke_width = 2,
tick_size = 0.075,
color = LIGHT_GREY,
)
for mob in wL, zL, aL, C0:
mob.number_line = number_line.deepcopy()
if mob is wL:
mob.number_line.next_to(mob, UP, MED_LARGE_BUFF, LEFT)
else:
mob.number_line.next_to(mob, RIGHT)
mob.dot = Dot(color = mob.get_color())
mob.dot.move_to(
mob.number_line.number_to_point(mob.val)
)
if mob is wL:
path_arc = 0
dot_spot = mob.dot.get_bottom()
else:
path_arc = -0.8*np.pi
dot_spot = mob.dot.get_top()
if mob is C0:
mob_spot = mob[0].get_corner(UP+RIGHT)
tip_length = 0.15
else:
mob_spot = mob.get_corner(UP+RIGHT)
tip_length = 0.2
mob.arrow = Arrow(
mob_spot, dot_spot,
use_rectangular_stem = False,
path_arc = path_arc,
tip_length = tip_length,
buff = SMALL_BUFF,
)
mob.arrow.highlight(mob.get_color())
mob.arrow.set_stroke(width = 5)
self.play(ShowCreation(
mob.number_line,
submobject_mode = "lagged_start"
))
self.play(
ShowCreation(mob.arrow),
ReplacementTransform(
mob.copy(), mob.dot,
path_arc = path_arc
)
)
self.dither()
def ask_about_w_sensitivity(self):
pass
wL, aLm1, bL, zL, aL, C0 = [
getattr(self.comp_graph, attr)
for attr in ["wL", "aLm1", "bL", "zL", "aL", "C0"]
]
aLm1_val = self.last_neurons[1].get_fill_opacity()
bL_val = self.network.biases[-1][0]
get_wL_val = lambda : wL.number_line.point_to_number(
wL.dot.get_center()
)
get_zL_val = lambda : get_wL_val()*aLm1_val+bL_val
get_aL_val = lambda : sigmoid(get_zL_val())
get_C0_val = lambda : (get_aL_val() - 1)**2
def generate_dot_update(term, val_func):
def update_dot(dot):
dot.move_to(term.number_line.number_to_point(val_func()))
return dot
return update_dot
dot_update_anims = [
UpdateFromFunc(term.dot, generate_dot_update(term, val_func))
for term, val_func in [
(zL, get_zL_val),
(aL, get_aL_val),
(C0, get_C0_val),
]
]
wL_line = Line(wL.dot.get_center(), wL.dot.get_center()+LEFT)
del_wL = TexMobject("\\partial w^{(L)}")
del_wL.scale(self.derivative_scale_vale)
del_wL.brace = Brace(wL_line, UP)
del_wL.highlight(wL.get_color())
del_wL.next_to(del_wL.brace, UP, SMALL_BUFF)
C0_line = Line(C0.dot.get_center(), C0.dot.get_center()+MED_SMALL_BUFF*RIGHT)
del_C0 = TexMobject("\\partial C_0")
del_C0.scale(self.derivative_scale_vale)
del_C0.brace = Brace(C0_line, UP)
del_C0.highlight(C0.get_color())
del_C0.next_to(del_C0.brace, UP, SMALL_BUFF)
for sym in del_wL, del_C0:
self.play(
GrowFromCenter(sym.brace),
Write(sym, run_time = 1)
)
self.play(
ApplyMethod(
wL.dot.shift, LEFT,
run_time = 2,
rate_func = there_and_back
),
*dot_update_anims
)
self.dither()
self.set_variables_as_attrs(
dot_update_anims, del_wL, del_C0,
)
def show_derivative_wrt_w(self):
pass