mirror of
https://github.com/3b1b/manim.git
synced 2025-08-21 05:44:04 +00:00
Further SimplestNetworkExample practice
This commit is contained in:
parent
a92e03a9b8
commit
e5490913a8
1 changed files with 107 additions and 7 deletions
114
nn/part3.py
114
nn/part3.py
|
@ -1778,7 +1778,6 @@ class OrganizeDataIntoMiniBatches(Scene):
|
|||
"random_seed" : 0,
|
||||
}
|
||||
def construct(self):
|
||||
self.frame_duration = 1./24
|
||||
self.seed_random_libraries()
|
||||
self.add_examples()
|
||||
self.shuffle_examples()
|
||||
|
@ -1998,6 +1997,12 @@ class EOCWrapper(Scene):
|
|||
self.dither()
|
||||
|
||||
class SimplestNetworkExample(PreviewLearning):
|
||||
CONFIG = {
|
||||
"random_seed" : 6,
|
||||
"z_color" : GREEN,
|
||||
"cost_color" : RED,
|
||||
"desired_output_color" : YELLOW,
|
||||
}
|
||||
def construct(self):
|
||||
self.force_skipping()
|
||||
|
||||
|
@ -2023,8 +2028,8 @@ class SimplestNetworkExample(PreviewLearning):
|
|||
self.show_previous_weight_and_bias()
|
||||
|
||||
def seed_random_libraries(self):
|
||||
np.random.seed(0)
|
||||
random.seed(0)
|
||||
np.random.seed(self.random_seed)
|
||||
random.seed(self.random_seed)
|
||||
|
||||
def collapse_ordinary_network(self):
|
||||
network_mob = self.network_mob
|
||||
|
@ -2174,7 +2179,6 @@ class SimplestNetworkExample(PreviewLearning):
|
|||
rect = SurroundingRectangle(neuron)
|
||||
words = TextMobject("Desired \\\\ output")
|
||||
words.next_to(rect, UP)
|
||||
VGroup(words, rect).highlight(YELLOW)
|
||||
|
||||
y_label = TexMobject("y")
|
||||
y_label.next_to(neuron, DOWN, LARGE_BUFF)
|
||||
|
@ -2184,6 +2188,7 @@ class SimplestNetworkExample(PreviewLearning):
|
|||
color = WHITE,
|
||||
buff = SMALL_BUFF
|
||||
)
|
||||
VGroup(words, rect, y_label).highlight(self.desired_output_color)
|
||||
|
||||
self.play(*map(FadeIn, [neuron, decimal]))
|
||||
self.play(
|
||||
|
@ -2206,13 +2211,108 @@ class SimplestNetworkExample(PreviewLearning):
|
|||
)
|
||||
|
||||
def show_cost(self):
|
||||
pass
|
||||
pre_a = self.a_labels[0].copy()
|
||||
pre_y = self.y_label.copy()
|
||||
|
||||
cost_equation = TexMobject(
|
||||
"C_0", "(", "\\dots", ")", "=",
|
||||
"(", "a^{(L)}", "-", "y", ")", "^2"
|
||||
)
|
||||
cost_equation.to_corner(UP+RIGHT)
|
||||
C0, a, y = [
|
||||
cost_equation.get_part_by_tex(tex)
|
||||
for tex in "C_0", "a^{(L)}", "y"
|
||||
]
|
||||
y.highlight(YELLOW)
|
||||
|
||||
cost_word = TextMobject("Cost")
|
||||
cost_word.next_to(C0[0], LEFT, LARGE_BUFF)
|
||||
cost_arrow = Arrow(
|
||||
cost_word, C0,
|
||||
buff = SMALL_BUFF
|
||||
)
|
||||
VGroup(C0, cost_word, cost_arrow).highlight(self.cost_color)
|
||||
|
||||
self.play(
|
||||
ReplacementTransform(pre_a, a),
|
||||
ReplacementTransform(pre_y, y),
|
||||
)
|
||||
self.play(LaggedStart(
|
||||
FadeIn, VGroup(*filter(
|
||||
lambda m : m not in [a, y],
|
||||
cost_equation
|
||||
))
|
||||
))
|
||||
self.dither()
|
||||
self.play(
|
||||
Write(cost_word, run_time = 1),
|
||||
GrowArrow(cost_arrow)
|
||||
)
|
||||
self.play(C0.shift, MED_SMALL_BUFF*UP, rate_func = wiggle)
|
||||
self.dither()
|
||||
|
||||
self.set_variables_as_attrs(
|
||||
cost_equation, cost_word, cost_arrow
|
||||
)
|
||||
|
||||
def show_activation_formula(self):
|
||||
pass
|
||||
neuron = self.network_mob.layers[-1].neurons[0]
|
||||
edge = self.network_mob.edge_groups[-1][0]
|
||||
pre_aL, pre_aLm1 = self.a_labels.copy()
|
||||
|
||||
formula = TexMobject(
|
||||
"a^{(L)}", "=", "\\sigma", "(",
|
||||
"w^{(L)}", "a^{(L-1)}", "+", "b^{(L)}", ")"
|
||||
)
|
||||
formula.next_to(neuron, UP, MED_LARGE_BUFF, RIGHT)
|
||||
aL, equals, sigma, lp, wL, aLm1, plus, bL, rp = formula
|
||||
wL.highlight(edge.get_color())
|
||||
weight_label = wL.copy()
|
||||
bL.highlight(MAROON_B)
|
||||
bias_label = bL.copy()
|
||||
sigma_group = VGroup(sigma, lp, rp)
|
||||
sigma_group.save_state()
|
||||
sigma_group.set_fill(opacity = 0)
|
||||
sigma_group.shift(DOWN)
|
||||
|
||||
self.play(
|
||||
ReplacementTransform(pre_aL, aL),
|
||||
Write(equals)
|
||||
)
|
||||
self.play(ReplacementTransform(
|
||||
edge.copy(), wL
|
||||
))
|
||||
self.dither()
|
||||
self.play(ReplacementTransform(pre_aLm1, aLm1))
|
||||
self.dither()
|
||||
self.play(Write(VGroup(plus, bL), run_time = 1))
|
||||
self.dither()
|
||||
self.play(sigma_group.restore)
|
||||
self.dither()
|
||||
|
||||
weighted_sum_terms = VGroup(wL, aLm1, bL)
|
||||
self.set_variables_as_attrs(
|
||||
formula, weighted_sum_terms
|
||||
)
|
||||
|
||||
def introduce_z(self):
|
||||
pass
|
||||
terms = self.weighted_sum_terms
|
||||
brace = Brace(terms, UP, buff = SMALL_BUFF)
|
||||
z_label = TexMobject("z^{(L)}")
|
||||
z_label.next_to(brace, UP, buff = SMALL_BUFF)
|
||||
z_label.highlight(self.z_color)
|
||||
rect = SurroundingRectangle(terms)
|
||||
rect.highlight(GREEN)
|
||||
|
||||
self.play(ShowCreation(rect))
|
||||
self.play(
|
||||
GrowFromCenter(brace),
|
||||
Write(z_label),
|
||||
)
|
||||
self.play(FadeOut(rect))
|
||||
self.dither()
|
||||
|
||||
self.set_variables_as_attrs(z_label, z_brace = brace)
|
||||
|
||||
def break_into_computational_graph(self):
|
||||
pass
|
||||
|
|
Loading…
Add table
Reference in a new issue