mirror of
https://github.com/3b1b/videos.git
synced 2025-08-31 21:58:59 +00:00
2627 lines
91 KiB
Python
2627 lines
91 KiB
Python
import torch
|
|
from scipy.stats import norm
|
|
|
|
from _2024.transformers.helpers import *
|
|
from manim_imports_ext import *
|
|
|
|
|
|
class LastTwoChapters(InteractiveScene):
|
|
def construct(self):
|
|
# Show last two chapters
|
|
frame = self.frame
|
|
self.camera.light_source.set_z(15)
|
|
self.set_floor_plane("xz")
|
|
|
|
thumbnails = self.get_thumbnails()
|
|
self.play(
|
|
LaggedStartMap(FadeIn, thumbnails, shift=UP, lag_ratio=0.5)
|
|
)
|
|
self.wait()
|
|
|
|
# Show transformer schematic
|
|
blocks = Group(self.get_block() for x in range(10))
|
|
blocks[1::2].stretch(2, 2).set_opacity(1)
|
|
|
|
blocks.arrange(OUT, buff=0.5)
|
|
blocks.set_depth(8, stretch=True)
|
|
blocks.set_opacity(0.8)
|
|
blocks.apply_depth_test()
|
|
|
|
trans_title = Text("Transformer", font_size=96)
|
|
trans_title.next_to(blocks, UP, buff=0.5)
|
|
|
|
self.play(
|
|
frame.animate.reorient(-32, 0, 0, (0.56, 2.48, 0.32), 12.75),
|
|
thumbnails.animate.scale(0.5).arrange(RIGHT, buff=2.0).to_edge(UP, buff=0.25),
|
|
LaggedStartMap(FadeIn, blocks, shift=0.25 * UP, scale=1.5, lag_ratio=0.1),
|
|
FadeIn(trans_title, UP),
|
|
)
|
|
self.wait()
|
|
|
|
# Break out transformer as sequence of blocks
|
|
att_blocks = blocks[0::2]
|
|
mlp_blocks = blocks[1::2]
|
|
|
|
att_title = Text("Attention", font_size=72)
|
|
mlp_title_full = Text("Multilayer Perceptron", font_size=72)
|
|
mlp_title = Text("MLP", font_size=72)
|
|
|
|
self.play(
|
|
frame.animate.reorient(-3, -2, 0, (0.23, 2.57, 0.3), 12.75),
|
|
trans_title.animate.shift(2 * UP),
|
|
att_blocks.animate.shift(4 * LEFT),
|
|
mlp_blocks.animate.shift(4 * RIGHT),
|
|
)
|
|
|
|
att_icon = self.get_att_icon(att_blocks[-1])
|
|
mlp_icon = self.get_mlp_icon(mlp_blocks[-1])
|
|
att_title.next_to(att_blocks[-1], UP, buff=0.75)
|
|
for title in [mlp_title, mlp_title_full]:
|
|
title.next_to(mlp_blocks[-1], UP, buff=0.75)
|
|
self.play(
|
|
FadeIn(att_icon, lag_ratio=1e-3),
|
|
FadeIn(att_title, UP),
|
|
trans_title.animate.scale(0.75).set_opacity(0.5)
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
Write(mlp_icon),
|
|
FadeIn(mlp_title_full, UP),
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
TransformMatchingStrings(mlp_title_full, mlp_title)
|
|
)
|
|
self.wait()
|
|
|
|
# Show sports facts
|
|
sport_facts = VGroup(
|
|
Text(line)
|
|
for line in Path(DATA_DIR, "athlete_sports.txt").read_text().split("\n")
|
|
)
|
|
for fact in sport_facts:
|
|
fact.next_to(trans_title, UP)
|
|
fact.shift(random.uniform(-3, 3) * RIGHT)
|
|
fact.shift(random.uniform(0, 3) * UP)
|
|
|
|
self.remove(mlp_icon, mlp_title)
|
|
self.play(
|
|
FadeOut(thumbnails),
|
|
FadeOut(trans_title),
|
|
LaggedStart(
|
|
(Succession(FadeIn(fact), fact.animate.scale(0.5).set_opacity(0).move_to(mlp_blocks))
|
|
for fact in sport_facts),
|
|
lag_ratio=0.15,
|
|
)
|
|
)
|
|
self.wait()
|
|
|
|
# Ask what is the MLP
|
|
rect = SurroundingRectangle(Group(mlp_blocks, mlp_title), buff=1.0)
|
|
rect.stretch(0.8, 1)
|
|
rect.match_z(mlp_blocks[-1])
|
|
question = Text("What are these?", font_size=90)
|
|
question.next_to(rect, UP, buff=3.0)
|
|
question.match_color(rect)
|
|
question.set_fill(border_width=0.5)
|
|
arrow = Arrow(question, rect)
|
|
arrow.match_color(rect)
|
|
|
|
self.play(
|
|
Group(att_blocks, att_title).animate.fade(0.5),
|
|
ShowCreation(rect),
|
|
Write(question),
|
|
GrowArrow(arrow),
|
|
)
|
|
self.wait()
|
|
|
|
def get_thumbnails(self):
|
|
folder = "/Users/grant/3Blue1Brown Dropbox/3Blue1Brown/videos/2024/transformers/Thumbnails"
|
|
images = [
|
|
ImageMobject(str(Path(folder, "Chapter5_TN5"))),
|
|
ImageMobject(str(Path(folder, "Chapter6_TN4"))),
|
|
]
|
|
thumbnails = Group(
|
|
Group(
|
|
SurroundingRectangle(image, buff=0).set_stroke(WHITE, 3),
|
|
image
|
|
)
|
|
for n, image in zip([5, 6], images)
|
|
)
|
|
thumbnails.set_height(3.5)
|
|
thumbnails.arrange(RIGHT, buff=1.0)
|
|
thumbnails.fix_in_frame()
|
|
return thumbnails
|
|
|
|
def get_att_icon(self, block, n_rows=8):
|
|
att_icon = Dot().get_grid(n_rows, n_rows)
|
|
att_icon.set_height(block.get_height() * 0.9)
|
|
att_icon.set_backstroke(BLACK, 0.5)
|
|
for dot in att_icon:
|
|
dot.set_fill(opacity=random.random()**5)
|
|
att_icon.move_to(block, OUT)
|
|
return att_icon
|
|
|
|
def get_mlp_icon(self, block, dot_buff=0.15, layer_buff=1.5, layer0_size=5):
|
|
layers = VGroup(
|
|
Dot().get_grid(layer0_size, 1, buff=dot_buff),
|
|
Dot().get_grid(2 * layer0_size, 1, buff=dot_buff),
|
|
Dot().get_grid(layer0_size, 1, buff=dot_buff),
|
|
)
|
|
layers.set_height(block.get_height() * 0.9)
|
|
layers.arrange(RIGHT, buff=layer_buff)
|
|
for layer in layers:
|
|
for dot in layer:
|
|
dot.set_fill(opacity=random.random())
|
|
layers.set_stroke(WHITE, 0.5)
|
|
lines = VGroup(
|
|
Line(dot1.get_center(), dot2.get_center(), buff=dot1.get_width() / 2)
|
|
for l1, l2 in zip(layers, layers[1:])
|
|
for dot1 in l1
|
|
for dot2 in l2
|
|
)
|
|
for line in lines:
|
|
line.set_stroke(
|
|
color=value_to_color(random.uniform(-10, 10)),
|
|
width=3 * random.random()**3
|
|
)
|
|
|
|
icon = VGroup(layers, lines)
|
|
icon.move_to(block, OUT)
|
|
return icon
|
|
|
|
def get_block(self, width=5, height=3, depth=1, color=GREY_D, opacity=0.8):
|
|
block = Cube(color=color, opacity=opacity)
|
|
block.deactivate_depth_test()
|
|
block.set_shape(width, height, depth)
|
|
block.set_shading(0.5, 0.5, 0.0)
|
|
block.sort(lambda p: np.dot(p, [-1, 1, 1]))
|
|
return block
|
|
|
|
|
|
class AltLastTwoChapters(LastTwoChapters):
|
|
def construct(self):
|
|
# Show last two chapters
|
|
thumbnails = self.get_thumbnails()
|
|
thumbnails.set_height(2.0)
|
|
thumbnails.arrange(RIGHT, buff=2.0)
|
|
thumbnails.to_edge(UP)
|
|
for n, thumbnail in zip([5, 6], thumbnails):
|
|
label = Text(f"Chapter {n}")
|
|
label.next_to(thumbnail, DOWN, SMALL_BUFF)
|
|
thumbnail.add(label)
|
|
|
|
self.play(
|
|
LaggedStartMap(FadeIn, thumbnails, shift=UP, lag_ratio=0.5)
|
|
)
|
|
self.wait()
|
|
|
|
# Focus on chapter 6
|
|
for thumbnail in thumbnails:
|
|
thumbnail.target = thumbnail.generate_target()
|
|
thumbnail.target.scale(1.25)
|
|
thumbnail.target[-1].scale(1.0 / 1.5).next_to(thumbnail.target[0], DOWN, SMALL_BUFF)
|
|
thumbnails[1].target.set_x(-2.85)
|
|
thumbnails[1].target.to_edge(UP, MED_SMALL_BUFF)
|
|
thumbnails[0].target.next_to(thumbnails[1].target, LEFT, buff=2.5)
|
|
|
|
self.play(
|
|
LaggedStartMap(MoveToTarget, thumbnails)
|
|
)
|
|
self.wait()
|
|
|
|
|
|
class MLPIcon(LastTwoChapters):
|
|
def construct(self):
|
|
# Add network
|
|
network = self.get_mlp_icon(Square(6), layer_buff=3.0, layer0_size=6)
|
|
self.play(Write(network, stroke_width=0.5, lag_ratio=1e-2, run_time=5))
|
|
self.wait()
|
|
|
|
# Propagate through
|
|
thick_layers = VGroup(network[1].family_members_with_points()).copy()
|
|
for line in thick_layers:
|
|
line.set_stroke(width=2 * line.get_width())
|
|
line.insert_n_curves(20)
|
|
self.play(LaggedStartMap(VShowPassingFlash, thick_layers, time_width=1.5, lag_ratio=5e-3, run_time=3))
|
|
self.wait()
|
|
|
|
|
|
class MLPStepsPreview(InteractiveScene):
|
|
def construct(self):
|
|
# Setup framing
|
|
background = FullScreenRectangle()
|
|
top_frame, low_frame = frames = Rectangle(7, 3.25).replicate(2)
|
|
frames.arrange(DOWN, buff=0.5)
|
|
frames.to_edge(LEFT)
|
|
frames.set_fill(BLACK, 1)
|
|
frames.set_stroke(WHITE, 2)
|
|
|
|
titles = VGroup(
|
|
VGroup(Text("Structure:"), Text("Easy")),
|
|
VGroup(Text("Emergent behavior:"), Text("Exceedingly challenging")),
|
|
)
|
|
for title, frame, color in zip(titles, frames, [GREEN, RED]):
|
|
title.scale(2)
|
|
for part in title:
|
|
part.set_max_width(6)
|
|
title.arrange(DOWN, buff=0.5, aligned_edge=LEFT)
|
|
title.next_to(frame, RIGHT, buff=0.5)
|
|
title[1].set_color(color)
|
|
|
|
titles[0].save_state()
|
|
top_frame.save_state()
|
|
top_frame.set_shape(8, 6).center().to_edge(LEFT)
|
|
titles[0].next_to(top_frame, RIGHT, buff=0.5)
|
|
|
|
self.add(background)
|
|
self.add(top_frame)
|
|
self.add(titles[0][0])
|
|
|
|
# Add all steps
|
|
arrows = Vector(2.2 * RIGHT).get_grid(1, 3, buff=0.25)
|
|
arrows.move_to(top_frame)
|
|
up_proj = WeightMatrix(shape=(10, 6))
|
|
down_proj = WeightMatrix(shape=(6, 10))
|
|
VGroup(up_proj, down_proj).match_width(arrows[0])
|
|
up_proj.next_to(arrows[0], UP, buff=MED_SMALL_BUFF)
|
|
down_proj.next_to(arrows[2], UP, buff=MED_SMALL_BUFF)
|
|
|
|
axes = Axes((-4, 4), (0, 4))
|
|
graph = axes.get_graph(lambda x: max(0, x))
|
|
graph.set_stroke(YELLOW, 5)
|
|
plot = VGroup(axes, graph)
|
|
plot.set_width(arrows[0].get_width() * 0.75)
|
|
plot.next_to(arrows[1], UP, buff=MED_SMALL_BUFF)
|
|
|
|
labels = VGroup(*map(Text, ["Linear", "ReLU", "Linear"]))
|
|
for label, arrow in zip(labels, arrows):
|
|
label.next_to(arrow, DOWN)
|
|
|
|
structure = VGroup(arrows, labels, VGroup(up_proj, plot, down_proj))
|
|
|
|
self.play(
|
|
LaggedStartMap(GrowArrow, arrows, lag_ratio=0.5),
|
|
LaggedStartMap(FadeIn, labels, shift=0.5 * RIGHT, lag_ratio=0.5),
|
|
Write(titles[0][1])
|
|
)
|
|
self.play(LaggedStart(
|
|
FadeIn(up_proj, shift=0.5 * UP),
|
|
FadeIn(down_proj, shift=0.5 * UP),
|
|
lag_ratio=0.5
|
|
))
|
|
self.play(FadeIn(plot, lag_ratio=1e-2))
|
|
self.wait(3)
|
|
|
|
# Reference emergent structure
|
|
|
|
self.play(
|
|
Restore(top_frame),
|
|
Restore(titles[0]),
|
|
structure.animate.set_width(0.9 * top_frame.saved_state.get_width()).move_to(top_frame.saved_state),
|
|
FadeIn(low_frame, DOWN),
|
|
FadeIn(titles[1][0], DOWN),
|
|
)
|
|
self.play(
|
|
Write(titles[1][1], stroke_color=RED)
|
|
)
|
|
|
|
# Data flying
|
|
kw = dict(font_size=16, shift_vect=0.5 * DOWN + 0.5 * RIGHT, word_shape=(5, 5))
|
|
data_modifying_matrix(self, up_proj, **kw)
|
|
data_modifying_matrix(self, down_proj, **kw)
|
|
self.wait()
|
|
|
|
# Swap out for toy example
|
|
toy_example_title = Text("Motivating Toy Example", font_size=54)
|
|
toy_example_title.next_to(titles[1][0], DOWN, MED_LARGE_BUFF, aligned_edge=LEFT)
|
|
strike = Line().replace(titles[1][0])
|
|
strike.set_stroke(RED, 8)
|
|
|
|
low_matrices = VGroup(up_proj, down_proj)
|
|
top_matrices = low_matrices.copy()
|
|
low_matrices.generate_target()
|
|
low_matrices.target.scale(1.75).arrange(RIGHT, buff=0.5)
|
|
low_matrices.target.move_to(low_frame, DOWN).shift(MED_SMALL_BUFF * UP)
|
|
|
|
self.play(
|
|
ShowCreation(strike),
|
|
FadeOut(titles[1][1]),
|
|
titles[1][0].animate.set_opacity(0.5)
|
|
)
|
|
self.add(top_matrices)
|
|
self.play(
|
|
MoveToTarget(low_matrices),
|
|
FadeIn(toy_example_title, DOWN)
|
|
)
|
|
self.wait()
|
|
|
|
# Write down fact
|
|
row_rect = SurroundingRectangle(low_matrices[0].get_rows()[0], buff=0.1)
|
|
col_rect = SurroundingRectangle(low_matrices[1].get_columns()[0], buff=0.1)
|
|
VGroup(row_rect, col_rect).set_stroke(WHITE, 1)
|
|
fact = Text("Michael Jordan plays Basketball", font_size=36)
|
|
fact.next_to(frames[1].get_top(), DOWN)
|
|
fact.align_to(low_matrices, LEFT)
|
|
mj, bb = fact["Michael Jordan"], fact["plays Basketball"]
|
|
mj_brace = Brace(mj, DOWN, buff=0.1)
|
|
bb_brace = Brace(bb, DOWN).match_y(mj_brace)
|
|
mj_arrow = Arrow(row_rect, mj_brace, buff=0.05)
|
|
bb_arrow = Arrow(col_rect.get_top(), bb_brace, buff=0.05)
|
|
|
|
row_cover = BackgroundRectangle(low_matrices[0].get_rows()[1:], buff=0.05)
|
|
col_cover = BackgroundRectangle(low_matrices[1].get_columns()[1:], buff=0.05)
|
|
VGroup(row_cover, col_cover).set_fill(BLACK, 0.75)
|
|
|
|
self.play(LaggedStart(
|
|
FadeIn(row_cover),
|
|
FadeIn(row_rect),
|
|
GrowFromCenter(mj_brace),
|
|
FadeIn(mj, 0.5 * UP)
|
|
))
|
|
self.play(
|
|
FadeIn(col_cover),
|
|
FadeIn(col_rect),
|
|
GrowArrow(bb_arrow),
|
|
GrowFromCenter(bb_brace),
|
|
FadeIn(bb, 0.5 * UP)
|
|
)
|
|
self.add(*low_matrices, row_cover, col_cover, row_rect, col_rect)
|
|
self.play(
|
|
RandomizeMatrixEntries(low_matrices[0]),
|
|
RandomizeMatrixEntries(low_matrices[1]),
|
|
)
|
|
self.wait()
|
|
|
|
|
|
class MatricesVsIntuition(InteractiveScene):
|
|
def construct(self):
|
|
# Add matrix
|
|
matrix = WeightMatrix(shape=(15, 15))
|
|
matrix.set_height(4)
|
|
matrix.to_edge(LEFT)
|
|
|
|
Text("Matrices filled with parameters\nlearned during gradient descent")
|
|
Text("Motivating examples which risk being\noversimplifications of what true models do")
|
|
|
|
self.add(matrix)
|
|
|
|
|
|
class BasicMLPWalkThrough(InteractiveScene):
|
|
random_seed = 1
|
|
|
|
def construct(self):
|
|
# Init camera settings
|
|
self.set_floor_plane("xz")
|
|
frame = self.frame
|
|
self.camera.light_source.set_z(15)
|
|
|
|
# Sequence of embeddings comes in to an MLP block
|
|
embedding_array = EmbeddingArray(shape=(6, 9))
|
|
embedding_array.set_width(10)
|
|
|
|
block = VCube(fill_color=GREY_D, fill_opacity=0.5)
|
|
block.sort(lambda p: p[2])
|
|
block[-1].set_fill(opacity=0)
|
|
block.set_stroke(GREY_B, 2, 0.25, behind=False)
|
|
block.set_shading(0.25, 0.25, 0.5)
|
|
block.set_shape(11, 4, 4)
|
|
block.move_to(0.5 * IN, IN)
|
|
block_title = Text("MLP", font_size=90)
|
|
block_title.next_to(block, UP)
|
|
|
|
frame.reorient(-21, -12, 0, (0.34, -0.94, -0.18), 9.79)
|
|
frame.set_field_of_view(30 * DEGREES)
|
|
self.add(block, block_title)
|
|
self.play(FadeIn(embedding_array, shift=2 * OUT))
|
|
self.wait()
|
|
|
|
# Highlight one vector
|
|
index = 3
|
|
emb = embedding_array.embeddings[index]
|
|
highlight_rect = SurroundingRectangle(emb)
|
|
embedding_array.target = embedding_array.generate_target()
|
|
embedding_array.target.set_stroke(width=0)
|
|
embedding_array.target.set_opacity(0.5)
|
|
embedding_array.target[0][index].set_backstroke(BLACK, 2)
|
|
embedding_array.target[0][index].set_opacity(1)
|
|
|
|
self.play(
|
|
MoveToTarget(embedding_array),
|
|
ShowCreation(highlight_rect),
|
|
)
|
|
self.wait()
|
|
|
|
# Reorient
|
|
rot_about_up = 89 * DEGREES
|
|
rot_about_left = 1 * DEGREES
|
|
up_emb = emb.copy() # For use down below
|
|
full_block = Group(block, embedding_array, highlight_rect, block_title)
|
|
full_block.target = full_block.generate_target()
|
|
full_block.target[0].set_depth(16, about_edge=IN, stretch=True)
|
|
full_block.target[0].set_height(5, about_edge=DOWN, stretch=True)
|
|
full_block.target.rotate(rot_about_up, UP)
|
|
full_block.target[:3].rotate(rot_about_left, LEFT)
|
|
full_block.target.scale(0.5)
|
|
full_block.target[3].rotate(90 * DEGREES, DOWN).next_to(full_block.target[0], UP, buff=0.5)
|
|
full_block.target.center().to_edge(DOWN, buff=0.75)
|
|
full_block.target[0][4].set_opacity(0.1)
|
|
|
|
self.play(
|
|
frame.animate.reorient(-3, -2, 0, (-0.0, -2.0, 0.01), 6.48),
|
|
MoveToTarget(full_block),
|
|
run_time=2
|
|
)
|
|
|
|
# Preview the sequence of operations
|
|
values = np.random.uniform(-10, 10, 9)
|
|
values[0] = 1.0
|
|
vects = VGroup(
|
|
NumericEmbedding(values=values, dark_color=GREY_B),
|
|
NumericEmbedding(values=np.clip(values, 0, np.inf), dark_color=GREY_B),
|
|
NumericEmbedding(length=6),
|
|
)
|
|
vects.set_width(emb.get_depth())
|
|
vects.arrange(RIGHT, buff=2.0)
|
|
vects.next_to(emb, RIGHT, buff=2.0)
|
|
|
|
arrows = VGroup(
|
|
Arrow(v1, v2)
|
|
for v1, v2 in zip([emb, *vects[:-1]], vects)
|
|
)
|
|
arrow_labels = VGroup(Text("Linear"), Text("ReLU"), Text("Linear"))
|
|
arrow_labels.scale(0.5)
|
|
|
|
phases = VGroup()
|
|
simple_phases = VGroup()
|
|
for arrow, label, vect in zip(arrows, arrow_labels, vects):
|
|
label.next_to(arrow, UP)
|
|
phases.add(VGroup(arrow, label, vect))
|
|
simple_phases.add(VGroup(arrow, vect))
|
|
|
|
self.play(
|
|
LaggedStartMap(FadeIn, vects, shift=RIGHT, lag_ratio=0.8),
|
|
LaggedStartMap(ShowCreation, arrows, lag_ratio=0.8),
|
|
LaggedStartMap(FadeIn, arrow_labels, lag_ratio=0.8),
|
|
)
|
|
self.wait()
|
|
|
|
# Show the sum
|
|
sum_circuit, output_emb = self.get_sum_circuit(emb, vects[-1])
|
|
|
|
self.play(
|
|
frame.animate.reorient(15, -4, 0, (0.82, -1.91, 0.04), 7.18),
|
|
ShowCreation(sum_circuit, lag_ratio=0.1),
|
|
run_time=2
|
|
)
|
|
self.play(
|
|
TransformFromCopy(emb, output_emb, path_arc=-30 * DEGREES),
|
|
TransformFromCopy(vects[2], output_emb, path_arc=-30 * DEGREES),
|
|
run_time=2
|
|
)
|
|
self.wait()
|
|
|
|
# Show all in parallel
|
|
simple_phases.add_to_back(highlight_rect)
|
|
simple_phases.add(VGroup(sum_circuit, output_emb))
|
|
simple_phase_copies = VGroup(
|
|
simple_phases.copy().match_z(emb)
|
|
for emb in embedding_array.embeddings
|
|
)
|
|
for sp_copy in simple_phase_copies:
|
|
for group in sp_copy[1:]:
|
|
arrow, vect = group
|
|
for entry in vect.get_entries():
|
|
dot = Dot().scale(0.5)
|
|
dot.match_color(entry)
|
|
dot.set_fill(opacity=0.5)
|
|
dot.move_to(entry)
|
|
entry.become(dot)
|
|
group.fade(0.5)
|
|
|
|
self.play(
|
|
frame.animate.reorient(0, -48, 0, (0.55, -2.21, 0.18), 7.05),
|
|
LaggedStart((
|
|
TransformFromCopy(simple_phases, sp_copy)
|
|
for sp_copy in simple_phase_copies
|
|
), lag_ratio=0.1),
|
|
FadeOut(block_title, time_span=(0, 1)),
|
|
run_time=3,
|
|
)
|
|
self.play(frame.animate.reorient(9, -15, 0, (0.55, -2.21, 0.18), 7.05), run_time=4)
|
|
self.play(frame.animate.reorient(-24, -16, 0, (0.18, -2.13, 0.09), 7.63), run_time=12)
|
|
block_title.next_to(block, UP)
|
|
self.play(
|
|
frame.animate.to_default_state(),
|
|
LaggedStartMap(FadeOut, simple_phase_copies, lag_ratio=0.1),
|
|
FadeIn(block_title),
|
|
run_time=2,
|
|
)
|
|
self.wait()
|
|
|
|
# Show MJ -> Basketball example
|
|
example_fact = TexText("``Michael Jordan plays Basketball''", font_size=60)
|
|
example_fact.to_edge(UP)
|
|
|
|
mj = TexText("Michael Jordan", font_size=36)
|
|
mj.next_to(emb, UL)
|
|
mj_lines = VGroup(
|
|
Line(char.get_bottom(), emb.get_top(), buff=0.1, path_arc=10 * DEGREES)
|
|
for char in mj
|
|
)
|
|
mj_lines.set_stroke(YELLOW, 1, 0.5)
|
|
|
|
basketball = TexText("Basketball", font_size=24)
|
|
basketball.next_to(vects[2], UP, buff=0.2)
|
|
|
|
self.play(Write(example_fact))
|
|
self.wait()
|
|
self.play(FadeTransform(example_fact[mj.get_tex()].copy(), mj))
|
|
self.play(Write(mj_lines, stroke_width=2, stroke_color=YELLOW_B, lag_ratio=1e-2))
|
|
self.wait()
|
|
|
|
mover = emb.copy()
|
|
for vect in vects:
|
|
self.play(Transform(mover, vect, rate_func=linear))
|
|
self.remove(mover)
|
|
self.wait()
|
|
self.play(FadeTransform(example_fact[basketball.get_tex()].copy(), basketball))
|
|
self.wait(2)
|
|
|
|
# Multiply by the up-projection
|
|
up_proj = WeightMatrix(shape=(9, 6))
|
|
up_proj.set_height(3)
|
|
up_proj.to_corner(UL)
|
|
up_emb.set_height(2)
|
|
up_emb.next_to(up_proj, RIGHT)
|
|
up_emb[-2:].set_fill(YELLOW) # Brackets
|
|
|
|
self.play(
|
|
phases[1:].animate.set_opacity(0.1),
|
|
sum_circuit.animate.set_stroke(opacity=0.1),
|
|
output_emb.animate.set_opacity(0.1),
|
|
FadeOut(mj),
|
|
FadeOut(mj_lines),
|
|
FadeOut(basketball),
|
|
FadeOut(example_fact),
|
|
)
|
|
self.wait()
|
|
self.play(TransformFromCopy(emb, up_emb))
|
|
self.play(FadeIn(up_proj, lag_ratio=0.01))
|
|
eq, rhs = show_matrix_vector_product(self, up_proj, up_emb)
|
|
self.wait()
|
|
data_modifying_matrix(self, up_proj, word_shape=(4, 7), fix_in_frame=True)
|
|
self.wait()
|
|
|
|
# Show machine
|
|
machine = MachineWithDials(
|
|
width=up_proj.get_width() + SMALL_BUFF,
|
|
height=up_proj.get_height() + SMALL_BUFF,
|
|
n_rows=8,
|
|
n_cols=9,
|
|
)
|
|
machine.move_to(up_proj)
|
|
|
|
self.play(FadeIn(machine))
|
|
self.play(machine.random_change_animation())
|
|
self.wait()
|
|
self.play(FadeOut(machine))
|
|
|
|
# Emphasize dot product with rows
|
|
n, m = up_proj.shape
|
|
n_rows_shown = 5
|
|
R_labels = VGroup(
|
|
Tex(R"\vec{\textbf{R}}_" + f"{{{n}}}")
|
|
for n in [*range(n_rows_shown - 1), "n"]
|
|
)
|
|
R_labels[-2].become(Tex(R"\vdots").replace(R_labels[-2], dim_to_match=1))
|
|
R_labels.arrange(DOWN, buff=0.5)
|
|
R_labels.match_height(up_proj)
|
|
R_labels.move_to(up_proj)
|
|
h_lines = VGroup(
|
|
Line(up_proj.get_brackets()[0], R_labels, buff=0.1),
|
|
Line(R_labels, up_proj.get_brackets()[1], buff=0.1),
|
|
)
|
|
h_lines.set_stroke(GREY_A, 2)
|
|
row_labels = VGroup(
|
|
VGroup(R_label, h_lines.copy().match_y(R_label))
|
|
for R_label in R_labels
|
|
)
|
|
row_matrix = VGroup(
|
|
up_proj.get_brackets().copy(),
|
|
row_labels
|
|
)
|
|
|
|
E_label = Tex(R"\vec{\textbf{E}}")
|
|
E_label.match_height(R_labels[0])
|
|
E_label.set_color(YELLOW)
|
|
E_label.move_to(up_emb)
|
|
E_col = VGroup(
|
|
up_emb[-2:].copy(),
|
|
Line(up_emb.get_top(), E_label, buff=0.1).set_stroke(GREY_A, 2),
|
|
E_label,
|
|
Line(E_label, up_emb.get_bottom(), buff=0.1).set_stroke(GREY_A, 2),
|
|
)
|
|
|
|
dot_prods = VGroup()
|
|
for n, R_label in enumerate(R_labels):
|
|
if n == len(R_labels) - 2:
|
|
dot_prod = R_label.copy()
|
|
else:
|
|
dot_prod = VGroup(
|
|
R_label.copy(),
|
|
Tex(R"\cdot"),
|
|
E_label.copy(),
|
|
)
|
|
dot_prod.arrange(RIGHT, buff=0.1)
|
|
dot_prod[-1].align_to(dot_prod[0][1], DOWN)
|
|
dot_prod.set_width(rhs.get_width() * 0.75)
|
|
dot_prod.move_to(R_label)
|
|
dot_prods.add(dot_prod)
|
|
dot_prods.move_to(rhs)
|
|
dot_prod_rhs = VGroup(
|
|
rhs.get_brackets().copy(),
|
|
dot_prods,
|
|
)
|
|
|
|
self.play(LaggedStart(
|
|
FadeOut(up_proj, scale=1.1),
|
|
FadeIn(row_matrix, scale=1.1),
|
|
FadeOut(up_emb, scale=1.1),
|
|
FadeIn(E_col, scale=1.1),
|
|
FadeOut(rhs, scale=1.1),
|
|
FadeIn(dot_prod_rhs[0], scale=1.1),
|
|
lag_ratio=0.1
|
|
))
|
|
self.wait()
|
|
for row_label, dot_prod in zip(row_labels, dot_prods):
|
|
R_label = row_label[0]
|
|
self.play(
|
|
TransformFromCopy(R_label, dot_prod[0]),
|
|
TransformFromCopy(R_label, dot_prod[1]),
|
|
TransformFromCopy(E_label, dot_prod[2]),
|
|
VShowPassingFlash(
|
|
Line(row_label.get_left(), row_label.get_right()).set_stroke(YELLOW, 5).insert_n_curves(100),
|
|
time_width=1.5
|
|
),
|
|
VShowPassingFlash(
|
|
Line(E_col.get_top(), E_col.get_bottom()).set_stroke(YELLOW, 5).insert_n_curves(100),
|
|
time_width=1.5
|
|
),
|
|
run_time=1
|
|
)
|
|
self.wait()
|
|
|
|
# First name Michael direction
|
|
row_rect = SurroundingRectangle(row_labels[0])
|
|
row_rect.set_stroke(GREY_BROWN, 2)
|
|
row_rect.set_fill(GREY_BROWN, 0.25)
|
|
row_eq = Tex("=").rotate(PI / 2)
|
|
row_eq.next_to(row_rect, UP, SMALL_BUFF)
|
|
first_name_label = Tex(R"\overrightarrow{\text{First Name Michael}}")
|
|
first_name_label.set_stroke(WHITE, 1)
|
|
first_name_label.match_width(row_rect)
|
|
first_name_label.next_to(row_eq, UP)
|
|
|
|
dot_prod = dot_prods[0]
|
|
dp_rect = SurroundingRectangle(dot_prod, buff=0.2)
|
|
dp_rect.set_stroke(RED)
|
|
dp_eq = Tex("=")
|
|
dp_eq.next_to(dp_rect, RIGHT, SMALL_BUFF)
|
|
mde_rhs = VGroup(
|
|
Tex(R"\approx 1 \quad \text{If } \vec{\textbf{E}} \text{ encodes ``First Name Michael''}"),
|
|
Tex(R"\le 0 \quad \text{If not}")
|
|
)
|
|
mde_rhs[0][R"\vec{\textbf{E}}"].set_color(YELLOW)
|
|
mde_rhs.scale(0.75)
|
|
mde_rhs.arrange(DOWN, buff=0.5, aligned_edge=LEFT)
|
|
rhs_brace = Brace(mde_rhs, LEFT)
|
|
rhs_brace.next_to(dp_eq, RIGHT, SMALL_BUFF)
|
|
mde_rhs.next_to(rhs_brace, RIGHT, MED_SMALL_BUFF)
|
|
|
|
self.play(
|
|
FadeIn(row_rect, scale=2),
|
|
FadeTransform(row_labels[0].copy(), first_name_label),
|
|
GrowFromCenter(row_eq),
|
|
frame.animate.reorient(0, 0, 0, (0.22, 0.54, 0.0), 9.27),
|
|
)
|
|
self.wait()
|
|
|
|
self.play(TransformFromCopy(row_rect.copy().set_fill(opacity=0), dp_rect))
|
|
self.play(
|
|
Write(dp_eq),
|
|
GrowFromCenter(rhs_brace),
|
|
FadeIn(mde_rhs),
|
|
)
|
|
self.wait()
|
|
|
|
# "First name Michael" + "Last name Jordan"
|
|
fn_tex = R"\overrightarrow{\text{F.N. Michael}}"
|
|
ln_tex = R"\overrightarrow{\text{L.N. Jordan}}"
|
|
name_sum_label = Tex(f"{fn_tex} + {ln_tex}")
|
|
name_sum_label.match_width(row_rect).scale(1.2)
|
|
name_sum_label.next_to(row_eq, UP)
|
|
|
|
self.play(
|
|
FadeTransform(first_name_label, name_sum_label[:21]),
|
|
FadeIn(name_sum_label[21:], shift=RIGHT, scale=2),
|
|
FadeOut(mde_rhs),
|
|
FadeOut(rhs_brace),
|
|
)
|
|
self.wait()
|
|
|
|
dist_rhs = VGroup(
|
|
Tex(R"(\vec{\textbf{M}} + \vec{\textbf{J}}) \cdot \vec{\textbf{E}}"),
|
|
Tex("="),
|
|
Tex(R"\vec{\textbf{M}} \cdot \vec{\textbf{E}} + \vec{\textbf{J}} \cdot \vec{\textbf{E}}"),
|
|
)
|
|
dist_rhs.scale(0.75)
|
|
dist_rhs.arrange(RIGHT, buff=0.2)
|
|
dist_rhs.next_to(dp_eq, RIGHT)
|
|
for part in dist_rhs:
|
|
part[R"\vec{\textbf{M}}"].set_color(RED_B)
|
|
part[R"\vec{\textbf{J}}"].set_color(RED)
|
|
part[R"\vec{\textbf{E}}"].set_color(YELLOW)
|
|
under_brace = Brace(dist_rhs[2])
|
|
|
|
two_condition = TexText(R"$\approx 2$ \; if $\vec{\textbf{E}}$ encodes ``Michael Jordan''")
|
|
two_condition[R"\vec{\textbf{E}}"].set_color(YELLOW)
|
|
else_condition = TexText(R"$\le 1$ \; Otherwise")
|
|
VGroup(two_condition, else_condition).scale(0.75)
|
|
two_condition.next_to(under_brace, DOWN, aligned_edge=LEFT)
|
|
else_condition.next_to(two_condition, DOWN, MED_LARGE_BUFF, aligned_edge=LEFT)
|
|
|
|
self.play(LaggedStart(
|
|
FadeTransformPieces(name_sum_label[:21].copy(), dist_rhs[0][1:3]),
|
|
FadeTransformPieces(name_sum_label[21].copy(), dist_rhs[0][3]),
|
|
FadeTransformPieces(name_sum_label[22:].copy(), dist_rhs[0][4:6]),
|
|
FadeTransformPieces(dot_prod[1:].copy(), dist_rhs[0][7:]),
|
|
FadeIn(dist_rhs[0][0]),
|
|
FadeIn(dist_rhs[0][6]),
|
|
lag_ratio=0.2
|
|
))
|
|
self.wait()
|
|
self.play(
|
|
TransformMatchingStrings(dist_rhs[0].copy(), dist_rhs[2], lag_ratio=0.01, path_arc=-45 * DEGREES),
|
|
Write(dist_rhs[1])
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
frame.animate.set_y(0.5),
|
|
GrowFromCenter(under_brace),
|
|
FadeIn(two_condition, DOWN)
|
|
)
|
|
self.wait()
|
|
self.play(FadeIn(else_condition, DOWN))
|
|
self.wait(2)
|
|
|
|
# Go back to the numbers
|
|
for entry in rhs.get_entries():
|
|
entry.set_value(np.random.uniform(-10, 10))
|
|
rhs.get_entries()[0].set_value(2.0)
|
|
self.play(
|
|
LaggedStart(*map(FadeOut, [
|
|
name_sum_label, row_eq, row_rect,
|
|
dp_rect, dp_eq, dist_rhs, under_brace,
|
|
two_condition, else_condition,
|
|
]), lag_ratio=0.1, run_time=1),
|
|
frame.animate.reorient(0, 0, 0, (-0.06, -0.06, 0.0), 8.27),
|
|
)
|
|
self.play(
|
|
FadeOut(row_matrix),
|
|
FadeIn(up_proj),
|
|
FadeOut(E_col),
|
|
FadeIn(up_emb),
|
|
FadeOut(dot_prod_rhs),
|
|
FadeIn(rhs),
|
|
)
|
|
|
|
# Show other rows
|
|
questions = VGroup(*map(Text, [
|
|
"Blah",
|
|
"Is it English?",
|
|
"Part of source code?",
|
|
"European country?",
|
|
"In quotation marks?",
|
|
"Something metallic?",
|
|
"A four-legged animal?",
|
|
]))
|
|
questions.scale(0.75)
|
|
rows = up_proj.get_rows()
|
|
rhs_entries = rhs.get_entries()
|
|
last_question = VGroup()
|
|
last_rect = VectorizedPoint(rows[1].get_top())
|
|
for index in range(1, 7):
|
|
for mob in [rows, rhs_entries]:
|
|
mob.target = mob.generate_target()
|
|
mob.target.set_opacity(0.25)
|
|
mob.target[index].set_opacity(1)
|
|
row_rect = SurroundingRectangle(rows[index])
|
|
row_rect.set_stroke(PINK, 2)
|
|
question = questions[index]
|
|
question.next_to(rows[index], UP, buff=0.15)
|
|
question.set_backstroke(BLACK, 3)
|
|
self.play(
|
|
MoveToTarget(rows),
|
|
MoveToTarget(rhs_entries),
|
|
FadeOut(last_question),
|
|
FadeIn(question),
|
|
FadeTransform(last_rect, row_rect, time_span=(0, 0.75)),
|
|
run_time=1.0
|
|
)
|
|
self.wait(0.5)
|
|
last_question = question
|
|
last_rect = row_rect
|
|
self.play(
|
|
rows.animate.set_opacity(1),
|
|
rhs.animate.set_opacity(1),
|
|
FadeOut(last_question),
|
|
FadeOut(last_rect),
|
|
)
|
|
self.wait()
|
|
|
|
# Add a bias
|
|
plus = Tex("+")
|
|
plus.next_to(up_emb, RIGHT)
|
|
bias = WeightMatrix(shape=(9, 1), ellipses_col=None)
|
|
bias.get_entries()[0].set_value(-1).set_color(RED)
|
|
bias.match_height(up_proj)
|
|
bias.next_to(plus)
|
|
bias_name = Text("Bias")
|
|
bias_name.next_to(bias, UP)
|
|
|
|
eq.target = eq.generate_target()
|
|
eq.target.next_to(bias, RIGHT)
|
|
rhs.target = vects[0].copy()
|
|
rhs.target.replace(rhs, dim_to_match=1)
|
|
rhs.target.next_to(eq.target, RIGHT)
|
|
|
|
self.play(
|
|
Write(plus),
|
|
FadeIn(bias, lag_ratio=0.1),
|
|
MoveToTarget(eq),
|
|
MoveToTarget(rhs),
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
frame.animate.scale(1.1, about_edge=DOWN),
|
|
Write(bias_name),
|
|
)
|
|
self.wait()
|
|
|
|
# Emphasize the parameters are learned from data
|
|
data_modifying_matrix(self, bias, word_shape=(5, 1), alpha_maxes=(0.4, 0.9), fix_in_frame=True)
|
|
bias.get_entries()[0].set_value(-1).set_color(RED)
|
|
|
|
# Pull up the MJ example again
|
|
fe_rect = SurroundingRectangle(rhs.get_entries()[0], buff=0.1) # fe = First entry
|
|
fe_rect.set_stroke(RED, 3)
|
|
fe_eq = Tex("=")
|
|
fe_eq.next_to(fe_rect, RIGHT, SMALL_BUFF)
|
|
fe_expr = VGroup(dist_rhs[2].copy(), Tex("- 1"))
|
|
fe_expr[1].set_height(fe_expr[0].get_height() * 0.8)
|
|
fe_expr.arrange(RIGHT)
|
|
fe_expr.next_to(fe_eq, RIGHT)
|
|
|
|
bias_rect = SurroundingRectangle(bias.get_entries()[0])
|
|
|
|
self.play(
|
|
ShowCreation(fe_rect),
|
|
FadeIn(fe_eq, RIGHT),
|
|
Write(fe_expr)
|
|
)
|
|
self.wait()
|
|
self.play(ShowCreation(bias_rect))
|
|
self.wait()
|
|
self.play(bias_rect.animate.surround(fe_expr[1]))
|
|
self.wait()
|
|
self.play(bias_rect.animate.surround(fe_expr))
|
|
self.wait()
|
|
|
|
# Show what it means, but now shifted
|
|
conditions = VGroup(
|
|
TexText(R"$\approx 1$ \; if $\vec{\textbf{E}}$ encodes ``Michael Jordan''"),
|
|
TexText(R"$\le 0$ \; Otherwise"),
|
|
)
|
|
conditions[0][R"\vec{\textbf{E}}"].set_color(YELLOW)
|
|
conditions.scale(0.75)
|
|
conditions.arrange(DOWN, buff=0.5, aligned_edge=LEFT)
|
|
under_brace = Brace(fe_expr, DOWN)
|
|
conditions.next_to(under_brace, DOWN, aligned_edge=LEFT)
|
|
|
|
self.play(
|
|
FadeOut(bias_rect),
|
|
GrowFromCenter(under_brace),
|
|
FadeIn(conditions[0], DOWN)
|
|
)
|
|
self.wait()
|
|
self.play(FadeIn(conditions[1], 0.25 * DOWN))
|
|
self.wait(2)
|
|
|
|
self.play(
|
|
frame.animate.reorient(0, 0, 0, (-2.5, 0.44, 0.0), 9.33),
|
|
LaggedStart(*map(FadeOut, [
|
|
fe_rect, fe_eq, fe_expr,
|
|
under_brace, *conditions
|
|
]))
|
|
)
|
|
|
|
# Show the matrix size
|
|
up_proj.refresh_bounding_box()
|
|
row_rects = VGroup(
|
|
SurroundingRectangle(row, buff=0.1)
|
|
for row in up_proj.get_rows()
|
|
)
|
|
row_rects.set_stroke(WHITE, 1)
|
|
row_rects.set_fill(GREY_C, 0.25)
|
|
row_rects[-2].match_width(row_rects, stretch=True)
|
|
|
|
over_brace = Brace(row_rects[0], UP, buff=SMALL_BUFF)
|
|
d_model = 12288
|
|
row_size = Integer(d_model)
|
|
row_size.next_to(over_brace, UP)
|
|
side_brace = Brace(row_rects, LEFT)
|
|
num_rows = Integer(4 * d_model)
|
|
num_rows.next_to(side_brace, LEFT)
|
|
num_rows_expr = Tex(R"4 \times 12{,}288")
|
|
num_rows_expr.next_to(side_brace, LEFT)
|
|
|
|
self.play(
|
|
FadeIn(row_rects, lag_ratio=0.5),
|
|
GrowFromCenter(side_brace),
|
|
CountInFrom(num_rows)
|
|
)
|
|
self.wait()
|
|
self.play(FadeTransform(num_rows, num_rows_expr))
|
|
self.wait()
|
|
self.play(
|
|
FadeTransform(num_rows_expr["12{,}288"].copy(), row_size),
|
|
TransformFromCopy(side_brace, over_brace),
|
|
)
|
|
self.wait()
|
|
self.play(FadeOut(row_rects, lag_ratio=0.1))
|
|
|
|
# Calculate matrix size
|
|
full_product = VGroup(
|
|
num_rows_expr.copy(),
|
|
Tex(R"\times"),
|
|
row_size.copy(),
|
|
Tex(Rf"="),
|
|
Integer(4 * d_model * d_model)
|
|
)
|
|
full_product.scale(1.5)
|
|
full_product.arrange(RIGHT, buff=MED_SMALL_BUFF)
|
|
full_product.next_to(row_rects, UP, buff=2.5)
|
|
|
|
self.play(LaggedStart(
|
|
frame.animate.reorient(0, 0, 0, (-3.88, 1.51, 0.0), 11.35),
|
|
TransformFromCopy(num_rows_expr, full_product[0]),
|
|
FadeIn(full_product[1], UP),
|
|
TransformFromCopy(row_size, full_product[2]),
|
|
lag_ratio=0.25,
|
|
run_time=2
|
|
))
|
|
self.play(
|
|
TransformFromCopy(full_product[:3], full_product[3:])
|
|
)
|
|
self.wait()
|
|
self.play(FlashAround(full_product[-1], run_time=2, time_width=1.5))
|
|
|
|
# Count bias parameters
|
|
bias_count = Tex(R"4 \times 12{,}288")
|
|
bias_count.match_height(full_product)
|
|
bias_count.match_y(full_product)
|
|
bias_count.match_x(bias)
|
|
bias_rect = SurroundingRectangle(VGroup(bias, bias_name))
|
|
bias_rect.set_stroke(BLUE_B)
|
|
bias_arrow = Arrow(bias_rect.get_top(), bias_count.get_bottom())
|
|
bias_arrow.match_color(bias_rect)
|
|
bias_count.match_color(bias_rect)
|
|
|
|
div_eq = Tex(R"{4 \times 12{,}288 \over 603{,}979{,}776} \approx 0.00008 ")
|
|
div_eq[R"{4 \times 12{,}288"].match_color(bias_rect)
|
|
div_eq.next_to(frame.get_corner(UR), DL, buff=MED_LARGE_BUFF)
|
|
div_eq.shift(RIGHT)
|
|
|
|
self.play(ShowCreation(bias_rect))
|
|
self.play(
|
|
GrowArrow(bias_arrow),
|
|
FadeInFromPoint(bias_count, bias_arrow.get_start()),
|
|
full_product.animate.scale(0.8).shift(3.5 * LEFT)
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
frame.animate.set_x(-3.0),
|
|
FadeTransform(bias_count.copy(), div_eq[R"4 \times 12{,}288"]),
|
|
Write(div_eq[R"\over"]),
|
|
FadeTransform(full_product[-1].copy(), div_eq[R"603{,}979{,}776}"]),
|
|
Write(div_eq[R"\approx 0.00008"]),
|
|
)
|
|
self.wait()
|
|
|
|
self.play(
|
|
frame.animate.reorient(0, 0, 0, (-2.5, 0.44, 0.0), 9.33),
|
|
*map(FadeOut, [full_product, bias_rect, bias_arrow, bias_count, div_eq])
|
|
)
|
|
|
|
# Collapse
|
|
substrs = [R"W_\uparrow", R"\vec{\textbf{E}}_i", "+", R"\vec{\textbf{B}}_\uparrow"]
|
|
linear_expr = Tex(" ".join(substrs))
|
|
W_up, E_i, plus2, B_up = [linear_expr[ss] for ss in substrs]
|
|
VGroup(W_up, B_up).set_color(BLUE)
|
|
E_i.set_color(YELLOW)
|
|
linear_expr.move_to(plus).shift(0.6 * LEFT)
|
|
|
|
low_emb_label = E_i.copy()
|
|
low_emb_label.scale(0.5).next_to(emb, UP)
|
|
|
|
self.play(
|
|
frame.animate.reorient(0, 0, 0, (-0.03, 0.03, 0.0), 8.34),
|
|
ReplacementTransform(up_proj, W_up, lag_ratio=1e-3),
|
|
FadeOut(side_brace, RIGHT, scale=0.5),
|
|
FadeOut(num_rows_expr, RIGHT, scale=0.5),
|
|
FadeOut(over_brace, DR, scale=0.5),
|
|
FadeOut(row_size, DR, scale=0.5),
|
|
)
|
|
self.wait()
|
|
self.play(ReplacementTransform(up_emb, E_i, lag_ratio=1e-2))
|
|
self.play(TransformFromCopy(E_i, low_emb_label))
|
|
self.wait()
|
|
self.play(
|
|
ReplacementTransform(plus, plus2),
|
|
ReplacementTransform(bias, B_up, lag_ratio=1e-2),
|
|
FadeOut(bias_name, DL),
|
|
VGroup(eq, rhs).animate.next_to(B_up, RIGHT).shift(0.1 * DOWN),
|
|
run_time=2
|
|
)
|
|
self.wait()
|
|
|
|
# Add parameters below first linear arrow
|
|
self.play(
|
|
linear_expr.animate.scale(0.5).next_to(arrows[0], DOWN, buff=0.1),
|
|
ReplacementTransform(rhs, vects[0]),
|
|
FadeOut(eq, 4 * DOWN + LEFT),
|
|
run_time=2
|
|
)
|
|
self.wait()
|
|
|
|
# Pull up ReLU
|
|
self.play(phases[1].animate.set_opacity(1))
|
|
phase1_copy = VGroup(vects[0], arrows[1], vects[1]).copy()
|
|
phase1_copy.save_state()
|
|
|
|
self.play(
|
|
phase1_copy.animate.scale(2.0).next_to(full_block, UP, buff=0.5),
|
|
frame.animate.reorient(0, 0, 0, (-0.26, 0.54, 0.0), 9.40)
|
|
)
|
|
self.wait()
|
|
|
|
# Break down ReLU
|
|
relu_arrow = phase1_copy[1]
|
|
neg_arrows = VGroup()
|
|
pos_arrows = VGroup()
|
|
neg_left_rects = VGroup()
|
|
zero_right_rects = VGroup()
|
|
pos_left_rects = VGroup()
|
|
pos_right_rects = VGroup()
|
|
in_vect = phase1_copy[0]
|
|
out_vect = phase1_copy[2]
|
|
for e1, e2 in zip(in_vect.get_entries(), out_vect.get_entries()):
|
|
arrow = Arrow(e1, e2, buff=0.3)
|
|
if e1.get_value() > 0:
|
|
arrow.set_color(BLUE)
|
|
pos_arrows.add(arrow)
|
|
pos_left_rects.add(SurroundingRectangle(e1, color=BLUE))
|
|
pos_right_rects.add(SurroundingRectangle(e2, color=BLUE))
|
|
else:
|
|
arrow.set_color(RED)
|
|
neg_arrows.add(arrow)
|
|
neg_left_rects.add(SurroundingRectangle(e1, color=RED))
|
|
zero_right_rects.add(SurroundingRectangle(e2, color=RED))
|
|
VGroup(neg_left_rects, zero_right_rects, pos_left_rects, pos_right_rects).set_stroke(width=2)
|
|
|
|
self.play(ShowCreation(neg_left_rects, lag_ratio=0.5))
|
|
self.wait()
|
|
self.play(
|
|
TransformFromCopy(neg_left_rects, zero_right_rects, lag_ratio=0.5),
|
|
ShowCreation(neg_arrows, lag_ratio=0.5),
|
|
FadeOut(relu_arrow),
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
FadeOut(neg_left_rects, lag_ratio=0.25),
|
|
FadeOut(zero_right_rects, lag_ratio=0.25),
|
|
FadeOut(neg_arrows, lag_ratio=0.25),
|
|
ShowCreation(pos_left_rects)
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
ShowCreation(pos_arrows, lag_ratio=0.5),
|
|
TransformFromCopy(pos_left_rects, pos_right_rects, lag_ratio=0.5),
|
|
)
|
|
self.wait()
|
|
|
|
# Graph ReLU
|
|
relu_title_full = Text("Rectified\nLinear\nUnit", alignment="LEFT")
|
|
relu_title_full.next_to(relu_arrow, UP)
|
|
|
|
axes = Axes((-4, 4), (-1, 4))
|
|
axes.set_width(6)
|
|
axes.next_to(phase1_copy, RIGHT, buff=1.0)
|
|
axes.add_coordinate_labels(font_size=16)
|
|
relu_graph = axes.get_graph(lambda x: max(0, x), discontinuities=[0])
|
|
relu_graph.set_stroke(YELLOW, 4)
|
|
plot = VGroup(axes, relu_graph)
|
|
|
|
relu_graph_label = Text("ReLU")
|
|
relu_graph_label.match_color(relu_graph)
|
|
relu_graph_label.move_to(axes, UL)
|
|
|
|
self.play(
|
|
frame.animate.set_x(2.7),
|
|
FadeIn(relu_arrow),
|
|
FadeIn(relu_title_full, 0.1 * UP, lag_ratio=0.1, run_time=2),
|
|
FadeOut(pos_arrows, lag_ratio=0.25),
|
|
FadeOut(pos_left_rects, lag_ratio=0.25),
|
|
FadeOut(pos_right_rects, lag_ratio=0.25),
|
|
FadeIn(plot, RIGHT),
|
|
)
|
|
self.wait()
|
|
self.play(*(
|
|
TransformFromCopy(relu_title_full[substr], relu_graph_label[substr])
|
|
for substr in ["Re", "L", "U"]
|
|
))
|
|
self.add(relu_graph_label)
|
|
|
|
# Recall the meaning of the first entry
|
|
mid_vect = phase1_copy[0]
|
|
conditions_rect = SurroundingRectangle(conditions, buff=0.25)
|
|
conditions_rect.set_stroke(YELLOW, 1)
|
|
under_brace = Brace(conditions_rect, DOWN, buff=SMALL_BUFF)
|
|
VGroup(conditions, conditions_rect, under_brace).next_to(mid_vect, UP)
|
|
fe_rect = SurroundingRectangle(mid_vect.get_entries()[0])
|
|
|
|
condition_group = VGroup(fe_rect, under_brace, conditions, conditions_rect)
|
|
|
|
self.play(
|
|
frame.animate.reorient(0, 0, 0, (2.61, 0.97, 0.0), 11.5),
|
|
ShowCreation(fe_rect),
|
|
GrowFromCenter(under_brace),
|
|
)
|
|
self.play(
|
|
TransformFromCopy(fe_rect, conditions_rect),
|
|
FadeInFromPoint(conditions, fe_rect.get_center()),
|
|
)
|
|
self.wait()
|
|
self.play(condition_group.animate.match_x(phase1_copy[2]))
|
|
|
|
equals = Tex("=")
|
|
ineq = conditions[1][0]
|
|
equals.replace(ineq, dim_to_match=0)
|
|
self.play(
|
|
FlashAround(equals, run_time=2, time_width=1.5),
|
|
ineq.animate.become(equals)
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
frame.animate.reorient(0, 0, 0, (2.48, 0.33, 0.0), 9.17),
|
|
FadeOut(condition_group, lag_ratio=0.01)
|
|
)
|
|
|
|
# Graph GeLU
|
|
gelu_title_full = Text("Gaussian\nError\nLinear\nUnit", font_size=42, alignment="LEFT")
|
|
gelu_title_full.next_to(relu_arrow, UP)
|
|
gelu_graph = axes.get_graph(lambda x: x * norm.cdf(x))
|
|
gelu_graph.set_stroke(GREEN, 4)
|
|
|
|
gelu_graph_label = Text("GELU")
|
|
gelu_graph_label.next_to(relu_graph_label, DOWN, buff=MED_LARGE_BUFF, aligned_edge=LEFT)
|
|
gelu_graph_label.match_color(gelu_graph)
|
|
|
|
self.play(
|
|
FadeTransform(relu_title_full, gelu_title_full),
|
|
relu_graph_label.animate.set_fill(opacity=0.25),
|
|
relu_graph.animate.set_stroke(opacity=0.25),
|
|
ShowCreation(gelu_graph),
|
|
TransformFromCopy(relu_graph_label, gelu_graph_label)
|
|
)
|
|
self.wait(2)
|
|
self.play(
|
|
gelu_graph.animate.set_stroke(opacity=0.25),
|
|
gelu_graph_label.animate.set_fill(opacity=0.25),
|
|
relu_graph.animate.set_stroke(opacity=1),
|
|
relu_graph_label.animate.set_fill(opacity=1),
|
|
FadeTransform(gelu_title_full, relu_title_full),
|
|
)
|
|
self.wait()
|
|
|
|
# Describe these as neurons
|
|
neuron_word = Text("Neurons", font_size=72)
|
|
neuron_word.next_to(phase1_copy, RIGHT, buff=2.5)
|
|
neuron_arrows = VGroup(
|
|
Arrow(neuron_word.get_left(), entry.get_right(), buff=0.4, stroke_width=3)
|
|
for entry in phase1_copy[2].get_entries()
|
|
)
|
|
|
|
self.play(
|
|
plot.animate.set_width(2).next_to(relu_arrow, DOWN),
|
|
FadeOut(VGroup(relu_graph_label, gelu_graph_label, gelu_graph)),
|
|
Write(neuron_word),
|
|
ShowCreation(neuron_arrows, lag_ratio=0.2, run_time=3),
|
|
LaggedStartMap(
|
|
FlashAround, phase1_copy[2].get_entries(),
|
|
time_width=3.0,
|
|
lag_ratio=0.05,
|
|
time_span=(1, 4),
|
|
run_time=4
|
|
)
|
|
)
|
|
self.wait()
|
|
|
|
# Show the classic dots picture
|
|
blocking_rect = BackgroundRectangle(VGroup(phase1_copy), buff=0.1)
|
|
blocking_rect.set_fill(BLACK, 1)
|
|
up_emb.move_to(blocking_rect, LEFT)
|
|
dots = VGroup(
|
|
Dot(radius=0.15).move_to(entry).set_fill(WHITE, opacity=clip(entry.get_value(), 0, 1))
|
|
for entry in phase1_copy[2].get_entries()
|
|
)
|
|
dots.set_stroke(WHITE, 2)
|
|
up_emb = emb.copy()
|
|
up_emb.rotate(PI / 2, DOWN)
|
|
up_emb.rotate(1 * DEGREES)
|
|
up_emb.match_width(phase1_copy[0])
|
|
up_emb.move_to(phase1_copy[0]).shift(RIGHT)
|
|
up_emb[-2:].set_color(YELLOW)
|
|
lines = VGroup(
|
|
Line(entry.get_right() + 0.05 * RIGHT, dot).set_stroke(
|
|
color=value_to_color(random.uniform(-10, 10)),
|
|
width=3 * random.random()**2,
|
|
)
|
|
for entry in up_emb.get_entries()
|
|
for dot in dots
|
|
)
|
|
|
|
self.play(
|
|
FadeIn(blocking_rect),
|
|
Write(dots),
|
|
)
|
|
self.play(TransformFromCopy(emb, up_emb))
|
|
self.play(ShowCreation(lines, lag_ratio=3 / len(lines)))
|
|
self.wait()
|
|
self.play(
|
|
LaggedStart(*map(FadeOut, [up_emb, *lines, blocking_rect, *dots]), lag_ratio=0.01)
|
|
)
|
|
|
|
# Discuss active and inactive
|
|
entry = phase1_copy[2].get_entries()[0]
|
|
entry_rect = SurroundingRectangle(entry)
|
|
entry_rect.set_stroke(YELLOW, 2)
|
|
active_words = TexText(R"``Michael Jordan'' neuron is \emph{active}")
|
|
active = active_words["active"][0]
|
|
active.set_color(BLUE_B)
|
|
active_words.next_to(entry_rect, UP, aligned_edge=LEFT)
|
|
active_words.shift(LEFT)
|
|
inactive = TexText(R"\emph{inactive}")
|
|
inactive.set_color(RED)
|
|
inactive.move_to(active, LEFT)
|
|
|
|
self.play(
|
|
frame.animate.reorient(0, 0, 0, (2.45, 0.58, 0.0), 9.65),
|
|
ShowCreation(entry_rect),
|
|
Write(active_words, run_time=1),
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
ChangeDecimalToValue(entry, 0),
|
|
ReplacementTransform(active, inactive[2:]),
|
|
GrowFromCenter(inactive[:2]),
|
|
)
|
|
active_words.add(inactive)
|
|
self.wait()
|
|
|
|
# Replace the ReLU diagram portion
|
|
self.play(
|
|
Restore(phase1_copy),
|
|
TransformMatchingStrings(relu_title_full, arrow_labels[1]),
|
|
plot.animate.scale(0.5).next_to(arrows[1], DOWN, SMALL_BUFF),
|
|
FadeOut(neuron_word, DOWN),
|
|
FadeOut(neuron_arrows, DOWN, lag_ratio=0.1),
|
|
FadeOut(entry_rect, DOWN),
|
|
FadeOut(active_words, DOWN, lag_ratio=0.01),
|
|
run_time=1.5
|
|
)
|
|
self.remove(phase1_copy)
|
|
|
|
# Down projection
|
|
neurons = vects[1].copy()
|
|
neurons.target = neurons.generate_target()
|
|
neurons.target.set_height(4)
|
|
neurons.target.move_to(3 * RIGHT + 2.5 * UP)
|
|
down_proj = WeightMatrix(shape=(6, 9))
|
|
down_proj.set_height(2.75)
|
|
down_proj.next_to(neurons.target, LEFT)
|
|
|
|
plus = Tex("+")
|
|
plus.next_to(neurons.target, RIGHT)
|
|
bias = WeightMatrix(shape=(6, 1))
|
|
bias.match_height(down_proj)
|
|
bias.next_to(plus, RIGHT)
|
|
|
|
equals = Tex("=")
|
|
equals.next_to(bias, RIGHT)
|
|
rhs = vects[2].copy()
|
|
rhs.set_opacity(1)
|
|
rhs.match_height(bias)
|
|
rhs.next_to(equals, RIGHT)
|
|
|
|
self.play(phases[2].animate.set_opacity(1))
|
|
self.play(MoveToTarget(neurons))
|
|
self.play(FadeTransform(arrows[2].copy(), down_proj))
|
|
self.wait()
|
|
temp_eq, temp_rhs = show_matrix_vector_product(self, down_proj, neurons)
|
|
self.wait()
|
|
self.play(
|
|
FadeOut(temp_eq, DOWN),
|
|
FadeOut(temp_rhs, DOWN),
|
|
Write(plus),
|
|
FadeIn(bias, RIGHT),
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
Write(equals),
|
|
TransformFromCopy(vects[2], rhs),
|
|
)
|
|
self.wait()
|
|
|
|
# Name it as the down-projection
|
|
over_brace = Brace(down_proj, UP)
|
|
name = TexText("``Down projection''")
|
|
name.next_to(over_brace, UP)
|
|
|
|
side_brace = Brace(rhs, RIGHT)
|
|
dim_count = Integer(12288)
|
|
dim_count.next_to(side_brace, RIGHT)
|
|
|
|
self.play(
|
|
CountInFrom(dim_count),
|
|
GrowFromCenter(side_brace),
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
Write(name),
|
|
GrowFromCenter(over_brace),
|
|
)
|
|
self.wait()
|
|
|
|
# Show column-by-column
|
|
col_matrix = self.get_col_matrix(down_proj, 7)
|
|
bias_as_col = self.get_col_matrix(bias, 1, dots_index=None, sym="B", top_index="", width_multiple=0.7)
|
|
n_labels = VGroup(
|
|
Tex(f"n_{{{m}}}")
|
|
for m in [*range(6), "m"]
|
|
)
|
|
n_labels.arrange(DOWN, buff=0.5)
|
|
n_labels.match_height(neurons.get_entries())
|
|
n_labels.move_to(neurons.get_entries())
|
|
n_labels.replace_submobject(-2, Tex(R"\vdots").move_to(n_labels[-2]))
|
|
n_labels.set_color(BLUE)
|
|
n_vect = VGroup(neurons[-2:].copy(), n_labels)
|
|
|
|
self.play(
|
|
LaggedStart(*map(FadeOut, [over_brace, name, side_brace, dim_count])),
|
|
LaggedStart(
|
|
FadeOut(down_proj),
|
|
FadeIn(col_matrix),
|
|
FadeOut(neurons),
|
|
FadeIn(n_vect),
|
|
FadeOut(bias),
|
|
FadeIn(bias_as_col),
|
|
)
|
|
)
|
|
self.wait()
|
|
|
|
# Expand the column interpretation
|
|
over_brace = Brace(VGroup(col_matrix, n_vect), UP)
|
|
scaled_cols = VGroup(
|
|
VGroup(n_label, col_label[0]).copy()
|
|
for n_label, col_label in zip(n_labels, col_matrix[1])
|
|
)
|
|
scaled_cols.target = VGroup()
|
|
for pair in scaled_cols:
|
|
pair.target = pair.generate_target()
|
|
pair.target[0].scale(1.5)
|
|
pair.target.arrange(RIGHT, buff=0.1, aligned_edge=DOWN)
|
|
scaled_cols.target.add(pair.target)
|
|
scaled_cols.target[-2].become(Tex(R"\dots"))
|
|
scaled_cols.target.arrange(RIGHT, buff=0.75)
|
|
scaled_cols.target.set_width(1.25 * over_brace.get_width())
|
|
scaled_cols.target.next_to(over_brace, UP, buff=0.5)
|
|
|
|
plusses = VGroup(
|
|
Tex("+").move_to(midpoint(m1.get_right(), m2.get_left()))
|
|
for m1, m2 in zip(scaled_cols.target, scaled_cols.target[1:])
|
|
)
|
|
|
|
self.play(
|
|
frame.animate.reorient(0, 0, 0, (-0.27, 1.04, 0.0), 11.06),
|
|
GrowFromCenter(over_brace),
|
|
LaggedStartMap(MoveToTarget, scaled_cols, lag_ratio=0.7, run_time=5),
|
|
LaggedStartMap(FadeIn, plusses, lag_ratio=0.7, run_time=5),
|
|
)
|
|
self.wait()
|
|
|
|
# Highlight each set
|
|
last_rects = VGroup()
|
|
all_rect_groups = VGroup()
|
|
for tup in zip(col_matrix[1], n_labels, scaled_cols):
|
|
rects = VGroup(SurroundingRectangle(mob) for mob in tup)
|
|
rects.set_stroke(YELLOW, 2)
|
|
self.play(
|
|
FadeOut(last_rects),
|
|
FadeIn(rects),
|
|
)
|
|
self.wait(0.5)
|
|
all_rect_groups.add(rects)
|
|
last_rects = rects
|
|
self.play(FadeOut(last_rects))
|
|
|
|
# First column as basketball
|
|
col_rect, n_rect, prod_rect = rects = all_rect_groups[0]
|
|
basketball = Text("Basketball", font_size=60)
|
|
basketball.set_color("#F88158")
|
|
basketball.next_to(col_rect, LEFT)
|
|
basketball.save_state()
|
|
basketball.rotate(-PI / 2)
|
|
basketball.move_to(col_rect)
|
|
basketball.set_opacity(0)
|
|
|
|
n0_term = scaled_cols[0][0]
|
|
n0_term.save_state()
|
|
one = Tex("1", font_size=60).move_to(n0_term, DR).set_color(BLUE)
|
|
zero = Tex("0", font_size=60).move_to(n0_term, DR).set_color(RED)
|
|
|
|
self.play(
|
|
ShowCreation(col_rect),
|
|
col_matrix[1][1:].animate.set_opacity(0.5),
|
|
n_labels[1:].animate.set_opacity(0.5),
|
|
scaled_cols[1:].animate.set_opacity(0.5),
|
|
plusses.animate.set_opacity(0.5)
|
|
)
|
|
self.play(Restore(basketball, path_arc=PI / 2))
|
|
self.wait()
|
|
self.play(TransformFromCopy(col_rect, n_rect))
|
|
self.wait()
|
|
self.play(
|
|
TransformFromCopy(col_rect, prod_rect),
|
|
TransformFromCopy(n_rect, prod_rect),
|
|
)
|
|
self.play(Transform(n0_term, one))
|
|
self.wait()
|
|
self.play(Transform(n0_term, zero))
|
|
self.wait()
|
|
self.play(Restore(n0_term))
|
|
n0_term.restore()
|
|
self.wait()
|
|
|
|
# Cycle through columns one more time
|
|
rects.add(basketball)
|
|
for index in range(1, len(all_rect_groups)):
|
|
self.play(
|
|
FadeOut(all_rect_groups[index - 1]),
|
|
FadeIn(all_rect_groups[index]),
|
|
col_matrix[1][index].animate.set_opacity(1),
|
|
n_labels[index].animate.set_opacity(1),
|
|
scaled_cols[index].animate.set_opacity(1),
|
|
plusses[index - 1].animate.set_opacity(1),
|
|
)
|
|
self.wait(0.5)
|
|
self.play(FadeOut(all_rect_groups[-1]))
|
|
|
|
# Highlight bias
|
|
bias_rect = SurroundingRectangle(bias)
|
|
bias_brace = Brace(bias_rect, UP)
|
|
bias_word = Text("Bias")
|
|
bias_word.next_to(bias_brace, UP, MED_SMALL_BUFF)
|
|
|
|
self.play(
|
|
ReplacementTransform(over_brace, bias_brace),
|
|
FadeIn(bias_rect),
|
|
FadeOut(plusses, lag_ratio=0.1),
|
|
FadeOut(scaled_cols, lag_ratio=0.1),
|
|
)
|
|
self.play(FadeIn(bias_word, 0.5 * UP))
|
|
self.wait()
|
|
self.play(LaggedStart(*map(FadeOut, [bias_word, bias_brace, bias_rect])))
|
|
|
|
# Collpase the down projection
|
|
W_down = Tex(R"W_\downarrow", font_size=60).set_color(BLUE)
|
|
B_down = Tex(R"\vec{\textbf{B}}_\downarrow", font_size=60).set_color(BLUE_B)
|
|
W_down.next_to(neurons, LEFT)
|
|
B_down.move_to(bias_as_col)
|
|
WB_down = VGroup(W_down, B_down)
|
|
n_rect = Rectangle(1, 1)
|
|
n_rect.set_height(W_down.get_height())
|
|
n_rect.move_to(n_vect)
|
|
n_rect.set_fill(GREY_C)
|
|
n_rect.set_stroke(WHITE, 1)
|
|
|
|
down_proj_expr = VGroup(W_down, n_vect, plus, B_down)
|
|
down_proj_expr.target = down_proj_expr.generate_target()
|
|
down_proj_expr.target[1].become(VGroup(n_rect))
|
|
down_proj_expr.target.arrange(RIGHT, buff=SMALL_BUFF)
|
|
down_proj_expr.target.scale(0.4)
|
|
down_proj_expr.target.next_to(arrows[2], DOWN)
|
|
|
|
self.play(ReplacementTransform(col_matrix, W_down, lag_ratio=5e-3, run_time=2))
|
|
self.play(ReplacementTransform(bias_as_col, B_down, lag_ratio=1e-2))
|
|
self.wait()
|
|
self.play(
|
|
LaggedStart(
|
|
MoveToTarget(down_proj_expr),
|
|
FadeOut(equals, 2 * DOWN + 0.5 * LEFT),
|
|
ReplacementTransform(rhs, vects[2]),
|
|
lag_ratio=0.25,
|
|
time_span=(0, 1.5),
|
|
),
|
|
frame.animate.reorient(0, -14, 0, (-0.1, -2.03, 0.01), 6.31),
|
|
run_time=2,
|
|
)
|
|
self.wait()
|
|
|
|
# Add it to the original
|
|
faded_sum_circuit = sum_circuit.copy()
|
|
sum_circuit.set_stroke(opacity=1)
|
|
sum_circuit.insert_n_curves(20)
|
|
|
|
self.add(faded_sum_circuit)
|
|
self.play(
|
|
frame.animate.reorient(13, -8, 0, (0.15, -2.05, 0.0), 6.52),
|
|
ShowCreation(sum_circuit, lag_ratio=0.5),
|
|
low_emb_label.animate.shift(0.2 * LEFT).set_anim_args(time_span=(0, 1)),
|
|
FadeOut(output_emb),
|
|
run_time=2,
|
|
)
|
|
self.remove(faded_sum_circuit)
|
|
output_emb.set_fill(opacity=1)
|
|
self.play(LaggedStart(
|
|
TransformFromCopy(emb, output_emb, path_arc=-45 * DEGREES),
|
|
TransformFromCopy(vects[2], output_emb, path_arc=-45 * DEGREES),
|
|
run_time=2,
|
|
lag_ratio=0.2,
|
|
))
|
|
self.wait()
|
|
|
|
# Yet again, emphasize the MJ example
|
|
m_color = interpolate_color_by_hsl(GREY_BROWN, WHITE, 0.5)
|
|
j_color = RED_B
|
|
b_color = basketball.get_color()
|
|
m_tex = Tex(R"\overrightarrow{\text{F.N. Michael}}").set_color(m_color)
|
|
j_tex = Tex(R"\overrightarrow{\text{L.N. Jordan}}").set_color(j_color)
|
|
b_tex = Tex(R"\overrightarrow{\text{Basketball}}").set_color(b_color)
|
|
mj = VGroup(m_tex, Tex("+"), j_tex).copy()
|
|
mjb = VGroup(m_tex, Tex("+"), j_tex, Tex("+"), b_tex).copy()
|
|
for tex_mob in [mj, mjb]:
|
|
tex_mob.set_height(0.45)
|
|
tex_mob.arrange(RIGHT, buff=SMALL_BUFF)
|
|
tex_mob.set_fill(border_width=1)
|
|
mj.next_to(low_emb_label, UP, buff=1.0).shift(0.5 * LEFT)
|
|
mjb.next_to(output_emb, UP, buff=1.5).shift(1.0 * RIGHT)
|
|
mj_arrow = Arrow(mj.get_bottom(), low_emb_label, buff=0.1)
|
|
mjb_arrow = Arrow(output_emb.get_top(), mjb.get_bottom(), buff=0.15)
|
|
|
|
self.play(
|
|
frame.animate.reorient(4, -6, 0, (-0.29, -1.76, 0.02), 7.70),
|
|
FadeIn(mj, lag_ratio=0.1),
|
|
ShowCreation(mj_arrow)
|
|
)
|
|
self.play(Transform(mj.copy(), emb.copy().set_opacity(0), lag_ratio=0.005, remover=True, run_time=2))
|
|
mover = emb.copy()
|
|
for vect in [*vects, output_emb]:
|
|
self.play(Transform(mover, vect, rate_func=linear))
|
|
self.remove(mover)
|
|
self.play(
|
|
frame.animate.reorient(-3, -5, 0, (1.09, -1.48, -0.03), 9.61),
|
|
FadeTransform(mj.copy(), mjb[:3]),
|
|
FadeTransformPieces(mj.copy()[-1:], mjb[3:]),
|
|
ShowCreation(mjb_arrow),
|
|
run_time=2,
|
|
)
|
|
self.wait(2)
|
|
self.play(
|
|
frame.animate.reorient(21, -14, 0, (-0.13, -2.21, 0.11), 6.91).set_anim_args(run_time=5),
|
|
LaggedStartMap(FadeOut, VGroup(mj, mj_arrow, mjb_arrow, mjb)),
|
|
)
|
|
|
|
# Show it done in parallel to all embeddings
|
|
self.play(
|
|
frame.animate.reorient(14, -12, 0, (0.55, -2.21, 0.18), 7.05),
|
|
LaggedStart((
|
|
TransformFromCopy(simple_phases, sp_copy)
|
|
for sp_copy in simple_phase_copies
|
|
), lag_ratio=0.1),
|
|
FadeOut(block_title, time_span=(0, 1)),
|
|
run_time=5,
|
|
)
|
|
self.play(
|
|
frame.animate.reorient(42, -23, 0, (0.55, -2.21, 0.18), 7.05),
|
|
run_time=8
|
|
)
|
|
|
|
self.wait()
|
|
|
|
# Show neurons?
|
|
sum_circuits = VGroup(
|
|
sum_circuit,
|
|
*(sp[0] for sp in simple_phase_copies),
|
|
*(sp[-1] for sp in simple_phase_copies),
|
|
)
|
|
n_vects = VGroup(vects[1], *(sp[2][1] for sp in simple_phase_copies))
|
|
|
|
neuron_points = np.array([
|
|
entry.get_center()
|
|
for vect in n_vects[1:]
|
|
for entry in vect.get_entries()
|
|
])
|
|
neurons = DotCloud(neuron_points)
|
|
neurons.set_radius(0.075)
|
|
neurons.set_shading(0.25, 0.25, 0.5)
|
|
neurons.apply_depth_test()
|
|
rgbas = np.random.random(len(neuron_points))
|
|
rgbas = rgbas.repeat(4).reshape((rgbas.size, 4))
|
|
rgbas[:, 3] = 1
|
|
neurons.set_rgba_array(rgbas)
|
|
neuron_ellipses = VGroup(
|
|
n_vect.get_ellipses()
|
|
for n_vect in n_vects[1:]
|
|
)
|
|
|
|
self.play(
|
|
frame.animate.reorient(11, -5, 0, (0.55, -2.21, 0.18), 7.05),
|
|
sum_circuits.animate.set_stroke(width=1, opacity=0.2),
|
|
FadeOut(block[4]),
|
|
run_time=2
|
|
)
|
|
self.play(
|
|
frame.animate.reorient(-11, -5, 0, (0.55, -2.21, 0.18), 7.05).set_anim_args(run_time=4),
|
|
FadeOut(n_vects),
|
|
ShowCreation(neurons, run_time=2),
|
|
FadeIn(neuron_ellipses, time_span=(1, 2)),
|
|
)
|
|
self.add(neuron_ellipses)
|
|
self.play(frame.animate.reorient(13, -7, 0, (0.55, -2.21, 0.18), 7.05), run_time=4)
|
|
self.wait()
|
|
|
|
def get_sum_circuit(
|
|
self, in_vect, diff_vect,
|
|
v_buff=0.15,
|
|
h_buff=0.5,
|
|
y_diff=0.65,
|
|
color=YELLOW
|
|
):
|
|
plus = VGroup(Line(UP, DOWN), Line(LEFT, RIGHT))
|
|
plus.scale(0.6)
|
|
circle = Circle(radius=1)
|
|
oplus = VGroup(circle, plus)
|
|
oplus.set_height(0.3)
|
|
oplus.next_to(diff_vect, RIGHT, buff=h_buff)
|
|
|
|
p0 = in_vect.get_top() + v_buff * UP
|
|
p1 = in_vect.get_top() + y_diff * UP
|
|
p2 = oplus.get_center()
|
|
p2[1] = p1[1]
|
|
p3 = oplus.get_top()
|
|
top_line = VMobject()
|
|
top_line.set_points_as_corners([p0, p1, p2, p3])
|
|
|
|
oplus.refresh_bounding_box() # Why?
|
|
h_line1 = Line(diff_vect.get_right(), oplus.get_left())
|
|
h_line2 = Line(oplus.get_right(), oplus.get_right() + h_buff * RIGHT)
|
|
|
|
output = diff_vect.copy()
|
|
output.next_to(h_line2, RIGHT, buff=0)
|
|
for e1, e2, e3 in zip(in_vect.get_entries(), diff_vect.get_entries(), output.get_entries()):
|
|
e3.set_value(e1.get_value() + e2.get_value())
|
|
|
|
circuit = VGroup(top_line, oplus, h_line1, h_line2)
|
|
circuit.set_stroke(color, 3)
|
|
|
|
return circuit, output
|
|
|
|
def get_col_matrix(self, matrix, n_cols_shown, dots_index=-2, sym="C", top_index="m-1", width_multiple=1.0):
|
|
C_labels = VGroup(
|
|
Tex(Rf"\vec{{\textbf{{{sym}}}}}_{{{n}}}")
|
|
for n in [*range(n_cols_shown - 1), top_index]
|
|
)
|
|
C_labels.arrange(RIGHT, buff=0.5)
|
|
C_labels.move_to(matrix.get_entries())
|
|
C_labels.set_width(matrix.get_entries().get_width() * width_multiple)
|
|
|
|
|
|
v_lines = VGroup(
|
|
Line(matrix.get_bottom(), C_labels.get_bottom() + SMALL_BUFF * DOWN),
|
|
Line(C_labels.get_top() + SMALL_BUFF * UP, matrix.get_top()),
|
|
)
|
|
v_lines.set_stroke(WHITE, 1)
|
|
col_labels = VGroup(
|
|
VGroup(C_label, v_lines.copy().match_x(C_label))
|
|
for C_label in C_labels
|
|
)
|
|
if dots_index is not None:
|
|
dots = Tex(R"\hdots")
|
|
dots.move_to(col_labels[dots_index])
|
|
col_labels.replace_submobject(dots_index, dots)
|
|
|
|
return VGroup(matrix.get_brackets().copy(), col_labels)
|
|
|
|
|
|
class NonlinearityOfLanguage(InteractiveScene):
|
|
def construct(self):
|
|
# Set up axes and M + J
|
|
unit_size = 2.5
|
|
|
|
plane = NumberPlane(
|
|
axis_config=dict(
|
|
stroke_width=1,
|
|
),
|
|
background_line_style=dict(
|
|
stroke_color=BLUE_D,
|
|
stroke_width=1,
|
|
stroke_opacity=0.75
|
|
),
|
|
faded_line_ratio=1,
|
|
unit_size=unit_size,
|
|
)
|
|
m_vect = Vector(unit_size * RIGHT).rotate(60 * DEGREES, about_point=ORIGIN)
|
|
j_vect = m_vect.copy().rotate(-90 * DEGREES, about_point=ORIGIN)
|
|
m_vect.set_color(YELLOW)
|
|
j_vect.set_color(RED)
|
|
m_ghost = m_vect.copy().shift(j_vect.get_vector())
|
|
j_ghost = j_vect.copy().shift(m_vect.get_vector())
|
|
VGroup(m_ghost, j_ghost).set_stroke(opacity=0.25)
|
|
|
|
sum_point = m_ghost.get_end()
|
|
span_line = Line(-sum_point, sum_point)
|
|
span_line.set_length(2 * FRAME_WIDTH)
|
|
span_line.set_stroke(WHITE, 2, opacity=0.5)
|
|
|
|
self.add(plane)
|
|
self.add(m_vect, m_ghost, j_vect, j_ghost)
|
|
self.add(span_line)
|
|
|
|
# Label vectors
|
|
m_label = Text("First Name Michael")
|
|
j_label = Text("Last Name Jordan")
|
|
for label, vect in [(m_label, m_vect), (j_label, j_vect)]:
|
|
label.scale(0.6)
|
|
label.match_color(vect)
|
|
direction = np.sign(vect.get_vector()[1]) * UP
|
|
label.next_to(ORIGIN, direction, buff=0.2, aligned_edge=LEFT)
|
|
label.rotate(vect.get_angle(), about_point=ORIGIN)
|
|
label.set_backstroke(BLACK, 3)
|
|
|
|
self.add(m_label)
|
|
self.add(j_label)
|
|
|
|
# Add dot product expression
|
|
expr = Tex(R"(\vec{\textbf{M}} + \vec{\textbf{J}}) \cdot \textbf{E}")
|
|
expr[1:3].match_color(m_vect)
|
|
expr[4:6].match_color(j_vect)
|
|
expr.to_corner(UL)
|
|
self.add(expr)
|
|
|
|
# Set up embedding with dot product tracker
|
|
emb_point = VectorizedPoint(unit_size * UL)
|
|
emb = Vector()
|
|
emb.add_updater(lambda m: m.put_start_and_end_on(ORIGIN, emb_point.get_center()))
|
|
normalized_sum = normalize(sum_point)
|
|
|
|
def get_line_point():
|
|
return normalized_sum * np.dot(normalized_sum, emb_point.get_center())
|
|
|
|
shadow = Line()
|
|
shadow.set_stroke(PINK, 3)
|
|
shadow.add_updater(lambda m: m.put_start_and_end_on(ORIGIN, get_line_point())) # This is a long line
|
|
|
|
dot = Dot()
|
|
dot.set_fill(PINK, 1)
|
|
dot.f_always.move_to(get_line_point)
|
|
|
|
dashed_line = always_redraw(
|
|
lambda: DashedLine(emb_point.get_center(), get_line_point()).set_stroke(PINK, 2)
|
|
)
|
|
|
|
dp_decimal = DecimalNumber(font_size=36)
|
|
dp_decimal.match_color(dot)
|
|
dp_decimal.f_always.set_value(lambda: np.dot(normalized_sum, emb_point.get_center()) * 2.0 / 3.535534)
|
|
dp_decimal.always.next_to(dot, DR, buff=SMALL_BUFF)
|
|
|
|
self.add(shadow, emb, dot, dashed_line, dp_decimal)
|
|
|
|
emb_point.move_to(ORIGIN + 0.01 * UP)
|
|
for point in [m_vect.get_end(), m_ghost.get_end(), j_vect.get_end(), m_ghost.get_end()]:
|
|
self.play(emb_point.animate.move_to(point), run_time=3)
|
|
|
|
# Set up names
|
|
names = VGroup(
|
|
Text(name, font_size=36)
|
|
for name in [
|
|
"Michael Jordan",
|
|
"Michael Phelps",
|
|
"Alexis Jordan",
|
|
]
|
|
)
|
|
name_points = [
|
|
sum_point,
|
|
m_vect.get_end(),
|
|
j_vect.get_end(),
|
|
]
|
|
for name, point in zip(names, name_points):
|
|
name.set_backstroke(BLACK, 3)
|
|
direction = RIGHT + np.sign(point[1]) * UP
|
|
name.next_to(point, direction, buff=0.1)
|
|
|
|
# Go through names
|
|
name = names[0].copy()
|
|
name_ghosts = names.copy().set_fill(opacity=0.75).set_stroke(width=0)
|
|
|
|
self.play(
|
|
FadeIn(name, 0.5 * UP),
|
|
Rotate(emb_point, TAU, about_point=emb_point.get_center() + 0.15 * DL, run_time=4),
|
|
)
|
|
self.wait()
|
|
self.add(name_ghosts[0])
|
|
self.play(
|
|
Transform(name, names[1]),
|
|
emb_point.animate.move_to(m_vect.get_end()),
|
|
run_time=2,
|
|
)
|
|
self.wait()
|
|
self.add(name_ghosts[1])
|
|
self.play(
|
|
Transform(name, names[2]),
|
|
emb_point.animate.move_to(j_vect.get_end()).set_anim_args(path_arc=30 * DEGREES),
|
|
run_time=2,
|
|
)
|
|
self.add(name_ghosts[2])
|
|
self.wait()
|
|
|
|
# Show other names
|
|
other_point = span_line.pfp(0.45)
|
|
other_word = Text("(Other)", font_size=36)
|
|
other_word.set_fill(GREY_B)
|
|
other_word.next_to(other_point, UL, buff=0)
|
|
|
|
self.play(
|
|
emb_point.animate.move_to(other_point),
|
|
LaggedStart(
|
|
FadeOut(name),
|
|
FadeIn(other_word),
|
|
lag_ratio=0.5,
|
|
),
|
|
run_time=3
|
|
)
|
|
self.wait()
|
|
|
|
# Show "yes" vs. "no" regions
|
|
regions = FullScreenRectangle().scale(2).replicate(2)
|
|
regions.arrange(LEFT, buff=0)
|
|
regions[0].set_fill(GREEN_B, 0.35)
|
|
regions[1].set_fill(RED, 0.25)
|
|
regions.rotate(span_line.get_angle(), about_point=ORIGIN)
|
|
regions.shift(0.85 * sum_point)
|
|
|
|
yes_no_words = VGroup(
|
|
Text("Yes", font_size=72).set_fill(GREEN).to_corner(UR),
|
|
Text("No", font_size=72).set_fill(RED).to_edge(UP).shift(LEFT),
|
|
)
|
|
|
|
for region, word in zip(regions, yes_no_words):
|
|
self.play(FadeIn(region), FadeIn(word))
|
|
self.wait()
|
|
|
|
|
|
class Superposition(InteractiveScene):
|
|
def construct(self):
|
|
# Add undulating bubble to encompass N-dimensional space
|
|
frame = self.frame
|
|
bubble = self.undulating_bubble()
|
|
bubble_label = TexText(R"$N$-dimensional\\ Space")
|
|
bubble_label.set_height(1)
|
|
bubble_label["$N$"].set_color(YELLOW)
|
|
bubble_label.next_to(bubble, LEFT)
|
|
|
|
self.add(bubble)
|
|
self.add(bubble_label)
|
|
|
|
# Preview some ideas
|
|
ideas = VGroup(Text("Latin"), Text("Microphone"), Text("Basketball"), Text("The 1920s"))
|
|
ideas.scale(0.75)
|
|
vectors = VGroup()
|
|
idea_vects = VGroup()
|
|
vect = DOWN
|
|
colors = [PINK, GREEN, ORANGE, BLUE]
|
|
for idea, color in zip(ideas, colors):
|
|
vect = rotate_vector(vect, 80 * DEGREES)
|
|
vector = Vector(1.25 * normalize(vect))
|
|
idea.next_to(vector.get_end(), vector.get_vector(), buff=SMALL_BUFF)
|
|
idea_vect = VGroup(vector, idea)
|
|
idea_vect.set_color(color)
|
|
idea_vect.shift(bubble.get_center())
|
|
idea_vects.add(idea_vect)
|
|
|
|
frame.save_state()
|
|
frame.scale(0.75)
|
|
frame.move_to(VGroup(bubble, bubble_label))
|
|
self.play(
|
|
Restore(frame, run_time=7),
|
|
LaggedStartMap(VFadeInThenOut, idea_vects, lag_ratio=0.5, run_time=5)
|
|
)
|
|
|
|
# Written conditions and answer
|
|
conditions = [
|
|
R"$90^\circ$ apart",
|
|
R"between $89^\circ$ and $91^\circ$ apart"
|
|
]
|
|
task1, task2 = tasks = VGroup(
|
|
TexText(Rf"Choose multiple vectors,\\ each pair {phrase}", font_size=42, alignment="")
|
|
for phrase in conditions
|
|
)
|
|
task1[R"90^\circ"].set_color(RED)
|
|
task2[R"$89^\circ$ and $91^\circ$"].set_color(BLUE)
|
|
task1.center().to_edge(UP)
|
|
task2.move_to(task1, UL)
|
|
|
|
maximum1, maximum2 = maxima = VGroup(
|
|
TexText(fR"Maximum \# of vectors: {answer}", font_size=42)
|
|
for answer in ["$N$", R"$\approx \exp(\epsilon \cdot N)$"]
|
|
)
|
|
for maximum in maxima:
|
|
maximum.next_to(tasks, DOWN, buff=LARGE_BUFF, aligned_edge=LEFT)
|
|
maximum1["N"].set_color(YELLOW)
|
|
maximum2["N"].set_color(YELLOW)
|
|
|
|
# Add 3 vectors such that each pair is 90-degrees
|
|
perp_vectors = VGroup(*map(Vector, [RIGHT, UP, OUT]))
|
|
perp_vectors.set_shading(0.25, 0.25, 0.25)
|
|
perp_vectors.set_submobject_colors_by_gradient(RED, GREEN, BLUE)
|
|
elbows = VGroup(
|
|
Elbow(width=0.1).rotate(angle, axis, about_point=ORIGIN).set_stroke(WHITE, 2)
|
|
for angle, axis in [(0, UP), (-PI / 2, UP), (PI / 2, RIGHT)]
|
|
)
|
|
elbows.set_stroke(GREY_A, 2)
|
|
|
|
perp_group = VGroup(perp_vectors, elbows)
|
|
perp_group.rotate(-10 * DEGREES, UP)
|
|
perp_group.rotate(20 * DEGREES, RIGHT)
|
|
perp_group.scale(2)
|
|
perp_group.move_to(bubble)
|
|
|
|
self.play(
|
|
FadeIn(task1),
|
|
LaggedStartMap(GrowArrow, perp_vectors[:2], lag_ratio=0.5)
|
|
)
|
|
self.play(ShowCreation(elbows[0]))
|
|
self.play(
|
|
GrowArrow(perp_vectors[2]),
|
|
LaggedStartMap(ShowCreation, elbows[1:3], lag_ratio=0.5),
|
|
)
|
|
self.play(
|
|
Rotate(perp_group, -50 * DEGREES, axis=perp_vectors[1].get_vector(), run_time=15),
|
|
Write(maximum1, time_span=(2, 4)),
|
|
)
|
|
|
|
# Relax the assumption
|
|
ninety_part = task1[conditions[0]]
|
|
cross = Cross(ninety_part)
|
|
crossed_part = VGroup(ninety_part, cross)
|
|
new_cond = task2[conditions[1]]
|
|
new_cond.align_to(ninety_part, LEFT)
|
|
|
|
pairs = VGroup(get_vector_pair(89), get_vector_pair(91))
|
|
pairs.arrange(RIGHT)
|
|
pairs.to_corner(UL)
|
|
|
|
self.play(
|
|
FadeOut(maximum1),
|
|
ShowCreation(cross),
|
|
)
|
|
self.play(
|
|
crossed_part.animate.shift(0.5 * DOWN).set_fill(opacity=0.5),
|
|
Write(new_cond),
|
|
LaggedStartMap(FadeIn, pairs, lag_ratio=0.25),
|
|
)
|
|
self.play(
|
|
Rotate(perp_group, 50 * DEGREES, axis=perp_vectors[1].get_vector(), run_time=10)
|
|
)
|
|
|
|
# Struggle with 3 vectors (Sub out the title)
|
|
three_d_label = TexText(R"3-dimensional\\ Space")
|
|
three_d_label["3"].set_color(BLUE)
|
|
three_d_label.move_to(bubble_label, UL)
|
|
bubble_label.save_state()
|
|
|
|
pv = perp_vectors
|
|
pv.save_state()
|
|
alt_vects = pv.copy()
|
|
origin = pv[0].get_start()
|
|
for vect in alt_vects:
|
|
vect.rotate(5 * DEGREES, axis=normalize(np.random.random(3)), about_point=origin)
|
|
|
|
new_vects = VGroup()
|
|
for (v1, v2) in it.combinations(pv, 2):
|
|
new_vects.add(Arrow(ORIGIN, v1.get_length() * normalize(v1.get_vector() + v2.get_vector()), buff=0).shift(origin))
|
|
new_vects.set_color(YELLOW)
|
|
new_vect = new_vects[0]
|
|
|
|
def shake(vect):
|
|
self.play(
|
|
vect.animate.rotate(5 * DEGREES, RIGHT, about_point=origin),
|
|
rate_func=lambda t: wiggle(t, 9)
|
|
)
|
|
|
|
self.play(
|
|
FadeIn(three_d_label, DOWN),
|
|
bubble_label.animate.to_edge(DOWN).set_opacity(0.5)
|
|
)
|
|
self.play(
|
|
GrowArrow(new_vect),
|
|
Transform(perp_vectors, alt_vects)
|
|
)
|
|
shake(new_vect)
|
|
self.play(
|
|
Restore(perp_vectors),
|
|
Transform(new_vect, new_vects[1])
|
|
)
|
|
shake(new_vect)
|
|
self.play(
|
|
Transform(perp_vectors, alt_vects),
|
|
Transform(new_vect, new_vects[2])
|
|
)
|
|
shake(new_vect)
|
|
self.wait()
|
|
self.play(
|
|
new_vect.animate.scale(0, about_point=origin),
|
|
ApplyMethod(perp_group.scale, 0, dict(about_point=origin), lag_ratio=0.25),
|
|
Restore(bubble_label),
|
|
FadeOut(three_d_label, UP),
|
|
run_time=2
|
|
)
|
|
self.remove(new_vect, perp_group)
|
|
|
|
# Stack on many vectors
|
|
dodec = Dodecahedron()
|
|
vertices = [face.get_center() for face in dodec]
|
|
vectors = VGroup(Vector(vert) for vert in vertices)
|
|
vectors.set_flat_stroke(True)
|
|
vectors.rotate(30 * DEGREES, UR)
|
|
for vector in vectors:
|
|
vector.always.set_perpendicular_to_camera(self.frame)
|
|
vector.set_color(random_bright_color(hue_range=(0.5, 0.7)))
|
|
vectors.move_to(bubble)
|
|
|
|
self.wait(6)
|
|
self.play(
|
|
FadeOut(crossed_part),
|
|
Write(maximum2),
|
|
Rotating(vectors, TAU, axis=UP, run_time=20),
|
|
LaggedStartMap(VFadeIn, vectors, lag_ratio=0.5, run_time=8)
|
|
)
|
|
self.wait()
|
|
|
|
# Somehow communicate exponential scaling
|
|
|
|
def undulating_bubble(self):
|
|
bubble = ThoughtBubble(filler_shape=(6, 3))[0][-1]
|
|
bubble.set_stroke(WHITE, 1)
|
|
bubble.set_fill(GREY)
|
|
bubble.set_shading(0.5, 0.5, 0)
|
|
bubble.to_edge(DOWN)
|
|
|
|
points = bubble.get_points().copy()
|
|
points -= np.mean(points, 0)
|
|
|
|
def update_bubble(bubble):
|
|
center = bubble.get_center()
|
|
angles = np.apply_along_axis(angle_of_vector, 1, points)
|
|
stretch_factors = 1.0 + 0.05 * np.sin(6 * angles + self.time)
|
|
bubble.set_points(points * stretch_factors[:, np.newaxis])
|
|
# bubble.move_to(center)
|
|
bubble.set_x(0).to_edge(DOWN)
|
|
|
|
bubble.add_updater(update_bubble)
|
|
return bubble
|
|
|
|
|
|
class StackOfVectors(InteractiveScene):
|
|
def construct(self):
|
|
# Set up the big matrix
|
|
rows = VGroup(
|
|
NumericEmbedding(shape=(1, 9), ellipses_col=-5, value_range=(-1, 1))
|
|
for n in range(20)
|
|
)
|
|
rows.arrange(DOWN)
|
|
for row in rows:
|
|
row.brackets[0].align_to(rows, LEFT)
|
|
row.brackets[1].align_to(rows, RIGHT)
|
|
rows.set_height(6)
|
|
rows.to_edge(DOWN)
|
|
rows[-2].become(Tex(R"\vdots").replace(rows[-2], dim_to_match=1))
|
|
brackets = NumericEmbedding(shape=(20, 9)).brackets
|
|
brackets.set_height(rows.get_height() + MED_SMALL_BUFF)
|
|
brackets[0].next_to(rows, LEFT, SMALL_BUFF)
|
|
brackets[1].next_to(rows, RIGHT, SMALL_BUFF)
|
|
|
|
top_brace = Brace(rows[0], UP)
|
|
top_label = top_brace.get_text("100-dimensional")
|
|
side_brace = Brace(brackets, LEFT)
|
|
side_label = side_brace.get_text("10,000\nvectors")
|
|
|
|
self.play(
|
|
GrowFromCenter(top_brace),
|
|
FadeIn(top_label, lag_ratio=0.1),
|
|
LaggedStartMap(FadeIn, rows, shift=0.25 * DOWN, lag_ratio=0.1, run_time=3),
|
|
*map(GrowFromCenter, brackets)
|
|
)
|
|
self.play(
|
|
LaggedStart(
|
|
(RandomizeMatrixEntries(row)
|
|
for row in rows[:-2]),
|
|
lag_ratio=0.05,
|
|
)
|
|
)
|
|
self.wait()
|
|
|
|
# Label first vector
|
|
self.play(
|
|
GrowFromCenter(side_brace),
|
|
FadeIn(side_label, lag_ratio=0.1),
|
|
)
|
|
self.wait(4)
|
|
|
|
|
|
class ShowAngleRange(InteractiveScene):
|
|
def construct(self):
|
|
# Test
|
|
angle_tracker = ValueTracker(10)
|
|
vect_pair = always_redraw(lambda: get_vector_pair(angle_tracker.get_value(), length=3, colors=(RED, GREEN)))
|
|
|
|
self.add(vect_pair)
|
|
self.play(
|
|
angle_tracker.animate.set_value(180),
|
|
run_time=8,
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
angle_tracker.animate.set_value(95),
|
|
run_time=3
|
|
)
|
|
self.wait()
|
|
|
|
|
|
class MLPFeatures(InteractiveScene):
|
|
def construct(self):
|
|
# Add neurons
|
|
radius = 0.15
|
|
layer1, layer2 = layers = VGroup(
|
|
Dot(radius=radius).get_grid(n, 1, buff=radius / 2)
|
|
for n in [8, 16]
|
|
)
|
|
layer2.arrange(DOWN, buff=radius)
|
|
layers.arrange(RIGHT, buff=3.0)
|
|
layers.to_edge(LEFT, buff=LARGE_BUFF)
|
|
layers.set_stroke(WHITE, 1)
|
|
for neuron in layer1:
|
|
neuron.set_fill(opacity=random.random())
|
|
layer2.set_fill(opacity=0)
|
|
|
|
self.add(layers)
|
|
|
|
# Add connections
|
|
connections = get_network_connections(layer1, layer2)
|
|
self.add(connections)
|
|
|
|
# Show single-neuron features
|
|
features = iter([
|
|
"Table",
|
|
"Slang",
|
|
"AM Radio",
|
|
"Humble",
|
|
"Notebook",
|
|
"Transparent",
|
|
"Duration",
|
|
"Madonna",
|
|
"Mirror",
|
|
"Pole Vaulting",
|
|
"Albert Einstein",
|
|
"Authentic",
|
|
"Scientific",
|
|
"Passionate",
|
|
"Bell Laboratories",
|
|
"Uzbekistan",
|
|
"Umbrella",
|
|
"Immanuel Kant",
|
|
"Baroque Music",
|
|
"Intense",
|
|
"Clock",
|
|
"Water skiing",
|
|
"Ancient Egypt",
|
|
"Ambiguous",
|
|
"Volume",
|
|
"Alexander the Great",
|
|
"Innovative",
|
|
"Religious",
|
|
])
|
|
|
|
last_neuron = VGroup()
|
|
last_feature_label = VGroup()
|
|
for neuron in layer2[:15]:
|
|
feature_label = Text(next(features), font_size=36)
|
|
feature_label.next_to(neuron, buff=SMALL_BUFF)
|
|
|
|
self.play(
|
|
FadeOut(last_feature_label),
|
|
FadeIn(feature_label),
|
|
last_neuron.animate.set_fill(opacity=0),
|
|
neuron.animate.set_fill(opacity=1),
|
|
)
|
|
|
|
last_neuron = neuron
|
|
last_feature_label = feature_label
|
|
|
|
# Show polysemantic features
|
|
brace = Brace(layer2, RIGHT)
|
|
|
|
def to_random_state(layer):
|
|
for dot in layer.generate_target():
|
|
dot.set_fill(opacity=random.random())
|
|
return MoveToTarget(layer)
|
|
|
|
self.play(
|
|
feature_label.animate.scale(48 / 36).next_to(brace, RIGHT),
|
|
GrowFromCenter(brace),
|
|
to_random_state(layer2),
|
|
)
|
|
self.wait()
|
|
for n in range(12):
|
|
feature_label = Text(next(features))
|
|
feature_label.next_to(brace, RIGHT)
|
|
self.play(
|
|
FadeOut(last_feature_label),
|
|
FadeIn(feature_label),
|
|
to_random_state(layer2),
|
|
)
|
|
self.wait(0.5)
|
|
|
|
last_feature_label = feature_label
|
|
|
|
|
|
class BreakDownThreeSteps(BasicMLPWalkThrough):
|
|
def construct(self):
|
|
# Add four vectors, spaced apart
|
|
vectors = VGroup(
|
|
NumericEmbedding(length=n)
|
|
for n in [8, 16, 16, 8]
|
|
)
|
|
vectors.set_height(6)
|
|
vectors.arrange(RIGHT, buff=3.5)
|
|
vectors[2].shift(1.1 * LEFT)
|
|
vectors[1].shift(0.2 * LEFT)
|
|
vectors.shift(DOWN)
|
|
for e1, e2 in zip(vectors[1].get_entries(), vectors[2].get_entries()):
|
|
e2.set_value(max(e1.get_value(), 0))
|
|
|
|
# Add arrows between them
|
|
arrows = VGroup(
|
|
Arrow(v1, v2)
|
|
for v1, v2 in zip(vectors, vectors[1:])
|
|
)
|
|
arrows.shift(DOWN)
|
|
|
|
E_sym = Tex(R"\vec{\textbf{E}}")
|
|
E_sym.next_to(arrows[0], LEFT).shift(0.1 * UP)
|
|
|
|
for vect in vectors:
|
|
vect.scale(0.75)
|
|
vect.shift(0.25 * UP)
|
|
|
|
# Put matrices on outer two
|
|
up_proj, down_proj = matrices = VGroup(
|
|
WeightMatrix(shape=(12, 6)),
|
|
WeightMatrix(shape=(6, 11)),
|
|
)
|
|
matrices.scale(0.25)
|
|
for arrow, mat in zip(arrows[::2], matrices):
|
|
mat.next_to(arrow, UP)
|
|
|
|
# Put ReLU graph on the middle
|
|
axes = Axes((-3, 3), (0, 3))
|
|
graph = axes.get_graph(lambda x: max(x, 0))
|
|
graph.set_color(BLUE)
|
|
relu = VGroup(axes, graph)
|
|
relu.match_width(arrows[1])
|
|
relu.next_to(arrows[1], UP)
|
|
|
|
# Full box
|
|
box = SurroundingRectangle(VGroup(arrows, matrices), buff=1.0)
|
|
box.set_stroke(WHITE, 2)
|
|
box.set_fill(GREY_E, 1)
|
|
title = Text("Multilayer Perceptron", font_size=60)
|
|
title.next_to(box, UP, SMALL_BUFF)
|
|
|
|
self.add(box, title)
|
|
|
|
# Animate them all in
|
|
for matrix in matrices:
|
|
matrix.brackets.save_state()
|
|
matrix.brackets.stretch(0, 0).set_opacity(0)
|
|
|
|
self.play(
|
|
LaggedStartMap(GrowArrow, arrows, lag_ratio=0.5),
|
|
FadeIn(up_proj.get_rows(), lag_ratio=0.1, time_span=(0.0, 1.5)),
|
|
FadeIn(down_proj.get_rows(), lag_ratio=0.1, time_span=(1.5, 3.0)),
|
|
Restore(up_proj.brackets, time_span=(0.0, 1.5)),
|
|
Restore(down_proj.brackets, time_span=(1.5, 3.0)),
|
|
Write(relu, time_span=(1, 2)),
|
|
run_time=3
|
|
)
|
|
self.wait()
|
|
|
|
# Show row replacement on the first
|
|
n, m = up_proj.shape
|
|
n_rows_shown = 8
|
|
R_labels = VGroup(
|
|
Tex(R"\vec{\textbf{R}}_{" + str(n) + "}")
|
|
for n in [*range(n_rows_shown - 1), "n-1"]
|
|
)
|
|
R_labels[-2].become(Tex(R"\vdots").replace(R_labels[-2], dim_to_match=1))
|
|
R_labels.arrange(DOWN, buff=0.5)
|
|
R_labels.match_height(up_proj)
|
|
R_labels.move_to(up_proj)
|
|
h_lines = VGroup(
|
|
Line(up_proj.get_brackets()[0], R_labels, buff=0.1),
|
|
Line(R_labels, up_proj.get_brackets()[1], buff=0.1),
|
|
)
|
|
h_lines.set_stroke(GREY_A, 2)
|
|
row_labels = VGroup(
|
|
VGroup(R_label, h_lines.copy().match_y(R_label))
|
|
for R_label in R_labels
|
|
)
|
|
row_labels.set_color(YELLOW)
|
|
row_matrix = VGroup(
|
|
up_proj.get_brackets().copy(),
|
|
row_labels
|
|
)
|
|
|
|
self.play(
|
|
FadeOut(up_proj.get_rows(), lag_ratio=0.1),
|
|
FadeIn(row_labels, lag_ratio=0.1),
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
row_labels[0][0].copy().animate.scale(2).next_to(title, UL).shift(2 * LEFT).set_opacity(0),
|
|
)
|
|
self.wait()
|
|
|
|
# Show the neurons
|
|
dots = VGroup(
|
|
Dot().set_fill(opacity=random.random()).move_to(entry)
|
|
for entry in vectors[2].get_columns()[0]
|
|
)
|
|
for dot in dots:
|
|
dot.match_x(dots[0])
|
|
dots.set_stroke(WHITE, 1)
|
|
self.play(Write(dots))
|
|
self.wait()
|
|
|
|
# Show column replacement on the second
|
|
col_matrix = self.get_col_matrix(down_proj, 8)
|
|
col_labels = col_matrix[1]
|
|
col_labels.set_color(RED_B)
|
|
|
|
self.play(
|
|
FadeOut(down_proj.get_columns(), lag_ratio=0.1),
|
|
FadeIn(col_labels, lag_ratio=0.1),
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
col_labels[0][0].copy().animate.scale(2).next_to(title, UR).shift(2 * RIGHT).set_opacity(0),
|
|
)
|
|
self.wait()
|
|
|
|
return
|
|
#### Trash ####
|
|
|
|
vectors[0].next_to(arrows[0], LEFT)
|
|
vectors[0].align_to(vectors[1], DOWN)
|
|
self.play(FadeIn(vectors[0]))
|
|
for i in (0, 1):
|
|
self.play(
|
|
FadeTransform(vectors[i].copy(), vectors[i + 1]),
|
|
rate_func=linear,
|
|
)
|
|
|
|
|
|
class SuperpositionVectorBundle(InteractiveScene):
|
|
def construct(self):
|
|
# Setup
|
|
frame = self.frame
|
|
axes = ThreeDAxes(z_range=(-3, 3))
|
|
axes.scale(0.5)
|
|
vects = VGroup(
|
|
self.get_new_vector(v)
|
|
for v in np.identity(3)
|
|
)
|
|
|
|
frame.reorient(23, 71, 0, (0.0, 0.0, 0.5), 3.5)
|
|
frame.add_ambient_rotation(4 * DEGREES)
|
|
self.add(frame)
|
|
self.add(axes)
|
|
self.add(vects)
|
|
self.wait(2)
|
|
|
|
# Add a new vector
|
|
n_vects = 10
|
|
for n in range(n_vects):
|
|
new_vect = self.get_new_vector(normalize(np.random.uniform(-1, 1, 3)))
|
|
# self.play(GrowArrow(new_vect))
|
|
vects.add(new_vect)
|
|
self.space_out_vectors(vects, run_time=3 + 0.5 * n)
|
|
self.wait(5)
|
|
|
|
# Use tensor flow to repeatedly cram more vectors into a space
|
|
pass
|
|
|
|
def get_new_vector(self, coords, color=None, opacity=0.9):
|
|
if color is None:
|
|
color = random_bright_color(hue_range=(0.4, 0.6), luminance_range=(0.5, 0.9))
|
|
vect = Vector(coords, thickness=2.0)
|
|
vect.set_fill(color, opacity=opacity, border_width=2)
|
|
vect.always.set_perpendicular_to_camera(self.frame)
|
|
return vect
|
|
|
|
def space_out_vectors(self, vects, run_time=4, learning_rate=0.01):
|
|
num_vectors = len(vects)
|
|
ends = np.array([v.get_end() for v in vects])
|
|
matrix = torch.from_numpy(ends)
|
|
matrix.requires_grad_(True)
|
|
|
|
optimizer = torch.optim.Adam([matrix], lr=learning_rate)
|
|
dot_diff_cutoff = 0.01
|
|
id_mat = torch.eye(num_vectors, num_vectors)
|
|
|
|
def update_vects(vects):
|
|
optimizer.zero_grad()
|
|
dot_products = matrix @ matrix.T
|
|
# Punish deviation from orthogonal
|
|
diff = dot_products - id_mat
|
|
# loss = (diff.abs() - dot_diff_cutoff).relu().sum()
|
|
loss = diff.pow(6).sum()
|
|
|
|
# Extra incentive to keep rows normalized
|
|
loss += num_vectors * diff.diag().pow(2).sum()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
for vect, arr in zip(vects, matrix):
|
|
vect.put_start_and_end_on(ORIGIN, arr.detach().numpy())
|
|
|
|
self.play(UpdateFromFunc(vects, update_vects, run_time=run_time))
|
|
|
|
|
|
# Some old stubs
|
|
|
|
|
|
class ClassicNeuralNetworksPicture(InteractiveScene):
|
|
def construct(self):
|
|
pass
|
|
|
|
|
|
class ShowBiasBakedIntoWeightMatrix(LastTwoChapters):
|
|
def construct(self):
|
|
# Add initial blocks
|
|
frame = self.frame
|
|
square = Square(2.0)
|
|
att_icon = self.get_att_icon(square)
|
|
att_icon.set_stroke(WHITE, 1, 0.5)
|
|
mlp_icon = self.get_mlp_icon(square, layer_buff=1.0)
|
|
lnm_icon = self.get_layer_norm_icon()
|
|
lnm_icon.match_height(mlp_icon)
|
|
|
|
att_block = self.get_block(att_icon, "Attention", "604M Parameters", color=YELLOW)
|
|
mlp_block = self.get_block(mlp_icon, "MLP", "1.2B Parameters", color=BLUE)
|
|
lnm_block = self.get_block(lnm_icon, "Layer Norm", "49K Parameters", color=GREY_B)
|
|
|
|
blocks = VGroup(att_block, mlp_block, lnm_block)
|
|
blocks.arrange(RIGHT, buff=1.5)
|
|
|
|
lil_wrapper = self.get_layer_wrapper(blocks[:2].copy())
|
|
big_wrapper = self.get_layer_wrapper(blocks)
|
|
|
|
self.add(lil_wrapper, blocks[:2])
|
|
frame.match_x(blocks[:2])
|
|
self.wait()
|
|
self.play(
|
|
frame.animate.match_x(blocks),
|
|
ReplacementTransform(lil_wrapper, big_wrapper),
|
|
FadeIn(lnm_block, RIGHT),
|
|
)
|
|
self.wait()
|
|
self.play(FlashAround(lnm_block[2], run_time=3, time_width=2))
|
|
self.wait()
|
|
|
|
def get_layer_norm_icon(self):
|
|
axes1, axes2 = all_axes = VGroup(
|
|
Axes((-4, 4), (0, 1, 0.25))
|
|
for x in range(2)
|
|
)
|
|
all_axes.set_shape(1.5, 0.5)
|
|
all_axes.arrange(DOWN, buff=1.0)
|
|
graph1 = axes1.get_graph(lambda x: 0.5 * norm.pdf(0.5 * x - 0.5))
|
|
graph2 = axes2.get_graph(lambda x: 1.5 * norm.pdf(x))
|
|
graph1.set_stroke(BLUE).set_fill(BLUE, 0.25)
|
|
graph2.set_stroke(BLUE).set_fill(BLUE, 0.25)
|
|
arrow = Arrow(axes1, axes2, buff=0.1)
|
|
|
|
return VGroup(axes1, graph1, arrow, axes2, graph2)
|
|
|
|
def get_layer_wrapper(self, blocks):
|
|
beige = "#F5F5DC"
|
|
rect = self.get_block(blocks, color=beige, buff=0.5, height=4)[0]
|
|
wrapped_arrow = self.get_wrapped_arrow(rect)
|
|
multiple = Tex(R"\times 96")
|
|
multiple.next_to(wrapped_arrow, UP)
|
|
|
|
arrows = VGroup()
|
|
for b1, b2 in zip(blocks, blocks[1:]):
|
|
arrows.add(Arrow(b1[0], b2[0], buff=0.1))
|
|
|
|
return VGroup(rect, arrows, wrapped_arrow, multiple)
|
|
|
|
def get_block(
|
|
self, content,
|
|
upper_label="",
|
|
lower_label="",
|
|
upper_font_size=42,
|
|
lower_font_size=36,
|
|
buff=0.25,
|
|
height=2,
|
|
color=BLUE,
|
|
stroke_width=3,
|
|
fill_opacity=0.2
|
|
):
|
|
block = SurroundingRectangle(content, buff=buff)
|
|
block.set_height(height, stretch=True)
|
|
block.round_corners(radius=0.25)
|
|
block.set_stroke(color, 3)
|
|
block.set_fill(color, fill_opacity)
|
|
|
|
low_label = Text(lower_label, font_size=lower_font_size)
|
|
low_label.next_to(block, DOWN, MED_SMALL_BUFF)
|
|
top_label = Text(upper_label, font_size=upper_font_size)
|
|
top_label.next_to(block, UP, MED_SMALL_BUFF)
|
|
|
|
return VGroup(block, content, low_label, top_label)
|
|
|
|
def get_wrapped_arrow(self, big_block, buff=0.75, color=GREY_B, stroke_width=4):
|
|
vertices = [
|
|
big_block.get_corner(RIGHT),
|
|
big_block.get_corner(RIGHT) + buff * RIGHT,
|
|
big_block.get_corner(UR) + buff * UR,
|
|
big_block.get_corner(UL) + buff * UL,
|
|
big_block.get_corner(LEFT) + buff * LEFT,
|
|
big_block.get_corner(LEFT),
|
|
]
|
|
line = Polygon(*vertices)
|
|
line.round_corners()
|
|
line.set_points(line.get_points()[:-2, :])
|
|
line.set_stroke(color, stroke_width)
|
|
tip = ArrowTip().move_to(line.get_end(), RIGHT)
|
|
tip.set_color(color)
|
|
line.add(tip)
|
|
return line
|
|
|
|
|
|
class AlmostOrthogonal(InteractiveScene):
|
|
def construct(self):
|
|
pass
|