From c37b90d6625d40098a1a19c83811537ceac6ddae Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Mon, 25 Mar 2024 19:09:46 -0300 Subject: [PATCH] More misc. animations for transformers --- _2024/transformers/attention.py | 1109 ++++++++++++++++++++++++++-- _2024/transformers/embedding.py | 289 +++++++- _2024/transformers/ml_basics.py | 2 - _2024/transformers/network_flow.py | 171 ++++- _2024/transformers/supplements.py | 38 +- 5 files changed, 1538 insertions(+), 71 deletions(-) diff --git a/_2024/transformers/attention.py b/_2024/transformers/attention.py index 65d2518..a315e17 100644 --- a/_2024/transformers/attention.py +++ b/_2024/transformers/attention.py @@ -1,3 +1,4 @@ +from sqlalchemy.sql.base import _DialectArgDict from manim_imports_ext import * from _2024.transformers.helpers import * @@ -5,10 +6,10 @@ from _2024.transformers.helpers import * class AttentionPatterns(InteractiveScene): def construct(self): # Add sentence - phrase = " the fluffy blue creature foraged in a verdant forest" + phrase = " a fluffy blue creature roamed the verdant forest" phrase_mob = Text(phrase) phrase_mob.move_to(2 * UP) - words = phrase.split(" ") + words = list(filter(lambda s: s.strip(), phrase.split(" "))) word2mob: Dict[str, VMobject] = { word: phrase_mob[" " + word][0] for word in words @@ -33,7 +34,7 @@ class AttentionPatterns(InteractiveScene): # Adjectives updating noun adjs = ["fluffy", "blue", "verdant"] nouns = ["creature", "forest"] - others = ["the", "foraged", "in", "a"] + others = ["a", "roamed", "the"] adj_mobs, noun_mobs, other_mobs = [ VGroup(word2mob[substr] for substr in group) for group in [adjs, nouns, others] @@ -113,7 +114,7 @@ class AttentionPatterns(InteractiveScene): ImageMobject(f"Dalle3_{word}").set_height(1.1).next_to(word2rect[word], UP) for word in ["fluffy", "blue", "creature", "verdant", "forest"] ) - image_vects = VGroup(embeddings[i] for i in [1, 2, 3, 7, 8]) + image_vects = VGroup(embeddings[i] for i in [1, 2, 3, 6, 7]) self.play( LaggedStartMap(FadeIn, images, scale=2, lag_ratio=0.05) @@ -159,25 +160,29 @@ class AttentionPatterns(InteractiveScene): # Collapse vectors template = Tex(R"\vec{\textbf{E}}_{0}") + template[0].scale(1.5, about_edge=DOWN) dec = template.make_number_changeable(0) - vect_syms = VGroup() + emb_syms = VGroup() for n, rect in enumerate(all_rects, start=1): dec.set_value(n) sym = template.copy() - sym.next_to(rect, DOWN, buff=LARGE_BUFF) + sym.next_to(rect, DOWN, buff=0.75) sym.set_color(GREY_A) - vect_syms.add(sym) - prev_center = vect_syms.get_center() - vect_syms.arrange_to_fit_width(vect_syms.get_width()) - vect_syms.move_to(prev_center) + emb_syms.add(sym) + for subgroup in [emb_syms[:4], emb_syms[4:]]: + subgroup.arrange_to_fit_width(subgroup.get_width()) emb_arrows.target = emb_arrows.generate_target() - for arrow, rect, sym in zip(emb_arrows.target, all_rects, vect_syms): - arrow.become(Arrow( - rect.get_bottom(), - sym[0].get_top(), - buff=SMALL_BUFF - )) + + for rect, arrow, sym in zip(all_rects, emb_arrows.target, emb_syms): + x_min = rect.get_x(LEFT) + x_max = rect.get_x(RIGHT) + low_point = sym[0].get_top() + if x_min < low_point[0] < x_max: + top_point = np.array([low_point[0], rect.get_y(DOWN), 0]) + else: + top_point = rect.get_bottom() + arrow.become(Arrow(top_point, low_point, buff=SMALL_BUFF)) all_brackets = VGroup(emb.get_brackets() for emb in embeddings) for brackets in all_brackets: @@ -185,10 +190,11 @@ class AttentionPatterns(InteractiveScene): brackets.target.stretch(0, 1, about_edge=UP) brackets.target.set_fill(opacity=0) - ghost_syms = vect_syms.copy() + ghost_syms = emb_syms.copy() ghost_syms.set_opacity(0) self.play( + frame.animate.set_x(0).set_anim_args(run_time=2), LaggedStart( (AnimationGroup( LaggedStart( @@ -203,56 +209,59 @@ class AttentionPatterns(InteractiveScene): for sym, embedding, brackets in zip(ghost_syms, embeddings, all_brackets)), group_type=Group ), - LaggedStartMap(FadeIn, vect_syms, shift=UP), + LaggedStartMap(FadeIn, emb_syms, shift=UP), brace.animate.stretch(0.25, 1, about_edge=UP).set_opacity(0), FadeOut(dim_value, 0.25 * UP), MoveToTarget(emb_arrows, lag_ratio=0.1, run_time=2), - LaggedStartMap(FadeOut, pos_labels, shift=UP) + LaggedStartMap(FadeOut, pos_labels, shift=UP), ) + emb_arrows.refresh_bounding_box(recurse_down=True) # Why? self.clear() - self.add(emb_arrows, all_rects, word_mobs, images, vect_syms) + self.add(emb_arrows, all_rects, word_mobs, images, emb_syms) self.wait() # Preview desired updates - vect_sym_primes = VGroup( + emb_sym_primes = VGroup( sym.copy().add(Tex("'").move_to(sym.get_corner(UR) + 0.05 * DL)) - for sym in vect_syms + for sym in emb_syms ) - vect_sym_primes.shift(2 * DOWN) - vect_sym_primes.set_color(TEAL) + emb_sym_primes.shift(2 * DOWN) + emb_sym_primes.set_color(TEAL) - full_connections = VGroup( - Line(sym1.get_bottom(), sym2.get_top(), buff=SMALL_BUFF) - for sym2 in vect_sym_primes - for sym1 in vect_syms - ) - full_connections.set_stroke(GREY_B, 1) + full_connections = VGroup() + for i, sym1 in enumerate(emb_syms, start=1): + for j, sym2 in enumerate(emb_sym_primes, start=1): + line = Line(sym1.get_bottom(), sym2.get_top(), buff=SMALL_BUFF) + line.set_stroke(GREY_B, width=random.random()**2, opacity=random.random()**0.25) + if (i, j) in [(2, 4), (3, 4), (4, 4), (7, 8), (8, 8)]: + line.set_stroke(WHITE, width=2 + random.random(), opacity=1) + full_connections.add(line) blue_fluff = ImageMobject("BlueFluff") verdant_forest = ImageMobject("VerdantForest") - for n, image in [(3, blue_fluff), (8, verdant_forest)]: + for n, image in [(3, blue_fluff), (7, verdant_forest)]: image.match_height(images) image.scale(1.2) - image.next_to(vect_sym_primes[n], DOWN, buff=MED_SMALL_BUFF) + image.next_to(emb_sym_primes[n], DOWN, buff=MED_SMALL_BUFF) self.play( ShowCreation(full_connections, lag_ratio=0.01, run_time=2), LaggedStart( (TransformFromCopy(sym1, sym2) - for sym1, sym2 in zip(vect_syms, vect_sym_primes)), + for sym1, sym2 in zip(emb_syms, emb_sym_primes)), lag_ratio=0.05, ), ) self.wait() self.play(LaggedStart( LaggedStart( - (FadeTransform(im.copy(), blue_fluff) + (FadeTransform(im.copy(), blue_fluff, remover=True) for im in images[:3]), lag_ratio=0.02, group_type=Group ), LaggedStart( - (FadeTransform(im.copy(), verdant_forest) + (FadeTransform(im.copy(), verdant_forest, remover=True) for im in images[3:]), lag_ratio=0.02, group_type=Group @@ -260,26 +269,985 @@ class AttentionPatterns(InteractiveScene): lag_ratio=0.5, run_time=2 )) + self.add(blue_fluff, verdant_forest) + self.wait() # Show black box that matrix multiples can be added to - in_arrows = VGroup(Vector(0.5 * DOWN).next_to(sym, DOWN) for sym in vect_syms) - dots = VGroup(Tex(R"\vdots").next_to(arrow, DOWN) for arrow in in_arrows) - box = Rectangle(vect_syms.get_width() + 1, 1.5) + in_arrows = VGroup( + Vector(0.25 * DOWN, max_width_to_length_ratio=12.0).next_to(sym, DOWN, SMALL_BUFF) + for sym in emb_syms + ) + box = Rectangle(15.0, 3.0) box.set_fill(GREY_E, 1) box.set_stroke(WHITE, 1) - box.next_to(in_arrows, DOWN) + box.next_to(in_arrows, DOWN, SMALL_BUFF) out_arrows = in_arrows.copy() out_arrows.next_to(box, DOWN) - self.add(box) - self.add(out_arrows) - self.add(vect_sym_primes) + self.play( + FadeIn(box, 0.25 * DOWN), + LaggedStartMap(FadeIn, in_arrows, shift=0.25 * DOWN, lag_ratio=0.025), + LaggedStartMap(FadeIn, out_arrows, shift=0.25 * DOWN, lag_ratio=0.025), + FadeOut(full_connections), + emb_sym_primes.animate.next_to(out_arrows, DOWN, SMALL_BUFF), + MaintainPositionRelativeTo(blue_fluff, emb_sym_primes), + MaintainPositionRelativeTo(verdant_forest, emb_sym_primes), + frame.animate.set_height(10).move_to(4 * UP, UP), + ) + self.wait() + + # Clear the board + self.play( + frame.animate.set_height(8).move_to(2 * UP).set_anim_args(run_time=1.5), + LaggedStartMap(FadeOut, Group( + *images, in_arrows, box, out_arrows, emb_sym_primes, + blue_fluff, verdant_forest, + ), lag_ratio=0.1) + ) # Ask questions + word_groups = VGroup(VGroup(*pair) for pair in zip(all_rects, word_mobs)) + for group in word_groups: + group.save_state() + q_bubble = SpeechBubble("Any adjectives\nin front of me?") + q_bubble.move_tip_to(word2rect["creature"].get_top()) + + a_bubbles = SpeechBubble("I am!", direction=RIGHT).replicate(2) + a_bubbles[1].flip() + a_bubbles[0].move_tip_to(word2rect["fluffy"].get_top()) + a_bubbles[1].move_tip_to(word2rect["blue"].get_top()) + + self.play( + FadeIn(q_bubble), + word_groups[:3].animate.fade(0.75), + word_groups[4:].animate.fade(0.75), + ) + self.wait() + self.play(LaggedStart( + Restore(word_groups[1]), + Restore(word_groups[2]), + *map(Write, a_bubbles), + lag_ratio=0.5 + )) + self.wait() # Associate questions with vectors + a_bubbles.save_state() + q_arrows = VGroup( + Vector(0.75 * DOWN).next_to(sym, DOWN, SMALL_BUFF) + for sym in emb_syms + ) + q_vects = VGroup( + NumericEmbedding(length=7).set_height(2).next_to(arrow, DOWN) + for arrow in q_arrows + ) + question = q_bubble.content - # Show matrices + + index = words.index("creature") + q_vect = q_vects[index] + q_arrow = q_arrows[index] + self.play(LaggedStart( + FadeOut(q_bubble.body, DOWN), + question.animate.scale(0.75).next_to(q_vect, RIGHT), + FadeIn(q_vect, DOWN), + GrowArrow(q_arrow), + frame.animate.move_to(ORIGIN), + a_bubbles.animate.fade(0.5), + )) + self.play( + self.bake_mobject_into_vector_entries(question, q_vect) + ) + self.wait() + + # Label query vector + brace = Brace(q_vect, LEFT, SMALL_BUFF) + query_word = Text("Query") + query_word.set_color(YELLOW) + query_word.next_to(brace, LEFT, SMALL_BUFF) + dim_text = Text("128-dimensional", font_size=36) + dim_text.set_color(YELLOW) + dim_text.next_to(brace, LEFT, SMALL_BUFF) + dim_text.set_y(query_word.get_y(DOWN)) + + self.play( + GrowFromCenter(brace), + FadeIn(query_word, 0.25 * LEFT), + ) + self.wait() + self.play( + query_word.animate.next_to(dim_text, UP, SMALL_BUFF), + FadeIn(dim_text, 0.1 * DOWN), + ) + self.wait() + + # Show individual matrix product + e_vect = NumericEmbedding(length=12) + e_vect.match_width(q_vect) + e_vect.next_to(q_vect, DR, buff=1.5) + matrix = WeightMatrix(shape=(7, 12)) + matrix.match_height(q_vect) + matrix.next_to(e_vect, LEFT) + e_label_copy = emb_syms[index].copy() + e_label_copy.next_to(e_vect, UP) + q_vect.save_state() + ghost_q_vect = NumericEmbedding(length=7).match_height(q_vect) + ghost_q_vect.get_columns().set_opacity(0) + ghost_q_vect.get_brackets().space_out_submobjects(1.75) + ghost_q_vect.next_to(e_vect, RIGHT, buff=0.7) + + mat_brace = Brace(matrix, UP) + mat_label = Tex("W_Q") + mat_label.next_to(mat_brace, UP, SMALL_BUFF) + mat_label.set_color(YELLOW) + + self.play( + frame.animate.set_height(11).move_to(all_rects, UP).shift(0.35 * UP), + FadeOut(a_bubbles), + FadeInFromPoint(e_vect, emb_syms[index].get_center()), + FadeInFromPoint(matrix, q_arrow.get_center()), + TransformFromCopy(emb_syms[index], e_label_copy), + FadeOut(q_vect), + TransformFromCopy(q_vect, ghost_q_vect), + MaintainPositionRelativeTo(question, q_vect), + ) + self.play( + GrowFromCenter(mat_brace), + FadeIn(mat_label, 0.1 * UP), + ) + self.remove(ghost_q_vect) + eq, rhs = show_matrix_vector_product(self, matrix, e_vect) + + new_q_vect = rhs.deepcopy() + new_q_vect.move_to(q_vect, LEFT) + + self.play( + TransformFromCopy(rhs, new_q_vect, path_arc=PI / 2), + question.animate.next_to(new_q_vect, RIGHT) + ) + self.wait() + + # Collapse query vector + q_sym_template = Tex(R"\vec{\textbf{Q}}_0", font_size=48) + q_sym_template[0].scale(1.5, about_edge=DOWN) + q_sym_template.set_color(YELLOW) + subscript = q_sym_template.make_number_changeable(0) + q_syms = VGroup() + for n, arrow in enumerate(q_arrows, start=1): + subscript.set_value(n) + sym = q_sym_template.copy() + sym.next_to(arrow, DOWN, SMALL_BUFF) + q_syms.add(sym) + + mat_label2 = mat_label.copy() + + q_sym = q_syms[index] + low_q_sym = q_sym.copy() + low_q_sym.next_to(rhs, UP) + globals().update(locals()) + + self.play(LaggedStart( + LaggedStart( + (FadeTransform(entry, q_sym, remover=True) + for entry in new_q_vect.get_columns()[0]), + lag_ratio=0.01, + group_type=Group, + ), + new_q_vect.get_brackets().animate.stretch(0, 1, about_edge=UP).set_opacity(0), + FadeOutToPoint(query_word, q_sym.get_center()), + FadeOutToPoint(dim_text, q_sym.get_center()), + FadeOut(brace), + question.animate.next_to(q_sym, DOWN), + FadeIn(low_q_sym, UP), + lag_ratio=0.1, + )) + self.remove(new_q_vect) + self.add(q_sym) + self.play( + mat_label2.animate.scale(0.9).next_to(q_arrow, RIGHT, buff=0.15), + ) + self.wait() + + # E to Q rects + e_rects = VGroup(map(SurroundingRectangle, [emb_syms[index], e_vect])) + q_rects = VGroup(map(SurroundingRectangle, [q_sym, rhs])) + e_rects.set_stroke(TEAL, 3) + q_rects.set_stroke(YELLOW, 3) + self.play(ShowCreation(e_rects, lag_ratio=0.2)) + self.wait() + self.play(Transform(e_rects, q_rects)) + self.wait() + self.play(FadeOut(e_rects)) + + # Add other query vectors + remaining_q_arrows = VGroup(*q_arrows[:index], *q_arrows[index + 1:]) + remaining_q_syms = VGroup(*q_syms[:index], *q_syms[index + 1:]) + wq_syms = VGroup( + Tex(R"W_Q", font_size=30).next_to(arrow, RIGHT, buff=0.1) + for arrow in q_arrows + ) + wq_syms.set_color(YELLOW) + subscripts = VGroup(e_label_copy[-1], low_q_sym[-1][0]) + for subscript in subscripts: + i_sym = Tex("i") + i_sym.replace(subscript) + i_sym.scale(0.75) + i_sym.match_style(subscript) + subscript.target = i_sym + + self.play( + LaggedStartMap(GrowArrow, remaining_q_arrows), + LaggedStartMap(FadeIn, remaining_q_syms, shift=0.1 * DOWN), + ReplacementTransform(VGroup(mat_label2), wq_syms, lag_ratio=0.01, run_time=2), + question.animate.shift(0.25 * DOWN), + *map(Restore, word_groups), + *map(MoveToTarget, subscripts), + ) + self.wait() + + # Emphasize model weights + self.play( + LaggedStartMap(FlashAround, matrix.get_entries(), lag_ratio=1e-2), + RandomizeMatrixEntries(matrix), + ) + data_modifying_matrix(self, matrix, word_shape=(3, 8)) + self.wait() + self.play( + LaggedStartMap(FadeOut, VGroup( + matrix, mat_brace, mat_label, + e_vect, e_label_copy, eq, rhs, + low_q_sym + ), shift=0.2 * DR) + ) + self.wait() + + # Move question + noun_q_syms = VGroup(q_syms[words.index(word)] for word in ["creature", "forest"]) + + self.play( + question.animate.shift(0.25 * DOWN).match_x(noun_q_syms) + ) + + noun_q_lines = VGroup( + Line(question.get_corner(v), sym.get_corner(-v)) + for sym, v in zip(noun_q_syms, [UL, UR]) + ) + noun_q_lines.set_stroke(GREY, 1) + self.play(ShowCreation(noun_q_lines, lag_ratio=0)) + self.wait() + + # Set up keys + key_word_groups = word_groups.copy() + key_word_groups.arrange(DOWN, buff=0.75, aligned_edge=RIGHT) + key_word_groups.next_to(q_syms, DL, buff=LARGE_BUFF) + key_word_groups.shift(3.0 * LEFT) + key_emb_syms = emb_syms.copy() + + k_sym_template = Tex(R"\vec{\textbf{K}}_0", font_size=48) + k_sym_template[0].scale(1.5, about_edge=DOWN) + k_sym_template.set_color(TEAL) + subscript = k_sym_template.make_number_changeable(0) + + k_syms = VGroup() + key_emb_arrows = VGroup() + wk_arrows = VGroup() + wk_syms = VGroup() + for group, emb_sym, n in zip(key_word_groups, key_emb_syms, it.count(1)): + emb_arrow = Vector(0.5 * RIGHT) + emb_arrow.next_to(group, RIGHT, SMALL_BUFF) + emb_sym.next_to(emb_arrow, RIGHT, SMALL_BUFF) + wk_arrow = Vector(0.75 * RIGHT) + wk_arrow.next_to(emb_sym, RIGHT) + wk_sym = Tex("W_k", font_size=30) + wk_sym.set_fill(TEAL, border_width=1) + wk_sym.next_to(wk_arrow, UP) + subscript.set_value(n) + k_sym = k_sym_template.copy() + k_sym.next_to(wk_arrow, RIGHT, buff=MED_SMALL_BUFF) + + key_emb_arrows.add(emb_arrow) + wk_arrows.add(wk_arrow) + wk_syms.add(wk_sym) + k_syms.add(k_sym) + + self.play( + frame.animate.move_to(2.5 * LEFT + 2.75 * DOWN), + TransformFromCopy(word_groups, key_word_groups), + TransformFromCopy(emb_arrows, key_emb_arrows), + TransformFromCopy(emb_syms, key_emb_syms), + FadeOut(question), + FadeOut(noun_q_lines), + run_time=2, + ) + self.play( + LaggedStartMap(GrowArrow, wk_arrows), + LaggedStartMap(FadeIn, wk_syms, shift=0.1 * UP), + ) + self.play(LaggedStart( + (TransformFromCopy(e_sym, k_sym) + for e_sym, k_sym in zip(key_emb_syms, k_syms)), + lag_ratio=0.05, + )) + self.wait() + + # Show example key matrix + matrix = WeightMatrix(shape=(7, 12)) + matrix.set_width(5) + matrix.next_to(k_syms, UP, buff=2.0, aligned_edge=RIGHT) + mat_rect = SurroundingRectangle(matrix, buff=MED_SMALL_BUFF) + lil_rect = SurroundingRectangle(wk_syms[0]) + lines = VGroup( + Line(lil_rect.get_corner(v + UP), mat_rect.get_corner(v + DOWN)) + for v in [LEFT, RIGHT] + ) + VGroup(mat_rect, lil_rect, *lines).set_stroke(GREY_A, 1) + + self.play(ShowCreation(lil_rect)) + self.play( + ShowCreation(lines, lag_ratio=0), + TransformFromCopy(lil_rect, mat_rect), + FadeInFromPoint(matrix, lil_rect.get_center()), + ) + self.wait() + data_modifying_matrix(self, matrix, word_shape=(3, 8)) + self.play( + LaggedStartMap(FadeOut, VGroup(matrix, mat_rect, lines, lil_rect), run_time=1) + ) + + # Isolate examples + fade_rects = VGroup( + BackgroundRectangle(VGroup(key_word_groups[0], wk_syms[0], k_syms[0])), + BackgroundRectangle(VGroup(key_word_groups[3:], wk_syms[3:], k_syms[3:])), + BackgroundRectangle(wq_syms[2]), + BackgroundRectangle(VGroup(word_groups[:3], q_syms[:3])), + BackgroundRectangle(VGroup(word_groups[4:], q_syms[4:])), + ) + fade_rects.set_fill(BLACK, 0.75) + fade_rects.set_stroke(BLACK, 3, 1) + q_bubble = SpeechBubble("Any adjectives\nin front of me?") + q_bubble.flip(RIGHT) + q_bubble.next_to(q_syms[3][-1], DOWN, SMALL_BUFF, LEFT) + a_bubbles = SpeechBubble("I'm an adjective!\nI'm there!").replicate(2) + a_bubbles[0].pin_to(k_syms[1]) + a_bubbles[1].pin_to(k_syms[2]) + a_bubbles[1].flip(RIGHT, about_edge=DOWN) + a_bubbles[1].shift(0.5 * DOWN) + + self.add(fade_rects, word_groups[3]) + self.play(FadeIn(fade_rects)) + self.play(FadeIn(q_bubble, lag_ratio=0.1)) + self.play(FadeIn(a_bubbles, lag_ratio=0.05)) + self.wait() + self.play( + LaggedStartMap(FadeOut, VGroup(q_bubble, *a_bubbles), lag_ratio=0.25) + ) + self.wait() + + # Draw grid + emb_arrows.refresh_bounding_box(recurse_down=True) + q_groups = VGroup( + VGroup(group[i] for group in [ + emb_arrows, emb_syms, wq_syms, q_arrows, q_syms + ]) + for i in range(len(emb_arrows)) + ) + q_groups.target = q_groups.generate_target() + q_groups.target.arrange_to_fit_width(12, about_edge=LEFT) + q_groups.target.shift(0.25 * DOWN) + + word_groups.target = word_groups.generate_target() + for word_group, q_group in zip(word_groups.target, q_groups.target): + word_group.scale(0.7) + word_group.next_to(q_group[0], UP, SMALL_BUFF) + + h_lines = VGroup() + v_buff = 0.5 * (key_word_groups[0].get_y(DOWN) - key_word_groups[1].get_y(UP)) + for kwg in key_word_groups: + h_line = Line(LEFT, RIGHT).set_width(20) + h_line.next_to(kwg, UP, buff=v_buff) + h_line.align_to(key_word_groups, LEFT) + h_lines.add(h_line) + + v_lines = VGroup() + h_buff = 0.5 + for q_group in q_groups.target: + v_line = Line(UP, DOWN).set_height(14) + v_line.next_to(q_group, LEFT, buff=h_buff, aligned_edge=UP) + v_lines.add(v_line) + v_lines.add(v_lines[-1].copy().next_to(q_groups.target, RIGHT, 0.5, UP)) + + grid_lines = VGroup(*h_lines, *v_lines) + grid_lines.set_stroke(GREY_A, 1) + + self.play( + frame.animate.set_height(15, about_edge=UP).set_x(-2).set_anim_args(run_time=3), + MoveToTarget(q_groups), + MoveToTarget(word_groups), + ShowCreation(h_lines, lag_ratio=0.2), + ShowCreation(v_lines, lag_ratio=0.2), + FadeOut(fade_rects), + ) + + # Take all dot products + dot_prods = VGroup() + for k_sym in k_syms: + for q_sym in q_syms: + square_center = np.array([q_sym.get_x(), k_sym.get_y(), 0]) + dot = Tex(R".", font_size=72) + dot.move_to(square_center) + dot.set_fill(opacity=0) + dot_prod = VGroup(k_sym.copy(), dot, q_sym.copy()) + dot_prod.target = dot_prod.generate_target() + dot_prod.target.arrange(RIGHT, buff=0.15) + dot_prod.target.scale(0.65) + dot_prod.target.move_to(square_center) + dot_prod.target.set_fill(opacity=1) + dot_prods.add(dot_prod) + + self.play( + LaggedStartMap(MoveToTarget, dot_prods, lag_ratio=0.025, run_time=4) + ) + self.wait() + + # Show grid of dots + dots = VGroup( + VGroup(Dot().match_x(q_sym).match_y(k_sym) for q_sym in q_syms) + for k_sym in k_syms + ) + for n, row in enumerate(dots, start=1): + for k, dot in enumerate(row, start=1): + dot.set_fill(GREY_C, 0.8) + dot.set_width(random.random()) + dot.target = dot.generate_target() + dot.target.set_width(0.1 + 0.2 * random.random()) + if (n, k) in [(2, 4), (3, 4), (7, 8)]: + dot.target.set_width(0.8 + 0.2 * random.random()) + flat_dots = VGroup(*it.chain(*dots)) + + self.play( + dot_prods.animate.set_fill(opacity=0.75), + LaggedStartMap(GrowFromCenter, flat_dots) + ) + self.wait() + self.play(LaggedStartMap(MoveToTarget, flat_dots, lag_ratio=0.01)) + self.wait() + + # Resize to reflect true pattern + k_groups = VGroup( + VGroup(group[i] for group in [ + key_word_groups, key_emb_arrows, + key_emb_syms, wk_syms, wk_arrows, k_syms + ]) + for i in range(len(emb_arrows)) + ) + for q_group, word_group in zip(q_groups, word_groups): + q_group.add_to_back(word_group) + self.add(k_groups, q_groups, Point()) + + k_fade_rects = VGroup(map(BackgroundRectangle, k_groups)) + q_fade_rects = VGroup(map(BackgroundRectangle, q_groups)) + for rect in (*k_fade_rects, *q_fade_rects): + rect.scale(1.05) + rect.set_fill(BLACK, 0.8) + + self.play( + frame.animate.move_to([-4.33, -2.4, 0.0]).set_height(9.52), + FadeIn(k_fade_rects[:1]), + FadeIn(k_fade_rects[3:]), + FadeIn(q_fade_rects[:3]), + FadeIn(q_fade_rects[4:]), + run_time=2 + ) + self.wait() + + k_rects = VGroup(map(SurroundingRectangle, k_groups[1:3])) + k_rects.set_stroke(TEAL, 2) + q_rects = VGroup(SurroundingRectangle(q_groups[3])) + q_rects.set_stroke(YELLOW, 2) + + self.play( + ShowCreation(k_rects, lag_ratio=0.5, run_time=2), + LaggedStartMap( + FlashAround, k_groups[1:3], + color=TEAL, + time_width=2, + lag_ratio=0.25, + run_time=3 + ), + ) + self.wait() + self.play(TransformFromCopy(k_rects, q_rects)) + self.wait() + + # Show numerical dot product + high_dot_prods = VGroup(dot_prods[8 + 3], dot_prods[2 * 8 + 3]) + dots_to_grow = VGroup(dots[1][3], dots[2][3]) + numerical_dot_prods = VGroup( + VGroup( + DecimalNumber( + np.random.uniform(-100, 10), + include_sign=True, + font_size=42, + num_decimal_places=1, + edge_to_fix=ORIGIN, + ).move_to(dot) + for dot in row + ) + for row in dots + ) + for n, row in enumerate(numerical_dot_prods): + row[n].set_value(5 * random.random()) # Add some self relevance + flat_numerical_dot_prods = VGroup(*it.chain(*numerical_dot_prods)) + for ndp in flat_numerical_dot_prods: + ndp.set_fill(interpolate_color(RED_E, GREY_C, random.random())) + high_numerical_dot_prods = VGroup( + numerical_dot_prods[1][3], + numerical_dot_prods[2][3], + numerical_dot_prods[6][7], + ) + for hdp in high_numerical_dot_prods: + hdp.set_value(92 + 2 * random.random()) + hdp.set_color(WHITE) + low_numerical_dot_prod = numerical_dot_prods[5][3] + low_numerical_dot_prod.set_value(-31.4) + low_numerical_dot_prod.set_fill(RED_D) + + self.play( + *(dtg.animate.scale(1.25) for dtg in dots_to_grow), + *(CountInFrom(ndp, run_time=1) for ndp in high_numerical_dot_prods[:2]), + *(VFadeIn(ndp) for ndp in high_numerical_dot_prods[:2]), + *(FadeOut(dot_prod, run_time=0.5) for dot_prod in dot_prods), + ) + self.wait() + + # Show "attends to" + att_arrow = Arrow(k_rects.get_top(), q_rects.get_left(), path_arc=-90 * DEGREES) + att_words = TexText("``Attend to''", font_size=72) + att_words.next_to(att_arrow.pfp(0.4), UL) + + self.play( + ShowCreation(att_arrow), + Write(att_words), + ) + self.wait() + self.play(FadeOut(att_words), FadeOut(att_arrow)) + + # Contrast with "the" and "creature" + self.play( + frame.animate.move_to([-2.79, -3.66, 0.0]).set_height(12.29), + *(k_rect.animate.surround(k_groups[5]) for k_rect in k_rects), + FadeIn(k_fade_rects[1:3]), + FadeOut(k_fade_rects[5]), + run_time=2, + ) + self.play( + CountInFrom(low_numerical_dot_prod), + VFadeIn(low_numerical_dot_prod), + FadeOut(dots[5][3]), + ) + self.wait() + + # Zoom out on full grid + self.play( + frame.animate.move_to([-1.5, -4.8, 0.0]).set_height(15).set_anim_args(run_time=3), + LaggedStart( + FadeOut(k_rects), + FadeOut(q_rects), + FadeOut(k_fade_rects[:5]), + FadeOut(k_fade_rects[6:]), + FadeOut(q_fade_rects[:3]), + FadeOut(q_fade_rects[4:]), + FadeOut(dots), + LaggedStartMap(FadeIn, numerical_dot_prods), + Animation(high_numerical_dot_prods.copy(), remover=True), + Animation(low_numerical_dot_prod.copy(), remover=True), + ) + ) + self.wait() + + # Focus on one column + ndp_columns = VGroup( + VGroup(row[i] for row in numerical_dot_prods) + for i in range(len(numerical_dot_prods[0])) + ) + col_rect = SurroundingRectangle(ndp_columns[3], buff=0.25) + col_rect.set_stroke(YELLOW, 2) + weight_words = Text("We want these to\nact like weights", font_size=96) + weight_words.set_backstroke(BLACK, 8) + weight_words.next_to(col_rect, RIGHT, buff=MED_LARGE_BUFF) + weight_words.match_y(h_lines[2]) + + index = words.index("creature") + self.play( + ShowCreation(col_rect), + grid_lines.animate.set_stroke(opacity=0.5), + ndp_columns[:index].animate.set_opacity(0.35), + ndp_columns[index + 1:].animate.set_opacity(0.35), + FadeIn(weight_words, lag_ratio=0.1) + ) + self.wait() + + # Show softmax of each columns + self.set_floor_plane("xz") + col_arrays = [np.array([num.get_value() for num in col]) for col in ndp_columns] + softmax_arrays = list(map(softmax, col_arrays)) + softmax_cols = VGroup( + VGroup(DecimalNumber(v) for v in softmax_array) + for softmax_array in softmax_arrays + ) + sm_arrows = VGroup() + sm_labels = VGroup() + sm_rects = VGroup() + for sm_col, col in zip(softmax_cols, ndp_columns): + for sm_val, val in zip(sm_col, col): + sm_val.move_to(val) + sm_col.save_state() + sm_col.shift(6 * OUT) + sm_rect = SurroundingRectangle(sm_col) + sm_rect.match_style(col_rect) + VGroup(sm_col, sm_rect).rotate(30 * DEGREES, DOWN) + arrow = Arrow(col, sm_col.get_center() + SMALL_BUFF * RIGHT + IN) + label = Text("softmax", font_size=72) + label.set_backstroke(BLACK, 5) + label.rotate(90 * DEGREES, DOWN) + label.next_to(arrow, UP) + sm_arrows.add(arrow) + sm_labels.add(label) + sm_rects.add(sm_rect) + + index = words.index("creature") + self.play( + frame.animate.reorient(-47, -7, 0, (-2.48, -5.84, -1.09), 20), + GrowArrow(sm_arrows[index], time_span=(1, 2)), + FadeIn(sm_labels[index], lag_ratio=0.1, time_span=(1, 2)), + TransformFromCopy(ndp_columns[index], softmax_cols[index], time_span=(1.5, 3)), + TransformFromCopy(col_rect, sm_rects[index], time_span=(1.5, 3)), + FadeOut(weight_words), + run_time=3 + ) + self.wait() + + remaining_indices = [*range(index), *range(index + 1, len(ndp_columns))] + last_index = index + for index in remaining_indices: + self.play( + ndp_columns[last_index].animate.set_opacity(0.35), + ndp_columns[index].animate.set_opacity(1), + col_rect.animate.move_to(ndp_columns[index]), + softmax_cols[last_index].animate.set_opacity(0.25), + *map(FadeOut, [sm_rects[last_index], sm_arrows[last_index], sm_labels[last_index]]), + ) + self.play( + GrowArrow(sm_arrows[index]), + FadeIn(sm_labels[index], lag_ratio=0.1), + TransformFromCopy(ndp_columns[index], softmax_cols[index]), + TransformFromCopy(col_rect, sm_rects[index]), + ) + last_index = index + self.play( + FadeOut(col_rect), + *map(FadeOut, [sm_rects[last_index], sm_arrows[last_index], sm_labels[last_index]]), + ) + self.wait() + self.play( + frame.animate.reorient(0, 0, 0, (-2.64, -4.8, 0.0), 14.54), + LaggedStartMap(Restore, softmax_cols, lag_ratio=0.1), + FadeOut(ndp_columns, time_span=(0, 1.5)), + run_time=3, + ) + self.wait() + + # Label attention pattern + for n, row in enumerate(dots): + if n not in [3, 7]: + row[n].set_width(0.7 + 0.2 * random.random()) + dots[1][3].set_width(0.6 + 0.1 * random.random()) + dots[2][3].set_width(0.6 + 0.1 * random.random()) + dots[6][7].set_width(0.9 + 0.1 * random.random()) + + pattern_words = Text("Attention\nPattern", font_size=120) + pattern_words.move_to(grid_lines, UL).shift(LEFT) + + self.play( + FadeOut(softmax_cols, lag_ratio=0.001), + FadeIn(dots, lag_ratio=0.001), + Write(pattern_words), + run_time=2 + ) + self.wait() + + # Preview masking + masked_dots = VGroup() + for n, row in enumerate(dots): + masked_dots.add(*row[:n]) + mask_rects = VGroup() + for dot in masked_dots: + mask_rect = Square(0.5) + mask_rect.set_stroke(RED, 2) + mask_rect.move_to(dot) + mask_rects.add(mask_rect) + + lag_ratio=1.0 / len(mask_rects) + self.play(ShowCreation(mask_rects, lag_ratio=lag_ratio)) + self.play( + LaggedStart( + (dot.animate.scale(0) for dot in masked_dots), + lag_ratio=lag_ratio + ) + ) + self.play( + FadeOut(mask_rects, lag_ratio=lag_ratio) + ) + self.wait() + + # Set aside keys and queries + pattern = VGroup(grid_lines, dots) + for group in q_groups: + group.sort(lambda p: -p[1]) + group.target = group.generate_target() + m3 = len(group) - 3 + group.target[m3:].scale(0, about_edge=DOWN) + group.target[:m3].move_to(group, DOWN) + + self.play( + frame.animate.move_to((-2.09, -5.59, 0.0)).set_height(12.95).set_anim_args(run_time=3), + LaggedStartMap(MoveToTarget, q_groups), + FadeOut(pattern_words), + v_lines.animate.stretch(0.95, 1, about_edge=DOWN), + ) + self.play( + LaggedStartMap(FadeOut, k_syms, shift=0.5 * DOWN, lag_ratio=0.1), + LaggedStartMap(FadeOut, wk_syms, shift=0.5 * DOWN, lag_ratio=0.1), + ) + self.wait() + + # Add values + value_color = RED + big_wv_sym = Tex(R"W_V", font_size=90) + big_wv_sym.set_color(value_color) + big_wv_sym.next_to(h_lines, UP, MED_LARGE_BUFF, LEFT) + wv_word = Text("Value matrix", font_size=90) + wv_word.next_to(big_wv_sym, UP, MED_LARGE_BUFF) + wv_word.set_color(value_color) + + wv_arrows = wk_arrows + v_sym_template = Tex(R"\vec{\textbf{V}}_{0}") + v_sym_template[0].scale(1.5, about_edge=DOWN) + v_sym_template.set_fill(value_color, border_width=1) + subscript = v_sym_template.make_number_changeable("0") + + wv_syms = VGroup() + v_syms = VGroup() + for n, arrow in enumerate(wv_arrows, start=1): + wv_sym = Tex("W_V", font_size=36) + wv_sym.set_fill(value_color, border_width=1) + wv_sym.next_to(arrow, UP, buff=0.2, aligned_edge=LEFT) + subscript.set_value(n) + v_sym = v_sym_template.copy() + v_sym.next_to(arrow, RIGHT, MED_SMALL_BUFF) + + v_syms.add(v_sym) + wv_syms.add(wv_sym) + + self.play( + FadeIn(big_wv_sym, 0.5 * DOWN), + FadeIn(wv_word, lag_ratio=0.1), + ) + self.play( + LaggedStart( + (TransformFromCopy(big_wv_sym, wv_sym) + for wv_sym in wv_syms), + lag_ratio=0.15, + ), + run_time=3 + ) + self.play( + LaggedStart( + (TransformFromCopy(e_sym, v_sym) + for e_sym, v_sym in zip(key_emb_syms, v_syms)), + lag_ratio=0.15, + ), + ) + self.wait() + self.play( + FadeTransform(v_syms, k_syms), + FadeTransform(wv_syms, wk_syms), + rate_func=there_and_back_with_pause, + run_time=3, + ) + self.remove(k_syms, wk_syms) + self.add(v_syms, wv_syms) + self.wait() + + # Show column of weights + index = words.index("creature") + weighted_sum_cols = VGroup() + for sm_col in softmax_cols: + weighted_sum_col = VGroup() + for weight, v_sym in zip(sm_col, v_syms): + product = VGroup(weight, v_sym.copy()) + product.target = product.generate_target() + product.target.arrange(RIGHT) + product.target[1].shift(UP * ( + product.target[0].get_y(DOWN) - + product.target[1][1].get_y(DOWN) + )) + product.target.scale(0.75) + product.target.move_to(weight) + product.target.set_fill( + opacity=clip(0.6 + weight.get_value(), 0, 1) + ) + weighted_sum_col.add(product) + weighted_sum_cols.add(weighted_sum_col) + + self.play( + FadeOut(dots, lag_ratio=0.1), + FadeIn(q_fade_rects[:index]), + FadeIn(q_fade_rects[index + 1:]), + FadeIn(softmax_cols[index]), + ) + self.wait() + self.play( + LaggedStartMap(MoveToTarget, weighted_sum_cols[index]) + ) + self.wait() + + # Emphasize fluffy and blue weights + rects = VGroup( + key_word_groups[i][0].copy() + for i in [1, 2] + ) + alt_rects = VGroup( + SurroundingRectangle(value, buff=SMALL_BUFF) + for value in (* softmax_cols[index][:1], *softmax_cols[index][3:]) + ) + alt_rects.set_stroke(RED, 1) + self.play( + LaggedStart( + (rect.animate.surround(value) + for rect, value in zip(rects, softmax_cols[index][1:3])), + lag_ratio=0.2, + ) + ) + self.wait() + self.play(Transform(rects, alt_rects)) + self.wait() + self.play(FadeOut(rects, lag_ratio=0.1)) + + # Show sum + emb_sym = emb_syms[index] + ws_col = weighted_sum_cols[index] + creature = images[2] + creature.set_height(1.5) + creature.next_to(word_groups[index], UP) + + emb_sym.target = emb_sym.generate_target() + emb_sym.target.scale(1.25, about_edge=UP) + sum_rect = SurroundingRectangle(emb_sym.target) + sum_rect.set_stroke(YELLOW, 2) + sum_rect.target = sum_rect.generate_target() + sum_rect.target.surround(VGroup(emb_sym.target, ws_col), buff=MED_SMALL_BUFF) + plusses = VGroup() + for m1, m2 in zip([emb_sym.target, *ws_col], ws_col): + plus = Tex(R"+", font_size=72) + plus.move_to(midpoint(m1.get_bottom(), m2.get_top())) + plusses.add(plus) + + self.play( + frame.animate.reorient(0, 0, 0, (-2.6, -4.79, 0.0), 15.07).set_anim_args(run_time=2), + MoveToTarget(emb_sym), + ShowCreation(sum_rect), + FadeIn(creature, UP), + FadeOut(wv_word), + FadeOut(big_wv_sym), + ) + self.add(Point(), q_fade_rects[index + 1:]) # Hack + self.wait() + self.play( + frame.animate.reorient(0, 0, 0, (-2.9, -6.5, 0.0), 19).set_anim_args(run_time=2), + MoveToTarget(sum_rect, run_time=2), + Write(plusses), + ) + self.wait() + + # Finish sum + low_arrows = VGroup( + Vector(DOWN).next_to(wsc[-1].target, DOWN) + for wsc in weighted_sum_cols + ) + for sym, arrow in zip(emb_sym_primes, low_arrows): + sym.match_height(emb_sym) + sym.next_to(arrow, DOWN) + blue_fluff.set_height(2.5) + blue_fluff.next_to(emb_sym_primes[index], buff=MED_LARGE_BUFF, aligned_edge=UP) + + self.play( + TransformFromCopy(emb_syms[index], emb_sym_primes[index]), + LaggedStart( + (FadeTransform(prod.copy(), emb_sym_primes[index]) + for prod in ws_col), + lag_ratio=0.05, + group_type=Group + ), + ShowCreation(low_arrows[index]), + FadeTransform(creature.copy(), blue_fluff) + ) + self.wait() + + # Map it over all vectors + plus_groups = VGroup( + plusses.copy().match_x(col[0].target) + for col in weighted_sum_cols + ) + plus_groups.set_fill(GREY_C, 1) + + for col in softmax_cols: + for value in col: + value.set_fill( + opacity=clip(0.6 + value.get_value(), 0, 1) + ) + + self.play( + frame.animate.reorient(0, 0, 0, (-2.76, -7, 0.0), 16), + FadeOut(sum_rect), + FadeOut(creature), + FadeOut(blue_fluff), + FadeOut(q_fade_rects[:index]), + FadeOut(q_fade_rects[index + 1:]), + FadeIn(softmax_cols[:index]), + FadeIn(softmax_cols[index + 1:]), + plusses.animate.set_fill(GREY_C, 1), + ) + self.play( + LaggedStart( + (LaggedStartMap(MoveToTarget, col) + for col in weighted_sum_cols), + lag_ratio=0.1 + ), + v_lines.animate.set_stroke(GREY_B, 3, 1), + *( + e_sym.animate.scale(1.25, about_edge=UP) + for e_sym in (*emb_syms[:index], *emb_syms[index + 1:]) + ), + ) + other_indices = [*range(index), *range(index + 1, len(plus_groups))] + self.play(LaggedStart( + (LaggedStart( + FadeIn(plus_groups[j], lag_ratio=0.1), + GrowArrow(low_arrows[j]), + LaggedStart( + (FadeTransform(ws.copy(), emb_sym_primes[j]) + for ws in weighted_sum_cols[j]), + lag_ratio=0.05, + group_type=Group + ), + lag_ratio=0.25, + ) + for j in other_indices), + lag_ratio=0.01, + group_type=Group + )) + self.wait() def bake_mobject_into_vector_entries(self, mob, vector, path_arc=30 * DEGREES, group_type=None): entries = vector.get_entries() @@ -299,3 +1267,58 @@ class AttentionPatterns(InteractiveScene): run_time=2 ), ) + + +class RoadNotTaken(InteractiveScene): + def construct(self): + # Add poem + kw = dict(alignment="LEFT") + stanzas = VGroup( + Text(""" + Two roads diverged in a yellow wood, + And sorry I could not travel both + And be one traveler, long I stood + And looked down one as far as I could + To where it bent in the undergrowth; + """, **kw), + Text(""" + Then took the other, as just as fair, + And having perhaps the better claim, + Because it was grassy and wanted wear; + Though as for that the passing there + Had worn them really about the same, + """, **kw), + Text(""" + And both that morning equally lay + In leaves no step had trodden black. + Oh, I kept the first for another day! + Yet knowing how way leads on to way, + I doubted if I should ever come back. + """, **kw), + Text(""" + I shall be telling this with a sigh + Somewhere ages and ages hence: + Two roads diverged in a wood, and I— + I took the one less traveled by, + And that has made all the difference. + """, **kw), + ) + stanzas.arrange_in_grid(h_buff=1.5, v_buff=1.0, fill_rows_first=False) + stanzas.set_width(FRAME_WIDTH - 1) + stanzas.move_to(0.5 * UP) + + self.play( + FadeIn(stanzas, lag_ratio=0.01, run_time=4) + ) + self.wait() + + # Note all text until "one" + + # Highlight "two roads" + + # Highlight "took the other" and "grassy and wanted wear" + + # Somehow higlight words throughout + + + diff --git a/_2024/transformers/embedding.py b/_2024/transformers/embedding.py index cd1f654..50a2d0b 100644 --- a/_2024/transformers/embedding.py +++ b/_2024/transformers/embedding.py @@ -99,6 +99,9 @@ def get_word_to_vec_model(model_name="glove-wiki-gigaword-50"): return model +# For chapter 1 + + class LyingAboutTokens2(InteractiveScene): def construct(self): # Mention next word prediction task @@ -320,7 +323,6 @@ class SoundTokens(InteractiveScene): self.wait() - class IntroduceEmbeddingMatrix(InteractiveScene): def construct(self): # Load words @@ -603,6 +605,7 @@ class Word2VecScene(InteractiveScene): height=8, depth=6.4, ) + label_rotation = PI / 2 # embedding_model = "word2vec-google-news-300" embedding_model = "glove-wiki-gigaword-50" @@ -668,6 +671,7 @@ class Word2VecScene(InteractiveScene): label_text=word if func_name is None else f"{func_name}({word})", buff=0, direction=direction, + label_rotation=self.label_rotation, **label_config, ) @@ -1790,19 +1794,20 @@ class DotProducts(InteractiveScene): dual_rotate(75, -95, run_time=8) -class DotProductWithGenderDirection(InteractiveScene): - vec_tex = R"\vec{\text{gen}}" - ref_words = ["man", "woman"] +class DotProductWithPluralDirection(InteractiveScene): + vec_tex = R"\vec{\text{plur}}" + ref_words = ["cat", "cats"] words = [ - "mother", "father", - "aunt", "uncle", - "sister", "brother", - "mama", "papa", + "octopus", "octopi", + "puppy", "puppies", + "student", "students", + "one", "two", "three", "four", + "single", "multiple", ] - x_range = (-5, 7 + 1e-4, 0.25) + x_range = (-8, 5 + 1e-4, 0.25) colors = [BLUE, RED] + threshold = -1.0 number_line_y = -1.5 - threshold = 1.0 def construct(self): # Initialize equation @@ -1842,12 +1847,16 @@ class DotProductWithGenderDirection(InteractiveScene): longer_tick_multiple=2.5, width=12 ) + # number_line.rotate(PI / 2) number_line.add_numbers( np.arange(*x_range[:2]), num_decimal_places=1, font_size=30, + # direction=LEFT ) number_line.move_to(self.number_line_y * UP) + # number_line.to_edge(LEFT, buff=1.0) + eq_rhs = self.get_equation_rhs(eq_lhs, words[0]) equation = VGroup(eq_lhs, eq_rhs) low_brace = Brace(equation, DOWN) @@ -1948,19 +1957,18 @@ class DotProductWithGenderDirection(InteractiveScene): ) -class DotProductWithPluralityDirection(DotProductWithGenderDirection): - vec_tex = R"\vec{\text{plur}}" - ref_words = ["cat", "cats"] +class DotProductWithGenderDirection(DotProductWithPluralDirection): + vec_tex = R"\vec{\text{gen}}" + ref_words = ["man", "woman"] words = [ - "octopus", "octopi", - "puppy", "puppies", - "student", "students", - "one", "two", "three", "four", - "single", "multiple", + "mother", "father", + "aunt", "uncle", + "sister", "brother", + "mama", "papa", ] - x_range = (-8, 5 + 1e-4, 0.25) + x_range = (-5, 7 + 1e-4, 0.25) colors = [BLUE, RED] - threshold = -1.0 + threshold = 1.0 class RicherEmbedding(InteractiveScene): @@ -2110,3 +2118,244 @@ class RicherEmbedding(InteractiveScene): result = VGroup(vect, text) return result + + +# For chapter 2 + +class MultipleMoleEmbeddings(Word2VecScene): + default_frame_orientation = (0, 0) + label_rotation = 0 + + def setup(self): + super().setup() + self.set_floor_plane("xz") + self.frame.add_ambient_rotation() + self.add_plane() + for mob in [self.plane, self.axes]: + mob.rotate(-90 * DEGREES, RIGHT) + + def construct(self): + # Show generic mole embedding + frame = self.frame + frame.reorient(-6, -6, 0, (-0.73, 1.29, -0.57), 5.27) + phrases = VGroup(map(Text, [ + "American shrew mole", + "One mole of carbon dioxide", + "Take a biopsy of the mole", + ])) + for phrase in phrases: + phrases.fix_in_frame() + phrases.to_corner(UL) + phrase["mole"][0].set_color(YELLOW) + + gen_vector = self.get_labeled_vector("mole", coords=(-2, 1.0, 1.5)) + curr_phrase = phrases[1] + mover = curr_phrase["mole"][0] + mover.set_backstroke(BLACK, 4) + + self.add(curr_phrase) + self.wait() + self.play( + GrowArrow(gen_vector), + TransformFromCopy(mover, gen_vector.label), + ) + self.wait(10) + + # Show three refined meanings + images = Group( + ImageMobject("ShrewMole"), + Tex(R"6.02 \times 10^{23}", font_size=24).set_color(BLUE), + ImageMobject("LipMole"), + ) + for image in images[::2]: + image.set_height(0.5) + image.set_opacity(0.75) + + colors = [GREY_BROWN, BLUE, ORANGE] + ref_vects = VGroup( + self.get_labeled_vector("", coords=coords) + for coords in [ + (-1.0, -1.5, 1.5), + (-4.0, 0.5, 1.0), + (-0.5, 1.0, 2.5), + ] + ) + for vect, image, color in zip(ref_vects, images, colors): + vect.set_color(color) + image.next_to(vect.get_end(), UP, SMALL_BUFF) + + gen_vect_group = VGroup(gen_vector, gen_vector.label) + + self.play( + frame.animate.reorient(-30, -5, 0, (-1.11, 1.35, -0.72), 5.27), + LaggedStart( + (TransformFromCopy(gen_vector, ref_vect) + for ref_vect in ref_vects), + lag_ratio=0.25, + run_time=2, + ), + LaggedStart( + (FadeInFromPoint(image, gen_vector.label.get_center()) + for image in images), + lag_ratio=0.25, + run_time=2, + group_type=Group, + ), + gen_vect_group.animate.set_opacity(0.25).set_anim_args(run_time=2), + run_time=2, + ) + self.wait(3) + + ref_vect_groups = Group( + Group(*pair) for pair in zip(ref_vects, images) + ) + + # Oscillate between meanings based on context + diff_vects = VGroup( + Arrow(gen_vector.get_end(), ref_vect.get_end(), buff=0) + for ref_vect in ref_vects + ) + diff_vects.set_color(GREY_B) + + last_phrase = curr_phrase + last_diff = VGroup() + for n, diff in enumerate(diff_vects): + ref_vect_groups.target = ref_vect_groups.generate_target() + ref_vect_groups.target.set_opacity(0.2) + ref_vect_groups.target[n].set_opacity(1) + if n != 2: + ref_vect_groups.target[2][1].set_opacity(0.1) + phrase = phrases[n] + self.play( + gen_vect_group.animate.set_opacity(1), + MoveToTarget(ref_vect_groups), + FadeOut(last_phrase, UP), + FadeIn(phrase, UP), + FadeOut(last_diff) + ) + self.play( + ShowCreation(diff, time_span=(1, 2)), + TransformFromCopy(gen_vector, ref_vects[n], time_span=(1, 2)), + ContextAnimation( + phrase["mole"][0], phrase, + direction=DOWN, + fix_in_frame=True, + ), + ) + self.wait(3) + + last_phrase = phrase + last_diff = diff + + self.wait(5) + + def get_basis(self, model): + basis = super().get_basis(model) * 2 + basis[2] *= -1 + return basis + + + +class RefineTowerMeaning(MultipleMoleEmbeddings): + def construct(self): + # Set up vectors and images + frame = self.frame + frame.reorient(-26, -4, 0, (3.27, 1.57, 0.59), 5.28) + frame.add_ambient_rotation(0.5 * DEGREES) + + words = VGroup(Text(word) for word in "Miniature Eiffel Tower".split(" ")) + words.scale(1.25) + words.to_edge(UP) + words.fix_in_frame() + + tower_images = Group( + ImageMobject(f"Tower{n}") + for n in range(1, 5) + ) + eiffel_tower_images = Group( + ImageMobject(f"EiffelTower{n}") + for n in range(1, 4) + ) + mini_eiffel_tower_images = Group( + ImageMobject("MiniEiffelTower1") + ) + image_groups = Group( + tower_images, + eiffel_tower_images, + mini_eiffel_tower_images + ) + + vectors = VGroup( + self.get_labeled_vector("", coords=coords) + for coords in [ + (4, -1, 3.0), + (5, -2, 1.5), + (-3, -1, 2.5), + ] + ) + colors = [BLUE_D, GREY_B, GREY_C] + for vector, color, image_group in zip(vectors, colors, image_groups): + vector.set_color(color) + for image in image_group: + image.set_height(1.5) + image.next_to(vector.get_end(), RIGHT * np.sign(vector.get_end()[0])) + + # Show tower + tower = words[-1] + tower.set_x(0) + pre_tower_image = tower_images[0].copy() + pre_tower_image.fix_in_frame() + pre_tower_image.replace(tower, stretch=True) + pre_tower_image.set_opacity(0) + + self.add(tower) + self.wait() + self.play( + GrowArrow(vectors[0]), + ReplacementTransform(pre_tower_image, tower_images[0]), + run_time=2, + ) + for ti1, ti2 in zip(tower_images, tower_images[1:]): + self.play( + FadeTransform(ti1, ti2), + run_time=2 + ) + self.wait(2) + + # Eiffel tower + words[:-1].set_opacity(0) + eiffel_tower = words[-2:] + + self.play( + frame.animate.reorient(-4, -7, 0, (2.95, 1.82, 0.49), 6.59), + eiffel_tower.animate.set_opacity(1).arrange(RIGHT, aligned_edge=DOWN).to_edge(UP), + ) + self.play( + vectors[0].animate.set_opacity(0.25), + tower_images[-1].animate.set_opacity(0.2), + TransformFromCopy(vectors[0], vectors[1]), + FadeTransform(tower_images[-1].copy(), eiffel_tower_images[0]), + ContextAnimation(words[2], words[1], direction=DOWN, fix_in_frame=True), + run_time=2, + ) + for ti1, ti2 in zip(eiffel_tower_images, eiffel_tower_images[1:]): + self.play( + FadeTransform(ti1, ti2), + run_time=2 + ) + self.wait(2) + + # Miniature eiffel tower + self.play( + frame.animate.reorient(-14, -2, 0, (-0.12, 2.21, 0.72), 7.05).set_anim_args(run_time=2), + words.animate.set_opacity(1).arrange(RIGHT, aligned_edge=DOWN).to_edge(UP), + ) + self.play( + vectors[1].animate.set_opacity(0.25), + eiffel_tower_images[-1].animate.set_opacity(0.2), + TransformFromCopy(vectors[1], vectors[2]), + FadeTransform(eiffel_tower_images[-1].copy(), mini_eiffel_tower_images[0]), + ContextAnimation(words[2], words[0], direction=DOWN, fix_in_frame=True), + run_time=2, + ) + self.wait(10) diff --git a/_2024/transformers/ml_basics.py b/_2024/transformers/ml_basics.py index c5622ca..7b0ea0b 100644 --- a/_2024/transformers/ml_basics.py +++ b/_2024/transformers/ml_basics.py @@ -1681,8 +1681,6 @@ class DistinguishWeightsAndData(InteractiveScene): v_line.to_edge(UP, buff=0) v_line.set_stroke(GREY_A, 2) - self.add(titles) - # Set up matrices matrices = VGroup( WeightMatrix( diff --git a/_2024/transformers/network_flow.py b/_2024/transformers/network_flow.py index 755c818..e3833bb 100644 --- a/_2024/transformers/network_flow.py +++ b/_2024/transformers/network_flow.py @@ -23,6 +23,7 @@ class HighLevelNetworkFlow(InteractiveScene): (" Ludwig", 0.0104), ] hide_block_labels = False + block_to_title_direction = UP def setup(self): super().setup() @@ -153,7 +154,7 @@ class HighLevelNetworkFlow(InteractiveScene): title = Text(title, font_size=title_font_size) title.set_backstroke(BLACK, title_backstroke_width) - title.next_to(body, UP, buff=0.1) + title.next_to(body, self.block_to_title_direction, buff=0.1) block = Group(body, title) block.body = body block.title = title @@ -221,7 +222,6 @@ class HighLevelNetworkFlow(InteractiveScene): self.play(LaggedStartMap(VFadeInThenOut, arrows, lag_ratio=0.25, run_time=4)) self.play(FadeOut(token_label, DOWN)) - # Show words into vectors layer = self.get_embedding_array( shape=(len(words), 10), @@ -1165,4 +1165,169 @@ class TextPassageIntro(InteractiveScene): run_time=3 ) self.add(short_text) - self.wait() \ No newline at end of file + self.wait() + + +class MoleExample1(HighLevelNetworkFlow): + block_to_title_direction = LEFT + highlighted_group_index = 1 + + def construct(self): + # Show three phrases + phrase_strs = [ + "American shrew mole", + "One mole of carbon dioxide", + "Take a biopsy of the mole", + ] + phrases = VGroup(map(Text, phrase_strs)) + phrases.arrange(DOWN, buff=2.0) + phrases.move_to(0.25 * DOWN) + + self.play(Write(phrases[0]), run_time=1) + self.wait() + for i in [1, 2]: + self.play( + Transform(phrases[i - 1]["mole"].copy(), phrases[i]["mole"].copy(), remover=True), + FadeIn(phrases[i], lag_ratio=0.1) + ) + self.wait() + + # Add mole images + images = Group( + ImageMobject("ShrewMole").set_height(1), + Tex(R"6.02 \times 10^{23}").set_color(TEAL), + ImageMobject("LipMole").set_height(1), + ) + braces = VGroup() + mole_words = VGroup() + for image, phrase in zip(images, phrases): + mole_word = phrase["mole"][0] + brace = Brace(mole_word, UP, SMALL_BUFF) + image.next_to(brace, UP, SMALL_BUFF) + braces.add(brace) + mole_words.add(mole_word) + + self.play( + LaggedStartMap(GrowFromCenter, braces, lag_ratio=0.5), + LaggedStartMap(FadeIn, images, shift=UP, lag_ratio=0.5), + mole_words.animate.set_color(YELLOW).set_anim_args(lag_ratio=0.1), + ) + self.wait() + + # Subdivide + word_groups = VGroup() + for phrase in phrases: + words = break_into_words(phrase.copy()) + rects = get_piece_rectangles( + words, leading_spaces=False, h_buff=0.05 + ) + word_group = VGroup(VGroup(*pair) for pair in zip(rects, words)) + word_groups.add(word_group) + + self.play( + FadeIn(word_groups), + LaggedStartMap(FadeOut, braces, shift=0.25 * DOWN, lag_ratio=0.25), + LaggedStartMap(FadeOut, images, shift=0.25 * DOWN, lag_ratio=0.25), + run_time=1 + ) + self.remove(phrases) + self.wait() + + # Divide into three regions + for group, sign in zip(word_groups, [-1, 0, 1]): + group.target = group.generate_target() + group.target.scale(0.75) + group.target.set_x(sign * FRAME_WIDTH / 3) + group.target.to_edge(UP) + + v_lines = Line(UP, DOWN).replicate(2) + v_lines.set_height(FRAME_HEIGHT) + v_lines.arrange(RIGHT, buff=FRAME_WIDTH / 3) + v_lines.center() + v_lines.set_stroke(GREY_B, 1) + + self.play( + LaggedStartMap(MoveToTarget, word_groups), + ShowCreation(v_lines, lag_ratio=0.5, time_span=(1, 2)) + ) + + # Show vector embeddings + embs = VGroup() + arrows = VGroup() + seed_array = np.random.uniform(0, 10, 7) + for group in word_groups: + for word in group: + arrow = Vector(0.5 * DOWN) + arrow.next_to(word, DOWN, SMALL_BUFF) + size = sum(len(m.get_points()) for m in word.family_members_with_points()) + values = (seed_array * size % 10) + emb = NumericEmbedding(values=values) + emb.set_height(2) + emb.next_to(arrow, DOWN, SMALL_BUFF) + + arrows.add(arrow) + embs.add(emb) + mole_indices = [2, 4, 13] + non_mole_indices = [n for n in range(len(embs)) if n not in mole_indices] + + mole_vect_rects = VGroup( + SurroundingRectangle(embs[index]) + for index in mole_indices + ) + mole_vect_rects.set_stroke(YELLOW, 2) + + globals().update(locals()) + self.play( + LaggedStartMap(GrowArrow, arrows), + LaggedStartMap(FadeIn, embs, shift=0.25 * DOWN), + ) + self.wait() + self.play( + LaggedStartMap(ShowCreation, mole_vect_rects), + VGroup(arrows[j] for j in non_mole_indices).animate.set_fill(opacity=0.5), + VGroup(embs[j] for j in non_mole_indices).animate.set_fill(opacity=0.5), + ) + self.wait() + self.play( + FadeOut(mole_vect_rects) + ) + + # Prepare to pass through an attention block + wg_lens = [len(wg) for wg in word_groups] + indices = [0, *np.cumsum(wg_lens)] + full_groups = VGroup( + VGroup(wg, arrows[i:j], embs[i:j]) + for wg, i, j in zip(word_groups, indices, indices[1:]) + ) + highlighted_group = full_groups[self.highlighted_group_index] + fade_groups = [fg for n, fg in enumerate(full_groups) if n != self.highlighted_group_index] + highlighted_group.target = highlighted_group.generate_target() + highlighted_group.target.scale(1.5, about_edge=UP) + highlighted_group.target.space_out_submobjects(1.1) + highlighted_group.target.center() + highlighted_group.target[2].set_fill(opacity=1) + + globals().update(locals()) + self.play( + FadeOut(v_lines, time_span=(0, 1)), + MoveToTarget(highlighted_group, lag_ratio=5e-4), + *( + FadeOut( + fg, + shift=fg.get_center() - highlighted_group.get_center() + 2 * DOWN, + lag_ratio=1e-3 + ) + for fg in fade_groups + ), + run_time=2 + ) + self.wait() + + # Pass through attention + layer = VGroup(highlighted_group[2]) + layer.embeddings = highlighted_group[2] + self.layers.set_submobjects([]) + self.layers.add(layer) + + self.progress_through_attention_block(target_frame_x=-2) + self.wait() diff --git a/_2024/transformers/supplements.py b/_2024/transformers/supplements.py index bafa689..347bece 100644 --- a/_2024/transformers/supplements.py +++ b/_2024/transformers/supplements.py @@ -460,9 +460,10 @@ class GamePlan(InteractiveScene): prev_thumbnails.set_width(FRAME_WIDTH - 2) prev_thumbnails.move_to(2 * UP) - new_thumbnails = Group( # TODO, give these images - Rectangle().set_stroke(width=0).set_fill(BLACK, 1) - for vid in tr_vids + tn_dir = "/Users/grant/3Blue1Brown Dropbox/3Blue1Brown/videos/2024/transformers/Thumbnails/" + new_thumbnails = Group( + ImageMobject(os.path.join(tn_dir, f"Chapter{n}")) + for n in range(5, 8) ) for tn1, tn2 in zip(prev_thumbnails, new_thumbnails): tn2.replace(tn1, stretch=True) @@ -567,6 +568,37 @@ class SeaOfNumbersUnderlay(TeacherStudentsScene): self.wait(8) +class Outdated(TeacherStudentsScene): + def construct(self): + # Add label + text = Text("GPT-3", font="Consolas", font_size=72) + openai_logo = SVGMobject("OpenAI.svg") + openai_logo.set_fill(WHITE) + openai_logo.set_height(2.0 * text.get_height()) + gpt3_label = VGroup(openai_logo, text) + gpt3_label.arrange(RIGHT) + gpt3_label.scale(0.75) + param_count = Text("175B Parameters") + param_count.set_color(BLUE) + param_count.next_to(gpt3_label, DOWN, aligned_edge=LEFT) + gpt3_label.add(param_count) + + gpt3_label.move_to(self.hold_up_spot, DOWN) + + morty = self.teacher + morty.body.insert_n_curves(100) + + self.play( + morty.change("raise_right_hand"), + FadeIn(gpt3_label, UP), + ) + self.play(self.change_students("raise_left_hand", "hesitant", "sassy")) + self.play( + self.students[0].says(TexText("Isn't that outdated?")) + ) + self.wait(3) + + class ConfusionAtScreen(TeacherStudentsScene): def construct(self): self.play(