mirror of
https://github.com/3b1b/videos.git
synced 2025-09-18 21:38:53 +00:00
1324 lines
48 KiB
Python
1324 lines
48 KiB
Python
from sqlalchemy.sql.base import _DialectArgDict
|
|
from manim_imports_ext import *
|
|
from _2024.transformers.helpers import *
|
|
|
|
|
|
class AttentionPatterns(InteractiveScene):
|
|
def construct(self):
|
|
# Add sentence
|
|
phrase = " a fluffy blue creature roamed the verdant forest"
|
|
phrase_mob = Text(phrase)
|
|
phrase_mob.move_to(2 * UP)
|
|
words = list(filter(lambda s: s.strip(), phrase.split(" ")))
|
|
word2mob: Dict[str, VMobject] = {
|
|
word: phrase_mob[" " + word][0]
|
|
for word in words
|
|
}
|
|
word_mobs = VGroup(*word2mob.values())
|
|
|
|
self.play(
|
|
LaggedStartMap(FadeIn, word_mobs, shift=0.5 * UP, lag_ratio=0.25)
|
|
)
|
|
self.wait()
|
|
|
|
# Create word rects
|
|
word2rect: Dict[str, VMobject] = dict()
|
|
for word in words:
|
|
rect = SurroundingRectangle(word2mob[word])
|
|
rect.set_height(phrase_mob.get_height() + SMALL_BUFF, stretch=True)
|
|
rect.match_y(phrase_mob)
|
|
rect.set_stroke(GREY, 2)
|
|
rect.set_fill(GREY, 0.2)
|
|
word2rect[word] = rect
|
|
|
|
# Adjectives updating noun
|
|
adjs = ["fluffy", "blue", "verdant"]
|
|
nouns = ["creature", "forest"]
|
|
others = ["a", "roamed", "the"]
|
|
adj_mobs, noun_mobs, other_mobs = [
|
|
VGroup(word2mob[substr] for substr in group)
|
|
for group in [adjs, nouns, others]
|
|
]
|
|
adj_rects, noun_rects, other_rects = [
|
|
VGroup(word2rect[substr] for substr in group)
|
|
for group in [adjs, nouns, others]
|
|
]
|
|
adj_rects.set_submobject_colors_by_gradient(BLUE_C, BLUE_D, GREEN)
|
|
noun_rects.set_color(GREY_BROWN).set_stroke(width=3)
|
|
kw = dict()
|
|
adj_arrows = VGroup(
|
|
Arrow(
|
|
adj_mobs[i].get_top(), noun_mobs[j].get_top(),
|
|
path_arc=-150 * DEGREES, buff=0.1, stroke_color=GREY_B
|
|
)
|
|
for i, j in [(0, 0), (1, 0), (2, 1)]
|
|
)
|
|
|
|
self.play(
|
|
LaggedStartMap(DrawBorderThenFill, adj_rects),
|
|
Animation(adj_mobs),
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
LaggedStartMap(DrawBorderThenFill, noun_rects),
|
|
Animation(noun_mobs),
|
|
LaggedStartMap(ShowCreation, adj_arrows, lag_ratio=0.2, run_time=1.5),
|
|
)
|
|
kw = dict(time_width=2, max_stroke_width=10, lag_ratio=0.2, path_arc=150 * DEGREES)
|
|
self.play(
|
|
ContextAnimation(noun_mobs[0], adj_mobs[:2], strengths=[1, 1], **kw),
|
|
ContextAnimation(noun_mobs[1], adj_mobs[2:], strengths=[1], **kw),
|
|
)
|
|
self.wait()
|
|
|
|
# Show embeddings
|
|
all_rects = VGroup(*adj_rects, *noun_rects, *other_rects)
|
|
all_rects.sort(lambda p: p[0])
|
|
embeddings = VGroup(
|
|
NumericEmbedding(length=10).set_width(0.5).next_to(rect, DOWN, buff=1.5)
|
|
for rect in all_rects
|
|
)
|
|
emb_arrows = VGroup(
|
|
Arrow(all_rects[0].get_bottom(), embeddings[0].get_top()).match_x(rect)
|
|
for rect in all_rects
|
|
)
|
|
for index, vect in [(5, LEFT), (6, RIGHT)]:
|
|
embeddings[index].shift(0.1 * vect)
|
|
emb_arrows[index].shift(0.05 * vect)
|
|
|
|
self.play(
|
|
FadeIn(other_rects),
|
|
Animation(word_mobs),
|
|
LaggedStartMap(GrowArrow, emb_arrows),
|
|
LaggedStartMap(FadeIn, embeddings, shift=0.5 * DOWN),
|
|
FadeOut(adj_arrows)
|
|
)
|
|
self.wait()
|
|
|
|
# Mention dimension of embedding
|
|
frame = self.frame
|
|
brace = Brace(embeddings[0], LEFT, buff=SMALL_BUFF)
|
|
dim_value = Integer(12288)
|
|
dim_value.next_to(brace, LEFT)
|
|
dim_value.set_color(YELLOW)
|
|
|
|
self.play(
|
|
GrowFromCenter(brace),
|
|
CountInFrom(dim_value, 0),
|
|
frame.animate.move_to(LEFT)
|
|
)
|
|
self.wait()
|
|
|
|
# Ingest meaning and and position
|
|
images = Group(
|
|
ImageMobject(f"Dalle3_{word}").set_height(1.1).next_to(word2rect[word], UP)
|
|
for word in ["fluffy", "blue", "creature", "verdant", "forest"]
|
|
)
|
|
image_vects = VGroup(embeddings[i] for i in [1, 2, 3, 6, 7])
|
|
|
|
self.play(
|
|
LaggedStartMap(FadeIn, images, scale=2, lag_ratio=0.05)
|
|
)
|
|
self.play(
|
|
LaggedStart(
|
|
(self.bake_mobject_into_vector_entries(image, vect, group_type=Group)
|
|
for image, vect in zip(images, image_vects)),
|
|
group_type=Group,
|
|
lag_ratio=0.2,
|
|
run_time=4,
|
|
remover=True
|
|
),
|
|
)
|
|
self.wait()
|
|
self.add(embeddings, images)
|
|
|
|
# Show positions
|
|
pos_labels = VGroup(
|
|
Integer(n, font_size=36).next_to(rect, DOWN, buff=0.1)
|
|
for n, rect in enumerate(all_rects, start=1)
|
|
)
|
|
pos_labels.set_color(TEAL)
|
|
|
|
self.play(
|
|
LaggedStart(
|
|
(arrow.animate.scale(0.7, about_edge=DOWN)
|
|
for arrow in emb_arrows),
|
|
lag_ratio=0.1,
|
|
),
|
|
LaggedStartMap(FadeIn, pos_labels, shift=0.25 * DOWN, lag_ratio=0.1)
|
|
)
|
|
self.play(
|
|
LaggedStart(
|
|
(self.bake_mobject_into_vector_entries(pos, vect)
|
|
for pos, vect in zip(pos_labels, embeddings)),
|
|
lag_ratio=0.2,
|
|
run_time=4,
|
|
remover=True
|
|
),
|
|
)
|
|
self.wait()
|
|
|
|
# Collapse vectors
|
|
template = Tex(R"\vec{\textbf{E}}_{0}")
|
|
template[0].scale(1.5, about_edge=DOWN)
|
|
dec = template.make_number_changeable(0)
|
|
emb_syms = VGroup()
|
|
for n, rect in enumerate(all_rects, start=1):
|
|
dec.set_value(n)
|
|
sym = template.copy()
|
|
sym.next_to(rect, DOWN, buff=0.75)
|
|
sym.set_color(GREY_A)
|
|
emb_syms.add(sym)
|
|
for subgroup in [emb_syms[:4], emb_syms[4:]]:
|
|
subgroup.arrange_to_fit_width(subgroup.get_width())
|
|
|
|
emb_arrows.target = emb_arrows.generate_target()
|
|
|
|
for rect, arrow, sym in zip(all_rects, emb_arrows.target, emb_syms):
|
|
x_min = rect.get_x(LEFT)
|
|
x_max = rect.get_x(RIGHT)
|
|
low_point = sym[0].get_top()
|
|
if x_min < low_point[0] < x_max:
|
|
top_point = np.array([low_point[0], rect.get_y(DOWN), 0])
|
|
else:
|
|
top_point = rect.get_bottom()
|
|
arrow.become(Arrow(top_point, low_point, buff=SMALL_BUFF))
|
|
|
|
all_brackets = VGroup(emb.get_brackets() for emb in embeddings)
|
|
for brackets in all_brackets:
|
|
brackets.target = brackets.generate_target()
|
|
brackets.target.stretch(0, 1, about_edge=UP)
|
|
brackets.target.set_fill(opacity=0)
|
|
|
|
ghost_syms = emb_syms.copy()
|
|
ghost_syms.set_opacity(0)
|
|
|
|
self.play(
|
|
frame.animate.set_x(0).set_anim_args(run_time=2),
|
|
LaggedStart(
|
|
(AnimationGroup(
|
|
LaggedStart(
|
|
(FadeTransform(entry, sym)
|
|
for entry in embedding.get_columns()[0]),
|
|
lag_ratio=0.01,
|
|
group_type=Group
|
|
),
|
|
MoveToTarget(brackets),
|
|
group_type=Group,
|
|
)
|
|
for sym, embedding, brackets in zip(ghost_syms, embeddings, all_brackets)),
|
|
group_type=Group
|
|
),
|
|
LaggedStartMap(FadeIn, emb_syms, shift=UP),
|
|
brace.animate.stretch(0.25, 1, about_edge=UP).set_opacity(0),
|
|
FadeOut(dim_value, 0.25 * UP),
|
|
MoveToTarget(emb_arrows, lag_ratio=0.1, run_time=2),
|
|
LaggedStartMap(FadeOut, pos_labels, shift=UP),
|
|
)
|
|
emb_arrows.refresh_bounding_box(recurse_down=True) # Why?
|
|
self.clear()
|
|
self.add(emb_arrows, all_rects, word_mobs, images, emb_syms)
|
|
self.wait()
|
|
|
|
# Preview desired updates
|
|
emb_sym_primes = VGroup(
|
|
sym.copy().add(Tex("'").move_to(sym.get_corner(UR) + 0.05 * DL))
|
|
for sym in emb_syms
|
|
)
|
|
emb_sym_primes.shift(2 * DOWN)
|
|
emb_sym_primes.set_color(TEAL)
|
|
|
|
full_connections = VGroup()
|
|
for i, sym1 in enumerate(emb_syms, start=1):
|
|
for j, sym2 in enumerate(emb_sym_primes, start=1):
|
|
line = Line(sym1.get_bottom(), sym2.get_top(), buff=SMALL_BUFF)
|
|
line.set_stroke(GREY_B, width=random.random()**2, opacity=random.random()**0.25)
|
|
if (i, j) in [(2, 4), (3, 4), (4, 4), (7, 8), (8, 8)]:
|
|
line.set_stroke(WHITE, width=2 + random.random(), opacity=1)
|
|
full_connections.add(line)
|
|
|
|
blue_fluff = ImageMobject("BlueFluff")
|
|
verdant_forest = ImageMobject("VerdantForest")
|
|
for n, image in [(3, blue_fluff), (7, verdant_forest)]:
|
|
image.match_height(images)
|
|
image.scale(1.2)
|
|
image.next_to(emb_sym_primes[n], DOWN, buff=MED_SMALL_BUFF)
|
|
|
|
self.play(
|
|
ShowCreation(full_connections, lag_ratio=0.01, run_time=2),
|
|
LaggedStart(
|
|
(TransformFromCopy(sym1, sym2)
|
|
for sym1, sym2 in zip(emb_syms, emb_sym_primes)),
|
|
lag_ratio=0.05,
|
|
),
|
|
)
|
|
self.wait()
|
|
self.play(LaggedStart(
|
|
LaggedStart(
|
|
(FadeTransform(im.copy(), blue_fluff, remover=True)
|
|
for im in images[:3]),
|
|
lag_ratio=0.02,
|
|
group_type=Group
|
|
),
|
|
LaggedStart(
|
|
(FadeTransform(im.copy(), verdant_forest, remover=True)
|
|
for im in images[3:]),
|
|
lag_ratio=0.02,
|
|
group_type=Group
|
|
),
|
|
lag_ratio=0.5,
|
|
run_time=2
|
|
))
|
|
self.add(blue_fluff, verdant_forest)
|
|
self.wait()
|
|
|
|
# Show black box that matrix multiples can be added to
|
|
in_arrows = VGroup(
|
|
Vector(0.25 * DOWN, max_width_to_length_ratio=12.0).next_to(sym, DOWN, SMALL_BUFF)
|
|
for sym in emb_syms
|
|
)
|
|
box = Rectangle(15.0, 3.0)
|
|
box.set_fill(GREY_E, 1)
|
|
box.set_stroke(WHITE, 1)
|
|
box.next_to(in_arrows, DOWN, SMALL_BUFF)
|
|
out_arrows = in_arrows.copy()
|
|
out_arrows.next_to(box, DOWN)
|
|
|
|
self.play(
|
|
FadeIn(box, 0.25 * DOWN),
|
|
LaggedStartMap(FadeIn, in_arrows, shift=0.25 * DOWN, lag_ratio=0.025),
|
|
LaggedStartMap(FadeIn, out_arrows, shift=0.25 * DOWN, lag_ratio=0.025),
|
|
FadeOut(full_connections),
|
|
emb_sym_primes.animate.next_to(out_arrows, DOWN, SMALL_BUFF),
|
|
MaintainPositionRelativeTo(blue_fluff, emb_sym_primes),
|
|
MaintainPositionRelativeTo(verdant_forest, emb_sym_primes),
|
|
frame.animate.set_height(10).move_to(4 * UP, UP),
|
|
)
|
|
self.wait()
|
|
|
|
# Clear the board
|
|
self.play(
|
|
frame.animate.set_height(8).move_to(2 * UP).set_anim_args(run_time=1.5),
|
|
LaggedStartMap(FadeOut, Group(
|
|
*images, in_arrows, box, out_arrows, emb_sym_primes,
|
|
blue_fluff, verdant_forest,
|
|
), lag_ratio=0.1)
|
|
)
|
|
|
|
# Ask questions
|
|
word_groups = VGroup(VGroup(*pair) for pair in zip(all_rects, word_mobs))
|
|
for group in word_groups:
|
|
group.save_state()
|
|
q_bubble = SpeechBubble("Any adjectives\nin front of me?")
|
|
q_bubble.move_tip_to(word2rect["creature"].get_top())
|
|
|
|
a_bubbles = SpeechBubble("I am!", direction=RIGHT).replicate(2)
|
|
a_bubbles[1].flip()
|
|
a_bubbles[0].move_tip_to(word2rect["fluffy"].get_top())
|
|
a_bubbles[1].move_tip_to(word2rect["blue"].get_top())
|
|
|
|
self.play(
|
|
FadeIn(q_bubble),
|
|
word_groups[:3].animate.fade(0.75),
|
|
word_groups[4:].animate.fade(0.75),
|
|
)
|
|
self.wait()
|
|
self.play(LaggedStart(
|
|
Restore(word_groups[1]),
|
|
Restore(word_groups[2]),
|
|
*map(Write, a_bubbles),
|
|
lag_ratio=0.5
|
|
))
|
|
self.wait()
|
|
|
|
# Associate questions with vectors
|
|
a_bubbles.save_state()
|
|
q_arrows = VGroup(
|
|
Vector(0.75 * DOWN).next_to(sym, DOWN, SMALL_BUFF)
|
|
for sym in emb_syms
|
|
)
|
|
q_vects = VGroup(
|
|
NumericEmbedding(length=7).set_height(2).next_to(arrow, DOWN)
|
|
for arrow in q_arrows
|
|
)
|
|
question = q_bubble.content
|
|
|
|
|
|
index = words.index("creature")
|
|
q_vect = q_vects[index]
|
|
q_arrow = q_arrows[index]
|
|
self.play(LaggedStart(
|
|
FadeOut(q_bubble.body, DOWN),
|
|
question.animate.scale(0.75).next_to(q_vect, RIGHT),
|
|
FadeIn(q_vect, DOWN),
|
|
GrowArrow(q_arrow),
|
|
frame.animate.move_to(ORIGIN),
|
|
a_bubbles.animate.fade(0.5),
|
|
))
|
|
self.play(
|
|
self.bake_mobject_into_vector_entries(question, q_vect)
|
|
)
|
|
self.wait()
|
|
|
|
# Label query vector
|
|
brace = Brace(q_vect, LEFT, SMALL_BUFF)
|
|
query_word = Text("Query")
|
|
query_word.set_color(YELLOW)
|
|
query_word.next_to(brace, LEFT, SMALL_BUFF)
|
|
dim_text = Text("128-dimensional", font_size=36)
|
|
dim_text.set_color(YELLOW)
|
|
dim_text.next_to(brace, LEFT, SMALL_BUFF)
|
|
dim_text.set_y(query_word.get_y(DOWN))
|
|
|
|
self.play(
|
|
GrowFromCenter(brace),
|
|
FadeIn(query_word, 0.25 * LEFT),
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
query_word.animate.next_to(dim_text, UP, SMALL_BUFF),
|
|
FadeIn(dim_text, 0.1 * DOWN),
|
|
)
|
|
self.wait()
|
|
|
|
# Show individual matrix product
|
|
e_vect = NumericEmbedding(length=12)
|
|
e_vect.match_width(q_vect)
|
|
e_vect.next_to(q_vect, DR, buff=1.5)
|
|
matrix = WeightMatrix(shape=(7, 12))
|
|
matrix.match_height(q_vect)
|
|
matrix.next_to(e_vect, LEFT)
|
|
e_label_copy = emb_syms[index].copy()
|
|
e_label_copy.next_to(e_vect, UP)
|
|
q_vect.save_state()
|
|
ghost_q_vect = NumericEmbedding(length=7).match_height(q_vect)
|
|
ghost_q_vect.get_columns().set_opacity(0)
|
|
ghost_q_vect.get_brackets().space_out_submobjects(1.75)
|
|
ghost_q_vect.next_to(e_vect, RIGHT, buff=0.7)
|
|
|
|
mat_brace = Brace(matrix, UP)
|
|
mat_label = Tex("W_Q")
|
|
mat_label.next_to(mat_brace, UP, SMALL_BUFF)
|
|
mat_label.set_color(YELLOW)
|
|
|
|
self.play(
|
|
frame.animate.set_height(11).move_to(all_rects, UP).shift(0.35 * UP),
|
|
FadeOut(a_bubbles),
|
|
FadeInFromPoint(e_vect, emb_syms[index].get_center()),
|
|
FadeInFromPoint(matrix, q_arrow.get_center()),
|
|
TransformFromCopy(emb_syms[index], e_label_copy),
|
|
FadeOut(q_vect),
|
|
TransformFromCopy(q_vect, ghost_q_vect),
|
|
MaintainPositionRelativeTo(question, q_vect),
|
|
)
|
|
self.play(
|
|
GrowFromCenter(mat_brace),
|
|
FadeIn(mat_label, 0.1 * UP),
|
|
)
|
|
self.remove(ghost_q_vect)
|
|
eq, rhs = show_matrix_vector_product(self, matrix, e_vect)
|
|
|
|
new_q_vect = rhs.deepcopy()
|
|
new_q_vect.move_to(q_vect, LEFT)
|
|
|
|
self.play(
|
|
TransformFromCopy(rhs, new_q_vect, path_arc=PI / 2),
|
|
question.animate.next_to(new_q_vect, RIGHT)
|
|
)
|
|
self.wait()
|
|
|
|
# Collapse query vector
|
|
q_sym_template = Tex(R"\vec{\textbf{Q}}_0", font_size=48)
|
|
q_sym_template[0].scale(1.5, about_edge=DOWN)
|
|
q_sym_template.set_color(YELLOW)
|
|
subscript = q_sym_template.make_number_changeable(0)
|
|
q_syms = VGroup()
|
|
for n, arrow in enumerate(q_arrows, start=1):
|
|
subscript.set_value(n)
|
|
sym = q_sym_template.copy()
|
|
sym.next_to(arrow, DOWN, SMALL_BUFF)
|
|
q_syms.add(sym)
|
|
|
|
mat_label2 = mat_label.copy()
|
|
|
|
q_sym = q_syms[index]
|
|
low_q_sym = q_sym.copy()
|
|
low_q_sym.next_to(rhs, UP)
|
|
globals().update(locals())
|
|
|
|
self.play(LaggedStart(
|
|
LaggedStart(
|
|
(FadeTransform(entry, q_sym, remover=True)
|
|
for entry in new_q_vect.get_columns()[0]),
|
|
lag_ratio=0.01,
|
|
group_type=Group,
|
|
),
|
|
new_q_vect.get_brackets().animate.stretch(0, 1, about_edge=UP).set_opacity(0),
|
|
FadeOutToPoint(query_word, q_sym.get_center()),
|
|
FadeOutToPoint(dim_text, q_sym.get_center()),
|
|
FadeOut(brace),
|
|
question.animate.next_to(q_sym, DOWN),
|
|
FadeIn(low_q_sym, UP),
|
|
lag_ratio=0.1,
|
|
))
|
|
self.remove(new_q_vect)
|
|
self.add(q_sym)
|
|
self.play(
|
|
mat_label2.animate.scale(0.9).next_to(q_arrow, RIGHT, buff=0.15),
|
|
)
|
|
self.wait()
|
|
|
|
# E to Q rects
|
|
e_rects = VGroup(map(SurroundingRectangle, [emb_syms[index], e_vect]))
|
|
q_rects = VGroup(map(SurroundingRectangle, [q_sym, rhs]))
|
|
e_rects.set_stroke(TEAL, 3)
|
|
q_rects.set_stroke(YELLOW, 3)
|
|
self.play(ShowCreation(e_rects, lag_ratio=0.2))
|
|
self.wait()
|
|
self.play(Transform(e_rects, q_rects))
|
|
self.wait()
|
|
self.play(FadeOut(e_rects))
|
|
|
|
# Add other query vectors
|
|
remaining_q_arrows = VGroup(*q_arrows[:index], *q_arrows[index + 1:])
|
|
remaining_q_syms = VGroup(*q_syms[:index], *q_syms[index + 1:])
|
|
wq_syms = VGroup(
|
|
Tex(R"W_Q", font_size=30).next_to(arrow, RIGHT, buff=0.1)
|
|
for arrow in q_arrows
|
|
)
|
|
wq_syms.set_color(YELLOW)
|
|
subscripts = VGroup(e_label_copy[-1], low_q_sym[-1][0])
|
|
for subscript in subscripts:
|
|
i_sym = Tex("i")
|
|
i_sym.replace(subscript)
|
|
i_sym.scale(0.75)
|
|
i_sym.match_style(subscript)
|
|
subscript.target = i_sym
|
|
|
|
self.play(
|
|
LaggedStartMap(GrowArrow, remaining_q_arrows),
|
|
LaggedStartMap(FadeIn, remaining_q_syms, shift=0.1 * DOWN),
|
|
ReplacementTransform(VGroup(mat_label2), wq_syms, lag_ratio=0.01, run_time=2),
|
|
question.animate.shift(0.25 * DOWN),
|
|
*map(Restore, word_groups),
|
|
*map(MoveToTarget, subscripts),
|
|
)
|
|
self.wait()
|
|
|
|
# Emphasize model weights
|
|
self.play(
|
|
LaggedStartMap(FlashAround, matrix.get_entries(), lag_ratio=1e-2),
|
|
RandomizeMatrixEntries(matrix),
|
|
)
|
|
data_modifying_matrix(self, matrix, word_shape=(3, 8))
|
|
self.wait()
|
|
self.play(
|
|
LaggedStartMap(FadeOut, VGroup(
|
|
matrix, mat_brace, mat_label,
|
|
e_vect, e_label_copy, eq, rhs,
|
|
low_q_sym
|
|
), shift=0.2 * DR)
|
|
)
|
|
self.wait()
|
|
|
|
# Move question
|
|
noun_q_syms = VGroup(q_syms[words.index(word)] for word in ["creature", "forest"])
|
|
|
|
self.play(
|
|
question.animate.shift(0.25 * DOWN).match_x(noun_q_syms)
|
|
)
|
|
|
|
noun_q_lines = VGroup(
|
|
Line(question.get_corner(v), sym.get_corner(-v))
|
|
for sym, v in zip(noun_q_syms, [UL, UR])
|
|
)
|
|
noun_q_lines.set_stroke(GREY, 1)
|
|
self.play(ShowCreation(noun_q_lines, lag_ratio=0))
|
|
self.wait()
|
|
|
|
# Set up keys
|
|
key_word_groups = word_groups.copy()
|
|
key_word_groups.arrange(DOWN, buff=0.75, aligned_edge=RIGHT)
|
|
key_word_groups.next_to(q_syms, DL, buff=LARGE_BUFF)
|
|
key_word_groups.shift(3.0 * LEFT)
|
|
key_emb_syms = emb_syms.copy()
|
|
|
|
k_sym_template = Tex(R"\vec{\textbf{K}}_0", font_size=48)
|
|
k_sym_template[0].scale(1.5, about_edge=DOWN)
|
|
k_sym_template.set_color(TEAL)
|
|
subscript = k_sym_template.make_number_changeable(0)
|
|
|
|
k_syms = VGroup()
|
|
key_emb_arrows = VGroup()
|
|
wk_arrows = VGroup()
|
|
wk_syms = VGroup()
|
|
for group, emb_sym, n in zip(key_word_groups, key_emb_syms, it.count(1)):
|
|
emb_arrow = Vector(0.5 * RIGHT)
|
|
emb_arrow.next_to(group, RIGHT, SMALL_BUFF)
|
|
emb_sym.next_to(emb_arrow, RIGHT, SMALL_BUFF)
|
|
wk_arrow = Vector(0.75 * RIGHT)
|
|
wk_arrow.next_to(emb_sym, RIGHT)
|
|
wk_sym = Tex("W_k", font_size=30)
|
|
wk_sym.set_fill(TEAL, border_width=1)
|
|
wk_sym.next_to(wk_arrow, UP)
|
|
subscript.set_value(n)
|
|
k_sym = k_sym_template.copy()
|
|
k_sym.next_to(wk_arrow, RIGHT, buff=MED_SMALL_BUFF)
|
|
|
|
key_emb_arrows.add(emb_arrow)
|
|
wk_arrows.add(wk_arrow)
|
|
wk_syms.add(wk_sym)
|
|
k_syms.add(k_sym)
|
|
|
|
self.play(
|
|
frame.animate.move_to(2.5 * LEFT + 2.75 * DOWN),
|
|
TransformFromCopy(word_groups, key_word_groups),
|
|
TransformFromCopy(emb_arrows, key_emb_arrows),
|
|
TransformFromCopy(emb_syms, key_emb_syms),
|
|
FadeOut(question),
|
|
FadeOut(noun_q_lines),
|
|
run_time=2,
|
|
)
|
|
self.play(
|
|
LaggedStartMap(GrowArrow, wk_arrows),
|
|
LaggedStartMap(FadeIn, wk_syms, shift=0.1 * UP),
|
|
)
|
|
self.play(LaggedStart(
|
|
(TransformFromCopy(e_sym, k_sym)
|
|
for e_sym, k_sym in zip(key_emb_syms, k_syms)),
|
|
lag_ratio=0.05,
|
|
))
|
|
self.wait()
|
|
|
|
# Show example key matrix
|
|
matrix = WeightMatrix(shape=(7, 12))
|
|
matrix.set_width(5)
|
|
matrix.next_to(k_syms, UP, buff=2.0, aligned_edge=RIGHT)
|
|
mat_rect = SurroundingRectangle(matrix, buff=MED_SMALL_BUFF)
|
|
lil_rect = SurroundingRectangle(wk_syms[0])
|
|
lines = VGroup(
|
|
Line(lil_rect.get_corner(v + UP), mat_rect.get_corner(v + DOWN))
|
|
for v in [LEFT, RIGHT]
|
|
)
|
|
VGroup(mat_rect, lil_rect, *lines).set_stroke(GREY_A, 1)
|
|
|
|
self.play(ShowCreation(lil_rect))
|
|
self.play(
|
|
ShowCreation(lines, lag_ratio=0),
|
|
TransformFromCopy(lil_rect, mat_rect),
|
|
FadeInFromPoint(matrix, lil_rect.get_center()),
|
|
)
|
|
self.wait()
|
|
data_modifying_matrix(self, matrix, word_shape=(3, 8))
|
|
self.play(
|
|
LaggedStartMap(FadeOut, VGroup(matrix, mat_rect, lines, lil_rect), run_time=1)
|
|
)
|
|
|
|
# Isolate examples
|
|
fade_rects = VGroup(
|
|
BackgroundRectangle(VGroup(key_word_groups[0], wk_syms[0], k_syms[0])),
|
|
BackgroundRectangle(VGroup(key_word_groups[3:], wk_syms[3:], k_syms[3:])),
|
|
BackgroundRectangle(wq_syms[2]),
|
|
BackgroundRectangle(VGroup(word_groups[:3], q_syms[:3])),
|
|
BackgroundRectangle(VGroup(word_groups[4:], q_syms[4:])),
|
|
)
|
|
fade_rects.set_fill(BLACK, 0.75)
|
|
fade_rects.set_stroke(BLACK, 3, 1)
|
|
q_bubble = SpeechBubble("Any adjectives\nin front of me?")
|
|
q_bubble.flip(RIGHT)
|
|
q_bubble.next_to(q_syms[3][-1], DOWN, SMALL_BUFF, LEFT)
|
|
a_bubbles = SpeechBubble("I'm an adjective!\nI'm there!").replicate(2)
|
|
a_bubbles[0].pin_to(k_syms[1])
|
|
a_bubbles[1].pin_to(k_syms[2])
|
|
a_bubbles[1].flip(RIGHT, about_edge=DOWN)
|
|
a_bubbles[1].shift(0.5 * DOWN)
|
|
|
|
self.add(fade_rects, word_groups[3])
|
|
self.play(FadeIn(fade_rects))
|
|
self.play(FadeIn(q_bubble, lag_ratio=0.1))
|
|
self.play(FadeIn(a_bubbles, lag_ratio=0.05))
|
|
self.wait()
|
|
self.play(
|
|
LaggedStartMap(FadeOut, VGroup(q_bubble, *a_bubbles), lag_ratio=0.25)
|
|
)
|
|
self.wait()
|
|
|
|
# Draw grid
|
|
emb_arrows.refresh_bounding_box(recurse_down=True)
|
|
q_groups = VGroup(
|
|
VGroup(group[i] for group in [
|
|
emb_arrows, emb_syms, wq_syms, q_arrows, q_syms
|
|
])
|
|
for i in range(len(emb_arrows))
|
|
)
|
|
q_groups.target = q_groups.generate_target()
|
|
q_groups.target.arrange_to_fit_width(12, about_edge=LEFT)
|
|
q_groups.target.shift(0.25 * DOWN)
|
|
|
|
word_groups.target = word_groups.generate_target()
|
|
for word_group, q_group in zip(word_groups.target, q_groups.target):
|
|
word_group.scale(0.7)
|
|
word_group.next_to(q_group[0], UP, SMALL_BUFF)
|
|
|
|
h_lines = VGroup()
|
|
v_buff = 0.5 * (key_word_groups[0].get_y(DOWN) - key_word_groups[1].get_y(UP))
|
|
for kwg in key_word_groups:
|
|
h_line = Line(LEFT, RIGHT).set_width(20)
|
|
h_line.next_to(kwg, UP, buff=v_buff)
|
|
h_line.align_to(key_word_groups, LEFT)
|
|
h_lines.add(h_line)
|
|
|
|
v_lines = VGroup()
|
|
h_buff = 0.5
|
|
for q_group in q_groups.target:
|
|
v_line = Line(UP, DOWN).set_height(14)
|
|
v_line.next_to(q_group, LEFT, buff=h_buff, aligned_edge=UP)
|
|
v_lines.add(v_line)
|
|
v_lines.add(v_lines[-1].copy().next_to(q_groups.target, RIGHT, 0.5, UP))
|
|
|
|
grid_lines = VGroup(*h_lines, *v_lines)
|
|
grid_lines.set_stroke(GREY_A, 1)
|
|
|
|
self.play(
|
|
frame.animate.set_height(15, about_edge=UP).set_x(-2).set_anim_args(run_time=3),
|
|
MoveToTarget(q_groups),
|
|
MoveToTarget(word_groups),
|
|
ShowCreation(h_lines, lag_ratio=0.2),
|
|
ShowCreation(v_lines, lag_ratio=0.2),
|
|
FadeOut(fade_rects),
|
|
)
|
|
|
|
# Take all dot products
|
|
dot_prods = VGroup()
|
|
for k_sym in k_syms:
|
|
for q_sym in q_syms:
|
|
square_center = np.array([q_sym.get_x(), k_sym.get_y(), 0])
|
|
dot = Tex(R".", font_size=72)
|
|
dot.move_to(square_center)
|
|
dot.set_fill(opacity=0)
|
|
dot_prod = VGroup(k_sym.copy(), dot, q_sym.copy())
|
|
dot_prod.target = dot_prod.generate_target()
|
|
dot_prod.target.arrange(RIGHT, buff=0.15)
|
|
dot_prod.target.scale(0.65)
|
|
dot_prod.target.move_to(square_center)
|
|
dot_prod.target.set_fill(opacity=1)
|
|
dot_prods.add(dot_prod)
|
|
|
|
self.play(
|
|
LaggedStartMap(MoveToTarget, dot_prods, lag_ratio=0.025, run_time=4)
|
|
)
|
|
self.wait()
|
|
|
|
# Show grid of dots
|
|
dots = VGroup(
|
|
VGroup(Dot().match_x(q_sym).match_y(k_sym) for q_sym in q_syms)
|
|
for k_sym in k_syms
|
|
)
|
|
for n, row in enumerate(dots, start=1):
|
|
for k, dot in enumerate(row, start=1):
|
|
dot.set_fill(GREY_C, 0.8)
|
|
dot.set_width(random.random())
|
|
dot.target = dot.generate_target()
|
|
dot.target.set_width(0.1 + 0.2 * random.random())
|
|
if (n, k) in [(2, 4), (3, 4), (7, 8)]:
|
|
dot.target.set_width(0.8 + 0.2 * random.random())
|
|
flat_dots = VGroup(*it.chain(*dots))
|
|
|
|
self.play(
|
|
dot_prods.animate.set_fill(opacity=0.75),
|
|
LaggedStartMap(GrowFromCenter, flat_dots)
|
|
)
|
|
self.wait()
|
|
self.play(LaggedStartMap(MoveToTarget, flat_dots, lag_ratio=0.01))
|
|
self.wait()
|
|
|
|
# Resize to reflect true pattern
|
|
k_groups = VGroup(
|
|
VGroup(group[i] for group in [
|
|
key_word_groups, key_emb_arrows,
|
|
key_emb_syms, wk_syms, wk_arrows, k_syms
|
|
])
|
|
for i in range(len(emb_arrows))
|
|
)
|
|
for q_group, word_group in zip(q_groups, word_groups):
|
|
q_group.add_to_back(word_group)
|
|
self.add(k_groups, q_groups, Point())
|
|
|
|
k_fade_rects = VGroup(map(BackgroundRectangle, k_groups))
|
|
q_fade_rects = VGroup(map(BackgroundRectangle, q_groups))
|
|
for rect in (*k_fade_rects, *q_fade_rects):
|
|
rect.scale(1.05)
|
|
rect.set_fill(BLACK, 0.8)
|
|
|
|
self.play(
|
|
frame.animate.move_to([-4.33, -2.4, 0.0]).set_height(9.52),
|
|
FadeIn(k_fade_rects[:1]),
|
|
FadeIn(k_fade_rects[3:]),
|
|
FadeIn(q_fade_rects[:3]),
|
|
FadeIn(q_fade_rects[4:]),
|
|
run_time=2
|
|
)
|
|
self.wait()
|
|
|
|
k_rects = VGroup(map(SurroundingRectangle, k_groups[1:3]))
|
|
k_rects.set_stroke(TEAL, 2)
|
|
q_rects = VGroup(SurroundingRectangle(q_groups[3]))
|
|
q_rects.set_stroke(YELLOW, 2)
|
|
|
|
self.play(
|
|
ShowCreation(k_rects, lag_ratio=0.5, run_time=2),
|
|
LaggedStartMap(
|
|
FlashAround, k_groups[1:3],
|
|
color=TEAL,
|
|
time_width=2,
|
|
lag_ratio=0.25,
|
|
run_time=3
|
|
),
|
|
)
|
|
self.wait()
|
|
self.play(TransformFromCopy(k_rects, q_rects))
|
|
self.wait()
|
|
|
|
# Show numerical dot product
|
|
high_dot_prods = VGroup(dot_prods[8 + 3], dot_prods[2 * 8 + 3])
|
|
dots_to_grow = VGroup(dots[1][3], dots[2][3])
|
|
numerical_dot_prods = VGroup(
|
|
VGroup(
|
|
DecimalNumber(
|
|
np.random.uniform(-100, 10),
|
|
include_sign=True,
|
|
font_size=42,
|
|
num_decimal_places=1,
|
|
edge_to_fix=ORIGIN,
|
|
).move_to(dot)
|
|
for dot in row
|
|
)
|
|
for row in dots
|
|
)
|
|
for n, row in enumerate(numerical_dot_prods):
|
|
row[n].set_value(5 * random.random()) # Add some self relevance
|
|
flat_numerical_dot_prods = VGroup(*it.chain(*numerical_dot_prods))
|
|
for ndp in flat_numerical_dot_prods:
|
|
ndp.set_fill(interpolate_color(RED_E, GREY_C, random.random()))
|
|
high_numerical_dot_prods = VGroup(
|
|
numerical_dot_prods[1][3],
|
|
numerical_dot_prods[2][3],
|
|
numerical_dot_prods[6][7],
|
|
)
|
|
for hdp in high_numerical_dot_prods:
|
|
hdp.set_value(92 + 2 * random.random())
|
|
hdp.set_color(WHITE)
|
|
low_numerical_dot_prod = numerical_dot_prods[5][3]
|
|
low_numerical_dot_prod.set_value(-31.4)
|
|
low_numerical_dot_prod.set_fill(RED_D)
|
|
|
|
self.play(
|
|
*(dtg.animate.scale(1.25) for dtg in dots_to_grow),
|
|
*(CountInFrom(ndp, run_time=1) for ndp in high_numerical_dot_prods[:2]),
|
|
*(VFadeIn(ndp) for ndp in high_numerical_dot_prods[:2]),
|
|
*(FadeOut(dot_prod, run_time=0.5) for dot_prod in dot_prods),
|
|
)
|
|
self.wait()
|
|
|
|
# Show "attends to"
|
|
att_arrow = Arrow(k_rects.get_top(), q_rects.get_left(), path_arc=-90 * DEGREES)
|
|
att_words = TexText("``Attend to''", font_size=72)
|
|
att_words.next_to(att_arrow.pfp(0.4), UL)
|
|
|
|
self.play(
|
|
ShowCreation(att_arrow),
|
|
Write(att_words),
|
|
)
|
|
self.wait()
|
|
self.play(FadeOut(att_words), FadeOut(att_arrow))
|
|
|
|
# Contrast with "the" and "creature"
|
|
self.play(
|
|
frame.animate.move_to([-2.79, -3.66, 0.0]).set_height(12.29),
|
|
*(k_rect.animate.surround(k_groups[5]) for k_rect in k_rects),
|
|
FadeIn(k_fade_rects[1:3]),
|
|
FadeOut(k_fade_rects[5]),
|
|
run_time=2,
|
|
)
|
|
self.play(
|
|
CountInFrom(low_numerical_dot_prod),
|
|
VFadeIn(low_numerical_dot_prod),
|
|
FadeOut(dots[5][3]),
|
|
)
|
|
self.wait()
|
|
|
|
# Zoom out on full grid
|
|
self.play(
|
|
frame.animate.move_to([-1.5, -4.8, 0.0]).set_height(15).set_anim_args(run_time=3),
|
|
LaggedStart(
|
|
FadeOut(k_rects),
|
|
FadeOut(q_rects),
|
|
FadeOut(k_fade_rects[:5]),
|
|
FadeOut(k_fade_rects[6:]),
|
|
FadeOut(q_fade_rects[:3]),
|
|
FadeOut(q_fade_rects[4:]),
|
|
FadeOut(dots),
|
|
LaggedStartMap(FadeIn, numerical_dot_prods),
|
|
Animation(high_numerical_dot_prods.copy(), remover=True),
|
|
Animation(low_numerical_dot_prod.copy(), remover=True),
|
|
)
|
|
)
|
|
self.wait()
|
|
|
|
# Focus on one column
|
|
ndp_columns = VGroup(
|
|
VGroup(row[i] for row in numerical_dot_prods)
|
|
for i in range(len(numerical_dot_prods[0]))
|
|
)
|
|
col_rect = SurroundingRectangle(ndp_columns[3], buff=0.25)
|
|
col_rect.set_stroke(YELLOW, 2)
|
|
weight_words = Text("We want these to\nact like weights", font_size=96)
|
|
weight_words.set_backstroke(BLACK, 8)
|
|
weight_words.next_to(col_rect, RIGHT, buff=MED_LARGE_BUFF)
|
|
weight_words.match_y(h_lines[2])
|
|
|
|
index = words.index("creature")
|
|
self.play(
|
|
ShowCreation(col_rect),
|
|
grid_lines.animate.set_stroke(opacity=0.5),
|
|
ndp_columns[:index].animate.set_opacity(0.35),
|
|
ndp_columns[index + 1:].animate.set_opacity(0.35),
|
|
FadeIn(weight_words, lag_ratio=0.1)
|
|
)
|
|
self.wait()
|
|
|
|
# Show softmax of each columns
|
|
self.set_floor_plane("xz")
|
|
col_arrays = [np.array([num.get_value() for num in col]) for col in ndp_columns]
|
|
softmax_arrays = list(map(softmax, col_arrays))
|
|
softmax_cols = VGroup(
|
|
VGroup(DecimalNumber(v) for v in softmax_array)
|
|
for softmax_array in softmax_arrays
|
|
)
|
|
sm_arrows = VGroup()
|
|
sm_labels = VGroup()
|
|
sm_rects = VGroup()
|
|
for sm_col, col in zip(softmax_cols, ndp_columns):
|
|
for sm_val, val in zip(sm_col, col):
|
|
sm_val.move_to(val)
|
|
sm_col.save_state()
|
|
sm_col.shift(6 * OUT)
|
|
sm_rect = SurroundingRectangle(sm_col)
|
|
sm_rect.match_style(col_rect)
|
|
VGroup(sm_col, sm_rect).rotate(30 * DEGREES, DOWN)
|
|
arrow = Arrow(col, sm_col.get_center() + SMALL_BUFF * RIGHT + IN)
|
|
label = Text("softmax", font_size=72)
|
|
label.set_backstroke(BLACK, 5)
|
|
label.rotate(90 * DEGREES, DOWN)
|
|
label.next_to(arrow, UP)
|
|
sm_arrows.add(arrow)
|
|
sm_labels.add(label)
|
|
sm_rects.add(sm_rect)
|
|
|
|
index = words.index("creature")
|
|
self.play(
|
|
frame.animate.reorient(-47, -7, 0, (-2.48, -5.84, -1.09), 20),
|
|
GrowArrow(sm_arrows[index], time_span=(1, 2)),
|
|
FadeIn(sm_labels[index], lag_ratio=0.1, time_span=(1, 2)),
|
|
TransformFromCopy(ndp_columns[index], softmax_cols[index], time_span=(1.5, 3)),
|
|
TransformFromCopy(col_rect, sm_rects[index], time_span=(1.5, 3)),
|
|
FadeOut(weight_words),
|
|
run_time=3
|
|
)
|
|
self.wait()
|
|
|
|
remaining_indices = [*range(index), *range(index + 1, len(ndp_columns))]
|
|
last_index = index
|
|
for index in remaining_indices:
|
|
self.play(
|
|
ndp_columns[last_index].animate.set_opacity(0.35),
|
|
ndp_columns[index].animate.set_opacity(1),
|
|
col_rect.animate.move_to(ndp_columns[index]),
|
|
softmax_cols[last_index].animate.set_opacity(0.25),
|
|
*map(FadeOut, [sm_rects[last_index], sm_arrows[last_index], sm_labels[last_index]]),
|
|
)
|
|
self.play(
|
|
GrowArrow(sm_arrows[index]),
|
|
FadeIn(sm_labels[index], lag_ratio=0.1),
|
|
TransformFromCopy(ndp_columns[index], softmax_cols[index]),
|
|
TransformFromCopy(col_rect, sm_rects[index]),
|
|
)
|
|
last_index = index
|
|
self.play(
|
|
FadeOut(col_rect),
|
|
*map(FadeOut, [sm_rects[last_index], sm_arrows[last_index], sm_labels[last_index]]),
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
frame.animate.reorient(0, 0, 0, (-2.64, -4.8, 0.0), 14.54),
|
|
LaggedStartMap(Restore, softmax_cols, lag_ratio=0.1),
|
|
FadeOut(ndp_columns, time_span=(0, 1.5)),
|
|
run_time=3,
|
|
)
|
|
self.wait()
|
|
|
|
# Label attention pattern
|
|
for n, row in enumerate(dots):
|
|
if n not in [3, 7]:
|
|
row[n].set_width(0.7 + 0.2 * random.random())
|
|
dots[1][3].set_width(0.6 + 0.1 * random.random())
|
|
dots[2][3].set_width(0.6 + 0.1 * random.random())
|
|
dots[6][7].set_width(0.9 + 0.1 * random.random())
|
|
|
|
pattern_words = Text("Attention\nPattern", font_size=120)
|
|
pattern_words.move_to(grid_lines, UL).shift(LEFT)
|
|
|
|
self.play(
|
|
FadeOut(softmax_cols, lag_ratio=0.001),
|
|
FadeIn(dots, lag_ratio=0.001),
|
|
Write(pattern_words),
|
|
run_time=2
|
|
)
|
|
self.wait()
|
|
|
|
# Preview masking
|
|
masked_dots = VGroup()
|
|
for n, row in enumerate(dots):
|
|
masked_dots.add(*row[:n])
|
|
mask_rects = VGroup()
|
|
for dot in masked_dots:
|
|
mask_rect = Square(0.5)
|
|
mask_rect.set_stroke(RED, 2)
|
|
mask_rect.move_to(dot)
|
|
mask_rects.add(mask_rect)
|
|
|
|
lag_ratio=1.0 / len(mask_rects)
|
|
self.play(ShowCreation(mask_rects, lag_ratio=lag_ratio))
|
|
self.play(
|
|
LaggedStart(
|
|
(dot.animate.scale(0) for dot in masked_dots),
|
|
lag_ratio=lag_ratio
|
|
)
|
|
)
|
|
self.play(
|
|
FadeOut(mask_rects, lag_ratio=lag_ratio)
|
|
)
|
|
self.wait()
|
|
|
|
# Set aside keys and queries
|
|
pattern = VGroup(grid_lines, dots)
|
|
for group in q_groups:
|
|
group.sort(lambda p: -p[1])
|
|
group.target = group.generate_target()
|
|
m3 = len(group) - 3
|
|
group.target[m3:].scale(0, about_edge=DOWN)
|
|
group.target[:m3].move_to(group, DOWN)
|
|
|
|
self.play(
|
|
frame.animate.move_to((-2.09, -5.59, 0.0)).set_height(12.95).set_anim_args(run_time=3),
|
|
LaggedStartMap(MoveToTarget, q_groups),
|
|
FadeOut(pattern_words),
|
|
v_lines.animate.stretch(0.95, 1, about_edge=DOWN),
|
|
)
|
|
self.play(
|
|
LaggedStartMap(FadeOut, k_syms, shift=0.5 * DOWN, lag_ratio=0.1),
|
|
LaggedStartMap(FadeOut, wk_syms, shift=0.5 * DOWN, lag_ratio=0.1),
|
|
)
|
|
self.wait()
|
|
|
|
# Add values
|
|
value_color = RED
|
|
big_wv_sym = Tex(R"W_V", font_size=90)
|
|
big_wv_sym.set_color(value_color)
|
|
big_wv_sym.next_to(h_lines, UP, MED_LARGE_BUFF, LEFT)
|
|
wv_word = Text("Value matrix", font_size=90)
|
|
wv_word.next_to(big_wv_sym, UP, MED_LARGE_BUFF)
|
|
wv_word.set_color(value_color)
|
|
|
|
wv_arrows = wk_arrows
|
|
v_sym_template = Tex(R"\vec{\textbf{V}}_{0}")
|
|
v_sym_template[0].scale(1.5, about_edge=DOWN)
|
|
v_sym_template.set_fill(value_color, border_width=1)
|
|
subscript = v_sym_template.make_number_changeable("0")
|
|
|
|
wv_syms = VGroup()
|
|
v_syms = VGroup()
|
|
for n, arrow in enumerate(wv_arrows, start=1):
|
|
wv_sym = Tex("W_V", font_size=36)
|
|
wv_sym.set_fill(value_color, border_width=1)
|
|
wv_sym.next_to(arrow, UP, buff=0.2, aligned_edge=LEFT)
|
|
subscript.set_value(n)
|
|
v_sym = v_sym_template.copy()
|
|
v_sym.next_to(arrow, RIGHT, MED_SMALL_BUFF)
|
|
|
|
v_syms.add(v_sym)
|
|
wv_syms.add(wv_sym)
|
|
|
|
self.play(
|
|
FadeIn(big_wv_sym, 0.5 * DOWN),
|
|
FadeIn(wv_word, lag_ratio=0.1),
|
|
)
|
|
self.play(
|
|
LaggedStart(
|
|
(TransformFromCopy(big_wv_sym, wv_sym)
|
|
for wv_sym in wv_syms),
|
|
lag_ratio=0.15,
|
|
),
|
|
run_time=3
|
|
)
|
|
self.play(
|
|
LaggedStart(
|
|
(TransformFromCopy(e_sym, v_sym)
|
|
for e_sym, v_sym in zip(key_emb_syms, v_syms)),
|
|
lag_ratio=0.15,
|
|
),
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
FadeTransform(v_syms, k_syms),
|
|
FadeTransform(wv_syms, wk_syms),
|
|
rate_func=there_and_back_with_pause,
|
|
run_time=3,
|
|
)
|
|
self.remove(k_syms, wk_syms)
|
|
self.add(v_syms, wv_syms)
|
|
self.wait()
|
|
|
|
# Show column of weights
|
|
index = words.index("creature")
|
|
weighted_sum_cols = VGroup()
|
|
for sm_col in softmax_cols:
|
|
weighted_sum_col = VGroup()
|
|
for weight, v_sym in zip(sm_col, v_syms):
|
|
product = VGroup(weight, v_sym.copy())
|
|
product.target = product.generate_target()
|
|
product.target.arrange(RIGHT)
|
|
product.target[1].shift(UP * (
|
|
product.target[0].get_y(DOWN) -
|
|
product.target[1][1].get_y(DOWN)
|
|
))
|
|
product.target.scale(0.75)
|
|
product.target.move_to(weight)
|
|
product.target.set_fill(
|
|
opacity=clip(0.6 + weight.get_value(), 0, 1)
|
|
)
|
|
weighted_sum_col.add(product)
|
|
weighted_sum_cols.add(weighted_sum_col)
|
|
|
|
self.play(
|
|
FadeOut(dots, lag_ratio=0.1),
|
|
FadeIn(q_fade_rects[:index]),
|
|
FadeIn(q_fade_rects[index + 1:]),
|
|
FadeIn(softmax_cols[index]),
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
LaggedStartMap(MoveToTarget, weighted_sum_cols[index])
|
|
)
|
|
self.wait()
|
|
|
|
# Emphasize fluffy and blue weights
|
|
rects = VGroup(
|
|
key_word_groups[i][0].copy()
|
|
for i in [1, 2]
|
|
)
|
|
alt_rects = VGroup(
|
|
SurroundingRectangle(value, buff=SMALL_BUFF)
|
|
for value in (* softmax_cols[index][:1], *softmax_cols[index][3:])
|
|
)
|
|
alt_rects.set_stroke(RED, 1)
|
|
self.play(
|
|
LaggedStart(
|
|
(rect.animate.surround(value)
|
|
for rect, value in zip(rects, softmax_cols[index][1:3])),
|
|
lag_ratio=0.2,
|
|
)
|
|
)
|
|
self.wait()
|
|
self.play(Transform(rects, alt_rects))
|
|
self.wait()
|
|
self.play(FadeOut(rects, lag_ratio=0.1))
|
|
|
|
# Show sum
|
|
emb_sym = emb_syms[index]
|
|
ws_col = weighted_sum_cols[index]
|
|
creature = images[2]
|
|
creature.set_height(1.5)
|
|
creature.next_to(word_groups[index], UP)
|
|
|
|
emb_sym.target = emb_sym.generate_target()
|
|
emb_sym.target.scale(1.25, about_edge=UP)
|
|
sum_rect = SurroundingRectangle(emb_sym.target)
|
|
sum_rect.set_stroke(YELLOW, 2)
|
|
sum_rect.target = sum_rect.generate_target()
|
|
sum_rect.target.surround(VGroup(emb_sym.target, ws_col), buff=MED_SMALL_BUFF)
|
|
plusses = VGroup()
|
|
for m1, m2 in zip([emb_sym.target, *ws_col], ws_col):
|
|
plus = Tex(R"+", font_size=72)
|
|
plus.move_to(midpoint(m1.get_bottom(), m2.get_top()))
|
|
plusses.add(plus)
|
|
|
|
self.play(
|
|
frame.animate.reorient(0, 0, 0, (-2.6, -4.79, 0.0), 15.07).set_anim_args(run_time=2),
|
|
MoveToTarget(emb_sym),
|
|
ShowCreation(sum_rect),
|
|
FadeIn(creature, UP),
|
|
FadeOut(wv_word),
|
|
FadeOut(big_wv_sym),
|
|
)
|
|
self.add(Point(), q_fade_rects[index + 1:]) # Hack
|
|
self.wait()
|
|
self.play(
|
|
frame.animate.reorient(0, 0, 0, (-2.9, -6.5, 0.0), 19).set_anim_args(run_time=2),
|
|
MoveToTarget(sum_rect, run_time=2),
|
|
Write(plusses),
|
|
)
|
|
self.wait()
|
|
|
|
# Finish sum
|
|
low_arrows = VGroup(
|
|
Vector(DOWN).next_to(wsc[-1].target, DOWN)
|
|
for wsc in weighted_sum_cols
|
|
)
|
|
for sym, arrow in zip(emb_sym_primes, low_arrows):
|
|
sym.match_height(emb_sym)
|
|
sym.next_to(arrow, DOWN)
|
|
blue_fluff.set_height(2.5)
|
|
blue_fluff.next_to(emb_sym_primes[index], buff=MED_LARGE_BUFF, aligned_edge=UP)
|
|
|
|
self.play(
|
|
TransformFromCopy(emb_syms[index], emb_sym_primes[index]),
|
|
LaggedStart(
|
|
(FadeTransform(prod.copy(), emb_sym_primes[index])
|
|
for prod in ws_col),
|
|
lag_ratio=0.05,
|
|
group_type=Group
|
|
),
|
|
ShowCreation(low_arrows[index]),
|
|
FadeTransform(creature.copy(), blue_fluff)
|
|
)
|
|
self.wait()
|
|
|
|
# Map it over all vectors
|
|
plus_groups = VGroup(
|
|
plusses.copy().match_x(col[0].target)
|
|
for col in weighted_sum_cols
|
|
)
|
|
plus_groups.set_fill(GREY_C, 1)
|
|
|
|
for col in softmax_cols:
|
|
for value in col:
|
|
value.set_fill(
|
|
opacity=clip(0.6 + value.get_value(), 0, 1)
|
|
)
|
|
|
|
self.play(
|
|
frame.animate.reorient(0, 0, 0, (-2.76, -7, 0.0), 16),
|
|
FadeOut(sum_rect),
|
|
FadeOut(creature),
|
|
FadeOut(blue_fluff),
|
|
FadeOut(q_fade_rects[:index]),
|
|
FadeOut(q_fade_rects[index + 1:]),
|
|
FadeIn(softmax_cols[:index]),
|
|
FadeIn(softmax_cols[index + 1:]),
|
|
plusses.animate.set_fill(GREY_C, 1),
|
|
)
|
|
self.play(
|
|
LaggedStart(
|
|
(LaggedStartMap(MoveToTarget, col)
|
|
for col in weighted_sum_cols),
|
|
lag_ratio=0.1
|
|
),
|
|
v_lines.animate.set_stroke(GREY_B, 3, 1),
|
|
*(
|
|
e_sym.animate.scale(1.25, about_edge=UP)
|
|
for e_sym in (*emb_syms[:index], *emb_syms[index + 1:])
|
|
),
|
|
)
|
|
other_indices = [*range(index), *range(index + 1, len(plus_groups))]
|
|
self.play(LaggedStart(
|
|
(LaggedStart(
|
|
FadeIn(plus_groups[j], lag_ratio=0.1),
|
|
GrowArrow(low_arrows[j]),
|
|
LaggedStart(
|
|
(FadeTransform(ws.copy(), emb_sym_primes[j])
|
|
for ws in weighted_sum_cols[j]),
|
|
lag_ratio=0.05,
|
|
group_type=Group
|
|
),
|
|
lag_ratio=0.25,
|
|
)
|
|
for j in other_indices),
|
|
lag_ratio=0.01,
|
|
group_type=Group
|
|
))
|
|
self.wait()
|
|
|
|
def bake_mobject_into_vector_entries(self, mob, vector, path_arc=30 * DEGREES, group_type=None):
|
|
entries = vector.get_entries()
|
|
mob_copies = mob.replicate(len(entries))
|
|
return AnimationGroup(
|
|
LaggedStart(
|
|
(FadeOutToPoint(mc, entry.get_center(), path_arc=path_arc)
|
|
for mc, entry in zip(mob_copies, entries)),
|
|
lag_ratio=0.05,
|
|
group_type=group_type,
|
|
run_time=2,
|
|
remover=True
|
|
),
|
|
RandomizeMatrixEntries(
|
|
vector,
|
|
rate_func=lambda t: clip(smooth(2 * t - 1), 0, 1),
|
|
run_time=2
|
|
),
|
|
)
|
|
|
|
|
|
class RoadNotTaken(InteractiveScene):
|
|
def construct(self):
|
|
# Add poem
|
|
kw = dict(alignment="LEFT")
|
|
stanzas = VGroup(
|
|
Text("""
|
|
Two roads diverged in a yellow wood,
|
|
And sorry I could not travel both
|
|
And be one traveler, long I stood
|
|
And looked down one as far as I could
|
|
To where it bent in the undergrowth;
|
|
""", **kw),
|
|
Text("""
|
|
Then took the other, as just as fair,
|
|
And having perhaps the better claim,
|
|
Because it was grassy and wanted wear;
|
|
Though as for that the passing there
|
|
Had worn them really about the same,
|
|
""", **kw),
|
|
Text("""
|
|
And both that morning equally lay
|
|
In leaves no step had trodden black.
|
|
Oh, I kept the first for another day!
|
|
Yet knowing how way leads on to way,
|
|
I doubted if I should ever come back.
|
|
""", **kw),
|
|
Text("""
|
|
I shall be telling this with a sigh
|
|
Somewhere ages and ages hence:
|
|
Two roads diverged in a wood, and I—
|
|
I took the one less traveled by,
|
|
And that has made all the difference.
|
|
""", **kw),
|
|
)
|
|
stanzas.arrange_in_grid(h_buff=1.5, v_buff=1.0, fill_rows_first=False)
|
|
stanzas.set_width(FRAME_WIDTH - 1)
|
|
stanzas.move_to(0.5 * UP)
|
|
|
|
self.play(
|
|
FadeIn(stanzas, lag_ratio=0.01, run_time=4)
|
|
)
|
|
self.wait()
|
|
|
|
# Note all text until "one"
|
|
|
|
# Highlight "two roads"
|
|
|
|
# Highlight "took the other" and "grassy and wanted wear"
|
|
|
|
# Somehow higlight words throughout
|
|
|
|
|
|
|