Up to weight matrix introduction of nn

This commit is contained in:
Grant Sanderson 2017-09-29 14:17:13 -07:00
parent 2d20084573
commit 7f69b6aa93
4 changed files with 656 additions and 14 deletions

View file

@ -36,6 +36,8 @@ python extract_scene.py -p example_scenes.py SquareToCircle
`-p` gives a preview of an animation, `-w` will write it to a file, and `-s` will show/save the final image in the animation.
You will probably want to change the MOVIE_DIR constant to be whatever direction you want video files to output to.
Look through the old_projects folder to see the code for previous 3b1b videos.
While developing a scene, the `-s` flag is helpful to just see what things look like at the end without having to generate the full animation. It can also be helpful to put `self.force_skipping()` at the top of the construct method, and `self.revert_to_original_skipping_status()` before the portion of the scene that you want to test, and run with the `-p` flag to just see a preview of one part of the scene.

View file

@ -330,6 +330,10 @@ class LaggedStart(Animation):
anim.update(alpha)
return self
def clean_up(self, *args, **kwargs):
for anim in self.subanimations:
anim.clean_up(*args, **kwargs)
class DelayByOrder(Animation):
"""
Modifier of animation.

View file

@ -153,6 +153,8 @@ class NetworkMobject(VGroup):
radius = self.neuron_radius,
stroke_color = self.neuron_stroke_color,
stroke_width = self.neuron_stroke_width,
fill_color = self.neuron_fill_color,
fill_opacity = 0,
)
for x in range(n_neurons)
])
@ -277,13 +279,17 @@ class MNistNetworkMobject(NetworkMobject):
class NetworkScene(Scene):
CONFIG = {
"layer_sizes" : [8, 6, 6, 4],
"network_mob_config" : {},
}
def setup(self):
self.add_network()
def add_network(self):
self.network = Network(sizes = self.layer_sizes)
self.network_mob = NetworkMobject(self.network)
self.network_mob = NetworkMobject(
self.network,
**self.network_mob_config
)
self.add(self.network_mob)
def feed_forward(self, input_vector, false_confidence = False, added_anims = None):
@ -731,12 +737,12 @@ class PreviewMNistNetwork(NetworkScene):
def feed_in_image(self, in_vect):
image = PixelsFromVect(in_vect)
image.next_to(self.network_mob, LEFT, LARGE_BUFF, UP)
big_rect = SurroundingRectangle(image, color = BLUE)
image_rect = SurroundingRectangle(image, color = BLUE)
start_neurons = self.network_mob.layers[0].neurons.copy()
start_neurons.set_stroke(WHITE, width = 0)
start_neurons.set_fill(WHITE, 0)
self.play(FadeIn(image), FadeIn(big_rect))
self.play(FadeIn(image), FadeIn(image_rect))
self.feed_forward(in_vect, added_anims = [
self.get_image_to_layer_one_animation(image, start_neurons)
])
@ -749,10 +755,13 @@ class PreviewMNistNetwork(NetworkScene):
self.network_mob.output_labels[n],
))
self.play(ShowCreation(rect))
self.play(FadeOut(rect))
self.reset_display(rect, image, image_rect)
def reset_display(self, answer_rect, image, image_rect):
self.play(FadeOut(answer_rect))
self.play(
FadeOut(image),
FadeOut(big_rect),
FadeOut(image_rect),
self.network_mob.deactivate_layers,
)
@ -2787,23 +2796,33 @@ class ContinualEdgeUpdate(ContinualAnimation):
def __init__(self, network_mob, **kwargs):
n_cycles = 5
edges = VGroup(*it.chain(*network_mob.edge_groups))
self.move_to_targets = []
for edge in edges:
edge.colors = [edge.get_color()] + [
edge.colors = [
random.choice([GREEN, GREEN, GREEN, RED])
for x in range(n_cycles)
]
edge.widths = [edge.get_stroke_width()] + [
edge.widths = [
3*random.random()**7
for x in range(n_cycles)
]
edge.cycle_time = 1 + random.random()
edge.generate_target()
edge.target.set_stroke(edge.colors[0], edge.widths[0])
self.move_to_targets.append(MoveToTarget(edge))
self.edges = edges
ContinualAnimation.__init__(self, edges, **kwargs)
def update_mobject(self, dt):
if self.internal_time < 1:
alpha = smooth(self.internal_time)
for move_to_target in self.move_to_targets:
move_to_target.update(alpha)
return
for edge in self.edges:
t = self.internal_time/edge.cycle_time
alpha = (self.internal_time%edge.cycle_time)/edge.cycle_time
t = (self.internal_time-1)/edge.cycle_time
alpha = ((self.internal_time-1)%edge.cycle_time)/edge.cycle_time
low_n = int(t)%len(edge.colors)
high_n = int(t+1)%len(edge.colors)
color = interpolate_color(edge.colors[low_n], edge.colors[high_n], alpha)
@ -3029,6 +3048,617 @@ class ShowRemainingNetwork(IntroduceWeights):
*added_anims
)
class ImagineSettingByHand(Scene):
def construct(self):
randy = Randolph()
randy.scale(0.7)
randy.to_corner(DOWN+LEFT)
bubble = randy.get_bubble()
network_mob = NetworkMobject(
Network(sizes = [8, 6, 6, 4]),
neuron_stroke_color = WHITE
)
network_mob.scale(0.7)
network_mob.move_to(bubble.get_bubble_center())
network_mob.shift(MED_SMALL_BUFF*RIGHT + SMALL_BUFF*(UP+RIGHT))
self.add(randy, bubble, network_mob)
self.add(ContinualEdgeUpdate(network_mob))
self.play(randy.change, "pondering")
self.dither()
self.play(Blink(randy))
self.dither()
self.play(randy.change, "horrified", network_mob)
self.play(Blink(randy))
self.dither(10)
class WhenTheNetworkFails(MoreHonestMNistNetworkPreview):
CONFIG = {
"network_mob_config" : {"layer_to_layer_buff" : 2}
}
def construct(self):
self.setup_network_mob()
self.black_box()
self.incorrect_classification()
self.ask_about_weights()
def setup_network_mob(self):
self.network_mob.scale(0.8)
self.network_mob.to_edge(DOWN)
def black_box(self):
network_mob = self.network_mob
layers = VGroup(*network_mob.layers[1:3])
box = SurroundingRectangle(
layers,
stroke_color = WHITE,
fill_color = BLACK,
fill_opacity = 0.8,
)
words = TextMobject("...rather than treating this as a black box")
words.next_to(box, UP, LARGE_BUFF)
self.play(
Write(words, run_time = 2),
DrawBorderThenFill(box)
)
self.dither()
self.play(*map(FadeOut, [words, box]))
def incorrect_classification(self):
network = self.network
training_data, validation_data, test_data = load_data_wrapper()
for in_vect, result in test_data[20:]:
network_answer = np.argmax(network.feedforward(in_vect))
if network_answer != result:
break
self.feed_in_image(in_vect)
wrong = TextMobject("Wrong!")
wrong.highlight(RED)
wrong.next_to(self.network_mob.layers[-1], UP+RIGHT)
self.play(Write(wrong, run_time = 1))
def ask_about_weights(self):
question = TextMobject(
"What weights are used here?\\\\",
"What are they doing?"
)
question.next_to(self.network_mob, UP)
self.add(ContinualEdgeUpdate(self.network_mob))
self.play(Write(question))
self.dither(10)
###
def reset_display(self, *args):
pass
class EvenWhenItWorks(TeacherStudentsScene):
def construct(self):
self.teacher_says(
"Even when it works,\\\\",
"dig into why."
)
self.change_student_modes(*["pondering"]*3)
self.dither(7)
class IntroduceWeightMatrix(NetworkScene):
CONFIG = {
"network_mob_config" : {
"neuron_stroke_color" : WHITE,
"neuron_fill_color" : WHITE,
"neuron_radius" : 0.35,
"layer_to_layer_buff" : 2,
},
"layer_sizes" : [8, 6],
}
def construct(self):
self.setup_network_mob()
self.show_weighted_sum()
self.organize_activations_into_column()
self.organize_weights_as_matrix()
self.show_meaning_of_matrix_row()
self.connect_weighted_sum_to_matrix_multiplication()
self.add_bias_vector()
self.apply_sigmoid()
self.write_clean_final_expression()
def setup_network_mob(self):
self.network_mob.to_edge(LEFT, buff = LARGE_BUFF)
self.network_mob.layers[1].neurons.shift(0.02*RIGHT)
def show_weighted_sum(self):
self.fade_many_neurons()
self.activate_first_layer()
self.show_first_neuron_weighted_sum()
self.add_bias()
self.add_sigmoid()
self.dither()
##
def fade_many_neurons(self):
anims = []
neurons = self.network_mob.layers[1].neurons
for neuron in neurons[1:]:
neuron.save_state()
neuron.edges_in.save_state()
anims += [
neuron.fade, 0.8,
neuron.set_fill, None, 0,
neuron.edges_in.fade, 0.8,
]
anims += [
Animation(neurons[0]),
Animation(neurons[0].edges_in),
]
self.play(*anims)
def activate_first_layer(self):
layer = self.network_mob.layers[0]
activations = 0.7*np.random.random(len(layer.neurons))
active_layer = self.network_mob.get_active_layer(0, activations)
a_labels = VGroup(*[
TexMobject("a^{(0)}_%d"%d)
for d in range(len(layer.neurons))
])
for label, neuron in zip(a_labels, layer.neurons):
label.scale(0.75)
label.move_to(neuron)
self.play(
Transform(layer, active_layer),
Write(a_labels, run_time = 2)
)
self.a_labels = a_labels
def show_first_neuron_weighted_sum(self):
neuron = self.network_mob.layers[1].neurons[0]
a_labels = VGroup(*self.a_labels[:2]).copy()
a_labels.generate_target()
w_labels = VGroup(*[
TexMobject("w_{0, %d}"%d)
for d in range(len(a_labels))
])
weighted_sum = VGroup()
symbols = VGroup()
for a_label, w_label in zip(a_labels.target, w_labels):
a_label.scale(1./0.75)
plus = TexMobject("+")
weighted_sum.add(w_label, a_label, plus)
symbols.add(plus)
weighted_sum.add(
TexMobject("\\cdots"),
TexMobject("+"),
TexMobject("w_{0, n}"),
TexMobject("a^{(0)}_n"),
)
weighted_sum.arrange_submobjects(RIGHT)
a1_label = TexMobject("a^{(1)}_0")
a1_label.next_to(neuron, RIGHT)
equals = TexMobject("=").next_to(a1_label, RIGHT)
weighted_sum.next_to(equals, RIGHT)
symbols.add(*weighted_sum[-4:-2])
w_labels.add(weighted_sum[-2])
a_labels.add(self.a_labels[-1].copy())
a_labels.target.add(weighted_sum[-1])
a_labels.add(VGroup(*self.a_labels[2:-1]).copy())
a_labels.target.add(VectorizedPoint(weighted_sum[-4].get_center()))
VGroup(a1_label, equals, weighted_sum).scale(
0.75, about_point = a1_label.get_left()
)
w_labels.highlight(GREEN)
w_labels.shift(0.6*SMALL_BUFF*DOWN)
a_labels.target.shift(0.5*SMALL_BUFF*UP)
self.play(
Write(a1_label),
Write(equals),
neuron.set_fill, None, 0.3,
run_time = 1
)
self.play(MoveToTarget(a_labels, run_time = 1.5))
self.play(
Write(w_labels),
Write(symbols),
)
self.a1_label = a1_label
self.a1_equals = equals
self.w_labels = w_labels
self.a_labels_in_sum = a_labels
self.symbols = symbols
self.weighted_sum = VGroup(w_labels, a_labels, symbols)
def add_bias(self):
weighted_sum = self.weighted_sum
bias = TexMobject("+\\,", "b_0")
bias.scale(0.75)
bias.next_to(weighted_sum, RIGHT, SMALL_BUFF)
bias.shift(0.5*SMALL_BUFF*DOWN)
name = TextMobject("Bias")
name.scale(0.75)
name.next_to(bias, DOWN, MED_LARGE_BUFF)
arrow = Arrow(name, bias, buff = SMALL_BUFF)
VGroup(name, arrow, bias).highlight(BLUE)
self.play(
FadeIn(name),
FadeIn(bias),
GrowArrow(arrow),
)
self.weighted_sum.add(bias)
self.bias = bias
self.bias_name = VGroup(name, arrow)
def add_sigmoid(self):
weighted_sum = self.weighted_sum
weighted_sum.generate_target()
sigma, lp, rp = mob = TexMobject("\\sigma\\big(\\big)")
# mob.scale(0.75)
sigma.move_to(weighted_sum.get_left())
sigma.shift(0.5*SMALL_BUFF*(DOWN+RIGHT))
lp.next_to(sigma, RIGHT, SMALL_BUFF)
weighted_sum.target.next_to(lp, RIGHT, SMALL_BUFF)
rp.next_to(weighted_sum.target, RIGHT, SMALL_BUFF)
name = TextMobject("Sigmoid")
name.next_to(sigma, UP, MED_LARGE_BUFF)
arrow = Arrow(name, sigma, buff = SMALL_BUFF)
sigmoid_name = VGroup(name, arrow)
VGroup(sigmoid_name, mob).highlight(YELLOW)
self.play(
FadeIn(mob),
MoveToTarget(weighted_sum),
MaintainPositionRelativeTo(self.bias_name, self.bias),
)
self.play(FadeIn(sigmoid_name))
self.sigma = sigma
self.sigma_parens = VGroup(lp, rp)
self.sigmoid_name = sigmoid_name
##
def organize_activations_into_column(self):
a_labels = self.a_labels.copy()
a_labels.generate_target()
column = a_labels.target
a_labels_in_sum = self.a_labels_in_sum
dots = TexMobject("\\vdots")
mid_as = VGroup(*column[2:-1])
Transform(mid_as, dots).update(1)
last_a = column[-1]
new_last_a = TexMobject(
last_a.get_tex_string().replace("7", "n")
)
new_last_a.replace(last_a)
Transform(last_a, new_last_a).update(1)
VGroup(
*column[:2] + [mid_as] + [column[-1]]
).arrange_submobjects(DOWN)
column.shift(DOWN + 3.5*RIGHT)
pre_brackets = self.get_brackets(a_labels)
post_bracketes = self.get_brackets(column)
pre_brackets.set_fill(opacity = 0)
self.play(LaggedStart(
Indicate, self.a_labels,
rate_func = there_and_back
))
self.play(
MoveToTarget(a_labels),
Transform(pre_brackets, post_bracketes),
run_time = 2
)
self.dither()
self.play(*[
LaggedStart(Indicate, mob, rate_func = there_and_back)
for mob in a_labels, a_labels_in_sum
])
self.dither()
self.a_column = a_labels
self.a_column_brackets = pre_brackets
def organize_weights_as_matrix(self):
a_column = self.a_column
a_column_brackets = self.a_column_brackets
w_brackets = a_column_brackets.copy()
w_brackets.next_to(a_column_brackets, LEFT, SMALL_BUFF)
lwb, rwb = w_brackets
w_labels = self.w_labels.copy()
w_labels.submobjects.insert(
2, self.symbols[-2].copy()
)
w_labels.generate_target()
w_labels.target.arrange_submobjects(RIGHT)
w_labels.target.next_to(a_column[0], LEFT, buff = 0.8)
lwb.next_to(w_labels.target, LEFT, SMALL_BUFF)
lwb.align_to(rwb, UP)
row_1, row_k = [
VGroup(*map(TexMobject, [
"w_{%s, 0}"%i,
"w_{%s, 1}"%i,
"\\cdots",
"w_{%s, k}"%i,
]))
for i in "1", "n"
]
dots_row = VGroup(*map(TexMobject, [
"\\vdots", "\\vdots", "\\ddots", "\\vdots"
]))
lower_rows = VGroup(row_1, dots_row, row_k)
lower_rows.scale(0.75)
last_row = w_labels.target
for row in lower_rows:
for target, mover in zip(last_row, row):
mover.move_to(target)
if "w" in mover.get_tex_string():
mover.highlight(GREEN)
row.next_to(last_row, DOWN, buff = 0.45)
last_row = row
self.play(
MoveToTarget(w_labels),
Write(w_brackets, run_time = 1)
)
self.play(FadeIn(
lower_rows,
run_time = 3,
submobject_mode = "lagged_start",
))
self.dither()
self.top_matrix_row = w_labels
self.lower_matrix_rows = lower_rows
self.matrix_brackets = w_brackets
def show_meaning_of_matrix_row(self):
row = self.top_matrix_row
edges = self.network_mob.layers[1].neurons[0].edges_in.copy()
edges.set_stroke(GREEN, 5)
rect = SurroundingRectangle(row, color = GREEN_B)
self.play(ShowCreation(rect))
for x in range(2):
self.play(LaggedStart(
ShowCreationThenDestruction, edges,
lag_ratio = 0.8
))
self.dither()
self.top_row_rect = rect
def connect_weighted_sum_to_matrix_multiplication(self):
a_column = self.a_column
a_brackets = self.a_column_brackets
top_row_rect = self.top_row_rect
column_rect = SurroundingRectangle(a_column)
equals = TexMobject("=")
equals.next_to(a_brackets, RIGHT)
result_brackets = a_brackets.copy()
result_terms = VGroup()
for i in 0, 1, 4, -1:
a = a_column[i]
if i == 4:
mob = TexMobject("\\vdots")
else:
# mob = Circle(radius = 0.2, color = YELLOW)
mob = TexMobject("?").scale(1.3).highlight(YELLOW)
result_terms.add(mob.move_to(a))
VGroup(result_brackets, result_terms).next_to(equals, RIGHT)
brace = Brace(
VGroup(self.w_labels, self.a_labels_in_sum), DOWN
)
arrow = Arrow(
brace.get_bottom(),
result_terms[0].get_top(),
buff = SMALL_BUFF
)
self.play(
GrowArrow(arrow),
GrowFromCenter(brace),
)
self.play(
Write(equals),
FadeIn(result_brackets),
)
self.play(ShowCreation(column_rect))
self.play(ReplacementTransform(
VGroup(top_row_rect, column_rect).copy(),
result_terms[0]
))
self.play(LaggedStart(
FadeIn, VGroup(*result_terms[1:])
))
self.dither(2)
self.play(*map(FadeOut, [
result_terms, result_brackets, equals,
arrow, brace,
top_row_rect, column_rect
]))
def add_bias_vector(self):
bias = self.bias
bias_name = self.bias_name
a_column_brackets = self.a_column_brackets
a_column = self.a_column
plus = TexMobject("+")
b_brackets = a_column_brackets.copy()
b_column = VGroup(*map(TexMobject, [
"b_0", "b_1", "\\vdots", "b_n",
]))
b_column.scale(0.85)
b_column.arrange_submobjects(DOWN, buff = 0.35)
b_column.move_to(a_column)
b_column.highlight(BLUE)
plus.next_to(a_column_brackets, RIGHT)
VGroup(b_brackets, b_column).next_to(plus, RIGHT)
bias_rect = SurroundingRectangle(bias)
self.play(ShowCreation(bias_rect))
self.play(FadeOut(bias_rect))
self.play(
Write(plus),
Write(b_brackets),
Transform(self.bias[1].copy(), b_column[0]),
run_time = 1
)
self.play(LaggedStart(
FadeIn, VGroup(*b_column[1:])
))
self.dither()
self.bias_plus = plus
self.b_brackets = b_brackets
self.b_column = b_column
def apply_sigmoid(self):
expression_bounds = VGroup(
self.matrix_brackets[0], self.b_brackets[1]
)
sigma = self.sigma.copy()
slp, srp = self.sigma_parens.copy()
big_lp, big_rp = parens = TexMobject("()")
parens.scale(3)
parens.stretch_to_fit_height(expression_bounds.get_height())
big_lp.next_to(expression_bounds, LEFT, SMALL_BUFF)
big_rp.next_to(expression_bounds, RIGHT, SMALL_BUFF)
parens.highlight(YELLOW)
self.play(
sigma.scale, 2,
sigma.next_to, big_lp, LEFT, SMALL_BUFF,
Transform(slp, big_lp),
Transform(srp, big_rp),
)
self.dither(2)
self.big_sigma_group = VGroup(VGroup(sigma), slp, srp)
def write_clean_final_expression(self):
self.fade_weighted_sum()
expression = TexMobject(
"\\textbf{a}^{(1)}",
"=",
"\\sigma",
"\\big(",
"\\textbf{W}",
"\\textbf{a}^{(0)}",
"+",
"\\textbf{b}",
"\\big)",
)
expression.highlight_by_tex_to_color_map({
"sigma" : YELLOW,
"big" : YELLOW,
"W" : GREEN,
"\\textbf{b}" : BLUE
})
expression.next_to(self.big_sigma_group, UP, LARGE_BUFF)
a1, equals, sigma, lp, W, a0, plus, b, rp = expression
neuron_anims = []
neurons = VGroup(*self.network_mob.layers[1].neurons[1:])
for neuron in neurons:
neuron_anims += [
neuron.restore,
neuron.set_fill, None, random.random()
]
neuron_anims += [
neuron.edges_in.restore
]
self.play(ReplacementTransform(
VGroup(
self.top_matrix_row, self.lower_matrix_rows,
self.matrix_brackets
).copy(),
VGroup(W),
))
self.play(ReplacementTransform(
VGroup(self.a_column, self.a_column_brackets).copy(),
VGroup(VGroup(a0)),
))
self.play(
ReplacementTransform(
VGroup(self.b_column, self.b_brackets).copy(),
VGroup(VGroup(b))
),
ReplacementTransform(
self.bias_plus.copy(), plus
)
)
self.play(ReplacementTransform(
self.big_sigma_group.copy(),
VGroup(sigma, lp, rp)
))
self.dither()
self.play(*neuron_anims, run_time = 2)
self.play(
ReplacementTransform(neurons.copy(), a1),
FadeIn(equals)
)
self.dither(2)
def fade_weighted_sum(self):
self.play(*map(FadeOut, [
self.a1_label, self.a1_equals,
self.sigma, self.sigma_parens,
self.weighted_sum,
self.bias_name,
self.sigmoid_name,
]))
###
def get_brackets(self, mob):
lb, rb = both = TexMobject("\\big[\\big]")
both.scale_to_fit_width(mob.get_width())
both.stretch_to_fit_height(1.2*mob.get_height())
lb.next_to(mob, LEFT, SMALL_BUFF)
rb.next_to(mob, RIGHT, SMALL_BUFF)
return both
class EoLA3Wrapper(Scene):
def construct(self):
pass

View file

@ -84,11 +84,7 @@ class Line(VMobject):
self.account_for_buff()
def account_for_buff(self):
anchors = self.get_anchors()
length = sum([
np.linalg.norm(a2-a1)
for a1, a2 in zip(anchors, anchors[1:])
])
length = self.get_arc_length()
if length < 2*self.buff or self.buff == 0:
return
buff_proportion = self.buff / length
@ -117,6 +113,16 @@ class Line(VMobject):
start, end = self.get_start_and_end()
return np.linalg.norm(start - end)
def get_arc_length(self):
if self.path_arc:
anchors = self.get_anchors()
return sum([
np.linalg.norm(a2-a1)
for a1, a2 in zip(anchors, anchors[1:])
])
else:
return self.get_length()
def get_start_and_end(self):
return self.get_start(), self.get_end()