From c5891ce25e26777c7fdb0e75b227c0b405ae1e15 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Fri, 20 Oct 2017 16:29:58 -0700 Subject: [PATCH] Tweaks to GradientNudging in nn/part2 --- nn/part2.py | 40 ++++++++++++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/nn/part2.py b/nn/part2.py index 350675b7..c6f75105 100644 --- a/nn/part2.py +++ b/nn/part2.py @@ -77,6 +77,8 @@ def get_decimal_vector(nums, with_dots = True): result.brackets = brackets result.decimals = decimals result.contents = contents + if with_dots: + result.dots = dots return result @@ -2606,7 +2608,8 @@ class TODOInsertEmphasizeComplexityOfCostFunctionCopy(TODOStub): class GradientNudging(PreviewLearning): CONFIG = { - "n_steps" : 10 + "n_steps" : 10, + "n_decimals" : 8, } def construct(self): self.setup_network_mob() @@ -2623,18 +2626,24 @@ class GradientNudging(PreviewLearning): lhs = TexMobject( "-", "\\nabla", "C(", "\\dots", ")", "=" ) - lhs.to_edge(LEFT) brace = Brace(lhs.get_part_by_tex("dots"), DOWN) words = brace.get_text("All weights \\\\ and biases") words.scale(0.8, about_point = words.get_top()) np.random.seed(3) - nums = 4*(np.random.random(8)-0.5) + nums = 4*(np.random.random(self.n_decimals)-0.5) vect = get_decimal_vector(nums) vect.next_to(lhs, RIGHT) + group = VGroup(lhs, brace, words, vect) + group.to_corner(UP+LEFT) - self.add(lhs, brace, words, vect) + self.add(*group) - self.grad_vect = vect + self.set_variables_as_attrs( + grad_lhs = lhs, + grad_vect = vect, + grad_arg_words = words, + grad_arg_brace = brace + ) def change_weights_repeatedly(self): network_mob = self.network_mob @@ -2659,14 +2668,17 @@ class GradientNudging(PreviewLearning): ]) mover = VGroup(*decimals.family_members_with_points()).copy() - mover.set_fill(opacity = 0) - mover.set_stroke(width = 1) - target = VGroup(*self.network_mob.edge_groups.family_members_with_points()) + # mover.set_fill(opacity = 0) + mover.set_stroke(width = 0) + target = VGroup(*self.network_mob.edge_groups.family_members_with_points()).copy() + target.set_fill(opacity = 0) + ApplyMethod(target.set_stroke, YELLOW, 2).update(0.3) self.play( ReplacementTransform(mover, target), FadeIn(words), LaggedStart(GrowArrow, arrows, run_time = 1) ) + self.play(FadeOut(target)) self.play(self.get_edge_change_anim(edges)) self.play(*self.get_decimal_change_anims(decimals)) for x in range(self.n_steps): @@ -2698,6 +2710,16 @@ class GradientNudging(PreviewLearning): ) def get_decimal_change_anims(self, decimals): + words = TextMobject("Recompute \\\\ gradient") + words.next_to(decimals, DOWN, MED_LARGE_BUFF) + def wrf(t): + if t < 1./3: + return smooth(3*t) + elif t < 2./3: + return 1 + else: + return smooth(3 - 3*t) + changes = 0.2*(np.random.random(len(decimals))-0.5) def generate_change_func(x, dx): return lambda a : interpolate(x, x+dx, a) @@ -2707,6 +2729,8 @@ class GradientNudging(PreviewLearning): generate_change_func(decimal.number, change) ) for decimal, change in zip(decimals, changes) + ] + [ + FadeIn(words, rate_func = wrf, run_time = 1.5, remover = True) ] class BackPropWrapper(PiCreatureScene):