mirror of
https://github.com/3b1b/manim.git
synced 2025-04-13 09:47:07 +00:00
4496 lines
144 KiB
Python
4496 lines
144 KiB
Python
from nn.network import *
|
|
from nn.part1 import *
|
|
from nn.part2 import *
|
|
|
|
class LayOutPlan(Scene):
|
|
def construct(self):
|
|
title = TextMobject("Plan")
|
|
title.scale(1.5)
|
|
title.to_edge(UP)
|
|
h_line = Line(LEFT, RIGHT).scale(FRAME_X_RADIUS - 1)
|
|
h_line.next_to(title, DOWN)
|
|
|
|
items = BulletedList(
|
|
"Recap",
|
|
"Intuitive walkthrough",
|
|
"Derivatives in \\\\ computational graphs",
|
|
)
|
|
items.to_edge(LEFT, buff = LARGE_BUFF)
|
|
self.add(items)
|
|
|
|
rect = ScreenRectangle()
|
|
rect.set_width(FRAME_WIDTH - items.get_width() - 2)
|
|
rect.next_to(items, RIGHT, MED_LARGE_BUFF)
|
|
|
|
self.play(
|
|
Write(title),
|
|
ShowCreation(h_line),
|
|
ShowCreation(rect),
|
|
run_time = 2
|
|
)
|
|
for i in range(len(items)):
|
|
self.play(items.fade_all_but, i)
|
|
self.wait(2)
|
|
|
|
class TODOInsertFeedForwardAnimations(TODOStub):
|
|
CONFIG = {
|
|
"message" : "Insert Feed Forward Animations"
|
|
}
|
|
|
|
class TODOInsertStepsDownCostSurface(TODOStub):
|
|
CONFIG = {
|
|
"message" : "Insert Steps Down Cost Surface"
|
|
}
|
|
|
|
class TODOInsertDefinitionOfCostFunction(TODOStub):
|
|
CONFIG = {
|
|
"message" : "Insert Definition of cost function"
|
|
}
|
|
|
|
class TODOInsertGradientNudging(TODOStub):
|
|
CONFIG = {
|
|
"message" : "Insert GradientNudging"
|
|
}
|
|
|
|
class InterpretGradientComponents(GradientNudging):
|
|
CONFIG = {
|
|
"network_mob_config" : {
|
|
"layer_to_layer_buff" : 3,
|
|
},
|
|
"stroke_width_exp" : 2,
|
|
"n_decimals" : 6,
|
|
"n_steps" : 3,
|
|
"start_cost" : 3.48,
|
|
"delta_cost" : -0.21,
|
|
}
|
|
def construct(self):
|
|
self.setup_network()
|
|
self.add_cost()
|
|
self.add_gradient()
|
|
self.change_weights_repeatedly()
|
|
self.ask_about_high_dimensions()
|
|
self.circle_magnitudes()
|
|
self.isolate_particular_weights()
|
|
self.shift_cost_expression()
|
|
self.tweak_individual_weights()
|
|
|
|
def setup_network(self):
|
|
self.network_mob.scale(0.55)
|
|
self.network_mob.to_corner(UP+RIGHT)
|
|
self.color_network_edges()
|
|
|
|
def add_cost(self):
|
|
rect = SurroundingRectangle(self.network_mob)
|
|
rect.set_color(RED)
|
|
arrow = Vector(DOWN, color = RED)
|
|
arrow.shift(rect.get_bottom())
|
|
cost = DecimalNumber(self.start_cost)
|
|
cost.set_color(RED)
|
|
cost.next_to(arrow, DOWN)
|
|
|
|
cost_expression = TexMobject(
|
|
"C(", "w_0, w_1, \\dots, w_{13{,}001}", ")", "="
|
|
)
|
|
for tex in "()":
|
|
cost_expression.set_color_by_tex(tex, RED)
|
|
cost_expression.next_to(cost, DOWN)
|
|
cost_group = VGroup(cost_expression, cost)
|
|
cost_group.arrange(RIGHT)
|
|
cost_group.next_to(arrow, DOWN)
|
|
|
|
self.add(rect, arrow, cost_group)
|
|
|
|
self.set_variables_as_attrs(
|
|
cost, cost_expression, cost_group,
|
|
network_rect = rect
|
|
)
|
|
|
|
def change_weights_repeatedly(self):
|
|
decimals = self.grad_vect.decimals
|
|
grad_terms = self.grad_vect.contents
|
|
edges = VGroup(*reversed(list(
|
|
it.chain(*self.network_mob.edge_groups)
|
|
)))
|
|
cost = self.cost
|
|
|
|
for x in range(self.n_steps):
|
|
self.move_grad_terms_into_position(
|
|
grad_terms.copy(),
|
|
*self.get_weight_adjustment_anims(edges, cost)
|
|
)
|
|
self.play(*self.get_decimal_change_anims(decimals))
|
|
|
|
def ask_about_high_dimensions(self):
|
|
grad_vect = self.grad_vect
|
|
|
|
words = TextMobject(
|
|
"Direction in \\\\ ${13{,}002}$ dimensions?!?")
|
|
words.set_color(YELLOW)
|
|
words.move_to(grad_vect).to_edge(DOWN)
|
|
arrow = Arrow(
|
|
words.get_top(),
|
|
grad_vect.get_bottom(),
|
|
buff = SMALL_BUFF
|
|
)
|
|
|
|
randy = Randolph()
|
|
randy.scale(0.6)
|
|
randy.next_to(words, LEFT)
|
|
randy.shift_onto_screen()
|
|
|
|
self.play(
|
|
Write(words, run_time = 2),
|
|
GrowArrow(arrow),
|
|
)
|
|
self.play(FadeIn(randy))
|
|
self.play(randy.change, "confused", words)
|
|
self.play(Blink(randy))
|
|
self.wait()
|
|
self.play(*list(map(FadeOut, [randy, words, arrow])))
|
|
|
|
def circle_magnitudes(self):
|
|
rects = VGroup()
|
|
for decimal in self.grad_vect.decimals:
|
|
rects.add(SurroundingRectangle(VGroup(*decimal[-8:])))
|
|
rects.set_color(WHITE)
|
|
|
|
self.play(OldLaggedStart(ShowCreation, rects))
|
|
self.play(FadeOut(rects))
|
|
|
|
def isolate_particular_weights(self):
|
|
vect_contents = self.grad_vect.contents
|
|
w_terms = self.cost_expression[1]
|
|
|
|
edges = self.network_mob.edge_groups
|
|
edge1 = self.network_mob.layers[1].neurons[3].edges_in[0].copy()
|
|
edge2 = self.network_mob.layers[1].neurons[9].edges_in[15].copy()
|
|
VGroup(edge1, edge2).set_stroke(width = 4)
|
|
d1 = DecimalNumber(3.2)
|
|
d2 = DecimalNumber(0.1)
|
|
VGroup(edge1, d1).set_color(YELLOW)
|
|
VGroup(edge2, d2).set_color(MAROON_B)
|
|
new_vect_contents = VGroup(
|
|
TexMobject("\\vdots"),
|
|
d1, TexMobject("\\vdots"),
|
|
d2, TexMobject("\\vdots"),
|
|
)
|
|
new_vect_contents.arrange(DOWN)
|
|
new_vect_contents.move_to(vect_contents)
|
|
|
|
new_w_terms = TexMobject(
|
|
"\\dots", "w_n", "\\dots", "w_k", "\\dots"
|
|
)
|
|
new_w_terms.move_to(w_terms, DOWN)
|
|
new_w_terms[1].set_color(d1.get_color())
|
|
new_w_terms[3].set_color(d2.get_color())
|
|
|
|
for d, edge in (d1, edge1), (d2, edge2):
|
|
d.arrow = Arrow(
|
|
d.get_right(), edge.get_center(),
|
|
color = d.get_color()
|
|
)
|
|
|
|
self.play(
|
|
FadeOut(vect_contents),
|
|
FadeIn(new_vect_contents),
|
|
FadeOut(w_terms),
|
|
FadeIn(new_w_terms),
|
|
edges.set_stroke, LIGHT_GREY, 0.35,
|
|
)
|
|
self.play(GrowArrow(d1.arrow))
|
|
self.play(ShowCreation(edge1))
|
|
self.wait()
|
|
self.play(GrowArrow(d2.arrow))
|
|
self.play(ShowCreation(edge2))
|
|
self.wait(2)
|
|
|
|
self.cost_expression.remove(w_terms)
|
|
self.cost_expression.add(new_w_terms)
|
|
self.set_variables_as_attrs(
|
|
edge1, edge2, new_w_terms,
|
|
new_decimals = VGroup(d1, d2)
|
|
)
|
|
|
|
def shift_cost_expression(self):
|
|
self.play(self.cost_group.shift, DOWN+0.5*LEFT)
|
|
|
|
def tweak_individual_weights(self):
|
|
cost = self.cost
|
|
cost_num = cost.number
|
|
edges = VGroup(self.edge1, self.edge2)
|
|
decimals = self.new_decimals
|
|
changes = (1.0, 1./32)
|
|
wn = self.new_w_terms[1]
|
|
wk = self.new_w_terms[3]
|
|
|
|
number_line_template = NumberLine(
|
|
x_min = -1,
|
|
x_max = 1,
|
|
tick_frequency = 0.25,
|
|
numbers_with_elongated_ticks = [],
|
|
color = WHITE
|
|
)
|
|
for term in wn, wk, cost:
|
|
term.number_line = number_line_template.copy()
|
|
term.brace = Brace(term.number_line, DOWN, buff = SMALL_BUFF)
|
|
group = VGroup(term.number_line, term.brace)
|
|
group.next_to(term, UP)
|
|
term.dot = Dot()
|
|
term.dot.set_color(term.get_color())
|
|
term.dot.move_to(term.number_line.get_center())
|
|
term.dot.save_state()
|
|
term.dot.move_to(term)
|
|
term.dot.set_fill(opacity = 0)
|
|
term.words = TextMobject("Nudge this weight")
|
|
term.words.scale(0.7)
|
|
term.words.next_to(term.number_line, UP, MED_SMALL_BUFF)
|
|
|
|
groups = [
|
|
VGroup(d, d.arrow, edge, w)
|
|
for d, edge, w in zip(decimals, edges, [wn, wk])
|
|
]
|
|
for group in groups:
|
|
group.save_state()
|
|
|
|
for i in range(2):
|
|
group1, group2 = groups[i], groups[1-i]
|
|
change = changes[i]
|
|
edge = edges[i]
|
|
w = group1[-1]
|
|
added_anims = []
|
|
if i == 0:
|
|
added_anims = [
|
|
GrowFromCenter(cost.brace),
|
|
ShowCreation(cost.number_line),
|
|
cost.dot.restore
|
|
]
|
|
self.play(
|
|
group1.restore,
|
|
group2.fade, 0.7,
|
|
GrowFromCenter(w.brace),
|
|
ShowCreation(w.number_line),
|
|
w.dot.restore,
|
|
Write(w.words, run_time = 1),
|
|
*added_anims
|
|
)
|
|
for x in range(2):
|
|
func = lambda a : interpolate(
|
|
cost_num, cost_num-change, a
|
|
)
|
|
self.play(
|
|
ChangingDecimal(cost, func),
|
|
cost.dot.shift, change*RIGHT,
|
|
w.dot.shift, 0.25*RIGHT,
|
|
edge.set_stroke, None, 8,
|
|
rate_func = lambda t : wiggle(t, 4),
|
|
run_time = 2,
|
|
)
|
|
self.wait()
|
|
self.play(*list(map(FadeOut, [
|
|
w.dot, w.brace, w.number_line, w.words
|
|
])))
|
|
|
|
|
|
######
|
|
|
|
def move_grad_terms_into_position(self, grad_terms, *added_anims):
|
|
cost_expression = self.cost_expression
|
|
w_terms = self.cost_expression[1]
|
|
points = VGroup(*[
|
|
VectorizedPoint()
|
|
for term in grad_terms
|
|
])
|
|
points.arrange(RIGHT)
|
|
points.replace(w_terms, dim_to_match = 0)
|
|
|
|
grad_terms.generate_target()
|
|
grad_terms.target[len(grad_terms)/2].rotate(np.pi/2)
|
|
grad_terms.target.arrange(RIGHT)
|
|
grad_terms.target.set_width(cost_expression.get_width())
|
|
grad_terms.target.next_to(cost_expression, DOWN)
|
|
|
|
words = TextMobject("Nudge weights")
|
|
words.scale(0.8)
|
|
words.next_to(grad_terms.target, DOWN)
|
|
|
|
self.play(
|
|
MoveToTarget(grad_terms),
|
|
FadeIn(words)
|
|
)
|
|
self.play(
|
|
Transform(
|
|
grad_terms, points,
|
|
remover = True,
|
|
lag_ratio = 0.5,
|
|
run_time = 1
|
|
),
|
|
FadeOut(words),
|
|
*added_anims
|
|
)
|
|
|
|
def get_weight_adjustment_anims(self, edges, cost):
|
|
start_cost = cost.number
|
|
target_cost = start_cost + self.delta_cost
|
|
w_terms = self.cost_expression[1]
|
|
|
|
return [
|
|
self.get_edge_change_anim(edges),
|
|
OldLaggedStart(
|
|
Indicate, w_terms,
|
|
rate_func = there_and_back,
|
|
run_time = 1.5,
|
|
),
|
|
ChangingDecimal(
|
|
cost,
|
|
lambda a : interpolate(start_cost, target_cost, a),
|
|
run_time = 1.5
|
|
)
|
|
]
|
|
|
|
class GetLostInNotation(PiCreatureScene):
|
|
def construct(self):
|
|
morty = self.pi_creature
|
|
equations = VGroup(
|
|
TexMobject(
|
|
"\\delta", "^L", "=", "\\nabla_a", "C",
|
|
"\\odot \\sigma'(", "z", "^L)"
|
|
),
|
|
TexMobject(
|
|
"\\delta", "^l = ((", "w", "^{l+1})^T",
|
|
"\\delta", "^{l+1}) \\odot \\sigma'(", "z", "^l)"
|
|
),
|
|
TexMobject(
|
|
"{\\partial", "C", "\\over \\partial", "b",
|
|
"_j^l} =", "\\delta", "_j^l"
|
|
),
|
|
TexMobject(
|
|
"{\\partial", "C", " \\over \\partial",
|
|
"w", "_{jk}^l} = ", "a", "_k^{l-1}", "\\delta", "_j^l"
|
|
),
|
|
)
|
|
for equation in equations:
|
|
equation.set_color_by_tex_to_color_map({
|
|
"\\delta" : YELLOW,
|
|
"C" : RED,
|
|
"b" : MAROON_B,
|
|
"w" : BLUE,
|
|
"z" : TEAL,
|
|
})
|
|
equation.set_color_by_tex("nabla", WHITE)
|
|
equations.arrange(
|
|
DOWN, buff = MED_LARGE_BUFF, aligned_edge = LEFT
|
|
)
|
|
|
|
circle = Circle(radius = 3*FRAME_X_RADIUS)
|
|
circle.set_fill(WHITE, 0)
|
|
circle.set_stroke(WHITE, 0)
|
|
|
|
self.play(
|
|
Write(equations),
|
|
morty.change, "confused", equations
|
|
)
|
|
self.wait()
|
|
self.play(morty.change, "pleading")
|
|
self.wait(2)
|
|
|
|
##
|
|
movers = VGroup(*equations.family_members_with_points())
|
|
random.shuffle(movers.submobjects)
|
|
for mover in list(movers):
|
|
if mover.is_subpath:
|
|
movers.remove(mover)
|
|
continue
|
|
mover.set_stroke(WHITE, width = 0)
|
|
mover.target = Circle()
|
|
mover.target.scale(0.5)
|
|
mover.target.set_fill(mover.get_color(), opacity = 0)
|
|
mover.target.set_stroke(BLACK, width = 1)
|
|
mover.target.move_to(mover)
|
|
self.play(
|
|
OldLaggedStart(
|
|
MoveToTarget, movers,
|
|
run_time = 2,
|
|
),
|
|
morty.change, "pondering",
|
|
)
|
|
self.wait()
|
|
|
|
class TODOInsertPreviewLearning(TODOStub):
|
|
CONFIG = {
|
|
"message" : "Insert PreviewLearning"
|
|
}
|
|
|
|
class ShowAveragingCost(PreviewLearning):
|
|
CONFIG = {
|
|
"network_scale_val" : 0.8,
|
|
"stroke_width_exp" : 1,
|
|
"start_examples_time" : 5,
|
|
"examples_per_adjustment_time" : 2,
|
|
"n_adjustments" : 5,
|
|
"time_per_example" : 1./15,
|
|
"image_height" : 1.2,
|
|
}
|
|
def construct(self):
|
|
self.setup_network()
|
|
self.setup_diff_words()
|
|
self.show_many_examples()
|
|
|
|
def setup_network(self):
|
|
self.network_mob.scale(self.network_scale_val)
|
|
self.network_mob.to_edge(DOWN)
|
|
self.network_mob.shift(LEFT)
|
|
self.color_network_edges()
|
|
|
|
def setup_diff_words(self):
|
|
last_layer_copy = self.network_mob.layers[-1].deepcopy()
|
|
last_layer_copy.add(self.network_mob.output_labels.copy())
|
|
last_layer_copy.shift(1.5*RIGHT)
|
|
|
|
double_arrow = DoubleArrow(
|
|
self.network_mob.output_labels,
|
|
last_layer_copy,
|
|
color = RED
|
|
)
|
|
brace = Brace(
|
|
VGroup(self.network_mob.layers[-1], last_layer_copy),
|
|
UP
|
|
)
|
|
cost_words = brace.get_text("Cost of \\\\ one example")
|
|
cost_words.set_color(RED)
|
|
|
|
self.add(last_layer_copy, double_arrow, brace, cost_words)
|
|
self.set_variables_as_attrs(
|
|
last_layer_copy, double_arrow, brace, cost_words
|
|
)
|
|
self.last_layer_copy = last_layer_copy
|
|
|
|
def show_many_examples(self):
|
|
training_data, validation_data, test_data = load_data_wrapper()
|
|
|
|
average_words = TextMobject("Average over all training examples")
|
|
average_words.next_to(LEFT, RIGHT)
|
|
average_words.to_edge(UP)
|
|
self.add(average_words)
|
|
|
|
n_start_examples = int(self.start_examples_time/self.time_per_example)
|
|
n_examples_per_adjustment = int(self.examples_per_adjustment_time/self.time_per_example)
|
|
for train_in, train_out in training_data[:n_start_examples]:
|
|
self.show_one_example(train_in, train_out)
|
|
self.wait(self.time_per_example)
|
|
|
|
#Wiggle all edges
|
|
edges = VGroup(*it.chain(*self.network_mob.edge_groups))
|
|
reversed_edges = VGroup(*reversed(edges))
|
|
self.play(OldLaggedStart(
|
|
ApplyFunction, edges,
|
|
lambda edge : (
|
|
lambda m : m.rotate_in_place(np.pi/12).set_color(YELLOW),
|
|
edge,
|
|
),
|
|
rate_func = lambda t : wiggle(t, 4),
|
|
run_time = 3,
|
|
))
|
|
|
|
#Show all, then adjust
|
|
words = TextMobject(
|
|
"Each step \\\\ uses every \\\\ example\\\\",
|
|
"$\\dots$theoretically",
|
|
alignment = ""
|
|
)
|
|
words.set_color(YELLOW)
|
|
words.scale(0.8)
|
|
words.to_corner(UP+LEFT)
|
|
|
|
for x in range(self.n_adjustments):
|
|
if x < 2:
|
|
self.play(FadeIn(words[x]))
|
|
for train_in, train_out in training_data[:n_examples_per_adjustment]:
|
|
self.show_one_example(train_in, train_out)
|
|
self.wait(self.time_per_example)
|
|
self.play(OldLaggedStart(
|
|
ApplyMethod, reversed_edges,
|
|
lambda m : (m.rotate_in_place, np.pi),
|
|
run_time = 1,
|
|
lag_ratio = 0.2,
|
|
))
|
|
if x >= 2:
|
|
self.wait()
|
|
|
|
####
|
|
|
|
def show_one_example(self, train_in, train_out):
|
|
if hasattr(self, "curr_image"):
|
|
self.remove(self.curr_image)
|
|
image = MNistMobject(train_in)
|
|
image.set_height(self.image_height)
|
|
image.next_to(
|
|
self.network_mob.layers[0].neurons, UP,
|
|
aligned_edge = LEFT
|
|
)
|
|
self.add(image)
|
|
self.network_mob.activate_layers(train_in)
|
|
|
|
index = np.argmax(train_out)
|
|
self.last_layer_copy.neurons.set_fill(opacity = 0)
|
|
self.last_layer_copy.neurons[index].set_fill(WHITE, opacity = 1)
|
|
self.add(self.last_layer_copy)
|
|
|
|
self.curr_image = image
|
|
|
|
class FocusOnOneExample(TeacherStudentsScene):
|
|
def construct(self):
|
|
self.teacher_says("Focus on just \\\\ one example")
|
|
self.wait(2)
|
|
|
|
class WalkThroughTwoExample(ShowAveragingCost):
|
|
CONFIG = {
|
|
"random_seed" : 0,
|
|
}
|
|
def setup(self):
|
|
np.random.seed(self.random_seed)
|
|
random.seed(self.random_seed)
|
|
self.setup_bases()
|
|
|
|
def construct(self):
|
|
self.force_skipping()
|
|
|
|
self.setup_network()
|
|
self.setup_diff_words()
|
|
self.show_single_example()
|
|
self.single_example_influencing_weights()
|
|
self.expand_last_layer()
|
|
self.show_activation_formula()
|
|
self.three_ways_to_increase()
|
|
self.note_connections_to_brightest_neurons()
|
|
self.fire_together_wire_together()
|
|
self.show_desired_increase_to_previous_neurons()
|
|
self.only_keeping_track_of_changes()
|
|
self.show_other_output_neurons()
|
|
self.show_recursion()
|
|
|
|
def show_single_example(self):
|
|
two_vect = get_organized_images()[2][0]
|
|
two_out = np.zeros(10)
|
|
two_out[2] = 1.0
|
|
self.show_one_example(two_vect, two_out)
|
|
for layer in self.network_mob.layers:
|
|
layer.neurons.set_fill(opacity = 0)
|
|
|
|
self.activate_network(two_vect)
|
|
self.wait()
|
|
|
|
def single_example_influencing_weights(self):
|
|
two = self.curr_image
|
|
two.save_state()
|
|
|
|
edge_groups = self.network_mob.edge_groups
|
|
def adjust_edge_group_anim(edge_group):
|
|
return OldLaggedStart(
|
|
ApplyFunction, edge_group,
|
|
lambda edge : (
|
|
lambda m : m.rotate_in_place(np.pi/12).set_color(YELLOW),
|
|
edge
|
|
),
|
|
rate_func = wiggle,
|
|
run_time = 1,
|
|
)
|
|
|
|
self.play(
|
|
two.next_to, edge_groups[0].get_corner(DOWN+RIGHT), DOWN,
|
|
adjust_edge_group_anim(edge_groups[0])
|
|
)
|
|
self.play(
|
|
ApplyMethod(
|
|
two.next_to, edge_groups[1].get_corner(UP+RIGHT), UP,
|
|
path_arc = np.pi/6,
|
|
),
|
|
adjust_edge_group_anim(VGroup(*reversed(edge_groups[1])))
|
|
)
|
|
self.play(
|
|
ApplyMethod(
|
|
two.next_to, edge_groups[2].get_corner(DOWN+RIGHT), DOWN,
|
|
path_arc = -np.pi/6,
|
|
),
|
|
adjust_edge_group_anim(edge_groups[2])
|
|
)
|
|
self.play(two.restore)
|
|
self.wait()
|
|
|
|
def expand_last_layer(self):
|
|
neurons = self.network_mob.layers[-1].neurons
|
|
alt_neurons = self.last_layer_copy.neurons
|
|
output_labels = self.network_mob.output_labels
|
|
alt_output_labels = self.last_layer_copy[-1]
|
|
edges = self.network_mob.edge_groups[-1]
|
|
|
|
movers = VGroup(
|
|
neurons, alt_neurons,
|
|
output_labels, alt_output_labels,
|
|
*edges
|
|
)
|
|
to_fade = VGroup(self.brace, self.cost_words, self.double_arrow)
|
|
for mover in movers:
|
|
mover.save_state()
|
|
mover.generate_target()
|
|
mover.target.scale_in_place(2)
|
|
neurons[2].save_state()
|
|
|
|
neurons.target.to_edge(DOWN, MED_LARGE_BUFF)
|
|
output_labels.target.next_to(neurons.target, RIGHT, MED_SMALL_BUFF)
|
|
alt_neurons.target.next_to(neurons.target, RIGHT, buff = 2)
|
|
alt_output_labels.target.next_to(alt_neurons.target, RIGHT, MED_SMALL_BUFF)
|
|
|
|
n_pairs = it.product(
|
|
self.network_mob.layers[-2].neurons,
|
|
neurons.target
|
|
)
|
|
for edge, (n1, n2) in zip(edges, n_pairs):
|
|
r1 = n1.get_width()/2.0
|
|
r2 = n2.get_width()/2.0
|
|
c1 = n1.get_center()
|
|
c2 = n2.get_center()
|
|
vect = c2 - c1
|
|
norm = get_norm(vect)
|
|
unit_vect = vect / norm
|
|
|
|
edge.target.put_start_and_end_on(
|
|
c1 + unit_vect*r1,
|
|
c2 - unit_vect*r2
|
|
)
|
|
|
|
self.play(
|
|
FadeOut(to_fade),
|
|
*list(map(MoveToTarget, movers))
|
|
)
|
|
self.show_decimals(neurons)
|
|
self.cannot_directly_affect_activations()
|
|
self.show_desired_activation_nudges(neurons, output_labels, alt_output_labels)
|
|
self.focus_on_one_neuron(movers)
|
|
|
|
def show_decimals(self, neurons):
|
|
decimals = VGroup()
|
|
for neuron in neurons:
|
|
activation = neuron.get_fill_opacity()
|
|
decimal = DecimalNumber(activation, num_decimal_places = 1)
|
|
decimal.set_width(0.7*neuron.get_width())
|
|
decimal.move_to(neuron)
|
|
if activation > 0.8:
|
|
decimal.set_color(BLACK)
|
|
decimals.add(decimal)
|
|
|
|
self.play(Write(decimals, run_time = 2))
|
|
self.wait()
|
|
self.decimals = decimals
|
|
|
|
def cannot_directly_affect_activations(self):
|
|
words = TextMobject("You can only adjust weights and biases")
|
|
words.next_to(self.curr_image, RIGHT, MED_SMALL_BUFF, UP)
|
|
|
|
edges = VGroup(*self.network_mob.edge_groups.family_members_with_points())
|
|
random.shuffle(edges.submobjects)
|
|
for edge in edges:
|
|
edge.generate_target()
|
|
edge.target.set_stroke(
|
|
random.choice([BLUE, RED]),
|
|
2*random.random()**2,
|
|
)
|
|
|
|
self.play(
|
|
OldLaggedStart(
|
|
Transform, edges,
|
|
lambda e : (e, e.target),
|
|
run_time = 4,
|
|
rate_func = there_and_back,
|
|
),
|
|
Write(words, run_time = 2)
|
|
)
|
|
self.play(FadeOut(words))
|
|
|
|
def show_desired_activation_nudges(self, neurons, output_labels, alt_output_labels):
|
|
arrows = VGroup()
|
|
rects = VGroup()
|
|
for i, neuron, label in zip(it.count(), neurons, alt_output_labels):
|
|
activation = neuron.get_fill_opacity()
|
|
target_val = 1 if i == 2 else 0
|
|
diff = abs(activation - target_val)
|
|
arrow = Arrow(
|
|
ORIGIN, diff*neuron.get_height()*DOWN,
|
|
color = RED,
|
|
)
|
|
arrow.move_to(neuron.get_right())
|
|
arrow.shift(0.175*RIGHT)
|
|
if i == 2:
|
|
arrow.set_color(BLUE)
|
|
arrow.rotate_in_place(np.pi)
|
|
arrows.add(arrow)
|
|
|
|
rect = SurroundingRectangle(VGroup(neuron, label))
|
|
if i == 2:
|
|
rect.set_color(BLUE)
|
|
else:
|
|
rect.set_color(RED)
|
|
rects.add(rect)
|
|
|
|
self.play(
|
|
output_labels.shift, SMALL_BUFF*RIGHT,
|
|
OldLaggedStart(GrowArrow, arrows, run_time = 1)
|
|
)
|
|
self.wait()
|
|
|
|
#Show changing activations
|
|
anims = []
|
|
def get_decimal_update(start, end):
|
|
return lambda a : interpolate(start, end, a)
|
|
for i in range(10):
|
|
target = 1.0 if i == 2 else 0.01
|
|
anims += [neurons[i].set_fill, WHITE, target]
|
|
decimal = self.decimals[i]
|
|
anims.append(ChangingDecimal(
|
|
decimal,
|
|
get_decimal_update(decimal.number, target),
|
|
num_decimal_places = 1
|
|
))
|
|
anims.append(UpdateFromFunc(
|
|
self.decimals[i],
|
|
lambda m : m.set_fill(WHITE if m.number < 0.8 else BLACK)
|
|
))
|
|
|
|
self.play(
|
|
*anims,
|
|
run_time = 3,
|
|
rate_func = there_and_back
|
|
)
|
|
|
|
two_rect = rects[2]
|
|
eight_rect = rects[8].copy()
|
|
non_two_rects = VGroup(*[r for r in rects if r is not two_rect])
|
|
self.play(ShowCreation(two_rect))
|
|
self.wait()
|
|
self.remove(two_rect)
|
|
self.play(ReplacementTransform(two_rect.copy(), non_two_rects))
|
|
self.wait()
|
|
self.play(OldLaggedStart(FadeOut, non_two_rects, run_time = 1))
|
|
self.play(OldLaggedStart(
|
|
ApplyFunction, arrows,
|
|
lambda arrow : (
|
|
lambda m : m.scale_in_place(0.5).set_color(YELLOW),
|
|
arrow,
|
|
),
|
|
rate_func = wiggle
|
|
))
|
|
self.play(ShowCreation(two_rect))
|
|
self.wait()
|
|
self.play(ReplacementTransform(two_rect, eight_rect))
|
|
self.wait()
|
|
self.play(FadeOut(eight_rect))
|
|
|
|
self.arrows = arrows
|
|
|
|
def focus_on_one_neuron(self, expanded_mobjects):
|
|
network_mob = self.network_mob
|
|
neurons = network_mob.layers[-1].neurons
|
|
labels = network_mob.output_labels
|
|
two_neuron = neurons[2]
|
|
neurons.remove(two_neuron)
|
|
two_label = labels[2]
|
|
labels.remove(two_label)
|
|
expanded_mobjects.remove(*two_neuron.edges_in)
|
|
two_decimal = self.decimals[2]
|
|
self.decimals.remove(two_decimal)
|
|
two_arrow = self.arrows[2]
|
|
self.arrows.remove(two_arrow)
|
|
|
|
to_fade = VGroup(*it.chain(
|
|
network_mob.layers[:2],
|
|
network_mob.edge_groups[:2],
|
|
expanded_mobjects,
|
|
self.decimals,
|
|
self.arrows
|
|
))
|
|
|
|
self.play(FadeOut(to_fade))
|
|
self.wait()
|
|
for mob in expanded_mobjects:
|
|
if mob in [neurons, labels]:
|
|
mob.scale(0.5)
|
|
mob.move_to(mob.saved_state)
|
|
else:
|
|
mob.restore()
|
|
for d, a, n in zip(self.decimals, self.arrows, neurons):
|
|
d.scale(0.5)
|
|
d.move_to(n)
|
|
a.scale(0.5)
|
|
a.move_to(n.get_right())
|
|
a.shift(SMALL_BUFF*RIGHT)
|
|
labels.shift(SMALL_BUFF*RIGHT)
|
|
|
|
self.set_variables_as_attrs(
|
|
two_neuron, two_label, two_arrow, two_decimal,
|
|
)
|
|
|
|
def show_activation_formula(self):
|
|
rhs = TexMobject(
|
|
"=", "\\sigma(",
|
|
"w_0", "a_0", "+",
|
|
"w_1", "a_1", "+",
|
|
"\\cdots", "+",
|
|
"w_{n-1}", "a_{n-1}", "+",
|
|
"b", ")"
|
|
)
|
|
equals = rhs[0]
|
|
sigma = VGroup(rhs[1], rhs[-1])
|
|
w_terms = rhs.get_parts_by_tex("w_")
|
|
a_terms = rhs.get_parts_by_tex("a_")
|
|
plus_terms = rhs.get_parts_by_tex("+")
|
|
b = rhs.get_part_by_tex("b", substring = False)
|
|
dots = rhs.get_part_by_tex("dots")
|
|
|
|
w_terms.set_color(BLUE)
|
|
b.set_color(MAROON_B)
|
|
sigma.set_color(YELLOW)
|
|
|
|
rhs.to_corner(UP+RIGHT)
|
|
sigma.save_state()
|
|
sigma.shift(DOWN)
|
|
sigma.set_fill(opacity = 0)
|
|
|
|
prev_neurons = self.network_mob.layers[-2].neurons
|
|
edges = self.two_neuron.edges_in
|
|
|
|
neuron_copy = VGroup(
|
|
self.two_neuron.copy(),
|
|
self.two_decimal.copy(),
|
|
)
|
|
|
|
self.play(
|
|
neuron_copy.next_to, equals.get_left(), LEFT,
|
|
self.curr_image.to_corner, UP+LEFT,
|
|
Write(equals)
|
|
)
|
|
self.play(
|
|
ReplacementTransform(edges.copy(), w_terms),
|
|
Write(VGroup(*plus_terms[:-1])),
|
|
Write(dots),
|
|
run_time = 1.5
|
|
)
|
|
self.wait()
|
|
self.play(ReplacementTransform(
|
|
prev_neurons.copy(), a_terms,
|
|
path_arc = np.pi/2
|
|
))
|
|
self.wait()
|
|
self.play(
|
|
Write(plus_terms[-1]),
|
|
Write(b)
|
|
)
|
|
self.wait()
|
|
self.play(sigma.restore)
|
|
self.wait()
|
|
for mob in b, w_terms, a_terms:
|
|
self.play(
|
|
mob.shift, MED_SMALL_BUFF*DOWN,
|
|
rate_func = there_and_back,
|
|
lag_ratio = 0.5,
|
|
run_time = 1.5
|
|
)
|
|
self.wait()
|
|
|
|
self.set_variables_as_attrs(
|
|
rhs, w_terms, a_terms, b,
|
|
lhs = neuron_copy
|
|
)
|
|
|
|
def three_ways_to_increase(self):
|
|
w_terms = self.w_terms
|
|
a_terms = self.a_terms
|
|
b = self.b
|
|
increase_words = VGroup(
|
|
TextMobject("Increase", "$b$"),
|
|
TextMobject("Increase", "$w_i$"),
|
|
TextMobject("Change", "$a_i$"),
|
|
)
|
|
for words in increase_words:
|
|
words.set_color_by_tex_to_color_map({
|
|
"b" : b.get_color(),
|
|
"w_" : w_terms.get_color(),
|
|
"a_" : a_terms.get_color(),
|
|
})
|
|
increase_words.arrange(
|
|
DOWN, aligned_edge = LEFT,
|
|
buff = LARGE_BUFF
|
|
)
|
|
increase_words.to_edge(LEFT)
|
|
|
|
mobs = [b, w_terms[0], a_terms[0]]
|
|
for words, mob in zip(increase_words, mobs):
|
|
self.play(
|
|
Write(words[0], run_time = 1),
|
|
ReplacementTransform(mob.copy(), words[1])
|
|
)
|
|
self.wait()
|
|
|
|
self.increase_words = increase_words
|
|
|
|
def note_connections_to_brightest_neurons(self):
|
|
w_terms = self.w_terms
|
|
a_terms = self.a_terms
|
|
increase_words = self.increase_words
|
|
prev_neurons = self.network_mob.layers[-2].neurons
|
|
edges = self.two_neuron.edges_in
|
|
|
|
prev_activations = np.array([n.get_fill_opacity() for n in prev_neurons])
|
|
sorted_indices = np.argsort(prev_activations.flatten())
|
|
bright_neurons = VGroup()
|
|
dim_neurons = VGroup()
|
|
edges_to_bright_neurons = VGroup()
|
|
for i in sorted_indices[:5]:
|
|
dim_neurons.add(prev_neurons[i])
|
|
for i in sorted_indices[-4:]:
|
|
bright_neurons.add(prev_neurons[i])
|
|
edges_to_bright_neurons.add(edges[i])
|
|
bright_edges = edges_to_bright_neurons.copy()
|
|
bright_edges.set_stroke(YELLOW, 4)
|
|
|
|
added_words = TextMobject("in proportion to $a_i$")
|
|
added_words.next_to(
|
|
increase_words[1], DOWN,
|
|
1.5*SMALL_BUFF, LEFT
|
|
)
|
|
added_words.set_color(YELLOW)
|
|
|
|
terms_rect = SurroundingRectangle(
|
|
VGroup(w_terms[0], a_terms[0]),
|
|
color = WHITE
|
|
)
|
|
|
|
self.play(OldLaggedStart(
|
|
ApplyFunction, edges,
|
|
lambda edge : (
|
|
lambda m : m.rotate_in_place(np.pi/12).set_stroke(YELLOW),
|
|
edge
|
|
),
|
|
rate_func = wiggle
|
|
))
|
|
self.wait()
|
|
self.play(
|
|
ShowCreation(bright_edges),
|
|
ShowCreation(bright_neurons)
|
|
)
|
|
self.play(OldLaggedStart(
|
|
ApplyMethod, bright_neurons,
|
|
lambda m : (m.shift, MED_LARGE_BUFF*LEFT),
|
|
rate_func = there_and_back
|
|
))
|
|
self.wait()
|
|
self.play(
|
|
ReplacementTransform(bright_edges[0].copy(), w_terms[0]),
|
|
ReplacementTransform(bright_neurons[0].copy(), a_terms[0]),
|
|
ShowCreation(terms_rect)
|
|
)
|
|
self.wait()
|
|
for x in range(2):
|
|
self.play(OldLaggedStart(ShowCreationThenDestruction, bright_edges))
|
|
self.play(OldLaggedStart(ShowCreation, bright_edges))
|
|
self.play(OldLaggedStart(
|
|
ApplyMethod, dim_neurons,
|
|
lambda m : (m.shift, MED_LARGE_BUFF*LEFT),
|
|
rate_func = there_and_back
|
|
))
|
|
self.play(FadeOut(terms_rect))
|
|
self.wait()
|
|
self.play(
|
|
self.curr_image.shift, MED_LARGE_BUFF*RIGHT,
|
|
rate_func = wiggle
|
|
)
|
|
self.wait()
|
|
self.play(Write(added_words))
|
|
self.wait()
|
|
|
|
self.set_variables_as_attrs(
|
|
bright_neurons, bright_edges,
|
|
in_proportion_to_a = added_words
|
|
)
|
|
|
|
def fire_together_wire_together(self):
|
|
bright_neurons = self.bright_neurons
|
|
bright_edges = self.bright_edges
|
|
two_neuron = self.two_neuron
|
|
two_decimal = self.two_decimal
|
|
two_activation = two_decimal.number
|
|
|
|
def get_edge_animation():
|
|
return OldLaggedStart(
|
|
ShowCreationThenDestruction, bright_edges,
|
|
lag_ratio = 0.7
|
|
)
|
|
neuron_arrows = VGroup(*[
|
|
Vector(MED_LARGE_BUFF*RIGHT).next_to(n, LEFT)
|
|
for n in bright_neurons
|
|
])
|
|
two_neuron_arrow = Vector(MED_LARGE_BUFF*DOWN)
|
|
two_neuron_arrow.next_to(two_neuron, UP)
|
|
VGroup(neuron_arrows, two_neuron_arrow).set_color(YELLOW)
|
|
|
|
neuron_rects = VGroup(*list(map(
|
|
SurroundingRectangle, bright_neurons
|
|
)))
|
|
two_neuron_rect = SurroundingRectangle(two_neuron)
|
|
seeing_words = TextMobject("Seeing a 2")
|
|
seeing_words.scale(0.8)
|
|
thinking_words = TextMobject("Thinking about a 2")
|
|
thinking_words.scale(0.8)
|
|
seeing_words.next_to(neuron_rects, UP)
|
|
thinking_words.next_to(two_neuron_arrow, RIGHT)
|
|
|
|
morty = Mortimer()
|
|
morty.scale(0.8)
|
|
morty.to_corner(DOWN+RIGHT)
|
|
words = TextMobject("""
|
|
``Neurons that \\\\
|
|
fire together \\\\
|
|
wire together''
|
|
""")
|
|
words.to_edge(RIGHT)
|
|
|
|
self.play(FadeIn(morty))
|
|
self.play(
|
|
Write(words),
|
|
morty.change, "speaking", words
|
|
)
|
|
self.play(Blink(morty))
|
|
self.play(
|
|
get_edge_animation(),
|
|
morty.change, "pondering", bright_edges
|
|
)
|
|
self.play(get_edge_animation())
|
|
self.play(
|
|
OldLaggedStart(GrowArrow, neuron_arrows),
|
|
get_edge_animation(),
|
|
)
|
|
self.play(
|
|
GrowArrow(two_neuron_arrow),
|
|
morty.change, "raise_right_hand", two_neuron
|
|
)
|
|
self.play(
|
|
ApplyMethod(two_neuron.set_fill, WHITE, 1),
|
|
ChangingDecimal(
|
|
two_decimal,
|
|
lambda a : interpolate(two_activation, 1, a),
|
|
num_decimal_places = 1,
|
|
),
|
|
UpdateFromFunc(
|
|
two_decimal,
|
|
lambda m : m.set_color(WHITE if m.number < 0.8 else BLACK),
|
|
),
|
|
OldLaggedStart(ShowCreation, bright_edges),
|
|
run_time = 2,
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
OldLaggedStart(ShowCreation, neuron_rects),
|
|
Write(seeing_words, run_time = 2),
|
|
morty.change, "thinking", seeing_words
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
ShowCreation(two_neuron_rect),
|
|
Write(thinking_words, run_time = 2),
|
|
morty.look_at, thinking_words
|
|
)
|
|
self.wait()
|
|
self.play(OldLaggedStart(FadeOut, VGroup(
|
|
neuron_rects, two_neuron_rect,
|
|
seeing_words, thinking_words,
|
|
words, morty,
|
|
neuron_arrows, two_neuron_arrow,
|
|
bright_edges,
|
|
)))
|
|
self.play(
|
|
ApplyMethod(two_neuron.set_fill, WHITE, two_activation),
|
|
ChangingDecimal(
|
|
two_decimal,
|
|
lambda a : interpolate(1, two_activation, a),
|
|
num_decimal_places = 1,
|
|
),
|
|
UpdateFromFunc(
|
|
two_decimal,
|
|
lambda m : m.set_color(WHITE if m.number < 0.8 else BLACK),
|
|
),
|
|
)
|
|
|
|
def show_desired_increase_to_previous_neurons(self):
|
|
increase_words = self.increase_words
|
|
two_neuron = self.two_neuron
|
|
two_decimal = self.two_decimal
|
|
edges = two_neuron.edges_in
|
|
prev_neurons = self.network_mob.layers[-2].neurons
|
|
|
|
positive_arrows = VGroup()
|
|
negative_arrows = VGroup()
|
|
all_arrows = VGroup()
|
|
positive_edges = VGroup()
|
|
negative_edges = VGroup()
|
|
positive_neurons = VGroup()
|
|
negative_neurons = VGroup()
|
|
for neuron, edge in zip(prev_neurons, edges):
|
|
value = self.get_edge_value(edge)
|
|
arrow = self.get_neuron_nudge_arrow(edge)
|
|
arrow.move_to(neuron.get_left())
|
|
arrow.shift(SMALL_BUFF*LEFT)
|
|
all_arrows.add(arrow)
|
|
if value > 0:
|
|
positive_arrows.add(arrow)
|
|
positive_edges.add(edge)
|
|
positive_neurons.add(neuron)
|
|
else:
|
|
negative_arrows.add(arrow)
|
|
negative_edges.add(edge)
|
|
negative_neurons.add(neuron)
|
|
for s_edges in positive_edges, negative_edges:
|
|
s_edges.alt_position = VGroup(*[
|
|
Line(LEFT, RIGHT, color = s_edge.get_color())
|
|
for s_edge in s_edges
|
|
])
|
|
s_edges.alt_position.arrange(DOWN, MED_SMALL_BUFF)
|
|
s_edges.alt_position.to_corner(DOWN+RIGHT, LARGE_BUFF)
|
|
|
|
added_words = TextMobject("in proportion to $w_i$")
|
|
added_words.set_color(self.w_terms.get_color())
|
|
added_words.next_to(
|
|
increase_words[-1], DOWN,
|
|
SMALL_BUFF, aligned_edge = LEFT
|
|
)
|
|
|
|
self.play(OldLaggedStart(
|
|
ApplyFunction, prev_neurons,
|
|
lambda neuron : (
|
|
lambda m : m.scale_in_place(0.5).set_color(YELLOW),
|
|
neuron
|
|
),
|
|
rate_func = wiggle
|
|
))
|
|
self.wait()
|
|
for positive in [True, False]:
|
|
if positive:
|
|
arrows = positive_arrows
|
|
s_edges = positive_edges
|
|
neurons = positive_neurons
|
|
color = self.positive_edge_color
|
|
else:
|
|
arrows = negative_arrows
|
|
s_edges = negative_edges
|
|
neurons = negative_neurons
|
|
color = self.negative_edge_color
|
|
s_edges.save_state()
|
|
self.play(Transform(s_edges, s_edges.alt_position))
|
|
self.wait(0.5)
|
|
self.play(s_edges.restore)
|
|
self.play(
|
|
OldLaggedStart(GrowArrow, arrows),
|
|
neurons.set_stroke, color
|
|
)
|
|
self.play(ApplyMethod(
|
|
neurons.set_fill, color, 1,
|
|
rate_func = there_and_back,
|
|
))
|
|
self.wait()
|
|
self.play(
|
|
two_neuron.set_fill, None, 0.8,
|
|
ChangingDecimal(
|
|
two_decimal,
|
|
lambda a : two_neuron.get_fill_opacity()
|
|
),
|
|
run_time = 3,
|
|
rate_func = there_and_back
|
|
)
|
|
self.wait()
|
|
self.play(*[
|
|
ApplyMethod(
|
|
edge.set_stroke, None, 3*edge.get_stroke_width(),
|
|
rate_func = there_and_back,
|
|
run_time = 2
|
|
)
|
|
for edge in edges
|
|
])
|
|
self.wait()
|
|
self.play(Write(added_words, run_time = 1))
|
|
self.play(prev_neurons.set_stroke, WHITE, 2)
|
|
|
|
self.set_variables_as_attrs(
|
|
in_proportion_to_w = added_words,
|
|
prev_neuron_arrows = all_arrows,
|
|
)
|
|
|
|
def only_keeping_track_of_changes(self):
|
|
arrows = self.prev_neuron_arrows
|
|
prev_neurons = self.network_mob.layers[-2].neurons
|
|
rect = SurroundingRectangle(VGroup(arrows, prev_neurons))
|
|
|
|
words1 = TextMobject("No direct influence")
|
|
words1.next_to(rect, UP)
|
|
words2 = TextMobject("Just keeping track")
|
|
words2.move_to(words1)
|
|
|
|
edges = self.network_mob.edge_groups[-2]
|
|
|
|
self.play(ShowCreation(rect))
|
|
self.play(Write(words1))
|
|
self.play(OldLaggedStart(
|
|
Indicate, prev_neurons,
|
|
rate_func = wiggle
|
|
))
|
|
self.wait()
|
|
self.play(OldLaggedStart(
|
|
ShowCreationThenDestruction, edges
|
|
))
|
|
self.play(Transform(words1, words2))
|
|
self.wait()
|
|
self.play(FadeOut(VGroup(words1, rect)))
|
|
|
|
def show_other_output_neurons(self):
|
|
two_neuron = self.two_neuron
|
|
two_decimal = self.two_decimal
|
|
two_arrow = self.two_arrow
|
|
two_label = self.two_label
|
|
two_edges = two_neuron.edges_in
|
|
|
|
prev_neurons = self.network_mob.layers[-2].neurons
|
|
neurons = self.network_mob.layers[-1].neurons
|
|
prev_neuron_arrows = self.prev_neuron_arrows
|
|
arrows_to_fade = VGroup(prev_neuron_arrows)
|
|
|
|
output_labels = self.network_mob.output_labels
|
|
quads = list(zip(neurons, self.decimals, self.arrows, output_labels))
|
|
|
|
self.revert_to_original_skipping_status()
|
|
self.play(
|
|
two_neuron.restore,
|
|
two_decimal.scale, 0.5,
|
|
two_decimal.move_to, two_neuron.saved_state,
|
|
two_arrow.scale, 0.5,
|
|
two_arrow.next_to, two_neuron.saved_state, RIGHT, 0.5*SMALL_BUFF,
|
|
two_label.scale, 0.5,
|
|
two_label.next_to, two_neuron.saved_state, RIGHT, 1.5*SMALL_BUFF,
|
|
FadeOut(VGroup(self.lhs, self.rhs)),
|
|
*[e.restore for e in two_edges]
|
|
)
|
|
for neuron, decimal, arrow, label in quads[:2] + quads[2:5]:
|
|
plusses = VGroup()
|
|
new_arrows = VGroup()
|
|
for edge, prev_arrow in zip(neuron.edges_in, prev_neuron_arrows):
|
|
plus = TexMobject("+").scale(0.5)
|
|
plus.move_to(prev_arrow)
|
|
plus.shift(2*SMALL_BUFF*LEFT)
|
|
new_arrow = self.get_neuron_nudge_arrow(edge)
|
|
new_arrow.move_to(plus)
|
|
new_arrow.shift(2*SMALL_BUFF*LEFT)
|
|
plusses.add(plus)
|
|
new_arrows.add(new_arrow)
|
|
|
|
self.play(
|
|
FadeIn(VGroup(neuron, decimal, arrow, label)),
|
|
OldLaggedStart(ShowCreation, neuron.edges_in),
|
|
)
|
|
self.play(
|
|
ReplacementTransform(neuron.edges_in.copy(), new_arrows),
|
|
Write(plusses, run_time = 2)
|
|
)
|
|
|
|
arrows_to_fade.add(new_arrows, plusses)
|
|
prev_neuron_arrows = new_arrows
|
|
|
|
all_dots_plus = VGroup()
|
|
for arrow in prev_neuron_arrows:
|
|
dots_plus = TexMobject("\\cdots +")
|
|
dots_plus.scale(0.5)
|
|
dots_plus.move_to(arrow.get_center(), RIGHT)
|
|
dots_plus.shift(2*SMALL_BUFF*LEFT)
|
|
all_dots_plus.add(dots_plus)
|
|
arrows_to_fade.add(all_dots_plus)
|
|
|
|
self.play(
|
|
OldLaggedStart(
|
|
FadeIn, VGroup(*it.starmap(VGroup, quads[5:])),
|
|
),
|
|
OldLaggedStart(
|
|
FadeIn, VGroup(*[n.edges_in for n in neurons[5:]])
|
|
),
|
|
Write(all_dots_plus),
|
|
run_time = 3,
|
|
)
|
|
self.wait(2)
|
|
|
|
##
|
|
words = TextMobject("Propagate backwards")
|
|
words.to_edge(UP)
|
|
words.set_color(BLUE)
|
|
target_arrows = prev_neuron_arrows.copy()
|
|
target_arrows.next_to(prev_neurons, RIGHT, SMALL_BUFF)
|
|
rect = SurroundingRectangle(VGroup(
|
|
self.network_mob.layers[-1],
|
|
self.network_mob.output_labels
|
|
))
|
|
rect.set_fill(BLACK, 1)
|
|
rect.set_stroke(BLACK, 0)
|
|
self.play(Write(words))
|
|
self.wait()
|
|
self.play(
|
|
FadeOut(self.network_mob.edge_groups[-1]),
|
|
FadeIn(rect),
|
|
ReplacementTransform(arrows_to_fade, VGroup(target_arrows)),
|
|
)
|
|
self.prev_neuron_arrows = target_arrows
|
|
|
|
def show_recursion(self):
|
|
network_mob = self.network_mob
|
|
words_to_fade = VGroup(
|
|
self.increase_words,
|
|
self.in_proportion_to_w,
|
|
self.in_proportion_to_a,
|
|
)
|
|
edges = network_mob.edge_groups[1]
|
|
neurons = network_mob.layers[2].neurons
|
|
prev_neurons = network_mob.layers[1].neurons
|
|
for neuron in neurons:
|
|
neuron.edges_in.save_state()
|
|
|
|
self.play(
|
|
FadeOut(words_to_fade),
|
|
FadeIn(prev_neurons),
|
|
OldLaggedStart(ShowCreation, edges),
|
|
)
|
|
self.wait()
|
|
for neuron, arrow in zip(neurons, self.prev_neuron_arrows):
|
|
edge_copies = neuron.edges_in.copy()
|
|
for edge in edge_copies:
|
|
edge.set_stroke(arrow.get_color(), 2)
|
|
edge.rotate_in_place(np.pi)
|
|
self.play(
|
|
edges.set_stroke, None, 0.15,
|
|
neuron.edges_in.restore,
|
|
)
|
|
self.play(ShowCreationThenDestruction(edge_copies))
|
|
self.remove(edge_copies)
|
|
|
|
|
|
####
|
|
|
|
def get_neuron_nudge_arrow(self, edge):
|
|
value = self.get_edge_value(edge)
|
|
height = np.sign(value)*0.1 + 0.1*value
|
|
arrow = Vector(height*UP, color = edge.get_color())
|
|
return arrow
|
|
|
|
def get_edge_value(self, edge):
|
|
value = edge.get_stroke_width()
|
|
if Color(edge.get_stroke_color()) == Color(self.negative_edge_color):
|
|
value *= -1
|
|
return value
|
|
|
|
class WriteHebbian(Scene):
|
|
def construct(self):
|
|
words = TextMobject("Hebbian theory")
|
|
words.set_width(FRAME_WIDTH - 1)
|
|
words.to_edge(UP)
|
|
self.play(Write(words))
|
|
self.wait()
|
|
|
|
class NotANeuroScientist(TeacherStudentsScene):
|
|
def construct(self):
|
|
quote = TextMobject("``Neurons that fire together wire together''")
|
|
quote.to_edge(UP)
|
|
self.add(quote)
|
|
asterisks = TextMobject("***")
|
|
asterisks.next_to(quote.get_corner(UP+RIGHT), RIGHT, SMALL_BUFF)
|
|
asterisks.set_color(BLUE)
|
|
|
|
brain = SVGMobject(file_name = "brain")
|
|
brain.set_height(1.5)
|
|
self.add(brain)
|
|
double_arrow = DoubleArrow(LEFT, RIGHT)
|
|
double_arrow.next_to(brain, RIGHT)
|
|
q_marks = TextMobject("???")
|
|
q_marks.next_to(double_arrow, UP)
|
|
|
|
network = NetworkMobject(Network(sizes = [6, 4, 4, 5]))
|
|
network.set_height(1.5)
|
|
network.next_to(double_arrow, RIGHT)
|
|
|
|
group = VGroup(brain, double_arrow, q_marks, network)
|
|
group.next_to(self.students, UP, buff = 1.5)
|
|
self.add(group)
|
|
self.add(ContinualEdgeUpdate(
|
|
network,
|
|
stroke_width_exp = 0.5,
|
|
color = [BLUE, RED],
|
|
))
|
|
|
|
rect = SurroundingRectangle(group)
|
|
no_claim_words = TextMobject("No claims here...")
|
|
no_claim_words.next_to(rect, UP)
|
|
no_claim_words.set_color(YELLOW)
|
|
|
|
brain_outline = brain.copy()
|
|
brain_outline.set_fill(opacity = 0)
|
|
brain_outline.set_stroke(BLUE, 3)
|
|
brain_anim = ShowCreationThenDestruction(brain_outline)
|
|
|
|
words = TextMobject("Definitely not \\\\ a neuroscientist")
|
|
words.next_to(self.teacher, UP, buff = 1.5)
|
|
words.shift_onto_screen()
|
|
arrow = Arrow(words.get_bottom(), self.teacher.get_top())
|
|
|
|
self.play(
|
|
Write(words),
|
|
GrowArrow(arrow),
|
|
self.teacher.change, "guilty", words,
|
|
run_time = 1,
|
|
)
|
|
self.change_student_modes(*3*["sassy"])
|
|
self.play(
|
|
ShowCreation(rect),
|
|
Write(no_claim_words, run_time = 1),
|
|
brain_anim
|
|
)
|
|
self.wait()
|
|
self.play(brain_anim)
|
|
self.play(FocusOn(asterisks))
|
|
self.play(Write(asterisks, run_time = 1))
|
|
for x in range(2):
|
|
self.play(brain_anim)
|
|
self.wait()
|
|
|
|
class ConstructGradientFromAllTrainingExamples(Scene):
|
|
CONFIG = {
|
|
"image_height" : 0.9,
|
|
"eyes_height" : 0.25,
|
|
"n_examples" : 6,
|
|
"change_scale_val" : 0.8,
|
|
}
|
|
def construct(self):
|
|
self.setup_grid()
|
|
self.setup_weights()
|
|
self.show_two_requesting_changes()
|
|
self.show_all_examples_requesting_changes()
|
|
self.average_together()
|
|
self.collapse_into_gradient_vector()
|
|
|
|
def setup_grid(self):
|
|
h_lines = VGroup(*[
|
|
Line(LEFT, RIGHT).scale(0.85*FRAME_X_RADIUS)
|
|
for x in range(6)
|
|
])
|
|
h_lines.arrange(DOWN, buff = 1)
|
|
h_lines.set_stroke(LIGHT_GREY, 2)
|
|
h_lines.to_edge(DOWN, buff = MED_LARGE_BUFF)
|
|
h_lines.to_edge(LEFT, buff = 0)
|
|
|
|
v_lines = VGroup(*[
|
|
Line(UP, DOWN).scale(FRAME_Y_RADIUS - MED_LARGE_BUFF)
|
|
for x in range(self.n_examples + 1)
|
|
])
|
|
v_lines.arrange(RIGHT, buff = 1.4)
|
|
v_lines.set_stroke(LIGHT_GREY, 2)
|
|
v_lines.to_edge(LEFT, buff = 2)
|
|
|
|
# self.add(h_lines, v_lines)
|
|
self.h_lines = h_lines
|
|
self.v_lines = v_lines
|
|
|
|
def setup_weights(self):
|
|
weights = VGroup(*list(map(TexMobject, [
|
|
"w_0", "w_1", "w_2", "\\vdots", "w_{13{,}001}"
|
|
])))
|
|
for i, weight in enumerate(weights):
|
|
weight.move_to(self.get_grid_position(i, 0))
|
|
weights.to_edge(LEFT, buff = MED_SMALL_BUFF)
|
|
|
|
brace = Brace(weights, RIGHT)
|
|
weights_words = brace.get_text("All weights and biases")
|
|
|
|
self.add(weights, brace, weights_words)
|
|
self.set_variables_as_attrs(
|
|
weights, brace, weights_words,
|
|
dots = weights[-2]
|
|
)
|
|
|
|
def show_two_requesting_changes(self):
|
|
two = self.get_example(get_organized_images()[2][0], 0)
|
|
self.two = two
|
|
self.add(two)
|
|
|
|
self.two_changes = VGroup()
|
|
for i in list(range(3)) + [4]:
|
|
weight = self.weights[i]
|
|
bubble, change = self.get_requested_change_bubble(two)
|
|
weight.save_state()
|
|
weight.generate_target()
|
|
weight.target.next_to(two, RIGHT, aligned_edge = DOWN)
|
|
|
|
self.play(
|
|
MoveToTarget(weight),
|
|
two.eyes.look_at_anim(weight.target),
|
|
FadeIn(bubble),
|
|
Write(change, run_time = 1),
|
|
)
|
|
if random.random() < 0.5:
|
|
self.play(two.eyes.blink_anim())
|
|
else:
|
|
self.wait()
|
|
if i == 0:
|
|
added_anims = [
|
|
FadeOut(self.brace),
|
|
FadeOut(self.weights_words),
|
|
]
|
|
elif i == 4:
|
|
dots_copy = self.dots.copy()
|
|
added_anims = [
|
|
dots_copy.move_to,
|
|
self.get_grid_position(3, 0)
|
|
]
|
|
self.first_column_dots = dots_copy
|
|
else:
|
|
added_anims = []
|
|
self.play(
|
|
FadeOut(bubble),
|
|
weight.restore,
|
|
two.eyes.look_at_anim(weight.saved_state),
|
|
change.restore,
|
|
change.scale, self.change_scale_val,
|
|
change.move_to, self.get_grid_position(i, 0),
|
|
*added_anims
|
|
)
|
|
self.two_changes.add(change)
|
|
self.wait()
|
|
|
|
def show_all_examples_requesting_changes(self):
|
|
training_data, validation_data, test_data = load_data_wrapper()
|
|
data = training_data[:self.n_examples-1]
|
|
examples = VGroup(*[
|
|
self.get_example(t[0], j)
|
|
for t, j in zip(data, it.count(1))
|
|
])
|
|
h_dots = TexMobject("\\dots")
|
|
h_dots.next_to(examples, RIGHT, MED_LARGE_BUFF)
|
|
more_h_dots = VGroup(*[
|
|
TexMobject("\\dots").move_to(
|
|
self.get_grid_position(i, self.n_examples)
|
|
)
|
|
for i in range(5)
|
|
])
|
|
more_h_dots.shift(MED_LARGE_BUFF*RIGHT)
|
|
more_h_dots[-2].rotate_in_place(-np.pi/4)
|
|
more_v_dots = VGroup(*[
|
|
self.dots.copy().move_to(
|
|
self.get_grid_position(3, j)
|
|
)
|
|
for j in range(1, self.n_examples)
|
|
])
|
|
|
|
changes = VGroup(*[
|
|
self.get_random_decimal().move_to(
|
|
self.get_grid_position(i, j)
|
|
)
|
|
for i in list(range(3)) + [4]
|
|
for j in range(1, self.n_examples)
|
|
])
|
|
for change in changes:
|
|
change.scale_in_place(self.change_scale_val)
|
|
|
|
self.play(
|
|
OldLaggedStart(FadeIn, examples),
|
|
OldLaggedStart(ShowCreation, self.h_lines),
|
|
OldLaggedStart(ShowCreation, self.v_lines),
|
|
Write(
|
|
h_dots,
|
|
run_time = 2,
|
|
rate_func = squish_rate_func(smooth, 0.7, 1)
|
|
)
|
|
)
|
|
self.play(
|
|
Write(changes),
|
|
Write(more_v_dots),
|
|
Write(more_h_dots),
|
|
*[
|
|
example.eyes.look_at_anim(random.choice(changes))
|
|
for example in examples
|
|
]
|
|
)
|
|
for x in range(2):
|
|
self.play(random.choice(examples).eyes.blink_anim())
|
|
|
|
k = self.n_examples - 1
|
|
self.change_rows = VGroup(*[
|
|
VGroup(two_change, *changes[k*i:k*(i+1)])
|
|
for i, two_change in enumerate(self.two_changes)
|
|
])
|
|
for i in list(range(3)) + [-1]:
|
|
self.change_rows[i].add(more_h_dots[i])
|
|
|
|
self.all_eyes = VGroup(*[
|
|
m.eyes for m in [self.two] + list(examples)
|
|
])
|
|
|
|
self.set_variables_as_attrs(
|
|
more_h_dots, more_v_dots,
|
|
h_dots, changes,
|
|
)
|
|
|
|
def average_together(self):
|
|
rects = VGroup()
|
|
arrows = VGroup()
|
|
averages = VGroup()
|
|
for row in self.change_rows:
|
|
rect = SurroundingRectangle(row)
|
|
arrow = Arrow(ORIGIN, RIGHT)
|
|
arrow.next_to(rect, RIGHT)
|
|
rect.arrow = arrow
|
|
average = self.get_colored_decimal(3*np.mean([
|
|
m.number for m in row
|
|
if isinstance(m, DecimalNumber)
|
|
]))
|
|
average.scale(self.change_scale_val)
|
|
average.next_to(arrow, RIGHT)
|
|
row.target = VGroup(average)
|
|
|
|
rects.add(rect)
|
|
arrows.add(arrow)
|
|
averages.add(average)
|
|
|
|
words = TextMobject("Average over \\\\ all training data")
|
|
words.scale(0.8)
|
|
words.to_corner(UP+RIGHT)
|
|
arrow_to_averages = Arrow(
|
|
words.get_bottom(), averages.get_top(),
|
|
color = WHITE
|
|
)
|
|
|
|
dots = self.dots.copy()
|
|
dots.move_to(VGroup(*averages[-2:]))
|
|
|
|
look_at_anims = self.get_look_at_anims
|
|
|
|
self.play(Write(words, run_time = 1), *look_at_anims(words))
|
|
self.play(ShowCreation(rects[0]), *look_at_anims(rects[0]))
|
|
self.play(
|
|
ReplacementTransform(rects[0].copy(), arrows[0]),
|
|
rects[0].set_stroke, WHITE, 1,
|
|
ReplacementTransform(
|
|
self.change_rows[0].copy(),
|
|
self.change_rows[0].target
|
|
),
|
|
*look_at_anims(averages[0])
|
|
)
|
|
self.play(GrowArrow(arrow_to_averages))
|
|
self.play(
|
|
OldLaggedStart(ShowCreation, VGroup(*rects[1:])),
|
|
*look_at_anims(rects[1])
|
|
)
|
|
self.play(
|
|
OldLaggedStart(
|
|
ReplacementTransform, VGroup(*rects[1:]).copy(),
|
|
lambda m : (m, m.arrow),
|
|
lag_ratio = 0.7,
|
|
),
|
|
VGroup(*rects[1:]).set_stroke, WHITE, 1,
|
|
OldLaggedStart(
|
|
ReplacementTransform, VGroup(*self.change_rows[1:]).copy(),
|
|
lambda m : (m, m.target),
|
|
lag_ratio = 0.7,
|
|
),
|
|
Write(dots),
|
|
*look_at_anims(averages[1])
|
|
)
|
|
self.blink(3)
|
|
self.wait()
|
|
|
|
averages.add(dots)
|
|
self.set_variables_as_attrs(
|
|
rects, arrows, averages,
|
|
arrow_to_averages
|
|
)
|
|
|
|
def collapse_into_gradient_vector(self):
|
|
averages = self.averages
|
|
lb, rb = brackets = TexMobject("[]")
|
|
brackets.scale(2)
|
|
brackets.stretch_to_fit_height(1.2*averages.get_height())
|
|
lb.next_to(averages, LEFT, SMALL_BUFF)
|
|
rb.next_to(averages, RIGHT, SMALL_BUFF)
|
|
brackets.set_fill(opacity = 0)
|
|
|
|
shift_vect = 2*LEFT
|
|
|
|
lhs = TexMobject(
|
|
"-", "\\nabla", "C(",
|
|
"w_1,", "w_2,", "\\dots", "w_{13{,}001}",
|
|
")", "="
|
|
)
|
|
lhs.next_to(lb, LEFT)
|
|
lhs.shift(shift_vect)
|
|
minus = lhs[0]
|
|
w_terms = lhs.get_parts_by_tex("w_")
|
|
dots_term = lhs.get_part_by_tex("dots")
|
|
eta = TexMobject("\\eta")
|
|
eta.move_to(minus, RIGHT)
|
|
eta.set_color(MAROON_B)
|
|
|
|
to_fade = VGroup(*it.chain(
|
|
self.h_lines, self.v_lines,
|
|
self.more_h_dots, self.more_v_dots,
|
|
self.change_rows,
|
|
self.first_column_dots,
|
|
self.rects,
|
|
self.arrows,
|
|
))
|
|
arrow = self.arrow_to_averages
|
|
|
|
self.play(OldLaggedStart(FadeOut, to_fade))
|
|
self.play(
|
|
brackets.shift, shift_vect,
|
|
brackets.set_fill, WHITE, 1,
|
|
averages.shift, shift_vect,
|
|
Transform(arrow, Arrow(
|
|
arrow.get_start(),
|
|
arrow.get_end() + shift_vect,
|
|
buff = 0,
|
|
color = arrow.get_color(),
|
|
)),
|
|
FadeIn(VGroup(*lhs[:3])),
|
|
FadeIn(VGroup(*lhs[-2:])),
|
|
*self.get_look_at_anims(lhs)
|
|
)
|
|
self.play(
|
|
ReplacementTransform(self.weights, w_terms),
|
|
ReplacementTransform(self.dots, dots_term),
|
|
*self.get_look_at_anims(w_terms)
|
|
)
|
|
self.blink(2)
|
|
self.play(
|
|
GrowFromCenter(eta),
|
|
minus.shift, MED_SMALL_BUFF*LEFT
|
|
)
|
|
self.wait()
|
|
|
|
####
|
|
|
|
def get_example(self, in_vect, index):
|
|
result = MNistMobject(in_vect)
|
|
result.set_height(self.image_height)
|
|
|
|
eyes = Eyes(result, height = self.eyes_height)
|
|
result.eyes = eyes
|
|
result.add(eyes)
|
|
result.move_to(self.get_grid_position(0, index))
|
|
result.to_edge(UP, buff = LARGE_BUFF)
|
|
return result
|
|
|
|
def get_grid_position(self, i, j):
|
|
x = VGroup(*self.v_lines[j:j+2]).get_center()[0]
|
|
y = VGroup(*self.h_lines[i:i+2]).get_center()[1]
|
|
return x*RIGHT + y*UP
|
|
|
|
def get_requested_change_bubble(self, example_mob):
|
|
change = self.get_random_decimal()
|
|
words = TextMobject("Change by")
|
|
change.next_to(words, RIGHT)
|
|
change.save_state()
|
|
content = VGroup(words, change)
|
|
|
|
bubble = SpeechBubble(height = 1.5, width = 3)
|
|
bubble.add_content(content)
|
|
group = VGroup(bubble, content)
|
|
group.shift(
|
|
example_mob.get_right() + SMALL_BUFF*RIGHT \
|
|
-bubble.get_corner(DOWN+LEFT)
|
|
)
|
|
|
|
return VGroup(bubble, words), change
|
|
|
|
def get_random_decimal(self):
|
|
return self.get_colored_decimal(
|
|
0.3*(random.random() - 0.5)
|
|
)
|
|
|
|
def get_colored_decimal(self, number):
|
|
result = DecimalNumber(number)
|
|
if result.number > 0:
|
|
plus = TexMobject("+")
|
|
plus.next_to(result, LEFT, SMALL_BUFF)
|
|
result.add_to_back(plus)
|
|
result.set_color(BLUE)
|
|
else:
|
|
result.set_color(RED)
|
|
return result
|
|
|
|
def get_look_at_anims(self, mob):
|
|
return [eyes.look_at_anim(mob) for eyes in self.all_eyes]
|
|
|
|
def blink(self, n):
|
|
for x in range(n):
|
|
self.play(random.choice(self.all_eyes).blink_anim())
|
|
|
|
class WatchPreviousScene(TeacherStudentsScene):
|
|
def construct(self):
|
|
screen = ScreenRectangle(height = 4.5)
|
|
screen.to_corner(UP+LEFT)
|
|
|
|
self.play(
|
|
self.teacher.change, "raise_right_hand", screen,
|
|
self.get_student_changes(
|
|
*["thinking"]*3,
|
|
look_at_arg = screen
|
|
),
|
|
ShowCreation(screen)
|
|
)
|
|
self.wait(10)
|
|
|
|
class OpenCloseSGD(Scene):
|
|
def construct(self):
|
|
term = TexMobject(
|
|
"\\langle", "\\text{Stochastic gradient descent}",
|
|
"\\rangle"
|
|
)
|
|
alt_term0 = TexMobject("\\langle /")
|
|
alt_term0.move_to(term[0], RIGHT)
|
|
|
|
term.save_state()
|
|
center = term.get_center()
|
|
term[0].move_to(center, RIGHT)
|
|
term[2].move_to(center, LEFT)
|
|
term[1].scale(0.0001).move_to(center)
|
|
|
|
self.play(term.restore)
|
|
self.wait(2)
|
|
self.play(Transform(term[0], alt_term0))
|
|
self.wait(2)
|
|
|
|
class OrganizeDataIntoMiniBatches(Scene):
|
|
CONFIG = {
|
|
"n_rows" : 5,
|
|
"n_cols" : 12,
|
|
"example_height" : 1,
|
|
"random_seed" : 0,
|
|
}
|
|
def construct(self):
|
|
self.seed_random_libraries()
|
|
self.add_examples()
|
|
self.shuffle_examples()
|
|
self.divide_into_minibatches()
|
|
self.one_step_per_batch()
|
|
|
|
def seed_random_libraries(self):
|
|
random.seed(self.random_seed)
|
|
np.random.seed(self.random_seed)
|
|
|
|
def add_examples(self):
|
|
examples = self.get_examples()
|
|
self.arrange_examples_in_grid(examples)
|
|
for example in examples:
|
|
example.save_state()
|
|
alt_order_examples = VGroup(*examples)
|
|
for mob in examples, alt_order_examples:
|
|
random.shuffle(mob.submobjects)
|
|
self.arrange_examples_in_grid(examples)
|
|
|
|
self.play(OldLaggedStart(
|
|
FadeIn, alt_order_examples,
|
|
lag_ratio = 0.2,
|
|
run_time = 4
|
|
))
|
|
self.wait()
|
|
|
|
self.examples = examples
|
|
|
|
def shuffle_examples(self):
|
|
self.play(OldLaggedStart(
|
|
ApplyMethod, self.examples,
|
|
lambda m : (m.restore,),
|
|
lag_ratio = 0.3,
|
|
run_time = 3,
|
|
path_arc = np.pi/3,
|
|
))
|
|
self.wait()
|
|
|
|
def divide_into_minibatches(self):
|
|
examples = self.examples
|
|
examples.sort(lambda p : -p[1])
|
|
rows = Group(*[
|
|
Group(*examples[i*self.n_cols:(i+1)*self.n_cols])
|
|
for i in range(self.n_rows)
|
|
])
|
|
|
|
mini_batches_words = TextMobject("``Mini-batches''")
|
|
mini_batches_words.to_edge(UP)
|
|
mini_batches_words.set_color(YELLOW)
|
|
|
|
self.play(
|
|
rows.space_out_submobjects, 1.5,
|
|
rows.to_edge, UP, 1.5,
|
|
Write(mini_batches_words, run_time = 1)
|
|
)
|
|
|
|
rects = VGroup(*[
|
|
SurroundingRectangle(
|
|
row,
|
|
stroke_width = 0,
|
|
fill_color = YELLOW,
|
|
fill_opacity = 0.25,
|
|
)
|
|
for row in rows
|
|
])
|
|
self.play(OldLaggedStart(
|
|
FadeIn, rects,
|
|
lag_ratio = 0.7,
|
|
rate_func = there_and_back
|
|
))
|
|
self.wait()
|
|
|
|
self.set_variables_as_attrs(rows, rects, mini_batches_words)
|
|
|
|
def one_step_per_batch(self):
|
|
rows = self.rows
|
|
brace = Brace(rows[0], UP, buff = SMALL_BUFF)
|
|
text = brace.get_text(
|
|
"Compute gradient descent step (using backprop)",
|
|
buff = SMALL_BUFF
|
|
)
|
|
def indicate_row(row):
|
|
row.sort(lambda p : p[0])
|
|
return OldLaggedStart(
|
|
ApplyFunction, row,
|
|
lambda row : (
|
|
lambda m : m.scale_in_place(0.75).set_color(YELLOW),
|
|
row
|
|
),
|
|
rate_func = wiggle
|
|
)
|
|
|
|
self.play(
|
|
FadeOut(self.mini_batches_words),
|
|
GrowFromCenter(brace),
|
|
Write(text, run_time = 2),
|
|
)
|
|
self.play(indicate_row(rows[0]))
|
|
brace.add(text)
|
|
for last_row, row in zip(rows, rows[1:-1]):
|
|
self.play(
|
|
last_row.shift, UP,
|
|
brace.next_to, row, UP, SMALL_BUFF
|
|
)
|
|
self.play(indicate_row(row))
|
|
self.wait()
|
|
|
|
|
|
###
|
|
|
|
def get_examples(self):
|
|
n_examples = self.n_rows*self.n_cols
|
|
height = self.example_height
|
|
training_data, validation_data, test_data = load_data_wrapper()
|
|
return Group(*[
|
|
MNistMobject(
|
|
t[0],
|
|
rect_kwargs = {"stroke_width" : 2}
|
|
).set_height(height)
|
|
for t in training_data[:n_examples]
|
|
])
|
|
# return Group(*[
|
|
# Square(
|
|
# color = BLUE,
|
|
# stroke_width = 2
|
|
# ).set_height(height)
|
|
# for x in range(n_examples)
|
|
# ])
|
|
|
|
def arrange_examples_in_grid(self, examples):
|
|
examples.arrange_in_grid(
|
|
n_rows = self.n_rows,
|
|
buff = SMALL_BUFF
|
|
)
|
|
|
|
class SGDSteps(ExternallyAnimatedScene):
|
|
pass
|
|
|
|
class GradientDescentSteps(ExternallyAnimatedScene):
|
|
pass
|
|
|
|
class SwimmingInTerms(TeacherStudentsScene):
|
|
def construct(self):
|
|
terms = VGroup(
|
|
TextMobject("Cost surface"),
|
|
TextMobject("Stochastic gradient descent"),
|
|
TextMobject("Mini-batches"),
|
|
TextMobject("Backpropagation"),
|
|
)
|
|
terms.arrange(DOWN)
|
|
terms.to_edge(UP)
|
|
self.play(
|
|
OldLaggedStart(FadeIn, terms),
|
|
self.get_student_changes(*["horrified"]*3)
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
terms[-1].next_to, self.teacher.get_corner(UP+LEFT), UP,
|
|
FadeOut(VGroup(*terms[:-1])),
|
|
self.teacher.change, "raise_right_hand",
|
|
self.get_student_changes(*["pondering"]*3)
|
|
)
|
|
self.wait()
|
|
|
|
class BackpropCode(ExternallyAnimatedScene):
|
|
pass
|
|
|
|
class BackpropCodeAddOn(PiCreatureScene):
|
|
def construct(self):
|
|
words = TextMobject(
|
|
"The code you'd find \\\\ in Nielsen's book"
|
|
)
|
|
words.to_corner(DOWN+LEFT)
|
|
morty = self.pi_creature
|
|
morty.next_to(words, UP)
|
|
self.add(words)
|
|
for mode in ["pondering", "thinking", "happy"]:
|
|
self.play(
|
|
morty.change, "pondering",
|
|
morty.look, UP+LEFT
|
|
)
|
|
self.play(morty.look, LEFT)
|
|
self.wait(2)
|
|
|
|
class CannotFollowCode(TeacherStudentsScene):
|
|
def construct(self):
|
|
self.student_says(
|
|
"I...er...can't follow\\\\ that code at all.",
|
|
target_mode = "confused",
|
|
student_index = 1
|
|
)
|
|
self.play(self.students[1].change, "sad")
|
|
self.change_student_modes(
|
|
"angry", "sad", "angry",
|
|
look_at_arg = self.teacher.eyes
|
|
)
|
|
self.play(self.teacher.change, "hesitant")
|
|
self.wait(2)
|
|
self.teacher_says(
|
|
"Let's get to the \\\\ calculus then",
|
|
target_mode = "hooray",
|
|
added_anims = [self.get_student_changes(*3*["plain"])],
|
|
run_time = 1
|
|
)
|
|
self.wait(2)
|
|
|
|
class EOCWrapper(Scene):
|
|
def construct(self):
|
|
title = TextMobject("Essence of calculus")
|
|
title.to_edge(UP)
|
|
screen = ScreenRectangle(height = 6)
|
|
screen.next_to(title, DOWN)
|
|
|
|
self.add(title)
|
|
self.play(ShowCreation(screen))
|
|
self.wait()
|
|
|
|
class SimplestNetworkExample(PreviewLearning):
|
|
CONFIG = {
|
|
"random_seed" : 6,
|
|
"z_color" : GREEN,
|
|
"cost_color" : RED,
|
|
"desired_output_color" : YELLOW,
|
|
"derivative_scale_val" : 0.85,
|
|
}
|
|
def construct(self):
|
|
self.seed_random_libraries()
|
|
self.collapse_ordinary_network()
|
|
self.show_weights_and_biases()
|
|
self.focus_just_on_last_two_layers()
|
|
self.label_neurons()
|
|
self.show_desired_output()
|
|
self.show_cost()
|
|
self.show_activation_formula()
|
|
self.introduce_z()
|
|
self.break_into_computational_graph()
|
|
self.show_preceding_layer_in_computational_graph()
|
|
self.show_number_lines()
|
|
self.ask_about_w_sensitivity()
|
|
self.show_derivative_wrt_w()
|
|
self.show_chain_of_events()
|
|
self.show_chain_rule()
|
|
self.name_chain_rule()
|
|
self.indicate_everything_on_screen()
|
|
self.prepare_for_derivatives()
|
|
self.compute_derivatives()
|
|
self.get_lost_in_formulas()
|
|
self.fire_together_wire_together()
|
|
self.organize_chain_rule_rhs()
|
|
self.show_average_derivative()
|
|
self.show_gradient()
|
|
self.transition_to_derivative_wrt_b()
|
|
self.show_derivative_wrt_b()
|
|
self.show_derivative_wrt_a()
|
|
self.show_previous_weight_and_bias()
|
|
self.animate_long_path()
|
|
|
|
def seed_random_libraries(self):
|
|
np.random.seed(self.random_seed)
|
|
random.seed(self.random_seed)
|
|
|
|
def collapse_ordinary_network(self):
|
|
network_mob = self.network_mob
|
|
config = dict(self.network_mob_config)
|
|
config.pop("include_output_labels")
|
|
config.update({
|
|
"edge_stroke_width" : 3,
|
|
"edge_propogation_color" : YELLOW,
|
|
"edge_propogation_time" : 1,
|
|
"neuron_radius" : 0.3,
|
|
})
|
|
simple_network = Network(sizes = [1, 1, 1, 1])
|
|
simple_network_mob = NetworkMobject(simple_network, **config)
|
|
self.color_network_edges()
|
|
s_edges = simple_network_mob.edge_groups
|
|
for edge, weight_matrix in zip(s_edges, simple_network.weights):
|
|
weight = weight_matrix[0][0]
|
|
width = 2*abs(weight)
|
|
color = BLUE if weight > 0 else RED
|
|
edge.set_stroke(color, width)
|
|
|
|
def edge_collapse_anims(edges, left_attachment_target):
|
|
return [
|
|
ApplyMethod(
|
|
e.put_start_and_end_on_with_projection,
|
|
left_attachment_target.get_right(),
|
|
e.get_end()
|
|
)
|
|
for e in edges
|
|
]
|
|
|
|
neuron = simple_network_mob.layers[0].neurons[0]
|
|
self.play(
|
|
ReplacementTransform(network_mob.layers[0], neuron),
|
|
*edge_collapse_anims(network_mob.edge_groups[0], neuron)
|
|
)
|
|
for i, layer in enumerate(network_mob.layers[1:]):
|
|
neuron = simple_network_mob.layers[i+1].neurons[0]
|
|
prev_edges = network_mob.edge_groups[i]
|
|
prev_edge_target = simple_network_mob.edge_groups[i]
|
|
if i+1 < len(network_mob.edge_groups):
|
|
edges = network_mob.edge_groups[i+1]
|
|
added_anims = edge_collapse_anims(edges, neuron)
|
|
else:
|
|
added_anims = [FadeOut(network_mob.output_labels)]
|
|
self.play(
|
|
ReplacementTransform(layer, neuron),
|
|
ReplacementTransform(prev_edges, prev_edge_target),
|
|
*added_anims
|
|
)
|
|
self.remove(network_mob)
|
|
self.add(simple_network_mob)
|
|
self.network_mob = simple_network_mob
|
|
self.network = self.network_mob.neural_network
|
|
self.feed_forward(np.array([0.5]))
|
|
self.wait()
|
|
|
|
def show_weights_and_biases(self):
|
|
network_mob = self.network_mob
|
|
edges = VGroup(*[eg[0] for eg in network_mob.edge_groups])
|
|
neurons = VGroup(*[
|
|
layer.neurons[0]
|
|
for layer in network_mob.layers[1:]
|
|
])
|
|
expression = TexMobject(
|
|
"C", "(",
|
|
"w_1", ",", "b_1", ",",
|
|
"w_2", ",", "b_2", ",",
|
|
"w_3", ",", "b_3",
|
|
")"
|
|
)
|
|
expression.shift(2*UP)
|
|
expression.set_color_by_tex("C", RED)
|
|
w_terms = expression.get_parts_by_tex("w_")
|
|
for w, edge in zip(w_terms, edges):
|
|
w.set_color(edge.get_color())
|
|
b_terms = expression.get_parts_by_tex("b_")
|
|
variables = VGroup(*it.chain(w_terms, b_terms))
|
|
other_terms = VGroup(*[m for m in expression if m not in variables])
|
|
random.shuffle(variables.submobjects)
|
|
|
|
self.play(ReplacementTransform(edges.copy(), w_terms))
|
|
self.wait()
|
|
self.play(ReplacementTransform(neurons.copy(), b_terms))
|
|
self.wait()
|
|
self.play(Write(other_terms))
|
|
for x in range(2):
|
|
self.play(OldLaggedStart(
|
|
Indicate, variables,
|
|
rate_func = wiggle,
|
|
run_time = 4,
|
|
))
|
|
self.wait()
|
|
self.play(
|
|
FadeOut(other_terms),
|
|
ReplacementTransform(w_terms, edges),
|
|
ReplacementTransform(b_terms, neurons),
|
|
)
|
|
self.remove(expression)
|
|
|
|
def focus_just_on_last_two_layers(self):
|
|
to_fade = VGroup(*it.chain(*list(zip(
|
|
self.network_mob.layers[:2],
|
|
self.network_mob.edge_groups[:2],
|
|
))))
|
|
for mob in to_fade:
|
|
mob.save_state()
|
|
self.play(OldLaggedStart(
|
|
ApplyMethod, to_fade,
|
|
lambda m : (m.fade, 0.9)
|
|
))
|
|
self.wait()
|
|
|
|
self.prev_layers = to_fade
|
|
|
|
def label_neurons(self):
|
|
neurons = VGroup(*[
|
|
self.network_mob.layers[i].neurons[0]
|
|
for i in (-1, -2)
|
|
])
|
|
decimals = VGroup()
|
|
a_labels = VGroup()
|
|
a_label_arrows = VGroup()
|
|
superscripts = ["L", "L-1"]
|
|
superscript_rects = VGroup()
|
|
for neuron, superscript in zip(neurons, superscripts):
|
|
decimal = self.get_neuron_activation_decimal(neuron)
|
|
label = TexMobject("a^{(%s)}"%superscript)
|
|
label.next_to(neuron, DOWN, buff = LARGE_BUFF)
|
|
superscript_rect = SurroundingRectangle(VGroup(*label[1:]))
|
|
arrow = Arrow(
|
|
label[0].get_top(),
|
|
neuron.get_bottom(),
|
|
buff = SMALL_BUFF,
|
|
color = WHITE
|
|
)
|
|
|
|
decimal.save_state()
|
|
decimal.set_fill(opacity = 0)
|
|
decimal.move_to(label)
|
|
|
|
decimals.add(decimal)
|
|
a_labels.add(label)
|
|
a_label_arrows.add(arrow)
|
|
superscript_rects.add(superscript_rect)
|
|
|
|
self.play(
|
|
Write(label, run_time = 1),
|
|
GrowArrow(arrow),
|
|
)
|
|
self.play(decimal.restore)
|
|
opacity = neuron.get_fill_opacity()
|
|
self.play(
|
|
neuron.set_fill, None, 0,
|
|
ChangingDecimal(
|
|
decimal,
|
|
lambda a : interpolate(opacity, 0.01, a)
|
|
),
|
|
UpdateFromFunc(
|
|
decimal,
|
|
lambda d : d.set_fill(WHITE if d.number < 0.8 else BLACK)
|
|
),
|
|
run_time = 2,
|
|
rate_func = there_and_back,
|
|
)
|
|
self.wait()
|
|
|
|
not_exponents = TextMobject("Not exponents")
|
|
not_exponents.next_to(superscript_rects, DOWN, MED_LARGE_BUFF)
|
|
not_exponents.set_color(YELLOW)
|
|
|
|
self.play(
|
|
OldLaggedStart(
|
|
ShowCreation, superscript_rects,
|
|
lag_ratio = 0.8, run_time = 1.5
|
|
),
|
|
Write(not_exponents, run_time = 2)
|
|
)
|
|
self.wait()
|
|
self.play(*list(map(FadeOut, [not_exponents, superscript_rects])))
|
|
|
|
self.set_variables_as_attrs(
|
|
a_labels, a_label_arrows, decimals,
|
|
last_neurons = neurons
|
|
)
|
|
|
|
def show_desired_output(self):
|
|
neuron = self.network_mob.layers[-1].neurons[0].copy()
|
|
neuron.shift(2*RIGHT)
|
|
neuron.set_fill(opacity = 1)
|
|
decimal = self.get_neuron_activation_decimal(neuron)
|
|
|
|
rect = SurroundingRectangle(neuron)
|
|
words = TextMobject("Desired \\\\ output")
|
|
words.next_to(rect, UP)
|
|
|
|
y_label = TexMobject("y")
|
|
y_label.next_to(neuron, DOWN, LARGE_BUFF)
|
|
y_label.align_to(self.a_labels, DOWN)
|
|
y_label_arrow = Arrow(
|
|
y_label, neuron,
|
|
color = WHITE,
|
|
buff = SMALL_BUFF
|
|
)
|
|
VGroup(words, rect, y_label).set_color(self.desired_output_color)
|
|
|
|
self.play(*list(map(FadeIn, [neuron, decimal])))
|
|
self.play(
|
|
ShowCreation(rect),
|
|
Write(words, run_time = 1)
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
Write(y_label, run_time = 1),
|
|
GrowArrow(y_label_arrow)
|
|
)
|
|
self.wait()
|
|
|
|
self.set_variables_as_attrs(
|
|
y_label, y_label_arrow,
|
|
desired_output_neuron = neuron,
|
|
desired_output_decimal = decimal,
|
|
desired_output_rect = rect,
|
|
desired_output_words = words,
|
|
)
|
|
|
|
def show_cost(self):
|
|
pre_a = self.a_labels[0].copy()
|
|
pre_y = self.y_label.copy()
|
|
|
|
cost_equation = TexMobject(
|
|
"C_0", "(", "\\dots", ")", "=",
|
|
"(", "a^{(L)}", "-", "y", ")", "^2"
|
|
)
|
|
cost_equation.to_corner(UP+RIGHT)
|
|
C0, a, y = [
|
|
cost_equation.get_part_by_tex(tex)
|
|
for tex in ("C_0", "a^{(L)}", "y")
|
|
]
|
|
y.set_color(YELLOW)
|
|
|
|
cost_word = TextMobject("Cost")
|
|
cost_word.next_to(C0[0], LEFT, LARGE_BUFF)
|
|
cost_arrow = Arrow(
|
|
cost_word, C0,
|
|
buff = SMALL_BUFF
|
|
)
|
|
VGroup(C0, cost_word, cost_arrow).set_color(self.cost_color)
|
|
|
|
expression = TexMobject(
|
|
"\\text{For example: }"
|
|
"(", "0.00", "-", "0.00", ")", "^2"
|
|
)
|
|
numbers = expression.get_parts_by_tex("0.00")
|
|
non_numbers = VGroup(*[m for m in expression if m not in numbers])
|
|
expression.next_to(cost_equation, DOWN, aligned_edge = RIGHT)
|
|
decimals = VGroup(
|
|
self.decimals[0],
|
|
self.desired_output_decimal
|
|
).copy()
|
|
decimals.generate_target()
|
|
for d, n in zip(decimals.target, numbers):
|
|
d.replace(n, dim_to_match = 1)
|
|
d.set_color(n.get_color())
|
|
|
|
self.play(
|
|
ReplacementTransform(pre_a, a),
|
|
ReplacementTransform(pre_y, y),
|
|
)
|
|
self.play(OldLaggedStart(
|
|
FadeIn, VGroup(*[m for m in cost_equation if m not in [a, y]])
|
|
))
|
|
self.play(
|
|
MoveToTarget(decimals),
|
|
FadeIn(non_numbers)
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
Write(cost_word, run_time = 1),
|
|
GrowArrow(cost_arrow)
|
|
)
|
|
self.play(C0.shift, MED_SMALL_BUFF*UP, rate_func = wiggle)
|
|
self.wait()
|
|
self.play(*list(map(FadeOut, [decimals, non_numbers])))
|
|
|
|
self.set_variables_as_attrs(
|
|
cost_equation, cost_word, cost_arrow
|
|
)
|
|
|
|
def show_activation_formula(self):
|
|
neuron = self.network_mob.layers[-1].neurons[0]
|
|
edge = self.network_mob.edge_groups[-1][0]
|
|
pre_aL, pre_aLm1 = self.a_labels.copy()
|
|
|
|
formula = TexMobject(
|
|
"a^{(L)}", "=", "\\sigma", "(",
|
|
"w^{(L)}", "a^{(L-1)}", "+", "b^{(L)}", ")"
|
|
)
|
|
formula.next_to(neuron, UP, MED_LARGE_BUFF, RIGHT)
|
|
aL, equals, sigma, lp, wL, aLm1, plus, bL, rp = formula
|
|
wL.set_color(edge.get_color())
|
|
weight_label = wL.copy()
|
|
bL.set_color(MAROON_B)
|
|
bias_label = bL.copy()
|
|
sigma_group = VGroup(sigma, lp, rp)
|
|
sigma_group.save_state()
|
|
sigma_group.set_fill(opacity = 0)
|
|
sigma_group.shift(DOWN)
|
|
|
|
self.play(
|
|
ReplacementTransform(pre_aL, aL),
|
|
Write(equals)
|
|
)
|
|
self.play(ReplacementTransform(
|
|
edge.copy(), wL
|
|
))
|
|
self.wait()
|
|
self.play(ReplacementTransform(pre_aLm1, aLm1))
|
|
self.wait()
|
|
self.play(Write(VGroup(plus, bL), run_time = 1))
|
|
self.wait()
|
|
self.play(sigma_group.restore)
|
|
self.wait()
|
|
|
|
weighted_sum_terms = VGroup(wL, aLm1, plus, bL)
|
|
self.set_variables_as_attrs(
|
|
formula, weighted_sum_terms
|
|
)
|
|
|
|
def introduce_z(self):
|
|
terms = self.weighted_sum_terms
|
|
terms.generate_target()
|
|
terms.target.next_to(
|
|
self.formula, UP,
|
|
buff = MED_LARGE_BUFF,
|
|
aligned_edge = RIGHT
|
|
)
|
|
terms.target.shift(MED_LARGE_BUFF*RIGHT)
|
|
equals = TexMobject("=")
|
|
equals.next_to(terms.target[0][0], LEFT)
|
|
|
|
z_label = TexMobject("z^{(L)}")
|
|
z_label.next_to(equals, LEFT)
|
|
z_label.align_to(terms.target, DOWN)
|
|
z_label.set_color(self.z_color)
|
|
z_label2 = z_label.copy()
|
|
|
|
aL_start = VGroup(*self.formula[:4])
|
|
aL_start.generate_target()
|
|
aL_start.target.align_to(z_label, LEFT)
|
|
z_label2.next_to(aL_start.target, RIGHT, SMALL_BUFF)
|
|
z_label2.align_to(aL_start.target[0], DOWN)
|
|
rp = self.formula[-1]
|
|
rp.generate_target()
|
|
rp.target.next_to(z_label2, RIGHT, SMALL_BUFF)
|
|
rp.target.align_to(aL_start.target, DOWN)
|
|
|
|
self.play(MoveToTarget(terms))
|
|
self.play(Write(z_label), Write(equals))
|
|
self.play(
|
|
ReplacementTransform(z_label.copy(), z_label2),
|
|
MoveToTarget(aL_start),
|
|
MoveToTarget(rp),
|
|
)
|
|
self.wait()
|
|
|
|
zL_formula = VGroup(z_label, equals, *terms)
|
|
aL_formula = VGroup(*list(aL_start) + [z_label2, rp])
|
|
self.set_variables_as_attrs(z_label, zL_formula, aL_formula)
|
|
|
|
def break_into_computational_graph(self):
|
|
network_early_layers = VGroup(*it.chain(
|
|
self.network_mob.layers[:2],
|
|
self.network_mob.edge_groups[:2]
|
|
))
|
|
|
|
wL, aL, plus, bL = self.weighted_sum_terms
|
|
top_terms = VGroup(wL, aL, bL).copy()
|
|
zL = self.z_label.copy()
|
|
aL = self.formula[0].copy()
|
|
y = self.y_label.copy()
|
|
C0 = self.cost_equation[0].copy()
|
|
targets = VGroup()
|
|
for mob in top_terms, zL, aL, C0:
|
|
mob.generate_target()
|
|
targets.add(mob.target)
|
|
y.generate_target()
|
|
top_terms.target.arrange(RIGHT, buff = MED_LARGE_BUFF)
|
|
targets.arrange(DOWN, buff = LARGE_BUFF)
|
|
targets.center().to_corner(DOWN+LEFT)
|
|
y.target.next_to(aL.target, LEFT, LARGE_BUFF, DOWN)
|
|
|
|
top_lines = VGroup(*[
|
|
Line(
|
|
term.get_bottom(),
|
|
zL.target.get_top(),
|
|
buff = SMALL_BUFF
|
|
)
|
|
for term in top_terms.target
|
|
])
|
|
z_to_a_line, a_to_c_line, y_to_c_line = all_lines = [
|
|
Line(
|
|
m1.target.get_bottom(),
|
|
m2.target.get_top(),
|
|
buff = SMALL_BUFF
|
|
)
|
|
for m1, m2 in [
|
|
(zL, aL),
|
|
(aL, C0),
|
|
(y, C0)
|
|
]
|
|
]
|
|
for mob in [top_lines] + all_lines:
|
|
yellow_copy = mob.copy().set_color(YELLOW)
|
|
mob.flash = ShowCreationThenDestruction(yellow_copy)
|
|
|
|
self.play(MoveToTarget(top_terms))
|
|
self.wait()
|
|
self.play(MoveToTarget(zL))
|
|
self.play(
|
|
ShowCreation(top_lines, lag_ratio = 0),
|
|
top_lines.flash
|
|
)
|
|
self.wait()
|
|
self.play(MoveToTarget(aL))
|
|
self.play(
|
|
network_early_layers.fade, 1,
|
|
ShowCreation(z_to_a_line),
|
|
z_to_a_line.flash
|
|
)
|
|
self.wait()
|
|
self.play(MoveToTarget(y))
|
|
self.play(MoveToTarget(C0))
|
|
self.play(*it.chain(*[
|
|
[ShowCreation(line), line.flash]
|
|
for line in (a_to_c_line, y_to_c_line)
|
|
]))
|
|
self.wait(2)
|
|
|
|
comp_graph = VGroup()
|
|
comp_graph.wL, comp_graph.aLm1, comp_graph.bL = top_terms
|
|
comp_graph.top_lines = top_lines
|
|
comp_graph.zL = zL
|
|
comp_graph.z_to_a_line = z_to_a_line
|
|
comp_graph.aL = aL
|
|
comp_graph.y = y
|
|
comp_graph.a_to_c_line = a_to_c_line
|
|
comp_graph.y_to_c_line = y_to_c_line
|
|
comp_graph.C0 = C0
|
|
comp_graph.digest_mobject_attrs()
|
|
self.comp_graph = comp_graph
|
|
|
|
def show_preceding_layer_in_computational_graph(self):
|
|
shift_vect = DOWN
|
|
comp_graph = self.comp_graph
|
|
comp_graph.save_state()
|
|
comp_graph.generate_target()
|
|
comp_graph.target.shift(shift_vect)
|
|
rect = SurroundingRectangle(comp_graph.aLm1)
|
|
|
|
attrs = ["wL", "aLm1", "bL", "zL"]
|
|
new_terms = VGroup()
|
|
for attr in attrs:
|
|
term = getattr(comp_graph, attr)
|
|
tex = term.get_tex_string()
|
|
if "L-1" in tex:
|
|
tex = tex.replace("L-1", "L-2")
|
|
else:
|
|
tex = tex.replace("L", "L-1")
|
|
new_term = TexMobject(tex)
|
|
new_term.set_color(term.get_color())
|
|
new_term.move_to(term)
|
|
new_terms.add(new_term)
|
|
new_edges = VGroup(
|
|
comp_graph.top_lines.copy(),
|
|
comp_graph.z_to_a_line.copy(),
|
|
)
|
|
new_subgraph = VGroup(new_terms, new_edges)
|
|
new_subgraph.next_to(comp_graph.target, UP, SMALL_BUFF)
|
|
self.wLm1 = new_terms[0]
|
|
self.zLm1 = new_terms[-1]
|
|
|
|
prev_neuron = self.network_mob.layers[1]
|
|
prev_neuron.restore()
|
|
prev_edge = self.network_mob.edge_groups[1]
|
|
prev_edge.restore()
|
|
|
|
self.play(
|
|
ShowCreation(rect),
|
|
FadeIn(prev_neuron),
|
|
ShowCreation(prev_edge)
|
|
)
|
|
self.play(
|
|
ReplacementTransform(
|
|
VGroup(prev_neuron, prev_edge).copy(),
|
|
new_subgraph
|
|
),
|
|
UpdateFromAlphaFunc(
|
|
new_terms,
|
|
lambda m, a : m.set_fill(opacity = a)
|
|
),
|
|
MoveToTarget(comp_graph),
|
|
rect.shift, shift_vect
|
|
)
|
|
self.wait(2)
|
|
self.play(
|
|
FadeOut(new_subgraph),
|
|
FadeOut(prev_neuron),
|
|
FadeOut(prev_edge),
|
|
comp_graph.restore,
|
|
rect.shift, -shift_vect,
|
|
rect.set_stroke, BLACK, 0
|
|
)
|
|
VGroup(prev_neuron, prev_edge).fade(1)
|
|
self.remove(rect)
|
|
self.wait()
|
|
|
|
self.prev_comp_subgraph = new_subgraph
|
|
|
|
def show_number_lines(self):
|
|
comp_graph = self.comp_graph
|
|
wL, aLm1, bL, zL, aL, C0 = [
|
|
getattr(comp_graph, attr)
|
|
for attr in ["wL", "aLm1", "bL", "zL", "aL", "C0"]
|
|
]
|
|
wL.val = self.network.weights[-1][0][0]
|
|
aL.val = self.decimals[0].number
|
|
zL.val = sigmoid_inverse(aL.val)
|
|
C0.val = (aL.val - 1)**2
|
|
|
|
number_line = UnitInterval(
|
|
unit_size = 2,
|
|
stroke_width = 2,
|
|
tick_size = 0.075,
|
|
color = LIGHT_GREY,
|
|
)
|
|
|
|
for mob in wL, zL, aL, C0:
|
|
mob.number_line = number_line.deepcopy()
|
|
if mob is wL:
|
|
mob.number_line.next_to(mob, UP, MED_LARGE_BUFF, LEFT)
|
|
else:
|
|
mob.number_line.next_to(mob, RIGHT)
|
|
if mob is C0:
|
|
mob.number_line.x_max = 0.5
|
|
for tick_mark in mob.number_line.tick_marks[1::2]:
|
|
mob.number_line.tick_marks.remove(tick_mark)
|
|
mob.dot = Dot(color = mob.get_color())
|
|
mob.dot.move_to(
|
|
mob.number_line.number_to_point(mob.val)
|
|
)
|
|
if mob is wL:
|
|
path_arc = 0
|
|
dot_spot = mob.dot.get_bottom()
|
|
else:
|
|
path_arc = -0.7*np.pi
|
|
dot_spot = mob.dot.get_top()
|
|
if mob is C0:
|
|
mob_spot = mob[0].get_corner(UP+RIGHT)
|
|
tip_length = 0.15
|
|
else:
|
|
mob_spot = mob.get_corner(UP+RIGHT)
|
|
tip_length = 0.2
|
|
mob.arrow = Arrow(
|
|
mob_spot, dot_spot,
|
|
path_arc = path_arc,
|
|
tip_length = tip_length,
|
|
buff = SMALL_BUFF,
|
|
)
|
|
mob.arrow.set_color(mob.get_color())
|
|
mob.arrow.set_stroke(width = 5)
|
|
|
|
self.play(ShowCreation(
|
|
mob.number_line,
|
|
lag_ratio = 0.5
|
|
))
|
|
self.play(
|
|
ShowCreation(mob.arrow),
|
|
ReplacementTransform(
|
|
mob.copy(), mob.dot,
|
|
path_arc = path_arc
|
|
)
|
|
)
|
|
self.wait()
|
|
|
|
def ask_about_w_sensitivity(self):
|
|
wL, aLm1, bL, zL, aL, C0 = [
|
|
getattr(self.comp_graph, attr)
|
|
for attr in ["wL", "aLm1", "bL", "zL", "aL", "C0"]
|
|
]
|
|
aLm1_val = self.last_neurons[1].get_fill_opacity()
|
|
bL_val = self.network.biases[-1][0]
|
|
|
|
get_wL_val = lambda : wL.number_line.point_to_number(
|
|
wL.dot.get_center()
|
|
)
|
|
get_zL_val = lambda : get_wL_val()*aLm1_val+bL_val
|
|
get_aL_val = lambda : sigmoid(get_zL_val())
|
|
get_C0_val = lambda : (get_aL_val() - 1)**2
|
|
|
|
def generate_dot_update(term, val_func):
|
|
def update_dot(dot):
|
|
dot.move_to(term.number_line.number_to_point(val_func()))
|
|
return dot
|
|
return update_dot
|
|
|
|
dot_update_anims = [
|
|
UpdateFromFunc(term.dot, generate_dot_update(term, val_func))
|
|
for term, val_func in [
|
|
(zL, get_zL_val),
|
|
(aL, get_aL_val),
|
|
(C0, get_C0_val),
|
|
]
|
|
]
|
|
|
|
def shake_dot(run_time = 2, rate_func = there_and_back):
|
|
self.play(
|
|
ApplyMethod(
|
|
wL.dot.shift, LEFT,
|
|
rate_func = rate_func,
|
|
run_time = run_time
|
|
),
|
|
*dot_update_anims
|
|
)
|
|
|
|
wL_line = Line(wL.dot.get_center(), wL.dot.get_center()+LEFT)
|
|
del_wL = TexMobject("\\partial w^{(L)}")
|
|
del_wL.scale(self.derivative_scale_val)
|
|
del_wL.brace = Brace(wL_line, UP, buff = SMALL_BUFF)
|
|
del_wL.set_color(wL.get_color())
|
|
del_wL.next_to(del_wL.brace, UP, SMALL_BUFF)
|
|
|
|
C0_line = Line(C0.dot.get_center(), C0.dot.get_center()+MED_SMALL_BUFF*RIGHT)
|
|
del_C0 = TexMobject("\\partial C_0")
|
|
del_C0.scale(self.derivative_scale_val)
|
|
del_C0.brace = Brace(C0_line, UP, buff = SMALL_BUFF)
|
|
del_C0.set_color(C0.get_color())
|
|
del_C0.next_to(del_C0.brace, UP, SMALL_BUFF)
|
|
|
|
for sym in del_wL, del_C0:
|
|
self.play(
|
|
GrowFromCenter(sym.brace),
|
|
Write(sym, run_time = 1)
|
|
)
|
|
shake_dot()
|
|
self.wait()
|
|
|
|
self.set_variables_as_attrs(
|
|
shake_dot, del_wL, del_C0,
|
|
)
|
|
|
|
def show_derivative_wrt_w(self):
|
|
del_wL = self.del_wL
|
|
del_C0 = self.del_C0
|
|
cost_word = self.cost_word
|
|
cost_arrow = self.cost_arrow
|
|
shake_dot = self.shake_dot
|
|
wL = self.comp_graph.wL
|
|
|
|
dC_dw = TexMobject(
|
|
"{\\partial C_0", "\\over", "\\partial w^{(L)} }"
|
|
)
|
|
dC_dw[0].set_color(del_C0.get_color())
|
|
dC_dw[2].set_color(del_wL.get_color())
|
|
dC_dw.scale(self.derivative_scale_val)
|
|
dC_dw.to_edge(UP, buff = MED_SMALL_BUFF)
|
|
dC_dw.shift(3.5*LEFT)
|
|
|
|
full_rect = SurroundingRectangle(dC_dw)
|
|
full_rect_copy = full_rect.copy()
|
|
words = TextMobject("What we want")
|
|
words.next_to(full_rect, RIGHT)
|
|
words.set_color(YELLOW)
|
|
|
|
denom_rect = SurroundingRectangle(dC_dw[2])
|
|
numer_rect = SurroundingRectangle(dC_dw[0])
|
|
|
|
self.play(
|
|
ReplacementTransform(del_C0.copy(), dC_dw[0]),
|
|
ReplacementTransform(del_wL.copy(), dC_dw[2]),
|
|
Write(dC_dw[1], run_time = 1)
|
|
)
|
|
self.play(
|
|
FadeOut(cost_word),
|
|
FadeOut(cost_arrow),
|
|
ShowCreation(full_rect),
|
|
Write(words, run_time = 1),
|
|
)
|
|
self.wait(2)
|
|
self.play(
|
|
FadeOut(words),
|
|
ReplacementTransform(full_rect, denom_rect)
|
|
)
|
|
self.play(Transform(dC_dw[2].copy(), del_wL, remover = True))
|
|
shake_dot()
|
|
self.play(ReplacementTransform(denom_rect, numer_rect))
|
|
self.play(Transform(dC_dw[0].copy(), del_C0, remover = True))
|
|
shake_dot()
|
|
self.wait()
|
|
self.play(ReplacementTransform(numer_rect, full_rect_copy))
|
|
self.play(FadeOut(full_rect_copy))
|
|
self.wait()
|
|
|
|
self.dC_dw = dC_dw
|
|
|
|
def show_chain_of_events(self):
|
|
comp_graph = self.comp_graph
|
|
wL, zL, aL, C0 = [
|
|
getattr(comp_graph, attr)
|
|
for attr in ["wL", "zL", "aL", "C0"]
|
|
]
|
|
del_wL = self.del_wL
|
|
del_C0 = self.del_C0
|
|
|
|
zL_line = Line(ORIGIN, MED_LARGE_BUFF*LEFT)
|
|
zL_line.shift(zL.dot.get_center())
|
|
del_zL = TexMobject("\\partial z^{(L)}")
|
|
del_zL.set_color(zL.get_color())
|
|
del_zL.brace = Brace(zL_line, DOWN, buff = SMALL_BUFF)
|
|
|
|
aL_line = Line(ORIGIN, MED_SMALL_BUFF*LEFT)
|
|
aL_line.shift(aL.dot.get_center())
|
|
del_aL = TexMobject("\\partial a^{(L)}")
|
|
del_aL.set_color(aL.get_color())
|
|
del_aL.brace = Brace(aL_line, DOWN, buff = SMALL_BUFF)
|
|
|
|
for sym in del_zL, del_aL:
|
|
sym.scale(self.derivative_scale_val)
|
|
sym.brace.stretch_about_point(
|
|
0.5, 1, sym.brace.get_top(),
|
|
)
|
|
sym.shift(
|
|
sym.brace.get_bottom()+SMALL_BUFF*DOWN \
|
|
-sym[0].get_corner(UP+RIGHT)
|
|
)
|
|
|
|
syms = [del_wL, del_zL, del_aL, del_C0]
|
|
for s1, s2 in zip(syms, syms[1:]):
|
|
self.play(
|
|
ReplacementTransform(s1.copy(), s2),
|
|
ReplacementTransform(s1.brace.copy(), s2.brace),
|
|
)
|
|
self.shake_dot(run_time = 1.5)
|
|
self.wait(0.5)
|
|
self.wait()
|
|
|
|
self.set_variables_as_attrs(del_zL, del_aL)
|
|
|
|
def show_chain_rule(self):
|
|
dC_dw = self.dC_dw
|
|
del_syms = [
|
|
getattr(self, attr)
|
|
for attr in ("del_wL", "del_zL", "del_aL", "del_C0")
|
|
]
|
|
|
|
dz_dw = TexMobject(
|
|
"{\\partial z^{(L)}", "\\over", "\\partial w^{(L)}}"
|
|
)
|
|
da_dz = TexMobject(
|
|
"{\\partial a^{(L)}", "\\over", "\\partial z^{(L)}}"
|
|
)
|
|
dC_da = TexMobject(
|
|
"{\\partial C0}", "\\over", "\\partial a^{(L)}}"
|
|
)
|
|
dz_dw[2].set_color(self.del_wL.get_color())
|
|
VGroup(dz_dw[0], da_dz[2]).set_color(self.z_color)
|
|
dC_da[0].set_color(self.cost_color)
|
|
equals = TexMobject("=")
|
|
group = VGroup(equals, dz_dw, da_dz, dC_da)
|
|
group.arrange(RIGHT, SMALL_BUFF)
|
|
group.scale(self.derivative_scale_val)
|
|
group.next_to(dC_dw, RIGHT)
|
|
for mob in group[1:]:
|
|
target_y = equals.get_center()[1]
|
|
y = mob[1].get_center()[1]
|
|
mob.shift((target_y - y)*UP)
|
|
|
|
self.play(Write(equals, run_time = 1))
|
|
for frac, top_sym, bot_sym in zip(group[1:], del_syms[1:], del_syms):
|
|
self.play(Indicate(top_sym, rate_func = wiggle))
|
|
self.play(
|
|
ReplacementTransform(top_sym.copy(), frac[0]),
|
|
FadeIn(frac[1]),
|
|
)
|
|
self.play(Indicate(bot_sym, rate_func = wiggle))
|
|
self.play(ReplacementTransform(
|
|
bot_sym.copy(), frac[2]
|
|
))
|
|
self.wait()
|
|
self.shake_dot()
|
|
self.wait()
|
|
|
|
self.chain_rule_equation = VGroup(dC_dw, *group)
|
|
|
|
def name_chain_rule(self):
|
|
graph_parts = self.get_all_comp_graph_parts()
|
|
equation = self.chain_rule_equation
|
|
rect = SurroundingRectangle(equation)
|
|
group = VGroup(equation, rect)
|
|
group.generate_target()
|
|
group.target.to_corner(UP+LEFT)
|
|
words = TextMobject("Chain rule")
|
|
words.set_color(YELLOW)
|
|
words.next_to(group.target, DOWN)
|
|
|
|
self.play(ShowCreation(rect))
|
|
self.play(
|
|
MoveToTarget(group),
|
|
Write(words, run_time = 1),
|
|
graph_parts.scale, 0.7, graph_parts.get_bottom()
|
|
)
|
|
self.wait(2)
|
|
self.play(*list(map(FadeOut, [rect, words])))
|
|
|
|
def indicate_everything_on_screen(self):
|
|
everything = VGroup(*self.get_top_level_mobjects())
|
|
everything = VGroup(*[m for m in everything.family_members_with_points() if not m.is_subpath])
|
|
self.play(OldLaggedStart(
|
|
Indicate, everything,
|
|
rate_func = wiggle,
|
|
lag_ratio = 0.2,
|
|
run_time = 5
|
|
))
|
|
self.wait()
|
|
|
|
def prepare_for_derivatives(self):
|
|
zL_formula = self.zL_formula
|
|
aL_formula = self.aL_formula
|
|
az_formulas = VGroup(zL_formula, aL_formula)
|
|
cost_equation = self.cost_equation
|
|
desired_output_words = self.desired_output_words
|
|
|
|
az_formulas.generate_target()
|
|
az_formulas.target.to_edge(RIGHT)
|
|
|
|
index = 4
|
|
cost_eq = cost_equation[index]
|
|
z_eq = az_formulas.target[0][1]
|
|
x_shift = (z_eq.get_center() - cost_eq.get_center())[0]*RIGHT
|
|
cost_equation.generate_target()
|
|
Transform(
|
|
VGroup(*cost_equation.target[1:index]),
|
|
VectorizedPoint(cost_eq.get_left())
|
|
).update(1)
|
|
cost_equation.target[0].next_to(cost_eq, LEFT, SMALL_BUFF)
|
|
cost_equation.target.shift(x_shift)
|
|
|
|
self.play(
|
|
FadeOut(self.all_comp_graph_parts),
|
|
FadeOut(self.desired_output_words),
|
|
MoveToTarget(az_formulas),
|
|
MoveToTarget(cost_equation)
|
|
)
|
|
|
|
def compute_derivatives(self):
|
|
cost_equation = self.cost_equation
|
|
zL_formula = self.zL_formula
|
|
aL_formula = self.aL_formula
|
|
chain_rule_equation = self.chain_rule_equation.copy()
|
|
dC_dw, equals, dz_dw, da_dz, dC_da = chain_rule_equation
|
|
|
|
derivs = VGroup(dC_da, da_dz, dz_dw)
|
|
deriv_targets = VGroup()
|
|
for deriv in derivs:
|
|
deriv.generate_target()
|
|
deriv_targets.add(deriv.target)
|
|
deriv_targets.arrange(DOWN, buff = MED_LARGE_BUFF)
|
|
deriv_targets.next_to(dC_dw, DOWN, LARGE_BUFF)
|
|
for deriv in derivs:
|
|
deriv.equals = TexMobject("=")
|
|
deriv.equals.next_to(deriv.target, RIGHT)
|
|
|
|
|
|
#dC_da
|
|
self.play(
|
|
MoveToTarget(dC_da),
|
|
Write(dC_da.equals)
|
|
)
|
|
index = 4
|
|
cost_rhs = VGroup(*cost_equation[index+1:])
|
|
dC_da.rhs = cost_rhs.copy()
|
|
two = dC_da.rhs[-1]
|
|
two.scale(1.5)
|
|
two.next_to(dC_da.rhs[0], LEFT, SMALL_BUFF)
|
|
dC_da.rhs.next_to(dC_da.equals, RIGHT)
|
|
dC_da.rhs.shift(0.7*SMALL_BUFF*UP)
|
|
cost_equation.save_state()
|
|
|
|
self.play(
|
|
cost_equation.next_to, dC_da.rhs,
|
|
DOWN, MED_LARGE_BUFF, LEFT
|
|
)
|
|
self.wait()
|
|
self.play(ReplacementTransform(
|
|
cost_rhs.copy(), dC_da.rhs,
|
|
path_arc = np.pi/2,
|
|
))
|
|
self.wait()
|
|
self.play(cost_equation.restore)
|
|
self.wait()
|
|
|
|
#show_difference
|
|
neuron = self.last_neurons[0]
|
|
decimal = self.decimals[0]
|
|
double_arrow = DoubleArrow(
|
|
neuron.get_right(),
|
|
self.desired_output_neuron.get_left(),
|
|
buff = SMALL_BUFF,
|
|
color = RED
|
|
)
|
|
|
|
moving_decimals = VGroup(
|
|
self.decimals[0].copy(),
|
|
self.desired_output_decimal.copy()
|
|
)
|
|
minus = TexMobject("-")
|
|
minus.move_to(moving_decimals)
|
|
minus.scale(0.7)
|
|
minus.set_fill(opacity = 0)
|
|
moving_decimals.submobjects.insert(1, minus)
|
|
moving_decimals.generate_target(use_deepcopy = True)
|
|
moving_decimals.target.arrange(RIGHT, buff = SMALL_BUFF)
|
|
moving_decimals.target.scale(1.5)
|
|
moving_decimals.target.next_to(
|
|
dC_da.rhs, DOWN,
|
|
buff = MED_LARGE_BUFF,
|
|
aligned_edge = RIGHT,
|
|
)
|
|
moving_decimals.target.set_fill(WHITE, 1)
|
|
|
|
self.play(ReplacementTransform(
|
|
dC_da.rhs.copy(), double_arrow
|
|
))
|
|
self.wait()
|
|
self.play(MoveToTarget(moving_decimals))
|
|
opacity = neuron.get_fill_opacity()
|
|
for target_o in 0, opacity:
|
|
self.wait(2)
|
|
self.play(
|
|
neuron.set_fill, None, target_o,
|
|
*[
|
|
ChangingDecimal(d, lambda a : neuron.get_fill_opacity())
|
|
for d in (decimal, moving_decimals[0])
|
|
]
|
|
)
|
|
self.play(*list(map(FadeOut, [double_arrow, moving_decimals])))
|
|
|
|
#da_dz
|
|
self.play(
|
|
MoveToTarget(da_dz),
|
|
Write(da_dz.equals)
|
|
)
|
|
a_rhs = VGroup(*aL_formula[2:])
|
|
da_dz.rhs = a_rhs.copy()
|
|
prime = TexMobject("'")
|
|
prime.move_to(da_dz.rhs[0].get_corner(UP+RIGHT))
|
|
da_dz.rhs[0].shift(0.5*SMALL_BUFF*LEFT)
|
|
da_dz.rhs.add_to_back(prime)
|
|
da_dz.rhs.next_to(da_dz.equals, RIGHT)
|
|
da_dz.rhs.shift(0.5*SMALL_BUFF*UP)
|
|
aL_formula.save_state()
|
|
self.play(
|
|
aL_formula.next_to, da_dz.rhs,
|
|
DOWN, MED_LARGE_BUFF, LEFT
|
|
)
|
|
self.wait()
|
|
self.play(ReplacementTransform(
|
|
a_rhs.copy(), da_dz.rhs,
|
|
))
|
|
self.wait()
|
|
self.play(aL_formula.restore)
|
|
self.wait()
|
|
|
|
#dz_dw
|
|
self.play(
|
|
MoveToTarget(dz_dw),
|
|
Write(dz_dw.equals)
|
|
)
|
|
z_rhs = VGroup(*zL_formula[2:])
|
|
dz_dw.rhs = z_rhs[1].copy()
|
|
dz_dw.rhs.next_to(dz_dw.equals, RIGHT)
|
|
dz_dw.rhs.shift(SMALL_BUFF*UP)
|
|
zL_formula.save_state()
|
|
self.play(
|
|
zL_formula.next_to, dz_dw.rhs,
|
|
DOWN, MED_LARGE_BUFF, LEFT,
|
|
)
|
|
self.wait()
|
|
rect = SurroundingRectangle(VGroup(*zL_formula[2:4]))
|
|
self.play(ShowCreation(rect))
|
|
self.play(FadeOut(rect))
|
|
self.play(ReplacementTransform(
|
|
z_rhs[1].copy(), dz_dw.rhs,
|
|
))
|
|
self.wait()
|
|
self.play(zL_formula.restore)
|
|
self.wait()
|
|
|
|
self.derivative_equations = VGroup(dC_da, da_dz, dz_dw)
|
|
|
|
def get_lost_in_formulas(self):
|
|
randy = Randolph()
|
|
randy.flip()
|
|
randy.scale(0.7)
|
|
randy.to_edge(DOWN)
|
|
randy.shift(LEFT)
|
|
|
|
self.play(FadeIn(randy))
|
|
self.play(randy.change, "pleading", self.chain_rule_equation)
|
|
self.play(Blink(randy))
|
|
self.play(randy.change, "maybe")
|
|
self.play(Blink(randy))
|
|
self.play(FadeOut(randy))
|
|
|
|
def fire_together_wire_together(self):
|
|
dz_dw = self.derivative_equations[2]
|
|
rhs = dz_dw.rhs
|
|
rhs_copy = rhs.copy()
|
|
del_wL = dz_dw[2].copy()
|
|
rect = SurroundingRectangle(VGroup(dz_dw, dz_dw.rhs))
|
|
edge = self.network_mob.edge_groups[-1][0]
|
|
edge.save_state()
|
|
neuron = self.last_neurons[1]
|
|
decimal = self.decimals[1]
|
|
|
|
def get_decimal_anims():
|
|
return [
|
|
ChangingDecimal(decimal, lambda a : neuron.get_fill_opacity()),
|
|
UpdateFromFunc(
|
|
decimal, lambda m : m.set_color(
|
|
WHITE if neuron.get_fill_opacity() < 0.8 \
|
|
else BLACK
|
|
)
|
|
)
|
|
]
|
|
|
|
self.play(ShowCreation(rect))
|
|
self.play(FadeOut(rect))
|
|
self.play(
|
|
del_wL.next_to, edge, UP, SMALL_BUFF
|
|
)
|
|
self.play(
|
|
edge.set_stroke, None, 10,
|
|
rate_func = wiggle,
|
|
run_time = 3,
|
|
)
|
|
self.wait()
|
|
self.play(rhs.shift, MED_LARGE_BUFF*UP, rate_func = wiggle)
|
|
self.play(
|
|
rhs_copy.move_to, neuron,
|
|
rhs_copy.set_fill, None, 0
|
|
)
|
|
self.remove(rhs_copy)
|
|
self.play(
|
|
neuron.set_fill, None, 0,
|
|
*get_decimal_anims(),
|
|
run_time = 3,
|
|
rate_func = there_and_back
|
|
)
|
|
self.wait()
|
|
|
|
#Fire together wire together
|
|
opacity = neuron.get_fill_opacity()
|
|
self.play(
|
|
neuron.set_fill, None, 0.99,
|
|
*get_decimal_anims()
|
|
)
|
|
self.play(edge.set_stroke, None, 8)
|
|
self.play(
|
|
neuron.set_fill, None, opacity,
|
|
*get_decimal_anims()
|
|
)
|
|
self.play(edge.restore, FadeOut(del_wL))
|
|
self.wait(3)
|
|
|
|
def organize_chain_rule_rhs(self):
|
|
fracs = self.derivative_equations
|
|
equals_group = VGroup(*[frac.equals for frac in fracs])
|
|
rhs_group = VGroup(*[frac.rhs for frac in reversed(fracs)])
|
|
|
|
chain_rule_equation = self.chain_rule_equation
|
|
equals = TexMobject("=")
|
|
equals.next_to(chain_rule_equation, RIGHT)
|
|
|
|
rhs_group.generate_target()
|
|
rhs_group.target.arrange(RIGHT, buff = SMALL_BUFF)
|
|
rhs_group.target.next_to(equals, RIGHT)
|
|
rhs_group.target.shift(SMALL_BUFF*UP)
|
|
|
|
right_group = VGroup(
|
|
self.cost_equation, self.zL_formula, self.aL_formula,
|
|
self.network_mob, self.decimals,
|
|
self.a_labels, self.a_label_arrows,
|
|
self.y_label, self.y_label_arrow,
|
|
self.desired_output_neuron,
|
|
self.desired_output_rect,
|
|
self.desired_output_decimal,
|
|
)
|
|
|
|
self.play(
|
|
MoveToTarget(rhs_group, path_arc = np.pi/2),
|
|
Write(equals),
|
|
FadeOut(fracs),
|
|
FadeOut(equals_group),
|
|
right_group.to_corner, DOWN+RIGHT
|
|
)
|
|
self.wait()
|
|
|
|
rhs_group.add(equals)
|
|
self.chain_rule_rhs = rhs_group
|
|
|
|
def show_average_derivative(self):
|
|
dC0_dw = self.chain_rule_equation[0]
|
|
full_derivative = TexMobject(
|
|
"{\\partial C", "\\over", "\\partial w^{(L)}}",
|
|
"=", "\\frac{1}{n}", "\\sum_{k=0}^{n-1}",
|
|
"{\\partial C_k", "\\over", "\\partial w^{(L)}}"
|
|
)
|
|
full_derivative.set_color_by_tex_to_color_map({
|
|
"partial C" : self.cost_color,
|
|
"partial w" : self.del_wL.get_color()
|
|
})
|
|
full_derivative.to_edge(LEFT)
|
|
|
|
dCk_dw = VGroup(*full_derivative[-3:])
|
|
lhs = VGroup(*full_derivative[:3])
|
|
rhs = VGroup(*full_derivative[4:])
|
|
lhs_brace = Brace(lhs, DOWN)
|
|
lhs_text = lhs_brace.get_text("Derivative of \\\\ full cost function")
|
|
rhs_brace = Brace(rhs, UP)
|
|
rhs_text = rhs_brace.get_text("Average of all \\\\ training examples")
|
|
VGroup(
|
|
full_derivative, lhs_brace, lhs_text, rhs_brace, rhs_text
|
|
).to_corner(DOWN+LEFT)
|
|
|
|
mover = dC0_dw.copy()
|
|
self.play(Transform(mover, dCk_dw))
|
|
self.play(Write(full_derivative, run_time = 2))
|
|
self.remove(mover)
|
|
for brace, text in (rhs_brace, rhs_text), (lhs_brace, lhs_text):
|
|
self.play(
|
|
GrowFromCenter(brace),
|
|
Write(text, run_time = 2),
|
|
)
|
|
self.wait(2)
|
|
self.cycle_through_altnernate_training_examples()
|
|
self.play(*list(map(FadeOut, [
|
|
VGroup(*full_derivative[3:]),
|
|
lhs_brace, lhs_text,
|
|
rhs_brace, rhs_text,
|
|
])))
|
|
|
|
self.dC_dw = lhs
|
|
|
|
def cycle_through_altnernate_training_examples(self):
|
|
neurons = VGroup(
|
|
self.desired_output_neuron, *self.last_neurons
|
|
)
|
|
decimals = VGroup(
|
|
self.desired_output_decimal, *self.decimals
|
|
)
|
|
group = VGroup(neurons, decimals)
|
|
group.save_state()
|
|
|
|
for x in range(20):
|
|
for n, d in zip(neurons, decimals):
|
|
o = np.random.random()
|
|
if n is self.desired_output_neuron:
|
|
o = np.round(o)
|
|
n.set_fill(opacity = o)
|
|
Transform(
|
|
d, self.get_neuron_activation_decimal(n)
|
|
).update(1)
|
|
self.wait(0.2)
|
|
self.play(group.restore, run_time = 0.2)
|
|
|
|
def show_gradient(self):
|
|
dC_dw = self.dC_dw
|
|
dC_dw.generate_target()
|
|
terms = VGroup(
|
|
TexMobject("{\\partial C", "\\over", "\\partial w^{(1)}"),
|
|
TexMobject("{\\partial C", "\\over", "\\partial b^{(1)}"),
|
|
TexMobject("\\vdots"),
|
|
dC_dw.target,
|
|
TexMobject("{\\partial C", "\\over", "\\partial b^{(L)}"),
|
|
)
|
|
for term in terms:
|
|
if isinstance(term, TexMobject):
|
|
term.set_color_by_tex_to_color_map({
|
|
"partial C" : RED,
|
|
"partial w" : BLUE,
|
|
"partial b" : MAROON_B,
|
|
})
|
|
terms.arrange(DOWN, buff = MED_LARGE_BUFF)
|
|
lb, rb = brackets = TexMobject("[]")
|
|
brackets.scale(3)
|
|
brackets.stretch_to_fit_height(1.1*terms.get_height())
|
|
lb.next_to(terms, LEFT, buff = SMALL_BUFF)
|
|
rb.next_to(terms, RIGHT, buff = SMALL_BUFF)
|
|
vect = VGroup(lb, terms, rb)
|
|
vect.set_height(5)
|
|
lhs = TexMobject("\\nabla C", "=")
|
|
lhs[0].set_color(RED)
|
|
lhs.next_to(vect, LEFT)
|
|
VGroup(lhs, vect).to_corner(DOWN+LEFT, buff = LARGE_BUFF)
|
|
terms.remove(dC_dw.target)
|
|
|
|
self.play(
|
|
MoveToTarget(dC_dw),
|
|
Write(vect, run_time = 1)
|
|
)
|
|
terms.add(dC_dw)
|
|
self.play(Write(lhs))
|
|
self.wait(2)
|
|
self.play(FadeOut(VGroup(lhs, vect)))
|
|
|
|
def transition_to_derivative_wrt_b(self):
|
|
all_comp_graph_parts = self.all_comp_graph_parts
|
|
all_comp_graph_parts.scale(
|
|
1.3, about_point = all_comp_graph_parts.get_bottom()
|
|
)
|
|
comp_graph = self.comp_graph
|
|
wL, bL, zL, aL, C0 = [
|
|
getattr(comp_graph, attr)
|
|
for attr in ["wL", "bL", "zL", "aL", "C0"]
|
|
]
|
|
path_to_C = VGroup(wL, zL, aL, C0)
|
|
|
|
top_expression = VGroup(
|
|
self.chain_rule_equation,
|
|
self.chain_rule_rhs
|
|
)
|
|
rect = SurroundingRectangle(top_expression)
|
|
|
|
self.play(ShowCreation(rect))
|
|
self.play(FadeIn(comp_graph), FadeOut(rect))
|
|
for x in range(2):
|
|
self.play(OldLaggedStart(
|
|
Indicate, path_to_C,
|
|
rate_func = there_and_back,
|
|
run_time = 1.5,
|
|
lag_ratio = 0.7,
|
|
))
|
|
self.wait()
|
|
|
|
def show_derivative_wrt_b(self):
|
|
comp_graph = self.comp_graph
|
|
dC0_dw = self.chain_rule_equation[0]
|
|
dz_dw = self.chain_rule_equation[2]
|
|
aLm1 = self.chain_rule_rhs[0]
|
|
left_term_group = VGroup(dz_dw, aLm1)
|
|
dz_dw_rect = SurroundingRectangle(dz_dw)
|
|
|
|
del_w = dC0_dw[2]
|
|
del_b = TexMobject("\\partial b^{(L)}")
|
|
del_b.set_color(MAROON_B)
|
|
del_b.replace(del_w)
|
|
|
|
dz_db = TexMobject(
|
|
"{\\partial z^{(L)}", "\\over", "\\partial b^{(L)}}"
|
|
)
|
|
dz_db.set_color_by_tex_to_color_map({
|
|
"partial z" : self.z_color,
|
|
"partial b" : MAROON_B
|
|
})
|
|
dz_db.replace(dz_dw)
|
|
|
|
one = TexMobject("1")
|
|
one.move_to(aLm1, RIGHT)
|
|
arrow = Arrow(
|
|
dz_db.get_bottom(),
|
|
one.get_bottom(),
|
|
path_arc = np.pi/2,
|
|
color = WHITE,
|
|
)
|
|
arrow.set_stroke(width = 2)
|
|
|
|
wL, bL, zL, aL, C0 = [
|
|
getattr(comp_graph, attr)
|
|
for attr in ["wL", "bL", "zL", "aL", "C0"]
|
|
]
|
|
path_to_C = VGroup(bL, zL, aL, C0)
|
|
def get_path_animation():
|
|
return OldLaggedStart(
|
|
Indicate, path_to_C,
|
|
rate_func = there_and_back,
|
|
run_time = 1.5,
|
|
lag_ratio = 0.7,
|
|
)
|
|
|
|
zL_formula = self.zL_formula
|
|
b_in_z_formula = zL_formula[-1]
|
|
z_formula_rect = SurroundingRectangle(zL_formula)
|
|
b_in_z_rect = SurroundingRectangle(b_in_z_formula)
|
|
|
|
self.play(get_path_animation())
|
|
self.play(ShowCreation(dz_dw_rect))
|
|
self.play(FadeOut(dz_dw_rect))
|
|
self.play(
|
|
left_term_group.shift, DOWN,
|
|
left_term_group.fade, 1,
|
|
)
|
|
self.remove(left_term_group)
|
|
self.chain_rule_equation.remove(dz_dw)
|
|
self.chain_rule_rhs.remove(aLm1)
|
|
self.play(Transform(del_w, del_b))
|
|
self.play(FadeIn(dz_db))
|
|
self.play(get_path_animation())
|
|
self.wait()
|
|
self.play(ShowCreation(z_formula_rect))
|
|
self.wait()
|
|
self.play(ReplacementTransform(z_formula_rect, b_in_z_rect))
|
|
self.wait()
|
|
self.play(
|
|
ReplacementTransform(b_in_z_formula.copy(), one),
|
|
FadeOut(b_in_z_rect)
|
|
)
|
|
self.play(
|
|
ShowCreation(arrow),
|
|
ReplacementTransform(
|
|
dz_db.copy(), one,
|
|
path_arc = arrow.path_arc
|
|
)
|
|
)
|
|
self.wait(2)
|
|
self.play(*list(map(FadeOut, [dz_db, arrow, one])))
|
|
|
|
self.dz_db = dz_db
|
|
|
|
def show_derivative_wrt_a(self):
|
|
denom = self.chain_rule_equation[0][2]
|
|
numer = VGroup(*self.chain_rule_equation[0][:2])
|
|
del_aLm1 = TexMobject("\\partial a^{(L-1)}")
|
|
del_aLm1.scale(0.8)
|
|
del_aLm1.move_to(denom)
|
|
dz_daLm1 = TexMobject(
|
|
"{\\partial z^{(L)}", "\\over", "\\partial a^{(L-1)}}"
|
|
)
|
|
dz_daLm1.scale(0.8)
|
|
dz_daLm1.next_to(self.chain_rule_equation[1], RIGHT, SMALL_BUFF)
|
|
dz_daLm1.shift(0.7*SMALL_BUFF*UP)
|
|
dz_daLm1[0].set_color(self.z_color)
|
|
dz_daLm1_rect = SurroundingRectangle(dz_daLm1)
|
|
wL = self.zL_formula[2].copy()
|
|
wL.next_to(self.chain_rule_rhs[0], LEFT, SMALL_BUFF)
|
|
|
|
arrow = Arrow(
|
|
dz_daLm1.get_bottom(), wL.get_bottom(),
|
|
path_arc = np.pi/2,
|
|
color = WHITE,
|
|
)
|
|
|
|
comp_graph = self.comp_graph
|
|
path_to_C = VGroup(*[
|
|
getattr(comp_graph, attr)
|
|
for attr in ["aLm1", "zL", "aL", "C0"]
|
|
])
|
|
def get_path_animation():
|
|
return OldLaggedStart(
|
|
Indicate, path_to_C,
|
|
rate_func = there_and_back,
|
|
run_time = 1.5,
|
|
lag_ratio = 0.7,
|
|
)
|
|
|
|
zL_formula = self.zL_formula
|
|
z_formula_rect = SurroundingRectangle(zL_formula)
|
|
a_in_z_rect = SurroundingRectangle(VGroup(*zL_formula[2:4]))
|
|
wL_in_z = zL_formula[2]
|
|
|
|
for x in range(3):
|
|
self.play(get_path_animation())
|
|
self.play(
|
|
numer.shift, SMALL_BUFF*UP,
|
|
Transform(denom, del_aLm1)
|
|
)
|
|
self.play(
|
|
FadeIn(dz_daLm1),
|
|
VGroup(*self.chain_rule_equation[-2:]).shift, SMALL_BUFF*RIGHT,
|
|
)
|
|
self.wait()
|
|
self.play(ShowCreation(dz_daLm1_rect))
|
|
self.wait()
|
|
self.play(ReplacementTransform(
|
|
dz_daLm1_rect, z_formula_rect
|
|
))
|
|
self.wait()
|
|
self.play(ReplacementTransform(z_formula_rect, a_in_z_rect))
|
|
self.play(
|
|
ReplacementTransform(wL_in_z.copy(), wL),
|
|
FadeOut(a_in_z_rect)
|
|
)
|
|
self.play(
|
|
ShowCreation(arrow),
|
|
ReplacementTransform(
|
|
dz_daLm1.copy(), wL,
|
|
path_arc = arrow.path_arc
|
|
)
|
|
)
|
|
self.wait(2)
|
|
|
|
self.chain_rule_rhs.add(wL, arrow)
|
|
self.chain_rule_equation.add(dz_daLm1)
|
|
|
|
def show_previous_weight_and_bias(self):
|
|
to_fade = self.chain_rule_rhs
|
|
comp_graph = self.comp_graph
|
|
prev_comp_subgraph = self.prev_comp_subgraph
|
|
prev_comp_subgraph.scale(0.8)
|
|
prev_comp_subgraph.next_to(comp_graph, UP, SMALL_BUFF)
|
|
|
|
prev_layer = VGroup(
|
|
self.network_mob.layers[1],
|
|
self.network_mob.edge_groups[1],
|
|
)
|
|
for mob in prev_layer:
|
|
mob.restore()
|
|
prev_layer.next_to(self.last_neurons, LEFT, buff = 0)
|
|
self.remove(prev_layer)
|
|
|
|
self.play(OldLaggedStart(FadeOut, to_fade, run_time = 1))
|
|
self.play(
|
|
ShowCreation(prev_comp_subgraph, run_time = 1),
|
|
self.chain_rule_equation.to_edge, RIGHT
|
|
)
|
|
self.play(FadeIn(prev_layer))
|
|
|
|
###
|
|
neuron = self.network_mob.layers[1].neurons[0]
|
|
decimal = self.get_neuron_activation_decimal(neuron)
|
|
a_label = TexMobject("a^{(L-2)}")
|
|
a_label.replace(self.a_labels[1])
|
|
arrow = self.a_label_arrows[1].copy()
|
|
VGroup(a_label, arrow).shift(
|
|
neuron.get_center() - self.last_neurons[1].get_center()
|
|
)
|
|
|
|
self.play(
|
|
Write(a_label, run_time = 1),
|
|
Write(decimal, run_time = 1),
|
|
GrowArrow(arrow),
|
|
)
|
|
|
|
def animate_long_path(self):
|
|
comp_graph = self.comp_graph
|
|
path_to_C = VGroup(
|
|
self.wLm1, self.zLm1,
|
|
*[
|
|
getattr(comp_graph, attr)
|
|
for attr in ["aLm1", "zL", "aL", "C0"]
|
|
]
|
|
)
|
|
for x in range(2):
|
|
self.play(OldLaggedStart(
|
|
Indicate, path_to_C,
|
|
rate_func = there_and_back,
|
|
run_time = 1.5,
|
|
lag_ratio = 0.4,
|
|
))
|
|
self.wait(2)
|
|
|
|
###
|
|
|
|
def get_neuron_activation_decimal(self, neuron):
|
|
opacity = neuron.get_fill_opacity()
|
|
decimal = DecimalNumber(opacity, num_decimal_places = 2)
|
|
decimal.set_width(0.85*neuron.get_width())
|
|
if decimal.number > 0.8:
|
|
decimal.set_fill(BLACK)
|
|
decimal.move_to(neuron)
|
|
return decimal
|
|
|
|
def get_all_comp_graph_parts(self):
|
|
comp_graph = self.comp_graph
|
|
result = VGroup(comp_graph)
|
|
for attr in "wL", "zL", "aL", "C0":
|
|
sym = getattr(comp_graph, attr)
|
|
result.add(
|
|
sym.arrow, sym.number_line, sym.dot
|
|
)
|
|
del_sym = getattr(self, "del_" + attr)
|
|
result.add(del_sym, del_sym.brace)
|
|
|
|
self.all_comp_graph_parts = result
|
|
return result
|
|
|
|
class IsntThatOverSimplified(TeacherStudentsScene):
|
|
def construct(self):
|
|
self.student_says(
|
|
"Isn't that over-simplified?",
|
|
target_mode = "raise_right_hand",
|
|
run_time = 1
|
|
)
|
|
self.change_student_modes(
|
|
"pondering", "raise_right_hand", "pondering"
|
|
)
|
|
self.wait()
|
|
self.teacher_says(
|
|
"Not that much, actually!",
|
|
run_time = 1,
|
|
target_mode = "hooray"
|
|
)
|
|
self.wait(2)
|
|
|
|
class GeneralFormulas(SimplestNetworkExample):
|
|
CONFIG = {
|
|
"layer_sizes" : [3, 3, 2],
|
|
"network_mob_config" : {
|
|
"include_output_labels" : False,
|
|
"neuron_to_neuron_buff" : LARGE_BUFF,
|
|
"neuron_radius" : 0.3,
|
|
},
|
|
"edge_stroke_width" : 4,
|
|
"stroke_width_exp" : 0.2,
|
|
"random_seed" : 9,
|
|
}
|
|
def setup(self):
|
|
self.seed_random_libraries()
|
|
self.setup_bases()
|
|
|
|
def construct(self):
|
|
self.setup_network_mob()
|
|
self.show_all_a_labels()
|
|
self.only_show_abstract_a_labels()
|
|
self.add_desired_output()
|
|
self.show_cost()
|
|
self.show_example_weight()
|
|
self.show_values_between_weight_and_cost()
|
|
self.show_weight_chain_rule()
|
|
self.show_derivative_wrt_prev_activation()
|
|
self.show_multiple_paths_from_prev_layer_neuron()
|
|
self.show_previous_layer()
|
|
|
|
def setup_network_mob(self):
|
|
self.color_network_edges()
|
|
self.network_mob.to_edge(LEFT)
|
|
self.network_mob.shift(DOWN)
|
|
in_vect = np.random.random(self.layer_sizes[0])
|
|
self.network_mob.activate_layers(in_vect)
|
|
self.remove(self.network_mob.layers[0])
|
|
self.remove(self.network_mob.edge_groups[0])
|
|
|
|
def show_all_a_labels(self):
|
|
Lm1_neurons = self.network_mob.layers[-2].neurons
|
|
L_neurons = self.network_mob.layers[-1].neurons
|
|
all_arrows = VGroup()
|
|
all_labels = VGroup()
|
|
all_decimals = VGroup()
|
|
all_subscript_rects = VGroup()
|
|
for neurons in L_neurons, Lm1_neurons:
|
|
is_L = neurons is L_neurons
|
|
vect = LEFT if is_L else RIGHT
|
|
s = "L" if is_L else "L-1"
|
|
arrows = VGroup()
|
|
labels = VGroup()
|
|
decimals = VGroup()
|
|
subscript_rects = VGroup()
|
|
for i, neuron in enumerate(neurons):
|
|
arrow = Arrow(ORIGIN, vect)
|
|
arrow.next_to(neuron, -vect)
|
|
arrow.set_fill(WHITE)
|
|
label = TexMobject("a^{(%s)}_%d"%(s, i))
|
|
label.next_to(arrow, -vect, SMALL_BUFF)
|
|
rect = SurroundingRectangle(label[-1], buff = 0.5*SMALL_BUFF)
|
|
decimal = self.get_neuron_activation_decimal(neuron)
|
|
neuron.arrow = arrow
|
|
neuron.label = label
|
|
neuron.decimal = decimal
|
|
arrows.add(arrow)
|
|
labels.add(label)
|
|
decimals.add(decimal)
|
|
subscript_rects.add(rect)
|
|
all_arrows.add(arrows)
|
|
all_labels.add(labels)
|
|
all_decimals.add(decimals)
|
|
all_subscript_rects.add(subscript_rects)
|
|
|
|
start_labels, start_arrows = [
|
|
VGroup(*list(map(VGroup, [group[i][0] for i in (0, 1)]))).copy()
|
|
for group in (all_labels, all_arrows)
|
|
]
|
|
for label in start_labels:
|
|
label[0][-1].set_color(BLACK)
|
|
|
|
self.add(all_decimals)
|
|
self.play(*it.chain(
|
|
list(map(Write, start_labels)),
|
|
[GrowArrow(a[0]) for a in start_arrows]
|
|
))
|
|
self.wait()
|
|
self.play(
|
|
ReplacementTransform(start_labels, all_labels),
|
|
ReplacementTransform(start_arrows, all_arrows),
|
|
)
|
|
self.play(OldLaggedStart(
|
|
ShowCreationThenDestruction,
|
|
VGroup(*all_subscript_rects.family_members_with_points()),
|
|
lag_ratio = 0.7
|
|
))
|
|
self.wait()
|
|
|
|
self.set_variables_as_attrs(
|
|
L_neurons, Lm1_neurons,
|
|
all_arrows, all_labels,
|
|
all_decimals, all_subscript_rects,
|
|
)
|
|
|
|
def only_show_abstract_a_labels(self):
|
|
arrows_to_fade = VGroup()
|
|
labels_to_fade = VGroup()
|
|
labels_to_change = VGroup()
|
|
self.chosen_neurons = VGroup()
|
|
rects = VGroup()
|
|
for x, layer in enumerate(self.network_mob.layers[-2:]):
|
|
for y, neuron in enumerate(layer.neurons):
|
|
if (x == 0 and y == 1) or (x == 1 and y == 0):
|
|
tex = "k" if x == 0 else "j"
|
|
neuron.label.generate_target()
|
|
self.replace_subscript(neuron.label.target, tex)
|
|
self.chosen_neurons.add(neuron)
|
|
labels_to_change.add(neuron.label)
|
|
rects.add(SurroundingRectangle(
|
|
neuron.label.target[-1],
|
|
buff = 0.5*SMALL_BUFF
|
|
))
|
|
else:
|
|
labels_to_fade.add(neuron.label)
|
|
arrows_to_fade.add(neuron.arrow)
|
|
|
|
self.play(
|
|
OldLaggedStart(FadeOut, labels_to_fade),
|
|
OldLaggedStart(FadeOut, arrows_to_fade),
|
|
run_time = 1
|
|
)
|
|
for neuron, rect in zip(self.chosen_neurons, rects):
|
|
self.play(
|
|
MoveToTarget(neuron.label),
|
|
ShowCreation(rect)
|
|
)
|
|
self.play(FadeOut(rect))
|
|
self.wait()
|
|
self.wait()
|
|
|
|
def add_desired_output(self):
|
|
layer = self.network_mob.layers[-1]
|
|
desired_output = layer.deepcopy()
|
|
desired_output.shift(3*RIGHT)
|
|
desired_output_decimals = VGroup()
|
|
arrows = VGroup()
|
|
labels = VGroup()
|
|
for i, neuron in enumerate(desired_output.neurons):
|
|
neuron.set_fill(opacity = i)
|
|
decimal = self.get_neuron_activation_decimal(neuron)
|
|
neuron.decimal = decimal
|
|
neuron.arrow = Arrow(ORIGIN, LEFT, color = WHITE)
|
|
neuron.arrow.next_to(neuron, RIGHT)
|
|
neuron.label = TexMobject("y_%d"%i)
|
|
neuron.label.next_to(neuron.arrow, RIGHT)
|
|
neuron.label.set_color(self.desired_output_color)
|
|
|
|
desired_output_decimals.add(decimal)
|
|
arrows.add(neuron.arrow)
|
|
labels.add(neuron.label)
|
|
rect = SurroundingRectangle(desired_output, buff = 0.5*SMALL_BUFF)
|
|
words = TextMobject("Desired output")
|
|
words.next_to(rect, DOWN)
|
|
VGroup(words, rect).set_color(self.desired_output_color)
|
|
|
|
self.play(
|
|
FadeIn(rect),
|
|
FadeIn(words),
|
|
ReplacementTransform(layer.copy(), desired_output),
|
|
FadeIn(labels),
|
|
*[
|
|
ReplacementTransform(n1.decimal.copy(), n2.decimal)
|
|
for n1, n2 in zip(layer.neurons, desired_output.neurons)
|
|
] + list(map(GrowArrow, arrows))
|
|
)
|
|
self.wait()
|
|
|
|
self.set_variables_as_attrs(
|
|
desired_output,
|
|
desired_output_decimals,
|
|
desired_output_rect = rect,
|
|
desired_output_words = words,
|
|
)
|
|
|
|
def show_cost(self):
|
|
aj = self.chosen_neurons[1].label.copy()
|
|
yj = self.desired_output.neurons[0].label.copy()
|
|
|
|
cost_equation = TexMobject(
|
|
"C_0", "=", "\\sum_{j = 0}^{n_L - 1}",
|
|
"(", "a^{(L)}_j", "-", "y_j", ")", "^2"
|
|
)
|
|
cost_equation.to_corner(UP+RIGHT)
|
|
cost_equation[0].set_color(self.cost_color)
|
|
aj.target = cost_equation.get_part_by_tex("a^{(L)}_j")
|
|
yj.target = cost_equation.get_part_by_tex("y_j")
|
|
yj.target.set_color(self.desired_output_color)
|
|
to_fade_in = VGroup(*[m for m in cost_equation if m not in [aj.target, yj.target]])
|
|
sum_part = cost_equation.get_part_by_tex("sum")
|
|
|
|
self.play(*[
|
|
ReplacementTransform(mob, mob.target)
|
|
for mob in (aj, yj)
|
|
])
|
|
self.play(OldLaggedStart(FadeIn, to_fade_in))
|
|
self.wait(2)
|
|
self.play(OldLaggedStart(
|
|
Indicate, sum_part,
|
|
rate_func = wiggle,
|
|
))
|
|
self.wait()
|
|
for mob in aj.target, yj.target, cost_equation[-1]:
|
|
self.play(Indicate(mob))
|
|
self.wait()
|
|
|
|
self.set_variables_as_attrs(cost_equation)
|
|
|
|
def show_example_weight(self):
|
|
edges = self.network_mob.edge_groups[-1]
|
|
edge = self.chosen_neurons[1].edges_in[1]
|
|
faded_edges = VGroup(*[e for e in edges if e is not edge])
|
|
faded_edges.save_state()
|
|
for faded_edge in faded_edges:
|
|
faded_edge.save_state()
|
|
|
|
w_label = TexMobject("w^{(L)}_{jk}")
|
|
subscripts = VGroup(*w_label[-2:])
|
|
w_label.scale(1.2)
|
|
w_label.add_background_rectangle()
|
|
w_label.next_to(ORIGIN, UP, SMALL_BUFF)
|
|
w_label.rotate(edge.get_angle())
|
|
w_label.shift(edge.get_center())
|
|
w_label.set_color(BLUE)
|
|
|
|
edges.save_state()
|
|
edges.generate_target()
|
|
for e in edges.target:
|
|
e.rotate(-e.get_angle())
|
|
edges.target.arrange(DOWN)
|
|
edges.target.move_to(edges)
|
|
edges.target.to_edge(UP)
|
|
|
|
self.play(MoveToTarget(edges))
|
|
self.play(OldLaggedStart(
|
|
ApplyFunction, edges,
|
|
lambda e : (
|
|
lambda m : m.rotate_in_place(np.pi/12).set_color(YELLOW),
|
|
e
|
|
),
|
|
rate_func = wiggle
|
|
))
|
|
self.play(edges.restore)
|
|
self.play(faded_edges.fade, 0.9)
|
|
for neuron in self.chosen_neurons:
|
|
self.play(Indicate(neuron), Animation(neuron.decimal))
|
|
self.play(Write(w_label))
|
|
self.wait()
|
|
self.play(Indicate(subscripts))
|
|
for x in range(2):
|
|
self.play(Swap(*subscripts))
|
|
self.wait()
|
|
|
|
self.set_variables_as_attrs(faded_edges, w_label)
|
|
|
|
def show_values_between_weight_and_cost(self):
|
|
z_formula = TexMobject(
|
|
"z^{(L)}_j", "=",
|
|
"w^{(L)}_{j0}", "a^{(L-1)}_0", "+",
|
|
"w^{(L)}_{j1}", "a^{(L-1)}_1", "+",
|
|
"w^{(L)}_{j2}", "a^{(L-1)}_2", "+",
|
|
"b^{(L)}_j"
|
|
)
|
|
compact_z_formula = TexMobject(
|
|
"z^{(L)}_j", "=",
|
|
"\\cdots", "", "+"
|
|
"w^{(L)}_{jk}", "a^{(L-1)}_k", "+",
|
|
"\\cdots", "", "", "",
|
|
)
|
|
for expression in z_formula, compact_z_formula:
|
|
expression.to_corner(UP+RIGHT)
|
|
expression.set_color_by_tex_to_color_map({
|
|
"z^" : self.z_color,
|
|
"w^" : self.w_label.get_color(),
|
|
"b^" : MAROON_B,
|
|
})
|
|
w_part = z_formula.get_parts_by_tex("w^")[1]
|
|
aLm1_part = z_formula.get_parts_by_tex("a^{(L-1)}")[1]
|
|
|
|
a_formula = TexMobject(
|
|
"a^{(L)}_j", "=", "\\sigma(", "z^{(L)}_j", ")"
|
|
)
|
|
a_formula.set_color_by_tex("z^", self.z_color)
|
|
a_formula.next_to(z_formula, DOWN, MED_LARGE_BUFF)
|
|
a_formula.align_to(self.cost_equation, LEFT)
|
|
aL_part = a_formula[0]
|
|
|
|
to_fade = VGroup(
|
|
self.desired_output,
|
|
self.desired_output_decimals,
|
|
self.desired_output_rect,
|
|
self.desired_output_words,
|
|
*[
|
|
VGroup(n.arrow, n.label)
|
|
for n in self.desired_output.neurons
|
|
]
|
|
)
|
|
|
|
self.play(
|
|
FadeOut(to_fade),
|
|
self.cost_equation.next_to, a_formula, DOWN, MED_LARGE_BUFF,
|
|
self.cost_equation.to_edge, RIGHT,
|
|
ReplacementTransform(self.w_label[1].copy(), w_part),
|
|
ReplacementTransform(
|
|
self.chosen_neurons[0].label.copy(),
|
|
aLm1_part
|
|
),
|
|
)
|
|
self.play(Write(VGroup(*[m for m in z_formula if m not in [w_part, aLm1_part]])))
|
|
self.wait()
|
|
self.play(ReplacementTransform(
|
|
self.chosen_neurons[1].label.copy(),
|
|
aL_part
|
|
))
|
|
self.play(
|
|
Write(VGroup(*a_formula[1:3] + [a_formula[-1]])),
|
|
ReplacementTransform(
|
|
z_formula[0].copy(),
|
|
a_formula.get_part_by_tex("z^")
|
|
)
|
|
)
|
|
self.wait()
|
|
|
|
self.set_variables_as_attrs(z_formula, compact_z_formula, a_formula)
|
|
|
|
def show_weight_chain_rule(self):
|
|
chain_rule = self.get_chain_rule(
|
|
"{\\partial C_0", "\\over", "\\partial w^{(L)}_{jk}}",
|
|
"=",
|
|
"{\\partial z^{(L)}_j", "\\over", "\\partial w^{(L)}_{jk}}",
|
|
"{\\partial a^{(L)}_j", "\\over", "\\partial z^{(L)}_j}",
|
|
"{\\partial C_0", "\\over", "\\partial a^{(L)}_j}",
|
|
)
|
|
terms = VGroup(*[
|
|
VGroup(*chain_rule[i:i+3])
|
|
for i in range(4,len(chain_rule), 3)
|
|
])
|
|
rects = VGroup(*[
|
|
SurroundingRectangle(term, buff = 0.5*SMALL_BUFF)
|
|
for term in terms
|
|
])
|
|
rects.set_color_by_gradient(GREEN, WHITE, RED)
|
|
|
|
self.play(Transform(
|
|
self.z_formula, self.compact_z_formula
|
|
))
|
|
self.play(Write(chain_rule))
|
|
self.wait()
|
|
self.play(OldLaggedStart(
|
|
ShowCreationThenDestruction, rects,
|
|
lag_ratio = 0.7,
|
|
run_time = 3
|
|
))
|
|
self.wait()
|
|
|
|
self.set_variables_as_attrs(chain_rule)
|
|
|
|
def show_derivative_wrt_prev_activation(self):
|
|
chain_rule = self.get_chain_rule(
|
|
"{\\partial C_0", "\\over", "\\partial a^{(L-1)}_k}",
|
|
"=",
|
|
"\\sum_{j=0}^{n_L - 1}",
|
|
"{\\partial z^{(L)}_j", "\\over", "\\partial a^{(L-1)}_k}",
|
|
"{\\partial a^{(L)}_j", "\\over", "\\partial z^{(L)}_j}",
|
|
"{\\partial C_0", "\\over", "\\partial a^{(L)}_j}",
|
|
)
|
|
formulas = VGroup(self.z_formula, self.a_formula, self.cost_equation)
|
|
|
|
n = chain_rule.index_of_part_by_tex("sum")
|
|
self.play(ReplacementTransform(
|
|
self.chain_rule, VGroup(*chain_rule[:n] + chain_rule[n+1:])
|
|
))
|
|
self.play(Write(chain_rule[n], run_time = 1))
|
|
self.wait()
|
|
|
|
self.set_variables_as_attrs(chain_rule)
|
|
|
|
def show_multiple_paths_from_prev_layer_neuron(self):
|
|
neurons = self.network_mob.layers[-1].neurons
|
|
labels, arrows, decimals = [
|
|
VGroup(*[getattr(n, attr) for n in neurons])
|
|
for attr in ("label", "arrow", "decimal")
|
|
]
|
|
edges = VGroup(*[n.edges_in[1] for n in neurons])
|
|
labels[0].generate_target()
|
|
self.replace_subscript(labels[0].target, "0")
|
|
|
|
paths = [
|
|
VGroup(
|
|
self.chosen_neurons[0].label,
|
|
self.chosen_neurons[0].arrow,
|
|
self.chosen_neurons[0],
|
|
self.chosen_neurons[0].decimal,
|
|
edges[i],
|
|
neurons[i],
|
|
decimals[i],
|
|
arrows[i],
|
|
labels[i],
|
|
)
|
|
for i in range(2)
|
|
]
|
|
path_lines = VGroup()
|
|
for path in paths:
|
|
points = [path[0].get_center()]
|
|
for mob in path[1:]:
|
|
if isinstance(mob, DecimalNumber):
|
|
continue
|
|
points.append(mob.get_center())
|
|
path_line = VMobject()
|
|
path_line.set_points_as_corners(points)
|
|
path_lines.add(path_line)
|
|
path_lines.set_color(YELLOW)
|
|
|
|
chain_rule = self.chain_rule
|
|
n = chain_rule.index_of_part_by_tex("sum")
|
|
brace = Brace(VGroup(*chain_rule[n:]), DOWN, buff = SMALL_BUFF)
|
|
words = brace.get_text("Sum over layer L", buff = SMALL_BUFF)
|
|
|
|
cost_aL = self.cost_equation.get_part_by_tex("a^{(L)}")
|
|
|
|
self.play(
|
|
MoveToTarget(labels[0]),
|
|
FadeIn(labels[1]),
|
|
GrowArrow(arrows[1]),
|
|
edges[1].restore,
|
|
FadeOut(self.w_label),
|
|
)
|
|
for x in range(5):
|
|
anims = [
|
|
ShowCreationThenDestruction(
|
|
path_line,
|
|
run_time = 1.5,
|
|
time_width = 0.5,
|
|
)
|
|
for path_line in path_lines
|
|
]
|
|
if x == 2:
|
|
anims += [
|
|
FadeIn(words),
|
|
GrowFromCenter(brace)
|
|
]
|
|
self.play(*anims)
|
|
self.wait()
|
|
for path, path_line in zip(paths, path_lines):
|
|
label = path[-1]
|
|
self.play(
|
|
OldLaggedStart(
|
|
Indicate, path,
|
|
rate_func = wiggle,
|
|
run_time = 1,
|
|
),
|
|
ShowCreation(path_line),
|
|
Animation(label)
|
|
)
|
|
self.wait()
|
|
group = VGroup(label, cost_aL)
|
|
self.play(
|
|
group.shift, MED_SMALL_BUFF*UP,
|
|
rate_func = wiggle
|
|
)
|
|
self.play(FadeOut(path_line))
|
|
self.wait()
|
|
|
|
def show_previous_layer(self):
|
|
mid_neurons = self.network_mob.layers[1].neurons
|
|
layer = self.network_mob.layers[0]
|
|
edges = self.network_mob.edge_groups[0]
|
|
faded_edges = self.faded_edges
|
|
to_fade = VGroup(
|
|
self.chosen_neurons[0].label,
|
|
self.chosen_neurons[0].arrow,
|
|
)
|
|
for neuron in layer.neurons:
|
|
neuron.add(self.get_neuron_activation_decimal(neuron))
|
|
|
|
all_edges_out = VGroup(*[
|
|
VGroup(*[n.edges_in[i] for n in mid_neurons]).copy()
|
|
for i in range(len(layer.neurons))
|
|
])
|
|
all_edges_out.set_stroke(YELLOW, 3)
|
|
|
|
deriv = VGroup(*self.chain_rule[:3])
|
|
deriv_rect = SurroundingRectangle(deriv)
|
|
mid_neuron_outlines = mid_neurons.copy()
|
|
mid_neuron_outlines.set_fill(opacity = 0)
|
|
mid_neuron_outlines.set_stroke(YELLOW, 5)
|
|
|
|
def get_neurons_decimal_anims(neuron):
|
|
return [
|
|
ChangingDecimal(
|
|
neuron.decimal,
|
|
lambda a : neuron.get_fill_opacity(),
|
|
),
|
|
UpdateFromFunc(
|
|
neuron.decimal,
|
|
lambda m : m.set_fill(
|
|
WHITE if neuron.get_fill_opacity() < 0.8 else BLACK
|
|
)
|
|
)
|
|
]
|
|
|
|
self.play(ShowCreation(deriv_rect))
|
|
self.play(OldLaggedStart(
|
|
ShowCreationThenDestruction,
|
|
mid_neuron_outlines
|
|
))
|
|
self.play(*it.chain(*[
|
|
[
|
|
ApplyMethod(n.set_fill, None, random.random()),
|
|
] + get_neurons_decimal_anims(n)
|
|
for n in mid_neurons
|
|
]), run_time = 4, rate_func = there_and_back)
|
|
self.play(faded_edges.restore)
|
|
self.play(
|
|
OldLaggedStart(
|
|
GrowFromCenter, layer.neurons,
|
|
run_time = 1
|
|
),
|
|
OldLaggedStart(ShowCreation, edges),
|
|
FadeOut(to_fade)
|
|
)
|
|
for x in range(3):
|
|
for edges_out in all_edges_out:
|
|
self.play(ShowCreationThenDestruction(edges_out))
|
|
self.wait()
|
|
|
|
####
|
|
|
|
def replace_subscript(self, label, tex):
|
|
subscript = label[-1]
|
|
new_subscript = TexMobject(tex)[0]
|
|
new_subscript.replace(subscript, dim_to_match = 1)
|
|
label.remove(subscript)
|
|
label.add(new_subscript)
|
|
return label
|
|
|
|
def get_chain_rule(self, *tex):
|
|
chain_rule = TexMobject(*tex)
|
|
chain_rule.scale(0.8)
|
|
chain_rule.to_corner(UP+LEFT)
|
|
chain_rule.set_color_by_tex_to_color_map({
|
|
"C_0" : self.cost_color,
|
|
"z^" : self.z_color,
|
|
"w^" : self.w_label.get_color()
|
|
})
|
|
return chain_rule
|
|
|
|
class ThatsPrettyMuchIt(TeacherStudentsScene):
|
|
def construct(self):
|
|
self.teacher_says(
|
|
"That's pretty \\\\ much it!",
|
|
target_mode = "hooray",
|
|
run_time = 1,
|
|
)
|
|
self.wait(2)
|
|
|
|
class PatYourselfOnTheBack(TeacherStudentsScene):
|
|
def construct(self):
|
|
self.teacher_says(
|
|
"Pat yourself on \\\\ the back!",
|
|
target_mode = "hooray"
|
|
)
|
|
self.change_student_modes(*["hooray"]*3)
|
|
self.wait(3)
|
|
|
|
class ThatsALotToThinkAbout(TeacherStudentsScene):
|
|
def construct(self):
|
|
self.teacher_says(
|
|
"That's a lot to \\\\ think about!",
|
|
target_mode = "surprised"
|
|
)
|
|
self.change_student_modes(*["thinking"]*3)
|
|
self.wait(4)
|
|
|
|
class LayersOfComplexity(Scene):
|
|
def construct(self):
|
|
chain_rule_equations = self.get_chain_rule_equations()
|
|
chain_rule_equations.to_corner(UP+RIGHT)
|
|
|
|
brace = Brace(chain_rule_equations, LEFT)
|
|
arrow = Vector(LEFT, color = RED)
|
|
arrow.next_to(brace, LEFT)
|
|
gradient = TexMobject("\\nabla C")
|
|
gradient.scale(2)
|
|
gradient.set_color(RED)
|
|
gradient.next_to(arrow, LEFT)
|
|
|
|
self.play(OldLaggedStart(FadeIn, chain_rule_equations))
|
|
self.play(GrowFromCenter(brace))
|
|
self.play(GrowArrow(arrow))
|
|
self.play(Write(gradient))
|
|
self.wait()
|
|
|
|
|
|
def get_chain_rule_equations(self):
|
|
w_deriv = TexMobject(
|
|
"{\\partial C", "\\over", "\\partial w^{(l)}_{jk}}",
|
|
"=",
|
|
"a^{(l-1)}_k",
|
|
"\\sigma'(z^{(l)}_j)",
|
|
"{\\partial C", "\\over", "\\partial a^{(l)}_j}",
|
|
)
|
|
lil_rect = SurroundingRectangle(
|
|
VGroup(*w_deriv[-3:]),
|
|
buff = 0.5*SMALL_BUFF
|
|
)
|
|
a_deriv = TexMobject(
|
|
"\\sum_{j = 0}^{n_{l+1} - 1}",
|
|
"w^{(l+1)}_{jk}",
|
|
"\\sigma'(z^{(l+1)}_j)",
|
|
"{\\partial C", "\\over", "\\partial a^{(l+1)}_j}",
|
|
)
|
|
or_word = TextMobject("or")
|
|
last_a_deriv = TexMobject("2(a^{(L)}_j - y_j)")
|
|
|
|
a_deriv.next_to(w_deriv, DOWN, LARGE_BUFF)
|
|
or_word.next_to(a_deriv, DOWN)
|
|
last_a_deriv.next_to(or_word, DOWN, MED_LARGE_BUFF)
|
|
|
|
big_rect = SurroundingRectangle(VGroup(a_deriv, last_a_deriv))
|
|
arrow = Arrow(
|
|
lil_rect.get_corner(DOWN+LEFT),
|
|
big_rect.get_top(),
|
|
)
|
|
|
|
result = VGroup(
|
|
w_deriv, lil_rect, arrow,
|
|
big_rect, a_deriv, or_word, last_a_deriv
|
|
)
|
|
for expression in w_deriv, a_deriv, last_a_deriv:
|
|
expression.set_color_by_tex_to_color_map({
|
|
"C" : RED,
|
|
"z^" : GREEN,
|
|
"w^" : BLUE,
|
|
"b^" : MAROON_B,
|
|
})
|
|
return result
|
|
|
|
class SponsorFrame(PiCreatureScene):
|
|
def construct(self):
|
|
morty = self.pi_creature
|
|
screen = ScreenRectangle(height = 5)
|
|
screen.to_corner(UP+LEFT)
|
|
url = TextMobject("http://3b1b.co/crowdflower")
|
|
url.move_to(screen, UP+LEFT)
|
|
screen.shift(LARGE_BUFF*DOWN)
|
|
arrow = Arrow(LEFT, RIGHT, color = WHITE)
|
|
arrow.next_to(url, RIGHT)
|
|
|
|
t_shirt_words = TextMobject("Free T-Shirt")
|
|
t_shirt_words.scale(1.5)
|
|
t_shirt_words.set_color(YELLOW)
|
|
t_shirt_words.next_to(morty, UP, aligned_edge = RIGHT)
|
|
|
|
human_in_the_loop = TextMobject("Human-in-the-loop approach")
|
|
human_in_the_loop.next_to(screen, DOWN)
|
|
|
|
self.play(
|
|
morty.change, "hooray", t_shirt_words,
|
|
Write(t_shirt_words, run_time = 2)
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
morty.change, "raise_right_hand", screen,
|
|
ShowCreation(screen)
|
|
)
|
|
self.play(
|
|
t_shirt_words.scale, 1./1.5,
|
|
t_shirt_words.next_to, arrow, RIGHT
|
|
)
|
|
self.play(Write(url))
|
|
self.play(GrowArrow(arrow))
|
|
self.wait(2)
|
|
self.play(morty.change, "thinking", url)
|
|
self.wait(3)
|
|
self.play(Write(human_in_the_loop))
|
|
self.play(morty.change, "happy", url)
|
|
self.play(morty.look_at, screen)
|
|
self.wait(7)
|
|
t_shirt_words_outline = t_shirt_words.copy()
|
|
t_shirt_words_outline.set_fill(opacity = 0)
|
|
t_shirt_words_outline.set_stroke(GREEN, 3)
|
|
self.play(
|
|
morty.change, "hooray", t_shirt_words,
|
|
OldLaggedStart(ShowCreation, t_shirt_words_outline),
|
|
)
|
|
self.play(FadeOut(t_shirt_words_outline))
|
|
self.play(OldLaggedStart(
|
|
Indicate, url,
|
|
rate_func = wiggle,
|
|
color = PINK,
|
|
run_time = 3
|
|
))
|
|
self.wait(3)
|
|
|
|
class NN3PatreonThanks(PatreonThanks):
|
|
CONFIG = {
|
|
"specific_patrons" : [
|
|
"Randall Hunt",
|
|
"Burt Humburg",
|
|
"CrypticSwarm",
|
|
"Juan Benet",
|
|
"David Kedmey",
|
|
"Michael Hardwicke",
|
|
"Nathan Weeks",
|
|
"Marcus Schiebold",
|
|
"Ali Yahya",
|
|
"William",
|
|
"Mayank M. Mehrotra",
|
|
"Lukas Biewald",
|
|
"Samantha D. Suplee",
|
|
"Yana Chernobilsky",
|
|
"Kaustuv DeBiswas",
|
|
"Kathryn Schmiedicke",
|
|
"Yu Jun",
|
|
"Dave Nicponski",
|
|
"Damion Kistler",
|
|
"Markus Persson",
|
|
"Yoni Nazarathy",
|
|
"Ed Kellett",
|
|
"Joseph John Cox",
|
|
"Luc Ritchie",
|
|
"1stViewMaths",
|
|
"Jacob Magnuson",
|
|
"Mark Govea",
|
|
"Dagan Harrington",
|
|
"Clark Gaebel",
|
|
"Eric Chow",
|
|
"Mathias Jansson",
|
|
"Robert Teed",
|
|
"Pedro Perez Sanchez",
|
|
"David Clark",
|
|
"Michael Gardner",
|
|
"Harsev Singh",
|
|
"Mads Elvheim",
|
|
"Erik Sundell",
|
|
"Xueqi Li",
|
|
"Dr. David G. Stork",
|
|
"Tianyu Ge",
|
|
"Ted Suzman",
|
|
"Linh Tran",
|
|
"Andrew Busey",
|
|
"John Haley",
|
|
"Ankalagon",
|
|
"Eric Lavault",
|
|
"Boris Veselinovich",
|
|
"Julian Pulgarin",
|
|
"Jeff Linse",
|
|
"Cooper Jones",
|
|
"Ryan Dahl",
|
|
"Jason Hise",
|
|
"Meshal Alshammari",
|
|
"Bernd Sing",
|
|
"Mustafa Mahdi",
|
|
"Mathew Bramson",
|
|
"Jerry Ling",
|
|
"Vecht",
|
|
"Shimin Kuang",
|
|
"Rish Kundalia",
|
|
"Achille Brighton",
|
|
"Ripta Pasay",
|
|
],
|
|
"max_patron_group_size" : 25,
|
|
"patron_scale_val" : 0.7,
|
|
}
|
|
|
|
class Thumbnail(PreviewLearning):
|
|
CONFIG = {
|
|
"layer_sizes" : [8, 6, 6, 4],
|
|
"network_mob_config" : {
|
|
"neuron_radius" : 0.3,
|
|
"neuron_to_neuron_buff" : MED_SMALL_BUFF,
|
|
"include_output_labels" : False,
|
|
},
|
|
"stroke_width_exp" : 1,
|
|
"max_stroke_width" : 5,
|
|
"title" : "Backpropagation",
|
|
"network_scale_val" : 0.8,
|
|
}
|
|
def construct(self):
|
|
self.color_network_edges()
|
|
network_mob = self.network_mob
|
|
network_mob.scale(
|
|
self.network_scale_val,
|
|
about_point = network_mob.get_bottom()
|
|
)
|
|
network_mob.activate_layers(np.random.random(self.layer_sizes[0]))
|
|
|
|
for edge in it.chain(*network_mob.edge_groups):
|
|
arrow = Arrow(
|
|
edge.get_end(), edge.get_start(),
|
|
buff = 0,
|
|
tip_length = 0.1,
|
|
color = edge.get_color()
|
|
)
|
|
network_mob.add(arrow.tip)
|
|
|
|
arrow = Vector(
|
|
3*LEFT,
|
|
tip_length = 0.75,
|
|
rectangular_stem_width = 0.2,
|
|
color = BLUE,
|
|
)
|
|
arrow.next_to(network_mob.edge_groups[1], UP, MED_LARGE_BUFF)
|
|
|
|
network_mob.add(arrow)
|
|
self.add(network_mob)
|
|
|
|
title = TextMobject(self.title)
|
|
title.scale(2)
|
|
title.to_edge(UP)
|
|
self.add(title)
|
|
|
|
class SupplementThumbnail(Thumbnail):
|
|
CONFIG = {
|
|
"title" : "Backpropagation \\\\ calculus",
|
|
"network_scale_val" : 0.7,
|
|
}
|
|
def construct(self):
|
|
Thumbnail.construct(self)
|
|
self.network_mob.to_edge(DOWN, buff = MED_SMALL_BUFF)
|
|
|
|
for layer in self.network_mob.layers:
|
|
for neuron in layer.neurons:
|
|
partial = TexMobject("\\partial")
|
|
partial.move_to(neuron)
|
|
self.remove(neuron)
|
|
self.add(partial)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|