More misc. animations for transformers

This commit is contained in:
Grant Sanderson 2024-03-25 19:09:46 -03:00
parent f7ce17bba6
commit c37b90d662
5 changed files with 1538 additions and 71 deletions

File diff suppressed because it is too large Load diff

View file

@ -99,6 +99,9 @@ def get_word_to_vec_model(model_name="glove-wiki-gigaword-50"):
return model
# For chapter 1
class LyingAboutTokens2(InteractiveScene):
def construct(self):
# Mention next word prediction task
@ -320,7 +323,6 @@ class SoundTokens(InteractiveScene):
self.wait()
class IntroduceEmbeddingMatrix(InteractiveScene):
def construct(self):
# Load words
@ -603,6 +605,7 @@ class Word2VecScene(InteractiveScene):
height=8,
depth=6.4,
)
label_rotation = PI / 2
# embedding_model = "word2vec-google-news-300"
embedding_model = "glove-wiki-gigaword-50"
@ -668,6 +671,7 @@ class Word2VecScene(InteractiveScene):
label_text=word if func_name is None else f"{func_name}({word})",
buff=0,
direction=direction,
label_rotation=self.label_rotation,
**label_config,
)
@ -1790,19 +1794,20 @@ class DotProducts(InteractiveScene):
dual_rotate(75, -95, run_time=8)
class DotProductWithGenderDirection(InteractiveScene):
vec_tex = R"\vec{\text{gen}}"
ref_words = ["man", "woman"]
class DotProductWithPluralDirection(InteractiveScene):
vec_tex = R"\vec{\text{plur}}"
ref_words = ["cat", "cats"]
words = [
"mother", "father",
"aunt", "uncle",
"sister", "brother",
"mama", "papa",
"octopus", "octopi",
"puppy", "puppies",
"student", "students",
"one", "two", "three", "four",
"single", "multiple",
]
x_range = (-5, 7 + 1e-4, 0.25)
x_range = (-8, 5 + 1e-4, 0.25)
colors = [BLUE, RED]
threshold = -1.0
number_line_y = -1.5
threshold = 1.0
def construct(self):
# Initialize equation
@ -1842,12 +1847,16 @@ class DotProductWithGenderDirection(InteractiveScene):
longer_tick_multiple=2.5,
width=12
)
# number_line.rotate(PI / 2)
number_line.add_numbers(
np.arange(*x_range[:2]),
num_decimal_places=1,
font_size=30,
# direction=LEFT
)
number_line.move_to(self.number_line_y * UP)
# number_line.to_edge(LEFT, buff=1.0)
eq_rhs = self.get_equation_rhs(eq_lhs, words[0])
equation = VGroup(eq_lhs, eq_rhs)
low_brace = Brace(equation, DOWN)
@ -1948,19 +1957,18 @@ class DotProductWithGenderDirection(InteractiveScene):
)
class DotProductWithPluralityDirection(DotProductWithGenderDirection):
vec_tex = R"\vec{\text{plur}}"
ref_words = ["cat", "cats"]
class DotProductWithGenderDirection(DotProductWithPluralDirection):
vec_tex = R"\vec{\text{gen}}"
ref_words = ["man", "woman"]
words = [
"octopus", "octopi",
"puppy", "puppies",
"student", "students",
"one", "two", "three", "four",
"single", "multiple",
"mother", "father",
"aunt", "uncle",
"sister", "brother",
"mama", "papa",
]
x_range = (-8, 5 + 1e-4, 0.25)
x_range = (-5, 7 + 1e-4, 0.25)
colors = [BLUE, RED]
threshold = -1.0
threshold = 1.0
class RicherEmbedding(InteractiveScene):
@ -2110,3 +2118,244 @@ class RicherEmbedding(InteractiveScene):
result = VGroup(vect, text)
return result
# For chapter 2
class MultipleMoleEmbeddings(Word2VecScene):
default_frame_orientation = (0, 0)
label_rotation = 0
def setup(self):
super().setup()
self.set_floor_plane("xz")
self.frame.add_ambient_rotation()
self.add_plane()
for mob in [self.plane, self.axes]:
mob.rotate(-90 * DEGREES, RIGHT)
def construct(self):
# Show generic mole embedding
frame = self.frame
frame.reorient(-6, -6, 0, (-0.73, 1.29, -0.57), 5.27)
phrases = VGroup(map(Text, [
"American shrew mole",
"One mole of carbon dioxide",
"Take a biopsy of the mole",
]))
for phrase in phrases:
phrases.fix_in_frame()
phrases.to_corner(UL)
phrase["mole"][0].set_color(YELLOW)
gen_vector = self.get_labeled_vector("mole", coords=(-2, 1.0, 1.5))
curr_phrase = phrases[1]
mover = curr_phrase["mole"][0]
mover.set_backstroke(BLACK, 4)
self.add(curr_phrase)
self.wait()
self.play(
GrowArrow(gen_vector),
TransformFromCopy(mover, gen_vector.label),
)
self.wait(10)
# Show three refined meanings
images = Group(
ImageMobject("ShrewMole"),
Tex(R"6.02 \times 10^{23}", font_size=24).set_color(BLUE),
ImageMobject("LipMole"),
)
for image in images[::2]:
image.set_height(0.5)
image.set_opacity(0.75)
colors = [GREY_BROWN, BLUE, ORANGE]
ref_vects = VGroup(
self.get_labeled_vector("", coords=coords)
for coords in [
(-1.0, -1.5, 1.5),
(-4.0, 0.5, 1.0),
(-0.5, 1.0, 2.5),
]
)
for vect, image, color in zip(ref_vects, images, colors):
vect.set_color(color)
image.next_to(vect.get_end(), UP, SMALL_BUFF)
gen_vect_group = VGroup(gen_vector, gen_vector.label)
self.play(
frame.animate.reorient(-30, -5, 0, (-1.11, 1.35, -0.72), 5.27),
LaggedStart(
(TransformFromCopy(gen_vector, ref_vect)
for ref_vect in ref_vects),
lag_ratio=0.25,
run_time=2,
),
LaggedStart(
(FadeInFromPoint(image, gen_vector.label.get_center())
for image in images),
lag_ratio=0.25,
run_time=2,
group_type=Group,
),
gen_vect_group.animate.set_opacity(0.25).set_anim_args(run_time=2),
run_time=2,
)
self.wait(3)
ref_vect_groups = Group(
Group(*pair) for pair in zip(ref_vects, images)
)
# Oscillate between meanings based on context
diff_vects = VGroup(
Arrow(gen_vector.get_end(), ref_vect.get_end(), buff=0)
for ref_vect in ref_vects
)
diff_vects.set_color(GREY_B)
last_phrase = curr_phrase
last_diff = VGroup()
for n, diff in enumerate(diff_vects):
ref_vect_groups.target = ref_vect_groups.generate_target()
ref_vect_groups.target.set_opacity(0.2)
ref_vect_groups.target[n].set_opacity(1)
if n != 2:
ref_vect_groups.target[2][1].set_opacity(0.1)
phrase = phrases[n]
self.play(
gen_vect_group.animate.set_opacity(1),
MoveToTarget(ref_vect_groups),
FadeOut(last_phrase, UP),
FadeIn(phrase, UP),
FadeOut(last_diff)
)
self.play(
ShowCreation(diff, time_span=(1, 2)),
TransformFromCopy(gen_vector, ref_vects[n], time_span=(1, 2)),
ContextAnimation(
phrase["mole"][0], phrase,
direction=DOWN,
fix_in_frame=True,
),
)
self.wait(3)
last_phrase = phrase
last_diff = diff
self.wait(5)
def get_basis(self, model):
basis = super().get_basis(model) * 2
basis[2] *= -1
return basis
class RefineTowerMeaning(MultipleMoleEmbeddings):
def construct(self):
# Set up vectors and images
frame = self.frame
frame.reorient(-26, -4, 0, (3.27, 1.57, 0.59), 5.28)
frame.add_ambient_rotation(0.5 * DEGREES)
words = VGroup(Text(word) for word in "Miniature Eiffel Tower".split(" "))
words.scale(1.25)
words.to_edge(UP)
words.fix_in_frame()
tower_images = Group(
ImageMobject(f"Tower{n}")
for n in range(1, 5)
)
eiffel_tower_images = Group(
ImageMobject(f"EiffelTower{n}")
for n in range(1, 4)
)
mini_eiffel_tower_images = Group(
ImageMobject("MiniEiffelTower1")
)
image_groups = Group(
tower_images,
eiffel_tower_images,
mini_eiffel_tower_images
)
vectors = VGroup(
self.get_labeled_vector("", coords=coords)
for coords in [
(4, -1, 3.0),
(5, -2, 1.5),
(-3, -1, 2.5),
]
)
colors = [BLUE_D, GREY_B, GREY_C]
for vector, color, image_group in zip(vectors, colors, image_groups):
vector.set_color(color)
for image in image_group:
image.set_height(1.5)
image.next_to(vector.get_end(), RIGHT * np.sign(vector.get_end()[0]))
# Show tower
tower = words[-1]
tower.set_x(0)
pre_tower_image = tower_images[0].copy()
pre_tower_image.fix_in_frame()
pre_tower_image.replace(tower, stretch=True)
pre_tower_image.set_opacity(0)
self.add(tower)
self.wait()
self.play(
GrowArrow(vectors[0]),
ReplacementTransform(pre_tower_image, tower_images[0]),
run_time=2,
)
for ti1, ti2 in zip(tower_images, tower_images[1:]):
self.play(
FadeTransform(ti1, ti2),
run_time=2
)
self.wait(2)
# Eiffel tower
words[:-1].set_opacity(0)
eiffel_tower = words[-2:]
self.play(
frame.animate.reorient(-4, -7, 0, (2.95, 1.82, 0.49), 6.59),
eiffel_tower.animate.set_opacity(1).arrange(RIGHT, aligned_edge=DOWN).to_edge(UP),
)
self.play(
vectors[0].animate.set_opacity(0.25),
tower_images[-1].animate.set_opacity(0.2),
TransformFromCopy(vectors[0], vectors[1]),
FadeTransform(tower_images[-1].copy(), eiffel_tower_images[0]),
ContextAnimation(words[2], words[1], direction=DOWN, fix_in_frame=True),
run_time=2,
)
for ti1, ti2 in zip(eiffel_tower_images, eiffel_tower_images[1:]):
self.play(
FadeTransform(ti1, ti2),
run_time=2
)
self.wait(2)
# Miniature eiffel tower
self.play(
frame.animate.reorient(-14, -2, 0, (-0.12, 2.21, 0.72), 7.05).set_anim_args(run_time=2),
words.animate.set_opacity(1).arrange(RIGHT, aligned_edge=DOWN).to_edge(UP),
)
self.play(
vectors[1].animate.set_opacity(0.25),
eiffel_tower_images[-1].animate.set_opacity(0.2),
TransformFromCopy(vectors[1], vectors[2]),
FadeTransform(eiffel_tower_images[-1].copy(), mini_eiffel_tower_images[0]),
ContextAnimation(words[2], words[0], direction=DOWN, fix_in_frame=True),
run_time=2,
)
self.wait(10)

