Further SimplestNetworkExample practice

This commit is contained in:
Grant Sanderson 2017-10-26 23:26:12 -07:00
parent a92e03a9b8
commit e5490913a8

View file

@ -1778,7 +1778,6 @@ class OrganizeDataIntoMiniBatches(Scene):
"random_seed" : 0, "random_seed" : 0,
} }
def construct(self): def construct(self):
self.frame_duration = 1./24
self.seed_random_libraries() self.seed_random_libraries()
self.add_examples() self.add_examples()
self.shuffle_examples() self.shuffle_examples()
@ -1998,6 +1997,12 @@ class EOCWrapper(Scene):
self.dither() self.dither()
class SimplestNetworkExample(PreviewLearning): class SimplestNetworkExample(PreviewLearning):
CONFIG = {
"random_seed" : 6,
"z_color" : GREEN,
"cost_color" : RED,
"desired_output_color" : YELLOW,
}
def construct(self): def construct(self):
self.force_skipping() self.force_skipping()
@ -2023,8 +2028,8 @@ class SimplestNetworkExample(PreviewLearning):
self.show_previous_weight_and_bias() self.show_previous_weight_and_bias()
def seed_random_libraries(self): def seed_random_libraries(self):
np.random.seed(0) np.random.seed(self.random_seed)
random.seed(0) random.seed(self.random_seed)
def collapse_ordinary_network(self): def collapse_ordinary_network(self):
network_mob = self.network_mob network_mob = self.network_mob
@ -2174,7 +2179,6 @@ class SimplestNetworkExample(PreviewLearning):
rect = SurroundingRectangle(neuron) rect = SurroundingRectangle(neuron)
words = TextMobject("Desired \\\\ output") words = TextMobject("Desired \\\\ output")
words.next_to(rect, UP) words.next_to(rect, UP)
VGroup(words, rect).highlight(YELLOW)
y_label = TexMobject("y") y_label = TexMobject("y")
y_label.next_to(neuron, DOWN, LARGE_BUFF) y_label.next_to(neuron, DOWN, LARGE_BUFF)
@ -2184,6 +2188,7 @@ class SimplestNetworkExample(PreviewLearning):
color = WHITE, color = WHITE,
buff = SMALL_BUFF buff = SMALL_BUFF
) )
VGroup(words, rect, y_label).highlight(self.desired_output_color)
self.play(*map(FadeIn, [neuron, decimal])) self.play(*map(FadeIn, [neuron, decimal]))
self.play( self.play(
@ -2206,13 +2211,108 @@ class SimplestNetworkExample(PreviewLearning):
) )
def show_cost(self): def show_cost(self):
pass pre_a = self.a_labels[0].copy()
pre_y = self.y_label.copy()
cost_equation = TexMobject(
"C_0", "(", "\\dots", ")", "=",
"(", "a^{(L)}", "-", "y", ")", "^2"
)
cost_equation.to_corner(UP+RIGHT)
C0, a, y = [
cost_equation.get_part_by_tex(tex)
for tex in "C_0", "a^{(L)}", "y"
]
y.highlight(YELLOW)
cost_word = TextMobject("Cost")
cost_word.next_to(C0[0], LEFT, LARGE_BUFF)
cost_arrow = Arrow(
cost_word, C0,
buff = SMALL_BUFF
)
VGroup(C0, cost_word, cost_arrow).highlight(self.cost_color)
self.play(
ReplacementTransform(pre_a, a),
ReplacementTransform(pre_y, y),
)
self.play(LaggedStart(
FadeIn, VGroup(*filter(
lambda m : m not in [a, y],
cost_equation
))
))
self.dither()
self.play(
Write(cost_word, run_time = 1),
GrowArrow(cost_arrow)
)
self.play(C0.shift, MED_SMALL_BUFF*UP, rate_func = wiggle)
self.dither()
self.set_variables_as_attrs(
cost_equation, cost_word, cost_arrow
)
def show_activation_formula(self): def show_activation_formula(self):
pass neuron = self.network_mob.layers[-1].neurons[0]
edge = self.network_mob.edge_groups[-1][0]
pre_aL, pre_aLm1 = self.a_labels.copy()
formula = TexMobject(
"a^{(L)}", "=", "\\sigma", "(",
"w^{(L)}", "a^{(L-1)}", "+", "b^{(L)}", ")"
)
formula.next_to(neuron, UP, MED_LARGE_BUFF, RIGHT)
aL, equals, sigma, lp, wL, aLm1, plus, bL, rp = formula
wL.highlight(edge.get_color())
weight_label = wL.copy()
bL.highlight(MAROON_B)
bias_label = bL.copy()
sigma_group = VGroup(sigma, lp, rp)
sigma_group.save_state()
sigma_group.set_fill(opacity = 0)
sigma_group.shift(DOWN)
self.play(
ReplacementTransform(pre_aL, aL),
Write(equals)
)
self.play(ReplacementTransform(
edge.copy(), wL
))
self.dither()
self.play(ReplacementTransform(pre_aLm1, aLm1))
self.dither()
self.play(Write(VGroup(plus, bL), run_time = 1))
self.dither()
self.play(sigma_group.restore)
self.dither()
weighted_sum_terms = VGroup(wL, aLm1, bL)
self.set_variables_as_attrs(
formula, weighted_sum_terms
)
def introduce_z(self): def introduce_z(self):
pass terms = self.weighted_sum_terms
brace = Brace(terms, UP, buff = SMALL_BUFF)
z_label = TexMobject("z^{(L)}")
z_label.next_to(brace, UP, buff = SMALL_BUFF)
z_label.highlight(self.z_color)
rect = SurroundingRectangle(terms)
rect.highlight(GREEN)
self.play(ShowCreation(rect))
self.play(
GrowFromCenter(brace),
Write(z_label),
)
self.play(FadeOut(rect))
self.dither()
self.set_variables_as_attrs(z_label, z_brace = brace)
def break_into_computational_graph(self): def break_into_computational_graph(self):
pass pass