mirror of
https://github.com/3b1b/videos.git
synced 2025-08-31 21:58:59 +00:00
4093 lines
142 KiB
Python
4093 lines
142 KiB
Python
from __future__ import annotations
|
|
|
|
from manim_imports_ext import *
|
|
from _2024.transformers.helpers import *
|
|
from _2024.transformers.embedding import break_into_words
|
|
from _2024.transformers.embedding import break_into_tokens
|
|
from _2024.transformers.embedding import get_piece_rectangles
|
|
|
|
|
|
class AttentionPatterns(InteractiveScene):
|
|
def construct(self):
|
|
# Add sentence
|
|
phrase = " a fluffy blue creature roamed the verdant forest"
|
|
phrase_mob = Text(phrase)
|
|
phrase_mob.move_to(2 * UP)
|
|
words = list(filter(lambda s: s.strip(), phrase.split(" ")))
|
|
word2mob: Dict[str, VMobject] = {
|
|
word: phrase_mob[" " + word][0]
|
|
for word in words
|
|
}
|
|
word_mobs = VGroup(*word2mob.values())
|
|
|
|
self.play(
|
|
LaggedStartMap(FadeIn, word_mobs, shift=0.5 * UP, lag_ratio=0.25)
|
|
)
|
|
self.wait()
|
|
|
|
# Create word rects
|
|
word2rect: Dict[str, VMobject] = dict()
|
|
for word in words:
|
|
rect = SurroundingRectangle(word2mob[word])
|
|
rect.set_height(phrase_mob.get_height() + SMALL_BUFF, stretch=True)
|
|
rect.match_y(phrase_mob)
|
|
rect.set_stroke(GREY, 2)
|
|
rect.set_fill(GREY, 0.2)
|
|
word2rect[word] = rect
|
|
|
|
# Adjectives updating noun
|
|
adjs = ["fluffy", "blue", "verdant"]
|
|
nouns = ["creature", "forest"]
|
|
others = ["a", "roamed", "the"]
|
|
adj_mobs, noun_mobs, other_mobs = [
|
|
VGroup(word2mob[substr] for substr in group)
|
|
for group in [adjs, nouns, others]
|
|
]
|
|
adj_rects, noun_rects, other_rects = [
|
|
VGroup(word2rect[substr] for substr in group)
|
|
for group in [adjs, nouns, others]
|
|
]
|
|
adj_rects.set_submobject_colors_by_gradient(BLUE_C, BLUE_D, GREEN)
|
|
noun_rects.set_color(GREY_BROWN).set_stroke(width=3)
|
|
kw = dict()
|
|
adj_arrows = VGroup(
|
|
Arrow(
|
|
adj_mobs[i].get_top(), noun_mobs[j].get_top(),
|
|
path_arc=-150 * DEGREES, buff=0.1, stroke_color=GREY_B
|
|
)
|
|
for i, j in [(0, 0), (1, 0), (2, 1)]
|
|
)
|
|
|
|
self.play(
|
|
LaggedStartMap(DrawBorderThenFill, adj_rects),
|
|
Animation(adj_mobs),
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
LaggedStartMap(DrawBorderThenFill, noun_rects),
|
|
Animation(noun_mobs),
|
|
LaggedStartMap(ShowCreation, adj_arrows, lag_ratio=0.2, run_time=1.5),
|
|
)
|
|
kw = dict(time_width=2, max_stroke_width=10, lag_ratio=0.2, path_arc=150 * DEGREES)
|
|
self.play(
|
|
ContextAnimation(noun_mobs[0], adj_mobs[:2], strengths=[1, 1], **kw),
|
|
ContextAnimation(noun_mobs[1], adj_mobs[2:], strengths=[1], **kw),
|
|
)
|
|
self.wait()
|
|
|
|
# Show embeddings
|
|
all_rects = VGroup(*adj_rects, *noun_rects, *other_rects)
|
|
all_rects.sort(lambda p: p[0])
|
|
embeddings = VGroup(
|
|
NumericEmbedding(length=10).set_width(0.5).next_to(rect, DOWN, buff=1.5)
|
|
for rect in all_rects
|
|
)
|
|
emb_arrows = VGroup(
|
|
Arrow(all_rects[0].get_bottom(), embeddings[0].get_top()).match_x(rect)
|
|
for rect in all_rects
|
|
)
|
|
for index, vect in [(5, LEFT), (6, RIGHT)]:
|
|
embeddings[index].shift(0.1 * vect)
|
|
emb_arrows[index].shift(0.05 * vect)
|
|
|
|
self.play(
|
|
FadeIn(other_rects),
|
|
Animation(word_mobs),
|
|
LaggedStartMap(GrowArrow, emb_arrows),
|
|
LaggedStartMap(FadeIn, embeddings, shift=0.5 * DOWN),
|
|
FadeOut(adj_arrows)
|
|
)
|
|
self.wait()
|
|
|
|
# Mention dimension of embedding
|
|
frame = self.frame
|
|
brace = Brace(embeddings[0], LEFT, buff=SMALL_BUFF)
|
|
dim_value = Integer(12288)
|
|
dim_value.next_to(brace, LEFT)
|
|
dim_value.set_color(YELLOW)
|
|
|
|
self.play(
|
|
GrowFromCenter(brace),
|
|
CountInFrom(dim_value, 0),
|
|
frame.animate.move_to(LEFT)
|
|
)
|
|
self.wait()
|
|
|
|
# Ingest meaning and and position
|
|
images = Group(
|
|
ImageMobject(f"Dalle3_{word}").set_height(1.1).next_to(word2rect[word], UP)
|
|
for word in ["fluffy", "blue", "creature", "verdant", "forest"]
|
|
)
|
|
image_vects = VGroup(embeddings[i] for i in [1, 2, 3, 6, 7])
|
|
|
|
self.play(
|
|
LaggedStartMap(FadeIn, images, scale=2, lag_ratio=0.05)
|
|
)
|
|
self.play(
|
|
LaggedStart(
|
|
(self.bake_mobject_into_vector_entries(image, vect, group_type=Group)
|
|
for image, vect in zip(images, image_vects)),
|
|
group_type=Group,
|
|
lag_ratio=0.2,
|
|
run_time=4,
|
|
remover=True
|
|
),
|
|
)
|
|
self.wait()
|
|
self.add(embeddings, images)
|
|
|
|
# Show positions
|
|
pos_labels = VGroup(
|
|
Integer(n, font_size=36).next_to(rect, DOWN, buff=0.1)
|
|
for n, rect in enumerate(all_rects, start=1)
|
|
)
|
|
pos_labels.set_color(TEAL)
|
|
|
|
self.play(
|
|
LaggedStart(
|
|
(arrow.animate.scale(0.7, about_edge=DOWN)
|
|
for arrow in emb_arrows),
|
|
lag_ratio=0.1,
|
|
),
|
|
LaggedStartMap(FadeIn, pos_labels, shift=0.25 * DOWN, lag_ratio=0.1)
|
|
)
|
|
self.play(
|
|
LaggedStart(
|
|
(self.bake_mobject_into_vector_entries(pos, vect)
|
|
for pos, vect in zip(pos_labels, embeddings)),
|
|
lag_ratio=0.2,
|
|
run_time=4,
|
|
remover=True
|
|
),
|
|
)
|
|
self.wait()
|
|
|
|
# Collapse vectors
|
|
template = Tex(R"\vec{\textbf{E}}_{0}")
|
|
template[0].scale(1.5, about_edge=DOWN)
|
|
dec = template.make_number_changeable(0)
|
|
emb_syms = VGroup()
|
|
for n, rect in enumerate(all_rects, start=1):
|
|
dec.set_value(n)
|
|
sym = template.copy()
|
|
sym.next_to(rect, DOWN, buff=0.75)
|
|
sym.set_color(GREY_A)
|
|
emb_syms.add(sym)
|
|
for subgroup in [emb_syms[:4], emb_syms[4:]]:
|
|
subgroup.arrange_to_fit_width(subgroup.get_width())
|
|
|
|
emb_arrows.target = emb_arrows.generate_target()
|
|
|
|
for rect, arrow, sym in zip(all_rects, emb_arrows.target, emb_syms):
|
|
x_min = rect.get_x(LEFT)
|
|
x_max = rect.get_x(RIGHT)
|
|
low_point = sym[0].get_top()
|
|
if x_min < low_point[0] < x_max:
|
|
top_point = np.array([low_point[0], rect.get_y(DOWN), 0])
|
|
else:
|
|
top_point = rect.get_bottom()
|
|
arrow.become(Arrow(top_point, low_point, buff=SMALL_BUFF))
|
|
|
|
all_brackets = VGroup(emb.get_brackets() for emb in embeddings)
|
|
for brackets in all_brackets:
|
|
brackets.target = brackets.generate_target()
|
|
brackets.target.stretch(0, 1, about_edge=UP)
|
|
brackets.target.set_fill(opacity=0)
|
|
|
|
ghost_syms = emb_syms.copy()
|
|
ghost_syms.set_opacity(0)
|
|
|
|
self.play(
|
|
frame.animate.set_x(0).set_anim_args(run_time=2),
|
|
LaggedStart(
|
|
(AnimationGroup(
|
|
LaggedStart(
|
|
(FadeTransform(entry, sym)
|
|
for entry in embedding.get_columns()[0]),
|
|
lag_ratio=0.01,
|
|
group_type=Group
|
|
),
|
|
MoveToTarget(brackets),
|
|
group_type=Group,
|
|
)
|
|
for sym, embedding, brackets in zip(ghost_syms, embeddings, all_brackets)),
|
|
group_type=Group
|
|
),
|
|
LaggedStartMap(FadeIn, emb_syms, shift=UP),
|
|
brace.animate.stretch(0.25, 1, about_edge=UP).set_opacity(0),
|
|
FadeOut(dim_value, 0.25 * UP),
|
|
MoveToTarget(emb_arrows, lag_ratio=0.1, run_time=2),
|
|
LaggedStartMap(FadeOut, pos_labels, shift=UP),
|
|
)
|
|
emb_arrows.refresh_bounding_box(recurse_down=True) # Why?
|
|
self.clear()
|
|
self.add(emb_arrows, all_rects, word_mobs, images, emb_syms)
|
|
self.wait()
|
|
|
|
# Preview desired updates
|
|
emb_sym_primes = VGroup(
|
|
sym.copy().add(Tex("'").move_to(sym.get_corner(UR) + 0.05 * DL))
|
|
for sym in emb_syms
|
|
)
|
|
emb_sym_primes.shift(2 * DOWN)
|
|
emb_sym_primes.set_color(TEAL)
|
|
|
|
full_connections = VGroup()
|
|
for i, sym1 in enumerate(emb_syms, start=1):
|
|
for j, sym2 in enumerate(emb_sym_primes, start=1):
|
|
line = Line(sym1.get_bottom(), sym2.get_top(), buff=SMALL_BUFF)
|
|
line.set_stroke(GREY_B, width=random.random()**2, opacity=random.random()**0.25)
|
|
if (i, j) in [(2, 4), (3, 4), (4, 4), (7, 8), (8, 8)]:
|
|
line.set_stroke(WHITE, width=2 + random.random(), opacity=1)
|
|
full_connections.add(line)
|
|
|
|
blue_fluff = ImageMobject("BlueFluff")
|
|
verdant_forest = ImageMobject("VerdantForest")
|
|
for n, image in [(3, blue_fluff), (7, verdant_forest)]:
|
|
image.match_height(images)
|
|
image.scale(1.2)
|
|
image.next_to(emb_sym_primes[n], DOWN, buff=MED_SMALL_BUFF)
|
|
|
|
self.play(
|
|
ShowCreation(full_connections, lag_ratio=0.01, run_time=2),
|
|
LaggedStart(
|
|
(TransformFromCopy(sym1, sym2)
|
|
for sym1, sym2 in zip(emb_syms, emb_sym_primes)),
|
|
lag_ratio=0.05,
|
|
),
|
|
)
|
|
self.wait()
|
|
self.play(LaggedStart(
|
|
LaggedStart(
|
|
(FadeTransform(im.copy(), blue_fluff, remover=True)
|
|
for im in images[:3]),
|
|
lag_ratio=0.02,
|
|
group_type=Group
|
|
),
|
|
LaggedStart(
|
|
(FadeTransform(im.copy(), verdant_forest, remover=True)
|
|
for im in images[3:]),
|
|
lag_ratio=0.02,
|
|
group_type=Group
|
|
),
|
|
lag_ratio=0.5,
|
|
run_time=2
|
|
))
|
|
self.add(blue_fluff, verdant_forest)
|
|
self.wait()
|
|
|
|
# Show black box that matrix multiples can be added to
|
|
in_arrows = VGroup(
|
|
Vector(0.25 * DOWN, max_width_to_length_ratio=12.0).next_to(sym, DOWN, SMALL_BUFF)
|
|
for sym in emb_syms
|
|
)
|
|
box = Rectangle(15.0, 3.0)
|
|
box.set_fill(GREY_E, 1)
|
|
box.set_stroke(WHITE, 1)
|
|
box.next_to(in_arrows, DOWN, SMALL_BUFF)
|
|
out_arrows = in_arrows.copy()
|
|
out_arrows.next_to(box, DOWN)
|
|
|
|
self.play(
|
|
FadeIn(box, 0.25 * DOWN),
|
|
LaggedStartMap(FadeIn, in_arrows, shift=0.25 * DOWN, lag_ratio=0.025),
|
|
LaggedStartMap(FadeIn, out_arrows, shift=0.25 * DOWN, lag_ratio=0.025),
|
|
FadeOut(full_connections),
|
|
emb_sym_primes.animate.next_to(out_arrows, DOWN, SMALL_BUFF),
|
|
MaintainPositionRelativeTo(blue_fluff, emb_sym_primes),
|
|
MaintainPositionRelativeTo(verdant_forest, emb_sym_primes),
|
|
frame.animate.set_height(10).move_to(4 * UP, UP),
|
|
)
|
|
self.wait()
|
|
|
|
# Clear the board
|
|
self.play(
|
|
frame.animate.set_height(8).move_to(2 * UP).set_anim_args(run_time=1.5),
|
|
LaggedStartMap(FadeOut, Group(
|
|
*images, in_arrows, box, out_arrows, emb_sym_primes,
|
|
blue_fluff, verdant_forest,
|
|
), lag_ratio=0.1)
|
|
)
|
|
|
|
# Ask questions
|
|
word_groups = VGroup(VGroup(*pair) for pair in zip(all_rects, word_mobs))
|
|
for group in word_groups:
|
|
group.save_state()
|
|
q_bubble = SpeechBubble("Any adjectives\nin front of me?")
|
|
q_bubble.move_tip_to(word2rect["creature"].get_top())
|
|
|
|
a_bubbles = SpeechBubble("I am!", direction=RIGHT).replicate(2)
|
|
a_bubbles[1].flip()
|
|
a_bubbles[0].move_tip_to(word2rect["fluffy"].get_top())
|
|
a_bubbles[1].move_tip_to(word2rect["blue"].get_top())
|
|
|
|
self.play(
|
|
FadeIn(q_bubble),
|
|
word_groups[:3].animate.fade(0.75),
|
|
word_groups[4:].animate.fade(0.75),
|
|
)
|
|
self.wait()
|
|
self.play(LaggedStart(
|
|
Restore(word_groups[1]),
|
|
Restore(word_groups[2]),
|
|
*map(Write, a_bubbles),
|
|
lag_ratio=0.5
|
|
))
|
|
self.wait()
|
|
|
|
# Associate questions with vectors
|
|
a_bubbles.save_state()
|
|
q_arrows = VGroup(
|
|
Vector(0.75 * DOWN).next_to(sym, DOWN, SMALL_BUFF)
|
|
for sym in emb_syms
|
|
)
|
|
q_vects = VGroup(
|
|
NumericEmbedding(length=7).set_height(2).next_to(arrow, DOWN)
|
|
for arrow in q_arrows
|
|
)
|
|
question = q_bubble.content
|
|
|
|
|
|
index = words.index("creature")
|
|
q_vect = q_vects[index]
|
|
q_arrow = q_arrows[index]
|
|
self.play(LaggedStart(
|
|
FadeOut(q_bubble.body, DOWN),
|
|
question.animate.scale(0.75).next_to(q_vect, RIGHT),
|
|
FadeIn(q_vect, DOWN),
|
|
GrowArrow(q_arrow),
|
|
frame.animate.move_to(ORIGIN),
|
|
a_bubbles.animate.fade(0.5),
|
|
))
|
|
self.play(
|
|
self.bake_mobject_into_vector_entries(question, q_vect)
|
|
)
|
|
self.wait()
|
|
|
|
# Label query vector
|
|
brace = Brace(q_vect, LEFT, SMALL_BUFF)
|
|
query_word = Text("Query")
|
|
query_word.set_color(YELLOW)
|
|
query_word.next_to(brace, LEFT, SMALL_BUFF)
|
|
dim_text = Text("128-dimensional", font_size=36)
|
|
dim_text.set_color(YELLOW)
|
|
dim_text.next_to(brace, LEFT, SMALL_BUFF)
|
|
dim_text.set_y(query_word.get_y(DOWN))
|
|
|
|
self.play(
|
|
GrowFromCenter(brace),
|
|
FadeIn(query_word, 0.25 * LEFT),
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
query_word.animate.next_to(dim_text, UP, SMALL_BUFF),
|
|
FadeIn(dim_text, 0.1 * DOWN),
|
|
)
|
|
self.wait()
|
|
|
|
# Show individual matrix product
|
|
e_vect = NumericEmbedding(length=12)
|
|
e_vect.match_width(q_vect)
|
|
e_vect.next_to(q_vect, DR, buff=1.5)
|
|
matrix = WeightMatrix(shape=(7, 12))
|
|
matrix.match_height(q_vect)
|
|
matrix.next_to(e_vect, LEFT)
|
|
e_label_copy = emb_syms[index].copy()
|
|
e_label_copy.next_to(e_vect, UP)
|
|
q_vect.save_state()
|
|
ghost_q_vect = NumericEmbedding(length=7).match_height(q_vect)
|
|
ghost_q_vect.get_columns().set_opacity(0)
|
|
ghost_q_vect.get_brackets().space_out_submobjects(1.75)
|
|
ghost_q_vect.next_to(e_vect, RIGHT, buff=0.7)
|
|
|
|
mat_brace = Brace(matrix, UP)
|
|
mat_label = Tex("W_Q")
|
|
mat_label.next_to(mat_brace, UP, SMALL_BUFF)
|
|
mat_label.set_color(YELLOW)
|
|
|
|
self.play(
|
|
frame.animate.set_height(11).move_to(all_rects, UP).shift(0.35 * UP),
|
|
FadeOut(a_bubbles),
|
|
FadeInFromPoint(e_vect, emb_syms[index].get_center()),
|
|
FadeInFromPoint(matrix, q_arrow.get_center()),
|
|
TransformFromCopy(emb_syms[index], e_label_copy),
|
|
FadeOut(q_vect),
|
|
TransformFromCopy(q_vect, ghost_q_vect),
|
|
MaintainPositionRelativeTo(question, q_vect),
|
|
)
|
|
self.play(
|
|
GrowFromCenter(mat_brace),
|
|
FadeIn(mat_label, 0.1 * UP),
|
|
)
|
|
self.remove(ghost_q_vect)
|
|
eq, rhs = show_matrix_vector_product(self, matrix, e_vect)
|
|
|
|
new_q_vect = rhs.deepcopy()
|
|
new_q_vect.move_to(q_vect, LEFT)
|
|
|
|
self.play(
|
|
TransformFromCopy(rhs, new_q_vect, path_arc=PI / 2),
|
|
question.animate.next_to(new_q_vect, RIGHT)
|
|
)
|
|
self.wait()
|
|
|
|
# Collapse query vector
|
|
q_sym_template = Tex(R"\vec{\textbf{Q}}_0", font_size=48)
|
|
q_sym_template[0].scale(1.5, about_edge=DOWN)
|
|
q_sym_template.set_color(YELLOW)
|
|
subscript = q_sym_template.make_number_changeable(0)
|
|
q_syms = VGroup()
|
|
for n, arrow in enumerate(q_arrows, start=1):
|
|
subscript.set_value(n)
|
|
sym = q_sym_template.copy()
|
|
sym.next_to(arrow, DOWN, SMALL_BUFF)
|
|
q_syms.add(sym)
|
|
|
|
mat_label2 = mat_label.copy()
|
|
|
|
q_sym = q_syms[index]
|
|
low_q_sym = q_sym.copy()
|
|
low_q_sym.next_to(rhs, UP)
|
|
|
|
self.play(LaggedStart(
|
|
LaggedStart(
|
|
(FadeTransform(entry, q_sym, remover=True)
|
|
for entry in new_q_vect.get_columns()[0]),
|
|
lag_ratio=0.01,
|
|
group_type=Group,
|
|
),
|
|
new_q_vect.get_brackets().animate.stretch(0, 1, about_edge=UP).set_opacity(0),
|
|
FadeOutToPoint(query_word, q_sym.get_center()),
|
|
FadeOutToPoint(dim_text, q_sym.get_center()),
|
|
FadeOut(brace),
|
|
question.animate.next_to(q_sym, DOWN),
|
|
FadeIn(low_q_sym, UP),
|
|
lag_ratio=0.1,
|
|
))
|
|
self.remove(new_q_vect)
|
|
self.add(q_sym)
|
|
self.play(
|
|
mat_label2.animate.scale(0.9).next_to(q_arrow, RIGHT, buff=0.15),
|
|
)
|
|
self.wait()
|
|
|
|
# E to Q rects
|
|
e_rects = VGroup(map(SurroundingRectangle, [emb_syms[index], e_vect]))
|
|
q_rects = VGroup(map(SurroundingRectangle, [q_sym, rhs]))
|
|
e_rects.set_stroke(TEAL, 3)
|
|
q_rects.set_stroke(YELLOW, 3)
|
|
self.play(ShowCreation(e_rects, lag_ratio=0.2))
|
|
self.wait()
|
|
self.play(Transform(e_rects, q_rects))
|
|
self.wait()
|
|
self.play(FadeOut(e_rects))
|
|
|
|
# Add other query vectors
|
|
remaining_q_arrows = VGroup(*q_arrows[:index], *q_arrows[index + 1:])
|
|
remaining_q_syms = VGroup(*q_syms[:index], *q_syms[index + 1:])
|
|
wq_syms = VGroup(
|
|
Tex(R"W_Q", font_size=30).next_to(arrow, RIGHT, buff=0.1)
|
|
for arrow in q_arrows
|
|
)
|
|
wq_syms.set_color(YELLOW)
|
|
subscripts = VGroup(e_label_copy[-1], low_q_sym[-1][0])
|
|
for subscript in subscripts:
|
|
i_sym = Tex("i")
|
|
i_sym.replace(subscript)
|
|
i_sym.scale(0.75)
|
|
i_sym.match_style(subscript)
|
|
subscript.target = i_sym
|
|
|
|
self.play(
|
|
LaggedStartMap(GrowArrow, remaining_q_arrows),
|
|
LaggedStartMap(FadeIn, remaining_q_syms, shift=0.1 * DOWN),
|
|
ReplacementTransform(VGroup(mat_label2), wq_syms, lag_ratio=0.01, run_time=2),
|
|
question.animate.shift(0.25 * DOWN),
|
|
*map(Restore, word_groups),
|
|
*map(MoveToTarget, subscripts),
|
|
)
|
|
self.wait()
|
|
|
|
# Emphasize model weights
|
|
self.play(
|
|
LaggedStartMap(FlashAround, matrix.get_entries(), lag_ratio=1e-2),
|
|
RandomizeMatrixEntries(matrix),
|
|
)
|
|
data_modifying_matrix(self, matrix, word_shape=(3, 8))
|
|
self.wait()
|
|
self.play(
|
|
LaggedStartMap(FadeOut, VGroup(
|
|
matrix, mat_brace, mat_label,
|
|
e_vect, e_label_copy, eq, rhs,
|
|
low_q_sym
|
|
), shift=0.2 * DR)
|
|
)
|
|
self.wait()
|
|
|
|
# Move question
|
|
noun_q_syms = VGroup(q_syms[words.index(word)] for word in ["creature", "forest"])
|
|
|
|
self.play(
|
|
question.animate.shift(0.25 * DOWN).match_x(noun_q_syms)
|
|
)
|
|
|
|
noun_q_lines = VGroup(
|
|
Line(question.get_corner(v), sym.get_corner(-v) + 0.25 * v)
|
|
for sym, v in zip(noun_q_syms, [UL, UR])
|
|
)
|
|
noun_q_lines.set_stroke(GREY, 1)
|
|
self.play(ShowCreation(noun_q_lines, lag_ratio=0))
|
|
self.wait()
|
|
|
|
# Set up keys
|
|
key_word_groups = word_groups.copy()
|
|
key_word_groups.arrange(DOWN, buff=0.75, aligned_edge=RIGHT)
|
|
key_word_groups.next_to(q_syms, DL, buff=LARGE_BUFF)
|
|
key_word_groups.shift(3.0 * LEFT)
|
|
key_emb_syms = emb_syms.copy()
|
|
|
|
k_sym_template = Tex(R"\vec{\textbf{K}}_0", font_size=48)
|
|
k_sym_template[0].scale(1.5, about_edge=DOWN)
|
|
k_sym_template.set_color(TEAL)
|
|
subscript = k_sym_template.make_number_changeable(0)
|
|
|
|
k_syms = VGroup()
|
|
key_emb_arrows = VGroup()
|
|
wk_arrows = VGroup()
|
|
wk_syms = VGroup()
|
|
for group, emb_sym, n in zip(key_word_groups, key_emb_syms, it.count(1)):
|
|
emb_arrow = Vector(0.5 * RIGHT)
|
|
emb_arrow.next_to(group, RIGHT, SMALL_BUFF)
|
|
emb_sym.next_to(emb_arrow, RIGHT, SMALL_BUFF)
|
|
wk_arrow = Vector(0.75 * RIGHT)
|
|
wk_arrow.next_to(emb_sym, RIGHT)
|
|
wk_sym = Tex("W_k", font_size=30)
|
|
wk_sym.set_fill(TEAL, border_width=1)
|
|
wk_sym.next_to(wk_arrow, UP)
|
|
subscript.set_value(n)
|
|
k_sym = k_sym_template.copy()
|
|
k_sym.next_to(wk_arrow, RIGHT, buff=MED_SMALL_BUFF)
|
|
|
|
key_emb_arrows.add(emb_arrow)
|
|
wk_arrows.add(wk_arrow)
|
|
wk_syms.add(wk_sym)
|
|
k_syms.add(k_sym)
|
|
|
|
self.remove(question, noun_q_lines)
|
|
self.play(
|
|
frame.animate.move_to(2.5 * LEFT + 2.75 * DOWN),
|
|
TransformFromCopy(word_groups, key_word_groups),
|
|
TransformFromCopy(emb_arrows, key_emb_arrows),
|
|
TransformFromCopy(emb_syms, key_emb_syms),
|
|
run_time=2,
|
|
)
|
|
self.play(
|
|
LaggedStartMap(GrowArrow, wk_arrows),
|
|
LaggedStartMap(FadeIn, wk_syms, shift=0.1 * UP),
|
|
)
|
|
self.play(LaggedStart(
|
|
(TransformFromCopy(e_sym, k_sym)
|
|
for e_sym, k_sym in zip(key_emb_syms, k_syms)),
|
|
lag_ratio=0.05,
|
|
))
|
|
self.wait()
|
|
|
|
# Isolate examples
|
|
fade_rects = VGroup(
|
|
BackgroundRectangle(VGroup(key_word_groups[0], wk_syms[0], k_syms[0])),
|
|
BackgroundRectangle(VGroup(key_word_groups[3:], wk_syms[3:], k_syms[3:])),
|
|
BackgroundRectangle(wq_syms[2]),
|
|
BackgroundRectangle(VGroup(word_groups[:3], q_syms[:3])),
|
|
BackgroundRectangle(VGroup(word_groups[4:], q_syms[4:])),
|
|
)
|
|
fade_rects.set_fill(BLACK, 0.75)
|
|
fade_rects.set_stroke(BLACK, 3, 1)
|
|
q_bubble = SpeechBubble("Any adjectives\nin front of me?")
|
|
q_bubble.flip(RIGHT)
|
|
q_bubble.next_to(q_syms[3][-1], DOWN, SMALL_BUFF, LEFT)
|
|
a_bubbles = SpeechBubble("I'm an adjective!\nI'm there!").replicate(2)
|
|
a_bubbles[0].pin_to(k_syms[1])
|
|
a_bubbles[1].pin_to(k_syms[2])
|
|
a_bubbles[1].flip(RIGHT, about_edge=DOWN)
|
|
a_bubbles[1].shift(0.5 * DOWN)
|
|
|
|
self.add(fade_rects, word_groups[3])
|
|
self.play(FadeIn(fade_rects))
|
|
self.play(FadeIn(q_bubble, lag_ratio=0.1))
|
|
self.play(FadeIn(a_bubbles, lag_ratio=0.05))
|
|
self.wait()
|
|
|
|
# Show example key matrix
|
|
matrix = WeightMatrix(shape=(7, 12))
|
|
matrix.set_width(5)
|
|
matrix.next_to(k_syms, UP, buff=2.0, aligned_edge=RIGHT)
|
|
mat_rect = SurroundingRectangle(matrix, buff=MED_SMALL_BUFF)
|
|
lil_rect = SurroundingRectangle(wk_syms[1])
|
|
lines = VGroup(
|
|
Line(lil_rect.get_corner(v + UP), mat_rect.get_corner(v + DOWN))
|
|
for v in [LEFT, RIGHT]
|
|
)
|
|
VGroup(mat_rect, lil_rect, *lines).set_stroke(GREY_A, 1)
|
|
|
|
self.play(ShowCreation(lil_rect))
|
|
self.play(
|
|
ShowCreation(lines, lag_ratio=0),
|
|
TransformFromCopy(lil_rect, mat_rect),
|
|
FadeInFromPoint(matrix, lil_rect.get_center()),
|
|
)
|
|
self.wait()
|
|
data_modifying_matrix(self, matrix, word_shape=(3, 8))
|
|
self.play(
|
|
LaggedStartMap(FadeOut, VGroup(matrix, mat_rect, lines, lil_rect), run_time=1)
|
|
)
|
|
self.play(
|
|
LaggedStartMap(FadeOut, VGroup(q_bubble, *a_bubbles), lag_ratio=0.25)
|
|
)
|
|
self.wait()
|
|
|
|
# Draw grid
|
|
emb_arrows.refresh_bounding_box(recurse_down=True)
|
|
q_groups = VGroup(
|
|
VGroup(group[i] for group in [
|
|
emb_arrows, emb_syms, wq_syms, q_arrows, q_syms
|
|
])
|
|
for i in range(len(emb_arrows))
|
|
)
|
|
q_groups.target = q_groups.generate_target()
|
|
q_groups.target.arrange_to_fit_width(12, about_edge=LEFT)
|
|
q_groups.target.shift(0.25 * DOWN)
|
|
|
|
word_groups.target = word_groups.generate_target()
|
|
for word_group, q_group in zip(word_groups.target, q_groups.target):
|
|
word_group.scale(0.7)
|
|
word_group.next_to(q_group[0], UP, SMALL_BUFF)
|
|
|
|
h_lines = VGroup()
|
|
v_buff = 0.5 * (key_word_groups[0].get_y(DOWN) - key_word_groups[1].get_y(UP))
|
|
for kwg in key_word_groups:
|
|
h_line = Line(LEFT, RIGHT).set_width(20)
|
|
h_line.next_to(kwg, UP, buff=v_buff)
|
|
h_line.align_to(key_word_groups, LEFT)
|
|
h_lines.add(h_line)
|
|
|
|
v_lines = VGroup()
|
|
h_buff = 0.5
|
|
for q_group in q_groups.target:
|
|
v_line = Line(UP, DOWN).set_height(14)
|
|
v_line.next_to(q_group, LEFT, buff=h_buff, aligned_edge=UP)
|
|
v_lines.add(v_line)
|
|
v_lines.add(v_lines[-1].copy().next_to(q_groups.target, RIGHT, 0.5, UP))
|
|
|
|
grid_lines = VGroup(*h_lines, *v_lines)
|
|
grid_lines.set_stroke(GREY_A, 1)
|
|
|
|
self.play(
|
|
frame.animate.set_height(15, about_edge=UP).set_x(-2).set_anim_args(run_time=3),
|
|
MoveToTarget(q_groups),
|
|
MoveToTarget(word_groups),
|
|
ShowCreation(h_lines, lag_ratio=0.2),
|
|
ShowCreation(v_lines, lag_ratio=0.2),
|
|
FadeOut(fade_rects),
|
|
)
|
|
|
|
# Take all dot products
|
|
dot_prods = VGroup()
|
|
for k_sym in k_syms:
|
|
for q_sym in q_syms:
|
|
square_center = np.array([q_sym.get_x(), k_sym.get_y(), 0])
|
|
dot = Tex(R".", font_size=72)
|
|
dot.move_to(square_center)
|
|
dot.set_fill(opacity=0)
|
|
dot_prod = VGroup(k_sym.copy(), dot, q_sym.copy())
|
|
dot_prod.target = dot_prod.generate_target()
|
|
dot_prod.target.arrange(RIGHT, buff=0.15)
|
|
dot_prod.target.scale(0.65)
|
|
dot_prod.target.move_to(square_center)
|
|
dot_prod.target.set_fill(opacity=1)
|
|
dot_prods.add(dot_prod)
|
|
|
|
self.play(
|
|
LaggedStartMap(MoveToTarget, dot_prods, lag_ratio=0.025, run_time=4)
|
|
)
|
|
self.wait()
|
|
|
|
# Show grid of dots
|
|
dots = VGroup(
|
|
VGroup(Dot().match_x(q_sym).match_y(k_sym) for q_sym in q_syms)
|
|
for k_sym in k_syms
|
|
)
|
|
for n, row in enumerate(dots, start=1):
|
|
for k, dot in enumerate(row, start=1):
|
|
dot.set_fill(GREY_C, 0.8)
|
|
dot.set_width(random.random())
|
|
dot.target = dot.generate_target()
|
|
dot.target.set_width(0.1 + 0.2 * random.random())
|
|
if (n, k) in [(2, 4), (3, 4), (7, 8)]:
|
|
dot.target.set_width(0.8 + 0.2 * random.random())
|
|
flat_dots = VGroup(*it.chain(*dots))
|
|
|
|
self.play(
|
|
dot_prods.animate.set_fill(opacity=0.75),
|
|
LaggedStartMap(GrowFromCenter, flat_dots)
|
|
)
|
|
self.wait()
|
|
self.play(LaggedStartMap(MoveToTarget, flat_dots, lag_ratio=0.01))
|
|
self.wait()
|
|
|
|
# Resize to reflect true pattern
|
|
k_groups = VGroup(
|
|
VGroup(group[i] for group in [
|
|
key_word_groups, key_emb_arrows,
|
|
key_emb_syms, wk_syms, wk_arrows, k_syms
|
|
])
|
|
for i in range(len(emb_arrows))
|
|
)
|
|
for q_group, word_group in zip(q_groups, word_groups):
|
|
q_group.add_to_back(word_group)
|
|
self.add(k_groups, q_groups, Point())
|
|
|
|
k_fade_rects = VGroup(map(BackgroundRectangle, k_groups))
|
|
q_fade_rects = VGroup(map(BackgroundRectangle, q_groups))
|
|
for rect in (*k_fade_rects, *q_fade_rects):
|
|
rect.scale(1.05)
|
|
rect.set_fill(BLACK, 0.8)
|
|
|
|
self.play(
|
|
frame.animate.move_to([-4.33, -2.4, 0.0]).set_height(9.52),
|
|
FadeIn(k_fade_rects[:1]),
|
|
FadeIn(k_fade_rects[3:]),
|
|
FadeIn(q_fade_rects[:3]),
|
|
FadeIn(q_fade_rects[4:]),
|
|
run_time=2
|
|
)
|
|
self.wait()
|
|
|
|
k_rects = VGroup(map(SurroundingRectangle, k_groups[1:3]))
|
|
k_rects.set_stroke(TEAL, 2)
|
|
q_rects = VGroup(SurroundingRectangle(q_groups[3]))
|
|
q_rects.set_stroke(YELLOW, 2)
|
|
|
|
self.play(
|
|
ShowCreation(k_rects, lag_ratio=0.5, run_time=2),
|
|
LaggedStartMap(
|
|
FlashAround, k_groups[1:3],
|
|
color=TEAL,
|
|
time_width=2,
|
|
lag_ratio=0.25,
|
|
run_time=3
|
|
),
|
|
)
|
|
self.wait()
|
|
self.play(TransformFromCopy(k_rects, q_rects))
|
|
self.wait()
|
|
|
|
# Show numerical dot product
|
|
high_dot_prods = VGroup(dot_prods[8 + 3], dot_prods[2 * 8 + 3])
|
|
dots_to_grow = VGroup(dots[1][3], dots[2][3])
|
|
numerical_dot_prods = VGroup(
|
|
VGroup(
|
|
DecimalNumber(
|
|
np.random.uniform(-100, 10),
|
|
include_sign=True,
|
|
font_size=42,
|
|
num_decimal_places=1,
|
|
edge_to_fix=ORIGIN,
|
|
).move_to(dot)
|
|
for dot in row
|
|
)
|
|
for row in dots
|
|
)
|
|
for n, row in enumerate(numerical_dot_prods):
|
|
row[n].set_value(5 * random.random()) # Add some self relevance
|
|
flat_numerical_dot_prods = VGroup(*it.chain(*numerical_dot_prods))
|
|
for ndp in flat_numerical_dot_prods:
|
|
ndp.set_fill(interpolate_color(RED_E, GREY_C, random.random()))
|
|
high_numerical_dot_prods = VGroup(
|
|
numerical_dot_prods[1][3],
|
|
numerical_dot_prods[2][3],
|
|
numerical_dot_prods[6][7],
|
|
)
|
|
for hdp in high_numerical_dot_prods:
|
|
hdp.set_value(92 + 2 * random.random())
|
|
hdp.set_color(WHITE)
|
|
low_numerical_dot_prod = numerical_dot_prods[5][3]
|
|
low_numerical_dot_prod.set_value(-31.4)
|
|
low_numerical_dot_prod.set_fill(RED_D)
|
|
|
|
self.play(
|
|
*(dtg.animate.scale(1.25) for dtg in dots_to_grow),
|
|
*(CountInFrom(ndp, run_time=1) for ndp in high_numerical_dot_prods[:2]),
|
|
*(VFadeIn(ndp) for ndp in high_numerical_dot_prods[:2]),
|
|
*(FadeOut(dot_prod, run_time=0.5) for dot_prod in dot_prods),
|
|
)
|
|
self.wait()
|
|
|
|
# Show "attends to"
|
|
att_arrow = Arrow(k_rects.get_top(), q_rects.get_left(), path_arc=-90 * DEGREES)
|
|
att_words = TexText("``Attends to''", font_size=72)
|
|
att_words.next_to(att_arrow.pfp(0.4), UL)
|
|
|
|
self.play(
|
|
ShowCreation(att_arrow),
|
|
Write(att_words),
|
|
)
|
|
self.wait()
|
|
self.play(FadeOut(att_words), FadeOut(att_arrow))
|
|
|
|
# Contrast with "the" and "creature"
|
|
self.play(
|
|
frame.animate.move_to([-2.79, -3.66, 0.0]).set_height(12.29),
|
|
*(k_rect.animate.surround(k_groups[5]) for k_rect in k_rects),
|
|
FadeIn(k_fade_rects[1:3]),
|
|
FadeOut(k_fade_rects[5]),
|
|
run_time=2,
|
|
)
|
|
self.play(
|
|
CountInFrom(low_numerical_dot_prod),
|
|
VFadeIn(low_numerical_dot_prod),
|
|
FadeOut(dots[5][3]),
|
|
)
|
|
self.wait()
|
|
|
|
# Zoom out on full grid
|
|
self.play(
|
|
frame.animate.move_to([-1.5, -4.8, 0.0]).set_height(15).set_anim_args(run_time=3),
|
|
LaggedStart(
|
|
FadeOut(k_rects),
|
|
FadeOut(q_rects),
|
|
FadeOut(k_fade_rects[:5]),
|
|
FadeOut(k_fade_rects[6:]),
|
|
FadeOut(q_fade_rects[:3]),
|
|
FadeOut(q_fade_rects[4:]),
|
|
FadeOut(dots),
|
|
LaggedStartMap(FadeIn, numerical_dot_prods),
|
|
Animation(high_numerical_dot_prods.copy(), remover=True),
|
|
Animation(low_numerical_dot_prod.copy(), remover=True),
|
|
)
|
|
)
|
|
self.wait()
|
|
|
|
# Focus on one column
|
|
ndp_columns = VGroup(
|
|
VGroup(row[i] for row in numerical_dot_prods)
|
|
for i in range(len(numerical_dot_prods[0]))
|
|
)
|
|
col_rect = SurroundingRectangle(ndp_columns[3], buff=0.25)
|
|
col_rect.set_stroke(YELLOW, 2)
|
|
weight_words = Text("We want these to\nact like weights", font_size=96)
|
|
weight_words.set_backstroke(BLACK, 8)
|
|
weight_words.next_to(col_rect, RIGHT, buff=MED_LARGE_BUFF)
|
|
weight_words.match_y(h_lines[2])
|
|
|
|
index = words.index("creature")
|
|
self.play(
|
|
ShowCreation(col_rect),
|
|
grid_lines.animate.set_stroke(opacity=0.5),
|
|
ndp_columns[:index].animate.set_opacity(0.35),
|
|
ndp_columns[index + 1:].animate.set_opacity(0.35),
|
|
FadeIn(weight_words, lag_ratio=0.1)
|
|
)
|
|
self.wait()
|
|
|
|
# Show softmax of each columns
|
|
self.set_floor_plane("xz")
|
|
col_arrays = [np.array([num.get_value() for num in col]) for col in ndp_columns]
|
|
softmax_arrays = list(map(softmax, col_arrays))
|
|
softmax_cols = VGroup(
|
|
VGroup(DecimalNumber(v) for v in softmax_array)
|
|
for softmax_array in softmax_arrays
|
|
)
|
|
sm_arrows = VGroup()
|
|
sm_labels = VGroup()
|
|
sm_rects = VGroup()
|
|
for sm_col, col in zip(softmax_cols, ndp_columns):
|
|
for sm_val, val in zip(sm_col, col):
|
|
sm_val.move_to(val)
|
|
sm_col.save_state()
|
|
sm_col.shift(6 * OUT)
|
|
sm_rect = SurroundingRectangle(sm_col)
|
|
sm_rect.match_style(col_rect)
|
|
VGroup(sm_col, sm_rect).rotate(30 * DEGREES, DOWN)
|
|
arrow = Arrow(col, sm_col.get_center() + SMALL_BUFF * RIGHT + IN)
|
|
label = Text("softmax", font_size=72)
|
|
label.set_backstroke(BLACK, 5)
|
|
label.rotate(90 * DEGREES, DOWN)
|
|
label.next_to(arrow, UP)
|
|
sm_arrows.add(arrow)
|
|
sm_labels.add(label)
|
|
sm_rects.add(sm_rect)
|
|
|
|
index = words.index("creature")
|
|
self.play(
|
|
frame.animate.reorient(-47, -7, 0, (-2.48, -5.84, -1.09), 20),
|
|
GrowArrow(sm_arrows[index], time_span=(1, 2)),
|
|
FadeIn(sm_labels[index], lag_ratio=0.1, time_span=(1, 2)),
|
|
TransformFromCopy(ndp_columns[index], softmax_cols[index], time_span=(1.5, 3)),
|
|
TransformFromCopy(col_rect, sm_rects[index], time_span=(1.5, 3)),
|
|
FadeOut(weight_words),
|
|
run_time=3
|
|
)
|
|
self.wait()
|
|
|
|
remaining_indices = [*range(index), *range(index + 1, len(ndp_columns))]
|
|
last_index = index
|
|
for index in remaining_indices:
|
|
self.play(
|
|
ndp_columns[last_index].animate.set_opacity(0.35),
|
|
ndp_columns[index].animate.set_opacity(1),
|
|
col_rect.animate.move_to(ndp_columns[index]),
|
|
softmax_cols[last_index].animate.set_opacity(0.25),
|
|
*map(FadeOut, [sm_rects[last_index], sm_arrows[last_index], sm_labels[last_index]]),
|
|
)
|
|
self.play(
|
|
GrowArrow(sm_arrows[index]),
|
|
FadeIn(sm_labels[index], lag_ratio=0.1),
|
|
TransformFromCopy(ndp_columns[index], softmax_cols[index]),
|
|
TransformFromCopy(col_rect, sm_rects[index]),
|
|
)
|
|
last_index = index
|
|
|
|
self.play(
|
|
FadeOut(col_rect),
|
|
*map(FadeOut, [sm_rects[last_index], sm_arrows[last_index], sm_labels[last_index]]),
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
frame.animate.reorient(0, 0, 0, (-2.64, -4.8, 0.0), 14.54),
|
|
LaggedStartMap(Restore, softmax_cols, lag_ratio=0.1),
|
|
FadeOut(ndp_columns, time_span=(0, 1.5)),
|
|
run_time=3,
|
|
)
|
|
self.wait()
|
|
|
|
# Label attention pattern
|
|
for n, row in enumerate(dots):
|
|
if n not in [3, 7]:
|
|
row[n].set_width(0.7 + 0.2 * random.random())
|
|
dots[1][3].set_width(0.6 + 0.1 * random.random())
|
|
dots[2][3].set_width(0.6 + 0.1 * random.random())
|
|
dots[6][7].set_width(0.9 + 0.1 * random.random())
|
|
|
|
pattern_words = Text("Attention\nPattern", font_size=120)
|
|
pattern_words.move_to(grid_lines, UL).shift(LEFT)
|
|
|
|
self.play(
|
|
FadeOut(softmax_cols, lag_ratio=0.001),
|
|
FadeIn(dots, lag_ratio=0.001),
|
|
Write(pattern_words),
|
|
run_time=2
|
|
)
|
|
self.wait()
|
|
|
|
# Preview masking
|
|
masked_dots = VGroup()
|
|
for n, row in enumerate(dots):
|
|
masked_dots.add(*row[:n])
|
|
mask_rects = VGroup()
|
|
for dot in masked_dots:
|
|
mask_rect = Square(0.5)
|
|
mask_rect.set_stroke(RED, 2)
|
|
mask_rect.move_to(dot)
|
|
mask_rects.add(mask_rect)
|
|
|
|
lag_ratio=1.0 / len(mask_rects)
|
|
self.play(ShowCreation(mask_rects, lag_ratio=lag_ratio))
|
|
self.play(
|
|
LaggedStart(
|
|
(dot.animate.scale(0) for dot in masked_dots),
|
|
lag_ratio=lag_ratio
|
|
)
|
|
)
|
|
self.play(
|
|
FadeOut(mask_rects, lag_ratio=lag_ratio)
|
|
)
|
|
self.wait()
|
|
|
|
# Set aside keys and queries
|
|
pattern = VGroup(grid_lines, dots)
|
|
for group in q_groups:
|
|
group.sort(lambda p: -p[1])
|
|
group.target = group.generate_target()
|
|
m3 = len(group) - 3
|
|
group.target[m3:].scale(0, about_edge=DOWN)
|
|
group.target[:m3].move_to(group, DOWN)
|
|
|
|
self.play(
|
|
frame.animate.move_to((-2.09, -5.59, 0.0)).set_height(12.95).set_anim_args(run_time=3),
|
|
LaggedStartMap(MoveToTarget, q_groups),
|
|
FadeOut(pattern_words),
|
|
v_lines.animate.stretch(0.95, 1, about_edge=DOWN),
|
|
)
|
|
self.play(
|
|
LaggedStartMap(FadeOut, k_syms, shift=0.5 * DOWN, lag_ratio=0.1),
|
|
LaggedStartMap(FadeOut, wk_syms, shift=0.5 * DOWN, lag_ratio=0.1),
|
|
)
|
|
self.wait()
|
|
|
|
# Add values
|
|
value_color = RED
|
|
big_wv_sym = Tex(R"W_V", font_size=90)
|
|
big_wv_sym.set_color(value_color)
|
|
big_wv_sym.next_to(h_lines, UP, MED_LARGE_BUFF, LEFT)
|
|
wv_word = Text("Value matrix", font_size=90)
|
|
wv_word.next_to(big_wv_sym, UP, MED_LARGE_BUFF)
|
|
wv_word.set_color(value_color)
|
|
|
|
wv_arrows = wk_arrows
|
|
v_sym_template = Tex(R"\vec{\textbf{V}}_{0}")
|
|
v_sym_template[0].scale(1.5, about_edge=DOWN)
|
|
v_sym_template.set_fill(value_color, border_width=1)
|
|
subscript = v_sym_template.make_number_changeable("0")
|
|
|
|
wv_syms = VGroup()
|
|
v_syms = VGroup()
|
|
for n, arrow in enumerate(wv_arrows, start=1):
|
|
wv_sym = Tex("W_V", font_size=36)
|
|
wv_sym.set_fill(value_color, border_width=1)
|
|
wv_sym.next_to(arrow, UP, buff=0.2, aligned_edge=LEFT)
|
|
subscript.set_value(n)
|
|
v_sym = v_sym_template.copy()
|
|
v_sym.next_to(arrow, RIGHT, MED_SMALL_BUFF)
|
|
|
|
v_syms.add(v_sym)
|
|
wv_syms.add(wv_sym)
|
|
|
|
self.play(
|
|
FadeIn(big_wv_sym, 0.5 * DOWN),
|
|
FadeIn(wv_word, lag_ratio=0.1),
|
|
)
|
|
self.play(
|
|
LaggedStart(
|
|
(TransformFromCopy(big_wv_sym, wv_sym)
|
|
for wv_sym in wv_syms),
|
|
lag_ratio=0.15,
|
|
),
|
|
run_time=3
|
|
)
|
|
self.play(
|
|
LaggedStart(
|
|
(TransformFromCopy(e_sym, v_sym)
|
|
for e_sym, v_sym in zip(key_emb_syms, v_syms)),
|
|
lag_ratio=0.15,
|
|
),
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
FadeTransform(v_syms, k_syms),
|
|
FadeTransform(wv_syms, wk_syms),
|
|
rate_func=there_and_back_with_pause,
|
|
run_time=3,
|
|
)
|
|
self.remove(k_syms, wk_syms)
|
|
self.add(v_syms, wv_syms)
|
|
self.wait()
|
|
|
|
# Show column of weights
|
|
index = words.index("creature")
|
|
weighted_sum_cols = VGroup()
|
|
for sm_col in softmax_cols:
|
|
weighted_sum_col = VGroup()
|
|
for weight, v_sym in zip(sm_col, v_syms):
|
|
product = VGroup(weight, v_sym.copy())
|
|
product.target = product.generate_target()
|
|
product.target.arrange(RIGHT)
|
|
product.target[1].shift(UP * (
|
|
product.target[0].get_y(DOWN) -
|
|
product.target[1][1].get_y(DOWN)
|
|
))
|
|
product.target.scale(0.75)
|
|
product.target.move_to(weight)
|
|
product.target.set_fill(
|
|
opacity=clip(0.6 + weight.get_value(), 0, 1)
|
|
)
|
|
weighted_sum_col.add(product)
|
|
weighted_sum_cols.add(weighted_sum_col)
|
|
|
|
self.play(
|
|
FadeOut(dots, lag_ratio=0.1),
|
|
FadeIn(q_fade_rects[:index]),
|
|
FadeIn(q_fade_rects[index + 1:]),
|
|
FadeIn(softmax_cols[index]),
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
LaggedStartMap(MoveToTarget, weighted_sum_cols[index])
|
|
)
|
|
self.wait()
|
|
|
|
# Emphasize fluffy and blue weights
|
|
rects = VGroup(
|
|
key_word_groups[i][0].copy()
|
|
for i in [1, 2]
|
|
)
|
|
alt_rects = VGroup(
|
|
SurroundingRectangle(value, buff=SMALL_BUFF)
|
|
for value in (* softmax_cols[index][:1], *softmax_cols[index][3:])
|
|
)
|
|
alt_rects.set_stroke(RED, 1)
|
|
self.play(
|
|
LaggedStart(
|
|
(rect.animate.surround(value)
|
|
for rect, value in zip(rects, softmax_cols[index][1:3])),
|
|
lag_ratio=0.2,
|
|
)
|
|
)
|
|
self.wait()
|
|
self.play(Transform(rects, alt_rects))
|
|
self.wait()
|
|
self.play(FadeOut(rects, lag_ratio=0.1))
|
|
|
|
# Show sum (Start re-rendering here, 151)
|
|
emb_sym = emb_syms[index]
|
|
ws_col = weighted_sum_cols[index]
|
|
creature = images[2]
|
|
creature.set_height(1.5)
|
|
creature.next_to(word_groups[index], UP)
|
|
|
|
emb_sym.target = emb_sym.generate_target()
|
|
emb_sym.target.scale(1.25, about_edge=UP)
|
|
sum_rect = SurroundingRectangle(emb_sym.target)
|
|
sum_rect.set_stroke(YELLOW, 2)
|
|
sum_rect.target = sum_rect.generate_target()
|
|
sum_rect.target.surround(ws_col, buff=MED_SMALL_BUFF)
|
|
plusses = VGroup()
|
|
for m1, m2 in zip(ws_col, ws_col[1:]):
|
|
plus = Tex(R"+", font_size=72)
|
|
plus.move_to(midpoint(m1.get_bottom(), m2.get_top()))
|
|
plusses.add(plus)
|
|
|
|
self.play(
|
|
frame.animate.reorient(0, 0, 0, (-2.6, -4.79, 0.0), 15.07).set_anim_args(run_time=2),
|
|
MoveToTarget(emb_sym),
|
|
ShowCreation(sum_rect),
|
|
FadeIn(creature, UP),
|
|
FadeOut(wv_word),
|
|
FadeOut(big_wv_sym),
|
|
)
|
|
self.add(Point(), q_fade_rects[index + 1:]) # Hack
|
|
self.wait()
|
|
self.play(
|
|
frame.animate.reorient(0, 0, 0, (-2.9, -6.5, 0.0), 19).set_anim_args(run_time=2),
|
|
MoveToTarget(sum_rect, run_time=2),
|
|
Write(plusses),
|
|
)
|
|
self.wait()
|
|
|
|
# Show Delta E
|
|
low_eqs = VGroup(
|
|
Tex("=", font_size=72).rotate(PI / 2).next_to(wsc[-1].target, DOWN, buff=0.5)
|
|
for wsc in weighted_sum_cols
|
|
)
|
|
low_eqs.set_color(YELLOW)
|
|
delta_Es = VGroup()
|
|
for emb_sym, eq in zip(emb_syms, low_eqs):
|
|
delta = Tex(R"\Delta")
|
|
delta.match_height(emb_sym[1])
|
|
delta.next_to(emb_sym[1], LEFT, buff=0, aligned_edge=DOWN)
|
|
delta_E = VGroup(delta, emb_sym.copy())
|
|
delta_E.set_color(YELLOW)
|
|
delta_E.set_height(0.8)
|
|
delta_E.next_to(eq, DOWN)
|
|
delta_Es.add(delta_E)
|
|
|
|
self.play(
|
|
LaggedStart(
|
|
(FadeTransform(term.copy(), delta_Es[index])
|
|
for term in weighted_sum_cols[index]),
|
|
lag_ratio=0.05,
|
|
group_type=Group
|
|
),
|
|
Write(low_eqs[index])
|
|
)
|
|
self.wait()
|
|
|
|
# Add Delta E
|
|
creature_group = Group(creature, q_groups[index]).copy()
|
|
creature_group.target = creature_group.generate_target()
|
|
creature_group.target.scale(1.5)
|
|
creature_group.target.next_to(h_lines, RIGHT, buff=4.0)
|
|
creature_group.target.align_to(creature, UP)
|
|
right_plus = Tex("+", font_size=96)
|
|
right_eq = Tex("=", font_size=120).rotate(PI / 2)
|
|
right_plus.next_to(creature_group.target, DOWN)
|
|
creature_delta_E = delta_Es[index].copy()
|
|
creature_delta_E.target = creature_delta_E.generate_target()
|
|
creature_delta_E.target.set_height(1.0)
|
|
creature_delta_E.target.next_to(right_plus, DOWN)
|
|
right_eq.next_to(creature_delta_E.target, DOWN, MED_LARGE_BUFF)
|
|
E_prime = emb_sym_primes[index].copy()
|
|
E_prime.set_height(1.25)
|
|
E_prime.next_to(right_eq, DOWN, MED_LARGE_BUFF)
|
|
blue_fluff.set_height(2.5)
|
|
blue_fluff.next_to(E_prime, DOWN, MED_LARGE_BUFF)
|
|
|
|
self.play(LaggedStart(
|
|
frame.animate.reorient(0, 0, 0, (4.96, -5.61, 0.0), 19.00),
|
|
MoveToTarget(creature_group),
|
|
FadeTransform(sum_rect.copy(), right_plus),
|
|
MoveToTarget(creature_delta_E),
|
|
run_time=2,
|
|
lag_ratio=0.1
|
|
))
|
|
self.wait()
|
|
self.play(
|
|
FadeTransform(creature_group[1][-4].copy(), E_prime),
|
|
FadeTransform(creature_delta_E.copy(), E_prime),
|
|
Write(right_eq),
|
|
FadeTransform(creature_group[0].copy(), blue_fluff, path_arc=-PI / 2, run_time=2)
|
|
)
|
|
self.wait()
|
|
|
|
right_sum_group = Group(
|
|
creature_group, right_plus, creature_delta_E,
|
|
right_eq, E_prime, blue_fluff
|
|
)
|
|
|
|
# Show all column sums
|
|
plus_groups = VGroup(
|
|
plusses.copy().match_x(col[0].target)
|
|
for col in weighted_sum_cols
|
|
)
|
|
plus_groups.set_fill(GREY_C, 1)
|
|
|
|
for col in softmax_cols:
|
|
for value in col:
|
|
value.set_fill(
|
|
opacity=clip(0.6 + value.get_value(), 0, 1)
|
|
)
|
|
|
|
self.play(LaggedStart(
|
|
right_sum_group.animate.fade(0.75),
|
|
FadeOut(sum_rect),
|
|
FadeOut(creature),
|
|
FadeOut(q_fade_rects[:index]),
|
|
FadeOut(q_fade_rects[index + 1:]),
|
|
FadeIn(softmax_cols[:index]),
|
|
FadeIn(softmax_cols[index + 1:]),
|
|
plusses.animate.set_fill(GREY_C, 1),
|
|
run_time=2,
|
|
))
|
|
self.play(
|
|
LaggedStart(
|
|
(LaggedStartMap(MoveToTarget, col)
|
|
for col in weighted_sum_cols),
|
|
lag_ratio=0.1
|
|
),
|
|
v_lines.animate.set_stroke(GREY_A, 4, 1),
|
|
*(
|
|
e_sym.animate.scale(1.25, about_edge=UP)
|
|
for e_sym in (*emb_syms[:index], *emb_syms[index + 1:])
|
|
),
|
|
)
|
|
|
|
other_indices = [*range(index), *range(index + 1, len(plus_groups))]
|
|
self.play(LaggedStart(
|
|
(LaggedStart(
|
|
FadeIn(plus_groups[j], lag_ratio=0.1),
|
|
Write(low_eqs[j]),
|
|
LaggedStart(
|
|
(FadeTransform(ws.copy(), delta_Es[j])
|
|
for ws in weighted_sum_cols[j]),
|
|
lag_ratio=0.05,
|
|
group_type=Group
|
|
),
|
|
lag_ratio=0.05,
|
|
)
|
|
for j in other_indices),
|
|
lag_ratio=0.1,
|
|
group_type=Group
|
|
))
|
|
self.wait()
|
|
|
|
# Add all deltas to embeddings
|
|
equations = VGroup()
|
|
equation_targets = VGroup()
|
|
for E, dE, Ep in zip(emb_syms.copy(), delta_Es.copy(), emb_sym_primes):
|
|
Ep.match_height(E)
|
|
plus = Tex("+", font_size=96)
|
|
eq = Tex("=", font_size=96).rotate(PI / 2)
|
|
equation = VGroup(E, plus, dE, eq, Ep)
|
|
equation.target = equation.generate_target()
|
|
for mob in equation.target[::2]:
|
|
mob.set_height(0.8)
|
|
equation.target.arrange(DOWN)
|
|
for mob in [Ep, plus, eq]:
|
|
mob.set_opacity(0)
|
|
mob.move_to(dE)
|
|
equations.add(equation)
|
|
equation_targets.add(equation.target)
|
|
|
|
equation_targets.scale(1.25)
|
|
equation_targets.arrange(RIGHT, buff=0.75)
|
|
equation_targets.next_to(h_lines, RIGHT, buff=1.5)
|
|
equation_targets.match_y(h_lines)
|
|
|
|
self.play(
|
|
frame.animate.reorient(0, 0, 0, (9.5, -7.17, 0.0), 20.33),
|
|
LaggedStartMap(MoveToTarget, equations, lag_ratio=0.05),
|
|
FadeTransform(right_sum_group, equation_targets[index]),
|
|
run_time=2.0
|
|
)
|
|
self.wait()
|
|
|
|
result_rect = SurroundingRectangle(
|
|
VGroup(eq[-1] for eq in equations),
|
|
buff=0.25
|
|
)
|
|
result_rect.set_stroke(TEAL, 3)
|
|
self.play(
|
|
ShowCreation(result_rect)
|
|
)
|
|
self.wait()
|
|
|
|
def bake_mobject_into_vector_entries(self, mob, vector, path_arc=30 * DEGREES, group_type=None):
|
|
entries = vector.get_entries()
|
|
mob_copies = mob.replicate(len(entries))
|
|
return AnimationGroup(
|
|
LaggedStart(
|
|
(FadeOutToPoint(mc, entry.get_center(), path_arc=path_arc)
|
|
for mc, entry in zip(mob_copies, entries)),
|
|
lag_ratio=0.05,
|
|
group_type=group_type,
|
|
run_time=2,
|
|
remover=True
|
|
),
|
|
RandomizeMatrixEntries(
|
|
vector,
|
|
rate_func=lambda t: clip(smooth(2 * t - 1), 0, 1),
|
|
run_time=2
|
|
),
|
|
)
|
|
|
|
def scrap():
|
|
# To be inserted after Show grid of dots sections
|
|
self.remove(dot_prods)
|
|
np.random.seed(time.gmtime().tm_sec)
|
|
pattern = np.random.normal(0, 1, (8, 8))
|
|
for n in range(len(pattern[0])):
|
|
pattern[:, n][n + 1:] = -np.inf
|
|
pattern[:, n] = softmax(pattern[:, n])
|
|
for row, arr in zip(dots, pattern):
|
|
for dot, value in zip(row, arr):
|
|
dot.set_width(value**0.5)
|
|
dots.set_fill(GREY_B, 1)
|
|
return
|
|
|
|
### To be inserted in "Show softmax" section
|
|
np.random.seed(time.gmtime().tm_sec)
|
|
softmax_arrays = np.random.normal(0, 1, (8, 8))
|
|
for n in range(len(softmax_arrays[0])):
|
|
softmax_arrays[:, n][n + 1:] = -np.inf
|
|
softmax_arrays[:, n] = softmax(softmax_arrays[:, n])
|
|
softmax_arrays = softmax_arrays.T
|
|
###
|
|
|
|
def thumbnail():
|
|
### Thumbnail design, insert in the middle of softmax show columns ###
|
|
self.remove(q_groups)
|
|
self.add(q_syms)
|
|
out_dots = VGroup()
|
|
for col in softmax_cols:
|
|
for value in col:
|
|
dot = Dot(radius=0.35)
|
|
dot.move_to(value)
|
|
dot.set_fill(WHITE, opacity=interpolate(0.1, 0.9, value.get_value()))
|
|
out_dots.add(dot)
|
|
out_dots.shift(2 * OUT)
|
|
out_dots.set_stroke(WHITE, 2, 0.25)
|
|
self.remove(softmax_cols)
|
|
self.remove(sm_rects[last_index])
|
|
self.add(out_dots)
|
|
index = 3
|
|
ndp_columns[-1].set_opacity(0.25)
|
|
ndp_columns[index].set_opacity(1)
|
|
sm_label_group = VGroup(sm_arrows[last_index], sm_labels[last_index])
|
|
sm_label_group.match_x(ndp_columns[index])
|
|
sm_label_group[1].scale(1.5, about_edge=DOWN)
|
|
sm_label_group[1].set_fill(border_width=0)
|
|
col_rect.match_x(ndp_columns[index])
|
|
col_rect.set_flat_stroke(False)
|
|
sm_col = col_rect.copy()
|
|
# sm_col.set_width(out_dots[0].get_width() + 0.2)
|
|
sm_col.match_z(out_dots)
|
|
sm_col.set_flat_stroke(False)
|
|
self.add(sm_col)
|
|
self.remove(sm_labels[last_index])
|
|
sm_arrows[last_index].set_stroke(width=10)
|
|
sm_arrows[last_index].shift(OUT)
|
|
|
|
grid_lines.set_stroke(WHITE, 2)
|
|
v_lines.set_height(12, about_edge=DOWN, stretch=True)
|
|
|
|
frame.set_field_of_view(35 * DEGREES)
|
|
frame.reorient(-52, -2, 0, (-1.74, -7.1, -0.03), 14.72)
|
|
###
|
|
|
|
### To be inserted before Set aside keys and queries
|
|
frame.move_to([-4.62, -5.04, 0.0]).set_height(14.5)
|
|
self.remove(pattern_words)
|
|
|
|
for dot in dots.family_members_with_points():
|
|
value = dot.get_radius() / 0.5
|
|
dot.set_fill(WHITE, opacity=value**0.75)
|
|
dot.set_width(1)
|
|
|
|
title = Text("Attention", font_size=250)
|
|
title.set_fill(border_width=2)
|
|
title.next_to(q_syms, LEFT, LARGE_BUFF, DOWN)
|
|
title.shift(0.5 * UP)
|
|
# self.add(title)
|
|
|
|
q_syms.set_fill(border_width=1.5)
|
|
k_syms.set_fill(border_width=1.5)
|
|
for q in q_syms:
|
|
q.scale(1.5, about_edge=DOWN)
|
|
for k in k_syms:
|
|
k.scale(1.5, about_edge=RIGHT)
|
|
|
|
self.remove(word_groups, q_arrows, emb_arrows, emb_syms, wq_syms)
|
|
VGroup(key_word_groups, key_emb_syms, key_emb_arrows, wk_arrows, wk_syms).shift(0.25 * LEFT)
|
|
###
|
|
|
|
|
|
class MyseteryNovel(InteractiveScene):
|
|
def construct(self):
|
|
# Create paragraphs
|
|
text = Path(DATA_DIR, "murder_story.txt").read_text()
|
|
paragraphs = VGroup(
|
|
get_paragraph(para.split(" "), line_len=40)
|
|
for para in text.split("\n\n")
|
|
)
|
|
dots = Tex(R"\vdots", font_size=200)
|
|
paragraphs.replace_submobject(4, dots)
|
|
paragraphs.arrange(DOWN, buff=1.0, aligned_edge=LEFT)
|
|
dots.match_x(paragraphs)
|
|
self.add(paragraphs)
|
|
|
|
# Mark last word
|
|
last_word = paragraphs[-1]["Derek!\""][0]
|
|
rect = SurroundingRectangle(last_word)
|
|
rect.set_stroke(YELLOW, 2)
|
|
rect.set_fill(YELLOW, 0.25)
|
|
q_marks = Tex("???")
|
|
q_marks.move_to(rect)
|
|
rect.add(q_marks)
|
|
rect.shift(0.05 * DR)
|
|
|
|
last_word.scale(0).set_fill(BLACK)
|
|
self.add(rect)
|
|
|
|
# Show the first line
|
|
frame = self.frame
|
|
frame.set_y(15)
|
|
paragraphs.set_fill(opacity=0.25)
|
|
opening = paragraphs[0]["It was a dark and stormy night."][0]
|
|
self.play(opening.animate.set_fill(opacity=1).set_anim_args(lag_ratio=0.1))
|
|
self.wait()
|
|
|
|
# Scroll down
|
|
penultimate_words = paragraphs[-1]["therefore, the murderer was"][0]
|
|
self.play(
|
|
frame.animate.set_y(-15.4),
|
|
paragraphs.animate.set_fill(opacity=1).set_anim_args(lag_ratio=0.01),
|
|
run_time=5,
|
|
)
|
|
self.wait()
|
|
self.add(penultimate_words.copy())
|
|
self.play(paragraphs.animate.set_opacity(0.25))
|
|
self.wait()
|
|
|
|
# Show the final vector
|
|
was = penultimate_words[-3:]
|
|
arrow = FillArrow(ORIGIN, DOWN, buff=0, thickness=0.07)
|
|
arrow.next_to(was, DOWN, MED_SMALL_BUFF)
|
|
vect = NumericEmbedding(length=12)
|
|
vect.set_height(5)
|
|
vect.next_to(arrow, DOWN)
|
|
|
|
self.play(LaggedStart(
|
|
frame.animate.set_y(-17.5).set_height(12.5),
|
|
FadeIn(arrow, scale=3, shift=DOWN),
|
|
FadeIn(vect, DOWN),
|
|
run_time=2
|
|
))
|
|
self.context_anim(paragraphs[-1], vect)
|
|
self.wait()
|
|
|
|
# Zoom out more
|
|
vect_group = VGroup(arrow, vect)
|
|
vect_group.target = vect_group.generate_target()
|
|
vect_group.target.scale(2.35, about_edge=UP)
|
|
self.play(
|
|
paragraphs.animate.set_fill(opacity=0.8),
|
|
frame.animate.set_height(37).set_y(-14),
|
|
MoveToTarget(vect_group),
|
|
run_time=2
|
|
)
|
|
self.context_anim(paragraphs[-4:], vect)
|
|
|
|
def context_anim(self, source, vect):
|
|
flat_source = VGroup(*source.family_members_with_points())
|
|
vect_len = len(vect.get_entries())
|
|
self.play(
|
|
LaggedStart(
|
|
(ContextAnimation(
|
|
entry, flat_source[n::vect_len],
|
|
path_arc=-PI / 2,
|
|
run_time=5,
|
|
lag_ratio=1e-3,
|
|
max_stroke_width=2
|
|
)
|
|
for n, entry in enumerate(vect.get_entries())),
|
|
lag_ratio=0.1,
|
|
),
|
|
RandomizeMatrixEntries(vect, run_time=5),
|
|
)
|
|
|
|
|
|
class RoadNotTaken(InteractiveScene):
|
|
def construct(self):
|
|
# Add poem
|
|
stanza_strs = [
|
|
"""
|
|
Two roads diverged in a yellow wood,
|
|
And sorry I could not travel both
|
|
And be one traveler, long I stood
|
|
And looked down one as far as I could
|
|
To where it bent in the undergrowth;
|
|
""",
|
|
"""
|
|
Then took the other, as just as fair,
|
|
And having perhaps the better claim,
|
|
Because it was grassy and wanted wear;
|
|
Though as for that the passing there
|
|
Had worn them really about the same,
|
|
""",
|
|
"""
|
|
And both that morning equally lay
|
|
In leaves no step had trodden black.
|
|
Oh, I kept the first for another day!
|
|
Yet knowing how way leads on to way,
|
|
I doubted if I should ever come back.
|
|
""",
|
|
"""
|
|
I shall be telling this with a sigh
|
|
Somewhere ages and ages hence:
|
|
Two roads diverged in a wood, and I—
|
|
I took the one less traveled by,
|
|
And that has made all the difference.
|
|
""",
|
|
]
|
|
poem = Text("\n\n".join(stanza_strs), alignment="LEFT")
|
|
stanzas = VGroup(poem[stanza_str][0] for stanza_str in stanza_strs)
|
|
stanzas.arrange_in_grid(h_buff=1.5, v_buff=1.0, fill_rows_first=False)
|
|
stanzas.set_width(FRAME_WIDTH - 1)
|
|
stanzas.move_to(0.5 * UP)
|
|
poem.refresh_bounding_box(recurse_down=True)
|
|
|
|
self.play(FadeIn(poem, lag_ratio=0.01, run_time=4))
|
|
self.wait()
|
|
|
|
# Note all text until "one"
|
|
rect = SurroundingRectangle(poem)
|
|
less = poem["less"][-1]
|
|
one = poem["one"][-1]
|
|
diff_rects = VGroup(
|
|
SurroundingRectangle(mob).scale(10, about_edge=UL)
|
|
for mob in [less, poem["And"][-1]]
|
|
)
|
|
for diff_rect in diff_rects:
|
|
rect = Difference(rect, diff_rect)
|
|
rect.set_stroke(TEAL, 3)
|
|
|
|
less_index = poem.submobjects.index(less[0])
|
|
faded_portion = poem[less_index:]
|
|
active_portion = poem[:less_index]
|
|
less_rect = SurroundingRectangle(less)
|
|
less_rect.set_stroke(YELLOW, 3)
|
|
one_rect = SurroundingRectangle(one)
|
|
one_rect.become(Difference(one_rect, less_rect))
|
|
one_rect.match_height(less_rect, about_edge=DOWN, stretch=True)
|
|
one_rect.set_stroke(BLUE, 3)
|
|
arrow = Vector(0.75 * UP)
|
|
arrow.next_to(one, DOWN, SMALL_BUFF)
|
|
arrow.set_stroke(YELLOW)
|
|
active_portion_copy = active_portion.copy()
|
|
active_portion_copy.set_color(TEAL_B)
|
|
|
|
self.play(
|
|
FadeIn(rect),
|
|
Write(active_portion_copy, run_time=2, stroke_color=TEAL, lag_ratio=0.01),
|
|
faded_portion.animate.set_fill(opacity=0.5),
|
|
)
|
|
self.play(FadeOut(active_portion_copy))
|
|
self.wait()
|
|
self.play(GrowArrow(arrow))
|
|
self.wait()
|
|
self.play(
|
|
ShowCreation(less_rect),
|
|
less.animate.set_fill(opacity=1),
|
|
arrow.animate.match_x(less),
|
|
)
|
|
self.wait()
|
|
self.remove(less_rect)
|
|
self.play(
|
|
arrow.animate.match_x(one),
|
|
TransformFromCopy(less_rect, one_rect),
|
|
)
|
|
self.wait()
|
|
|
|
# Highlight "two roads"
|
|
one = one.copy()
|
|
less = less.copy()
|
|
two_roads = poem["Two roads"][-1].copy()
|
|
took_the = poem["I took the"][-1].copy()
|
|
|
|
self.play(
|
|
FadeIn(two_roads, lag_ratio=0.1),
|
|
FadeIn(took_the, lag_ratio=0.1),
|
|
FadeIn(one),
|
|
arrow.animate.rotate(-PI / 2).next_to(two_roads, LEFT, SMALL_BUFF),
|
|
poem.animate.set_fill(opacity=0.5),
|
|
run_time=1.5
|
|
)
|
|
self.wait()
|
|
|
|
# Highlight "took the other" and "grassy and wanted wear"
|
|
top_two_roads = poem["Two roads diverged"][0].copy()
|
|
took_other = poem["Then took the other"][0].copy()
|
|
wanted_wear = poem["it was grassy and wanted wear"][0].copy()
|
|
for phrase in [top_two_roads, took_other, wanted_wear]:
|
|
phrase.set_fill(WHITE, 1)
|
|
|
|
self.play(
|
|
arrow.animate.rotate(PI / 2).next_to(top_two_roads, DOWN, SMALL_BUFF),
|
|
FadeIn(top_two_roads),
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
arrow.animate.rotate(3 * PI / 4).next_to(took_other, UP, SMALL_BUFF),
|
|
FadeIn(took_other)
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
arrow.animate.rotate(-PI / 2).next_to(wanted_wear, DOWN, SMALL_BUFF),
|
|
FadeIn(wanted_wear)
|
|
)
|
|
self.wait()
|
|
|
|
# Higlight words throughout
|
|
active_portion_copy.set_fill(YELLOW_A, 1)
|
|
|
|
self.play(
|
|
LaggedStart(
|
|
(FadeIn(char, rate_func=there_and_back_with_pause)
|
|
for char in active_portion_copy),
|
|
lag_ratio=0.005,
|
|
run_time=6
|
|
)
|
|
)
|
|
self.wait()
|
|
|
|
# Show less again
|
|
self.play(
|
|
arrow.animate.rotate(-PI / 4).next_to(less, DOWN, SMALL_BUFF),
|
|
ShowCreation(less_rect),
|
|
less.animate.set_fill(WHITE, 1)
|
|
)
|
|
self.wait()
|
|
|
|
# Show final embedding
|
|
frame = self.frame
|
|
embedding = NumericEmbedding(length=10)
|
|
embedding.set_height(3)
|
|
embedding.next_to(one, DOWN, buff=arrow.get_length() + 2 * SMALL_BUFF)
|
|
|
|
self.play(
|
|
arrow.animate.rotate(PI).next_to(one, DOWN, SMALL_BUFF).set_anim_args(path_arc=PI),
|
|
frame.animate.set_height(9).move_to(DOWN)
|
|
)
|
|
self.play(TransformFromCopy(one, embedding))
|
|
self.play(RandomizeMatrixEntries(embedding))
|
|
self.wait()
|
|
|
|
|
|
class QueryMap(InteractiveScene):
|
|
map_tex = "W_Q"
|
|
map_color = YELLOW
|
|
src_name = "Creature"
|
|
pos_word = "position 4"
|
|
trg_name = "Any adjectives\nbefore position 4?"
|
|
in_vect_color = BLUE_B
|
|
in_vect_coords = (3, 2, -2)
|
|
out_vect_coords = (-2, -1)
|
|
|
|
def construct(self):
|
|
# Setup 3d axes
|
|
axes_3d = ThreeDAxes((-4, 4), (-3, 3), (-4, 4))
|
|
xz_plane = NumberPlane(
|
|
(-4, 4), (-4, 4),
|
|
background_line_style=dict(
|
|
stroke_color=GREY,
|
|
stroke_width=1,
|
|
),
|
|
faded_line_ratio=0
|
|
)
|
|
xz_plane.rotate(90 * DEGREES, RIGHT)
|
|
xz_plane.move_to(axes_3d)
|
|
xz_plane.axes.set_opacity(0)
|
|
axes_3d.add(xz_plane)
|
|
axes_3d.set_height(2.0)
|
|
|
|
self.set_floor_plane("xz")
|
|
frame = self.frame
|
|
frame.set_field_of_view(30 * DEGREES)
|
|
frame.reorient(-32, 0, 0, (2.13, 1.11, 0.27), 4.50)
|
|
frame.add_ambient_rotation(1 * DEGREES)
|
|
|
|
self.add(axes_3d)
|
|
|
|
# Set up target plane
|
|
plane = NumberPlane(
|
|
(-3, 3), (-3, 3),
|
|
faded_line_ratio=1,
|
|
background_line_style=dict(
|
|
stroke_color=BLUE,
|
|
stroke_width=1,
|
|
stroke_opacity=0.75
|
|
),
|
|
faded_line_style=dict(
|
|
stroke_color=BLUE,
|
|
stroke_width=1,
|
|
stroke_opacity=0.25,
|
|
)
|
|
)
|
|
plane.set_height(3.5)
|
|
plane.to_corner(DR)
|
|
|
|
arrow = Tex(R"\longrightarrow")
|
|
arrow.set_width(2)
|
|
arrow.stretch(0.75, 1)
|
|
arrow.next_to(plane, LEFT, buff=1.0)
|
|
arrow.set_color(self.map_color)
|
|
|
|
map_name = Tex(self.map_tex, font_size=72)
|
|
map_name.set_color(self.map_color)
|
|
map_name.next_to(arrow.get_left(), UR, SMALL_BUFF).shift(0.25 * RIGHT)
|
|
|
|
for mob in [plane, arrow, map_name]:
|
|
mob.fix_in_frame()
|
|
|
|
self.add(plane)
|
|
self.add(arrow)
|
|
self.add(map_name)
|
|
|
|
# Add titles
|
|
titles = VGroup(
|
|
Text("Embedding space"),
|
|
Text("Query/Key space"),
|
|
)
|
|
subtitles = VGroup(
|
|
Text("12,288-dimensional"),
|
|
Text("128-dimensional"),
|
|
)
|
|
subtitles.scale(0.75)
|
|
subtitles.set_fill(GREY_B)
|
|
x_values = [-frame.get_x() * FRAME_HEIGHT / frame.get_height(), plane.get_x()]
|
|
for title, subtitle, x_value in zip(titles, subtitles, x_values):
|
|
subtitle.next_to(title, DOWN, SMALL_BUFF)
|
|
title.add(subtitle)
|
|
title.next_to(plane, UP, MED_LARGE_BUFF)
|
|
title.set_x(x_value)
|
|
title.fix_in_frame()
|
|
|
|
self.add(titles)
|
|
|
|
# Show vector transformation
|
|
in_vect = Arrow(axes_3d.get_origin(), axes_3d.c2p(*self.in_vect_coords), buff=0)
|
|
in_vect.set_stroke(self.in_vect_color)
|
|
in_vect_label = TexText("``" + self.src_name + "''", font_size=24)
|
|
pos_label = Text(self.pos_word, font_size=16)
|
|
pos_label.next_to(in_vect_label, DOWN, SMALL_BUFF)
|
|
pos_label.set_opacity(0.75)
|
|
in_vect_label.add(pos_label)
|
|
in_vect_label.set_color(self.in_vect_color)
|
|
in_vect_label.next_to(in_vect.get_end(), UP, SMALL_BUFF)
|
|
|
|
out_vect = Arrow(plane.get_origin(), plane.c2p(*self.out_vect_coords), buff=0)
|
|
out_vect.set_stroke(self.map_color)
|
|
out_vect_label = Text(self.trg_name, font_size=30)
|
|
out_vect_label.next_to(out_vect.get_end(), DOWN, buff=0.2)
|
|
out_vect_label.set_backstroke(BLACK, 5)
|
|
VGroup(out_vect, out_vect_label).fix_in_frame()
|
|
|
|
self.play(
|
|
GrowArrow(in_vect),
|
|
FadeInFromPoint(in_vect_label, axes_3d.get_origin()),
|
|
)
|
|
self.wait(2)
|
|
self.play(
|
|
TransformFromCopy(in_vect, out_vect),
|
|
FadeTransform(in_vect_label.copy(), out_vect_label),
|
|
run_time=2,
|
|
)
|
|
self.wait(20)
|
|
self.play(FadeOut(out_vect_label))
|
|
self.wait(5)
|
|
|
|
|
|
class KeyMap(QueryMap):
|
|
map_tex = "W_K"
|
|
map_color = TEAL
|
|
src_name = "Fluffy"
|
|
pos_word = "position 2"
|
|
trg_name = "Adjective at\nposition 2"
|
|
in_vect_color = BLUE_B
|
|
in_vect_coords = (-3, 1, 2)
|
|
out_vect_coords = (-1.75, -1)
|
|
|
|
|
|
class DescribeAttentionEquation(InteractiveScene):
|
|
def construct(self):
|
|
# Stage image
|
|
image = ImageMobject("AttentionPaperStill")
|
|
image.set_height(FRAME_HEIGHT)
|
|
self.add(image)
|
|
|
|
# Add equation
|
|
equation = Tex(R"\text{Attention}(Q, K, V) = \text{softmax}({QK^T \over \sqrt{d_k}}) V")
|
|
equation.set_height(1.06929)
|
|
equation.move_to([-0.41406, 1.177, 0])
|
|
|
|
self.play(
|
|
FadeIn(equation),
|
|
FadeOut(image),
|
|
)
|
|
self.wait()
|
|
|
|
# Show Q and K arrays
|
|
syms = ["Q", "K"]
|
|
colors = [YELLOW, TEAL]
|
|
q_array, k_array = arrays = VGroup(
|
|
self.get_array_representation(sym, color)
|
|
for sym, color in zip(syms, colors)
|
|
)
|
|
arrays.arrange(RIGHT, buff=1.5)
|
|
arrays.next_to(equation, DOWN, buff=1.0)
|
|
|
|
lil_rects = VGroup()
|
|
rect_lines = VGroup()
|
|
big_rects = VGroup()
|
|
for arr, sym, color in zip(arrays, syms, colors):
|
|
lil_rect = SurroundingRectangle(equation["Q"][0])
|
|
lil_rect.match_x(equation[sym][0])
|
|
big_rect = SurroundingRectangle(arr)
|
|
lines = VGroup(
|
|
Line(lil_rect.get_corner(DOWN + v), big_rect.get_corner(UP + v))
|
|
for v in [LEFT, RIGHT]
|
|
)
|
|
VGroup(lil_rect, big_rect, lines).set_stroke(color, 2)
|
|
lil_rects.add(lil_rect)
|
|
rect_lines.add(lines)
|
|
big_rects.add(big_rect)
|
|
|
|
self.play(
|
|
ShowCreation(lil_rect),
|
|
equation[sym].animate.set_color(color),
|
|
)
|
|
self.play(
|
|
TransformFromCopy(lil_rect, big_rect),
|
|
FadeInFromPoint(arr, lil_rect.get_center()),
|
|
ShowCreation(lines, lag_ratio=0)
|
|
)
|
|
self.wait()
|
|
|
|
# Highlight numerator
|
|
num_rect = SurroundingRectangle(equation["QK^T"])
|
|
num_rect.set_stroke(BLUE, 2)
|
|
|
|
self.play(
|
|
ReplacementTransform(lil_rects[0], num_rect),
|
|
ReplacementTransform(lil_rects[1], num_rect),
|
|
FadeOut(rect_lines)
|
|
)
|
|
self.wait()
|
|
|
|
# Arrange for grid
|
|
frame = self.frame
|
|
qs = q_array[1]
|
|
ks = k_array[1]
|
|
q_array.remove(qs)
|
|
k_array.remove(ks)
|
|
|
|
h_buff = 0.8
|
|
v_buff = 0.6
|
|
|
|
qs.target = qs.generate_target()
|
|
qs.target.scale(0.75)
|
|
qs.target.arrange(RIGHT, buff=h_buff)
|
|
qs.target.next_to(equation, DOWN, buff=0.75)
|
|
|
|
ks.target = ks.generate_target()
|
|
ks.target.scale(0.75)
|
|
ks.target.arrange(DOWN, buff=MED_LARGE_BUFF)
|
|
ks.target[-2].rotate(PI / 2)
|
|
ks.target.next_to(qs.target, DL, buff=v_buff)
|
|
|
|
self.play(
|
|
frame.animate.move_to(1.5 * DOWN),
|
|
FadeOut(q_array),
|
|
FadeOut(k_array),
|
|
MoveToTarget(qs),
|
|
MoveToTarget(ks),
|
|
big_rects[0].animate.surround(qs.target).set_stroke(opacity=0),
|
|
big_rects[1].animate.surround(ks.target).set_stroke(opacity=0),
|
|
run_time=2
|
|
)
|
|
|
|
# Add grid lines
|
|
grid = VGroup(qs, ks)
|
|
|
|
v_lines = Line(UP, DOWN).match_height(grid).scale(1.1).replicate(len(qs) + 1)
|
|
for v_line, mob in zip(v_lines, (ks, *qs)):
|
|
v_line.next_to(mob, RIGHT, buff=h_buff / 2)
|
|
v_line.align_to(qs, UP)
|
|
|
|
h_lines = Line(LEFT, RIGHT).match_width(grid).scale(1.1).replicate(len(ks) + 1)
|
|
for h_line, mob in zip(h_lines, (qs, *ks)):
|
|
h_line.next_to(mob, DOWN, buff=v_buff / 2)
|
|
h_line.align_to(ks, LEFT)
|
|
|
|
VGroup(v_lines, h_lines).set_stroke(GREY_B, 1)
|
|
|
|
grid.add(v_lines, h_lines)
|
|
|
|
self.play(
|
|
FadeIn(h_lines, lag_ratio=0.1),
|
|
FadeIn(v_lines, lag_ratio=0.1),
|
|
ks[-2].animate.match_y(h_lines[-3:-1]),
|
|
)
|
|
|
|
# Dot products
|
|
dot_prods = VGroup()
|
|
for q in qs:
|
|
for k in ks:
|
|
dot = Tex(".")
|
|
dot.match_x(q)
|
|
dot.match_y(k)
|
|
dot_prod = VGroup(q.copy(), dot, k.copy())
|
|
dot_prod.target = dot_prod.generate_target()
|
|
dot_prod.target.arrange(RIGHT, buff=SMALL_BUFF)
|
|
dot_prod.target.scale(0.7)
|
|
dot_prod.target.move_to(dot)
|
|
if len(q) == 3:
|
|
dot_prod.target[1:].scale(0)
|
|
for mob in dot_prod.target:
|
|
mob.move_to(dot)
|
|
elif len(k) == 3:
|
|
dot_prod.target[:-1].scale(0)
|
|
for mob in dot_prod.target:
|
|
mob.move_to(dot)
|
|
dot.set_opacity(0)
|
|
dot_prods.add(dot_prod)
|
|
|
|
self.play(
|
|
LaggedStartMap(MoveToTarget, dot_prods, lag_ratio=0.01),
|
|
run_time=3
|
|
)
|
|
self.wait()
|
|
|
|
# Add sqrt to denominator
|
|
sqrt_part = equation[R"\over \sqrt{d_k}"][0]
|
|
|
|
denoms = VGroup()
|
|
for dot_prod in dot_prods:
|
|
dot_prod.target = dot_prod.generate_target()
|
|
if 3 in [len(dot_prod[0]), len(dot_prod[2])]:
|
|
continue
|
|
denom = sqrt_part.copy()
|
|
denom.set_fill(opacity=0.9)
|
|
denom.match_width(dot_prod)
|
|
denom.move_to(dot_prod.get_center(), UP)
|
|
dot_prod.target.next_to(denom, UP, buff=SMALL_BUFF)
|
|
VGroup(dot_prod.target, denom).scale(0.75)
|
|
denoms.add(denom)
|
|
|
|
self.play(num_rect.animate.surround(equation[R"QK^T \over \sqrt{d_k}"]))
|
|
self.play(
|
|
LaggedStartMap(MoveToTarget, dot_prods, lag_ratio=0.05, time_span=(1, 3)),
|
|
LaggedStart(
|
|
(TransformFromCopy(sqrt_part, denom)
|
|
for denom in denoms),
|
|
lag_ratio=0.01,
|
|
),
|
|
run_time=3
|
|
)
|
|
self.wait()
|
|
|
|
# Highlight softmax
|
|
self.play(
|
|
num_rect.animate.surround(equation[R"\text{softmax}({QK^T \over \sqrt{d_k}})"])
|
|
)
|
|
self.wait()
|
|
|
|
# Mention V
|
|
v_parts = equation["V"]
|
|
v_rects = VGroup(map(SurroundingRectangle, v_parts))
|
|
v_rects.set_stroke(RED, 3)
|
|
|
|
self.play(
|
|
ReplacementTransform(VGroup(num_rect), v_rects),
|
|
v_parts.animate.set_color(RED),
|
|
)
|
|
self.wait()
|
|
|
|
def get_array_representation(self, sym, color=WHITE, length=7):
|
|
template = Tex(f"{sym}_0")
|
|
template.set_fill(color)
|
|
substr = template.make_number_changeable(0)
|
|
terms = VGroup()
|
|
term_lines = VGroup()
|
|
term_groups = VGroup()
|
|
for n in range(1, length + 1):
|
|
if n == length:
|
|
substr.become(Tex("n").replace(substr))
|
|
else:
|
|
substr.set_value(n)
|
|
substr.set_color(color)
|
|
term = template.copy()
|
|
lines = Line(ORIGIN, 0.5 * UP).replicate(2)
|
|
lines.arrange(DOWN, buff=term.get_height() + 2 * SMALL_BUFF)
|
|
lines.move_to(term)
|
|
term_lines.add(lines)
|
|
terms.add(term)
|
|
term_groups.add(VGroup(term, lines))
|
|
term_groups.arrange(RIGHT, buff=MED_SMALL_BUFF)
|
|
|
|
dots = Tex(R"\dots")
|
|
dots.replace(terms[-2], dim_to_match=0)
|
|
terms.replace_submobject(length - 2, dots)
|
|
term_groups.remove(term_groups[-2])
|
|
|
|
brackets = Tex("[]")
|
|
brackets.stretch(1.5, 1)
|
|
brackets.set_height(term_groups.get_height() + MED_SMALL_BUFF)
|
|
for bracket, vect in zip(brackets, [LEFT, RIGHT]):
|
|
bracket.next_to(terms, vect, SMALL_BUFF)
|
|
|
|
result = VGroup(brackets, terms, term_lines)
|
|
|
|
return result
|
|
|
|
|
|
class ShowAllPossibleNextTokenPredictions(InteractiveScene):
|
|
def construct(self):
|
|
# Add phrase
|
|
phrase = Text("the fluffy blue creature roamed the verdant forest despite")
|
|
plain_words = break_into_words(phrase)
|
|
rects = get_piece_rectangles(plain_words)
|
|
words = VGroup(VGroup(*pair) for pair in zip(rects, plain_words))
|
|
words = words[:-1]
|
|
words.to_edge(LEFT, buff=MED_LARGE_BUFF)
|
|
|
|
next_token_box = rects[-1].copy()
|
|
next_token_box.set_color(YELLOW)
|
|
next_token_box.set_stroke(YELLOW, 3)
|
|
next_token_box.next_to(words, RIGHT, buff=LARGE_BUFF)
|
|
q_marks = Tex("???")
|
|
q_marks.move_to(next_token_box)
|
|
next_token_box.add(q_marks)
|
|
|
|
arrow = Arrow(words, next_token_box, buff=SMALL_BUFF)
|
|
|
|
self.add(words)
|
|
self.play(
|
|
GrowArrow(arrow),
|
|
FadeIn(next_token_box, RIGHT)
|
|
)
|
|
self.wait()
|
|
|
|
# Set up subphrases
|
|
scale_factor = 0.75
|
|
v_buff = 0.4
|
|
subphrases = VGroup(
|
|
words[:n].copy().scale(scale_factor)
|
|
for n in range(1, len(words) + 1)
|
|
)
|
|
subphrases.arrange(DOWN, buff=v_buff, aligned_edge=LEFT)
|
|
subphrases.to_corner(UL)
|
|
|
|
rhs = VGroup(arrow, next_token_box)
|
|
alt_rhss = VGroup(
|
|
rhs.copy().scale(scale_factor).next_to(subphrase, RIGHT, SMALL_BUFF)
|
|
for subphrase in subphrases
|
|
)
|
|
|
|
self.play(
|
|
Transform(words, subphrases[-1]),
|
|
Transform(rhs, alt_rhss[-1]),
|
|
)
|
|
for n in range(len(subphrases) - 1, 0, -1):
|
|
sp1 = subphrases[n]
|
|
sp2 = subphrases[n - 1]
|
|
rhs1 = alt_rhss[n]
|
|
rhs2 = alt_rhss[n - 1]
|
|
self.play(
|
|
TransformFromCopy(sp1[:len(sp2)], sp2),
|
|
TransformFromCopy(rhs1, rhs2),
|
|
rate_func=linear,
|
|
run_time=0.5
|
|
)
|
|
self.wait()
|
|
|
|
# Highlight two examples
|
|
for phrase, alt_rhs in zip(subphrases, alt_rhss):
|
|
arrow = alt_rhs[0]
|
|
alt_rhs.remove(arrow)
|
|
phrase.add(arrow)
|
|
phrase.save_state()
|
|
index = 3
|
|
self.play(LaggedStart(
|
|
FadeOut(alt_rhss),
|
|
FadeOut(rhs),
|
|
words.animate.fade(0.75),
|
|
subphrases[:index].animate.fade(0.75),
|
|
subphrases[index + 1:].animate.fade(0.75),
|
|
subphrases[index].animate.align_to(3 * RIGHT, RIGHT),
|
|
))
|
|
self.wait()
|
|
self.play(
|
|
subphrases[index].animate.align_to(subphrases, LEFT).fade(0.75),
|
|
subphrases[5].animate.restore().align_to(3 * RIGHT, RIGHT),
|
|
)
|
|
self.wait()
|
|
|
|
def get_next_word_distribution():
|
|
pass
|
|
|
|
|
|
class ShowMasking(InteractiveScene):
|
|
def construct(self):
|
|
# Set up two patterns
|
|
shape = (6, 6)
|
|
left_grid = Square().get_grid(*shape, buff=0)
|
|
left_grid.set_shape(5.5, 5)
|
|
left_grid.to_edge(LEFT)
|
|
left_grid.set_y(-0.5)
|
|
left_grid.set_stroke(GREY_B, 1)
|
|
|
|
right_grid = left_grid.copy()
|
|
right_grid.to_edge(RIGHT)
|
|
|
|
grids = VGroup(left_grid, right_grid)
|
|
arrow = Arrow(left_grid, right_grid)
|
|
sm_label = Text("softmax")
|
|
sm_label.next_to(arrow, UP)
|
|
|
|
titles = VGroup(
|
|
Text("Unnormalized\nAttention Pattern"),
|
|
Text("Normalized\nAttention Pattern"),
|
|
)
|
|
for title, grid in zip(titles, grids):
|
|
title.next_to(grid, UP, buff=MED_LARGE_BUFF)
|
|
|
|
values_array = np.random.normal(0, 2, shape)
|
|
font_size = 30
|
|
raw_values = VGroup(
|
|
DecimalNumber(
|
|
value,
|
|
include_sign=True,
|
|
font_size=font_size,
|
|
).move_to(square)
|
|
for square, value in zip(left_grid, values_array.flatten())
|
|
)
|
|
|
|
self.add(left_grid)
|
|
self.add(right_grid)
|
|
self.add(titles)
|
|
self.add(arrow)
|
|
self.add(sm_label)
|
|
self.add(raw_values)
|
|
|
|
# Highlight lower lefts
|
|
changers = VGroup()
|
|
for n, dec in enumerate(raw_values):
|
|
i = n // shape[1]
|
|
j = n % shape[1]
|
|
if i > j:
|
|
changers.add(dec)
|
|
neg_inf = Tex(R"-\infty", font_size=36)
|
|
neg_inf.move_to(dec)
|
|
neg_inf.set_fill(RED, border_width=1.5)
|
|
dec.target = neg_inf
|
|
values_array[i, j] = -np.inf
|
|
rects = VGroup(map(SurroundingRectangle, changers))
|
|
rects.set_stroke(RED, 3)
|
|
|
|
self.play(LaggedStartMap(ShowCreation, rects))
|
|
self.play(
|
|
LaggedStartMap(FadeOut, rects),
|
|
LaggedStartMap(MoveToTarget, changers)
|
|
)
|
|
self.wait()
|
|
|
|
# Normalized values
|
|
normalized_array = np.array([
|
|
softmax(col)
|
|
for col in values_array.T
|
|
]).T
|
|
normalized_values = VGroup(
|
|
DecimalNumber(value, font_size=font_size).move_to(square)
|
|
for square, value in zip(right_grid, normalized_array.flatten())
|
|
)
|
|
for n, value in enumerate(normalized_values):
|
|
value.set_fill(opacity=interpolate(0.5, 1, rush_from(value.get_value())))
|
|
if (n // shape[1]) > (n % shape[1]):
|
|
value.set_fill(RED, 0.75)
|
|
|
|
self.play(
|
|
LaggedStart(
|
|
(FadeTransform(v1.copy(), v2)
|
|
for v1, v2 in zip(raw_values, normalized_values)),
|
|
lag_ratio=0.05,
|
|
group_type=Group
|
|
)
|
|
)
|
|
self.wait()
|
|
|
|
|
|
class ScalingAPattern(InteractiveScene):
|
|
def construct(self):
|
|
# Position grid
|
|
N = 50
|
|
grid = Square(side_length=1.0).get_grid(N, N, buff=0)
|
|
grid.set_stroke(GREY_A, 1)
|
|
grid.stretch(0.89, 0)
|
|
grid.stretch(0.70, 1)
|
|
# grid.move_to(1.67 * LEFT + 1.596 * UP, UL)
|
|
grid.move_to(5.0 * LEFT + 2.5 * UP, UL)
|
|
self.add(grid)
|
|
|
|
# Dots
|
|
values = np.random.normal(0, 1, (N, N))
|
|
dots = VGroup()
|
|
for n, row in enumerate(values):
|
|
row[:n] = -np.inf
|
|
for k, col in enumerate(values.T):
|
|
for n, value in enumerate(softmax(col)):
|
|
dot = Dot(radius=0.3 * value**0.75)
|
|
dot.move_to(grid[n * N + k])
|
|
dots.add(dot)
|
|
dots.set_fill(GREY_C, 1)
|
|
self.add(dots)
|
|
|
|
# Add symbols
|
|
q_template = Tex(R"\vec{\textbf{Q}}_0").set_color(YELLOW)
|
|
k_template = Tex(R"\vec{\textbf{K}}_0").set_color(TEAL)
|
|
for template in [q_template, k_template]:
|
|
template.scale(0.75)
|
|
template.substr = template.make_number_changeable("0")
|
|
|
|
qs = VGroup()
|
|
ks = VGroup()
|
|
for n, square in enumerate(grid[:N], start=1):
|
|
q_template.substr.set_value(n)
|
|
q_template.next_to(square, UP, buff=SMALL_BUFF)
|
|
qs.add(q_template.copy())
|
|
for k, square in enumerate(grid[::N], start=1):
|
|
k_template.substr.set_value(k)
|
|
k_template.next_to(square, LEFT, buff=2 * SMALL_BUFF)
|
|
ks.add(k_template.copy())
|
|
self.add(qs, ks)
|
|
|
|
# Slowly zoom out
|
|
self.play(
|
|
self.frame.animate.reorient(0, 0, 0, (14.72, -14.71, 0.0), 38.06),
|
|
grid.animate.set_stroke(width=1, opacity=0.25),
|
|
dots.animate.set_fill(GREY_B, 1).set_stroke(GREY_B, 1),
|
|
run_time=20,
|
|
)
|
|
self.wait()
|
|
|
|
|
|
class IntroduceValueMatrix(InteractiveScene):
|
|
def setup(self):
|
|
self.fix_new_entries_in_frame = False
|
|
super().setup()
|
|
|
|
def construct(self):
|
|
# Initialized axes
|
|
frame = self.frame
|
|
self.set_floor_plane("xz")
|
|
axes = ThreeDAxes((-4, 4), (-4, 4), (-4, 4))
|
|
plane = NumberPlane(
|
|
(-4, 4), (-4, 4),
|
|
background_line_style=dict(
|
|
stroke_color=GREY,
|
|
stroke_width=1,
|
|
stroke_opacity=0.5,
|
|
)
|
|
)
|
|
plane.axes.set_opacity(0)
|
|
plane.rotate(PI / 2, RIGHT)
|
|
axes.add(plane)
|
|
|
|
frame.reorient(5, -4, 0, (-4.66, 2.07, 0.04), 12.48)
|
|
# frame.add_ambient_rotation()
|
|
self.add(axes)
|
|
|
|
# Add word pair
|
|
words = VGroup(Text("blue"), Text("fluffy"), Text("creature"))
|
|
words.scale(1.5)
|
|
words.arrange(RIGHT, aligned_edge=UP)
|
|
words.to_edge(UP)
|
|
words.to_edge(LEFT, buff=0)
|
|
rects = get_piece_rectangles(words, h_buff=0.1)
|
|
rects[0].set_color(BLUE)
|
|
rects[1].set_color(TEAL)
|
|
rects[2].set_color(ORANGE)
|
|
arrows = VGroup(Vector(DOWN).next_to(rect, DOWN) for rect in rects)
|
|
embs = VGroup(
|
|
NumericEmbedding(length=8).set_height(4.0).next_to(arrow, DOWN)
|
|
for arrow in arrows
|
|
)
|
|
|
|
blue_group = VGroup(rects[0], words[0], arrows[0], embs[0])
|
|
blue_group.set_opacity(0)
|
|
|
|
self.fix_new_entries_in_frame = True
|
|
self.add(rects)
|
|
self.add(words)
|
|
self.add(arrows)
|
|
self.add(embs)
|
|
|
|
# Add word vectors
|
|
creature_vect = self.get_labeled_vector(axes, (-2, 3, 1), ORANGE, "Dalle3_creature")
|
|
with_fluffy_vect = self.get_labeled_vector(axes, (2, 3, 1), GREY_BROWN, "Dalle3_creature_2")
|
|
with_blue_vect = self.get_labeled_vector(axes, (1, 2, 4), BLUE, "BlueFluff")
|
|
|
|
self.wait()
|
|
self.fix_new_entries_in_frame = False
|
|
self.play(
|
|
FadeTransform(words[1].copy(), creature_vect[1]),
|
|
TransformFromCopy(
|
|
Arrow(embs[1].get_bottom(), embs[1].get_top(), buff=0).fix_in_frame().set_stroke(width=10, opacity=0.25),
|
|
creature_vect[0],
|
|
)
|
|
)
|
|
self.add(creature_vect)
|
|
|
|
# Show influence
|
|
diff_vect = Arrow(
|
|
creature_vect[0].get_end(),
|
|
with_fluffy_vect[0].get_end(),
|
|
buff=0
|
|
)
|
|
diff_vect.scale(0.95)
|
|
self.fix_new_entries_in_frame = False
|
|
self.play(
|
|
FadeTransform(creature_vect[1].copy(), with_fluffy_vect[1]),
|
|
TransformFromCopy(creature_vect[0], with_fluffy_vect[0]),
|
|
run_time=3,
|
|
)
|
|
self.add(with_fluffy_vect)
|
|
self.play(GrowArrow(diff_vect, run_time=2))
|
|
|
|
self.fix_new_entries_in_frame = True
|
|
self.play(
|
|
RandomizeMatrixEntries(embs[2], time_span=(1, 5)),
|
|
LaggedStart(
|
|
(ContextAnimation(entry, embs[1].get_entries(), path_arc=10 * DEGREES, lag_ratio=0.1)
|
|
for entry in embs[2].get_entries()),
|
|
lag_ratio=0.01,
|
|
run_time=5,
|
|
),
|
|
)
|
|
self.wait()
|
|
|
|
# Make room
|
|
corner_group = VGroup(rects, words, arrows, embs)
|
|
self.play(
|
|
frame.animate.reorient(10, -7, 0, (-8.33, -0.79, 0.37), 16.82),
|
|
corner_group.animate.set_height(3).to_edge(UP, buff=0.25).set_x(-2),
|
|
run_time=2
|
|
)
|
|
|
|
# Show value matrix
|
|
matrix = WeightMatrix(shape=(8, 8))
|
|
matrix.set_height(2.75)
|
|
matrix.to_corner(DL)
|
|
matrix_brace = Brace(matrix, UP)
|
|
matrix_label = Tex("W_V")
|
|
matrix_label.next_to(matrix_brace, UP)
|
|
matrix_label.set_color(RED)
|
|
|
|
fluff_emb = embs[1]
|
|
in_vect_rect = SurroundingRectangle(fluff_emb)
|
|
in_vect_rect.set_stroke(TEAL, 2)
|
|
in_vect = fluff_emb.copy()
|
|
in_vect.match_height(matrix)
|
|
in_vect.next_to(matrix, RIGHT, SMALL_BUFF)
|
|
in_vect_path = self.get_top_vect_to_low_vect_path(fluff_emb, in_vect, TEAL)
|
|
|
|
self.fix_new_entries_in_frame = True
|
|
self.play(
|
|
FadeIn(matrix, lag_ratio=1e-3),
|
|
GrowFromCenter(matrix_brace),
|
|
FadeIn(matrix_label, shift=0.25 * UP)
|
|
)
|
|
self.play(ShowCreation(in_vect_rect))
|
|
self.play(
|
|
ShowCreation(in_vect_path),
|
|
TransformFromCopy(fluff_emb, in_vect, path_arc=-20 * DEGREES),
|
|
run_time=2
|
|
)
|
|
|
|
# Show matrix product
|
|
eq, rhs = show_matrix_vector_product(self, matrix, in_vect)
|
|
self.wait()
|
|
|
|
# Position value vect
|
|
value_rect = SurroundingRectangle(rhs)
|
|
value_rect.set_stroke(RED, 2)
|
|
value_label = Text("Value")
|
|
value_label.next_to(value_rect, RIGHT)
|
|
value_label.set_color(RED)
|
|
value_label.set_backstroke()
|
|
self.fix_new_entries_in_frame = True
|
|
self.play(
|
|
ShowCreation(value_rect),
|
|
FadeIn(value_label, lag_ratio=0.1)
|
|
)
|
|
self.wait()
|
|
|
|
value_label2 = value_label.copy()
|
|
value_label2.set_backstroke(BLACK, 5)
|
|
value_label2.scale(1.5)
|
|
value_label2.next_to(diff_vect, UP, MED_SMALL_BUFF)
|
|
value_label2.unfix_from_frame()
|
|
|
|
self.fix_new_entries_in_frame = False
|
|
self.play(
|
|
frame.animate.reorient(29, -2, 0, (-7.48, 1.91, 1.21), 11.89),
|
|
FadeInFromPoint(value_label2, np.array([-4, -5, 0])),
|
|
TransformFromCopy(value_rect, diff_vect),
|
|
run_time=2
|
|
)
|
|
self.wait()
|
|
|
|
# Show blue
|
|
blue_group.target = blue_group.generate_target()
|
|
blue_group.target[0].set_stroke(opacity=1)
|
|
blue_group.target[0].set_fill(opacity=0.2)
|
|
blue_group.target[1:].set_opacity(1)
|
|
blue_group.target.shift(0.2 * LEFT)
|
|
|
|
blue_path = self.get_top_vect_to_low_vect_path(blue_group.target, in_vect, BLUE)
|
|
blue_emb = blue_group[3]
|
|
blue_in_vect = blue_emb.copy().set_opacity(1)
|
|
blue_in_vect.replace(in_vect)
|
|
|
|
self.fix_new_entries_in_frame = True
|
|
self.play(
|
|
MoveToTarget(blue_group),
|
|
LaggedStartMap(FadeOut, VGroup(
|
|
in_vect_path, in_vect_rect,
|
|
rhs, value_rect, value_label,
|
|
value_label2,
|
|
)),
|
|
run_time=1
|
|
)
|
|
self.play(
|
|
TransformFromCopy(blue_emb, blue_in_vect),
|
|
ShowCreation(blue_path),
|
|
FadeOut(in_vect, 3 * DOWN),
|
|
run_time=1.5
|
|
)
|
|
eq, rhs2 = show_matrix_vector_product(self, matrix, blue_in_vect)
|
|
|
|
# Show in diagram
|
|
diff2 = Arrow(
|
|
with_fluffy_vect[0].get_end(),
|
|
with_blue_vect[0].get_end(),
|
|
buff=0.05
|
|
)
|
|
diff2.set_flat_stroke(False)
|
|
rhs_rect = SurroundingRectangle(rhs2)
|
|
rhs_rect.set_stroke(RED, 2)
|
|
|
|
self.fix_new_entries_in_frame = True
|
|
self.play(ShowCreation(rhs_rect))
|
|
self.fix_new_entries_in_frame = False
|
|
self.add(diff2)
|
|
self.play(
|
|
TransformFromCopy(rhs_rect, diff2),
|
|
FadeIn(diff2),
|
|
frame.animate.reorient(-16, -3, 0, (-6.41, 2.78, 1.37), 13.21),
|
|
TransformFromCopy(with_fluffy_vect[0], with_blue_vect[0]),
|
|
FadeTransform(with_fluffy_vect[1].copy(), with_blue_vect[1]),
|
|
run_time=2,
|
|
)
|
|
frame.add_ambient_rotation(2 * DEGREES)
|
|
self.wait(8)
|
|
|
|
|
|
def get_top_vect_to_low_vect_path(self, top_vect, low_vect, color, top_buff=0.1, low_buff=0.2, bezier_factor=1.5):
|
|
result = CubicBezier(
|
|
top_vect.get_bottom() + top_buff * DOWN,
|
|
top_vect.get_bottom() + bezier_factor * DOWN,
|
|
low_vect.get_top() + bezier_factor * UP,
|
|
low_vect.get_top() + low_buff * UP,
|
|
)
|
|
result.set_stroke(color, 3)
|
|
return result
|
|
|
|
def get_labeled_vector(self, axes, coords, color, image_name, image_height=1.0):
|
|
vect = Arrow(axes.get_origin(), axes.c2p(*coords), buff=0)
|
|
vect.set_color(color)
|
|
image = ImageMobject(image_name)
|
|
image.set_height(image_height)
|
|
image.next_to(vect.get_end(), UP, MED_SMALL_BUFF)
|
|
|
|
return Group(vect, image)
|
|
|
|
def add(self, *mobjects):
|
|
if self.fix_new_entries_in_frame:
|
|
for mob in mobjects:
|
|
mob.fix_in_frame()
|
|
super().add(*mobjects)
|
|
|
|
|
|
class CountMatrixParameters(InteractiveScene):
|
|
count_font_size = 36
|
|
|
|
def construct(self):
|
|
# Add three matrices
|
|
d_embed = 12_288
|
|
d_key = 128
|
|
key_mat_shape = (5, 10)
|
|
|
|
que_mat = WeightMatrix(shape=key_mat_shape)
|
|
key_mat = WeightMatrix(shape=key_mat_shape)
|
|
val_mat = WeightMatrix(shape=(key_mat_shape[1], key_mat_shape[1]))
|
|
matrices = VGroup(que_mat, key_mat, val_mat)
|
|
for matrix in matrices:
|
|
matrix.set_max_width(4)
|
|
|
|
matrices.arrange(DOWN, buff=0.75)
|
|
|
|
colors = [YELLOW, TEAL, RED]
|
|
|
|
titles = VGroup(Text("Query"), Text("Key"), Text("Value"))
|
|
que_title, key_title, val_title = titles
|
|
titles.arrange(DOWN, aligned_edge=LEFT)
|
|
titles.next_to(matrices, LEFT, LARGE_BUFF)
|
|
for title, matrix, color in zip(titles, matrices, colors):
|
|
title.match_y(matrix)
|
|
title.set_color(color)
|
|
|
|
self.play(
|
|
LaggedStartMap(FadeIn, titles, shift=0.25 * LEFT, lag_ratio=0.5),
|
|
LaggedStart(
|
|
(FadeIn(matrix, lag_ratio=1e-2)
|
|
for matrix in matrices),
|
|
lag_ratio=0.5,
|
|
)
|
|
)
|
|
self.wait()
|
|
|
|
# Data animations
|
|
change_anims = [RandomizeMatrixEntries(mat) for mat in matrices]
|
|
highlight_anims = [
|
|
LaggedStartMap(FlashUnder, mat.get_entries(), lag_ratio=5e-3, stroke_width=1)
|
|
for mat in matrices
|
|
]
|
|
|
|
self.play(
|
|
LaggedStart(highlight_anims, lag_ratio=0.2),
|
|
LaggedStart(change_anims, lag_ratio=0.2),
|
|
run_time=3
|
|
)
|
|
|
|
# Ask about total number of parameters
|
|
rects = VGroup(
|
|
SurroundingRectangle(entry, buff=0.025)
|
|
for matrix in matrices
|
|
for entry in matrix.get_entries()
|
|
)
|
|
rects.set_stroke(WHITE, 1)
|
|
question = Text("How many\nparameters?")
|
|
question.next_to(matrices, RIGHT, LARGE_BUFF)
|
|
|
|
self.play(
|
|
ShowCreation(rects, lag_ratio=5e-3, run_time=2),
|
|
Write(question)
|
|
)
|
|
self.play(FadeOut(rects))
|
|
self.wait()
|
|
|
|
# Make room to count query/key
|
|
value_group = VGroup(val_title, val_mat)
|
|
value_group.save_state()
|
|
qk_mats = matrices[:2]
|
|
qk_mats.target = qk_mats.generate_target()
|
|
qk_mats.target.arrange(RIGHT, buff=3.0)
|
|
qk_mats.target.move_to(DR)
|
|
|
|
self.play(
|
|
FadeOut(question, DR),
|
|
value_group.animate.scale(0.25).to_corner(DR).fade(0.25),
|
|
MoveToTarget(qk_mats),
|
|
que_title.animate.next_to(qk_mats.target[0], UP, buff=2.0),
|
|
key_title.animate.next_to(qk_mats.target[1], UP, buff=2.0),
|
|
)
|
|
|
|
# Count up query and key
|
|
que_col_count = self.show_column_count(que_mat, d_embed)
|
|
key_col_count = self.show_column_count(key_mat, d_embed)
|
|
self.wait()
|
|
que_row_count = self.show_row_count(que_mat, d_key)
|
|
key_row_count = self.show_row_count(key_mat, d_key)
|
|
self.wait()
|
|
|
|
que_product = self.show_product(
|
|
que_col_count, que_row_count,
|
|
added_anims=[que_title.animate.shift(UP)]
|
|
)
|
|
key_product = self.show_product(
|
|
key_col_count, key_row_count,
|
|
added_anims=[key_title.animate.shift(UP)]
|
|
)
|
|
self.wait()
|
|
|
|
# Pull up the value matrix
|
|
qk_titles = titles[:2]
|
|
qk_titles.target = qk_titles.generate_target()
|
|
qk_titles.target.arrange(DOWN, buff=2.0, aligned_edge=LEFT)
|
|
qk_titles.target.to_corner(UL)
|
|
qk_titles.target.scale(0.5, about_edge=UL)
|
|
|
|
qk_mats.target = qk_mats.generate_target()
|
|
|
|
qk_rhss = VGroup(que_product[-1], key_product[-1]).copy()
|
|
qk_rhss.target = qk_rhss.generate_target()
|
|
|
|
for mat, title, rhs in zip(qk_mats.target, qk_titles.target, qk_rhss.target):
|
|
rhs.scale(0.5)
|
|
mat.scale(0.5)
|
|
rhs.next_to(title, DOWN, SMALL_BUFF, aligned_edge=LEFT)
|
|
mat.next_to(VGroup(title, rhs), RIGHT, buff=MED_LARGE_BUFF)
|
|
|
|
self.play(
|
|
MoveToTarget(qk_titles),
|
|
MoveToTarget(qk_mats),
|
|
MoveToTarget(qk_rhss),
|
|
FadeOut(VGroup(
|
|
que_product, key_product,
|
|
que_col_count, que_row_count,
|
|
key_col_count, key_row_count,
|
|
), shift=0.5 * UL, lag_ratio=1e-3, time_span=(0, 1.0)),
|
|
value_group.animate.restore().arrange(DOWN, buff=1.0).move_to(2.0 * RIGHT + 0.5 * DOWN),
|
|
run_time=2,
|
|
)
|
|
self.wait()
|
|
|
|
# Count up current value
|
|
in_vect = NumericEmbedding(length=key_mat_shape[1])
|
|
in_vect.match_height(val_mat)
|
|
in_vect.next_to(val_mat, RIGHT, SMALL_BUFF)
|
|
|
|
val_col_count = self.show_column_count(
|
|
val_mat, d_embed,
|
|
added_anims=[val_title.animate.shift(UP)]
|
|
)
|
|
self.play(FadeIn(in_vect))
|
|
eq, rhs = show_matrix_vector_product(self, val_mat, in_vect)
|
|
val_row_count = self.show_row_count(val_mat, d_embed)
|
|
self.wait()
|
|
val_product = self.show_product(
|
|
val_col_count, val_row_count,
|
|
added_anims=[val_title.animate.shift(UP)]
|
|
)
|
|
self.wait()
|
|
|
|
# Compare the two
|
|
frame = self.frame
|
|
q_group, k_group = qk_groups = VGroup(
|
|
VGroup(*trip)
|
|
for trip in zip(qk_mats, qk_titles, qk_rhss)
|
|
)
|
|
for group, y in zip(qk_groups, [+1.25, -1.25]):
|
|
group.save_state()
|
|
group.target = group.generate_target()
|
|
group.target.scale(2)
|
|
group.target.next_to(val_mat, LEFT, buff=2.5)
|
|
group.target.set_y(y)
|
|
|
|
self.play(
|
|
frame.animate.reorient(0, 0, 0, (-1.58, 0.02, 0.0), 9.22),
|
|
LaggedStartMap(MoveToTarget, qk_groups),
|
|
)
|
|
self.wait()
|
|
|
|
# Circle both
|
|
val_rhs_rect = SurroundingRectangle(val_product[-1])
|
|
val_rhs_rect.set_stroke(RED_B, 3)
|
|
qk_rhs_rects = VGroup(
|
|
SurroundingRectangle(rhs) for rhs in qk_rhss
|
|
)
|
|
qk_rhs_rects[0].set_stroke(YELLOW, 3)
|
|
qk_rhs_rects[1].set_stroke(TEAL, 3)
|
|
|
|
big_rect = FullScreenFadeRectangle()
|
|
big_rect.scale(2)
|
|
big_rect.set_fill(opacity=0.5)
|
|
val_rhs_copy = val_product[-1].copy()
|
|
qk_rhs_copies = qk_rhss.copy()
|
|
|
|
self.add(big_rect, val_rhs_copy)
|
|
self.play(
|
|
FadeIn(big_rect),
|
|
ShowCreation(val_rhs_rect)
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
TransformFromCopy(VGroup(val_rhs_rect), qk_rhs_rects),
|
|
FadeIn(qk_rhs_copies)
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
LaggedStartMap(FadeOut, VGroup(
|
|
big_rect, qk_rhs_copies, val_rhs_copy,
|
|
qk_rhs_rects, val_rhs_rect
|
|
))
|
|
)
|
|
|
|
# Cross out
|
|
cross = Cross(val_product, stroke_width=[0, 12, 0]).scale(1.1)
|
|
self.play(LaggedStart(
|
|
FadeOut(qk_groups, 2 * UR, scale=0.5),
|
|
ShowCreation(cross),
|
|
frame.animate.set_height(FRAME_HEIGHT).move_to(RIGHT),
|
|
run_time=2,
|
|
lag_ratio=0.1
|
|
))
|
|
self.wait()
|
|
self.play(FadeOut(val_product), FadeOut(cross))
|
|
|
|
# Factor out
|
|
val_down_mat = WeightMatrix(shape=key_mat_shape)
|
|
val_up_mat = WeightMatrix(shape=(key_mat_shape[1], 4))
|
|
val_down_mat.match_width(val_mat)
|
|
val_up_mat.match_height(in_vect)
|
|
|
|
val_down_mat.move_to(val_mat, RIGHT)
|
|
val_up_mat.next_to(val_down_mat, LEFT, SMALL_BUFF)
|
|
|
|
self.remove(val_mat)
|
|
self.play(
|
|
TransformFromCopy(val_mat.get_brackets(), val_down_mat.get_brackets()),
|
|
TransformFromCopy(val_mat.get_columns(), val_down_mat.get_columns()),
|
|
TransformFromCopy(val_mat.get_brackets(), val_up_mat.get_brackets()),
|
|
TransformFromCopy(val_mat.get_rows(), val_up_mat.get_rows()),
|
|
val_col_count.animate.next_to(val_down_mat, UP, SMALL_BUFF),
|
|
val_row_count.animate.next_to(val_up_mat, LEFT, SMALL_BUFF),
|
|
)
|
|
self.add(val_down_mat)
|
|
self.wait()
|
|
|
|
# Circle the full linear map
|
|
big_rect = SurroundingRectangle(VGroup(val_row_count, val_col_count))
|
|
big_rect.round_corners(radius=0.25)
|
|
big_rect.set_stroke(RED_B, 2)
|
|
linear_map_words = Text("Linear map")
|
|
linear_map_words.next_to(big_rect, UP)
|
|
linear_map_words.set_color(RED_B)
|
|
|
|
in_label, out_label = [
|
|
VGroup(Text(text), Integer(d_embed))
|
|
for text in ["d_input", "d_output"]
|
|
]
|
|
for label, array, shift in [(in_label, in_vect, LEFT), (out_label, rhs, RIGHT)]:
|
|
label.arrange(DOWN)
|
|
label.scale(0.65)
|
|
label.next_to(array, UP, buff=LARGE_BUFF)
|
|
label.shift(0.25 * shift)
|
|
arrow = Arrow(label, array)
|
|
label.add(arrow)
|
|
|
|
self.play(
|
|
FadeIn(big_rect),
|
|
FadeTransform(val_title, linear_map_words),
|
|
)
|
|
self.wait()
|
|
self.play(FadeIn(in_label, lag_ratio=0.1))
|
|
self.play(FadeIn(out_label, lag_ratio=0.1))
|
|
self.wait(2)
|
|
|
|
# Show the value_down map
|
|
val_down_group = VGroup(val_down_mat, val_col_count)
|
|
val_up_group = VGroup(val_up_mat, val_row_count)
|
|
val_down_group.save_state()
|
|
val_up_group.save_state()
|
|
|
|
small_row_count = self.show_row_count(
|
|
val_down_mat, d_key,
|
|
added_anims=[val_up_group.animate.scale(0.5).to_edge(LEFT, buff=1.25).fade(0.5)]
|
|
)
|
|
self.wait()
|
|
self.play(frame.animate.set_y(0.5))
|
|
self.wait()
|
|
|
|
value_down_rect = SurroundingRectangle(
|
|
VGroup(small_row_count, val_down_mat, val_col_count)
|
|
)
|
|
value_down_rect.round_corners(radius=0.25)
|
|
value_down_rect.set_stroke(RED_B, 2)
|
|
value_down_title = TexText(R"Value$_\downarrow$")
|
|
value_down_title.set_fill(RED_B)
|
|
value_down_title.next_to(val_down_mat, DOWN)
|
|
|
|
self.remove(big_rect)
|
|
self.play(
|
|
TransformFromCopy(big_rect, value_down_rect),
|
|
FadeOut(linear_map_words),
|
|
FadeIn(value_down_title, DOWN)
|
|
)
|
|
self.wait()
|
|
|
|
# Show value_up map
|
|
small_row_count.target = small_row_count.generate_target()
|
|
small_row_count.target.rotate(-PI / 2)
|
|
small_row_count.target[1].rotate(PI / 2)
|
|
small_row_count.target[0].stretch_to_fit_width(val_up_group.saved_state[0].get_width())
|
|
small_row_count.target[1].next_to(small_row_count.target[0], UP, SMALL_BUFF)
|
|
small_row_count.target.next_to(val_up_group.saved_state[0], UP, SMALL_BUFF)
|
|
big_rect.set_height(3.9, stretch=True)
|
|
big_rect.align_to(VGroup(val_down_mat, val_up_group.saved_state), DR)
|
|
big_rect.shift(0.8 * DOWN + 0.05 * RIGHT)
|
|
linear_map_words.next_to(big_rect, UP)
|
|
|
|
value_up_title = TexText(R"Value$_\uparrow$")
|
|
value_up_title.set_fill(RED_B)
|
|
value_up_title.next_to(val_up_group.saved_state[0], DOWN)
|
|
|
|
self.play(LaggedStart(
|
|
val_down_group.animate.fade(0.5),
|
|
value_down_title.animate.fade(0.5),
|
|
ReplacementTransform(value_down_rect, big_rect),
|
|
Restore(val_up_group),
|
|
MoveToTarget(small_row_count),
|
|
FadeIn(linear_map_words, shift=0.5 * UP),
|
|
run_time=2,
|
|
))
|
|
val_up_group.add(small_row_count)
|
|
self.wait()
|
|
self.play(TransformFromCopy(value_down_title, value_up_title))
|
|
self.wait()
|
|
|
|
# Low rank label
|
|
low_rank_words = TexText("``Low rank'' transformation")
|
|
low_rank_words.next_to(big_rect, UP)
|
|
low_rank_words.shift(0.5 * LEFT)
|
|
self.play(
|
|
val_down_group.animate.set_fill(opacity=1),
|
|
value_down_title.animate.set_fill(opacity=1),
|
|
FadeTransform(linear_map_words, low_rank_words)
|
|
)
|
|
self.wait()
|
|
|
|
def scrap(self):
|
|
# Label the value matrix
|
|
tiny_buff = 0.025
|
|
value_rect = SurroundingRectangle(val_down_group, buff=tiny_buff)
|
|
value_rect.stretch(1.2, 1)
|
|
value_rect.round_corners(0.1)
|
|
value_rect.set_stroke(RED, 3)
|
|
value_arrow = Vector(DOWN)
|
|
value_arrow.match_color(value_rect)
|
|
value_arrow.next_to(value_rect, UP, SMALL_BUFF)
|
|
|
|
val_up_group.save_state()
|
|
out_rect = SurroundingRectangle(val_up_group, buff=tiny_buff)
|
|
out_rect.set_height(big_rect.get_height() - SMALL_BUFF, stretch=True)
|
|
out_rect.match_y(big_rect)
|
|
out_rect.round_corners(0.1)
|
|
out_rect.set_stroke(PINK, 3)
|
|
out_arrow = Vector(0.5 * DOWN)
|
|
out_arrow.next_to(out_rect, UP, SMALL_BUFF)
|
|
out_arrow.match_color(out_rect)
|
|
output_title = TexText("Output$^{*}$")
|
|
output_title.match_color(out_rect)
|
|
output_title.next_to(out_arrow, UP, SMALL_BUFF)
|
|
|
|
|
|
self.play(LaggedStart(
|
|
Restore(val_down_group),
|
|
LaggedStartMap(FadeOut, VGroup(in_label, out_label)),
|
|
TransformFromCopy(big_rect, value_rect),
|
|
FadeOut(linear_map_words),
|
|
val_title.animate.next_to(value_arrow, UP, SMALL_BUFF),
|
|
FadeIn(value_arrow, shift=DOWN),
|
|
val_up_group.animate.fade(0.5),
|
|
))
|
|
self.wait()
|
|
self.play(LaggedStart(
|
|
TransformFromCopy(big_rect, out_rect),
|
|
TransformFromCopy(value_arrow, out_arrow),
|
|
FadeTransform(val_title.copy(), output_title),
|
|
Restore(val_up_group),
|
|
))
|
|
self.wait()
|
|
|
|
def show_column_count(self, matrix, count, added_anims=[]):
|
|
cols = matrix.get_columns()
|
|
col_rects = VGroup(SurroundingRectangle(cols[0], buff=0).match_x(col) for col in cols)
|
|
col_rects.set_stroke(WHITE, 1, 0.5)
|
|
col_rects.set_fill(GREY_D, 0.5)
|
|
top_brace = Brace(col_rects, UP, buff=SMALL_BUFF)
|
|
count_mob = Integer(count, font_size=self.count_font_size)
|
|
count_mob.next_to(top_brace, UP)
|
|
|
|
self.play(
|
|
GrowFromCenter(top_brace),
|
|
CountInFrom(count_mob, 0),
|
|
FadeIn(col_rects, lag_ratio=0.25),
|
|
*added_anims,
|
|
)
|
|
self.play(FadeOut(col_rects))
|
|
return VGroup(top_brace, count_mob)
|
|
|
|
def show_row_count(self, matrix, count, added_anims=[]):
|
|
rows = matrix.get_rows()
|
|
row_rects = VGroup(SurroundingRectangle(rows[0], buff=0).match_y(row) for row in rows)
|
|
row_rects.set_stroke(WHITE, 1, 0.5)
|
|
row_rects.set_fill(GREY_D, 0.5)
|
|
left_brace = Brace(matrix, LEFT, buff=SMALL_BUFF)
|
|
count_mob = Integer(count, font_size=self.count_font_size)
|
|
count_mob.next_to(left_brace, LEFT)
|
|
|
|
self.play(
|
|
GrowFromCenter(left_brace),
|
|
CountInFrom(count_mob, 0),
|
|
FadeIn(row_rects, lag_ratio=0.25),
|
|
*added_anims,
|
|
)
|
|
self.play(FadeOut(row_rects))
|
|
return VGroup(left_brace, count_mob)
|
|
|
|
def show_product(self, col_count, row_count, added_anims=[]):
|
|
col_dec = col_count[1]
|
|
row_dec = row_count[1]
|
|
prod_dec = Integer(
|
|
col_dec.get_value() * row_dec.get_value(),
|
|
font_size=self.count_font_size
|
|
)
|
|
|
|
equation = VGroup(
|
|
row_dec.copy(),
|
|
Tex(R"\times", font_size=self.count_font_size),
|
|
col_dec.copy(),
|
|
Tex(R"=", font_size=self.count_font_size),
|
|
prod_dec
|
|
)
|
|
equation.arrange(RIGHT,buff=SMALL_BUFF)
|
|
for index in [0, 2]:
|
|
equation[index].align_to(equation[4], UP)
|
|
equation.next_to(col_dec, UP, buff=1.0)
|
|
|
|
self.play(
|
|
TransformFromCopy(row_dec, equation[0]),
|
|
FadeIn(equation[1]),
|
|
TransformFromCopy(col_dec, equation[2]),
|
|
FadeIn(equation[3]),
|
|
*added_anims
|
|
)
|
|
self.play(
|
|
FadeTransform(equation[0].copy(), equation[4]),
|
|
FadeTransform(equation[2].copy(), equation[4]),
|
|
)
|
|
self.add(equation)
|
|
return equation
|
|
|
|
|
|
class LowRankTransformation(InteractiveScene):
|
|
def construct(self):
|
|
# Add three sets of axes
|
|
frame = self.frame
|
|
frame.set_field_of_view(10 * DEGREES)
|
|
|
|
all_axes = VGroup(
|
|
self.get_3d_axes(),
|
|
self.get_2d_axes(),
|
|
self.get_3d_axes(),
|
|
)
|
|
all_axes.arrange(RIGHT, buff=2.0)
|
|
all_axes.set_width(FRAME_WIDTH - 2)
|
|
all_axes.move_to(0.5 * DOWN)
|
|
dim_labels = VGroup(
|
|
Text("12,288 dims"),
|
|
Text("128 dims"),
|
|
Text("12,288 dims"),
|
|
)
|
|
dim_labels.scale(0.75)
|
|
dim_labels.set_fill(GREY_A)
|
|
for label, axes in zip(dim_labels, all_axes):
|
|
label.next_to(axes, UP, buff=MED_LARGE_BUFF)
|
|
|
|
map_arrows = Tex(R"\rightarrow", font_size=96).replicate(2)
|
|
map_arrows.set_color(YELLOW)
|
|
for arrow, vect in zip(map_arrows, [LEFT, RIGHT]):
|
|
arrow.next_to(all_axes[1], vect, buff=0.5)
|
|
|
|
axes_group = VGroup(all_axes, dim_labels)
|
|
self.add(axes_group)
|
|
self.add(map_arrows)
|
|
|
|
# Add vectors
|
|
all_coords = [
|
|
(4, 2, 1),
|
|
(2, 3),
|
|
(-3, 3, -2),
|
|
]
|
|
colors = [BLUE, RED_B, RED_C]
|
|
vects = VGroup(
|
|
Arrow(axes.get_origin(), axes.c2p(*coords), buff=0, stroke_color=color)
|
|
for axes, coords, color in zip(all_axes, all_coords, colors)
|
|
)
|
|
|
|
self.add(vects[0])
|
|
for v1, v2 in zip(vects, vects[1:]):
|
|
self.play(TransformFromCopy(v1, v2))
|
|
|
|
for axes, vect in zip(all_axes, vects):
|
|
axes.add(vect)
|
|
for axes in all_axes[0::2]:
|
|
axes.add_updater(lambda m, dt: m.rotate(2 * dt * DEGREES, axis=m.y_axis.get_vector()))
|
|
self.wait(3)
|
|
|
|
# Add title
|
|
big_rect = SurroundingRectangle(axes_group, buff=0.5)
|
|
big_rect.round_corners(radius=0.5)
|
|
big_rect.set_stroke(RED_B, 2)
|
|
title = Text("Low-rank transformation", font_size=72)
|
|
title.next_to(big_rect, UP, buff=MED_LARGE_BUFF)
|
|
|
|
self.play(
|
|
ShowCreation(big_rect),
|
|
FadeIn(title, shift=0.25 * UP)
|
|
)
|
|
self.wait(5)
|
|
|
|
|
|
def get_3d_axes(self, height=3):
|
|
result = ThreeDAxes((-4, 4), (-4, 4), (-4, 4))
|
|
result.set_height(height)
|
|
result.rotate(20 * DEGREES, DOWN)
|
|
result.rotate(5 * DEGREES, RIGHT)
|
|
return result
|
|
|
|
def get_2d_axes(self, height=2):
|
|
plane = NumberPlane(
|
|
(-4, 4), (-4, 4),
|
|
faded_line_ratio=0,
|
|
background_line_style=dict(
|
|
stroke_color=GREY_B,
|
|
stroke_width=1,
|
|
stroke_opacity=0.5
|
|
)
|
|
)
|
|
plane.set_height(height)
|
|
return plane
|
|
|
|
|
|
class ThinkAboutOverallMap(InteractiveScene):
|
|
def construct(self):
|
|
# Test
|
|
rect = Rectangle(6.5, 2.75)
|
|
rect.round_corners(radius=0.5)
|
|
rect.set_stroke(RED_B, 2)
|
|
label = Text("Think about the\noverall map")
|
|
label.next_to(rect, UP, aligned_edge=LEFT)
|
|
label.shift(0.5 * RIGHT)
|
|
self.play(
|
|
ShowCreation(rect),
|
|
FadeIn(label, UP),
|
|
)
|
|
self.wait()
|
|
|
|
|
|
class CrossAttention(InteractiveScene):
|
|
def construct(self):
|
|
# Show both
|
|
en_tokens = self.get_words("I do not want to pet it")
|
|
fr_tokens = self.get_words("Je ne veux pas le caresser", hue_range=(0.2, 0.3))
|
|
phrases = VGroup(en_tokens, fr_tokens)
|
|
phrases.arrange(DOWN, buff=2.0)
|
|
self.play(LaggedStartMap(FadeIn, en_tokens, scale=2, lag_ratio=0.25))
|
|
self.wait()
|
|
self.play(LaggedStartMap(FadeIn, fr_tokens, scale=2, lag_ratio=0.25))
|
|
self.wait()
|
|
|
|
# Create attention pattern
|
|
unnormalized_pattern = [
|
|
[3, 0, 0, 0, 0, 0],
|
|
[0, 1, 1.3, 1, 0, 0],
|
|
[0, 3, 0, 3, 0, 0],
|
|
[0, 0, 3, 0, 0, 0],
|
|
[0, 0, 0, 0, 0, 3],
|
|
[0, 0, 0, 0, 0, 3],
|
|
[0, 0, 0, 0, 3, 0],
|
|
]
|
|
attention_pattern = np.array([
|
|
softmax(col) for col in unnormalized_pattern
|
|
]).T
|
|
|
|
# Show connections
|
|
lines = VGroup()
|
|
for n, row in enumerate(attention_pattern.T):
|
|
for k, value in enumerate(row):
|
|
line = Line(en_tokens[n].get_bottom(), fr_tokens[k].get_top(), buff=0)
|
|
line.set_stroke(
|
|
color=[
|
|
en_tokens[n][0].get_color(),
|
|
fr_tokens[k][0].get_color(),
|
|
],
|
|
width=3,
|
|
opacity=value,
|
|
)
|
|
lines.add(line)
|
|
|
|
self.play(ShowCreation(lines, lag_ratio=0.01, run_time=2))
|
|
self.wait(2)
|
|
self.play(FadeOut(lines))
|
|
|
|
# Create grid
|
|
grid = Square().get_grid(len(fr_tokens), len(en_tokens), buff=0)
|
|
grid.stretch(1.2, 0)
|
|
grid.set_stroke(GREY_B, 1)
|
|
grid.set_height(5.0)
|
|
grid.to_edge(DOWN, buff=SMALL_BUFF)
|
|
grid.set_x(1)
|
|
|
|
# Create qk symbols
|
|
q_sym_generator = self.get_symbol_generator(R"\vec{\textbf{Q}}_0", color=YELLOW)
|
|
k_sym_generator = self.get_symbol_generator(R"\vec{\textbf{K}}_0", color=TEAL)
|
|
e_sym_generator = self.get_symbol_generator(R"\vec{\textbf{E}}_0", color=GREY_B)
|
|
f_sym_generator = self.get_symbol_generator(R"\vec{\textbf{F}}_0", color=BLUE)
|
|
|
|
q_syms = VGroup(q_sym_generator(n + 1) for n in range(len(en_tokens)))
|
|
k_syms = VGroup(k_sym_generator(n + 1) for n in range(len(fr_tokens)))
|
|
e_syms = VGroup(e_sym_generator(n + 1) for n in range(len(en_tokens)))
|
|
f_syms = VGroup(f_sym_generator(n + 1) for n in range(len(fr_tokens)))
|
|
VGroup(q_syms, k_syms, e_syms, f_syms).scale(0.65)
|
|
|
|
for q_sym, e_sym, square in zip(q_syms, e_syms, grid):
|
|
q_sym.next_to(square, UP, SMALL_BUFF)
|
|
e_sym.next_to(q_sym, UP, buff=0.65)
|
|
|
|
for k_sym, f_sym, square in zip(k_syms, f_syms, grid[::len(en_tokens)]):
|
|
k_sym.next_to(square, LEFT, SMALL_BUFF)
|
|
f_sym.next_to(k_sym, LEFT, buff=0.75)
|
|
|
|
q_arrows = VGroup(Arrow(*pair, buff=0.1) for pair in zip(e_syms, q_syms))
|
|
k_arrows = VGroup(Arrow(*pair, buff=0.1) for pair in zip(f_syms, k_syms))
|
|
e_arrows = VGroup(Vector(0.4 * DOWN).next_to(e_sym, UP, SMALL_BUFF) for e_sym in e_syms)
|
|
f_arrows = VGroup(Vector(0.5 * RIGHT).next_to(f_sym, LEFT, SMALL_BUFF) for f_sym in f_syms)
|
|
arrows = VGroup(q_arrows, k_arrows, e_arrows, f_arrows)
|
|
arrows.set_color(GREY_B)
|
|
|
|
wq_syms = VGroup(
|
|
Tex("W_Q", font_size=20, fill_color=YELLOW).next_to(arrow, RIGHT, buff=0.1)
|
|
for arrow in q_arrows
|
|
)
|
|
wk_syms = VGroup(
|
|
Tex("W_K", font_size=20, fill_color=TEAL).next_to(arrow, UP, buff=0.1)
|
|
for arrow in k_arrows
|
|
)
|
|
|
|
# Move tokens into place
|
|
en_tokens.target = en_tokens.generate_target()
|
|
fr_tokens.target = fr_tokens.generate_target()
|
|
for token, arrow in zip(en_tokens.target, e_arrows):
|
|
token.next_to(arrow, UP, SMALL_BUFF)
|
|
for token, arrow in zip(fr_tokens.target, f_arrows):
|
|
token.next_to(arrow, LEFT, SMALL_BUFF)
|
|
self.play(
|
|
MoveToTarget(en_tokens),
|
|
MoveToTarget(fr_tokens),
|
|
)
|
|
self.play(
|
|
LaggedStartMap(GrowArrow, e_arrows),
|
|
LaggedStartMap(GrowArrow, f_arrows),
|
|
LaggedStartMap(FadeIn, e_syms, shift=0.25 * DOWN),
|
|
LaggedStartMap(FadeIn, f_syms, shift=0.25 * RIGHT),
|
|
lag_ratio=0.25,
|
|
run_time=1.5,
|
|
)
|
|
self.play(
|
|
LaggedStartMap(GrowArrow, q_arrows),
|
|
LaggedStartMap(GrowArrow, k_arrows),
|
|
LaggedStartMap(FadeIn, wq_syms, shift=0.25 * DOWN),
|
|
LaggedStartMap(FadeIn, wk_syms, shift=0.25 * RIGHT),
|
|
LaggedStartMap(FadeIn, q_syms, shift=0.5 * DOWN),
|
|
LaggedStartMap(FadeIn, k_syms, shift=0.5 * RIGHT),
|
|
lag_ratio=0.25,
|
|
run_time=1.5,
|
|
)
|
|
self.play(FadeIn(grid, lag_ratio=1e-2), run_time=2)
|
|
self.wait()
|
|
|
|
# Show dot products
|
|
dot_prods = VGroup()
|
|
for q_sym in q_syms:
|
|
for k_sym in k_syms:
|
|
dot = Tex(".")
|
|
dot.match_x(q_sym)
|
|
dot.match_y(k_sym)
|
|
dot_prod = VGroup(q_sym.copy(), dot, k_sym.copy())
|
|
dot_prod.target = dot_prod.generate_target()
|
|
dot_prod.target.arrange(RIGHT, buff=SMALL_BUFF)
|
|
dot_prod.target.scale(0.7)
|
|
dot_prod.target.move_to(dot)
|
|
dot.set_opacity(0)
|
|
dot_prods.add(dot_prod)
|
|
|
|
self.play(
|
|
LaggedStartMap(MoveToTarget, dot_prods, lag_ratio=0.01),
|
|
run_time=3
|
|
)
|
|
self.wait()
|
|
|
|
# Show dots
|
|
dots = VGroup()
|
|
for square, value in zip(grid, attention_pattern.flatten()):
|
|
dot = Dot(radius=value * 0.4)
|
|
dot.set_fill(GREY_B, 1)
|
|
dot.move_to(square)
|
|
dots.add(dot)
|
|
|
|
|
|
self.play(
|
|
LaggedStartMap(GrowFromCenter, dots, lag_ratio=1e-2),
|
|
dot_prods.animate.set_fill(opacity=0.2).set_anim_args(lag_ratio=1e-3),
|
|
run_time=4
|
|
)
|
|
self.wait()
|
|
|
|
|
|
|
|
|
|
pass
|
|
|
|
def get_words(self, text, hue_range=(0.5, 0.6)):
|
|
sent = Text(text)
|
|
tokens = break_into_words(sent)
|
|
rects = get_piece_rectangles(
|
|
tokens, hue_range=hue_range,
|
|
# h_buff=0, leading_spaces=True
|
|
)
|
|
return VGroup(VGroup(*pair) for pair in zip(rects, tokens))
|
|
|
|
def get_symbol_generator(self, raw_tex, subsrc="0", color=WHITE):
|
|
template = Tex(raw_tex)
|
|
template.set_color(color)
|
|
subscr = template.make_number_changeable(subsrc)
|
|
|
|
def get_sym(number):
|
|
subscr.set_value(number)
|
|
return template.copy()
|
|
|
|
return get_sym
|
|
|
|
|
|
class CarCrashedExample(InteractiveScene):
|
|
def construct(self):
|
|
# Add sentence
|
|
sentence = Text("... when suddenly they crashed the car into a tree ...")
|
|
words = break_into_words(sentence)
|
|
rects = get_piece_rectangles(words)
|
|
word_groups = VGroup(VGroup(*pair) for pair in zip(rects, words))
|
|
|
|
car = word_groups[6]
|
|
crashed = VGroup(*it.chain(*(wg[1] for wg in word_groups[3:6])))
|
|
arrow = Vector(UP).next_to(car, UP, SMALL_BUFF)
|
|
|
|
self.play(LaggedStartMap(FadeIn, word_groups, shift=0.25 * UP, lag_ratio=0.25))
|
|
self.play(
|
|
word_groups[:3].animate.fade(0.5),
|
|
word_groups[7:].animate.fade(0.5),
|
|
FadeIn(arrow),
|
|
)
|
|
self.wait()
|
|
|
|
# Influence
|
|
self.play(ContextAnimation(car, crashed, direction=DOWN, run_time=5))
|
|
self.wait()
|
|
|
|
|
|
class TwoHarrysExample(InteractiveScene):
|
|
def construct(self):
|
|
# Test
|
|
s1, s2 = sentences = VGroup(
|
|
break_into_words(Text("... " + " ... ".join(words)))
|
|
for words in [
|
|
("wizard", "Hogwarts", "Hermione", "Harry"),
|
|
("Queen", "Sussex", "William", "Harry"),
|
|
]
|
|
)
|
|
sentences.arrange(DOWN, buff=2.0, aligned_edge=RIGHT)
|
|
sentences.to_edge(LEFT)
|
|
|
|
def context_anim(group):
|
|
self.play(
|
|
ContextAnimation(
|
|
group[-1],
|
|
VGroup(*it.chain(*group[1:-1:2])),
|
|
direction=DOWN,
|
|
path_arc=PI / 4,
|
|
run_time=5,
|
|
lag_ratio=0.025,
|
|
)
|
|
)
|
|
|
|
self.add(s1)
|
|
context_anim(s1)
|
|
self.wait()
|
|
self.play(FadeTransformPieces(s1.copy(), s2))
|
|
context_anim(s2)
|
|
|
|
|
|
class ManyTypesOfUpdates(InteractiveScene):
|
|
def construct(self):
|
|
# Add matrices
|
|
shapes = [(4, 8), (4, 8), (8, 4), (4, 8)]
|
|
names = ["W_Q", "W_K", R"\uparrow W_V", R"\downarrow W_V"]
|
|
colors = [YELLOW, TEAL, RED_B, RED_C]
|
|
|
|
matrices = VGroup(
|
|
WeightMatrix(shape=shape)
|
|
for shape in shapes
|
|
)
|
|
buff_ratio = 0.35
|
|
matrices.arrange(RIGHT, buff=matrices[0].get_width() * buff_ratio)
|
|
matrices[-1].next_to(matrices[-2], RIGHT, buff=matrices[-2].get_width() * 0.1)
|
|
matrices.center()
|
|
matrices.set_width(FRAME_WIDTH - 2)
|
|
matrices.to_edge(UP, buff=1.0)
|
|
titles = VGroup(
|
|
Tex(name).set_color(color).match_x(mat).to_edge(UP, buff=MED_SMALL_BUFF)
|
|
for name, color, mat in zip(names, colors, matrices)
|
|
)
|
|
for title in titles[2:]:
|
|
title[0].next_to(title[1], LEFT, buff=0.5 * SMALL_BUFF)
|
|
|
|
self.add(matrices, titles)
|
|
|
|
# Add phrase
|
|
phrase = Text("John hit the brakes sharply, they screeched loudly, and he jolted forward.")
|
|
raw_words = break_into_words(phrase)
|
|
rects = get_piece_rectangles(raw_words)
|
|
rects.fade(0.5)
|
|
words = VGroup(VGroup(*pair) for pair in zip(rects, raw_words))
|
|
words.set_width(FRAME_WIDTH - 1)
|
|
words.center().set_y(-2)
|
|
|
|
self.add(words)
|
|
|
|
labels = index_labels(words)
|
|
labels.shift(0.5 * DOWN)
|
|
|
|
# Set up association types
|
|
attention_types = [
|
|
(
|
|
"Adverb to verb",
|
|
[
|
|
(1, 4, 1.0),
|
|
(6, 7, 1.0),
|
|
]
|
|
),
|
|
(
|
|
"Subject to verb",
|
|
[
|
|
(0, 1, 1.0),
|
|
(3, 6, 0.5),
|
|
(5, 6, 0.5),
|
|
(0, 10, 0.5),
|
|
(9, 10, 0.5),
|
|
],
|
|
),
|
|
(
|
|
"Antecedent to pronoun",
|
|
[
|
|
(0, 9, 1.0),
|
|
(3, 5, 1.0),
|
|
]
|
|
),
|
|
(
|
|
"Related to the subject",
|
|
[
|
|
(0, 1, 0.25),
|
|
(0, 3, 0.25),
|
|
(0, 9, 0.2),
|
|
(0, 10, 0.2),
|
|
(0, 11, 0.2),
|
|
]
|
|
),
|
|
(
|
|
"Related to the object",
|
|
[
|
|
(3, 4, 0.2),
|
|
(3, 5, 0.5),
|
|
(3, 6, 0.35),
|
|
(3, 7, 0.2),
|
|
]
|
|
),
|
|
]
|
|
|
|
# Animate
|
|
last_group = VGroup()
|
|
for description, connections in attention_types:
|
|
desc = Text(description)
|
|
desc.center()
|
|
connections = VGroup(
|
|
Line(
|
|
words[i].get_top(),
|
|
words[j].get_top(),
|
|
path_arc=-PI / 2,
|
|
stroke_color=random_bright_color(
|
|
hue_range=(0.3, 0.5),
|
|
luminance_range=(0.5, 0.7),
|
|
),
|
|
stroke_opacity=strength**0.5,
|
|
)
|
|
for (i, j, strength) in connections
|
|
)
|
|
connections.set_stroke(width=(0, 5, 5, 5, 0))
|
|
connections.shuffle()
|
|
self.play(
|
|
FadeOut(last_group),
|
|
# FadeIn(desc, shift=0.25 * UP),
|
|
ShowCreation(connections, lag_ratio=0.25, run_time=0.5 * len(connections)),
|
|
LaggedStart(
|
|
(self.get_matrix_update_anim(mat)
|
|
for mat in matrices),
|
|
lag_ratio=0.15,
|
|
),
|
|
LaggedStart(
|
|
(VShowPassingFlash(
|
|
line.copy().insert_n_curves(100).set_stroke(width=10),
|
|
time_width=2.0,
|
|
run_time=2,
|
|
)
|
|
for line in connections),
|
|
lag_ratio=0.1,
|
|
)
|
|
)
|
|
self.wait(2)
|
|
# last_group = VGroup(desc, connections)
|
|
last_group = VGroup(connections)
|
|
|
|
def get_matrix_update_anim(self, matrix):
|
|
rects = VGroup(
|
|
Underline(entry, buff=0.05)
|
|
for entry in matrix.get_entries()
|
|
)
|
|
rects.set_stroke(WHITE, 1)
|
|
return AnimationGroup(
|
|
LaggedStartMap(ShowCreationThenFadeOut, rects, lag_ratio=1e-2),
|
|
RandomizeMatrixEntries(matrix)
|
|
)
|
|
|
|
|
|
class MultiHeadedAttention(InteractiveScene):
|
|
def construct(self):
|
|
# Mention head
|
|
background_rect = FullScreenRectangle()
|
|
single_title = Text("Single head of attention")
|
|
multiple_title = Text("Multi-headed attention")
|
|
titles = VGroup(single_title, multiple_title)
|
|
for title in titles:
|
|
title.scale(1.25)
|
|
title.to_edge(UP)
|
|
|
|
screen_rect = ScreenRectangle(height=6)
|
|
screen_rect.set_fill(BLACK, 1)
|
|
screen_rect.set_stroke(WHITE, 3)
|
|
screen_rect.next_to(titles, DOWN, buff=0.5)
|
|
|
|
head = single_title["head"][0]
|
|
|
|
self.add(background_rect)
|
|
self.add(single_title)
|
|
self.add(screen_rect)
|
|
self.wait()
|
|
self.play(
|
|
FlashAround(head, run_time=2),
|
|
head.animate.set_color(YELLOW),
|
|
)
|
|
self.wait()
|
|
|
|
# Change title
|
|
kw = dict(path_arc=45 * DEGREES)
|
|
self.play(
|
|
FadeTransform(single_title["Single"], multiple_title["Multi-"], **kw),
|
|
FadeTransform(single_title["head"], multiple_title["head"], **kw),
|
|
FadeIn(multiple_title["ed"], 0.25 * RIGHT),
|
|
FadeTransform(single_title["attention"], multiple_title["attention"], **kw),
|
|
FadeOut(single_title["of"])
|
|
)
|
|
self.add(multiple_title)
|
|
|
|
# Set up images
|
|
n_heads = 15
|
|
directory = "/Users/grant/3Blue1Brown Dropbox/3Blue1Brown/videos/2024/transformers/attention/images/"
|
|
heads = Group()
|
|
for n in range(n_heads):
|
|
im = ImageMobject(os.path.join(directory, f"AttentionPattern{n % 4 + 1}"))
|
|
im.set_opacity(1)
|
|
im.shift(0.01 * OUT)
|
|
rect = SurroundingRectangle(im, buff=0)
|
|
rect.set_fill(BLACK, 0.75)
|
|
rect.set_stroke(WHITE, 1, 1)
|
|
heads.add(Group(rect, im))
|
|
|
|
# Show many parallel layers
|
|
self.set_floor_plane("xz")
|
|
frame = self.frame
|
|
multiple_title.fix_in_frame()
|
|
background_rect.fix_in_frame()
|
|
|
|
heads.set_height(4)
|
|
heads.arrange(OUT, buff=1.0)
|
|
heads.move_to(DOWN)
|
|
pre_head = ImageMobject(os.path.join(directory, f"AttentionPattern0"))
|
|
|
|
pre_head.replace(screen_rect)
|
|
pre_head = Group(screen_rect, pre_head)
|
|
|
|
self.add(pre_head)
|
|
self.wait()
|
|
self.play(
|
|
frame.animate.reorient(41, -12, 0, (-1.0, -1.42, 1.09), 12.90).set_anim_args(run_time=2),
|
|
background_rect.animate.set_fill(opacity=0.75),
|
|
FadeTransform(pre_head, heads[-1], time_span=(1, 2)),
|
|
)
|
|
self.play(
|
|
frame.animate.reorient(48, -11, 0, (-1.0, -1.42, 1.09), 12.90),
|
|
LaggedStart(
|
|
(FadeTransform(heads[-1].copy(), image)
|
|
for image in heads),
|
|
lag_ratio=0.1,
|
|
group_type=Group,
|
|
),
|
|
run_time=4,
|
|
)
|
|
self.add(heads)
|
|
self.wait()
|
|
|
|
# Show matrices
|
|
colors = [YELLOW, TEAL, RED, PINK]
|
|
texs = ["W_Q", "W_K", R"\downarrow W_V", R"\uparrow W_V"]
|
|
n_shown = 9
|
|
wq_syms, wk_syms, wv_down_syms, wv_up_syms = sym_groups = VGroup(
|
|
VGroup(
|
|
Tex(tex + f"^{{({n})}}", font_size=36).next_to(image, UP, MED_SMALL_BUFF)
|
|
for n, image in enumerate(heads[:-n_shown - 1:-1], start=1)
|
|
).set_color(color).set_backstroke(BLACK, 5)
|
|
for tex, color in zip(texs, colors)
|
|
)
|
|
for group in wv_down_syms, wv_up_syms:
|
|
for sym in group:
|
|
sym[0].next_to(sym[1], LEFT, buff=0.025)
|
|
dots = Tex(R"\dots", font_size=90)
|
|
dots.rotate(PI / 2, UP)
|
|
sym_rot_angle = 70 * DEGREES
|
|
for syms in sym_groups:
|
|
syms.align_to(heads, LEFT)
|
|
for sym in syms:
|
|
sym.rotate(sym_rot_angle, UP)
|
|
dots.next_to(syms, IN, buff=0.5)
|
|
dots.match_style(syms[0])
|
|
syms.add(dots.copy())
|
|
|
|
up_shift = 0.75 * UP
|
|
self.play(
|
|
LaggedStartMap(FadeIn, wq_syms, shift=0.2 * UP, lag_ratio=0.25),
|
|
frame.animate.reorient(59, -7, 0, (-1.62, 0.25, 1.29), 14.18),
|
|
run_time=2,
|
|
)
|
|
for n in range(1, len(sym_groups)):
|
|
self.play(
|
|
LaggedStartMap(FadeIn, sym_groups[n], shift=0.2 * UP, lag_ratio=0.1),
|
|
sym_groups[:n].animate.shift(up_shift),
|
|
run_time=1,
|
|
)
|
|
self.wait()
|
|
|
|
# Count up 96 heads
|
|
depth = heads.get_depth()
|
|
brace = Brace(Line(LEFT, RIGHT).set_width(0.5 * depth), UP).scale(2)
|
|
brace_label = brace.get_text("96", font_size=96, buff=MED_SMALL_BUFF)
|
|
brace_group = VGroup(brace, brace_label)
|
|
brace_group.rotate(PI / 2, UP)
|
|
brace_group.next_to(heads, UP, buff=MED_LARGE_BUFF)
|
|
|
|
self.add(brace, brace_label, sym_groups)
|
|
self.play(
|
|
frame.animate.reorient(62, -6, 0, (-0.92, -0.08, -0.51), 14.18).set_anim_args(run_time=5),
|
|
GrowFromCenter(brace),
|
|
sym_groups.animate.set_fill(opacity=0.5).set_stroke(width=0),
|
|
FadeIn(brace_label, 0.5 * UP, time_span=(0.5, 1.5)),
|
|
)
|
|
|
|
# Set up pure attention patterns, flattened
|
|
for head in heads:
|
|
n_rows = 8
|
|
grid = Square().get_grid(n_rows, 1, buff=0).get_grid(1, n_rows, buff=0)
|
|
grid.set_stroke(WHITE, 1, 0.5)
|
|
grid.set_height(0.9 * head.get_height())
|
|
grid.move_to(head)
|
|
|
|
pattern = np.random.normal(0, 1, (n_rows, n_rows))
|
|
for n in range(len(pattern[0])):
|
|
pattern[:, n][n + 1:] = -np.inf
|
|
pattern[:, n] = softmax(pattern[:, n])
|
|
pattern = pattern.T
|
|
|
|
dots = VGroup()
|
|
for col, values in zip(grid, pattern):
|
|
for square, value in zip(col, values):
|
|
if value < 1e-3:
|
|
continue
|
|
dot = Dot(radius=0.4 * square.get_height() * value)
|
|
dot.move_to(square)
|
|
dots.add(dot)
|
|
dots.set_fill(GREY_B, 1)
|
|
grid.add(dots)
|
|
|
|
head.add(grid)
|
|
head.target = head.generate_target()
|
|
grid.set_opacity(0)
|
|
head.target[1].set_opacity(0)
|
|
head.target[0].set_opacity(1)
|
|
|
|
n_shown = 4
|
|
heads_target = Group(h.target for h in heads)
|
|
heads_target.arrange(LEFT, buff=MED_LARGE_BUFF)
|
|
heads_target.set_height(1.5)
|
|
heads_target.to_edge(LEFT)
|
|
heads_target.shift(2 * UP)
|
|
heads_target[:-n_shown].set_opacity(0)
|
|
|
|
# Set up key/query targets
|
|
for group in sym_groups:
|
|
group.generate_target()
|
|
group_targets = [group.target for group in sym_groups]
|
|
|
|
for head, wq, wk, wv_down, wv_up in zip(heads_target[::-1], *group_targets):
|
|
for sym in [wq, wk, wv_down, wv_up]:
|
|
sym.set_fill(opacity=1)
|
|
sym.set_height(0.35)
|
|
sym.rotate(-sym_rot_angle, UP)
|
|
wk.next_to(head, UP, aligned_edge=LEFT)
|
|
wq.next_to(wk, RIGHT, buff=0.35)
|
|
wv_up.next_to(head, UP, aligned_edge=LEFT)
|
|
wv_down.next_to(wv_up, RIGHT, buff=0.35)
|
|
|
|
for group in group_targets:
|
|
group[n_shown:].set_opacity(0)
|
|
|
|
# Animate the flattening
|
|
right_dots = Tex(R"\dots", font_size=96)
|
|
right_dots.move_to(heads_target[-n_shown - 1], LEFT).shift(MED_SMALL_BUFF * RIGHT)
|
|
|
|
brace_group.target = brace_group.generate_target()
|
|
brace_group.target.shift(UP)
|
|
brace_group.target.set_opacity(0)
|
|
|
|
self.play(
|
|
frame.animate.reorient(0, 0, 0, ORIGIN, FRAME_HEIGHT).set_anim_args(run_time=2),
|
|
FadeOut(multiple_title, UP),
|
|
MoveToTarget(brace_group, remover=True),
|
|
MoveToTarget(wq_syms, time_span=(0.5, 2)),
|
|
MoveToTarget(wk_syms, time_span=(0.5, 2)),
|
|
FadeOut(wv_down_syms),
|
|
FadeOut(wv_up_syms),
|
|
LaggedStartMap(MoveToTarget, heads, lag_ratio=0.01),
|
|
Write(right_dots, time_span=(1.5, 2.0)),
|
|
)
|
|
|
|
att_patterns = VGroup(
|
|
VGroup(head[0], head[2])
|
|
for head in heads[:len(heads) - n_shown - 1:-1]
|
|
)
|
|
self.remove(heads)
|
|
self.add(att_patterns)
|
|
|
|
# Show value maps
|
|
for group in [wv_up_syms, wv_down_syms]:
|
|
group.become(group.target)
|
|
|
|
value_diagrams = VGroup()
|
|
arrows = VGroup()
|
|
all_v_stacks = VGroup()
|
|
for pattern, wv_up, wv_down, idx in zip(att_patterns, wv_up_syms, wv_down_syms, it.count(1)):
|
|
rect = pattern[0].copy()
|
|
|
|
v_stack = VGroup(Tex(Rf"\vec{{\textbf{{v}}}}_{n}") for n in range(1, 4))
|
|
v_stack.arrange(DOWN, buff=LARGE_BUFF)
|
|
v_stack.set_color(RED)
|
|
plusses = VGroup()
|
|
coefs = VGroup()
|
|
for n, v_term in enumerate(v_stack):
|
|
coef = Tex(f"w_{n + 1}")
|
|
coef.next_to(v_term, LEFT, SMALL_BUFF)
|
|
coef.set_fill(GREY_B)
|
|
plus = Tex("+")
|
|
plus.next_to(VGroup(coef, v_term), DOWN)
|
|
plusses.add(plus)
|
|
coefs.add(coef)
|
|
dots = Tex(R"\vdots")
|
|
dots.next_to(plusses, DOWN)
|
|
v_stack.add(coefs, plusses, dots)
|
|
|
|
v_stacks = v_stack.replicate(4)
|
|
v_stacks.arrange(RIGHT, buff=LARGE_BUFF)
|
|
v_stacks.set_height(rect.get_height() * 0.85)
|
|
v_stacks.set_fill(border_width=1)
|
|
|
|
v_terms = VGroup(
|
|
*(Tex(Rf"\vec{{\textbf{{v}}}}_{n}^{{({idx})}}") for n in range(1, 4)),
|
|
Tex(R"\dots")
|
|
)
|
|
v_terms[:3].set_color(RED)
|
|
v_terms.arrange(RIGHT)
|
|
v_terms.set_width(0.8 * rect.get_width())
|
|
v_terms.move_to(rect)
|
|
|
|
diagram = VGroup(rect, v_terms)
|
|
diagram.to_edge(DOWN, buff=1.5)
|
|
|
|
v_stacks.move_to(rect)
|
|
all_v_stacks.add(v_stacks)
|
|
|
|
VGroup(wv_up, wv_down).next_to(diagram, UP, buff=SMALL_BUFF, aligned_edge=LEFT)
|
|
|
|
arrow = Arrow(pattern, diagram, buff=0.5)
|
|
arrow.shift(0.25 * UP)
|
|
|
|
value_diagrams.add(diagram)
|
|
arrows.add(arrow)
|
|
|
|
right_dots2 = right_dots.copy()
|
|
|
|
self.play(
|
|
LaggedStart(
|
|
(FadeTransform(m1.copy(), m2)
|
|
for m1, m2 in zip(att_patterns, value_diagrams)),
|
|
lag_ratio=0.25,
|
|
group_type=Group,
|
|
),
|
|
LaggedStartMap(FadeIn, wv_up_syms, shift=DOWN, lag_ratio=0.25),
|
|
LaggedStartMap(FadeIn, wv_down_syms, shift=DOWN, lag_ratio=0.25),
|
|
LaggedStartMap(GrowArrow, arrows, lag_ratio=0.25),
|
|
right_dots2.animate.match_y(value_diagrams).set_anim_args(time_span=(1.0, 1.75)),
|
|
)
|
|
self.wait()
|
|
|
|
self.play(
|
|
LaggedStart(
|
|
(Transform(VGroup(diagram[1]), v_stacks)
|
|
for diagram, v_stacks in zip(value_diagrams, all_v_stacks)),
|
|
lag_ratio=0.25,
|
|
run_time=2
|
|
)
|
|
)
|
|
self.remove(value_diagrams)
|
|
new_diagrams = VGroup(
|
|
VGroup(vd[0], stacks)
|
|
for vd, stacks in zip(value_diagrams, all_v_stacks)
|
|
)
|
|
value_diagrams = new_diagrams
|
|
self.add(value_diagrams)
|
|
|
|
# Show sums
|
|
index = 2
|
|
rects = VGroup()
|
|
delta_Es = VGroup()
|
|
arrows = VGroup()
|
|
for n, diagram in enumerate(value_diagrams, start=1):
|
|
diagram.target = diagram.generate_target()
|
|
stacks = diagram.target[1]
|
|
stacks.set_opacity(0.5)
|
|
stacks[index].set_opacity(1, border_width=1)
|
|
rect = SurroundingRectangle(stacks[index], buff=0.05)
|
|
|
|
arrow = Vector(0.5 * DOWN)
|
|
arrow.set_color(BLUE)
|
|
arrow.next_to(rect, DOWN, SMALL_BUFF)
|
|
|
|
delta_E = Tex(Rf"\Delta \vec{{\textbf{{E}}}}^{{({n})}}_i", font_size=36)
|
|
delta_E.set_color(BLUE)
|
|
delta_E.next_to(arrow, DOWN, SMALL_BUFF)
|
|
|
|
rects.add(rect)
|
|
arrows.add(arrow)
|
|
delta_Es.add(delta_E)
|
|
|
|
rects.set_stroke(BLUE, 2)
|
|
|
|
self.play(
|
|
LaggedStartMap(MoveToTarget, value_diagrams),
|
|
LaggedStartMap(ShowCreation, rects),
|
|
LaggedStartMap(GrowArrow, arrows),
|
|
LaggedStartMap(FadeIn, delta_Es, shift=0.5 * DOWN),
|
|
)
|
|
self.wait()
|
|
|
|
# Add together all changes
|
|
low_delta_Es = delta_Es.copy()
|
|
low_delta_Es.scale(1.5)
|
|
low_delta_Es.arrange(RIGHT, buff=0.75)
|
|
low_delta_Es.next_to(delta_Es, DOWN, buff=1.0)
|
|
plusses = VGroup(
|
|
Tex("+", font_size=72).next_to(ldE, buff=0.1).shift(0.1 * DOWN)
|
|
for ldE in low_delta_Es
|
|
)
|
|
dots = Tex(R"\dots", font_size=72).next_to(plusses, RIGHT)
|
|
|
|
self.play(
|
|
TransformFromCopy(delta_Es, low_delta_Es),
|
|
Write(plusses),
|
|
Write(dots),
|
|
frame.animate.reorient(0, 0, 0, (-0.99, -1.51, 0.0), 10.71),
|
|
)
|
|
self.wait()
|
|
|
|
# Include original embedding
|
|
og_emb = Tex(R"\vec{\textbf{E}}_i", font_size=72)
|
|
og_emb_plus = Tex("+", font_size=72)
|
|
og_emb_plus.next_to(low_delta_Es, LEFT, SMALL_BUFF)
|
|
og_emb.next_to(og_emb_plus, LEFT, 2 * SMALL_BUFF)
|
|
lil_rect = SurroundingRectangle(og_emb)
|
|
big_rect = SurroundingRectangle(VGroup(og_emb, low_delta_Es, dots), buff=0.25)
|
|
lil_rect.set_stroke(WHITE, 2)
|
|
big_rect.set_stroke(TEAL, 3)
|
|
og_label = Text("Original\nembedding")
|
|
new_label = Text("New\nembedding")
|
|
new_label.set_color(TEAL)
|
|
for label in [og_label, new_label]:
|
|
label.next_to(lil_rect, LEFT, buff=MED_LARGE_BUFF)
|
|
|
|
self.play(
|
|
FadeIn(og_emb, shift=RIGHT, scale=0.5),
|
|
Write(og_emb_plus),
|
|
FadeIn(og_label, shift=RIGHT),
|
|
)
|
|
self.play(ShowCreation(lil_rect))
|
|
self.wait()
|
|
self.play(
|
|
ReplacementTransform(lil_rect, big_rect),
|
|
FadeTransform(og_label, new_label)
|
|
)
|
|
self.wait()
|
|
|
|
|
|
class OutputMatrix(InteractiveScene):
|
|
def construct(self):
|
|
# Set up all heads
|
|
matrix_pairs = VGroup(self.get_factored_value_map() for x in range(3))
|
|
matrix_pairs.arrange(RIGHT, buff=LARGE_BUFF)
|
|
matrix_pairs.to_edge(LEFT)
|
|
matrix_pairs.set_y(1)
|
|
dots = Tex(R"\dots", font_size=120)
|
|
dots.next_to(matrix_pairs, RIGHT, LARGE_BUFF)
|
|
|
|
rects = VGroup(SurroundingRectangle(pair, buff=0.25) for pair in matrix_pairs)
|
|
rects.set_stroke(RED, 2)
|
|
labels = VGroup()
|
|
for n, rect in enumerate(rects, start=1):
|
|
rect.set_height(2.5, stretch=True, about_edge=UP)
|
|
rect.round_corners(radius=0.1)
|
|
label = Text(f"Head {n}\nValue map", font_size=36)
|
|
label.next_to(rect, UP)
|
|
labels.add(label)
|
|
|
|
up_labels = VGroup()
|
|
down_labels = VGroup()
|
|
for n, pair in enumerate(matrix_pairs, start=1):
|
|
up_mat, down_mat = pair
|
|
down_label = TexText(Rf"Value$^{{({n})}}_{{\downarrow}}$", font_size=30)
|
|
up_label = TexText(Rf"Value$^{{({n})}}_{{\uparrow}}$", font_size=30)
|
|
for label, mat, v in zip([up_label, down_label], pair, [ORIGIN, 0.25 * RIGHT]):
|
|
label.next_to(pair, DOWN, buff=0.5)
|
|
label[-1].scale(1.5, about_edge=UL)
|
|
label.match_x(mat)
|
|
label.shift(v)
|
|
arrow = FillArrow(label[2], mat, thickness=0.025)
|
|
arrow.scale(0.6)
|
|
label.add(arrow)
|
|
|
|
up_labels.add(up_label)
|
|
down_labels.add(down_label)
|
|
|
|
up_labels.set_fill(RED_B)
|
|
down_labels.set_fill(RED_C)
|
|
|
|
# Animate
|
|
for pair, rect, label, up_label, down_label in zip(matrix_pairs, rects, labels, up_labels, down_labels):
|
|
mat_labels =VGroup(up_label, down_label)
|
|
self.play(
|
|
FadeIn(label, 0.25 * UP),
|
|
LaggedStartMap(FadeIn, pair, scale=1.25, lag_ratio=0.5),
|
|
LaggedStartMap(FadeIn, mat_labels, lag_ratio=0.5),
|
|
ShowCreation(rect),
|
|
)
|
|
self.play(Write(dots))
|
|
self.wait()
|
|
|
|
# Aggregate into the output matrix
|
|
up_matrices = VGroup(pair[0] for pair in matrix_pairs)
|
|
stapled_up_matrices = up_matrices.copy()
|
|
for mat in stapled_up_matrices:
|
|
brackets = mat[-2:]
|
|
brackets[0].stretch(0, 0, about_edge=RIGHT)
|
|
brackets[1].stretch(0, 0, about_edge=LEFT)
|
|
brackets.set_opacity(0)
|
|
stapled_up_matrices.arrange(RIGHT, buff=SMALL_BUFF)
|
|
stapled_up_matrices.scale(2)
|
|
stapled_up_matrices.next_to(rects, DOWN, buff=1.5)
|
|
|
|
up_labels.target = up_labels.generate_target()
|
|
lines = VGroup()
|
|
for stum, up_label in zip(stapled_up_matrices, up_labels.target):
|
|
line = Line(UP, DOWN).match_height(stum)
|
|
line.set_stroke(WHITE, 1)
|
|
line.next_to(stum, RIGHT, buff=SMALL_BUFF / 2)
|
|
lines.add(line)
|
|
up_label[-1].set_opacity(0)
|
|
up_label[-1].scale(0, about_edge=DOWN)
|
|
up_label.scale(0.75)
|
|
up_label.next_to(stum, UP, buff=SMALL_BUFF)
|
|
|
|
out_dots = dots.copy()
|
|
out_dots.scale(0.5)
|
|
out_dots.next_to(lines, RIGHT)
|
|
out_brackets = up_matrices[0].get_brackets().copy()
|
|
out_brackets.match_height(stapled_up_matrices)
|
|
out_brackets[0].next_to(stapled_up_matrices, LEFT, SMALL_BUFF)
|
|
out_brackets[1].next_to(out_dots, RIGHT, SMALL_BUFF)
|
|
|
|
out_matrix = VGroup(stapled_up_matrices, lines, out_dots, out_brackets)
|
|
|
|
self.play(
|
|
self.frame.animate.reorient(0, 0, 0, (-0.88, -0.87, 0.0), 8.00),
|
|
up_matrices.animate.set_opacity(0.5),
|
|
TransformFromCopy(up_matrices, stapled_up_matrices, lag_ratio=1e-4),
|
|
MoveToTarget(up_labels),
|
|
TransformFromCopy(dots, out_dots),
|
|
FadeIn(lines, lag_ratio=0.5),
|
|
FadeIn(out_brackets, scale=1.25),
|
|
run_time=2
|
|
)
|
|
self.wait()
|
|
|
|
# Circle and label output
|
|
out_rect = SurroundingRectangle(VGroup(out_matrix, up_labels), buff=MED_SMALL_BUFF)
|
|
out_rect.round_corners(radius=0.1)
|
|
out_rect.set_stroke(PINK, 3)
|
|
out_label = Text("Output\nmatrix")
|
|
out_label.set_color(PINK)
|
|
out_label.next_to(out_rect, LEFT)
|
|
|
|
self.play(
|
|
ShowCreation(out_rect),
|
|
FadeIn(out_label, shift=0.25 * LEFT, scale=1.25),
|
|
)
|
|
self.wait()
|
|
|
|
# Center the down matrices
|
|
self.play(
|
|
LaggedStart(
|
|
(pair[1].animate.shift(0.5 * LEFT)
|
|
for pair in matrix_pairs),
|
|
lag_ratio=0.05,
|
|
),
|
|
LaggedStart(
|
|
(label.animate.shift(0.5 * LEFT)
|
|
for label in down_labels),
|
|
lag_ratio=0.05,
|
|
),
|
|
LaggedStartMap(FadeOut, up_matrices)
|
|
)
|
|
self.wait()
|
|
|
|
def get_factored_value_map(self, big_d=7, lil_d=4, height=1.0):
|
|
matrices = VGroup(
|
|
WeightMatrix(shape=(big_d, lil_d)),
|
|
WeightMatrix(shape=(lil_d, big_d)),
|
|
)
|
|
matrices.arrange(RIGHT, buff=matrices[0].get_width() * 0.1)
|
|
matrices.set_height(height)
|
|
return matrices
|
|
|
|
|
|
class Parallelizability(InteractiveScene):
|
|
def construct(self):
|
|
# Set up curves
|
|
n_instances = 20
|
|
comp_syms = Tex(R"+\,\times").replicate(n_instances)
|
|
comp_syms.arrange(DOWN)
|
|
comp_syms.set_height(5.5)
|
|
comp_syms.to_edge(DOWN)
|
|
left_point = comp_syms.get_left() + 2 * LEFT
|
|
right_point = comp_syms.get_right() + 2 * RIGHT
|
|
curves = VGroup()
|
|
for sym in comp_syms:
|
|
curve = VMobject()
|
|
curve.start_new_path(left_point)
|
|
curve.add_cubic_bezier_curve_to(
|
|
left_point + RIGHT,
|
|
sym.get_left() + LEFT,
|
|
sym.get_left()
|
|
)
|
|
curve.add_line_to(sym.get_right())
|
|
curve.add_cubic_bezier_curve_to(
|
|
sym.get_right() + RIGHT,
|
|
right_point + LEFT,
|
|
right_point,
|
|
)
|
|
curve.insert_n_curves(10)
|
|
curves.add(curve)
|
|
curves.set_stroke(width=(0, 2, 2, 2, 0))
|
|
curves.set_submobject_colors_by_gradient(TEAL, BLUE)
|
|
|
|
# Setup words
|
|
in_word = Text("Input")
|
|
out_word = Text("output")
|
|
in_word.next_to(left_point, LEFT, SMALL_BUFF)
|
|
out_word.next_to(right_point, RIGHT, SMALL_BUFF)
|
|
self.add(comp_syms, in_word, out_word)
|
|
|
|
# GPU symbol
|
|
gpu = SVGMobject("gpu_large.svg")
|
|
gpu.set_fill(GREY_B)
|
|
gpu.set_width(1.5)
|
|
gpu.next_to(comp_syms, UP)
|
|
gpu_name = Text("GPU")
|
|
gpu_name.next_to(gpu, UP)
|
|
gpu_name.set_fill(GREY_B)
|
|
self.add(gpu, gpu_name)
|
|
|
|
# Animation
|
|
for n in range(4):
|
|
curves.shuffle()
|
|
self.play(
|
|
LaggedStartMap(
|
|
ShowPassingFlash, curves,
|
|
lag_ratio=5e-3,
|
|
time_width=1.5,
|
|
run_time=4
|
|
)
|
|
)
|