View file

@ -1681,8 +1681,6 @@ class DistinguishWeightsAndData(InteractiveScene):
v_line.to_edge(UP, buff=0)
v_line.set_stroke(GREY_A, 2)
self.add(titles)
# Set up matrices
matrices = VGroup(
WeightMatrix(

View file

@ -23,6 +23,7 @@ class HighLevelNetworkFlow(InteractiveScene):
(" Ludwig", 0.0104),
]
hide_block_labels = False
block_to_title_direction = UP
def setup(self):
super().setup()
@ -153,7 +154,7 @@ class HighLevelNetworkFlow(InteractiveScene):
title = Text(title, font_size=title_font_size)
title.set_backstroke(BLACK, title_backstroke_width)
title.next_to(body, UP, buff=0.1)
title.next_to(body, self.block_to_title_direction, buff=0.1)
block = Group(body, title)
block.body = body
block.title = title
@ -221,7 +222,6 @@ class HighLevelNetworkFlow(InteractiveScene):
self.play(LaggedStartMap(VFadeInThenOut, arrows, lag_ratio=0.25, run_time=4))
self.play(FadeOut(token_label, DOWN))
# Show words into vectors
layer = self.get_embedding_array(
shape=(len(words), 10),
@ -1165,4 +1165,169 @@ class TextPassageIntro(InteractiveScene):
run_time=3
)
self.add(short_text)
self.wait()
self.wait()
class MoleExample1(HighLevelNetworkFlow):
block_to_title_direction = LEFT
highlighted_group_index = 1
def construct(self):
# Show three phrases
phrase_strs = [
"American shrew mole",
"One mole of carbon dioxide",
"Take a biopsy of the mole",
]
phrases = VGroup(map(Text, phrase_strs))
phrases.arrange(DOWN, buff=2.0)
phrases.move_to(0.25 * DOWN)
self.play(Write(phrases[0]), run_time=1)
self.wait()
for i in [1, 2]:
self.play(
Transform(phrases[i - 1]["mole"].copy(), phrases[i]["mole"].copy(), remover=True),
FadeIn(phrases[i], lag_ratio=0.1)
)
self.wait()
# Add mole images
images = Group(
ImageMobject("ShrewMole").set_height(1),
Tex(R"6.02 \times 10^{23}").set_color(TEAL),
ImageMobject("LipMole").set_height(1),
)
braces = VGroup()
mole_words = VGroup()
for image, phrase in zip(images, phrases):
mole_word = phrase["mole"][0]
brace = Brace(mole_word, UP, SMALL_BUFF)
image.next_to(brace, UP, SMALL_BUFF)
braces.add(brace)
mole_words.add(mole_word)
self.play(
LaggedStartMap(GrowFromCenter, braces, lag_ratio=0.5),
LaggedStartMap(FadeIn, images, shift=UP, lag_ratio=0.5),
mole_words.animate.set_color(YELLOW).set_anim_args(lag_ratio=0.1),
)
self.wait()
# Subdivide
word_groups = VGroup()
for phrase in phrases:
words = break_into_words(phrase.copy())
rects = get_piece_rectangles(
words, leading_spaces=False, h_buff=0.05
)
word_group = VGroup(VGroup(*pair) for pair in zip(rects, words))
word_groups.add(word_group)
self.play(
FadeIn(word_groups),
LaggedStartMap(FadeOut, braces, shift=0.25 * DOWN, lag_ratio=0.25),
LaggedStartMap(FadeOut, images, shift=0.25 * DOWN, lag_ratio=0.25),
run_time=1
)
self.remove(phrases)
self.wait()
# Divide into three regions
for group, sign in zip(word_groups, [-1, 0, 1]):
group.target = group.generate_target()
group.target.scale(0.75)
group.target.set_x(sign * FRAME_WIDTH / 3)
group.target.to_edge(UP)
v_lines = Line(UP, DOWN).replicate(2)
v_lines.set_height(FRAME_HEIGHT)
v_lines.arrange(RIGHT, buff=FRAME_WIDTH / 3)
v_lines.center()
v_lines.set_stroke(GREY_B, 1)
self.play(
LaggedStartMap(MoveToTarget, word_groups),
ShowCreation(v_lines, lag_ratio=0.5, time_span=(1, 2))
)
# Show vector embeddings
embs = VGroup()
arrows = VGroup()
seed_array = np.random.uniform(0, 10, 7)
for group in word_groups:
for word in group:
arrow = Vector(0.5 * DOWN)
arrow.next_to(word, DOWN, SMALL_BUFF)
size = sum(len(m.get_points()) for m in word.family_members_with_points())
values = (seed_array * size % 10)
emb = NumericEmbedding(values=values)
emb.set_height(2)
emb.next_to(arrow, DOWN, SMALL_BUFF)
arrows.add(arrow)
embs.add(emb)
mole_indices = [2, 4, 13]
non_mole_indices = [n for n in range(len(embs)) if n not in mole_indices]
mole_vect_rects = VGroup(
SurroundingRectangle(embs[index])
for index in mole_indices
)
mole_vect_rects.set_stroke(YELLOW, 2)
globals().update(locals())
self.play(
LaggedStartMap(GrowArrow, arrows),
LaggedStartMap(FadeIn, embs, shift=0.25 * DOWN),
)
self.wait()
self.play(
LaggedStartMap(ShowCreation, mole_vect_rects),
VGroup(arrows[j] for j in non_mole_indices).animate.set_fill(opacity=0.5),
VGroup(embs[j] for j in non_mole_indices).animate.set_fill(opacity=0.5),
)
self.wait()
self.play(
FadeOut(mole_vect_rects)
)
# Prepare to pass through an attention block
wg_lens = [len(wg) for wg in word_groups]
indices = [0, *np.cumsum(wg_lens)]
full_groups = VGroup(
VGroup(wg, arrows[i:j], embs[i:j])
for wg, i, j in zip(word_groups, indices, indices[1:])
)
highlighted_group = full_groups[self.highlighted_group_index]
fade_groups = [fg for n, fg in enumerate(full_groups) if n != self.highlighted_group_index]
highlighted_group.target = highlighted_group.generate_target()
highlighted_group.target.scale(1.5, about_edge=UP)
highlighted_group.target.space_out_submobjects(1.1)
highlighted_group.target.center()
highlighted_group.target[2].set_fill(opacity=1)
globals().update(locals())
self.play(
FadeOut(v_lines, time_span=(0, 1)),
MoveToTarget(highlighted_group, lag_ratio=5e-4),
*(
FadeOut(
fg,
shift=fg.get_center() - highlighted_group.get_center() + 2 * DOWN,
lag_ratio=1e-3
)
for fg in fade_groups
),
run_time=2
)
self.wait()
# Pass through attention
layer = VGroup(highlighted_group[2])
layer.embeddings = highlighted_group[2]
self.layers.set_submobjects([])
self.layers.add(layer)
self.progress_through_attention_block(target_frame_x=-2)
self.wait()

View file

@ -460,9 +460,10 @@ class GamePlan(InteractiveScene):
prev_thumbnails.set_width(FRAME_WIDTH - 2)
prev_thumbnails.move_to(2 * UP)
new_thumbnails = Group( # TODO, give these images
Rectangle().set_stroke(width=0).set_fill(BLACK, 1)
for vid in tr_vids
tn_dir = "/Users/grant/3Blue1Brown Dropbox/3Blue1Brown/videos/2024/transformers/Thumbnails/"
new_thumbnails = Group(
ImageMobject(os.path.join(tn_dir, f"Chapter{n}"))
for n in range(5, 8)
)
for tn1, tn2 in zip(prev_thumbnails, new_thumbnails):
tn2.replace(tn1, stretch=True)
@ -567,6 +568,37 @@ class SeaOfNumbersUnderlay(TeacherStudentsScene):
self.wait(8)
class Outdated(TeacherStudentsScene):
def construct(self):
# Add label
text = Text("GPT-3", font="Consolas", font_size=72)
openai_logo = SVGMobject("OpenAI.svg")
openai_logo.set_fill(WHITE)
openai_logo.set_height(2.0 * text.get_height())
gpt3_label = VGroup(openai_logo, text)
gpt3_label.arrange(RIGHT)
gpt3_label.scale(0.75)
param_count = Text("175B Parameters")
param_count.set_color(BLUE)
param_count.next_to(gpt3_label, DOWN, aligned_edge=LEFT)
gpt3_label.add(param_count)
gpt3_label.move_to(self.hold_up_spot, DOWN)
morty = self.teacher
morty.body.insert_n_curves(100)
self.play(
morty.change("raise_right_hand"),
FadeIn(gpt3_label, UP),
)
self.play(self.change_students("raise_left_hand", "hesitant", "sassy"))
self.play(
self.students[0].says(TexText("Isn't that outdated?"))
)
self.wait(3)
class ConfusionAtScreen(TeacherStudentsScene):
def construct(self):
self.play(