mirror of
https://github.com/3b1b/manim.git
synced 2025-08-05 16:49:03 +00:00
Tweaks to GradientNudging in nn/part2
This commit is contained in:
parent
666fe16e8c
commit
c5891ce25e
1 changed files with 32 additions and 8 deletions
40
nn/part2.py
40
nn/part2.py
|
@ -77,6 +77,8 @@ def get_decimal_vector(nums, with_dots = True):
|
||||||
result.brackets = brackets
|
result.brackets = brackets
|
||||||
result.decimals = decimals
|
result.decimals = decimals
|
||||||
result.contents = contents
|
result.contents = contents
|
||||||
|
if with_dots:
|
||||||
|
result.dots = dots
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@ -2606,7 +2608,8 @@ class TODOInsertEmphasizeComplexityOfCostFunctionCopy(TODOStub):
|
||||||
|
|
||||||
class GradientNudging(PreviewLearning):
|
class GradientNudging(PreviewLearning):
|
||||||
CONFIG = {
|
CONFIG = {
|
||||||
"n_steps" : 10
|
"n_steps" : 10,
|
||||||
|
"n_decimals" : 8,
|
||||||
}
|
}
|
||||||
def construct(self):
|
def construct(self):
|
||||||
self.setup_network_mob()
|
self.setup_network_mob()
|
||||||
|
@ -2623,18 +2626,24 @@ class GradientNudging(PreviewLearning):
|
||||||
lhs = TexMobject(
|
lhs = TexMobject(
|
||||||
"-", "\\nabla", "C(", "\\dots", ")", "="
|
"-", "\\nabla", "C(", "\\dots", ")", "="
|
||||||
)
|
)
|
||||||
lhs.to_edge(LEFT)
|
|
||||||
brace = Brace(lhs.get_part_by_tex("dots"), DOWN)
|
brace = Brace(lhs.get_part_by_tex("dots"), DOWN)
|
||||||
words = brace.get_text("All weights \\\\ and biases")
|
words = brace.get_text("All weights \\\\ and biases")
|
||||||
words.scale(0.8, about_point = words.get_top())
|
words.scale(0.8, about_point = words.get_top())
|
||||||
np.random.seed(3)
|
np.random.seed(3)
|
||||||
nums = 4*(np.random.random(8)-0.5)
|
nums = 4*(np.random.random(self.n_decimals)-0.5)
|
||||||
vect = get_decimal_vector(nums)
|
vect = get_decimal_vector(nums)
|
||||||
vect.next_to(lhs, RIGHT)
|
vect.next_to(lhs, RIGHT)
|
||||||
|
group = VGroup(lhs, brace, words, vect)
|
||||||
|
group.to_corner(UP+LEFT)
|
||||||
|
|
||||||
self.add(lhs, brace, words, vect)
|
self.add(*group)
|
||||||
|
|
||||||
self.grad_vect = vect
|
self.set_variables_as_attrs(
|
||||||
|
grad_lhs = lhs,
|
||||||
|
grad_vect = vect,
|
||||||
|
grad_arg_words = words,
|
||||||
|
grad_arg_brace = brace
|
||||||
|
)
|
||||||
|
|
||||||
def change_weights_repeatedly(self):
|
def change_weights_repeatedly(self):
|
||||||
network_mob = self.network_mob
|
network_mob = self.network_mob
|
||||||
|
@ -2659,14 +2668,17 @@ class GradientNudging(PreviewLearning):
|
||||||
])
|
])
|
||||||
|
|
||||||
mover = VGroup(*decimals.family_members_with_points()).copy()
|
mover = VGroup(*decimals.family_members_with_points()).copy()
|
||||||
mover.set_fill(opacity = 0)
|
# mover.set_fill(opacity = 0)
|
||||||
mover.set_stroke(width = 1)
|
mover.set_stroke(width = 0)
|
||||||
target = VGroup(*self.network_mob.edge_groups.family_members_with_points())
|
target = VGroup(*self.network_mob.edge_groups.family_members_with_points()).copy()
|
||||||
|
target.set_fill(opacity = 0)
|
||||||
|
ApplyMethod(target.set_stroke, YELLOW, 2).update(0.3)
|
||||||
self.play(
|
self.play(
|
||||||
ReplacementTransform(mover, target),
|
ReplacementTransform(mover, target),
|
||||||
FadeIn(words),
|
FadeIn(words),
|
||||||
LaggedStart(GrowArrow, arrows, run_time = 1)
|
LaggedStart(GrowArrow, arrows, run_time = 1)
|
||||||
)
|
)
|
||||||
|
self.play(FadeOut(target))
|
||||||
self.play(self.get_edge_change_anim(edges))
|
self.play(self.get_edge_change_anim(edges))
|
||||||
self.play(*self.get_decimal_change_anims(decimals))
|
self.play(*self.get_decimal_change_anims(decimals))
|
||||||
for x in range(self.n_steps):
|
for x in range(self.n_steps):
|
||||||
|
@ -2698,6 +2710,16 @@ class GradientNudging(PreviewLearning):
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_decimal_change_anims(self, decimals):
|
def get_decimal_change_anims(self, decimals):
|
||||||
|
words = TextMobject("Recompute \\\\ gradient")
|
||||||
|
words.next_to(decimals, DOWN, MED_LARGE_BUFF)
|
||||||
|
def wrf(t):
|
||||||
|
if t < 1./3:
|
||||||
|
return smooth(3*t)
|
||||||
|
elif t < 2./3:
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
return smooth(3 - 3*t)
|
||||||
|
|
||||||
changes = 0.2*(np.random.random(len(decimals))-0.5)
|
changes = 0.2*(np.random.random(len(decimals))-0.5)
|
||||||
def generate_change_func(x, dx):
|
def generate_change_func(x, dx):
|
||||||
return lambda a : interpolate(x, x+dx, a)
|
return lambda a : interpolate(x, x+dx, a)
|
||||||
|
@ -2707,6 +2729,8 @@ class GradientNudging(PreviewLearning):
|
||||||
generate_change_func(decimal.number, change)
|
generate_change_func(decimal.number, change)
|
||||||
)
|
)
|
||||||
for decimal, change in zip(decimals, changes)
|
for decimal, change in zip(decimals, changes)
|
||||||
|
] + [
|
||||||
|
FadeIn(words, rate_func = wrf, run_time = 1.5, remover = True)
|
||||||
]
|
]
|
||||||
|
|
||||||
class BackPropWrapper(PiCreatureScene):
|
class BackPropWrapper(PiCreatureScene):
|
||||||
|
|
Loading…
Add table
Reference in a new issue