mirror of
https://github.com/3b1b/videos.git
synced 2025-09-18 21:38:53 +00:00
1531 lines
51 KiB
Python
1531 lines
51 KiB
Python
import torch
|
|
from manim_imports_ext import *
|
|
from _2024.transformers.helpers import *
|
|
from _2024.transformers.embedding import *
|
|
|
|
|
|
class HighLevelNetworkFlow(InteractiveScene):
|
|
# example_text = "To date, the cleverest thinker of all time was blank"
|
|
use_words = False
|
|
# possible_next_tokens = [
|
|
# (" the", 0.0882),
|
|
# (" probably", 0.0437),
|
|
# (" John", 0.0404),
|
|
# (" Sir", 0.0366),
|
|
# (" Albert", 0.0363),
|
|
# (" Ber", 0.0331),
|
|
# (" a", 0.029),
|
|
# (" Isaac", 0.0201),
|
|
# (" undoubtedly", 0.0158),
|
|
# (" arguably", 0.0133),
|
|
# (" Im", 0.0116),
|
|
# (" Einstein", 0.0113),
|
|
# (" Ludwig", 0.0104),
|
|
# ]
|
|
hide_block_labels = False
|
|
block_to_title_direction = UP
|
|
|
|
example_text = "The Computer History Museum is located in Mountain"
|
|
possible_next_tokens = [
|
|
("Mountain", 0.706),
|
|
("the", 0.128),
|
|
("San", 0.078),
|
|
("Seattle", 0.006),
|
|
("New", 0.003),
|
|
("Palo", 0.002),
|
|
("Google", 0.002),
|
|
("southern", 0.002),
|
|
]
|
|
|
|
def setup(self):
|
|
super().setup()
|
|
self.set_floor_plane("xz")
|
|
self.camera.light_source.move_to([-10, 10, 15])
|
|
self.camera.light_source.add_updater(lambda m: m.set_z(self.frame.get_z() + 10))
|
|
self.layers = VGroup()
|
|
self.blocks = Group()
|
|
self.mlps = Group()
|
|
|
|
def construct(self):
|
|
frame = self.frame
|
|
|
|
# Embedding
|
|
self.show_initial_text_embedding()
|
|
|
|
# Passing through some layers
|
|
self.progress_through_attention_block(target_frame_x=-2)
|
|
self.progress_through_mlp_block(
|
|
show_one_by_one=True,
|
|
sideview_orientation=(-78, -10, 0),
|
|
)
|
|
|
|
# # Temporary
|
|
# block = self.blocks[-1]
|
|
# nodes = self.mlps[-1][0]
|
|
# lines = self.mlps[-1][1]
|
|
# lines.save_state()
|
|
# lines.set_stroke(width=0.5, opacity=0.5)
|
|
# self.play(
|
|
# FadeOut(self.token_blocks),
|
|
# FadeOut(self.token_arrows),
|
|
# FadeOut(self.final_word_question),
|
|
# FadeOut(self.blocks[0]),
|
|
# FadeOut(self.layers),
|
|
# FadeOut(block[0]),
|
|
# )
|
|
# self.play(
|
|
# self.frame.animate.reorient(-65, -8, 0, (-0.77, 1.7, 7.26), 12.57),
|
|
# Restore(lines, lag_ratio=0.01),
|
|
# run_time=4
|
|
# )
|
|
# self.play(self.frame.animate.reorient(-15, -8, 0, (-0.77, 1.7, 7.26), 12.57), run_time=4)
|
|
|
|
#
|
|
orientation = frame.get_euler_angles() / DEGREES
|
|
mlp_kw = dict(sideview_orientation=orientation, final_orientation=orientation)
|
|
att_kw = dict(target_orientation=orientation)
|
|
for x in [-3, -4, -5, -6, -7]:
|
|
self.progress_through_attention_block(target_frame_x=x, **att_kw)
|
|
self.progress_through_mlp_block(**mlp_kw)
|
|
|
|
# Show how it ends
|
|
self.remove_mlps()
|
|
self.mention_repetitions()
|
|
self.focus_on_last_layer()
|
|
self.show_unembedding()
|
|
|
|
def get_embedding_array(
|
|
self,
|
|
shape=(9, 10),
|
|
height=4,
|
|
dots_index=-4,
|
|
buff_ratio=0.4,
|
|
bracket_color=GREY_B,
|
|
backstroke_width=3,
|
|
add_background_rectangle=False,
|
|
):
|
|
return EmbeddingArray(
|
|
shape=shape,
|
|
height=height,
|
|
dots_index=dots_index,
|
|
buff_ratio=buff_ratio,
|
|
bracket_color=bracket_color,
|
|
backstroke_width=backstroke_width,
|
|
add_background_rectangle=add_background_rectangle,
|
|
)
|
|
|
|
def swap_embedding_for_dots(self, embedding_array, dots_index=-4):
|
|
embedding_array.swap_embedding_for_dots(dots_index)
|
|
|
|
def get_next_layer_array(self, embedding_array, z_buff=3.0):
|
|
next_array = embedding_array.copy()
|
|
embeddings = [part for part in next_array if isinstance(part, NumericEmbedding)]
|
|
for embedding in embeddings:
|
|
for entry in embedding.get_entries():
|
|
entry.set_value(random.uniform(*embedding.value_range))
|
|
next_array.shift(z_buff * OUT)
|
|
return next_array
|
|
|
|
def get_block(
|
|
self,
|
|
layer,
|
|
depth=2.0,
|
|
buff=1.0,
|
|
size_buff=1.0,
|
|
color=GREY_E,
|
|
opacity=0.75,
|
|
shading=(0.25, 0.1, 0.0),
|
|
title="Attention",
|
|
title_font_size=96,
|
|
title_backstroke_width=3,
|
|
):
|
|
# Block
|
|
body = Cube(color=color, opacity=opacity)
|
|
body.deactivate_depth_test()
|
|
width, height = layer.get_shape()[:2]
|
|
body.set_shape(width + size_buff, height + size_buff, depth)
|
|
body.set_shading(0.5, 0.5, 0.0)
|
|
body.next_to(layer, OUT, buff=buff)
|
|
body.sort(lambda p: np.dot(p, [-1, 1, 1]))
|
|
|
|
title = Text(title, font_size=title_font_size)
|
|
title.set_backstroke(BLACK, title_backstroke_width)
|
|
title.next_to(body, self.block_to_title_direction, buff=0.1)
|
|
block = Group(body, title)
|
|
block.body = body
|
|
block.title = title
|
|
if self.hide_block_labels:
|
|
title.set_opacity(0)
|
|
|
|
return block
|
|
|
|
def show_initial_text_embedding(self, word_scale_factor=0.6, bump_first=False):
|
|
# Mention next word prediction task
|
|
phrase = Text(self.example_text)
|
|
phrase.set_max_width(FRAME_WIDTH - 1)
|
|
if self.use_words:
|
|
words = break_into_words(phrase)
|
|
rects = get_piece_rectangles(words)
|
|
else:
|
|
words = break_into_tokens(phrase)
|
|
rects = get_piece_rectangles(
|
|
words, leading_spaces=True, h_buff=0
|
|
)
|
|
|
|
words.remove(words[-1])
|
|
q_marks = Text("???")
|
|
rects[-1].set_color(YELLOW)
|
|
q_marks.next_to(rects[-1], DOWN)
|
|
|
|
big_rect = Rectangle()
|
|
big_rect.replace(rects[:-1], stretch=True)
|
|
big_rect.set_stroke(GREY_B, 2)
|
|
arrow = Arrow(big_rect.get_top(), rects[-1].get_top(), path_arc=-120 * DEGREES)
|
|
arrow.scale(0.5, about_edge=DR)
|
|
|
|
self.play(ShowIncreasingSubsets(words, run_time=1))
|
|
self.add(rects[-1])
|
|
self.play(LaggedStart(
|
|
FadeIn(big_rect),
|
|
ShowCreation(arrow),
|
|
Write(q_marks),
|
|
lag_ratio=0.3,
|
|
))
|
|
self.wait()
|
|
self.play(
|
|
FadeOut(big_rect),
|
|
LaggedStart(*(
|
|
DrawBorderThenFill(rect)
|
|
for rect in rects[:-1]
|
|
), lag_ratio=0.02),
|
|
LaggedStart(*(
|
|
token.animate.match_color(rect)
|
|
for token, rect in zip(words, rects)
|
|
)),
|
|
FadeOut(arrow)
|
|
)
|
|
self.wait()
|
|
|
|
# Label the tokens
|
|
token_label = Text("Tokens", font_size=72)
|
|
token_label.to_edge(UP)
|
|
arrows = VGroup(
|
|
Arrow(token_label.get_bottom(), rect.get_top()).match_color(rect)
|
|
for rect in rects[:-1]
|
|
)
|
|
|
|
self.play(FadeIn(token_label, UP))
|
|
self.play(LaggedStartMap(VFadeInThenOut, arrows, lag_ratio=0.25, run_time=4))
|
|
self.play(FadeOut(token_label, DOWN))
|
|
|
|
# Show words into vectors
|
|
layer = self.get_embedding_array(
|
|
shape=(10, len(words)),
|
|
dots_index=None,
|
|
)
|
|
vectors = layer.embeddings
|
|
|
|
blocks = VGroup(*(VGroup(rect, token) for rect, token in zip(rects, words)))
|
|
q_group = VGroup(rects[-1], q_marks)
|
|
blocks.target = blocks.generate_target()
|
|
for block, vector in zip(blocks.target, vectors):
|
|
block.scale(word_scale_factor)
|
|
block.next_to(layer, UP, buff=1.5)
|
|
block.match_x(vector)
|
|
|
|
arrows = VGroup(*(
|
|
Arrow(block, vect, stroke_width=3)
|
|
for block, vect in zip(blocks.target, vectors)
|
|
))
|
|
word_to_index = dict(zip(self.example_text.split(" "), it.count()))
|
|
# self.swap_embedding_for_dots(layer, word_to_index.get("...", -4))
|
|
|
|
if bump_first:
|
|
blocks.target[0].next_to(blocks.target[1], LEFT, buff=0.1)
|
|
|
|
self.play(
|
|
self.frame.animate.move_to(1.0 * UP),
|
|
MoveToTarget(blocks),
|
|
q_group.animate.scale(word_scale_factor).next_to(blocks.target, RIGHT, aligned_edge=UP),
|
|
LaggedStartMap(FadeIn, vectors, shift=0.5 * DOWN),
|
|
LaggedStartMap(GrowFromCenter, arrows),
|
|
Write(layer.dots)
|
|
)
|
|
# self.play(Write(layer.brackets))
|
|
self.wait()
|
|
|
|
self.token_blocks = blocks
|
|
self.token_arrows = arrows
|
|
self.final_word_question = VGroup(rects[-1], q_marks)
|
|
|
|
self.layers.add(layer)
|
|
|
|
def progress_through_attention_block(
|
|
self,
|
|
depth=2.0,
|
|
target_orientation=(-40, -15, 0),
|
|
target_frame_height=14,
|
|
target_frame_x=-2,
|
|
target_frame_y=0,
|
|
attention_anim_run_time=5,
|
|
label_text="Attention"
|
|
):
|
|
# self.embed()
|
|
# # Test
|
|
# self.frame.clear_updaters()
|
|
# attention_anim_run_time = 16
|
|
# target_orientation = (-53, -16, 0, (-1.08, -0.57, 1.12), 11.27)
|
|
# self.camera.light_source.set_z(10)
|
|
|
|
layer = self.layers[-1]
|
|
layer.brackets.set_fill(opacity=1)
|
|
layer.save_state()
|
|
block = self.get_block(layer, title=label_text, depth=depth)
|
|
z_diff = block.get_z() - layer.get_z()
|
|
block_opacity = block.body[0].get_opacity()
|
|
block.body[0].set_opacity(0)
|
|
block.title.set_backstroke(BLACK, 5)
|
|
new_layer = layer.copy()
|
|
new_layer.match_z(block)
|
|
new_layer.set_backstroke(BLACK, 1)
|
|
|
|
self.frame.target = self.frame.generate_target()
|
|
self.frame.target.reorient(*target_orientation)
|
|
self.frame.target.set_height(target_frame_height)
|
|
self.frame.target.set_x(target_frame_x)
|
|
self.frame.target.set_y(target_frame_y)
|
|
|
|
self.play(
|
|
MoveToTarget(self.frame),
|
|
LaggedStart(
|
|
layer.animate.set_opacity(0.25),
|
|
FadeIn(block),
|
|
TransformFromCopy(layer, new_layer),
|
|
(self.blocks[-1].title if len(self.blocks) > 0 else VGroup()).animate.set_opacity(0.25),
|
|
lag_ratio=0.3
|
|
),
|
|
self.token_arrows.animate.set_opacity(0.1),
|
|
run_time=2
|
|
)
|
|
self.play_simple_attention_animation(
|
|
new_layer,
|
|
run_time=attention_anim_run_time,
|
|
added_anims=[self.frame.animate.reorient(0, -15, 0, (-0.07, 0.45, 2.03), 9.25)]
|
|
)
|
|
|
|
# Take new layer out of block
|
|
self.add(*block.body, block.title, new_layer)
|
|
new_z = block.get_z() + z_diff
|
|
self.play(
|
|
block.body[0].animate.set_opacity(block_opacity),
|
|
new_layer.animate.set_z(new_z),
|
|
self.frame.animate.set_z(new_z),
|
|
Restore(layer),
|
|
)
|
|
self.add(block, new_layer)
|
|
|
|
if False:
|
|
# Highlight example
|
|
index = 4
|
|
highlight = new_layer[0][index].copy()
|
|
highlight.set_backstroke(BLACK, 3)
|
|
highlight.set_fill(border_width=1)
|
|
rect = SurroundingRectangle(highlight)
|
|
rect.set_stroke(TEAL, 3)
|
|
new_arrow = Arrow(self.token_blocks[index].get_bottom(), rect.get_top(), tip_angle=45 * DEGREES)
|
|
new_arrow.set_fill(GREY_A, border_width=4)
|
|
|
|
self.add(block.body, new_layer, block.title),
|
|
self.play(LaggedStart(
|
|
self.frame.animate.reorient(0, -14, 0, (-0.18, 0.73, 3.26), 8.31),
|
|
block.body.animate.set_color(BLACK).set_shading(0.1, 0.1, 0.5),
|
|
block.title.animate.set_opacity(0.1),
|
|
new_layer.animate.fade(0.75).set_stroke(width=0),
|
|
self.token_blocks[index].animate.scale(2, about_edge=DOWN),
|
|
self.token_blocks[:index].animate.shift(0.35 * LEFT),
|
|
FadeIn(highlight),
|
|
ShowCreation(rect),
|
|
FadeIn(new_arrow, scale=4),
|
|
run_time=3
|
|
))
|
|
self.wait()
|
|
|
|
self.blocks.add(block)
|
|
self.layers.add(new_layer)
|
|
|
|
def play_simple_attention_animation(self, layer, run_time=5, added_anims=[]):
|
|
arc_groups = VGroup()
|
|
for _ in range(3):
|
|
for n, e1 in enumerate(layer.embeddings):
|
|
arc_group = VGroup()
|
|
for e2 in layer.embeddings[n + 1:]:
|
|
sign = (-1)**int(e2.get_x() > e1.get_x())
|
|
arc_group.add(Line(
|
|
e1.get_top(), e2.get_top(),
|
|
path_arc=sign * PI / 3,
|
|
stroke_color=random_bright_color(hue_range=(0.1, 0.3)),
|
|
stroke_width=5 * random.random()**5,
|
|
))
|
|
arc_group.shuffle()
|
|
if len(arc_group) > 0:
|
|
arc_groups.add(arc_group)
|
|
|
|
self.play(
|
|
LaggedStart(*(
|
|
AnimationGroup(
|
|
LaggedStartMap(VShowPassingFlash, arc_group.copy(), time_width=2, lag_ratio=0.15),
|
|
LaggedStartMap(ShowCreationThenFadeOut, arc_group, lag_ratio=0.15),
|
|
)
|
|
for arc_group in arc_groups
|
|
), lag_ratio=0.0),
|
|
LaggedStartMap(RandomizeMatrixEntries, layer.embeddings, lag_ratio=0.0),
|
|
*added_anims,
|
|
run_time=run_time
|
|
)
|
|
self.add(layer)
|
|
|
|
def progress_through_mlp_block(
|
|
self,
|
|
n_neurons=20,
|
|
# depth=3.0,
|
|
depth=4.0,
|
|
buff=1.0,
|
|
dot_buff_ratio=0.2,
|
|
neuron_color=GREY_C,
|
|
neuron_shading=(0.25, 0.75, 0.2),
|
|
sideview_orientation=(-60, -5, 0),
|
|
final_orientation=(-51, -18, 0),
|
|
show_one_by_one=False,
|
|
# label_text="Multilayer\nPerceptron",
|
|
# title_font_size=72,
|
|
label_text="Feedforward",
|
|
title_font_size=90,
|
|
):
|
|
# MLP Test
|
|
layer = self.layers[-1]
|
|
block = self.get_block(
|
|
layer,
|
|
depth=depth,
|
|
buff=buff,
|
|
size_buff=1.0,
|
|
title=label_text,
|
|
title_font_size=title_font_size,
|
|
)
|
|
|
|
# New layer
|
|
new_layer = self.get_next_layer_array(layer)
|
|
new_layer.next_to(block.body, OUT, buff)
|
|
|
|
# Neurons
|
|
def get_neurons(points):
|
|
neurons = DotCloud(points, radius=0.1)
|
|
neurons.make_3d()
|
|
neurons.set_glow_factor(0)
|
|
neurons.set_shading(*neuron_shading)
|
|
neurons.set_color(neuron_color)
|
|
neurons.set_opacity(np.random.uniform(0.5, 1, len(points)))
|
|
return neurons
|
|
|
|
all_neuron_points = np.zeros((0, 3))
|
|
neuron_clusters = Group()
|
|
connections = VGroup()
|
|
y_min = block.body.get_y(DOWN) + SMALL_BUFF
|
|
y_max = block.body.get_y(UP) - SMALL_BUFF
|
|
block_z = block.body.get_z()
|
|
for embedding in layer.embeddings:
|
|
l1_points = np.array([e.get_center() for e in embedding.get_columns()[0]])
|
|
l3_points = l1_points.copy()
|
|
l1_points[:, 2] = block_z - 0.4 * depth
|
|
l3_points[:, 2] = block_z + 0.4 * depth
|
|
x0, y0, z0 = embedding.get_center()
|
|
l2_points = np.array([
|
|
[x0, y, block_z]
|
|
for y in np.linspace(y_min, y_max, n_neurons)
|
|
])
|
|
new_points = np.vstack([l1_points, l2_points, l3_points])
|
|
all_neuron_points = np.vstack([all_neuron_points, new_points])
|
|
neuron_clusters.add(get_neurons(new_points))
|
|
# Lines
|
|
weights = VGroup(*(
|
|
Line(
|
|
p1, p2,
|
|
buff=0.1,
|
|
stroke_width=3,
|
|
stroke_opacity=random.random(),
|
|
stroke_color=value_to_color(random.uniform(-10, 10))
|
|
)
|
|
for points1, points2 in [
|
|
[l1_points, l2_points],
|
|
[l2_points, l3_points],
|
|
]
|
|
for p1, p2 in it.product(points1, points2)
|
|
if random.random() < 0.2
|
|
))
|
|
connections.add(weights)
|
|
neurons = get_neurons(all_neuron_points)
|
|
connections.apply_depth_test()
|
|
connections.set_flat_stroke(False)
|
|
|
|
# Flow through layer
|
|
if show_one_by_one:
|
|
networks = Group(*(Group(cluster, lines.copy()) for cluster, lines in zip(neuron_clusters, connections)))
|
|
last_network = VectorizedPoint()
|
|
emb_pairs = list(zip(layer.embeddings, new_layer.embeddings))
|
|
index = 4
|
|
block.title.rotate(PI / 2, DOWN)
|
|
self.play(
|
|
self.blocks[-1].title.animate.set_opacity(0.25),
|
|
FadeIn(block.title, time_span=(0, 1)),
|
|
self.frame.animate.reorient(*sideview_orientation),
|
|
Write(networks[index][1]),
|
|
FadeIn(networks[index][0]),
|
|
FadeTransform(emb_pairs[index][0].copy(), emb_pairs[index][1]),
|
|
run_time=3
|
|
)
|
|
lag_kw = dict(lag_ratio=0.5, run_time=9)
|
|
self.remove(block.title)
|
|
self.play(
|
|
LaggedStart(*(
|
|
FadeTransform(e1.copy(), e2)
|
|
for e1, e2 in [*emb_pairs[:index], *emb_pairs[index + 1:]]
|
|
), **lag_kw),
|
|
LaggedStart(*(
|
|
Write(network[1])
|
|
for network in [*networks[:index], *networks[index + 1:]]
|
|
), **lag_kw),
|
|
LaggedStart(*(
|
|
FadeIn(network[0])
|
|
for network in [*networks[:index], *networks[index + 1:]]
|
|
), **lag_kw),
|
|
self.frame.animate.reorient(0, -37, 0, (-1.08, 2.29, 7.99), 12.27),
|
|
block.title.animate.rotate(PI / 2, UP),
|
|
run_time=15,
|
|
)
|
|
self.remove(networks)
|
|
self.add(neurons, connections, block.body, block.title, new_layer)
|
|
self.play(
|
|
FadeIn(block.body),
|
|
FadeIn(new_layer),
|
|
FadeOut(new_layer.embeddings.copy()),
|
|
# self.frame.animate.reorient(*final_orientation).match_z(new_layer),
|
|
self.frame.animate.reorient(-48, -12, 0).match_z(new_layer),
|
|
run_time=2
|
|
)
|
|
else:
|
|
self.play(
|
|
FadeIn(block.title, time_span=(0, 1)),
|
|
self.blocks[-1].title.animate.set_opacity(0.25),
|
|
self.frame.animate.reorient(*sideview_orientation),
|
|
FadeIn(neurons, time_span=(0, 1)),
|
|
Write(connections, stroke_width=3),
|
|
TransformFromCopy(layer, new_layer),
|
|
run_time=3
|
|
)
|
|
self.add(block.body, block.title, new_layer)
|
|
self.play(
|
|
self.frame.animate.reorient(*final_orientation).match_z(new_layer),
|
|
FadeIn(block.body),
|
|
run_time=2,
|
|
)
|
|
|
|
# Aggregate
|
|
self.mlps.add(Group(neurons, connections))
|
|
self.blocks.add(block)
|
|
self.layers.add(new_layer)
|
|
|
|
def remove_mlps(self):
|
|
self.remove(self.mlps)
|
|
|
|
def mention_repetitions(self, depth=8):
|
|
# Mention repetition
|
|
frame = self.frame
|
|
layer = self.layers[-1]
|
|
block = self.blocks[-1].body
|
|
|
|
thin_blocks = block.replicate(2)
|
|
thin_blocks.set_depth(1.0, stretch=True)
|
|
|
|
dots = Tex(".....", font_size=250)
|
|
brace = Brace(dots, UP)
|
|
brace_text = brace.get_text("Many\nrepetitions")
|
|
rep_label = Group(dots, brace, brace_text)
|
|
rep_label.set_width(depth)
|
|
rep_label.rotate(PI / 2, DOWN)
|
|
rep_label.next_to(layer, OUT, buff=3.0)
|
|
rep_label.align_to(ORIGIN, DOWN)
|
|
|
|
thin_blocks[0].align_to(brace, IN)
|
|
thin_blocks[1].align_to(brace, OUT)
|
|
dots.scale(0.5)
|
|
VGroup(brace, brace_text).next_to(thin_blocks, UP)
|
|
|
|
final_layer = self.get_next_layer_array(layer)
|
|
final_layer.set_z(rep_label.get_z(OUT) + 1)
|
|
final_layer.save_state()
|
|
final_layer.become(layer)
|
|
final_layer.set_opacity(0)
|
|
|
|
self.play(
|
|
frame.animate.reorient(-57, -8, 0, (-6.59, 1.2, 57.84), 30),
|
|
FadeIn(thin_blocks, lag_ratio=0.1),
|
|
GrowFromCenter(brace),
|
|
Write(brace_text, time_span=(1, 2)),
|
|
Write(dots),
|
|
run_time=3
|
|
)
|
|
self.play(
|
|
Restore(final_layer, run_time=2)
|
|
)
|
|
self.wait()
|
|
|
|
self.rep_label = rep_label
|
|
self.blocks.add(*thin_blocks)
|
|
self.layers.add(final_layer)
|
|
|
|
def focus_on_last_layer(self):
|
|
# Last Layer
|
|
layer = self.layers[-1]
|
|
rect = BackgroundRectangle(layer)
|
|
rect.set_fill(BLACK, 0.5)
|
|
rect.scale(3)
|
|
|
|
self.rep_label.target = self.rep_label.generate_target()
|
|
self.rep_label.target[1:].set_opacity(0)
|
|
self.rep_label.target[0].set_opacity(0.5)
|
|
|
|
target_frame_center = layer.get_center() + 4 * RIGHT + UP
|
|
|
|
self.add(self.rep_label, rect, layer)
|
|
self.play(
|
|
FadeIn(rect),
|
|
MoveToTarget(self.rep_label),
|
|
*map(FadeOut, [block.title for block in self.blocks if hasattr(block, "title")]),
|
|
self.frame.animate.reorient(-3, -12, 0, target_frame_center, 12.00),
|
|
run_time=3,
|
|
)
|
|
|
|
def show_unembedding(self):
|
|
# Unembedding
|
|
label_font_size = 30
|
|
layer = self.layers[-1]
|
|
last_embedding = layer.embeddings[-1]
|
|
rect = SurroundingRectangle(last_embedding)
|
|
rect.set_stroke(YELLOW, 3)
|
|
|
|
words, dist = zip(*self.possible_next_tokens)
|
|
bars = BarChart(dist).bars
|
|
bars.rotate(-90 * DEGREES)
|
|
bars.set_width(2.5, stretch=True)
|
|
bars.next_to(layer, RIGHT, buff=4.0)
|
|
bars.set_y(2)
|
|
for bar, word, value in zip(bars, words, dist):
|
|
percentage = DecimalNumber(
|
|
100 * value,
|
|
num_decimal_places=2,
|
|
unit="%",
|
|
font_size=label_font_size,
|
|
)
|
|
percentage.next_to(bar, RIGHT)
|
|
text = Text(word, font_size=label_font_size)
|
|
text.next_to(bar, LEFT)
|
|
bar.push_self_into_submobjects()
|
|
bar.add(text, percentage)
|
|
|
|
dots = Tex(R"\vdots")
|
|
dots.next_to(bars[-1][1], DOWN)
|
|
bars.add(dots)
|
|
brace = Brace(bars, LEFT)
|
|
|
|
arrow = Line()
|
|
arrow.clear_points()
|
|
arrow.start_new_path(rect.get_top())
|
|
arrow.add_cubic_bezier_curve_to(
|
|
rect.get_top() + UP,
|
|
rect.get_top() + UR,
|
|
rect.get_top() + 2 * RIGHT,
|
|
)
|
|
arrow.add_line_to(
|
|
brace.get_left() + 0.1 * LEFT,
|
|
)
|
|
arrow.make_smooth(approx=False)
|
|
arrow.add_tip()
|
|
arrow.set_color(YELLOW)
|
|
|
|
self.play(ShowCreation(rect))
|
|
self.play(
|
|
ShowCreation(arrow),
|
|
GrowFromCenter(brace),
|
|
LaggedStartMap(FadeIn, bars, shift=DOWN)
|
|
)
|
|
self.wait()
|
|
# Test
|
|
self.play(
|
|
self.frame.animate.reorient(-5, -8, 0, (2.99, 2.67, 71.6), 13.60),
|
|
run_time=8
|
|
)
|
|
|
|
|
|
class SimplifiedFlow(HighLevelNetworkFlow):
|
|
example_text = "Word vectors will be updated to encode more than mere words"
|
|
attention_anim_run_time = 1.0
|
|
orientation = (-55, -19, 0)
|
|
target_frame_x = -2
|
|
x_range = np.linspace(-2, -8, 5)
|
|
frame_height_growth_rate = 0.15
|
|
|
|
def construct(self):
|
|
# Test
|
|
self.camera.light_source.set_z(60)
|
|
self.show_initial_text_embedding(word_scale_factor=0.6)
|
|
self.show_simple_flow(self.x_range)
|
|
|
|
def show_simple_flow(self, x_range, orientation=None):
|
|
if orientation is None:
|
|
orientation = self.orientation
|
|
curr_time = float(self.time)
|
|
curr_height = self.frame.get_height()
|
|
self.frame.add_updater(lambda f: f.set_height(curr_height + self.frame_height_growth_rate * (self.time - curr_time)))
|
|
for x in x_range:
|
|
self.progress_through_attention_block(
|
|
target_orientation=orientation,
|
|
target_frame_x=self.target_frame_x,
|
|
attention_anim_run_time=self.attention_anim_run_time,
|
|
)
|
|
self.progress_through_mlp_block(
|
|
sideview_orientation=orientation,
|
|
final_orientation=orientation,
|
|
)
|
|
|
|
|
|
class AltIntro(HighLevelNetworkFlow):
|
|
example_text = "Four score and seven years ago our fathers"
|
|
|
|
def construct(self):
|
|
self.show_initial_text_embedding(word_scale_factor=0.6)
|
|
|
|
|
|
class SimplifiedFlowAlternateAngle(SimplifiedFlow):
|
|
example_text = "The goal of the network is to predict the next token"
|
|
attention_anim_run_time = 1.0
|
|
orientation = (-10, -20, 0)
|
|
target_frame_x = 0
|
|
hide_block_labels = True
|
|
|
|
|
|
class LongerFlowLongerAttention(SimplifiedFlow):
|
|
example_text = "The goal of the network is to predict the next token"
|
|
attention_anim_run_time = 3.0
|
|
target_frame_x = 0
|
|
hide_block_labels = True
|
|
x_range = np.linspace(-2, -12, 7)
|
|
|
|
|
|
class MentionContextSizeAndUnembedding(SimplifiedFlow):
|
|
example_text = "Harry Potter was a highly unusual boy ... least favourite teacher, Professor Snape"
|
|
attention_anim_run_time = 5.0
|
|
use_words = True
|
|
|
|
def construct(self):
|
|
# Initial flow
|
|
self.show_initial_text_embedding(word_scale_factor=0.5)
|
|
self.cycle_through_embeddings()
|
|
self.show_simple_flow(
|
|
np.linspace(-2, -5, 2),
|
|
orientation=(-50, -19, 0)
|
|
)
|
|
self.show_context_size()
|
|
|
|
# Skip to the end
|
|
self.remove_mlps()
|
|
self.mention_repetitions()
|
|
self.focus_on_last_layer()
|
|
self.remove(*self.layers[:-1])
|
|
|
|
# Discuss unembedding
|
|
self.show_desired_output()
|
|
self.show_unembedding_matrix()
|
|
|
|
def cycle_through_embeddings(self):
|
|
layer = self.layers[0]
|
|
blocks = self.token_blocks
|
|
arrows = self.token_arrows
|
|
embeddings = VGroup(*layer.embeddings, *layer.dots)
|
|
embeddings.sort(lambda p: p[0])
|
|
rect_groups = VGroup(*(
|
|
VGroup(*(
|
|
BackgroundRectangle(mob, buff=0.2 if mob in arrows else 0.1)
|
|
for mob in tup
|
|
))
|
|
for tup in zip(blocks, arrows, embeddings)
|
|
))
|
|
rect_groups.set_opacity(0)
|
|
self.add(rect_groups)
|
|
|
|
for index in range(len(rect_groups)):
|
|
rect_groups.generate_target()
|
|
rect_groups.target.set_opacity(0.75)
|
|
rect_groups.target[index].set_opacity(0)
|
|
self.play(MoveToTarget(rect_groups))
|
|
self.play(FadeOut(rect_groups))
|
|
|
|
def show_context_size(self):
|
|
# Zoom in
|
|
frame = self.frame
|
|
frame.save_state()
|
|
block_titles = VGroup(*(block.title for block in self.blocks))
|
|
layer = self.layers[-1]
|
|
|
|
self.play(
|
|
frame.animate.reorient(-27, -18, 0, (-0.84, -0.36, 18.15), 9.00),
|
|
FadeOut(block_titles, lag_ratio=0.01),
|
|
run_time=2,
|
|
)
|
|
|
|
# Show size
|
|
font_size = 72
|
|
brace = Brace(layer.embeddings, DOWN)
|
|
label = brace.get_text("Context size", font_size=font_size)
|
|
label.generate_target()
|
|
rhs = Tex("= 2{,}048", font_size=font_size)
|
|
full_label = VGroup(label.target, rhs)
|
|
full_label.arrange(RIGHT, aligned_edge=UP)
|
|
full_label.next_to(brace, DOWN)
|
|
|
|
dim_brace = Brace(layer.embeddings, RIGHT)
|
|
dim_label = dim_brace.get_text("12,288", font_size=font_size)
|
|
|
|
self.play(
|
|
GrowFromCenter(brace),
|
|
FadeIn(label, 0.5 * DOWN),
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
MoveToTarget(label),
|
|
Write(rhs)
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
FadeOut(layer.brackets),
|
|
GrowFromCenter(dim_brace),
|
|
FadeIn(dim_label, 0.5 * LEFT),
|
|
)
|
|
self.wait(3)
|
|
|
|
# Return
|
|
self.play(
|
|
Restore(frame, run_time=3),
|
|
FadeIn(block_titles, time_span=(2, 3), lag_ratio=0.01),
|
|
FadeIn(layer.brackets),
|
|
LaggedStartMap(FadeOut, VGroup(
|
|
brace, label, rhs, dim_brace, dim_label
|
|
), shift=DOWN)
|
|
)
|
|
for layer, block in zip(self.layers, self.blocks):
|
|
self.add(layer, block)
|
|
self.add(self.layers[-1])
|
|
|
|
def show_desired_output(self):
|
|
# Show phrase
|
|
frame = self.frame
|
|
phrase = Text(self.example_text)
|
|
phrase.set_max_width(FRAME_WIDTH - 1)
|
|
phrase.to_corner(UL)
|
|
phrase.fix_in_frame()
|
|
last_word = phrase[self.example_text.split(" ")[-1]][0]
|
|
last_word.set_opacity(0)
|
|
rect = SurroundingRectangle(last_word, buff=0.1)
|
|
rect.set_stroke(YELLOW, 2)
|
|
rect.set_fill(YELLOW, 0.5)
|
|
rect.align_to(last_word, LEFT)
|
|
q_marks = Text("???")
|
|
q_marks.next_to(rect, DOWN)
|
|
q_group = VGroup(rect, q_marks)
|
|
q_group.fix_in_frame()
|
|
|
|
self.play(FadeIn(phrase, lag_ratio=0.1, run_time=2))
|
|
self.play(FadeIn(q_group, 0.2 * RIGHT))
|
|
self.wait()
|
|
|
|
# Get all possible next words
|
|
layer = self.layers[-1]
|
|
word_strs = [
|
|
"...",
|
|
"Snake",
|
|
"Snape",
|
|
"Snare",
|
|
"...",
|
|
"Treks",
|
|
"Trelawney",
|
|
"Trellis",
|
|
"...",
|
|
"Quirky",
|
|
"Quirrell",
|
|
"Quirt",
|
|
"..."
|
|
]
|
|
words = VGroup()
|
|
dots = VGroup()
|
|
for word_str in word_strs:
|
|
word = Text(word_str)
|
|
words.add(word)
|
|
if word_str == "...":
|
|
dots.add(word)
|
|
word.rotate(PI / 2)
|
|
words.arrange(DOWN, aligned_edge=LEFT)
|
|
dots.shift(0.25 * RIGHT)
|
|
words.set_max_height(5.0)
|
|
words.to_edge(DOWN, buff=1.0)
|
|
words.to_edge(RIGHT, buff=0.1)
|
|
words.fix_in_frame()
|
|
|
|
# Add probability bars
|
|
values = [0, 0.78, 0, 0, 0.16, 0, 0, 0.06, 0]
|
|
bars = VGroup()
|
|
probs = VGroup()
|
|
value_iter = iter(values)
|
|
bar_height = 0.8 * words[1].get_height()
|
|
for word in words:
|
|
if word in dots:
|
|
continue
|
|
value = next(value_iter)
|
|
prob = DecimalNumber(value, font_size=24)
|
|
bar = Rectangle(10 * value * bar_height, bar_height)
|
|
bar.next_to(word, LEFT)
|
|
hsl = list(Color(BLUE).get_hsl())
|
|
hsl[0] = interpolate(0.5, 0.6, value)
|
|
bar.set_fill(Color(hsl=hsl), 1)
|
|
bar.set_stroke(BLUE, 1)
|
|
prob.next_to(bar, LEFT)
|
|
probs.add(prob)
|
|
bars.add(bar)
|
|
|
|
probs.fix_in_frame()
|
|
bars.fix_in_frame()
|
|
brace = Brace(VGroup(probs, dots), LEFT).fix_in_frame()
|
|
arrow = Vector(RIGHT, stroke_width=8).fix_in_frame()
|
|
arrow.set_color(YELLOW)
|
|
arrow.next_to(brace, LEFT)
|
|
|
|
# Show creation
|
|
self.play(
|
|
LaggedStartMap(FadeIn, words, shift=0.1 * DOWN),
|
|
LaggedStartMap(FadeIn, bars),
|
|
LaggedStartMap(FadeIn, probs, shift=0.2 * LEFT),
|
|
GrowFromCenter(brace),
|
|
GrowArrow(arrow),
|
|
)
|
|
self.wait()
|
|
|
|
self.prob_group = VGroup(arrow, brace, words, bars, probs)
|
|
self.phrase = VGroup(phrase, q_group)
|
|
|
|
def show_unembedding_matrix(self, vector_index=-1):
|
|
# Clear frame
|
|
frame = self.frame
|
|
prob_group = self.prob_group
|
|
prob_group.save_state()
|
|
prob_group.generate_target()
|
|
prob_group.target.scale(0.5, about_edge=DR)
|
|
prob_group.target.to_corner(DR)
|
|
prob_group.target[0].set_opacity(0)
|
|
|
|
self.play(
|
|
FadeOut(self.phrase, UP),
|
|
MoveToTarget(prob_group),
|
|
frame.animate.reorient(0, 1, 0, (4.98, 3.34, 30.08), 12.00),
|
|
run_time=2,
|
|
)
|
|
|
|
# Show the weight matrix
|
|
layer = self.layers[-1]
|
|
vector = layer.embeddings[vector_index].copy()
|
|
matrix = WeightMatrix(
|
|
shape=(15, vector.shape[0]),
|
|
ellipses_row=8,
|
|
)
|
|
matrix.set_height(6)
|
|
matrix.next_to(vector, UP, aligned_edge=RIGHT, buff=1.0)
|
|
matrix.shift(LEFT)
|
|
last_vector_rect = rect = SurroundingRectangle(vector, buff=0.1)
|
|
rect.set_stroke(YELLOW, 3)
|
|
vector.generate_target()
|
|
vector.target.next_to(matrix, RIGHT)
|
|
last_vect_arrow = Arrow(
|
|
rect.get_top(), vector.target.get_bottom(),
|
|
)
|
|
last_vect_arrow.set_color(YELLOW)
|
|
|
|
self.play(
|
|
FadeIn(matrix, scale=0.8, shift=DR),
|
|
*map(FadeOut, [self.token_blocks, self.token_arrows, self.final_word_question]),
|
|
)
|
|
self.play(ShowCreation(rect))
|
|
self.play(MoveToTarget(vector), ShowCreation(last_vect_arrow))
|
|
|
|
# Show matrix vector product
|
|
eq, rhs = show_matrix_vector_product(self, matrix, vector)
|
|
|
|
# Count values
|
|
brace = Brace(rhs, RIGHT)
|
|
brace_label = brace.get_tex(R"\sim 50k \text{ values}", font_size=60, buff=0.25)
|
|
|
|
self.play(
|
|
GrowFromCenter(brace),
|
|
FadeIn(brace_label, lag_ratio=0.1),
|
|
)
|
|
brace_group = VGroup(brace, brace_label)
|
|
|
|
# Show words
|
|
word_strs = [
|
|
"aah",
|
|
"aardvark",
|
|
"aardwolf",
|
|
"aargh",
|
|
"ab",
|
|
"aback",
|
|
"abacterial",
|
|
"abacus",
|
|
"...",
|
|
"zygote",
|
|
"zygotic",
|
|
"zyme",
|
|
"zymogen",
|
|
"zymosis",
|
|
"zzz",
|
|
]
|
|
words = VGroup(*map(Text, word_strs))
|
|
words[word_strs.index("...")].rotate(PI / 2)
|
|
for word, entry in zip(words, rhs):
|
|
word.set_max_height(entry.get_height())
|
|
word.next_to(rhs, RIGHT, buff=0.25)
|
|
word.match_y(entry)
|
|
|
|
self.play(
|
|
LaggedStartMap(FadeIn, words, shift=0.5 * RIGHT),
|
|
brace_group.animate.next_to(words, RIGHT),
|
|
)
|
|
self.wait()
|
|
|
|
# Mention softmax
|
|
big_rect = SurroundingRectangle(VGroup(rhs, words), buff=0.25)
|
|
big_rect.set_fill(TEAL, 0.1)
|
|
big_rect.set_stroke(TEAL, 3)
|
|
softmax_arrow = Vector(2.2 * RIGHT, stroke_width=8)
|
|
softmax_arrow.next_to(big_rect, RIGHT, buff=0.1)
|
|
softmax_label = Text("softmax", font_size=48)
|
|
softmax_label.next_to(softmax_arrow, UP, buff=0.35)
|
|
|
|
prob_group.generate_target()
|
|
prob_group.target[0].scale(0).move_to(prob_group)
|
|
prob_group.target.set_height(4)
|
|
prob_group.target.to_edge(UP, buff=0.25).to_edge(RIGHT, buff=0.1)
|
|
|
|
self.play(
|
|
LaggedStart(
|
|
FadeOut(brace_group),
|
|
DrawBorderThenFill(big_rect),
|
|
ShowCreation(softmax_arrow),
|
|
FadeIn(softmax_label, lag_ratio=0.1),
|
|
),
|
|
MoveToTarget(prob_group, time_span=(1, 2))
|
|
)
|
|
prob_group.unfix_from_frame()
|
|
prob_group.match_height(rhs)
|
|
prob_group.next_to(softmax_arrow, RIGHT, buff=0.25)
|
|
self.wait()
|
|
|
|
# Ask about other vectors
|
|
rects = VGroup(*(
|
|
SurroundingRectangle(emb)
|
|
for emb in layer.embeddings[:-1]
|
|
))
|
|
rects.set_stroke(PINK, 3)
|
|
rects.set_fill(PINK, 0.25)
|
|
question = Text("What about these?", font_size=90)
|
|
question.next_to(layer, DOWN, buff=2.0)
|
|
question_arrows = VGroup()
|
|
for rect in rects:
|
|
question_arrows.add(Arrow(
|
|
question.get_top(), rect.get_bottom(),
|
|
stroke_color=rect.get_color()
|
|
))
|
|
|
|
self.play(
|
|
frame.animate.reorient(-1, 0, 0, (1.69, 1.57, 29.79), 15.88),
|
|
Write(question),
|
|
run_time=2,
|
|
)
|
|
|
|
last_rect = VGroup()
|
|
for rect, arrow in zip(rects, question_arrows):
|
|
self.play(
|
|
FadeOut(last_rect),
|
|
FadeIn(rect),
|
|
FadeIn(arrow),
|
|
)
|
|
last_rect = rect
|
|
self.play(FadeOut(last_rect))
|
|
self.wait()
|
|
|
|
# Move back
|
|
self.play(
|
|
FadeOut(question_arrows, lag_ratio=0.1),
|
|
FadeOut(question, lag_ratio=0.1),
|
|
frame.animate.reorient(0, 1, 0, (4.27, 3.46, 29.83), 12.86),
|
|
run_time=2,
|
|
)
|
|
# self.play(
|
|
# LaggedStartMap(FadeOut, VGroup(
|
|
# *question_arrows, question,
|
|
# matrix, vector, eq, rhs,
|
|
# big_rect, softmax_arrow, prob_group, words, softmax_label,
|
|
# last_vector_rect, last_vect_arrow,
|
|
# )),
|
|
# FadeOut(question_arrows, lag_ratio=0.1),
|
|
# FadeOut(question, lag_ratio=0.1),
|
|
# frame.animate.reorient(-2, 0, 0, (2.72, 1.82, 29.44), 10.31),
|
|
# run_time=2,
|
|
# )
|
|
self.wait()
|
|
|
|
# Name the unembedding matrix
|
|
matrix_rect = SurroundingRectangle(matrix, buff=0.1)
|
|
matrix_rect.set_stroke(BLUE, 3)
|
|
name = Text("Unembedding\nmatrix", font_size=60)
|
|
label = Tex("W_U", font_size=90)
|
|
name.next_to(matrix_rect, LEFT)
|
|
label.next_to(name, DOWN)
|
|
|
|
self.play(
|
|
Write(matrix_rect, stroke_width=5, stroke_color=BLUE),
|
|
Write(name, run_time=2),
|
|
frame.animate.reorient(0, 1, 0, (4.24, 3.15, 29.81), 13.23),
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
FadeIn(label, DOWN),
|
|
name.animate.next_to(label, UP, buff=1)
|
|
)
|
|
self.wait()
|
|
|
|
# Data flying
|
|
data_modifying_matrix(self, matrix)
|
|
self.wait()
|
|
|
|
# Count parameters
|
|
label.set_backstroke(BLACK, 6)
|
|
entries = VGroup(*matrix.elements, *matrix.ellipses)
|
|
row_rects = VGroup(*map(SurroundingRectangle, matrix.get_rows()))
|
|
col_rects = VGroup(*map(SurroundingRectangle, matrix.get_columns()))
|
|
VGroup(row_rects, col_rects).set_stroke(GREY, 1).set_fill(GREY_B, 0.5)
|
|
|
|
left_brace = Brace(matrix, LEFT)
|
|
top_brace = Brace(matrix, UP)
|
|
vocab_count = Integer(50257, font_size=90)
|
|
vocab_count.next_to(left_brace, LEFT)
|
|
dim_count = Integer(12288, font_size=90)
|
|
dim_count.next_to(top_brace, UP)
|
|
|
|
top_equation = VGroup(
|
|
Text("Total parameters = "),
|
|
Integer(vocab_count.get_value()),
|
|
Tex(R"\times"),
|
|
Integer(dim_count.get_value()),
|
|
Tex("="),
|
|
Integer(vocab_count.get_value() * dim_count.get_value()).set_color(YELLOW),
|
|
)
|
|
top_equation.arrange(RIGHT)
|
|
top_equation.scale(2)
|
|
top_equation.next_to(dim_count, UP, buff=1)
|
|
top_equation.align_to(vocab_count, LEFT)
|
|
|
|
self.play(
|
|
label.animate.move_to(matrix),
|
|
FadeOut(name, lag_ratio=0.1),
|
|
entries.animate.set_fill(opacity=0.5, border_width=0),
|
|
)
|
|
self.add(row_rects, label)
|
|
self.play(
|
|
GrowFromCenter(left_brace, time_span=(0, 1)),
|
|
CountInFrom(vocab_count),
|
|
LaggedStartMap(VFadeInThenOut, row_rects, lag_ratio=0.2),
|
|
run_time=2.5,
|
|
)
|
|
self.add(col_rects, label)
|
|
self.play(
|
|
GrowFromCenter(top_brace, time_span=(0, 1)),
|
|
CountInFrom(dim_count),
|
|
LaggedStartMap(VFadeInThenOut, col_rects, lag_ratio=0.2),
|
|
frame.animate.reorient(0, 0, 0, (5.66, 4.08, 29.89), 14.32),
|
|
run_time=2.5,
|
|
)
|
|
self.wait()
|
|
self.play(LaggedStart(
|
|
Write(top_equation[0:6:2]),
|
|
TransformFromCopy(vocab_count, top_equation[1]),
|
|
TransformFromCopy(dim_count, top_equation[3]),
|
|
frame.animate.reorient(0, -1, 0, (4.15, 5.07, 29.75), 17.21),
|
|
run_time=3
|
|
))
|
|
self.play(
|
|
FadeTransform(top_equation[1:4].copy(), top_equation[-1])
|
|
)
|
|
self.wait()
|
|
|
|
|
|
class FlowForMLPIntroReview(SimplifiedFlow):
|
|
example_text = "That which does not kill you only makes you stronger"
|
|
x_range = np.linspace(-2, -4, 3)
|
|
frame_height_growth_rate = 0.3
|
|
possible_next_tokens = [
|
|
("stronger", 0.906),
|
|
("stranger", 0.028),
|
|
("more", 0.006),
|
|
("weaker", 0.003),
|
|
("...", 0.003),
|
|
("Strong", 0.002),
|
|
("wish", 0.002),
|
|
("STR", 0.002),
|
|
]
|
|
|
|
def construct(self):
|
|
# Initial flow
|
|
self.camera.light_source.set_z(20)
|
|
self.show_initial_text_embedding(word_scale_factor=0.6)
|
|
self.play(self.frame.animate.scale(1.25))
|
|
self.show_simple_flow(self.x_range)
|
|
|
|
# Show how it ends
|
|
self.remove_mlps()
|
|
self.mention_repetitions()
|
|
self.frame.clear_updaters()
|
|
self.focus_on_last_layer()
|
|
self.play(self.frame.animate.scale(1.15).shift(RIGHT))
|
|
self.show_unembedding()
|
|
|
|
|
|
class FlowForCHM(SimplifiedFlow):
|
|
example_text = "Down by the river bank ... until they jumped into the pond"
|
|
use_words = True
|
|
x_range = np.linspace(-2, -4, 3)
|
|
frame_height_growth_rate = 0.3
|
|
|
|
def construct(self):
|
|
# Initial flow
|
|
self.camera.light_source.set_z(20)
|
|
self.show_initial_text_embedding(word_scale_factor=0.6)
|
|
self.play(self.frame.animate.scale(1.25))
|
|
self.show_simple_flow(self.x_range)
|
|
|
|
# Show how it ends
|
|
self.remove_mlps()
|
|
self.mention_repetitions()
|
|
self.frame.clear_updaters()
|
|
self.focus_on_last_layer()
|
|
self.play(self.frame.animate.scale(1.15).shift(RIGHT))
|
|
self.show_unembedding()
|
|
|
|
def progress_through_mlp_block(self, *args, **kwargs):
|
|
kwargs.update(
|
|
label_text="Feed Forward",
|
|
title_font_size=92,
|
|
)
|
|
super().progress_through_mlp_block(*args, **kwargs)
|
|
|
|
|
|
class FlowForCHM2(FlowForCHM):
|
|
example_text = "The Computer History Museum is located in Mountain"
|
|
possible_next_tokens = [
|
|
("Mountain", 0.706),
|
|
("the", 0.128),
|
|
("San", 0.078),
|
|
("Seattle", 0.006),
|
|
("New", 0.003),
|
|
("Palo", 0.002),
|
|
("Google", 0.002),
|
|
("southern", 0.002),
|
|
]
|
|
|
|
def show_initial_text_embedding(self, word_scale_factor=0.6, bump_first=False):
|
|
super().show_initial_text_embedding(word_scale_factor=0.5)
|
|
|
|
|
|
class TextToNumerical(SimplifiedFlow):
|
|
example_text = "Text must be encoded as numbers blah"
|
|
use_words = True
|
|
|
|
def construct(self):
|
|
# Test
|
|
self.show_initial_text_embedding(word_scale_factor=0.75)
|
|
|
|
|
|
class FlowForCHMNoText(FlowForCHM2):
|
|
def get_block(self, *args, **kwargs):
|
|
kwargs.update(title="")
|
|
return super().get_block(*args, **kwargs)
|
|
|
|
|
|
class TextPassageIntro(InteractiveScene):
|
|
example_text = MentionContextSizeAndUnembedding.example_text
|
|
|
|
def construct(self):
|
|
# Read in passage
|
|
passage_str = Path(DATA_DIR, "harry_potter_3.txt").read_text()
|
|
passage_str = passage_str.replace("\n", "\n\\\\")
|
|
passage = TexText(passage_str, alignment="", additional_preamble=R"\tiny")
|
|
passage.set_height(FRAME_HEIGHT - 1)
|
|
passage[-len("Snape"):].set_opacity(0)
|
|
|
|
# Initial surroundings
|
|
frame = self.frame
|
|
lh, rh = (1176, 1540)
|
|
section = passage[lh:rh].copy()
|
|
section.save_state()
|
|
section.set_width(FRAME_WIDTH - 1)
|
|
section.center()
|
|
|
|
word_lh, word_rh = (150, 155)
|
|
word = section[word_lh:word_rh].copy()
|
|
word.save_state()
|
|
word.set_height(0.75).center()
|
|
|
|
self.play(Write(word))
|
|
self.wait()
|
|
self.play(
|
|
Restore(word, time_span=(0, 0.5), remover=True),
|
|
ShowIncreasingSubsets(section),
|
|
run_time=1
|
|
)
|
|
self.play(ContextAnimation(word, section))
|
|
self.wait()
|
|
self.play(
|
|
Restore(section, remover=True),
|
|
ShowIncreasingSubsets(VGroup(*passage[:lh], *passage[rh:]))
|
|
)
|
|
self.add(passage)
|
|
self.play(ContextAnimation(
|
|
passage[lh:rh][word_lh:word_rh],
|
|
VGroup(
|
|
*passage[0:11],
|
|
*passage[211:217],
|
|
# *passage[2366:2374],
|
|
),
|
|
lag_ratio=0.01
|
|
))
|
|
self.wait()
|
|
|
|
# Compress
|
|
start, end = self.example_text.split(" ... ")
|
|
short_text = Text(self.example_text)
|
|
short_text["Snape"].set_opacity(0)
|
|
short_text.set_max_width(FRAME_WIDTH - 1)
|
|
dots = short_text["..."][0]
|
|
|
|
lh, rh = (31, 2474)
|
|
self.play(
|
|
FadeTransformPieces(
|
|
passage[:lh],
|
|
short_text[start][0],
|
|
),
|
|
ReplacementTransform(passage[lh:rh], dots),
|
|
FadeTransformPieces(
|
|
passage[rh:],
|
|
short_text[end][0],
|
|
),
|
|
run_time=3
|
|
)
|
|
self.add(short_text)
|
|
self.wait()
|
|
|
|
|
|
class ThumbnailBase(HighLevelNetworkFlow):
|
|
block_to_title_direction = LEFT
|
|
def construct(self):
|
|
# Add blocks
|
|
self.show_initial_text_embedding(word_scale_factor=0.6)
|
|
for x in range(4):
|
|
self.progress_through_attention_block(
|
|
depth=1.5,
|
|
label_text="Att"
|
|
)
|
|
self.progress_through_mlp_block(
|
|
depth=2.0,
|
|
label_text="MLP"
|
|
)
|
|
self.frame.reorient(-45, -10, 0, (-2.9, 0.95, 24.44), 14.34)
|
|
|
|
# Fade title
|
|
for n, block in enumerate(self.blocks):
|
|
block.title.set_opacity(0)
|
|
block.body.set_opacity(0.5)
|
|
self.remove(*self.layers[:-1])
|
|
|
|
self.mention_repetitions()
|
|
self.rep_label[0].apply_depth_test()
|
|
self.rep_label[1:].set_opacity(0)
|
|
self.add(self.blocks)
|
|
|
|
example_text = "The initials GPT stand for Generative Pre-trained Transformer"
|
|
|
|
|
|
class MoleExample1(HighLevelNetworkFlow):
|
|
block_to_title_direction = LEFT
|
|
highlighted_group_index = 1
|
|
|
|
def construct(self):
|
|
# Show three phrases
|
|
phrase_strs = [
|
|
"American shrew mole",
|
|
"One mole of carbon dioxide",
|
|
"Take a biopsy of the mole",
|
|
]
|
|
phrases = VGroup(map(Text, phrase_strs))
|
|
phrases.arrange(DOWN, buff=2.0)
|
|
phrases.move_to(0.25 * DOWN)
|
|
|
|
self.play(Write(phrases[0]), run_time=1)
|
|
self.wait()
|
|
for i in [1, 2]:
|
|
self.play(
|
|
Transform(phrases[i - 1]["mole"].copy(), phrases[i]["mole"].copy(), remover=True),
|
|
FadeIn(phrases[i], lag_ratio=0.1)
|
|
)
|
|
self.wait()
|
|
|
|
# Add mole images
|
|
images = Group(
|
|
ImageMobject("ShrewMole").set_height(1),
|
|
Tex(R"6.02 \times 10^{23}").set_color(TEAL),
|
|
ImageMobject("LipMole").set_height(1),
|
|
)
|
|
braces = VGroup()
|
|
mole_words = VGroup()
|
|
for image, phrase in zip(images, phrases):
|
|
mole_word = phrase["mole"][0]
|
|
brace = Brace(mole_word, UP, SMALL_BUFF)
|
|
image.next_to(brace, UP, SMALL_BUFF)
|
|
braces.add(brace)
|
|
mole_words.add(mole_word)
|
|
|
|
self.play(
|
|
LaggedStartMap(GrowFromCenter, braces, lag_ratio=0.5),
|
|
LaggedStartMap(FadeIn, images, shift=UP, lag_ratio=0.5),
|
|
mole_words.animate.set_color(YELLOW).set_anim_args(lag_ratio=0.1),
|
|
)
|
|
self.wait()
|
|
|
|
# Subdivide
|
|
word_groups = VGroup()
|
|
for phrase in phrases:
|
|
words = break_into_words(phrase.copy())
|
|
rects = get_piece_rectangles(
|
|
words, leading_spaces=False, h_buff=0.05
|
|
)
|
|
word_group = VGroup(VGroup(*pair) for pair in zip(rects, words))
|
|
word_groups.add(word_group)
|
|
|
|
self.play(
|
|
FadeIn(word_groups),
|
|
LaggedStartMap(FadeOut, braces, shift=0.25 * DOWN, lag_ratio=0.25),
|
|
LaggedStartMap(FadeOut, images, shift=0.25 * DOWN, lag_ratio=0.25),
|
|
run_time=1
|
|
)
|
|
self.remove(phrases)
|
|
self.wait()
|
|
|
|
# Divide into three regions
|
|
for group, sign in zip(word_groups, [-1, 0, 1]):
|
|
group.target = group.generate_target()
|
|
group.target.scale(0.75)
|
|
group.target.set_x(sign * FRAME_WIDTH / 3)
|
|
group.target.to_edge(UP)
|
|
|
|
v_lines = Line(UP, DOWN).replicate(2)
|
|
v_lines.set_height(FRAME_HEIGHT)
|
|
v_lines.arrange(RIGHT, buff=FRAME_WIDTH / 3)
|
|
v_lines.center()
|
|
v_lines.set_stroke(GREY_B, 1)
|
|
|
|
self.play(
|
|
LaggedStartMap(MoveToTarget, word_groups),
|
|
ShowCreation(v_lines, lag_ratio=0.5, time_span=(1, 2))
|
|
)
|
|
|
|
# Show vector embeddings
|
|
embs = VGroup()
|
|
arrows = VGroup()
|
|
seed_array = np.random.uniform(0, 10, 7)
|
|
for group in word_groups:
|
|
for word in group:
|
|
arrow = Vector(0.5 * DOWN)
|
|
arrow.next_to(word, DOWN, SMALL_BUFF)
|
|
size = sum(len(m.get_points()) for m in word.family_members_with_points())
|
|
values = (seed_array * size % 10)
|
|
emb = NumericEmbedding(values=values)
|
|
emb.set_height(2)
|
|
emb.next_to(arrow, DOWN, SMALL_BUFF)
|
|
|
|
arrows.add(arrow)
|
|
embs.add(emb)
|
|
mole_indices = [2, 4, 13]
|
|
non_mole_indices = [n for n in range(len(embs)) if n not in mole_indices]
|
|
|
|
mole_vect_rects = VGroup(
|
|
SurroundingRectangle(embs[index])
|
|
for index in mole_indices
|
|
)
|
|
mole_vect_rects.set_stroke(YELLOW, 2)
|
|
|
|
self.play(
|
|
LaggedStartMap(GrowArrow, arrows),
|
|
LaggedStartMap(FadeIn, embs, shift=0.25 * DOWN),
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
LaggedStartMap(ShowCreation, mole_vect_rects),
|
|
VGroup(arrows[j] for j in non_mole_indices).animate.set_fill(opacity=0.5),
|
|
VGroup(embs[j] for j in non_mole_indices).animate.set_fill(opacity=0.5),
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
FadeOut(mole_vect_rects)
|
|
)
|
|
|
|
# Prepare to pass through an attention block
|
|
wg_lens = [len(wg) for wg in word_groups]
|
|
indices = [0, *np.cumsum(wg_lens)]
|
|
full_groups = VGroup(
|
|
VGroup(wg, arrows[i:j], embs[i:j])
|
|
for wg, i, j in zip(word_groups, indices, indices[1:])
|
|
)
|
|
highlighted_group = full_groups[self.highlighted_group_index]
|
|
fade_groups = [fg for n, fg in enumerate(full_groups) if n != self.highlighted_group_index]
|
|
highlighted_group.target = highlighted_group.generate_target()
|
|
highlighted_group.target.scale(1.5, about_edge=UP)
|
|
highlighted_group.target.space_out_submobjects(1.1)
|
|
highlighted_group.target.center()
|
|
highlighted_group.target[2].set_fill(opacity=1)
|
|
|
|
self.play(
|
|
FadeOut(v_lines, time_span=(0, 1)),
|
|
MoveToTarget(highlighted_group, lag_ratio=5e-4),
|
|
*(
|
|
FadeOut(
|
|
fg,
|
|
shift=fg.get_center() - highlighted_group.get_center() + 2 * DOWN,
|
|
lag_ratio=1e-3
|
|
)
|
|
for fg in fade_groups
|
|
),
|
|
run_time=2
|
|
)
|
|
self.wait()
|
|
|
|
# Pass through attention
|
|
layer = VGroup(highlighted_group[2])
|
|
layer.embeddings = highlighted_group[2]
|
|
self.layers.set_submobjects([])
|
|
self.layers.add(layer)
|
|
|
|
self.progress_through_attention_block(target_frame_x=-2)
|
|
self.wait()
|