3b1b-videos/_2024/transformers/mlp.py
2024-10-13 21:05:24 -05:00

2627 lines
91 KiB
Python

import torch
from scipy.stats import norm
from _2024.transformers.helpers import *
from manim_imports_ext import *
class LastTwoChapters(InteractiveScene):
def construct(self):
# Show last two chapters
frame = self.frame
self.camera.light_source.set_z(15)
self.set_floor_plane("xz")
thumbnails = self.get_thumbnails()
self.play(
LaggedStartMap(FadeIn, thumbnails, shift=UP, lag_ratio=0.5)
)
self.wait()
# Show transformer schematic
blocks = Group(self.get_block() for x in range(10))
blocks[1::2].stretch(2, 2).set_opacity(1)
blocks.arrange(OUT, buff=0.5)
blocks.set_depth(8, stretch=True)
blocks.set_opacity(0.8)
blocks.apply_depth_test()
trans_title = Text("Transformer", font_size=96)
trans_title.next_to(blocks, UP, buff=0.5)
self.play(
frame.animate.reorient(-32, 0, 0, (0.56, 2.48, 0.32), 12.75),
thumbnails.animate.scale(0.5).arrange(RIGHT, buff=2.0).to_edge(UP, buff=0.25),
LaggedStartMap(FadeIn, blocks, shift=0.25 * UP, scale=1.5, lag_ratio=0.1),
FadeIn(trans_title, UP),
)
self.wait()
# Break out transformer as sequence of blocks
att_blocks = blocks[0::2]
mlp_blocks = blocks[1::2]
att_title = Text("Attention", font_size=72)
mlp_title_full = Text("Multilayer Perceptron", font_size=72)
mlp_title = Text("MLP", font_size=72)
self.play(
frame.animate.reorient(-3, -2, 0, (0.23, 2.57, 0.3), 12.75),
trans_title.animate.shift(2 * UP),
att_blocks.animate.shift(4 * LEFT),
mlp_blocks.animate.shift(4 * RIGHT),
)
att_icon = self.get_att_icon(att_blocks[-1])
mlp_icon = self.get_mlp_icon(mlp_blocks[-1])
att_title.next_to(att_blocks[-1], UP, buff=0.75)
for title in [mlp_title, mlp_title_full]:
title.next_to(mlp_blocks[-1], UP, buff=0.75)
self.play(
FadeIn(att_icon, lag_ratio=1e-3),
FadeIn(att_title, UP),
trans_title.animate.scale(0.75).set_opacity(0.5)
)
self.wait()
self.play(
Write(mlp_icon),
FadeIn(mlp_title_full, UP),
)
self.wait()
self.play(
TransformMatchingStrings(mlp_title_full, mlp_title)
)
self.wait()
# Show sports facts
sport_facts = VGroup(
Text(line)
for line in Path(DATA_DIR, "athlete_sports.txt").read_text().split("\n")
)
for fact in sport_facts:
fact.next_to(trans_title, UP)
fact.shift(random.uniform(-3, 3) * RIGHT)
fact.shift(random.uniform(0, 3) * UP)
self.remove(mlp_icon, mlp_title)
self.play(
FadeOut(thumbnails),
FadeOut(trans_title),
LaggedStart(
(Succession(FadeIn(fact), fact.animate.scale(0.5).set_opacity(0).move_to(mlp_blocks))
for fact in sport_facts),
lag_ratio=0.15,
)
)
self.wait()
# Ask what is the MLP
rect = SurroundingRectangle(Group(mlp_blocks, mlp_title), buff=1.0)
rect.stretch(0.8, 1)
rect.match_z(mlp_blocks[-1])
question = Text("What are these?", font_size=90)
question.next_to(rect, UP, buff=3.0)
question.match_color(rect)
question.set_fill(border_width=0.5)
arrow = Arrow(question, rect)
arrow.match_color(rect)
self.play(
Group(att_blocks, att_title).animate.fade(0.5),
ShowCreation(rect),
Write(question),
GrowArrow(arrow),
)
self.wait()
def get_thumbnails(self):
folder = "/Users/grant/3Blue1Brown Dropbox/3Blue1Brown/videos/2024/transformers/Thumbnails"
images = [
ImageMobject(str(Path(folder, "Chapter5_TN5"))),
ImageMobject(str(Path(folder, "Chapter6_TN4"))),
]
thumbnails = Group(
Group(
SurroundingRectangle(image, buff=0).set_stroke(WHITE, 3),
image
)
for n, image in zip([5, 6], images)
)
thumbnails.set_height(3.5)
thumbnails.arrange(RIGHT, buff=1.0)
thumbnails.fix_in_frame()
return thumbnails
def get_att_icon(self, block, n_rows=8):
att_icon = Dot().get_grid(n_rows, n_rows)
att_icon.set_height(block.get_height() * 0.9)
att_icon.set_backstroke(BLACK, 0.5)
for dot in att_icon:
dot.set_fill(opacity=random.random()**5)
att_icon.move_to(block, OUT)
return att_icon
def get_mlp_icon(self, block, dot_buff=0.15, layer_buff=1.5, layer0_size=5):
layers = VGroup(
Dot().get_grid(layer0_size, 1, buff=dot_buff),
Dot().get_grid(2 * layer0_size, 1, buff=dot_buff),
Dot().get_grid(layer0_size, 1, buff=dot_buff),
)
layers.set_height(block.get_height() * 0.9)
layers.arrange(RIGHT, buff=layer_buff)
for layer in layers:
for dot in layer:
dot.set_fill(opacity=random.random())
layers.set_stroke(WHITE, 0.5)
lines = VGroup(
Line(dot1.get_center(), dot2.get_center(), buff=dot1.get_width() / 2)
for l1, l2 in zip(layers, layers[1:])
for dot1 in l1
for dot2 in l2
)
for line in lines:
line.set_stroke(
color=value_to_color(random.uniform(-10, 10)),
width=3 * random.random()**3
)
icon = VGroup(layers, lines)
icon.move_to(block, OUT)
return icon
def get_block(self, width=5, height=3, depth=1, color=GREY_D, opacity=0.8):
block = Cube(color=color, opacity=opacity)
block.deactivate_depth_test()
block.set_shape(width, height, depth)
block.set_shading(0.5, 0.5, 0.0)
block.sort(lambda p: np.dot(p, [-1, 1, 1]))
return block
class AltLastTwoChapters(LastTwoChapters):
def construct(self):
# Show last two chapters
thumbnails = self.get_thumbnails()
thumbnails.set_height(2.0)
thumbnails.arrange(RIGHT, buff=2.0)
thumbnails.to_edge(UP)
for n, thumbnail in zip([5, 6], thumbnails):
label = Text(f"Chapter {n}")
label.next_to(thumbnail, DOWN, SMALL_BUFF)
thumbnail.add(label)
self.play(
LaggedStartMap(FadeIn, thumbnails, shift=UP, lag_ratio=0.5)
)
self.wait()
# Focus on chapter 6
for thumbnail in thumbnails:
thumbnail.target = thumbnail.generate_target()
thumbnail.target.scale(1.25)
thumbnail.target[-1].scale(1.0 / 1.5).next_to(thumbnail.target[0], DOWN, SMALL_BUFF)
thumbnails[1].target.set_x(-2.85)
thumbnails[1].target.to_edge(UP, MED_SMALL_BUFF)
thumbnails[0].target.next_to(thumbnails[1].target, LEFT, buff=2.5)
self.play(
LaggedStartMap(MoveToTarget, thumbnails)
)
self.wait()
class MLPIcon(LastTwoChapters):
def construct(self):
# Add network
network = self.get_mlp_icon(Square(6), layer_buff=3.0, layer0_size=6)
self.play(Write(network, stroke_width=0.5, lag_ratio=1e-2, run_time=5))
self.wait()
# Propagate through
thick_layers = VGroup(network[1].family_members_with_points()).copy()
for line in thick_layers:
line.set_stroke(width=2 * line.get_width())
line.insert_n_curves(20)
self.play(LaggedStartMap(VShowPassingFlash, thick_layers, time_width=1.5, lag_ratio=5e-3, run_time=3))
self.wait()
class MLPStepsPreview(InteractiveScene):
def construct(self):
# Setup framing
background = FullScreenRectangle()
top_frame, low_frame = frames = Rectangle(7, 3.25).replicate(2)
frames.arrange(DOWN, buff=0.5)
frames.to_edge(LEFT)
frames.set_fill(BLACK, 1)
frames.set_stroke(WHITE, 2)
titles = VGroup(
VGroup(Text("Structure:"), Text("Easy")),
VGroup(Text("Emergent behavior:"), Text("Exceedingly challenging")),
)
for title, frame, color in zip(titles, frames, [GREEN, RED]):
title.scale(2)
for part in title:
part.set_max_width(6)
title.arrange(DOWN, buff=0.5, aligned_edge=LEFT)
title.next_to(frame, RIGHT, buff=0.5)
title[1].set_color(color)
titles[0].save_state()
top_frame.save_state()
top_frame.set_shape(8, 6).center().to_edge(LEFT)
titles[0].next_to(top_frame, RIGHT, buff=0.5)
self.add(background)
self.add(top_frame)
self.add(titles[0][0])
# Add all steps
arrows = Vector(2.2 * RIGHT).get_grid(1, 3, buff=0.25)
arrows.move_to(top_frame)
up_proj = WeightMatrix(shape=(10, 6))
down_proj = WeightMatrix(shape=(6, 10))
VGroup(up_proj, down_proj).match_width(arrows[0])
up_proj.next_to(arrows[0], UP, buff=MED_SMALL_BUFF)
down_proj.next_to(arrows[2], UP, buff=MED_SMALL_BUFF)
axes = Axes((-4, 4), (0, 4))
graph = axes.get_graph(lambda x: max(0, x))
graph.set_stroke(YELLOW, 5)
plot = VGroup(axes, graph)
plot.set_width(arrows[0].get_width() * 0.75)
plot.next_to(arrows[1], UP, buff=MED_SMALL_BUFF)
labels = VGroup(*map(Text, ["Linear", "ReLU", "Linear"]))
for label, arrow in zip(labels, arrows):
label.next_to(arrow, DOWN)
structure = VGroup(arrows, labels, VGroup(up_proj, plot, down_proj))
self.play(
LaggedStartMap(GrowArrow, arrows, lag_ratio=0.5),
LaggedStartMap(FadeIn, labels, shift=0.5 * RIGHT, lag_ratio=0.5),
Write(titles[0][1])
)
self.play(LaggedStart(
FadeIn(up_proj, shift=0.5 * UP),
FadeIn(down_proj, shift=0.5 * UP),
lag_ratio=0.5
))
self.play(FadeIn(plot, lag_ratio=1e-2))
self.wait(3)
# Reference emergent structure
self.play(
Restore(top_frame),
Restore(titles[0]),
structure.animate.set_width(0.9 * top_frame.saved_state.get_width()).move_to(top_frame.saved_state),
FadeIn(low_frame, DOWN),
FadeIn(titles[1][0], DOWN),
)
self.play(
Write(titles[1][1], stroke_color=RED)
)
# Data flying
kw = dict(font_size=16, shift_vect=0.5 * DOWN + 0.5 * RIGHT, word_shape=(5, 5))
data_modifying_matrix(self, up_proj, **kw)
data_modifying_matrix(self, down_proj, **kw)
self.wait()
# Swap out for toy example
toy_example_title = Text("Motivating Toy Example", font_size=54)
toy_example_title.next_to(titles[1][0], DOWN, MED_LARGE_BUFF, aligned_edge=LEFT)
strike = Line().replace(titles[1][0])
strike.set_stroke(RED, 8)
low_matrices = VGroup(up_proj, down_proj)
top_matrices = low_matrices.copy()
low_matrices.generate_target()
low_matrices.target.scale(1.75).arrange(RIGHT, buff=0.5)
low_matrices.target.move_to(low_frame, DOWN).shift(MED_SMALL_BUFF * UP)
self.play(
ShowCreation(strike),
FadeOut(titles[1][1]),
titles[1][0].animate.set_opacity(0.5)
)
self.add(top_matrices)
self.play(
MoveToTarget(low_matrices),
FadeIn(toy_example_title, DOWN)
)
self.wait()
# Write down fact
row_rect = SurroundingRectangle(low_matrices[0].get_rows()[0], buff=0.1)
col_rect = SurroundingRectangle(low_matrices[1].get_columns()[0], buff=0.1)
VGroup(row_rect, col_rect).set_stroke(WHITE, 1)
fact = Text("Michael Jordan plays Basketball", font_size=36)
fact.next_to(frames[1].get_top(), DOWN)
fact.align_to(low_matrices, LEFT)
mj, bb = fact["Michael Jordan"], fact["plays Basketball"]
mj_brace = Brace(mj, DOWN, buff=0.1)
bb_brace = Brace(bb, DOWN).match_y(mj_brace)
mj_arrow = Arrow(row_rect, mj_brace, buff=0.05)
bb_arrow = Arrow(col_rect.get_top(), bb_brace, buff=0.05)
row_cover = BackgroundRectangle(low_matrices[0].get_rows()[1:], buff=0.05)
col_cover = BackgroundRectangle(low_matrices[1].get_columns()[1:], buff=0.05)
VGroup(row_cover, col_cover).set_fill(BLACK, 0.75)
self.play(LaggedStart(
FadeIn(row_cover),
FadeIn(row_rect),
GrowFromCenter(mj_brace),
FadeIn(mj, 0.5 * UP)
))
self.play(
FadeIn(col_cover),
FadeIn(col_rect),
GrowArrow(bb_arrow),
GrowFromCenter(bb_brace),
FadeIn(bb, 0.5 * UP)
)
self.add(*low_matrices, row_cover, col_cover, row_rect, col_rect)
self.play(
RandomizeMatrixEntries(low_matrices[0]),
RandomizeMatrixEntries(low_matrices[1]),
)
self.wait()
class MatricesVsIntuition(InteractiveScene):
def construct(self):
# Add matrix
matrix = WeightMatrix(shape=(15, 15))
matrix.set_height(4)
matrix.to_edge(LEFT)
Text("Matrices filled with parameters\nlearned during gradient descent")
Text("Motivating examples which risk being\noversimplifications of what true models do")
self.add(matrix)
class BasicMLPWalkThrough(InteractiveScene):
random_seed = 1
def construct(self):
# Init camera settings
self.set_floor_plane("xz")
frame = self.frame
self.camera.light_source.set_z(15)
# Sequence of embeddings comes in to an MLP block
embedding_array = EmbeddingArray(shape=(6, 9))
embedding_array.set_width(10)
block = VCube(fill_color=GREY_D, fill_opacity=0.5)
block.sort(lambda p: p[2])
block[-1].set_fill(opacity=0)
block.set_stroke(GREY_B, 2, 0.25, behind=False)
block.set_shading(0.25, 0.25, 0.5)
block.set_shape(11, 4, 4)
block.move_to(0.5 * IN, IN)
block_title = Text("MLP", font_size=90)
block_title.next_to(block, UP)
frame.reorient(-21, -12, 0, (0.34, -0.94, -0.18), 9.79)
frame.set_field_of_view(30 * DEGREES)
self.add(block, block_title)
self.play(FadeIn(embedding_array, shift=2 * OUT))
self.wait()
# Highlight one vector
index = 3
emb = embedding_array.embeddings[index]
highlight_rect = SurroundingRectangle(emb)
embedding_array.target = embedding_array.generate_target()
embedding_array.target.set_stroke(width=0)
embedding_array.target.set_opacity(0.5)
embedding_array.target[0][index].set_backstroke(BLACK, 2)
embedding_array.target[0][index].set_opacity(1)
self.play(
MoveToTarget(embedding_array),
ShowCreation(highlight_rect),
)
self.wait()
# Reorient
rot_about_up = 89 * DEGREES
rot_about_left = 1 * DEGREES
up_emb = emb.copy() # For use down below
full_block = Group(block, embedding_array, highlight_rect, block_title)
full_block.target = full_block.generate_target()
full_block.target[0].set_depth(16, about_edge=IN, stretch=True)
full_block.target[0].set_height(5, about_edge=DOWN, stretch=True)
full_block.target.rotate(rot_about_up, UP)
full_block.target[:3].rotate(rot_about_left, LEFT)
full_block.target.scale(0.5)
full_block.target[3].rotate(90 * DEGREES, DOWN).next_to(full_block.target[0], UP, buff=0.5)
full_block.target.center().to_edge(DOWN, buff=0.75)
full_block.target[0][4].set_opacity(0.1)
self.play(
frame.animate.reorient(-3, -2, 0, (-0.0, -2.0, 0.01), 6.48),
MoveToTarget(full_block),
run_time=2
)
# Preview the sequence of operations
values = np.random.uniform(-10, 10, 9)
values[0] = 1.0
vects = VGroup(
NumericEmbedding(values=values, dark_color=GREY_B),
NumericEmbedding(values=np.clip(values, 0, np.inf), dark_color=GREY_B),
NumericEmbedding(length=6),
)
vects.set_width(emb.get_depth())
vects.arrange(RIGHT, buff=2.0)
vects.next_to(emb, RIGHT, buff=2.0)
arrows = VGroup(
Arrow(v1, v2)
for v1, v2 in zip([emb, *vects[:-1]], vects)
)
arrow_labels = VGroup(Text("Linear"), Text("ReLU"), Text("Linear"))
arrow_labels.scale(0.5)
phases = VGroup()
simple_phases = VGroup()
for arrow, label, vect in zip(arrows, arrow_labels, vects):
label.next_to(arrow, UP)
phases.add(VGroup(arrow, label, vect))
simple_phases.add(VGroup(arrow, vect))
self.play(
LaggedStartMap(FadeIn, vects, shift=RIGHT, lag_ratio=0.8),
LaggedStartMap(ShowCreation, arrows, lag_ratio=0.8),
LaggedStartMap(FadeIn, arrow_labels, lag_ratio=0.8),
)
self.wait()
# Show the sum
sum_circuit, output_emb = self.get_sum_circuit(emb, vects[-1])
self.play(
frame.animate.reorient(15, -4, 0, (0.82, -1.91, 0.04), 7.18),
ShowCreation(sum_circuit, lag_ratio=0.1),
run_time=2
)
self.play(
TransformFromCopy(emb, output_emb, path_arc=-30 * DEGREES),
TransformFromCopy(vects[2], output_emb, path_arc=-30 * DEGREES),
run_time=2
)
self.wait()
# Show all in parallel
simple_phases.add_to_back(highlight_rect)
simple_phases.add(VGroup(sum_circuit, output_emb))
simple_phase_copies = VGroup(
simple_phases.copy().match_z(emb)
for emb in embedding_array.embeddings
)
for sp_copy in simple_phase_copies:
for group in sp_copy[1:]:
arrow, vect = group
for entry in vect.get_entries():
dot = Dot().scale(0.5)
dot.match_color(entry)
dot.set_fill(opacity=0.5)
dot.move_to(entry)
entry.become(dot)
group.fade(0.5)
self.play(
frame.animate.reorient(0, -48, 0, (0.55, -2.21, 0.18), 7.05),
LaggedStart((
TransformFromCopy(simple_phases, sp_copy)
for sp_copy in simple_phase_copies
), lag_ratio=0.1),
FadeOut(block_title, time_span=(0, 1)),
run_time=3,
)
self.play(frame.animate.reorient(9, -15, 0, (0.55, -2.21, 0.18), 7.05), run_time=4)
self.play(frame.animate.reorient(-24, -16, 0, (0.18, -2.13, 0.09), 7.63), run_time=12)
block_title.next_to(block, UP)
self.play(
frame.animate.to_default_state(),
LaggedStartMap(FadeOut, simple_phase_copies, lag_ratio=0.1),
FadeIn(block_title),
run_time=2,
)
self.wait()
# Show MJ -> Basketball example
example_fact = TexText("``Michael Jordan plays Basketball''", font_size=60)
example_fact.to_edge(UP)
mj = TexText("Michael Jordan", font_size=36)
mj.next_to(emb, UL)
mj_lines = VGroup(
Line(char.get_bottom(), emb.get_top(), buff=0.1, path_arc=10 * DEGREES)
for char in mj
)
mj_lines.set_stroke(YELLOW, 1, 0.5)
basketball = TexText("Basketball", font_size=24)
basketball.next_to(vects[2], UP, buff=0.2)
self.play(Write(example_fact))
self.wait()
self.play(FadeTransform(example_fact[mj.get_tex()].copy(), mj))
self.play(Write(mj_lines, stroke_width=2, stroke_color=YELLOW_B, lag_ratio=1e-2))
self.wait()
mover = emb.copy()
for vect in vects:
self.play(Transform(mover, vect, rate_func=linear))
self.remove(mover)
self.wait()
self.play(FadeTransform(example_fact[basketball.get_tex()].copy(), basketball))
self.wait(2)
# Multiply by the up-projection
up_proj = WeightMatrix(shape=(9, 6))
up_proj.set_height(3)
up_proj.to_corner(UL)
up_emb.set_height(2)
up_emb.next_to(up_proj, RIGHT)
up_emb[-2:].set_fill(YELLOW) # Brackets
self.play(
phases[1:].animate.set_opacity(0.1),
sum_circuit.animate.set_stroke(opacity=0.1),
output_emb.animate.set_opacity(0.1),
FadeOut(mj),
FadeOut(mj_lines),
FadeOut(basketball),
FadeOut(example_fact),
)
self.wait()
self.play(TransformFromCopy(emb, up_emb))
self.play(FadeIn(up_proj, lag_ratio=0.01))
eq, rhs = show_matrix_vector_product(self, up_proj, up_emb)
self.wait()
data_modifying_matrix(self, up_proj, word_shape=(4, 7), fix_in_frame=True)
self.wait()
# Show machine
machine = MachineWithDials(
width=up_proj.get_width() + SMALL_BUFF,
height=up_proj.get_height() + SMALL_BUFF,
n_rows=8,
n_cols=9,
)
machine.move_to(up_proj)
self.play(FadeIn(machine))
self.play(machine.random_change_animation())
self.wait()
self.play(FadeOut(machine))
# Emphasize dot product with rows
n, m = up_proj.shape
n_rows_shown = 5
R_labels = VGroup(
Tex(R"\vec{\textbf{R}}_" + f"{{{n}}}")
for n in [*range(n_rows_shown - 1), "n"]
)
R_labels[-2].become(Tex(R"\vdots").replace(R_labels[-2], dim_to_match=1))
R_labels.arrange(DOWN, buff=0.5)
R_labels.match_height(up_proj)
R_labels.move_to(up_proj)
h_lines = VGroup(
Line(up_proj.get_brackets()[0], R_labels, buff=0.1),
Line(R_labels, up_proj.get_brackets()[1], buff=0.1),
)
h_lines.set_stroke(GREY_A, 2)
row_labels = VGroup(
VGroup(R_label, h_lines.copy().match_y(R_label))
for R_label in R_labels
)
row_matrix = VGroup(
up_proj.get_brackets().copy(),
row_labels
)
E_label = Tex(R"\vec{\textbf{E}}")
E_label.match_height(R_labels[0])
E_label.set_color(YELLOW)
E_label.move_to(up_emb)
E_col = VGroup(
up_emb[-2:].copy(),
Line(up_emb.get_top(), E_label, buff=0.1).set_stroke(GREY_A, 2),
E_label,
Line(E_label, up_emb.get_bottom(), buff=0.1).set_stroke(GREY_A, 2),
)
dot_prods = VGroup()
for n, R_label in enumerate(R_labels):
if n == len(R_labels) - 2:
dot_prod = R_label.copy()
else:
dot_prod = VGroup(
R_label.copy(),
Tex(R"\cdot"),
E_label.copy(),
)
dot_prod.arrange(RIGHT, buff=0.1)
dot_prod[-1].align_to(dot_prod[0][1], DOWN)
dot_prod.set_width(rhs.get_width() * 0.75)
dot_prod.move_to(R_label)
dot_prods.add(dot_prod)
dot_prods.move_to(rhs)
dot_prod_rhs = VGroup(
rhs.get_brackets().copy(),
dot_prods,
)
self.play(LaggedStart(
FadeOut(up_proj, scale=1.1),
FadeIn(row_matrix, scale=1.1),
FadeOut(up_emb, scale=1.1),
FadeIn(E_col, scale=1.1),
FadeOut(rhs, scale=1.1),
FadeIn(dot_prod_rhs[0], scale=1.1),
lag_ratio=0.1
))
self.wait()
for row_label, dot_prod in zip(row_labels, dot_prods):
R_label = row_label[0]
self.play(
TransformFromCopy(R_label, dot_prod[0]),
TransformFromCopy(R_label, dot_prod[1]),
TransformFromCopy(E_label, dot_prod[2]),
VShowPassingFlash(
Line(row_label.get_left(), row_label.get_right()).set_stroke(YELLOW, 5).insert_n_curves(100),
time_width=1.5
),
VShowPassingFlash(
Line(E_col.get_top(), E_col.get_bottom()).set_stroke(YELLOW, 5).insert_n_curves(100),
time_width=1.5
),
run_time=1
)
self.wait()
# First name Michael direction
row_rect = SurroundingRectangle(row_labels[0])
row_rect.set_stroke(GREY_BROWN, 2)
row_rect.set_fill(GREY_BROWN, 0.25)
row_eq = Tex("=").rotate(PI / 2)
row_eq.next_to(row_rect, UP, SMALL_BUFF)
first_name_label = Tex(R"\overrightarrow{\text{First Name Michael}}")
first_name_label.set_stroke(WHITE, 1)
first_name_label.match_width(row_rect)
first_name_label.next_to(row_eq, UP)
dot_prod = dot_prods[0]
dp_rect = SurroundingRectangle(dot_prod, buff=0.2)
dp_rect.set_stroke(RED)
dp_eq = Tex("=")
dp_eq.next_to(dp_rect, RIGHT, SMALL_BUFF)
mde_rhs = VGroup(
Tex(R"\approx 1 \quad \text{If } \vec{\textbf{E}} \text{ encodes ``First Name Michael''}"),
Tex(R"\le 0 \quad \text{If not}")
)
mde_rhs[0][R"\vec{\textbf{E}}"].set_color(YELLOW)
mde_rhs.scale(0.75)
mde_rhs.arrange(DOWN, buff=0.5, aligned_edge=LEFT)
rhs_brace = Brace(mde_rhs, LEFT)
rhs_brace.next_to(dp_eq, RIGHT, SMALL_BUFF)
mde_rhs.next_to(rhs_brace, RIGHT, MED_SMALL_BUFF)
self.play(
FadeIn(row_rect, scale=2),
FadeTransform(row_labels[0].copy(), first_name_label),
GrowFromCenter(row_eq),
frame.animate.reorient(0, 0, 0, (0.22, 0.54, 0.0), 9.27),
)
self.wait()
self.play(TransformFromCopy(row_rect.copy().set_fill(opacity=0), dp_rect))
self.play(
Write(dp_eq),
GrowFromCenter(rhs_brace),
FadeIn(mde_rhs),
)
self.wait()
# "First name Michael" + "Last name Jordan"
fn_tex = R"\overrightarrow{\text{F.N. Michael}}"
ln_tex = R"\overrightarrow{\text{L.N. Jordan}}"
name_sum_label = Tex(f"{fn_tex} + {ln_tex}")
name_sum_label.match_width(row_rect).scale(1.2)
name_sum_label.next_to(row_eq, UP)
self.play(
FadeTransform(first_name_label, name_sum_label[:21]),
FadeIn(name_sum_label[21:], shift=RIGHT, scale=2),
FadeOut(mde_rhs),
FadeOut(rhs_brace),
)
self.wait()
dist_rhs = VGroup(
Tex(R"(\vec{\textbf{M}} + \vec{\textbf{J}}) \cdot \vec{\textbf{E}}"),
Tex("="),
Tex(R"\vec{\textbf{M}} \cdot \vec{\textbf{E}} + \vec{\textbf{J}} \cdot \vec{\textbf{E}}"),
)
dist_rhs.scale(0.75)
dist_rhs.arrange(RIGHT, buff=0.2)
dist_rhs.next_to(dp_eq, RIGHT)
for part in dist_rhs:
part[R"\vec{\textbf{M}}"].set_color(RED_B)
part[R"\vec{\textbf{J}}"].set_color(RED)
part[R"\vec{\textbf{E}}"].set_color(YELLOW)
under_brace = Brace(dist_rhs[2])
two_condition = TexText(R"$\approx 2$ \; if $\vec{\textbf{E}}$ encodes ``Michael Jordan''")
two_condition[R"\vec{\textbf{E}}"].set_color(YELLOW)
else_condition = TexText(R"$\le 1$ \; Otherwise")
VGroup(two_condition, else_condition).scale(0.75)
two_condition.next_to(under_brace, DOWN, aligned_edge=LEFT)
else_condition.next_to(two_condition, DOWN, MED_LARGE_BUFF, aligned_edge=LEFT)
self.play(LaggedStart(
FadeTransformPieces(name_sum_label[:21].copy(), dist_rhs[0][1:3]),
FadeTransformPieces(name_sum_label[21].copy(), dist_rhs[0][3]),
FadeTransformPieces(name_sum_label[22:].copy(), dist_rhs[0][4:6]),
FadeTransformPieces(dot_prod[1:].copy(), dist_rhs[0][7:]),
FadeIn(dist_rhs[0][0]),
FadeIn(dist_rhs[0][6]),
lag_ratio=0.2
))
self.wait()
self.play(
TransformMatchingStrings(dist_rhs[0].copy(), dist_rhs[2], lag_ratio=0.01, path_arc=-45 * DEGREES),
Write(dist_rhs[1])
)
self.wait()
self.play(
frame.animate.set_y(0.5),
GrowFromCenter(under_brace),
FadeIn(two_condition, DOWN)
)
self.wait()
self.play(FadeIn(else_condition, DOWN))
self.wait(2)
# Go back to the numbers
for entry in rhs.get_entries():
entry.set_value(np.random.uniform(-10, 10))
rhs.get_entries()[0].set_value(2.0)
self.play(
LaggedStart(*map(FadeOut, [
name_sum_label, row_eq, row_rect,
dp_rect, dp_eq, dist_rhs, under_brace,
two_condition, else_condition,
]), lag_ratio=0.1, run_time=1),
frame.animate.reorient(0, 0, 0, (-0.06, -0.06, 0.0), 8.27),
)
self.play(
FadeOut(row_matrix),
FadeIn(up_proj),
FadeOut(E_col),
FadeIn(up_emb),
FadeOut(dot_prod_rhs),
FadeIn(rhs),
)
# Show other rows
questions = VGroup(*map(Text, [
"Blah",
"Is it English?",
"Part of source code?",
"European country?",
"In quotation marks?",
"Something metallic?",
"A four-legged animal?",
]))
questions.scale(0.75)
rows = up_proj.get_rows()
rhs_entries = rhs.get_entries()
last_question = VGroup()
last_rect = VectorizedPoint(rows[1].get_top())
for index in range(1, 7):
for mob in [rows, rhs_entries]:
mob.target = mob.generate_target()
mob.target.set_opacity(0.25)
mob.target[index].set_opacity(1)
row_rect = SurroundingRectangle(rows[index])
row_rect.set_stroke(PINK, 2)
question = questions[index]
question.next_to(rows[index], UP, buff=0.15)
question.set_backstroke(BLACK, 3)
self.play(
MoveToTarget(rows),
MoveToTarget(rhs_entries),
FadeOut(last_question),
FadeIn(question),
FadeTransform(last_rect, row_rect, time_span=(0, 0.75)),
run_time=1.0
)
self.wait(0.5)
last_question = question
last_rect = row_rect
self.play(
rows.animate.set_opacity(1),
rhs.animate.set_opacity(1),
FadeOut(last_question),
FadeOut(last_rect),
)
self.wait()
# Add a bias
plus = Tex("+")
plus.next_to(up_emb, RIGHT)
bias = WeightMatrix(shape=(9, 1), ellipses_col=None)
bias.get_entries()[0].set_value(-1).set_color(RED)
bias.match_height(up_proj)
bias.next_to(plus)
bias_name = Text("Bias")
bias_name.next_to(bias, UP)
eq.target = eq.generate_target()
eq.target.next_to(bias, RIGHT)
rhs.target = vects[0].copy()
rhs.target.replace(rhs, dim_to_match=1)
rhs.target.next_to(eq.target, RIGHT)
self.play(
Write(plus),
FadeIn(bias, lag_ratio=0.1),
MoveToTarget(eq),
MoveToTarget(rhs),
)
self.wait()
self.play(
frame.animate.scale(1.1, about_edge=DOWN),
Write(bias_name),
)
self.wait()
# Emphasize the parameters are learned from data
data_modifying_matrix(self, bias, word_shape=(5, 1), alpha_maxes=(0.4, 0.9), fix_in_frame=True)
bias.get_entries()[0].set_value(-1).set_color(RED)
# Pull up the MJ example again
fe_rect = SurroundingRectangle(rhs.get_entries()[0], buff=0.1) # fe = First entry
fe_rect.set_stroke(RED, 3)
fe_eq = Tex("=")
fe_eq.next_to(fe_rect, RIGHT, SMALL_BUFF)
fe_expr = VGroup(dist_rhs[2].copy(), Tex("- 1"))
fe_expr[1].set_height(fe_expr[0].get_height() * 0.8)
fe_expr.arrange(RIGHT)
fe_expr.next_to(fe_eq, RIGHT)
bias_rect = SurroundingRectangle(bias.get_entries()[0])
self.play(
ShowCreation(fe_rect),
FadeIn(fe_eq, RIGHT),
Write(fe_expr)
)
self.wait()
self.play(ShowCreation(bias_rect))
self.wait()
self.play(bias_rect.animate.surround(fe_expr[1]))
self.wait()
self.play(bias_rect.animate.surround(fe_expr))
self.wait()
# Show what it means, but now shifted
conditions = VGroup(
TexText(R"$\approx 1$ \; if $\vec{\textbf{E}}$ encodes ``Michael Jordan''"),
TexText(R"$\le 0$ \; Otherwise"),
)
conditions[0][R"\vec{\textbf{E}}"].set_color(YELLOW)
conditions.scale(0.75)
conditions.arrange(DOWN, buff=0.5, aligned_edge=LEFT)
under_brace = Brace(fe_expr, DOWN)
conditions.next_to(under_brace, DOWN, aligned_edge=LEFT)
self.play(
FadeOut(bias_rect),
GrowFromCenter(under_brace),
FadeIn(conditions[0], DOWN)
)
self.wait()
self.play(FadeIn(conditions[1], 0.25 * DOWN))
self.wait(2)
self.play(
frame.animate.reorient(0, 0, 0, (-2.5, 0.44, 0.0), 9.33),
LaggedStart(*map(FadeOut, [
fe_rect, fe_eq, fe_expr,
under_brace, *conditions
]))
)
# Show the matrix size
up_proj.refresh_bounding_box()
row_rects = VGroup(
SurroundingRectangle(row, buff=0.1)
for row in up_proj.get_rows()
)
row_rects.set_stroke(WHITE, 1)
row_rects.set_fill(GREY_C, 0.25)
row_rects[-2].match_width(row_rects, stretch=True)
over_brace = Brace(row_rects[0], UP, buff=SMALL_BUFF)
d_model = 12288
row_size = Integer(d_model)
row_size.next_to(over_brace, UP)
side_brace = Brace(row_rects, LEFT)
num_rows = Integer(4 * d_model)
num_rows.next_to(side_brace, LEFT)
num_rows_expr = Tex(R"4 \times 12{,}288")
num_rows_expr.next_to(side_brace, LEFT)
self.play(
FadeIn(row_rects, lag_ratio=0.5),
GrowFromCenter(side_brace),
CountInFrom(num_rows)
)
self.wait()
self.play(FadeTransform(num_rows, num_rows_expr))
self.wait()
self.play(
FadeTransform(num_rows_expr["12{,}288"].copy(), row_size),
TransformFromCopy(side_brace, over_brace),
)
self.wait()
self.play(FadeOut(row_rects, lag_ratio=0.1))
# Calculate matrix size
full_product = VGroup(
num_rows_expr.copy(),
Tex(R"\times"),
row_size.copy(),
Tex(Rf"="),
Integer(4 * d_model * d_model)
)
full_product.scale(1.5)
full_product.arrange(RIGHT, buff=MED_SMALL_BUFF)
full_product.next_to(row_rects, UP, buff=2.5)
self.play(LaggedStart(
frame.animate.reorient(0, 0, 0, (-3.88, 1.51, 0.0), 11.35),
TransformFromCopy(num_rows_expr, full_product[0]),
FadeIn(full_product[1], UP),
TransformFromCopy(row_size, full_product[2]),
lag_ratio=0.25,
run_time=2
))
self.play(
TransformFromCopy(full_product[:3], full_product[3:])
)
self.wait()
self.play(FlashAround(full_product[-1], run_time=2, time_width=1.5))
# Count bias parameters
bias_count = Tex(R"4 \times 12{,}288")
bias_count.match_height(full_product)
bias_count.match_y(full_product)
bias_count.match_x(bias)
bias_rect = SurroundingRectangle(VGroup(bias, bias_name))
bias_rect.set_stroke(BLUE_B)
bias_arrow = Arrow(bias_rect.get_top(), bias_count.get_bottom())
bias_arrow.match_color(bias_rect)
bias_count.match_color(bias_rect)
div_eq = Tex(R"{4 \times 12{,}288 \over 603{,}979{,}776} \approx 0.00008 ")
div_eq[R"{4 \times 12{,}288"].match_color(bias_rect)
div_eq.next_to(frame.get_corner(UR), DL, buff=MED_LARGE_BUFF)
div_eq.shift(RIGHT)
self.play(ShowCreation(bias_rect))
self.play(
GrowArrow(bias_arrow),
FadeInFromPoint(bias_count, bias_arrow.get_start()),
full_product.animate.scale(0.8).shift(3.5 * LEFT)
)
self.wait()
self.play(
frame.animate.set_x(-3.0),
FadeTransform(bias_count.copy(), div_eq[R"4 \times 12{,}288"]),
Write(div_eq[R"\over"]),
FadeTransform(full_product[-1].copy(), div_eq[R"603{,}979{,}776}"]),
Write(div_eq[R"\approx 0.00008"]),
)
self.wait()
self.play(
frame.animate.reorient(0, 0, 0, (-2.5, 0.44, 0.0), 9.33),
*map(FadeOut, [full_product, bias_rect, bias_arrow, bias_count, div_eq])
)
# Collapse
substrs = [R"W_\uparrow", R"\vec{\textbf{E}}_i", "+", R"\vec{\textbf{B}}_\uparrow"]
linear_expr = Tex(" ".join(substrs))
W_up, E_i, plus2, B_up = [linear_expr[ss] for ss in substrs]
VGroup(W_up, B_up).set_color(BLUE)
E_i.set_color(YELLOW)
linear_expr.move_to(plus).shift(0.6 * LEFT)
low_emb_label = E_i.copy()
low_emb_label.scale(0.5).next_to(emb, UP)
self.play(
frame.animate.reorient(0, 0, 0, (-0.03, 0.03, 0.0), 8.34),
ReplacementTransform(up_proj, W_up, lag_ratio=1e-3),
FadeOut(side_brace, RIGHT, scale=0.5),
FadeOut(num_rows_expr, RIGHT, scale=0.5),
FadeOut(over_brace, DR, scale=0.5),
FadeOut(row_size, DR, scale=0.5),
)
self.wait()
self.play(ReplacementTransform(up_emb, E_i, lag_ratio=1e-2))
self.play(TransformFromCopy(E_i, low_emb_label))
self.wait()
self.play(
ReplacementTransform(plus, plus2),
ReplacementTransform(bias, B_up, lag_ratio=1e-2),
FadeOut(bias_name, DL),
VGroup(eq, rhs).animate.next_to(B_up, RIGHT).shift(0.1 * DOWN),
run_time=2
)
self.wait()
# Add parameters below first linear arrow
self.play(
linear_expr.animate.scale(0.5).next_to(arrows[0], DOWN, buff=0.1),
ReplacementTransform(rhs, vects[0]),
FadeOut(eq, 4 * DOWN + LEFT),
run_time=2
)
self.wait()
# Pull up ReLU
self.play(phases[1].animate.set_opacity(1))
phase1_copy = VGroup(vects[0], arrows[1], vects[1]).copy()
phase1_copy.save_state()
self.play(
phase1_copy.animate.scale(2.0).next_to(full_block, UP, buff=0.5),
frame.animate.reorient(0, 0, 0, (-0.26, 0.54, 0.0), 9.40)
)
self.wait()
# Break down ReLU
relu_arrow = phase1_copy[1]
neg_arrows = VGroup()
pos_arrows = VGroup()
neg_left_rects = VGroup()
zero_right_rects = VGroup()
pos_left_rects = VGroup()
pos_right_rects = VGroup()
in_vect = phase1_copy[0]
out_vect = phase1_copy[2]
for e1, e2 in zip(in_vect.get_entries(), out_vect.get_entries()):
arrow = Arrow(e1, e2, buff=0.3)
if e1.get_value() > 0:
arrow.set_color(BLUE)
pos_arrows.add(arrow)
pos_left_rects.add(SurroundingRectangle(e1, color=BLUE))
pos_right_rects.add(SurroundingRectangle(e2, color=BLUE))
else:
arrow.set_color(RED)
neg_arrows.add(arrow)
neg_left_rects.add(SurroundingRectangle(e1, color=RED))
zero_right_rects.add(SurroundingRectangle(e2, color=RED))
VGroup(neg_left_rects, zero_right_rects, pos_left_rects, pos_right_rects).set_stroke(width=2)
self.play(ShowCreation(neg_left_rects, lag_ratio=0.5))
self.wait()
self.play(
TransformFromCopy(neg_left_rects, zero_right_rects, lag_ratio=0.5),
ShowCreation(neg_arrows, lag_ratio=0.5),
FadeOut(relu_arrow),
)
self.wait()
self.play(
FadeOut(neg_left_rects, lag_ratio=0.25),
FadeOut(zero_right_rects, lag_ratio=0.25),
FadeOut(neg_arrows, lag_ratio=0.25),
ShowCreation(pos_left_rects)
)
self.wait()
self.play(
ShowCreation(pos_arrows, lag_ratio=0.5),
TransformFromCopy(pos_left_rects, pos_right_rects, lag_ratio=0.5),
)
self.wait()
# Graph ReLU
relu_title_full = Text("Rectified\nLinear\nUnit", alignment="LEFT")
relu_title_full.next_to(relu_arrow, UP)
axes = Axes((-4, 4), (-1, 4))
axes.set_width(6)
axes.next_to(phase1_copy, RIGHT, buff=1.0)
axes.add_coordinate_labels(font_size=16)
relu_graph = axes.get_graph(lambda x: max(0, x), discontinuities=[0])
relu_graph.set_stroke(YELLOW, 4)
plot = VGroup(axes, relu_graph)
relu_graph_label = Text("ReLU")
relu_graph_label.match_color(relu_graph)
relu_graph_label.move_to(axes, UL)
self.play(
frame.animate.set_x(2.7),
FadeIn(relu_arrow),
FadeIn(relu_title_full, 0.1 * UP, lag_ratio=0.1, run_time=2),
FadeOut(pos_arrows, lag_ratio=0.25),
FadeOut(pos_left_rects, lag_ratio=0.25),
FadeOut(pos_right_rects, lag_ratio=0.25),
FadeIn(plot, RIGHT),
)
self.wait()
self.play(*(
TransformFromCopy(relu_title_full[substr], relu_graph_label[substr])
for substr in ["Re", "L", "U"]
))
self.add(relu_graph_label)
# Recall the meaning of the first entry
mid_vect = phase1_copy[0]
conditions_rect = SurroundingRectangle(conditions, buff=0.25)
conditions_rect.set_stroke(YELLOW, 1)
under_brace = Brace(conditions_rect, DOWN, buff=SMALL_BUFF)
VGroup(conditions, conditions_rect, under_brace).next_to(mid_vect, UP)
fe_rect = SurroundingRectangle(mid_vect.get_entries()[0])
condition_group = VGroup(fe_rect, under_brace, conditions, conditions_rect)
self.play(
frame.animate.reorient(0, 0, 0, (2.61, 0.97, 0.0), 11.5),
ShowCreation(fe_rect),
GrowFromCenter(under_brace),
)
self.play(
TransformFromCopy(fe_rect, conditions_rect),
FadeInFromPoint(conditions, fe_rect.get_center()),
)
self.wait()
self.play(condition_group.animate.match_x(phase1_copy[2]))
equals = Tex("=")
ineq = conditions[1][0]
equals.replace(ineq, dim_to_match=0)
self.play(
FlashAround(equals, run_time=2, time_width=1.5),
ineq.animate.become(equals)
)
self.wait()
self.play(
frame.animate.reorient(0, 0, 0, (2.48, 0.33, 0.0), 9.17),
FadeOut(condition_group, lag_ratio=0.01)
)
# Graph GeLU
gelu_title_full = Text("Gaussian\nError\nLinear\nUnit", font_size=42, alignment="LEFT")
gelu_title_full.next_to(relu_arrow, UP)
gelu_graph = axes.get_graph(lambda x: x * norm.cdf(x))
gelu_graph.set_stroke(GREEN, 4)
gelu_graph_label = Text("GELU")
gelu_graph_label.next_to(relu_graph_label, DOWN, buff=MED_LARGE_BUFF, aligned_edge=LEFT)
gelu_graph_label.match_color(gelu_graph)
self.play(
FadeTransform(relu_title_full, gelu_title_full),
relu_graph_label.animate.set_fill(opacity=0.25),
relu_graph.animate.set_stroke(opacity=0.25),
ShowCreation(gelu_graph),
TransformFromCopy(relu_graph_label, gelu_graph_label)
)
self.wait(2)
self.play(
gelu_graph.animate.set_stroke(opacity=0.25),
gelu_graph_label.animate.set_fill(opacity=0.25),
relu_graph.animate.set_stroke(opacity=1),
relu_graph_label.animate.set_fill(opacity=1),
FadeTransform(gelu_title_full, relu_title_full),
)
self.wait()
# Describe these as neurons
neuron_word = Text("Neurons", font_size=72)
neuron_word.next_to(phase1_copy, RIGHT, buff=2.5)
neuron_arrows = VGroup(
Arrow(neuron_word.get_left(), entry.get_right(), buff=0.4, stroke_width=3)
for entry in phase1_copy[2].get_entries()
)
self.play(
plot.animate.set_width(2).next_to(relu_arrow, DOWN),
FadeOut(VGroup(relu_graph_label, gelu_graph_label, gelu_graph)),
Write(neuron_word),
ShowCreation(neuron_arrows, lag_ratio=0.2, run_time=3),
LaggedStartMap(
FlashAround, phase1_copy[2].get_entries(),
time_width=3.0,
lag_ratio=0.05,
time_span=(1, 4),
run_time=4
)
)
self.wait()
# Show the classic dots picture
blocking_rect = BackgroundRectangle(VGroup(phase1_copy), buff=0.1)
blocking_rect.set_fill(BLACK, 1)
up_emb.move_to(blocking_rect, LEFT)
dots = VGroup(
Dot(radius=0.15).move_to(entry).set_fill(WHITE, opacity=clip(entry.get_value(), 0, 1))
for entry in phase1_copy[2].get_entries()
)
dots.set_stroke(WHITE, 2)
up_emb = emb.copy()
up_emb.rotate(PI / 2, DOWN)
up_emb.rotate(1 * DEGREES)
up_emb.match_width(phase1_copy[0])
up_emb.move_to(phase1_copy[0]).shift(RIGHT)
up_emb[-2:].set_color(YELLOW)
lines = VGroup(
Line(entry.get_right() + 0.05 * RIGHT, dot).set_stroke(
color=value_to_color(random.uniform(-10, 10)),
width=3 * random.random()**2,
)
for entry in up_emb.get_entries()
for dot in dots
)
self.play(
FadeIn(blocking_rect),
Write(dots),
)
self.play(TransformFromCopy(emb, up_emb))
self.play(ShowCreation(lines, lag_ratio=3 / len(lines)))
self.wait()
self.play(
LaggedStart(*map(FadeOut, [up_emb, *lines, blocking_rect, *dots]), lag_ratio=0.01)
)
# Discuss active and inactive
entry = phase1_copy[2].get_entries()[0]
entry_rect = SurroundingRectangle(entry)
entry_rect.set_stroke(YELLOW, 2)
active_words = TexText(R"``Michael Jordan'' neuron is \emph{active}")
active = active_words["active"][0]
active.set_color(BLUE_B)
active_words.next_to(entry_rect, UP, aligned_edge=LEFT)
active_words.shift(LEFT)
inactive = TexText(R"\emph{inactive}")
inactive.set_color(RED)
inactive.move_to(active, LEFT)
self.play(
frame.animate.reorient(0, 0, 0, (2.45, 0.58, 0.0), 9.65),
ShowCreation(entry_rect),
Write(active_words, run_time=1),
)
self.wait()
self.play(
ChangeDecimalToValue(entry, 0),
ReplacementTransform(active, inactive[2:]),
GrowFromCenter(inactive[:2]),
)
active_words.add(inactive)
self.wait()
# Replace the ReLU diagram portion
self.play(
Restore(phase1_copy),
TransformMatchingStrings(relu_title_full, arrow_labels[1]),
plot.animate.scale(0.5).next_to(arrows[1], DOWN, SMALL_BUFF),
FadeOut(neuron_word, DOWN),
FadeOut(neuron_arrows, DOWN, lag_ratio=0.1),
FadeOut(entry_rect, DOWN),
FadeOut(active_words, DOWN, lag_ratio=0.01),
run_time=1.5
)
self.remove(phase1_copy)
# Down projection
neurons = vects[1].copy()
neurons.target = neurons.generate_target()
neurons.target.set_height(4)
neurons.target.move_to(3 * RIGHT + 2.5 * UP)
down_proj = WeightMatrix(shape=(6, 9))
down_proj.set_height(2.75)
down_proj.next_to(neurons.target, LEFT)
plus = Tex("+")
plus.next_to(neurons.target, RIGHT)
bias = WeightMatrix(shape=(6, 1))
bias.match_height(down_proj)
bias.next_to(plus, RIGHT)
equals = Tex("=")
equals.next_to(bias, RIGHT)
rhs = vects[2].copy()
rhs.set_opacity(1)
rhs.match_height(bias)
rhs.next_to(equals, RIGHT)
self.play(phases[2].animate.set_opacity(1))
self.play(MoveToTarget(neurons))
self.play(FadeTransform(arrows[2].copy(), down_proj))
self.wait()
temp_eq, temp_rhs = show_matrix_vector_product(self, down_proj, neurons)
self.wait()
self.play(
FadeOut(temp_eq, DOWN),
FadeOut(temp_rhs, DOWN),
Write(plus),
FadeIn(bias, RIGHT),
)
self.wait()
self.play(
Write(equals),
TransformFromCopy(vects[2], rhs),
)
self.wait()
# Name it as the down-projection
over_brace = Brace(down_proj, UP)
name = TexText("``Down projection''")
name.next_to(over_brace, UP)
side_brace = Brace(rhs, RIGHT)
dim_count = Integer(12288)
dim_count.next_to(side_brace, RIGHT)
self.play(
CountInFrom(dim_count),
GrowFromCenter(side_brace),
)
self.wait()
self.play(
Write(name),
GrowFromCenter(over_brace),
)
self.wait()
# Show column-by-column
col_matrix = self.get_col_matrix(down_proj, 7)
bias_as_col = self.get_col_matrix(bias, 1, dots_index=None, sym="B", top_index="", width_multiple=0.7)
n_labels = VGroup(
Tex(f"n_{{{m}}}")
for m in [*range(6), "m"]
)
n_labels.arrange(DOWN, buff=0.5)
n_labels.match_height(neurons.get_entries())
n_labels.move_to(neurons.get_entries())
n_labels.replace_submobject(-2, Tex(R"\vdots").move_to(n_labels[-2]))
n_labels.set_color(BLUE)
n_vect = VGroup(neurons[-2:].copy(), n_labels)
self.play(
LaggedStart(*map(FadeOut, [over_brace, name, side_brace, dim_count])),
LaggedStart(
FadeOut(down_proj),
FadeIn(col_matrix),
FadeOut(neurons),
FadeIn(n_vect),
FadeOut(bias),
FadeIn(bias_as_col),
)
)
self.wait()
# Expand the column interpretation
over_brace = Brace(VGroup(col_matrix, n_vect), UP)
scaled_cols = VGroup(
VGroup(n_label, col_label[0]).copy()
for n_label, col_label in zip(n_labels, col_matrix[1])
)
scaled_cols.target = VGroup()
for pair in scaled_cols:
pair.target = pair.generate_target()
pair.target[0].scale(1.5)
pair.target.arrange(RIGHT, buff=0.1, aligned_edge=DOWN)
scaled_cols.target.add(pair.target)
scaled_cols.target[-2].become(Tex(R"\dots"))
scaled_cols.target.arrange(RIGHT, buff=0.75)
scaled_cols.target.set_width(1.25 * over_brace.get_width())
scaled_cols.target.next_to(over_brace, UP, buff=0.5)
plusses = VGroup(
Tex("+").move_to(midpoint(m1.get_right(), m2.get_left()))
for m1, m2 in zip(scaled_cols.target, scaled_cols.target[1:])
)
self.play(
frame.animate.reorient(0, 0, 0, (-0.27, 1.04, 0.0), 11.06),
GrowFromCenter(over_brace),
LaggedStartMap(MoveToTarget, scaled_cols, lag_ratio=0.7, run_time=5),
LaggedStartMap(FadeIn, plusses, lag_ratio=0.7, run_time=5),
)
self.wait()
# Highlight each set
last_rects = VGroup()
all_rect_groups = VGroup()
for tup in zip(col_matrix[1], n_labels, scaled_cols):
rects = VGroup(SurroundingRectangle(mob) for mob in tup)
rects.set_stroke(YELLOW, 2)
self.play(
FadeOut(last_rects),
FadeIn(rects),
)
self.wait(0.5)
all_rect_groups.add(rects)
last_rects = rects
self.play(FadeOut(last_rects))
# First column as basketball
col_rect, n_rect, prod_rect = rects = all_rect_groups[0]
basketball = Text("Basketball", font_size=60)
basketball.set_color("#F88158")
basketball.next_to(col_rect, LEFT)
basketball.save_state()
basketball.rotate(-PI / 2)
basketball.move_to(col_rect)
basketball.set_opacity(0)
n0_term = scaled_cols[0][0]
n0_term.save_state()
one = Tex("1", font_size=60).move_to(n0_term, DR).set_color(BLUE)
zero = Tex("0", font_size=60).move_to(n0_term, DR).set_color(RED)
self.play(
ShowCreation(col_rect),
col_matrix[1][1:].animate.set_opacity(0.5),
n_labels[1:].animate.set_opacity(0.5),
scaled_cols[1:].animate.set_opacity(0.5),
plusses.animate.set_opacity(0.5)
)
self.play(Restore(basketball, path_arc=PI / 2))
self.wait()
self.play(TransformFromCopy(col_rect, n_rect))
self.wait()
self.play(
TransformFromCopy(col_rect, prod_rect),
TransformFromCopy(n_rect, prod_rect),
)
self.play(Transform(n0_term, one))
self.wait()
self.play(Transform(n0_term, zero))
self.wait()
self.play(Restore(n0_term))
n0_term.restore()
self.wait()
# Cycle through columns one more time
rects.add(basketball)
for index in range(1, len(all_rect_groups)):
self.play(
FadeOut(all_rect_groups[index - 1]),
FadeIn(all_rect_groups[index]),
col_matrix[1][index].animate.set_opacity(1),
n_labels[index].animate.set_opacity(1),
scaled_cols[index].animate.set_opacity(1),
plusses[index - 1].animate.set_opacity(1),
)
self.wait(0.5)
self.play(FadeOut(all_rect_groups[-1]))
# Highlight bias
bias_rect = SurroundingRectangle(bias)
bias_brace = Brace(bias_rect, UP)
bias_word = Text("Bias")
bias_word.next_to(bias_brace, UP, MED_SMALL_BUFF)
self.play(
ReplacementTransform(over_brace, bias_brace),
FadeIn(bias_rect),
FadeOut(plusses, lag_ratio=0.1),
FadeOut(scaled_cols, lag_ratio=0.1),
)
self.play(FadeIn(bias_word, 0.5 * UP))
self.wait()
self.play(LaggedStart(*map(FadeOut, [bias_word, bias_brace, bias_rect])))
# Collpase the down projection
W_down = Tex(R"W_\downarrow", font_size=60).set_color(BLUE)
B_down = Tex(R"\vec{\textbf{B}}_\downarrow", font_size=60).set_color(BLUE_B)
W_down.next_to(neurons, LEFT)
B_down.move_to(bias_as_col)
WB_down = VGroup(W_down, B_down)
n_rect = Rectangle(1, 1)
n_rect.set_height(W_down.get_height())
n_rect.move_to(n_vect)
n_rect.set_fill(GREY_C)
n_rect.set_stroke(WHITE, 1)
down_proj_expr = VGroup(W_down, n_vect, plus, B_down)
down_proj_expr.target = down_proj_expr.generate_target()
down_proj_expr.target[1].become(VGroup(n_rect))
down_proj_expr.target.arrange(RIGHT, buff=SMALL_BUFF)
down_proj_expr.target.scale(0.4)
down_proj_expr.target.next_to(arrows[2], DOWN)
self.play(ReplacementTransform(col_matrix, W_down, lag_ratio=5e-3, run_time=2))
self.play(ReplacementTransform(bias_as_col, B_down, lag_ratio=1e-2))
self.wait()
self.play(
LaggedStart(
MoveToTarget(down_proj_expr),
FadeOut(equals, 2 * DOWN + 0.5 * LEFT),
ReplacementTransform(rhs, vects[2]),
lag_ratio=0.25,
time_span=(0, 1.5),
),
frame.animate.reorient(0, -14, 0, (-0.1, -2.03, 0.01), 6.31),
run_time=2,
)
self.wait()
# Add it to the original
faded_sum_circuit = sum_circuit.copy()
sum_circuit.set_stroke(opacity=1)
sum_circuit.insert_n_curves(20)
self.add(faded_sum_circuit)
self.play(
frame.animate.reorient(13, -8, 0, (0.15, -2.05, 0.0), 6.52),
ShowCreation(sum_circuit, lag_ratio=0.5),
low_emb_label.animate.shift(0.2 * LEFT).set_anim_args(time_span=(0, 1)),
FadeOut(output_emb),
run_time=2,
)
self.remove(faded_sum_circuit)
output_emb.set_fill(opacity=1)
self.play(LaggedStart(
TransformFromCopy(emb, output_emb, path_arc=-45 * DEGREES),
TransformFromCopy(vects[2], output_emb, path_arc=-45 * DEGREES),
run_time=2,
lag_ratio=0.2,
))
self.wait()
# Yet again, emphasize the MJ example
m_color = interpolate_color_by_hsl(GREY_BROWN, WHITE, 0.5)
j_color = RED_B
b_color = basketball.get_color()
m_tex = Tex(R"\overrightarrow{\text{F.N. Michael}}").set_color(m_color)
j_tex = Tex(R"\overrightarrow{\text{L.N. Jordan}}").set_color(j_color)
b_tex = Tex(R"\overrightarrow{\text{Basketball}}").set_color(b_color)
mj = VGroup(m_tex, Tex("+"), j_tex).copy()
mjb = VGroup(m_tex, Tex("+"), j_tex, Tex("+"), b_tex).copy()
for tex_mob in [mj, mjb]:
tex_mob.set_height(0.45)
tex_mob.arrange(RIGHT, buff=SMALL_BUFF)
tex_mob.set_fill(border_width=1)
mj.next_to(low_emb_label, UP, buff=1.0).shift(0.5 * LEFT)
mjb.next_to(output_emb, UP, buff=1.5).shift(1.0 * RIGHT)
mj_arrow = Arrow(mj.get_bottom(), low_emb_label, buff=0.1)
mjb_arrow = Arrow(output_emb.get_top(), mjb.get_bottom(), buff=0.15)
self.play(
frame.animate.reorient(4, -6, 0, (-0.29, -1.76, 0.02), 7.70),
FadeIn(mj, lag_ratio=0.1),
ShowCreation(mj_arrow)
)
self.play(Transform(mj.copy(), emb.copy().set_opacity(0), lag_ratio=0.005, remover=True, run_time=2))
mover = emb.copy()
for vect in [*vects, output_emb]:
self.play(Transform(mover, vect, rate_func=linear))
self.remove(mover)
self.play(
frame.animate.reorient(-3, -5, 0, (1.09, -1.48, -0.03), 9.61),
FadeTransform(mj.copy(), mjb[:3]),
FadeTransformPieces(mj.copy()[-1:], mjb[3:]),
ShowCreation(mjb_arrow),
run_time=2,
)
self.wait(2)
self.play(
frame.animate.reorient(21, -14, 0, (-0.13, -2.21, 0.11), 6.91).set_anim_args(run_time=5),
LaggedStartMap(FadeOut, VGroup(mj, mj_arrow, mjb_arrow, mjb)),
)
# Show it done in parallel to all embeddings
self.play(
frame.animate.reorient(14, -12, 0, (0.55, -2.21, 0.18), 7.05),
LaggedStart((
TransformFromCopy(simple_phases, sp_copy)
for sp_copy in simple_phase_copies
), lag_ratio=0.1),
FadeOut(block_title, time_span=(0, 1)),
run_time=5,
)
self.play(
frame.animate.reorient(42, -23, 0, (0.55, -2.21, 0.18), 7.05),
run_time=8
)
self.wait()
# Show neurons?
sum_circuits = VGroup(
sum_circuit,
*(sp[0] for sp in simple_phase_copies),
*(sp[-1] for sp in simple_phase_copies),
)
n_vects = VGroup(vects[1], *(sp[2][1] for sp in simple_phase_copies))
neuron_points = np.array([
entry.get_center()
for vect in n_vects[1:]
for entry in vect.get_entries()
])
neurons = DotCloud(neuron_points)
neurons.set_radius(0.075)
neurons.set_shading(0.25, 0.25, 0.5)
neurons.apply_depth_test()
rgbas = np.random.random(len(neuron_points))
rgbas = rgbas.repeat(4).reshape((rgbas.size, 4))
rgbas[:, 3] = 1
neurons.set_rgba_array(rgbas)
neuron_ellipses = VGroup(
n_vect.get_ellipses()
for n_vect in n_vects[1:]
)
self.play(
frame.animate.reorient(11, -5, 0, (0.55, -2.21, 0.18), 7.05),
sum_circuits.animate.set_stroke(width=1, opacity=0.2),
FadeOut(block[4]),
run_time=2
)
self.play(
frame.animate.reorient(-11, -5, 0, (0.55, -2.21, 0.18), 7.05).set_anim_args(run_time=4),
FadeOut(n_vects),
ShowCreation(neurons, run_time=2),
FadeIn(neuron_ellipses, time_span=(1, 2)),
)
self.add(neuron_ellipses)
self.play(frame.animate.reorient(13, -7, 0, (0.55, -2.21, 0.18), 7.05), run_time=4)
self.wait()
def get_sum_circuit(
self, in_vect, diff_vect,
v_buff=0.15,
h_buff=0.5,
y_diff=0.65,
color=YELLOW
):
plus = VGroup(Line(UP, DOWN), Line(LEFT, RIGHT))
plus.scale(0.6)
circle = Circle(radius=1)
oplus = VGroup(circle, plus)
oplus.set_height(0.3)
oplus.next_to(diff_vect, RIGHT, buff=h_buff)
p0 = in_vect.get_top() + v_buff * UP
p1 = in_vect.get_top() + y_diff * UP
p2 = oplus.get_center()
p2[1] = p1[1]
p3 = oplus.get_top()
top_line = VMobject()
top_line.set_points_as_corners([p0, p1, p2, p3])
oplus.refresh_bounding_box() # Why?
h_line1 = Line(diff_vect.get_right(), oplus.get_left())
h_line2 = Line(oplus.get_right(), oplus.get_right() + h_buff * RIGHT)
output = diff_vect.copy()
output.next_to(h_line2, RIGHT, buff=0)
for e1, e2, e3 in zip(in_vect.get_entries(), diff_vect.get_entries(), output.get_entries()):
e3.set_value(e1.get_value() + e2.get_value())
circuit = VGroup(top_line, oplus, h_line1, h_line2)
circuit.set_stroke(color, 3)
return circuit, output
def get_col_matrix(self, matrix, n_cols_shown, dots_index=-2, sym="C", top_index="m-1", width_multiple=1.0):
C_labels = VGroup(
Tex(Rf"\vec{{\textbf{{{sym}}}}}_{{{n}}}")
for n in [*range(n_cols_shown - 1), top_index]
)
C_labels.arrange(RIGHT, buff=0.5)
C_labels.move_to(matrix.get_entries())
C_labels.set_width(matrix.get_entries().get_width() * width_multiple)
v_lines = VGroup(
Line(matrix.get_bottom(), C_labels.get_bottom() + SMALL_BUFF * DOWN),
Line(C_labels.get_top() + SMALL_BUFF * UP, matrix.get_top()),
)
v_lines.set_stroke(WHITE, 1)
col_labels = VGroup(
VGroup(C_label, v_lines.copy().match_x(C_label))
for C_label in C_labels
)
if dots_index is not None:
dots = Tex(R"\hdots")
dots.move_to(col_labels[dots_index])
col_labels.replace_submobject(dots_index, dots)
return VGroup(matrix.get_brackets().copy(), col_labels)
class NonlinearityOfLanguage(InteractiveScene):
def construct(self):
# Set up axes and M + J
unit_size = 2.5
plane = NumberPlane(
axis_config=dict(
stroke_width=1,
),
background_line_style=dict(
stroke_color=BLUE_D,
stroke_width=1,
stroke_opacity=0.75
),
faded_line_ratio=1,
unit_size=unit_size,
)
m_vect = Vector(unit_size * RIGHT).rotate(60 * DEGREES, about_point=ORIGIN)
j_vect = m_vect.copy().rotate(-90 * DEGREES, about_point=ORIGIN)
m_vect.set_color(YELLOW)
j_vect.set_color(RED)
m_ghost = m_vect.copy().shift(j_vect.get_vector())
j_ghost = j_vect.copy().shift(m_vect.get_vector())
VGroup(m_ghost, j_ghost).set_stroke(opacity=0.25)
sum_point = m_ghost.get_end()
span_line = Line(-sum_point, sum_point)
span_line.set_length(2 * FRAME_WIDTH)
span_line.set_stroke(WHITE, 2, opacity=0.5)
self.add(plane)
self.add(m_vect, m_ghost, j_vect, j_ghost)
self.add(span_line)
# Label vectors
m_label = Text("First Name Michael")
j_label = Text("Last Name Jordan")
for label, vect in [(m_label, m_vect), (j_label, j_vect)]:
label.scale(0.6)
label.match_color(vect)
direction = np.sign(vect.get_vector()[1]) * UP
label.next_to(ORIGIN, direction, buff=0.2, aligned_edge=LEFT)
label.rotate(vect.get_angle(), about_point=ORIGIN)
label.set_backstroke(BLACK, 3)
self.add(m_label)
self.add(j_label)
# Add dot product expression
expr = Tex(R"(\vec{\textbf{M}} + \vec{\textbf{J}}) \cdot \textbf{E}")
expr[1:3].match_color(m_vect)
expr[4:6].match_color(j_vect)
expr.to_corner(UL)
self.add(expr)
# Set up embedding with dot product tracker
emb_point = VectorizedPoint(unit_size * UL)
emb = Vector()
emb.add_updater(lambda m: m.put_start_and_end_on(ORIGIN, emb_point.get_center()))
normalized_sum = normalize(sum_point)
def get_line_point():
return normalized_sum * np.dot(normalized_sum, emb_point.get_center())
shadow = Line()
shadow.set_stroke(PINK, 3)
shadow.add_updater(lambda m: m.put_start_and_end_on(ORIGIN, get_line_point())) # This is a long line
dot = Dot()
dot.set_fill(PINK, 1)
dot.f_always.move_to(get_line_point)
dashed_line = always_redraw(
lambda: DashedLine(emb_point.get_center(), get_line_point()).set_stroke(PINK, 2)
)
dp_decimal = DecimalNumber(font_size=36)
dp_decimal.match_color(dot)
dp_decimal.f_always.set_value(lambda: np.dot(normalized_sum, emb_point.get_center()) * 2.0 / 3.535534)
dp_decimal.always.next_to(dot, DR, buff=SMALL_BUFF)
self.add(shadow, emb, dot, dashed_line, dp_decimal)
emb_point.move_to(ORIGIN + 0.01 * UP)
for point in [m_vect.get_end(), m_ghost.get_end(), j_vect.get_end(), m_ghost.get_end()]:
self.play(emb_point.animate.move_to(point), run_time=3)
# Set up names
names = VGroup(
Text(name, font_size=36)
for name in [
"Michael Jordan",
"Michael Phelps",
"Alexis Jordan",
]
)
name_points = [
sum_point,
m_vect.get_end(),
j_vect.get_end(),
]
for name, point in zip(names, name_points):
name.set_backstroke(BLACK, 3)
direction = RIGHT + np.sign(point[1]) * UP
name.next_to(point, direction, buff=0.1)
# Go through names
name = names[0].copy()
name_ghosts = names.copy().set_fill(opacity=0.75).set_stroke(width=0)
self.play(
FadeIn(name, 0.5 * UP),
Rotate(emb_point, TAU, about_point=emb_point.get_center() + 0.15 * DL, run_time=4),
)
self.wait()
self.add(name_ghosts[0])
self.play(
Transform(name, names[1]),
emb_point.animate.move_to(m_vect.get_end()),
run_time=2,
)
self.wait()
self.add(name_ghosts[1])
self.play(
Transform(name, names[2]),
emb_point.animate.move_to(j_vect.get_end()).set_anim_args(path_arc=30 * DEGREES),
run_time=2,
)
self.add(name_ghosts[2])
self.wait()
# Show other names
other_point = span_line.pfp(0.45)
other_word = Text("(Other)", font_size=36)
other_word.set_fill(GREY_B)
other_word.next_to(other_point, UL, buff=0)
self.play(
emb_point.animate.move_to(other_point),
LaggedStart(
FadeOut(name),
FadeIn(other_word),
lag_ratio=0.5,
),
run_time=3
)
self.wait()
# Show "yes" vs. "no" regions
regions = FullScreenRectangle().scale(2).replicate(2)
regions.arrange(LEFT, buff=0)
regions[0].set_fill(GREEN_B, 0.35)
regions[1].set_fill(RED, 0.25)
regions.rotate(span_line.get_angle(), about_point=ORIGIN)
regions.shift(0.85 * sum_point)
yes_no_words = VGroup(
Text("Yes", font_size=72).set_fill(GREEN).to_corner(UR),
Text("No", font_size=72).set_fill(RED).to_edge(UP).shift(LEFT),
)
for region, word in zip(regions, yes_no_words):
self.play(FadeIn(region), FadeIn(word))
self.wait()
class Superposition(InteractiveScene):
def construct(self):
# Add undulating bubble to encompass N-dimensional space
frame = self.frame
bubble = self.undulating_bubble()
bubble_label = TexText(R"$N$-dimensional\\ Space")
bubble_label.set_height(1)
bubble_label["$N$"].set_color(YELLOW)
bubble_label.next_to(bubble, LEFT)
self.add(bubble)
self.add(bubble_label)
# Preview some ideas
ideas = VGroup(Text("Latin"), Text("Microphone"), Text("Basketball"), Text("The 1920s"))
ideas.scale(0.75)
vectors = VGroup()
idea_vects = VGroup()
vect = DOWN
colors = [PINK, GREEN, ORANGE, BLUE]
for idea, color in zip(ideas, colors):
vect = rotate_vector(vect, 80 * DEGREES)
vector = Vector(1.25 * normalize(vect))
idea.next_to(vector.get_end(), vector.get_vector(), buff=SMALL_BUFF)
idea_vect = VGroup(vector, idea)
idea_vect.set_color(color)
idea_vect.shift(bubble.get_center())
idea_vects.add(idea_vect)
frame.save_state()
frame.scale(0.75)
frame.move_to(VGroup(bubble, bubble_label))
self.play(
Restore(frame, run_time=7),
LaggedStartMap(VFadeInThenOut, idea_vects, lag_ratio=0.5, run_time=5)
)
# Written conditions and answer
conditions = [
R"$90^\circ$ apart",
R"between $89^\circ$ and $91^\circ$ apart"
]
task1, task2 = tasks = VGroup(
TexText(Rf"Choose multiple vectors,\\ each pair {phrase}", font_size=42, alignment="")
for phrase in conditions
)
task1[R"90^\circ"].set_color(RED)
task2[R"$89^\circ$ and $91^\circ$"].set_color(BLUE)
task1.center().to_edge(UP)
task2.move_to(task1, UL)
maximum1, maximum2 = maxima = VGroup(
TexText(fR"Maximum \# of vectors: {answer}", font_size=42)
for answer in ["$N$", R"$\approx \exp(\epsilon \cdot N)$"]
)
for maximum in maxima:
maximum.next_to(tasks, DOWN, buff=LARGE_BUFF, aligned_edge=LEFT)
maximum1["N"].set_color(YELLOW)
maximum2["N"].set_color(YELLOW)
# Add 3 vectors such that each pair is 90-degrees
perp_vectors = VGroup(*map(Vector, [RIGHT, UP, OUT]))
perp_vectors.set_shading(0.25, 0.25, 0.25)
perp_vectors.set_submobject_colors_by_gradient(RED, GREEN, BLUE)
elbows = VGroup(
Elbow(width=0.1).rotate(angle, axis, about_point=ORIGIN).set_stroke(WHITE, 2)
for angle, axis in [(0, UP), (-PI / 2, UP), (PI / 2, RIGHT)]
)
elbows.set_stroke(GREY_A, 2)
perp_group = VGroup(perp_vectors, elbows)
perp_group.rotate(-10 * DEGREES, UP)
perp_group.rotate(20 * DEGREES, RIGHT)
perp_group.scale(2)
perp_group.move_to(bubble)
self.play(
FadeIn(task1),
LaggedStartMap(GrowArrow, perp_vectors[:2], lag_ratio=0.5)
)
self.play(ShowCreation(elbows[0]))
self.play(
GrowArrow(perp_vectors[2]),
LaggedStartMap(ShowCreation, elbows[1:3], lag_ratio=0.5),
)
self.play(
Rotate(perp_group, -50 * DEGREES, axis=perp_vectors[1].get_vector(), run_time=15),
Write(maximum1, time_span=(2, 4)),
)
# Relax the assumption
ninety_part = task1[conditions[0]]
cross = Cross(ninety_part)
crossed_part = VGroup(ninety_part, cross)
new_cond = task2[conditions[1]]
new_cond.align_to(ninety_part, LEFT)
pairs = VGroup(get_vector_pair(89), get_vector_pair(91))
pairs.arrange(RIGHT)
pairs.to_corner(UL)
self.play(
FadeOut(maximum1),
ShowCreation(cross),
)
self.play(
crossed_part.animate.shift(0.5 * DOWN).set_fill(opacity=0.5),
Write(new_cond),
LaggedStartMap(FadeIn, pairs, lag_ratio=0.25),
)
self.play(
Rotate(perp_group, 50 * DEGREES, axis=perp_vectors[1].get_vector(), run_time=10)
)
# Struggle with 3 vectors (Sub out the title)
three_d_label = TexText(R"3-dimensional\\ Space")
three_d_label["3"].set_color(BLUE)
three_d_label.move_to(bubble_label, UL)
bubble_label.save_state()
pv = perp_vectors
pv.save_state()
alt_vects = pv.copy()
origin = pv[0].get_start()
for vect in alt_vects:
vect.rotate(5 * DEGREES, axis=normalize(np.random.random(3)), about_point=origin)
new_vects = VGroup()
for (v1, v2) in it.combinations(pv, 2):
new_vects.add(Arrow(ORIGIN, v1.get_length() * normalize(v1.get_vector() + v2.get_vector()), buff=0).shift(origin))
new_vects.set_color(YELLOW)
new_vect = new_vects[0]
def shake(vect):
self.play(
vect.animate.rotate(5 * DEGREES, RIGHT, about_point=origin),
rate_func=lambda t: wiggle(t, 9)
)
self.play(
FadeIn(three_d_label, DOWN),
bubble_label.animate.to_edge(DOWN).set_opacity(0.5)
)
self.play(
GrowArrow(new_vect),
Transform(perp_vectors, alt_vects)
)
shake(new_vect)
self.play(
Restore(perp_vectors),
Transform(new_vect, new_vects[1])
)
shake(new_vect)
self.play(
Transform(perp_vectors, alt_vects),
Transform(new_vect, new_vects[2])
)
shake(new_vect)
self.wait()
self.play(
new_vect.animate.scale(0, about_point=origin),
ApplyMethod(perp_group.scale, 0, dict(about_point=origin), lag_ratio=0.25),
Restore(bubble_label),
FadeOut(three_d_label, UP),
run_time=2
)
self.remove(new_vect, perp_group)
# Stack on many vectors
dodec = Dodecahedron()
vertices = [face.get_center() for face in dodec]
vectors = VGroup(Vector(vert) for vert in vertices)
vectors.set_flat_stroke(True)
vectors.rotate(30 * DEGREES, UR)
for vector in vectors:
vector.always.set_perpendicular_to_camera(self.frame)
vector.set_color(random_bright_color(hue_range=(0.5, 0.7)))
vectors.move_to(bubble)
self.wait(6)
self.play(
FadeOut(crossed_part),
Write(maximum2),
Rotating(vectors, TAU, axis=UP, run_time=20),
LaggedStartMap(VFadeIn, vectors, lag_ratio=0.5, run_time=8)
)
self.wait()
# Somehow communicate exponential scaling
def undulating_bubble(self):
bubble = ThoughtBubble(filler_shape=(6, 3))[0][-1]
bubble.set_stroke(WHITE, 1)
bubble.set_fill(GREY)
bubble.set_shading(0.5, 0.5, 0)
bubble.to_edge(DOWN)
points = bubble.get_points().copy()
points -= np.mean(points, 0)
def update_bubble(bubble):
center = bubble.get_center()
angles = np.apply_along_axis(angle_of_vector, 1, points)
stretch_factors = 1.0 + 0.05 * np.sin(6 * angles + self.time)
bubble.set_points(points * stretch_factors[:, np.newaxis])
# bubble.move_to(center)
bubble.set_x(0).to_edge(DOWN)
bubble.add_updater(update_bubble)
return bubble
class StackOfVectors(InteractiveScene):
def construct(self):
# Set up the big matrix
rows = VGroup(
NumericEmbedding(shape=(1, 9), ellipses_col=-5, value_range=(-1, 1))
for n in range(20)
)
rows.arrange(DOWN)
for row in rows:
row.brackets[0].align_to(rows, LEFT)
row.brackets[1].align_to(rows, RIGHT)
rows.set_height(6)
rows.to_edge(DOWN)
rows[-2].become(Tex(R"\vdots").replace(rows[-2], dim_to_match=1))
brackets = NumericEmbedding(shape=(20, 9)).brackets
brackets.set_height(rows.get_height() + MED_SMALL_BUFF)
brackets[0].next_to(rows, LEFT, SMALL_BUFF)
brackets[1].next_to(rows, RIGHT, SMALL_BUFF)
top_brace = Brace(rows[0], UP)
top_label = top_brace.get_text("100-dimensional")
side_brace = Brace(brackets, LEFT)
side_label = side_brace.get_text("10,000\nvectors")
self.play(
GrowFromCenter(top_brace),
FadeIn(top_label, lag_ratio=0.1),
LaggedStartMap(FadeIn, rows, shift=0.25 * DOWN, lag_ratio=0.1, run_time=3),
*map(GrowFromCenter, brackets)
)
self.play(
LaggedStart(
(RandomizeMatrixEntries(row)
for row in rows[:-2]),
lag_ratio=0.05,
)
)
self.wait()
# Label first vector
self.play(
GrowFromCenter(side_brace),
FadeIn(side_label, lag_ratio=0.1),
)
self.wait(4)
class ShowAngleRange(InteractiveScene):
def construct(self):
# Test
angle_tracker = ValueTracker(10)
vect_pair = always_redraw(lambda: get_vector_pair(angle_tracker.get_value(), length=3, colors=(RED, GREEN)))
self.add(vect_pair)
self.play(
angle_tracker.animate.set_value(180),
run_time=8,
)
self.wait()
self.play(
angle_tracker.animate.set_value(95),
run_time=3
)
self.wait()
class MLPFeatures(InteractiveScene):
def construct(self):
# Add neurons
radius = 0.15
layer1, layer2 = layers = VGroup(
Dot(radius=radius).get_grid(n, 1, buff=radius / 2)
for n in [8, 16]
)
layer2.arrange(DOWN, buff=radius)
layers.arrange(RIGHT, buff=3.0)
layers.to_edge(LEFT, buff=LARGE_BUFF)
layers.set_stroke(WHITE, 1)
for neuron in layer1:
neuron.set_fill(opacity=random.random())
layer2.set_fill(opacity=0)
self.add(layers)
# Add connections
connections = get_network_connections(layer1, layer2)
self.add(connections)
# Show single-neuron features
features = iter([
"Table",
"Slang",
"AM Radio",
"Humble",
"Notebook",
"Transparent",
"Duration",
"Madonna",
"Mirror",
"Pole Vaulting",
"Albert Einstein",
"Authentic",
"Scientific",
"Passionate",
"Bell Laboratories",
"Uzbekistan",
"Umbrella",
"Immanuel Kant",
"Baroque Music",
"Intense",
"Clock",
"Water skiing",
"Ancient Egypt",
"Ambiguous",
"Volume",
"Alexander the Great",
"Innovative",
"Religious",
])
last_neuron = VGroup()
last_feature_label = VGroup()
for neuron in layer2[:15]:
feature_label = Text(next(features), font_size=36)
feature_label.next_to(neuron, buff=SMALL_BUFF)
self.play(
FadeOut(last_feature_label),
FadeIn(feature_label),
last_neuron.animate.set_fill(opacity=0),
neuron.animate.set_fill(opacity=1),
)
last_neuron = neuron
last_feature_label = feature_label
# Show polysemantic features
brace = Brace(layer2, RIGHT)
def to_random_state(layer):
for dot in layer.generate_target():
dot.set_fill(opacity=random.random())
return MoveToTarget(layer)
self.play(
feature_label.animate.scale(48 / 36).next_to(brace, RIGHT),
GrowFromCenter(brace),
to_random_state(layer2),
)
self.wait()
for n in range(12):
feature_label = Text(next(features))
feature_label.next_to(brace, RIGHT)
self.play(
FadeOut(last_feature_label),
FadeIn(feature_label),
to_random_state(layer2),
)
self.wait(0.5)
last_feature_label = feature_label
class BreakDownThreeSteps(BasicMLPWalkThrough):
def construct(self):
# Add four vectors, spaced apart
vectors = VGroup(
NumericEmbedding(length=n)
for n in [8, 16, 16, 8]
)
vectors.set_height(6)
vectors.arrange(RIGHT, buff=3.5)
vectors[2].shift(1.1 * LEFT)
vectors[1].shift(0.2 * LEFT)
vectors.shift(DOWN)
for e1, e2 in zip(vectors[1].get_entries(), vectors[2].get_entries()):
e2.set_value(max(e1.get_value(), 0))
# Add arrows between them
arrows = VGroup(
Arrow(v1, v2)
for v1, v2 in zip(vectors, vectors[1:])
)
arrows.shift(DOWN)
E_sym = Tex(R"\vec{\textbf{E}}")
E_sym.next_to(arrows[0], LEFT).shift(0.1 * UP)
for vect in vectors:
vect.scale(0.75)
vect.shift(0.25 * UP)
# Put matrices on outer two
up_proj, down_proj = matrices = VGroup(
WeightMatrix(shape=(12, 6)),
WeightMatrix(shape=(6, 11)),
)
matrices.scale(0.25)
for arrow, mat in zip(arrows[::2], matrices):
mat.next_to(arrow, UP)
# Put ReLU graph on the middle
axes = Axes((-3, 3), (0, 3))
graph = axes.get_graph(lambda x: max(x, 0))
graph.set_color(BLUE)
relu = VGroup(axes, graph)
relu.match_width(arrows[1])
relu.next_to(arrows[1], UP)
# Full box
box = SurroundingRectangle(VGroup(arrows, matrices), buff=1.0)
box.set_stroke(WHITE, 2)
box.set_fill(GREY_E, 1)
title = Text("Multilayer Perceptron", font_size=60)
title.next_to(box, UP, SMALL_BUFF)
self.add(box, title)
# Animate them all in
for matrix in matrices:
matrix.brackets.save_state()
matrix.brackets.stretch(0, 0).set_opacity(0)
self.play(
LaggedStartMap(GrowArrow, arrows, lag_ratio=0.5),
FadeIn(up_proj.get_rows(), lag_ratio=0.1, time_span=(0.0, 1.5)),
FadeIn(down_proj.get_rows(), lag_ratio=0.1, time_span=(1.5, 3.0)),
Restore(up_proj.brackets, time_span=(0.0, 1.5)),
Restore(down_proj.brackets, time_span=(1.5, 3.0)),
Write(relu, time_span=(1, 2)),
run_time=3
)
self.wait()
# Show row replacement on the first
n, m = up_proj.shape
n_rows_shown = 8
R_labels = VGroup(
Tex(R"\vec{\textbf{R}}_{" + str(n) + "}")
for n in [*range(n_rows_shown - 1), "n-1"]
)
R_labels[-2].become(Tex(R"\vdots").replace(R_labels[-2], dim_to_match=1))
R_labels.arrange(DOWN, buff=0.5)
R_labels.match_height(up_proj)
R_labels.move_to(up_proj)
h_lines = VGroup(
Line(up_proj.get_brackets()[0], R_labels, buff=0.1),
Line(R_labels, up_proj.get_brackets()[1], buff=0.1),
)
h_lines.set_stroke(GREY_A, 2)
row_labels = VGroup(
VGroup(R_label, h_lines.copy().match_y(R_label))
for R_label in R_labels
)
row_labels.set_color(YELLOW)
row_matrix = VGroup(
up_proj.get_brackets().copy(),
row_labels
)
self.play(
FadeOut(up_proj.get_rows(), lag_ratio=0.1),
FadeIn(row_labels, lag_ratio=0.1),
)
self.wait()
self.play(
row_labels[0][0].copy().animate.scale(2).next_to(title, UL).shift(2 * LEFT).set_opacity(0),
)
self.wait()
# Show the neurons
dots = VGroup(
Dot().set_fill(opacity=random.random()).move_to(entry)
for entry in vectors[2].get_columns()[0]
)
for dot in dots:
dot.match_x(dots[0])
dots.set_stroke(WHITE, 1)
self.play(Write(dots))
self.wait()
# Show column replacement on the second
col_matrix = self.get_col_matrix(down_proj, 8)
col_labels = col_matrix[1]
col_labels.set_color(RED_B)
self.play(
FadeOut(down_proj.get_columns(), lag_ratio=0.1),
FadeIn(col_labels, lag_ratio=0.1),
)
self.wait()
self.play(
col_labels[0][0].copy().animate.scale(2).next_to(title, UR).shift(2 * RIGHT).set_opacity(0),
)
self.wait()
return
#### Trash ####
vectors[0].next_to(arrows[0], LEFT)
vectors[0].align_to(vectors[1], DOWN)
self.play(FadeIn(vectors[0]))
for i in (0, 1):
self.play(
FadeTransform(vectors[i].copy(), vectors[i + 1]),
rate_func=linear,
)
class SuperpositionVectorBundle(InteractiveScene):
def construct(self):
# Setup
frame = self.frame
axes = ThreeDAxes(z_range=(-3, 3))
axes.scale(0.5)
vects = VGroup(
self.get_new_vector(v)
for v in np.identity(3)
)
frame.reorient(23, 71, 0, (0.0, 0.0, 0.5), 3.5)
frame.add_ambient_rotation(4 * DEGREES)
self.add(frame)
self.add(axes)
self.add(vects)
self.wait(2)
# Add a new vector
n_vects = 10
for n in range(n_vects):
new_vect = self.get_new_vector(normalize(np.random.uniform(-1, 1, 3)))
# self.play(GrowArrow(new_vect))
vects.add(new_vect)
self.space_out_vectors(vects, run_time=3 + 0.5 * n)
self.wait(5)
# Use tensor flow to repeatedly cram more vectors into a space
pass
def get_new_vector(self, coords, color=None, opacity=0.9):
if color is None:
color = random_bright_color(hue_range=(0.4, 0.6), luminance_range=(0.5, 0.9))
vect = Vector(coords, thickness=2.0)
vect.set_fill(color, opacity=opacity, border_width=2)
vect.always.set_perpendicular_to_camera(self.frame)
return vect
def space_out_vectors(self, vects, run_time=4, learning_rate=0.01):
num_vectors = len(vects)
ends = np.array([v.get_end() for v in vects])
matrix = torch.from_numpy(ends)
matrix.requires_grad_(True)
optimizer = torch.optim.Adam([matrix], lr=learning_rate)
dot_diff_cutoff = 0.01
id_mat = torch.eye(num_vectors, num_vectors)
def update_vects(vects):
optimizer.zero_grad()
dot_products = matrix @ matrix.T
# Punish deviation from orthogonal
diff = dot_products - id_mat
# loss = (diff.abs() - dot_diff_cutoff).relu().sum()
loss = diff.pow(6).sum()
# Extra incentive to keep rows normalized
loss += num_vectors * diff.diag().pow(2).sum()
loss.backward()
optimizer.step()
for vect, arr in zip(vects, matrix):
vect.put_start_and_end_on(ORIGIN, arr.detach().numpy())
self.play(UpdateFromFunc(vects, update_vects, run_time=run_time))
# Some old stubs
class ClassicNeuralNetworksPicture(InteractiveScene):
def construct(self):
pass
class ShowBiasBakedIntoWeightMatrix(LastTwoChapters):
def construct(self):
# Add initial blocks
frame = self.frame
square = Square(2.0)
att_icon = self.get_att_icon(square)
att_icon.set_stroke(WHITE, 1, 0.5)
mlp_icon = self.get_mlp_icon(square, layer_buff=1.0)
lnm_icon = self.get_layer_norm_icon()
lnm_icon.match_height(mlp_icon)
att_block = self.get_block(att_icon, "Attention", "604M Parameters", color=YELLOW)
mlp_block = self.get_block(mlp_icon, "MLP", "1.2B Parameters", color=BLUE)
lnm_block = self.get_block(lnm_icon, "Layer Norm", "49K Parameters", color=GREY_B)
blocks = VGroup(att_block, mlp_block, lnm_block)
blocks.arrange(RIGHT, buff=1.5)
lil_wrapper = self.get_layer_wrapper(blocks[:2].copy())
big_wrapper = self.get_layer_wrapper(blocks)
self.add(lil_wrapper, blocks[:2])
frame.match_x(blocks[:2])
self.wait()
self.play(
frame.animate.match_x(blocks),
ReplacementTransform(lil_wrapper, big_wrapper),
FadeIn(lnm_block, RIGHT),
)
self.wait()
self.play(FlashAround(lnm_block[2], run_time=3, time_width=2))
self.wait()
def get_layer_norm_icon(self):
axes1, axes2 = all_axes = VGroup(
Axes((-4, 4), (0, 1, 0.25))
for x in range(2)
)
all_axes.set_shape(1.5, 0.5)
all_axes.arrange(DOWN, buff=1.0)
graph1 = axes1.get_graph(lambda x: 0.5 * norm.pdf(0.5 * x - 0.5))
graph2 = axes2.get_graph(lambda x: 1.5 * norm.pdf(x))
graph1.set_stroke(BLUE).set_fill(BLUE, 0.25)
graph2.set_stroke(BLUE).set_fill(BLUE, 0.25)
arrow = Arrow(axes1, axes2, buff=0.1)
return VGroup(axes1, graph1, arrow, axes2, graph2)
def get_layer_wrapper(self, blocks):
beige = "#F5F5DC"
rect = self.get_block(blocks, color=beige, buff=0.5, height=4)[0]
wrapped_arrow = self.get_wrapped_arrow(rect)
multiple = Tex(R"\times 96")
multiple.next_to(wrapped_arrow, UP)
arrows = VGroup()
for b1, b2 in zip(blocks, blocks[1:]):
arrows.add(Arrow(b1[0], b2[0], buff=0.1))
return VGroup(rect, arrows, wrapped_arrow, multiple)
def get_block(
self, content,
upper_label="",
lower_label="",
upper_font_size=42,
lower_font_size=36,
buff=0.25,
height=2,
color=BLUE,
stroke_width=3,
fill_opacity=0.2
):
block = SurroundingRectangle(content, buff=buff)
block.set_height(height, stretch=True)
block.round_corners(radius=0.25)
block.set_stroke(color, 3)
block.set_fill(color, fill_opacity)
low_label = Text(lower_label, font_size=lower_font_size)
low_label.next_to(block, DOWN, MED_SMALL_BUFF)
top_label = Text(upper_label, font_size=upper_font_size)
top_label.next_to(block, UP, MED_SMALL_BUFF)
return VGroup(block, content, low_label, top_label)
def get_wrapped_arrow(self, big_block, buff=0.75, color=GREY_B, stroke_width=4):
vertices = [
big_block.get_corner(RIGHT),
big_block.get_corner(RIGHT) + buff * RIGHT,
big_block.get_corner(UR) + buff * UR,
big_block.get_corner(UL) + buff * UL,
big_block.get_corner(LEFT) + buff * LEFT,
big_block.get_corner(LEFT),
]
line = Polygon(*vertices)
line.round_corners()
line.set_points(line.get_points()[:-2, :])
line.set_stroke(color, stroke_width)
tip = ArrowTip().move_to(line.get_end(), RIGHT)
tip.set_color(color)
line.add(tip)
return line
class AlmostOrthogonal(InteractiveScene):
def construct(self):
pass