Tweaks to GradientNudging in nn/part2

This commit is contained in:
Grant Sanderson 2017-10-20 16:29:58 -07:00
parent 666fe16e8c
commit c5891ce25e

View file

@ -77,6 +77,8 @@ def get_decimal_vector(nums, with_dots = True):
result.brackets = brackets
result.decimals = decimals
result.contents = contents
if with_dots:
result.dots = dots
return result
@ -2606,7 +2608,8 @@ class TODOInsertEmphasizeComplexityOfCostFunctionCopy(TODOStub):
class GradientNudging(PreviewLearning):
CONFIG = {
"n_steps" : 10
"n_steps" : 10,
"n_decimals" : 8,
}
def construct(self):
self.setup_network_mob()
@ -2623,18 +2626,24 @@ class GradientNudging(PreviewLearning):
lhs = TexMobject(
"-", "\\nabla", "C(", "\\dots", ")", "="
)
lhs.to_edge(LEFT)
brace = Brace(lhs.get_part_by_tex("dots"), DOWN)
words = brace.get_text("All weights \\\\ and biases")
words.scale(0.8, about_point = words.get_top())
np.random.seed(3)
nums = 4*(np.random.random(8)-0.5)
nums = 4*(np.random.random(self.n_decimals)-0.5)
vect = get_decimal_vector(nums)
vect.next_to(lhs, RIGHT)
group = VGroup(lhs, brace, words, vect)
group.to_corner(UP+LEFT)
self.add(lhs, brace, words, vect)
self.add(*group)
self.grad_vect = vect
self.set_variables_as_attrs(
grad_lhs = lhs,
grad_vect = vect,
grad_arg_words = words,
grad_arg_brace = brace
)
def change_weights_repeatedly(self):
network_mob = self.network_mob
@ -2659,14 +2668,17 @@ class GradientNudging(PreviewLearning):
])
mover = VGroup(*decimals.family_members_with_points()).copy()
mover.set_fill(opacity = 0)
mover.set_stroke(width = 1)
target = VGroup(*self.network_mob.edge_groups.family_members_with_points())
# mover.set_fill(opacity = 0)
mover.set_stroke(width = 0)
target = VGroup(*self.network_mob.edge_groups.family_members_with_points()).copy()
target.set_fill(opacity = 0)
ApplyMethod(target.set_stroke, YELLOW, 2).update(0.3)
self.play(
ReplacementTransform(mover, target),
FadeIn(words),
LaggedStart(GrowArrow, arrows, run_time = 1)
)
self.play(FadeOut(target))
self.play(self.get_edge_change_anim(edges))
self.play(*self.get_decimal_change_anims(decimals))
for x in range(self.n_steps):
@ -2698,6 +2710,16 @@ class GradientNudging(PreviewLearning):
)
def get_decimal_change_anims(self, decimals):
words = TextMobject("Recompute \\\\ gradient")
words.next_to(decimals, DOWN, MED_LARGE_BUFF)
def wrf(t):
if t < 1./3:
return smooth(3*t)
elif t < 2./3:
return 1
else:
return smooth(3 - 3*t)
changes = 0.2*(np.random.random(len(decimals))-0.5)
def generate_change_func(x, dx):
return lambda a : interpolate(x, x+dx, a)
@ -2707,6 +2729,8 @@ class GradientNudging(PreviewLearning):
generate_change_func(decimal.number, change)
)
for decimal, change in zip(decimals, changes)
] + [
FadeIn(words, rate_func = wrf, run_time = 1.5, remover = True)
]
class BackPropWrapper(PiCreatureScene):