mirror of
https://github.com/3b1b/manim.git
synced 2025-08-21 05:44:04 +00:00
Further SimplestNetworkExample practice
This commit is contained in:
parent
a92e03a9b8
commit
e5490913a8
1 changed files with 107 additions and 7 deletions
114
nn/part3.py
114
nn/part3.py
|
@ -1778,7 +1778,6 @@ class OrganizeDataIntoMiniBatches(Scene):
|
||||||
"random_seed" : 0,
|
"random_seed" : 0,
|
||||||
}
|
}
|
||||||
def construct(self):
|
def construct(self):
|
||||||
self.frame_duration = 1./24
|
|
||||||
self.seed_random_libraries()
|
self.seed_random_libraries()
|
||||||
self.add_examples()
|
self.add_examples()
|
||||||
self.shuffle_examples()
|
self.shuffle_examples()
|
||||||
|
@ -1998,6 +1997,12 @@ class EOCWrapper(Scene):
|
||||||
self.dither()
|
self.dither()
|
||||||
|
|
||||||
class SimplestNetworkExample(PreviewLearning):
|
class SimplestNetworkExample(PreviewLearning):
|
||||||
|
CONFIG = {
|
||||||
|
"random_seed" : 6,
|
||||||
|
"z_color" : GREEN,
|
||||||
|
"cost_color" : RED,
|
||||||
|
"desired_output_color" : YELLOW,
|
||||||
|
}
|
||||||
def construct(self):
|
def construct(self):
|
||||||
self.force_skipping()
|
self.force_skipping()
|
||||||
|
|
||||||
|
@ -2023,8 +2028,8 @@ class SimplestNetworkExample(PreviewLearning):
|
||||||
self.show_previous_weight_and_bias()
|
self.show_previous_weight_and_bias()
|
||||||
|
|
||||||
def seed_random_libraries(self):
|
def seed_random_libraries(self):
|
||||||
np.random.seed(0)
|
np.random.seed(self.random_seed)
|
||||||
random.seed(0)
|
random.seed(self.random_seed)
|
||||||
|
|
||||||
def collapse_ordinary_network(self):
|
def collapse_ordinary_network(self):
|
||||||
network_mob = self.network_mob
|
network_mob = self.network_mob
|
||||||
|
@ -2174,7 +2179,6 @@ class SimplestNetworkExample(PreviewLearning):
|
||||||
rect = SurroundingRectangle(neuron)
|
rect = SurroundingRectangle(neuron)
|
||||||
words = TextMobject("Desired \\\\ output")
|
words = TextMobject("Desired \\\\ output")
|
||||||
words.next_to(rect, UP)
|
words.next_to(rect, UP)
|
||||||
VGroup(words, rect).highlight(YELLOW)
|
|
||||||
|
|
||||||
y_label = TexMobject("y")
|
y_label = TexMobject("y")
|
||||||
y_label.next_to(neuron, DOWN, LARGE_BUFF)
|
y_label.next_to(neuron, DOWN, LARGE_BUFF)
|
||||||
|
@ -2184,6 +2188,7 @@ class SimplestNetworkExample(PreviewLearning):
|
||||||
color = WHITE,
|
color = WHITE,
|
||||||
buff = SMALL_BUFF
|
buff = SMALL_BUFF
|
||||||
)
|
)
|
||||||
|
VGroup(words, rect, y_label).highlight(self.desired_output_color)
|
||||||
|
|
||||||
self.play(*map(FadeIn, [neuron, decimal]))
|
self.play(*map(FadeIn, [neuron, decimal]))
|
||||||
self.play(
|
self.play(
|
||||||
|
@ -2206,13 +2211,108 @@ class SimplestNetworkExample(PreviewLearning):
|
||||||
)
|
)
|
||||||
|
|
||||||
def show_cost(self):
|
def show_cost(self):
|
||||||
pass
|
pre_a = self.a_labels[0].copy()
|
||||||
|
pre_y = self.y_label.copy()
|
||||||
|
|
||||||
|
cost_equation = TexMobject(
|
||||||
|
"C_0", "(", "\\dots", ")", "=",
|
||||||
|
"(", "a^{(L)}", "-", "y", ")", "^2"
|
||||||
|
)
|
||||||
|
cost_equation.to_corner(UP+RIGHT)
|
||||||
|
C0, a, y = [
|
||||||
|
cost_equation.get_part_by_tex(tex)
|
||||||
|
for tex in "C_0", "a^{(L)}", "y"
|
||||||
|
]
|
||||||
|
y.highlight(YELLOW)
|
||||||
|
|
||||||
|
cost_word = TextMobject("Cost")
|
||||||
|
cost_word.next_to(C0[0], LEFT, LARGE_BUFF)
|
||||||
|
cost_arrow = Arrow(
|
||||||
|
cost_word, C0,
|
||||||
|
buff = SMALL_BUFF
|
||||||
|
)
|
||||||
|
VGroup(C0, cost_word, cost_arrow).highlight(self.cost_color)
|
||||||
|
|
||||||
|
self.play(
|
||||||
|
ReplacementTransform(pre_a, a),
|
||||||
|
ReplacementTransform(pre_y, y),
|
||||||
|
)
|
||||||
|
self.play(LaggedStart(
|
||||||
|
FadeIn, VGroup(*filter(
|
||||||
|
lambda m : m not in [a, y],
|
||||||
|
cost_equation
|
||||||
|
))
|
||||||
|
))
|
||||||
|
self.dither()
|
||||||
|
self.play(
|
||||||
|
Write(cost_word, run_time = 1),
|
||||||
|
GrowArrow(cost_arrow)
|
||||||
|
)
|
||||||
|
self.play(C0.shift, MED_SMALL_BUFF*UP, rate_func = wiggle)
|
||||||
|
self.dither()
|
||||||
|
|
||||||
|
self.set_variables_as_attrs(
|
||||||
|
cost_equation, cost_word, cost_arrow
|
||||||
|
)
|
||||||
|
|
||||||
def show_activation_formula(self):
|
def show_activation_formula(self):
|
||||||
pass
|
neuron = self.network_mob.layers[-1].neurons[0]
|
||||||
|
edge = self.network_mob.edge_groups[-1][0]
|
||||||
|
pre_aL, pre_aLm1 = self.a_labels.copy()
|
||||||
|
|
||||||
|
formula = TexMobject(
|
||||||
|
"a^{(L)}", "=", "\\sigma", "(",
|
||||||
|
"w^{(L)}", "a^{(L-1)}", "+", "b^{(L)}", ")"
|
||||||
|
)
|
||||||
|
formula.next_to(neuron, UP, MED_LARGE_BUFF, RIGHT)
|
||||||
|
aL, equals, sigma, lp, wL, aLm1, plus, bL, rp = formula
|
||||||
|
wL.highlight(edge.get_color())
|
||||||
|
weight_label = wL.copy()
|
||||||
|
bL.highlight(MAROON_B)
|
||||||
|
bias_label = bL.copy()
|
||||||
|
sigma_group = VGroup(sigma, lp, rp)
|
||||||
|
sigma_group.save_state()
|
||||||
|
sigma_group.set_fill(opacity = 0)
|
||||||
|
sigma_group.shift(DOWN)
|
||||||
|
|
||||||
|
self.play(
|
||||||
|
ReplacementTransform(pre_aL, aL),
|
||||||
|
Write(equals)
|
||||||
|
)
|
||||||
|
self.play(ReplacementTransform(
|
||||||
|
edge.copy(), wL
|
||||||
|
))
|
||||||
|
self.dither()
|
||||||
|
self.play(ReplacementTransform(pre_aLm1, aLm1))
|
||||||
|
self.dither()
|
||||||
|
self.play(Write(VGroup(plus, bL), run_time = 1))
|
||||||
|
self.dither()
|
||||||
|
self.play(sigma_group.restore)
|
||||||
|
self.dither()
|
||||||
|
|
||||||
|
weighted_sum_terms = VGroup(wL, aLm1, bL)
|
||||||
|
self.set_variables_as_attrs(
|
||||||
|
formula, weighted_sum_terms
|
||||||
|
)
|
||||||
|
|
||||||
def introduce_z(self):
|
def introduce_z(self):
|
||||||
pass
|
terms = self.weighted_sum_terms
|
||||||
|
brace = Brace(terms, UP, buff = SMALL_BUFF)
|
||||||
|
z_label = TexMobject("z^{(L)}")
|
||||||
|
z_label.next_to(brace, UP, buff = SMALL_BUFF)
|
||||||
|
z_label.highlight(self.z_color)
|
||||||
|
rect = SurroundingRectangle(terms)
|
||||||
|
rect.highlight(GREEN)
|
||||||
|
|
||||||
|
self.play(ShowCreation(rect))
|
||||||
|
self.play(
|
||||||
|
GrowFromCenter(brace),
|
||||||
|
Write(z_label),
|
||||||
|
)
|
||||||
|
self.play(FadeOut(rect))
|
||||||
|
self.dither()
|
||||||
|
|
||||||
|
self.set_variables_as_attrs(z_label, z_brace = brace)
|
||||||
|
|
||||||
def break_into_computational_graph(self):
|
def break_into_computational_graph(self):
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Add table
Reference in a new issue