mirror of
https://github.com/3b1b/manim.git
synced 2025-08-05 16:49:03 +00:00
Tweaks to GradientNudging in nn/part2
This commit is contained in:
parent
666fe16e8c
commit
c5891ce25e
1 changed files with 32 additions and 8 deletions
40
nn/part2.py
40
nn/part2.py
|
@ -77,6 +77,8 @@ def get_decimal_vector(nums, with_dots = True):
|
|||
result.brackets = brackets
|
||||
result.decimals = decimals
|
||||
result.contents = contents
|
||||
if with_dots:
|
||||
result.dots = dots
|
||||
return result
|
||||
|
||||
|
||||
|
@ -2606,7 +2608,8 @@ class TODOInsertEmphasizeComplexityOfCostFunctionCopy(TODOStub):
|
|||
|
||||
class GradientNudging(PreviewLearning):
|
||||
CONFIG = {
|
||||
"n_steps" : 10
|
||||
"n_steps" : 10,
|
||||
"n_decimals" : 8,
|
||||
}
|
||||
def construct(self):
|
||||
self.setup_network_mob()
|
||||
|
@ -2623,18 +2626,24 @@ class GradientNudging(PreviewLearning):
|
|||
lhs = TexMobject(
|
||||
"-", "\\nabla", "C(", "\\dots", ")", "="
|
||||
)
|
||||
lhs.to_edge(LEFT)
|
||||
brace = Brace(lhs.get_part_by_tex("dots"), DOWN)
|
||||
words = brace.get_text("All weights \\\\ and biases")
|
||||
words.scale(0.8, about_point = words.get_top())
|
||||
np.random.seed(3)
|
||||
nums = 4*(np.random.random(8)-0.5)
|
||||
nums = 4*(np.random.random(self.n_decimals)-0.5)
|
||||
vect = get_decimal_vector(nums)
|
||||
vect.next_to(lhs, RIGHT)
|
||||
group = VGroup(lhs, brace, words, vect)
|
||||
group.to_corner(UP+LEFT)
|
||||
|
||||
self.add(lhs, brace, words, vect)
|
||||
self.add(*group)
|
||||
|
||||
self.grad_vect = vect
|
||||
self.set_variables_as_attrs(
|
||||
grad_lhs = lhs,
|
||||
grad_vect = vect,
|
||||
grad_arg_words = words,
|
||||
grad_arg_brace = brace
|
||||
)
|
||||
|
||||
def change_weights_repeatedly(self):
|
||||
network_mob = self.network_mob
|
||||
|
@ -2659,14 +2668,17 @@ class GradientNudging(PreviewLearning):
|
|||
])
|
||||
|
||||
mover = VGroup(*decimals.family_members_with_points()).copy()
|
||||
mover.set_fill(opacity = 0)
|
||||
mover.set_stroke(width = 1)
|
||||
target = VGroup(*self.network_mob.edge_groups.family_members_with_points())
|
||||
# mover.set_fill(opacity = 0)
|
||||
mover.set_stroke(width = 0)
|
||||
target = VGroup(*self.network_mob.edge_groups.family_members_with_points()).copy()
|
||||
target.set_fill(opacity = 0)
|
||||
ApplyMethod(target.set_stroke, YELLOW, 2).update(0.3)
|
||||
self.play(
|
||||
ReplacementTransform(mover, target),
|
||||
FadeIn(words),
|
||||
LaggedStart(GrowArrow, arrows, run_time = 1)
|
||||
)
|
||||
self.play(FadeOut(target))
|
||||
self.play(self.get_edge_change_anim(edges))
|
||||
self.play(*self.get_decimal_change_anims(decimals))
|
||||
for x in range(self.n_steps):
|
||||
|
@ -2698,6 +2710,16 @@ class GradientNudging(PreviewLearning):
|
|||
)
|
||||
|
||||
def get_decimal_change_anims(self, decimals):
|
||||
words = TextMobject("Recompute \\\\ gradient")
|
||||
words.next_to(decimals, DOWN, MED_LARGE_BUFF)
|
||||
def wrf(t):
|
||||
if t < 1./3:
|
||||
return smooth(3*t)
|
||||
elif t < 2./3:
|
||||
return 1
|
||||
else:
|
||||
return smooth(3 - 3*t)
|
||||
|
||||
changes = 0.2*(np.random.random(len(decimals))-0.5)
|
||||
def generate_change_func(x, dx):
|
||||
return lambda a : interpolate(x, x+dx, a)
|
||||
|
@ -2707,6 +2729,8 @@ class GradientNudging(PreviewLearning):
|
|||
generate_change_func(decimal.number, change)
|
||||
)
|
||||
for decimal, change in zip(decimals, changes)
|
||||
] + [
|
||||
FadeIn(words, rate_func = wrf, run_time = 1.5, remover = True)
|
||||
]
|
||||
|
||||
class BackPropWrapper(PiCreatureScene):
|
||||
|
|
Loading…
Add table
Reference in a new issue