from manim_imports_ext import * from _2024.transformers.generation import * from _2024.transformers.helpers import * from _2024.transformers.embedding import * from _2024.transformers.ml_basics import * # Intro class HoldUpThumbnail(TeacherStudentsScene): def construct(self): # Test im = ImageMobject("/Users/grant/3Blue1Brown Dropbox/3Blue1Brown/videos/2024/transformers/Thumbnails/Chapter5_TN3.png") im_group = Group( SurroundingRectangle(im, buff=0).set_stroke(WHITE, 3), im ) im_group.set_height(3) im_group.move_to(self.hold_up_spot, DOWN) morty = self.teacher stds = self.students self.play( FadeIn(im_group, UP), morty.change("raise_right_hand", look_at=im_group), self.change_students("tease", "happy", "tease", look_at=im_group), ) self.wait(4) class IsThisUsefulToShare(TeacherStudentsScene): def construct(self): # Test morty = self.teacher self.play( morty.says("Do you find\nthis useful?"), self.change_students("pondering", "hesitant", "well", look_at=self.screen) ) self.wait(3) self.play(self.change_students("thinking", "pondering", "tease")) self.wait(3) class AskAboutAttention(TeacherStudentsScene): def construct(self): # Test stds = self.students morty = self.teacher self.play( morty.change("tease"), stds[2].says("Can you explain what\nAttention does?", mode="raise_left_hand", bubble_direction=LEFT), stds[1].change("pondering", self.screen), stds[0].change("pondering", self.screen), ) self.wait(4) # Version 1 class PredictTheNextWord(SimpleAutogregression): text_corner = 3.5 * UP + 6.5 * LEFT machine_name = "Large\nLanguage\nModel" seed_text = "Paris is a city in" model = "gpt3" n_shown_predictions = 12 random_seed = 2 def construct(self): # Setup machine text_mob, next_word_line, machine = self.init_text_and_machine() machine.move_to(ORIGIN) machine[1].set_backstroke(BLACK, 3) text_group = VGroup(text_mob, next_word_line) text_group.save_state() text_group.scale(1.5) text_group.match_x(machine[0]).to_edge(UP) # Introduce the machine in_arrow = Arrow(text_group, machine[0].get_top(), thickness=5) frame = self.frame self.set_floor_plane("xz") blocks = machine[0] llm_text = machine[1] block_outlines = blocks.copy() block_outlines.set_fill(opacity=0) block_outlines.set_stroke(GREY_B, 2) block_outlines.insert_n_curves(20) flat_dials, last_dials = self.get_machine_dials(blocks) self.clear() frame.reorient(-31, -4, -5, (-0.24, -0.26, -0.06), 3) self.play( FadeIn(blocks, shift=0.0, lag_ratio=0.01), LaggedStartMap(VShowPassingFlash, block_outlines.family_members_with_points(), time_width=2.0, lag_ratio=0.01, remover=True), LaggedStartMap(VFadeInThenOut, flat_dials, lag_ratio=0.001, remover=True), Write(llm_text, time_span=(2, 4), stroke_color=WHITE), FadeIn(last_dials, time_span=(4, 5)), frame.animate.reorient(0, 0, 0, (-0.17, -0.12, 0.0), 4.50), run_time=6, ) blocks[-1].add(last_dials) self.play( frame.animate.to_default_state(), FadeIn(text_group, UP), GrowFromCenter(in_arrow), run_time=3 ) # Single word prediction out_arrow = Vector(1.5 * RIGHT, thickness=5) out_arrow.next_to(machine[0][-1], RIGHT) prediction = Text("France", font_size=72) prediction.next_to(out_arrow, RIGHT) self.animate_text_input( text_mob, machine, position_text_over_machine=False, ) self.play( LaggedStart( (TransformFromCopy(VectorizedPoint(machine.get_right()), letter) for letter in prediction), lag_ratio=0.05, ), GrowArrow(out_arrow) ) self.wait() machine.replace_submobject(2, out_arrow) # Probability distribution self.play(FadeOut(prediction, DOWN)) bar_groups = self.animate_prediction_ouptut(machine, self.cur_str) self.wait() # Show auto_regression self.play( Restore(text_group), FadeOut(in_arrow), ) seed_label = Text("Seed text") seed_label.set_color(YELLOW) seed_label.next_to(text_mob, DOWN) self.play( FadeIn(seed_label, rate_func=there_and_back_with_pause), FlashAround(text_mob, time_width=2), frame.animate.reorient(0, 0, 0, (0.7, -0.01, 0.0), 8.52), run_time=2, ) self.animate_random_sample(bar_groups) new_text_mob = self.animate_word_addition( bar_groups, text_mob, next_word_line, ) # More! for n in range(20): text_mob = self.new_selection_cycle( text_mob, next_word_line, machine, quick=True, skip_anims=(n > 5), ) self.wait(0.25) def get_machine_dials(self, blocks): dials = VGroup( Dial().get_grid(8, 12).set_width(0.9 * block.get_width()).move_to(block) for block in blocks ) dials.set_stroke(opacity=0.5) for group in dials: for dial in group: dial.set_value(dial.get_random_value()) flat_dials = VGroup(*it.chain(*dials)) last_dials = dials[-1].copy() last_dials.set_stroke(opacity=0.1) return flat_dials, last_dials class LotsOfTextIntoTheMachine(PredictTheNextWord): run_time = 25 max_snippet_width = 3 def construct(self): # Add machine text_mob, next_word_line, machine = self.init_text_and_machine() machine.scale(1.5) self.clear() self.add(machine) blocks, title = machine[:2] dials = Dial().get_grid(8, 12).set_width(0.9 * blocks[-1].get_width()).move_to(blocks[-1]) dials.set_stroke(opacity=0.1) blocks[-1].add(dials) machine.center() machine[1].set_stroke(BLACK, 3) # Feed in lots of text snippets = self.get_text_snippets() text_mobs = VGroup(get_paragraph(snippet.split(" "), line_len=25) for snippet in snippets) directions = compass_directions(12, start_vect=UR) for text_mob, vect in zip(text_mobs, it.cycle(directions)): text_mob.set_max_width(self.max_snippet_width) text_mob.move_to(5 * vect).shift_onto_screen(buff=0.25) self.play( LaggedStart( (Succession( FadeIn(text_mob), text_mob.animate.set_opacity(0).move_to(machine.get_center()), ) for text_mob in text_mobs), lag_ratio=0.05, run_time=self.run_time ) ) self.remove(text_mobs) self.wait() def get_text_snippets(self): facts = Path(DATA_DIR, "pile_of_text.txt").read_text().split("\n") random.shuffle(facts) return facts class EvenMoreTextIntoMachine(LotsOfTextIntoTheMachine): run_time = 40 max_snippet_width = 2.5 n_examples = 300 context_size = 25 def get_text_snippets(self): book = Path(DATA_DIR, "tale_of_two_cities.txt").read_text() book = book.replace("\n", " ") words = list(filter(lambda m: m, book.split(" "))) context_size = self.context_size result = [] for n in range(self.n_examples): index = random.randint(0, len(words) - context_size - 1) result.append(" ".join(words[index:index + context_size])) return result class WriteTransformer(InteractiveScene): def construct(self): text = Text("Transformer", font_size=120) self.play(Write(text)) self.wait() class LabelVector(InteractiveScene): def construct(self): brace = Brace(Line(UP, DOWN).set_height(4), RIGHT) name = Text("Vector", font_size=72) name.next_to(brace, RIGHT) name.set_backstroke(BLACK, 5) self.play( GrowFromCenter(brace), Write(name), ) self.wait() class AdjustingTheMachine(InteractiveScene): def construct(self): # Add a machine and repeatedly tweak it frame = self.frame self.set_floor_plane("xz") frame.reorient(-28, -17, 0, ORIGIN, 8.91) self.camera.light_source.move_to([-10, 10, 10]) machine = MachineWithDials(n_rows=10, n_cols=12) machine.set_height(6) blocks = VCube().replicate(10) blocks.set_shape(machine.get_width(), machine.get_height(), 1.0) blocks.deactivate_depth_test() cam_loc = self.frame.get_implied_camera_location() for block in blocks: block.sort(lambda p: -get_norm(p - cam_loc)) blocks.set_fill(GREY_D, 1) blocks.set_shading(0.2, 0.5, 0.25) blocks.arrange(OUT, buff=0.5) blocks.move_to(machine, OUT) self.add(blocks) self.add(machine) frame.clear_updaters() frame.add_updater(lambda f: f.set_theta(-30 * DEGREES * math.cos(0.1 * self.time))) self.add(frame) for x in range(6): self.play(machine.random_change_animation(lag_factor=0.1)) class FirthQuote(InteractiveScene): def construct(self): # Show Quote quote = TexText(R"``You shall know a word\\by the company it keeps!''", font_size=60) image = ImageMobject("JohnRFirth") # From https://www.cambridge.org/core/journals/bulletin-of-the-school-of-oriental-and-african-studies/article/john-rupert-firth/D926AFCBF99AD17D5C7A7A9C0558DFDC image.set_height(6.5) image.to_corner(UL, buff=0.5) name = Text("John R. Firth") name.next_to(image, DOWN) quote.move_to(midpoint(image.get_right(), RIGHT_SIDE)) quote.to_edge(UP) self.play( FadeIn(image, 0.25 * UP), FadeIn(name, lag_ratio=0.1) ) self.play(Write(quote)) self.wait() # Show two sentences phrases = VGroup( Text("Down by the river bank"), Text("Deposit a check at the bank"), ) bank = Text("bank", font_size=90) bank.set_color(TEAL) bank.match_x(quote).match_y(image) for phrase in phrases: phrase["bank"].set_color(TEAL) phrases.arrange(DOWN, buff=1.0, aligned_edge=LEFT) phrases.next_to(quote, DOWN, buff=2.5) phrases[1].set_opacity(0.15) banks = VGroup( phrase["bank"][0] for phrase in phrases ) self.play( FadeIn(bank, scale=2, lag_ratio=0.25), quote.animate.scale(0.7, about_edge=UP).set_opacity(0.75) ) self.wait() self.remove(bank) self.play( FadeIn(phrases[0][:len("downbytheriver")], lag_ratio=0.1), FadeIn(phrases[1][:len("depositacheckatthe")], lag_ratio=0.1), *(TransformFromCopy(bank, bank2) for bank2 in banks) ) self.wait() self.play( phrases[0].animate.set_opacity(0.5), phrases[1].animate.set_opacity(1), ) self.wait() # Isolate both phrases self.play(LaggedStart( FadeOut(image, LEFT, scale=0.5), FadeOut(name, LEFT, scale=0.5), FadeOut(quote, LEFT, scale=0.5), phrases.animate.set_opacity(1).arrange(DOWN, buff=3.5, aligned_edge=LEFT).move_to(0.5 * UP), )) self.wait() # Recreate word = Text("bank", font_size=72) word.set_color(TEAL) self.clear() self.add(word) self.wait() self.remove(word) self.play( *( FadeIn(phrase[phrase.get_text().replace("bank", "")]) for phrase in phrases ), *( TransformFromCopy(word, phrase["bank"][0]) for phrase in phrases ) ) self.add(phrases) # Show influence query_rects = VGroup( SurroundingRectangle(bank) for bank in banks ) query_rects.set_stroke(TEAL, 2) query_rects.set_fill(TEAL, 0.25) key_rects = VGroup( SurroundingRectangle(phrases[0]["river"]), SurroundingRectangle(phrases[1]["Deposit"]), SurroundingRectangle(phrases[1]["check"]), ) key_rects.set_stroke(BLUE, 2) key_rects.set_fill(BLUE, 0.5) key_rects[2].match_height(key_rects[1], about_edge=UP, stretch=True) arrows = VGroup( Arrow(key_rects[0].get_top(), banks[0].get_top(), path_arc=-180 * DEGREES, buff=0.1), Arrow(key_rects[1].get_top(), banks[1].get_top(), path_arc=-90 * DEGREES), Arrow(key_rects[2].get_top(), banks[1].get_top(), path_arc=-90 * DEGREES), ) arrows.set_color(BLUE) key_rects.save_state() key_rects[0].become(query_rects[0]) key_rects[1].become(query_rects[1]) key_rects[2].become(query_rects[1]) key_rects.set_opacity(0) self.add(query_rects, phrases) self.play(FadeIn(query_rects, lag_ratio=0.25)) self.wait() self.add(key_rects, phrases) self.play(Restore(key_rects, lag_ratio=0.1, path_arc=PI / 4, run_time=2)) self.play(LaggedStartMap(Write, arrows, stroke_width=5, run_time=3)) self.wait() # Show images images = Group( ImageMobject("RiverBank"), ImageMobject("FederalReserve"), ) for image, bank in zip(images, banks): image.set_height(2.0) image.next_to(bank, DOWN, MED_SMALL_BUFF, aligned_edge=LEFT) self.play( LaggedStart( (FadeTransform(Group(word).copy(), image) for word, image in zip(banks, images)), lag_ratio=0.5, group_type=Group, ) ) self.wait(2) class DownByTheRiverHeader(InteractiveScene): def construct(self): words = Text("Down by the river bank ...") rect = SurroundingRectangle(words["bank"]) rect.set_fill(BLUE, 0.5) rect.set_stroke(BLUE, 3) brace = Brace(rect, DOWN, buff=SMALL_BUFF) self.add(rect, words, brace) class RiverBankProbParts(SimpleAutogregression): seed_text = "Down by the river bank, " model = "gpt3" def construct(self): # Test text_mob, next_word_line, machine = self.init_text_and_machine() machine.set_x(0) words = [ "water", "river", "lake", "grass", "waves", "shallows", "pool", "depths", "foam", "mist", ] probs = softmax([6, 5, 4, 4, 3.5, 3.25, 3, 3, 2.5, 2]) bar_groups = self.get_distribution(words, probs, machine) self.clear() bar_groups.set_height(6).center() self.play( LaggedStartMap(FadeIn, bar_groups, shift=0.25 * DOWN, run_time=3) ) self.wait() class FourStepsWithParameters(InteractiveScene): def construct(self): # Add rectangles and titles self.add(FullScreenRectangle(fill_color=GREY_E)) rects = Square().replicate(4) rects.arrange(RIGHT, buff=0.25 * rects[0].get_width()) rects.set_width(FRAME_WIDTH - 1.0) rects.center().to_edge(UP, buff=0.5) rects.set_fill(BLACK, 1) rects.set_stroke(WHITE, 2) names = VGroup(*map(TexText, [ R"Text snippets\\$\downarrow$\\Vectors", R"Attention", R"Feedforward", R"Final prediction", ])) for name, rect in zip(names, rects): name.scale(0.8) name.next_to(rect, DOWN) self.add(rects) self.play(LaggedStartMap(FadeIn, names, shift=0.25 * DOWN, lag_ratio=0.25)) self.wait() # Show many dials machines = VGroup( MachineWithDials( width=rect.get_width(), height=3.0, n_rows=9, n_cols=6, ) for rect in rects ) for machine, rect in zip(machines, rects): machine.next_to(rect, DOWN, buff=0) machine[0].set_opacity(0) machine.scale(rect.get_width() / machine.dials.get_width(), about_edge=UP) machine.dials.shift(0.25 * UP) for dial in machine.dials: dial.set_value(0) self.play( LaggedStart(( LaggedStart( (GrowFromPoint(dial, machine.get_top()) for dial in machine.dials), lag_ratio=0.025, ) for machine in machines ), lag_ratio=0.25), LaggedStartMap(FadeOut, names) ) for _ in range(2): self.play( LaggedStart( (machine.random_change_animation() for machine in machines), lag_ratio=0.2, ) ) class ChatbotFeedback(InteractiveScene): random_seed = 404 def construct(self): # Test self.frame.set_height(10).move_to(DOWN) user_prompt = "User: How and when was the internet invented?" prompt_mob = Text(user_prompt) prompt_mob.to_edge(UP) prompt_mob["User:"].set_color(BLUE) self.answer_mob = Text("AI Assistant:") self.answer_mob.next_to(prompt_mob, DOWN, buff=1.0, aligned_edge=LEFT) self.answer_mob.set_color(YELLOW) self.og_answer_mob = self.answer_mob self.add(prompt_mob, self.answer_mob) # Show multiple answer for n in range(8): self.give_answer(prompt_mob) mark = self.judge_answer() self.add(self.og_answer_mob) self.play(FadeOut(self.answer_mob), FadeOut(mark)) self.answer_mob = self.og_answer_mob def display_answer(self, text): new_answer_mob = get_paragraph(text.replace("\n", " ").split(" ")) new_answer_mob[:len(self.og_answer_mob)].match_style(self.og_answer_mob) new_answer_mob.move_to(self.og_answer_mob, UL) self.remove(self.answer_mob) self.answer_mob = new_answer_mob self.add(self.answer_mob) def give_answer(self, prompt_mob, max_responses=100): answer = self.og_answer_mob.get_text() user_prompt = prompt_mob.get_text() for n in range(max_responses): answer, stop = self.add_to_answer(user_prompt, answer) if stop: break self.display_answer(answer) self.wait(2 / 30) def judge_answer(self): mark = random.choice([ Checkmark().set_color(GREEN), Exmark().set_color(RED), ]) mark.scale(5) mark.next_to(self.answer_mob, RIGHT, aligned_edge=UP) rect = SurroundingRectangle(self.answer_mob) rect.match_color(mark) self.play(FadeIn(mark, scale=2), FadeIn(rect, scale=1.05)) self.wait() return VGroup(mark, rect) def add_to_answer(self, user_prompt: str, answer: str): try: tokens, probs = gpt3_predict_next_token("\n\n".join([user_prompt, answer])) token = random.choices(tokens, np.array(probs) / sum(probs))[0] except IndexError: return answer, True stop = False if token == '<|endoftext|>': stop = True else: answer += token return answer, stop class ContrastWithEarlierFrame(InteractiveScene): def construct(self): # Test vline = Line(UP, DOWN) vline.set_height(FRAME_HEIGHT) self.add(vline) titles = VGroup( VGroup( Text("Most earlier models"), # Vector(0.75 * DOWN, thickness=4), # Text("One word at a time") ), VGroup( Text("Transformers"), # Vector(0.75 * DOWN, thickness=4), # Text("All words in parallel") ), ) for title, vect in zip(titles, [LEFT, RIGHT]): title.arrange(DOWN, buff=0.2) title.scale(1.5) title.move_to(FRAME_WIDTH * vect / 4) title.to_edge(UP) self.add(titles) class SequentialProcessing(InteractiveScene): def construct(self): # Add text text = Text("Down by the river bank, where I used to go fishing ...") text.move_to(1.0 * DOWN) words = break_into_words(text) rects = get_piece_rectangles(words) blocks = VGroup(VGroup(rect, word) for rect, word in zip(rects, words)) blocks.save_state() self.add(blocks) # Vector wandering over vect = NumericEmbedding() vect.set_width(1.0) vect.next_to(rects[0], UP) for n in range(len(blocks) - 1): blocks.target = blocks.saved_state.copy() blocks.target[:n].fade(0.75) blocks.target[n + 1:].fade(0.75) self.play( vect.animate.next_to(blocks[n], UP), MoveToTarget(blocks) ) self.play( LaggedStart( (ContextAnimation(elem, blocks[n][1], lag_ratio=0.01) for elem in vect.get_entries()), lag_ratio=0.01, ), RandomizeMatrixEntries(vect), run_time=2 ) # Version 2 class PartialScript(SimpleAutogregression): machine_name = "Magic next\nword predictor" machine_phi = 5 * DEGREES machine_theta = 6 * DEGREES def construct(self): # Set frame frame = self.frame self.set_floor_plane("xz") # Unfurl script curled_script_img = ImageMobject("HumanAIScript") curled_script_img.set_height(7) curves = VGroup(SVGMobject("JaggedCurl1")[0], SVGMobject("JaggedCurl2")[0]) for curve in curves: curve.make_smooth(approx=False) curve.insert_n_curves(100) curve.set_stroke(WHITE, 3) curve.set_fill(opacity=0) curve.set_height(5) curves[1].scale(curves[0].get_arc_length() / curves[1].get_arc_length()) resolution = (2, 200) # Change surface_kw = dict(u_range=(-6, 6), v_range=(0.05, 0.95), resolution=resolution) curled_script_templates = Group( ParametricSurface( lambda u, v: (*curve.pfp(v)[:2], u), **surface_kw ) for curve in curves ) curled_script_templates[1].rotate(PI / 2, UP) curled_script_templates[0].rotate(-PI / 2) flat_script_template = ParametricSurface( lambda u, v: (u, v, 0), **surface_kw ) curled_script0 = TexturedSurface(curled_script_templates[0], "HumanAIScript") curled_script1 = TexturedSurface(curled_script_templates[1], "HumanAIScript") curled_script1_torn = TexturedSurface(curled_script_templates[1], "HumanAIScriptTorn") flat_script = TexturedSurface(flat_script_template, "HumanAIScriptTorn") flat_script.replace(curled_script_img, stretch=True) for script in [curled_script0, curled_script1]: script.set_shading(0.25, 0.25, 0.35) curled_script1_torn.set_shading(0, 0, 0) flat_script.set_shading(0, 0, 0) frame.reorient(0, -1, 0, (-0.28, 0.69, 0.0), 14.43) self.play( TransformFromCopy(curled_script0, curled_script1), frame.animate.reorient(56, -17, 0, (-0.2, -1.52, -2.39), 20.05), run_time=3 ) self.play( frame.animate.reorient(-6, -11, 0, (1.06, -1.22, -2.65), 20.05), run_time=8, ) self.play( FadeOut(curled_script1, shift=1e-2 * IN), FadeIn(curled_script1_torn, shift=1e-2 * IN), ) self.play( ReplacementTransform(curled_script1_torn, flat_script), frame.animate.to_default_state(), run_time=2 ) self.wait() # Show the machine machine = self.get_transformer_drawing() machine[1].set_height(0.7).set_stroke(width=2) machine[1].set_opacity(0) machine.remove(machine[-1]) machine.set_height(3) machine.to_edge(RIGHT) self.play( flat_script.animate.set_height(5).to_edge(LEFT), FadeIn(machine, lag_ratio=0.01) ) self.add(machine) self.wait() # Show example input and output out_arrow = Vector(DOWN, thickness=6) out_arrow.next_to(machine, DOWN) in_arrow = out_arrow.copy().next_to(machine, UP, SMALL_BUFF) in_text = Text("To be or not to _") in_text[-1].stretch(3, 0, about_edge=LEFT) in_text.next_to(in_arrow, UP) prediction = Text("be", font_size=72) prediction.next_to(out_arrow, DOWN) self.play(FadeIn(in_text), GrowArrow(in_arrow)) self.animate_text_input(in_text, machine, position_text_over_machine=False) self.play( GrowArrow(out_arrow), FadeIn(prediction, DOWN), ) self.wait() # Clear the board script_text = self.get_text() script_text.set_width(0.89 * flat_script.get_width()) script_text.next_to(flat_script.get_top(), DOWN, buff=0.33) font_size = 48 * (script_text[0].get_height() / Text("H").get_height()) completion = "A transistor is a semiconductor device used to amplify or switch electronic signals. It consists of three layers of semiconductor material, either p-type or n-type, forming a structure with terminals called the emitter, base, and collector." words = completion.split(" ") paragraph = get_paragraph(completion.split(" "), font_size=font_size) paragraph.next_to(script_text, DOWN, aligned_edge=LEFT) paragraph.set_color(YELLOW) self.play( FadeIn(script_text), FadeOut(flat_script), FadeOut(VGroup(in_text, in_arrow, prediction)), ) # Repeatedly add predictions machine.scale(1.25, about_edge=RIGHT) out_arrow.next_to(machine, DOWN, buff=0.5) blocks = machine[0] dials = Dial().get_grid(11, 16) dials.set_width(blocks[-1].get_width() * 0.95) dials.rotate(5 * DEGREES, RIGHT).rotate(10 * DEGREES, UP) dials.move_to(blocks[-1]) dials.set_stroke(opacity=0.5) for dial in dials: dial.set_value(dial.get_random_value()) dials.set_z_index(2) self.add(dials) curr_answer = VGroup() curr_answer.next_to(script_text, DOWN) for n in range(6): word = words[n] prediction = Text(words[n], font_size=72) prediction.next_to(out_arrow, DOWN) word_in_answer = paragraph[len(curr_answer):len(curr_answer) + len(word)] word_in_answer.set_color(YELLOW) mover = VGroup(script_text, curr_answer).copy() if n > 2: self.skip_animations = True self.play( mover.animate.set_height(1.8).next_to(machine, UP, SMALL_BUFF).set_anim_args(path_arc=-30 * DEGREES), ) self.animate_text_input( mover, machine, position_text_over_machine=False, lag_ratio=1e-3 ) self.play(FadeIn(prediction, DOWN, rate_func=rush_from, run_time=0.5)) if n > 2: self.skip_animations = False self.wait(0.5) self.skip_animations = True self.play( curr_answer.animate.set_color(WHITE), Transform(prediction, word_in_answer), FadeOut(mover), ) curr_answer.add(*word_in_answer) self.add(curr_answer) self.remove(prediction) def get_text(self): script_text = Text(""" Human: Can you explain the history of transistors and how they're relevant to computers? What is a transistor, and how exactly is it used to perform computations? AI assistant: """, alignment="LEFT") script_text["Human"].set_color(BLUE) script_text["AI assistant"].set_color(TEAL) script_text.set_height(4).to_edge(UP) return script_text def create_image(self): # Create image script_text = self.get_text() script_text.set_fill(BLACK) script_text["Human"].set_fill(BLUE_D) script_text["AI assistant"].set_fill(TEAL_D) self.add(FullScreenRectangle(fill_color="#FCF5E5", fill_opacity=1)) self.add(script_text) # Add off test tear_off = SVGMobject('TearOff') tear_off.set_stroke(width=0) tear_off.set_fill(BLACK, 1) tear_off.set_width(7.5) tear_off.next_to(script_text, DOWN, buff=-0.2) self.add(tear_off) class ShowMachineWithDials(PredictTheNextWord): words = ['worst', 'age', 'worse', 'best', 'most', 'end', 'very', 'blur'] logprobs = [4.0, 2.15, 1.89, 1.4, 0.1, -0.18, -0.23, -0.61] def construct(self): # Show machine (same position as in PredictTheNextWord) frame = self.frame self.set_floor_plane("xz") blocks, llm_text, flat_dials, last_dials = self.get_blocks_and_dials() self.clear() self.add(frame) frame.reorient(0, 0, 0, (-0.17, -0.12, 0.0), 4.50) self.add(blocks, llm_text, last_dials) # Prepare dial highlight last_dials.target = last_dials.generate_target() self.fix_dials(last_dials.target) small_rect = SurroundingRectangle(last_dials[0], buff=0.025) small_rect.set_stroke(BLUE, 2) big_rect = small_rect.copy().scale(4) big_rect.next_to(blocks, UP, buff=SMALL_BUFF, aligned_edge=LEFT + OUT) big_rect.shift(1.5 * RIGHT) big_dial = last_dials[0].copy().scale(4).set_stroke(opacity=1) big_dial.move_to(big_rect) rect_lines = VGroup( Line(small_rect.get_corner(UL), big_rect.get_corner(DL)), Line(small_rect.get_corner(UR), big_rect.get_corner(DR)), ) rect_lines.set_stroke(WHITE, width=(1, 3)) highlighed_parameter_group = VGroup(small_rect, rect_lines, big_rect, big_dial) last_dials.set_stroke(width=1, opacity=1) self.play( MoveToTarget(last_dials), FadeOut(llm_text), FadeIn(small_rect), ) # Show an example input and output example = self.get_example(blocks) in_text, in_arrow, out_arrow, bar_groups = example logprobs = example.logprobs true_probs = 100 * softmax(logprobs) bar_groups = self.get_output_distribution(self.words, 0.1 * logprobs, out_arrow) self.play( LaggedStart( ShowCreation(rect_lines, lag_ratio=0), TransformFromCopy(small_rect, big_rect), TransformFromCopy(last_dials[0], big_dial), FadeIn(in_text), GrowArrow(in_arrow), FadeIn(bar_groups), GrowArrow(out_arrow), ), frame.animate.reorient(0, 0, 0, (-0.43, 0.38, 0.0), 7.05), run_time=2 ) self.play( last_dials[0].animate_set_value(0.8), big_dial.animate_set_value(0.8), LaggedStart( (dial.animate_set_value(dial.get_random_value()) for dial in last_dials[1:]), lag_ratio=1.0 / len(last_dials), ), *( self.bar_group_change_animation(bg, value) for bg, value in zip(bar_groups[:-1], true_probs) ), run_time=3 ) self.wait() # Play around tweaking the parameters, and seeing the output change self.play( LaggedStart( (dial.animate_set_value(0) for dial in last_dials[:12]), lag_ratio=0.01, ), big_dial.animate_set_value(0), self.bar_group_change_animation(bar_groups[0], 50), self.bar_group_change_animation(bar_groups[1], 34), self.bar_group_change_animation(bar_groups[2], 5), run_time=4, ) self.play( LaggedStart( (dial.animate_set_value(1) for dial in last_dials[:12]), lag_ratio=0.01, ), big_dial.animate_set_value(1), self.bar_group_change_animation(bar_groups[0], 80), self.bar_group_change_animation(bar_groups[1], 5), self.bar_group_change_animation(bar_groups[2], 15), run_time=4, ) self.wait() # Mention randomness random_words = Text("Initially random") random_words.next_to(blocks, UP) random_words.set_color(RED) out_dots = Tex(R"...", font_size=120) out_dots.next_to(out_arrow, RIGHT) self.play( FadeOut(big_rect), Uncreate(rect_lines, lag_ratio=0), FadeOut(small_rect), Transform(big_dial, last_dials[0]) ) self.play( Write(random_words), LaggedStart( (dial.animate_set_value(dial.get_random_value()) for dial in last_dials), lag_ratio=0.5 / len(last_dials), run_time=2 ), FadeOut(bar_groups), ) self.play(Write(out_dots)) self.wait() self.play( FadeOut(dots), FadeOut(random_words), FadeIn(bar_groups), ) # Show many many parameters example.save_state() blocks.save_state() last_dials.save_state() all_dials = VGroup(*flat_dials, *last_dials) all_dials.generate_target() all_dials.target.space_out_submobjects(3) new_dials = VGroup( all_dials.target.copy().shift(3 * 2 * x * (flat_dials.get_center() - last_dials.get_center())) for x in range(1, 9) ) self.play( FadeOut(example), FadeOut(blocks), FadeIn(flat_dials), FadeOut(bar_groups), FadeOut(out_arrow), ) self.play( FadeOut(highlighed_parameter_group), MoveToTarget(all_dials), LaggedStart( (TransformFromCopy(all_dials.copy().set_opacity(0), nd) for nd in new_dials), lag_ratio=0.05, ), frame.animate.reorient(-9, 0, 0, (-0.71, -0.07, -0.06), 9.64), run_time=4 ) self.wait() def get_blocks_and_dials(self): machine = self.get_transformer_drawing() machine.move_to(ORIGIN) self.machine = machine blocks = machine[0] llm_text = machine[1] llm_text.set_backstroke(BLACK, 2) flat_dials, last_dials = self.get_machine_dials(blocks) return blocks, llm_text, flat_dials, last_dials def get_example(self, blocks): in_text = Text("It was the best\nof times it was\nthe _", alignment="LEFT") in_text[-1].stretch(4, 0, about_edge=LEFT) in_text.next_to(blocks, LEFT, LARGE_BUFF) in_arrow = Arrow(in_text, blocks) out_arrow = Vector(RIGHT) out_arrow.next_to(blocks[-1], RIGHT, buff=0.1) logprobs = np.array(self.logprobs) bar_groups = self.get_output_distribution(self.words, logprobs, out_arrow) example = VGroup(in_text, in_arrow, out_arrow, bar_groups) example.logprobs = logprobs return example def fix_dials(self, dials): for dial in dials: dial.set_stroke(width=1, opacity=1) dial.needle.set_stroke(width=(2, 0)) return dials def bar_group_change_animation(self, bar_group, new_value): text, rect, value_mob = bar_group buff = value_mob.get_left() - rect.get_right() factor = new_value / value_mob.get_value() return AnimationGroup( rect.animate.stretch(factor, 0, about_edge=LEFT), ChangeDecimalToValue(value_mob, new_value), UpdateFromFunc(text, lambda m: value_mob.move_to(rect.get_right() + buff, LEFT)), ) def get_output_distribution(self, words, logprobs, out_arrow): probs = softmax(logprobs) bar_groups = self.get_distribution(words, probs, self.machine, width_100p=1.0) bar_groups.next_to(out_arrow, RIGHT) return bar_groups class ShowSingleTrainingExample(ShowMachineWithDials): logprobs = [4.0, 6.15, 1.89, 1.4, 0.1, -0.18, -0.23, -0.61] def construct(self): # Add state from before frame = self.frame self.set_floor_plane("xz") blocks, llm_text, flat_dials, last_dials = self.get_blocks_and_dials() self.fix_dials(last_dials) example = self.get_example(blocks) in_text, in_arrow, out_arrow, bar_groups = example self.add(blocks, last_dials) # Show example up top parts = ("It was the best of times it was the", "worst") sentence = Text(" ".join(parts)) start = sentence[parts[0]][0] end = sentence[parts[1]][0] sentence.set_width(10) sentence.next_to(blocks, UP, buff=1.5) start_rect = SurroundingRectangle(start) start_rect.set_stroke(BLUE, 2) start_rect.set_fill(BLUE, 0.2) end_rect = SurroundingRectangle(end) end_rect.match_height(start_rect, stretch=True).match_y(start_rect) end_rect.set_stroke(YELLOW, 2) end_rect.set_fill(YELLOW, 0.2) arrow = Arrow(start_rect.get_top(), end_rect.get_top(), path_arc=-90 * DEGREES, thickness=5) arrow.set_fill(border_width=1) frame.reorient(0, 0, 0, (-0.36, 0.97, 0.0), 7.52) self.play(FadeIn(sentence, UP)) self.play( LaggedStartMap(DrawBorderThenFill, VGroup(start_rect, end_rect)), FadeIn(arrow), ) self.remove(last_dials) self.play(LaggedStart( AnimationGroup( TransformFromCopy(start, in_text[:-1]), TransformFromCopy(end_rect, in_text[-1]), FadeIn(in_arrow) ), LaggedStart( ( block.animate.set_color( block.get_color() if block is blocks[-1] else TEAL ).set_anim_args(rate_func=there_and_back) for block in blocks ), group=blocks, lag_ratio=0.1, run_time=1 ), Animation(last_dials), GrowArrow(out_arrow), LaggedStartMap(GrowFromPoint, bar_groups, point=out_arrow.get_start()), lag_ratio=0.3 )) self.wait() # Flag bad prediction out_rects = VGroup( SurroundingRectangle(bg) for bg in bar_groups[:2] ) out_rects.set_stroke(RED, 3) annotations = VGroup( Tex(tex, font_size=60).next_to(rect, LEFT, buff=SMALL_BUFF) for rect, tex in zip(out_rects, [R"\uparrow", R"\downarrow"]) ) annotations.set_color(RED) self.play( FadeTransform(end_rect.copy(), out_rects[0]), Write(annotations[0]), ) self.wait() self.play( FadeTransform(*out_rects), FadeTransform(*annotations), ) self.wait() self.play( FadeOut(out_rects[1]), FadeOut(annotations[1]), ) # Adjust self.play( LaggedStart( (dial.animate_set_value(dial.get_random_value()) for dial in last_dials), lag_ratio=1.0 / len(last_dials), ), LaggedStart( (FlashAround(dial, stroke_width=2, color=YELLOW, time_width=1, buff=0.025) for dial in last_dials), lag_ratio=1.0 / len(last_dials), ), self.bar_group_change_animation(bar_groups[0], 70), self.bar_group_change_animation(bar_groups[1], 20), self.bar_group_change_animation(bar_groups[2], 8), run_time=6 ) class ParameterWeight(InteractiveScene): def construct(self): # Test text = Text("Parameter / Weight", font_size=72) text.to_edge(UP) text.set_color(YELLOW) param = text["Parameter"][0] param.save_state() param.set_x(0) self.play(Write(param)) self.wait() self.play(LaggedStart( Restore(param), FadeIn(text["/ Weight"]), )) self.wait() class LargeInLargeLanguageModel(InteractiveScene): def construct(self): # Test text = Text("Large Language Model", font_size=72) text.to_edge(UP) large = text["Large"][0] large.save_state() large.set_x(0) self.add(large) self.play(FlashUnder(large), large.animate.set_color(YELLOW)) self.play( Restore(large, path_arc=-30 * DEGREES), Write(text[len(large):], time_span=(0.5, 1.5)) ) self.wait() class ThousandsOfWords(InteractiveScene): def construct(self): # Find passage file = Path("/Users/grant/3Blue1Brown Dropbox/3Blue1Brown/videos/2024/transformers/data/tale_of_two_cities.txt") novel = file.read_text() start_index = novel.index("It was the best of times") end_index = novel.index("There were a king with a large jaw") # Add text passage = novel[start_index:start_index + 5000].replace("\n", " ") text = get_paragraph(passage.split(" "), line_len=150) text.set_width(14) text.to_edge(UP) self.add(text) class EnormousAmountOfTrainingText(PremiseOfMLWithText): def construct(self): # Setup self.init_data() # n_rows = n_cols = 41 n_rows = n_cols = 9 screens = VGroup() for row in range(n_rows): for col in range(n_cols): screen = self.get_screen() screen.move_to(FRAME_WIDTH * row * RIGHT + FRAME_HEIGHT * col * DOWN) screens.add(screen) screens.center() screens.submobjects.sort(key=lambda sm: get_norm(sm.machine.get_center())) self.add(screens) # Add frame growth frame = self.frame frame.clear_updaters() frame.add_updater(lambda m: m.set_height(FRAME_HEIGHT * np.exp(0.2 * self.time))) # Show lots of new data inner_screens = screens[:25] n_examples = 20 for n in range(n_examples): self.play(LaggedStart( *(self.change_example_animation(screen, show_dial_change=True) for screen in inner_screens), lag_ratio=0.1, run_time=0.5, )) def change_example_animation(self, screen, show_dial_change=True): new_example = VGroup(*self.new_input_output_example(*screen.arrows)) time_span = (0, 0.35) anims = [ FadeOut(screen.training_example, time_span=time_span), FadeIn(new_example, time_span=time_span), ] if show_dial_change: anims.append(screen.machine.random_change_animation(run_time=0.5)) screen.training_example = new_example return AnimationGroup(*anims) def get_screen(self): border = FullScreenRectangle() border.set_fill(opacity=0) border.set_stroke(WHITE, 2) machine = MachineWithDials(width=3.5, height=2.5, n_rows=5, n_cols=7) machine.move_to(1.0 * RIGHT) in_arrow, out_arrow = arrows = Vector(RIGHT).replicate(2) in_arrow.next_to(machine, LEFT) out_arrow.next_to(machine, RIGHT) in_data, out_data = training_example = VGroup( *self.new_input_output_example(in_arrow, out_arrow) ) screen = VGroup( border, machine, arrows, training_example ) screen.border = border screen.machine = machine screen.arrows = arrows screen.training_example = training_example return screen def new_input_output_example(self, in_arrow, out_arrow): in_data, out_data = super().new_input_output_example(in_arrow, out_arrow) in_data.scale(0.8, about_edge=RIGHT) out_data.scale(0.8, about_edge=LEFT) return in_data, out_data class BadChatBot(InteractiveScene): def construct(self): # Add bot bot = self.get_bot() bot.set_height(3) lines = Line(LEFT, RIGHT).get_grid(4, 1, buff=0.25) lines.set_stroke(WHITE, 1) lines[-1].stretch(0.5, 0, about_edge=LEFT) lines.set_width(3) bubble = SpeechBubble(lines, buff=MED_LARGE_BUFF) bubble.set_stroke(width=5) bubble.pin_to(bot).shift(DOWN) self.add(bot) self.play(Write(bubble, run_time=3)) self.blink(bot) self.wait() # Make lines bad self.play( LaggedStart( (Transform(line, self.get_scribble(line)) for line in lines), lag_ratio=0.1, run_time=2 ) ) for _ in range(2): self.blink(bot) self.wait(2) def get_scribble(self, line): freqs = np.random.random(5) graph = FunctionGraph( lambda x: 0.05 * sum(math.sin(freq * TAU * x) for freq in freqs), x_range=(0, 5, 0.1) ) graph.put_start_and_end_on(*line.get_start_and_end()) graph.match_style(line) graph.set_stroke(color=RED) return graph def get_bot(self): bot = SVGMobject("Bot") subpaths = bot[0].get_subpaths() bot[0].set_points([*subpaths[0], subpaths[0][-1], *subpaths[1]]) eyes = VGroup(Dot().replace(VMobject().set_points(subpath)) for subpath in subpaths[2:]) bot.eyes = eyes bot.add(eyes) bot.set_stroke(width=0) bot.set_height(4) bot.set_fill(GREY_B) bot.set_shading(0.5, 0.5, 1) return bot def blink(self, bot): self.play( bot.eyes.animate.stretch(0, 1).set_anim_args(rate_func=squish_rate_func(there_and_back)) ) class WriteRLHF(InteractiveScene): def construct(self): text = Text("Step 2: RLHF") full_text = Text("Reinforcement Learning\nwith Human Feedback") full_text.next_to(text, UP, LARGE_BUFF) full_text.align_to(text, RIGHT).shift(RIGHT) initials = VGroup(full_text[letter[0]][0][0] for letter in "RLHF") full_text.remove(*initials) self.add(text) self.wait() self.play( TransformFromCopy(text["RLHF"][0], initials, lag_ratio=0.25), Write(full_text, time_span=(1.5, 3)), run_time=3 ) self.wait() class RLHFWorker(InteractiveScene): def construct(self): # Test self.add(FullScreenRectangle().set_fill(GREY_E, 1)) # worker = SVGMobject("computer_stall") worker = SVGMobject("comp_worker") worker.set_height(4) worker.move_to(4 * LEFT) worker.set_fill(GREY_C, 1) rect = Rectangle(7, 5) rect.to_edge(RIGHT) rect.set_stroke(WHITE, 2) rect.set_fill(BLACK, 1) self.add(worker) self.add(rect) class RLHFWorkers(ShowMachineWithDials): def construct(self): # Add workers self.add(FullScreenRectangle().set_fill(GREY_E, 1)) workers = SVGMobject("comp_worker").get_grid(3, 2, buff=0.5) workers.set_height(7) workers.to_edge(LEFT) workers.set_fill(GREY_C, 1) self.add(workers) # Machine blocks, llm_text, flat_dials, last_dials = self.get_blocks_and_dials() machine = VGroup(blocks, last_dials) machine.set_height(4) machine.center().to_edge(RIGHT, buff=LARGE_BUFF) last_dials.set_stroke(opacity=1) self.add(machine) for _ in range(8): self.play(LaggedStart( (dial.animate_set_value(dial.get_random_value()) for dial in last_dials), lag_ratio=0.5 / len(last_dials), run_time=2 )) self.wait() class SerialProcessing(InteractiveScene): phrase = "It was the best of times it was the worst of times" phrase_center = 2 * UP def construct(self): # Set up words words = self.get_words() rects = get_piece_rectangles(words) self.add(rects) self.add(words) # Animate in the vectors vectors = VGroup( self.get_abstract_vector().next_to(word, DOWN, LARGE_BUFF) for word in words ) last_vect = VGroup(VectorizedPoint(rects[0].get_bottom())) for word, vect in zip(words, vectors): self.play( FadeIn(vect, run_time=2), LaggedStart( (ContextAnimation( square, VGroup(*word, *last_vect), direction=DOWN, lag_ratio=0.01, path_arc=30 * DEGREES ) for square in vect), lag_ratio=0.05, run_time=2 ), last_vect.animate.set_opacity(0.2) ) last_vect = vect def get_words(self): result = break_into_words(Text(self.phrase)) result.move_to(self.phrase_center) return result def get_abstract_vector(self, values=None, default_length=10, elem_size=0.2): if values is None: values = np.random.uniform(-1, 1, default_length) result = Square().get_grid(len(values), 1, buff=0) result.set_width(elem_size) result.set_stroke(WHITE, 1) for square, value in zip(result, values): color = value_to_color(value, min_value=0, max_value=1) square.set_fill(color, opacity=1) return result class ParallelProcessing(SerialProcessing): def construct(self): # Set up words words = self.get_words() rects = get_piece_rectangles(words) self.add(rects) self.add(words) # Animate in the vectors vectors = VGroup( self.get_abstract_vector().next_to(word, DOWN, buff=1.5) for word in words ) lines = VGroup( Line( rect.get_bottom(), vect.get_top(), buff=0.05, stroke_color=WHITE, stroke_width=2 * random.random()**3 ) for rect in rects for vect in vectors ) lines.shuffle() for vect, word in zip(vectors, words): vect.save_state() for square in vect: square.move_to(word) square.set_opacity(0) self.play( LaggedStartMap(ShowCreation, lines, lag_ratio=0.01), LaggedStartMap(Restore, vectors, lag_ratio=0) ) self.play(lines.animate.set_stroke(opacity=0.25)) self.wait() class ManyComputationsPerUnitTimeV2(InteractiveScene): def construct(self): # Add computations box = Rectangle(5, 5) label = Text("1 Billion computations per Second") label.next_to(box, UP) self.add(box) self.add(label) comps = self.get_computations(box) self.add(comps) self.wait(3) # Place box into minute interval width = FRAME_WIDTH - 1 number_lines = VGroup( minute_line := NumberLine((0, 60, 1), width=width, big_tick_spacing=10), hour_line := NumberLine((0, 60, 1), width=width, big_tick_spacing=10), day_line := NumberLine((0, 24, 1), width=width, big_tick_spacing=6), month_line := NumberLine((0, 31, 1), width=width), year_line := NumberLine((0, 12, 1), width=width), y100_line := NumberLine((0, 100, 1), width=width), y10k_line := NumberLine((0, 100, 1), width=width), y1M_line := NumberLine((0, 100, 1), width=width), y100M_line := NumberLine((0, 100, 1), width=width), ) number_lines.move_to(DOWN) first_ticks = minute_line.ticks[:2] sec_brace = Brace(first_ticks, DOWN, buff=0, tex_string=R"\underbrace{\qquad\qquad}") sec_label = Text("Second", font_size=30).next_to(sec_brace, DOWN, SMALL_BUFF) self.play( ShowCreation(minute_line, lag_ratio=0.01), box.animate.match_width(first_ticks).move_to(first_ticks.get_center(), DOWN).set_stroke(width=1), TransformFromCopy(label["Second"][0], sec_label), GrowFromCenter(sec_brace), run_time=2 ) # Add other boxes minute_label = self.get_timeline_full_label(number_lines[1], "Minute") new_boxes = VGroup( box.copy().move_to(tick.get_center(), DL) for tick in minute_line.ticks[1:-1] ) for new_box in new_boxes: new_box.save_state() new_box.move_to(box) computations = VGroup( self.get_computations(new_box, n_iterations=1) for new_box in new_boxes ) # computations = VGroup() # If needed self.add(computations) self.play( FadeIn(minute_label, DOWN), LaggedStartMap(Restore, new_boxes, lag_ratio=0.1), run_time=2 ) self.wait(2) # Add labels minute_line.add(minute_label) names = ["Hour", "Day", "Month", "Year", "100 Years", "10,000 Years", "1,000,000 Years", "100,000,000 Years"] for line, name in zip(number_lines[1:], names): line.label = self.get_timeline_full_label(line, name) line.add(line.label) # Arrange all lines number_lines[1:].arrange(DOWN, buff=2.0) number_lines[1:].next_to(minute_line, DOWN, buff=2.0) scale_lines = VGroup() for nl1, nl2 in zip(number_lines, number_lines[1:]): n = len(nl2.ticks) // 2 mini_line = Line(nl2.ticks[n - 1].get_center(), nl2.ticks[n].get_center()) pair = VGroup( DashedLine(nl1.get_start(), mini_line.get_start()), DashedLine(nl1.get_end(), mini_line.get_end()), ) pair.set_stroke(WHITE, 2) nl1.target = nl1.copy() nl1.target.replace(mini_line, dim_to_match=0) nl1.target.shift(mini_line.pfp(0.5) - nl1.target.pfp(0.5)) scale_lines.add(pair) # Start panning down lag_ratio = 1.5 self.play( LaggedStart( *(AnimationGroup(*(ShowCreation(sl) for sl in pair)) for pair in scale_lines), lag_ratio=lag_ratio, ), LaggedStart( *(FadeIn(nl) for nl in number_lines[1:]), lag_ratio=lag_ratio, ), LaggedStart( *(TransformFromCopy(nl, nl.target) for nl in number_lines[:-1]), lag_ratio=lag_ratio, ), self.frame.animate.set_y(number_lines[-1].get_y() + 2).set_width(18).set_anim_args( rate_func=lambda t: interpolate(smooth(t), linear(t), there_and_back_with_pause(t, pause_ratio=0.8)) ), run_time=30 ) self.play(self.frame.animate.reorient(0, 0, 0, (-0.03, -11.55, 0.0), 31.76), run_time=4) self.wait(4) def fade_in_bigger_interval(self, new_interval, prev_interval, fader, scale_factor, added_anims=[]): pivot = prev_interval.n2p(0) new_interval.save_state() new_interval.scale(scale_factor, about_point=pivot) new_interval[:-1].set_opacity(0) new_interval[-1].set_fill(BLACK) self.play( Restore(new_interval), prev_interval.animate.scale(1.0 / scale_factor, about_point=pivot).set_fill(border_width=0), fader.animate.scale(1.0 / scale_factor, about_point=pivot).set_opacity(0), *added_anims, run_time=4, rate_func=rush_from ) self.remove(fader) def get_timeline_full_label(self, timeline, name): brace = Brace(Line().set_width(7), UP, buff=MED_SMALL_BUFF) brace.set_fill(border_width=5) brace.match_width(timeline) brace.next_to(timeline, UP, buff=MED_SMALL_BUFF) label = Text(name, font_size=72) label.next_to(brace, UP, MED_SMALL_BUFF) label.next_to(timeline, DOWN) return label return VGroup(brace, label) def get_computations(self, box, n_lines=10, n_iterations=3, n_digits=4, cycle_time=0.5): # Try adding lines lines = VGroup() for iteration in range(n_iterations): cluster = VGroup() for n in range(n_lines): x = random.uniform(0, 10**(n_digits)) y = random.uniform(0, 10**(n_digits)) if random.choice([True, False]): comb = x * y sym = Tex(R"\times") else: comb = x + y sym = Tex(R"+") line = VGroup( DecimalNumber(x, num_decimal_places=3), sym, DecimalNumber(y, num_decimal_places=3), Tex("="), DecimalNumber(comb, num_decimal_places=3) ) line.arrange(RIGHT, buff=SMALL_BUFF) lines.add(line) cluster.add(line) cluster.arrange(DOWN, buff=MED_LARGE_BUFF, aligned_edge=LEFT) cluster.set_max_height(0.9 * box.get_height()) cluster.set_max_width(0.9 * box.get_width()) cluster.move_to(box) # Add updater def update_lines(lines): sigma = 0.12 alpha = (self.time / (cycle_time * n_iterations)) % 1 step = 1.0 / len(lines) for n, line in enumerate(lines): x = min(( abs(a - n * step) for a in (alpha - 1, alpha, alpha + 1) )) y = np.exp(-x**2 / sigma**2) line.set_fill(opacity=y) lines.set_height(0.9 * box.get_height()) lines.move_to(box) lines.clear_updaters() lines.add_updater(update_lines) return lines def old(self): # Repeatedly scale down to_fade = VGroup(sec_brace, sec_label, box, comps, new_boxes, computations) scale_factors = [60, 24, 365, 1000] for new_int, prev_int, scale_factor in zip(number_lines[1:], number_lines[0:], scale_factors): self.fade_in_bigger_interval( new_int, prev_int, to_fade, scale_factor, added_anims=[label.animate.set_opacity(0)], ) self.wait(2) to_fade = prev_int # Multiply last line by 100 self.fade_in_bigger_interval( y1M_line, millenium_line, year_line, 1000, added_anims=[self.frame.animate.reorient(0, 0, 0, (-3.51, -5.18, 0.0), 12.93)], ) lines = Line(LEFT, RIGHT).replicate(100) lines.match_width(y1M_line) lines.arrange_to_fit_height(10) lines.sort(lambda p: -p[1]) lines.set_stroke(WHITE, 1) lines.move_to(y1M_line[0].get_center(), UP) side_brace, label100M = self.get_timeline_full_label(y1M_line, "100,000,000 Years") side_brace.rotate(PI / 2) side_brace.match_height(lines) side_brace.next_to(lines, LEFT) label100M.next_to(side_brace, LEFT) self.play( LaggedStart( (TransformFromCopy(lines[0].copy().set_opacity(0), line) for line in lines), lag_ratio=0.03, run_time=2 ), FadeIn(side_brace, scale=10, shift=2 * DOWN, time_span=(1, 2)), FadeIn(label100M, time_span=(1, 2)), ) self.wait() class VectorLabel(InteractiveScene): def construct(self): # Test brace = Brace(Line(4 * UP, ORIGIN), LEFT) brace.center() brace.set_stroke(WHITE, 3) text = Text("Vector", font_size=90) text.next_to(brace, LEFT, MED_SMALL_BUFF) text.shift(SMALL_BUFF * UP) self.play( GrowFromCenter(brace), Write(text) ) self.play( FlashUnder(text, color=YELLOW) ) self.wait() class ParameterToVectorAnnotation(InteractiveScene): def construct(self): # Test dials = VGroup(Dial(value_range=(-10, 10, 1)) for _ in range(10)) dials.arrange(DOWN) dials.set_height(5) values = [1, 4.3, 2, 0.9, -1.5, 2.9, -1.2, 7.8, 0, -2.3] arrows = VGroup( Vector(0.5 * RIGHT, thickness=2).next_to(dial, RIGHT, buff=SMALL_BUFF) for dial in dials ) self.play( Write(dials, lag_ratio=0.01), LaggedStartMap(GrowArrow, arrows), ) self.play(LaggedStart( (dial.animate_set_value(value) for dial, value in zip(dials, values)), lag_ratio=0.05, )) self.wait() class ThreeWordsToOne(InteractiveScene): def construct(self): # Test image = ImageMobject("CHMTopText") image.set_height(FRAME_HEIGHT) # self.add(image) phrase = Text("Computer History Museum", font_size=61) words = VGroup(phrase[word][0] for word in phrase.get_text().split(" ")) words.move_to([0, 2.627, 0]) og_words = words.copy() og_words.shift(DOWN) words[0].shift(0.13 * LEFT) words[2].shift(0.4 * RIGHT) colors = ["#63DCF7", "#90C9FA", "#85D4FE"] for word, color in zip(words, colors): word.set_color(color) words.save_state() self.add(words) self.wait() # Back to unity rect = SurroundingRectangle(og_words) rect.set_color(RED) chm_image = ImageMobject("/Users/grant/3Blue1Brown Dropbox/3Blue1Brown/videos/2024/transformers/chm/images/CHM_Exterior.jpeg") chm_image.match_width(rect) chm_image.next_to(rect, DOWN) self.play(Transform(words, og_words)) self.play( ShowCreation(rect), FadeIn(chm_image, DOWN) ) self.wait() # Three pieces rects = VGroup( SurroundingRectangle(word).set_fill(color, 0.2).set_stroke(color, 2) for word, color in zip(words.saved_state, colors) ) words.set_z_index(1) icons = VGroup( SVGMobject("GenericComputer.svg"), SVGMobject("History.svg"), SVGMobject("Museum.svg"), ) for word, icon in zip(words.saved_state, icons): icon.set_fill(word.get_color(), 1, border_width=1) icon.set_height(1) icon.next_to(word, DOWN) self.remove(chm_image) self.play( ReplacementTransform(VGroup(rect), rects), Restore(words), *( FadeTransform(chm_image.copy(), icon) for icon in icons ) ) self.wait() class ExamplePhraseHeader(InteractiveScene): def construct(self): # Test phrase = Text("The Computer History Museum\nis located in ?????") phrase.to_edge(UP) rect = SurroundingRectangle(phrase).set_stroke(WHITE, 2) q_marks = phrase["?????"][0] q_marks[::4].set_fill(opacity=0) q_rect = SurroundingRectangle(q_marks) q_rect.set_fill(YELLOW, 0.25) q_rect.set_stroke(YELLOW, 2) self.add(q_rect) self.add(phrase) class TrainingDataCHM(InteractiveScene): def construct(self): # Test passages = [ "The Computer History Museum (CHM) is a museum ... located in Mountain View...", "Computer History Museum ... 1401 N. Shoreline Blvd. Mountain View...", "Things to do in Mountain View ... the Computer History Museum ...", "While I was in Mountain View ... stopped by the Computer History Museum ...", ] items = VGroup( get_paragraph(passage.split(" "), line_len=35, font_size=30) for passage in passages ) items.arrange(DOWN, buff=LARGE_BUFF, aligned_edge=LEFT) items.to_corner(DL) items.shift(0.5 * UP) dots = Tex(R"\vdots") dots.next_to(items, DOWN, MED_LARGE_BUFF) dots.shift_onto_screen(buff=MED_SMALL_BUFF) items.add(dots) title = Text("Training Data") title.next_to(items, UP, buff=LARGE_BUFF) title.shift_onto_screen(buff=MED_SMALL_BUFF) underline = Underline(title) chm_phrases = VGroup(item["Computer History Museum"] for item in items) mv_phrases = VGroup(item["Mountain View"] for item in items) self.play( FadeIn(title), ShowCreation(underline), LaggedStartMap(FadeIn, items, shift=DOWN, lag_ratio=0.15) ) self.wait() self.play(chm_phrases.animate.set_color(RED).set_anim_args(lag_ratio=0.1)) self.wait() self.play(mv_phrases.animate.set_color(PINK).set_anim_args(lag_ratio=0.1)) self.wait() # Arrows to ffn ffn_point = 3 * RIGHT + DOWN arrows = VGroup( Arrow( item.get_right(), interpolate(item.get_right(), ffn_point, 0.6), path_arc=arc * DEGREES, ) for item, arc in zip(items[:-1], range(-40, 40, 20)) ) arrows.set_fill(border_width=1) self.play(Write(arrows, lag_ratio=0.1), run_time=3) self.play( LaggedStart( *( FadeOutToPoint(letter.copy(), ffn_point) for letter in VGroup(chm_phrases, mv_phrases).family_members_with_points() ), lag_ratio=1e-2, run_time=3 ) ) self.wait() class DivyUpParameters(ShowMachineWithDials): def construct(self): # Show machine frame = self.frame self.set_floor_plane("xz") machine = VGroup(*self.get_blocks_and_dials()) blocks, llm_text, flat_dials, last_dials = machine machine.set_height(3.0) machine.to_edge(DOWN, buff=LARGE_BUFF) block_outlines = blocks.copy() block_outlines.set_fill(opacity=0) block_outlines.set_stroke(WHITE, 2) block_outlines.insert_n_curves(20) # last_dials.set_submobjects(last_dials[:3]) # Remove last_dials.set_stroke(opacity=1) for dial in last_dials: dial[0].set_stroke(width=1) dial[1].set_stroke(width=1) dial[3].set_stroke(width=(3, 0)) frame.reorient(-23, -13, 0, (-0.41, -1.71, -0.06), 4.95) self.play( FadeIn(blocks, shift=0.0, lag_ratio=0.01), LaggedStartMap(VShowPassingFlash, block_outlines.family_members_with_points(), time_width=2.0, lag_ratio=0.01, remover=True), LaggedStartMap(VFadeInThenOut, flat_dials, lag_ratio=0.001, remover=True), FadeIn(last_dials, time_span=(2, 3)), self.frame.animate.reorient(10, -2, 0, (-0.25, -1.58, -0.02), 4.61), run_time=3, ) self.remove(flat_dials) # Show individual blocks top_blocks = blocks[:3].copy() all_dials = VGroup(*last_dials) for block in top_blocks: dials = last_dials.copy() dials.rotate(self.machine_phi, RIGHT) dials.rotate(self.machine_theta, UP) dials.move_to(block) dials.set_stroke(opacity=1) block.add(dials) block.target = block.generate_target() dials.set_opacity(0) all_dials.add(*dials) block_targets = Group(block.target for block in top_blocks) block_targets.rotate(-self.machine_theta, UP) block_targets.rotate(-self.machine_phi, RIGHT) block_targets.set_height(2) block_targets.arrange(RIGHT, buff=1.5) block_targets.to_edge(UP) block_targets.set_shading(0.1, 0.1, 0.1) labels = VGroup( TexText(R"Word $\to$ Vector"), Text("Attention"), Text("Feedforward"), ) for label, block in zip(labels, block_targets): label.next_to(block, DOWN) self.add( blocks[0], top_blocks[0], blocks[1], top_blocks[1], blocks[2], top_blocks[2], blocks[3:], last_dials ) self.play( MoveToTarget(top_blocks[1], time_span=(0, 2)), MoveToTarget(top_blocks[2], time_span=(1, 3)), MoveToTarget(top_blocks[0], time_span=(2, 4)), Write(labels[1], time_span=(1.5, 2)), Write(labels[2], time_span=(2.5, 3)), Write(labels[0], time_span=(3.5, 4)), frame.animate.to_default_state(), run_time=4 ) self.wait() # Change all the parameters self.play( LaggedStart( (dial.animate_set_value(dial.get_random_value()) for dial in all_dials), lag_ratio=1 / len(all_dials), run_time=6 ), LaggedStart( (FlashAround(dial, buff=0, color=YELLOW) for dial in all_dials), lag_ratio=1 / len(all_dials), run_time=6 ), ) self.wait() # End clips class ShowPreviousVideos(InteractiveScene): def construct(self): # Backdrop background = FullScreenRectangle() self.add(background) line = Line(UP, DOWN).set_height(FRAME_HEIGHT) line.set_stroke(WHITE, 2) series_name = Text("Deep Learning Series", font_size=68) series_name.to_edge(UP, buff=0.35) self.add(series_name) # Show thumbnails thumbnails = Group( Group( Rectangle(16, 9).set_height(1).set_stroke(WHITE, 2), ImageMobject(f"https://img.youtube.com/vi/{slug}/maxresdefault.jpg", height=1) ) for slug in [ "aircAruvnKk", "IHZwWFHWa-w", "Ilg3gGewQ5U", "tIeHLnjs5U8", "wjZofJX0v4M", "eMlx5fFNoYc", "9-Jl0dxWQs8", ] ) thumbnails.arrange_in_grid(n_cols=4, buff=0.2) thumbnails.set_width(FRAME_WIDTH - 1) thumbnails.next_to(series_name, DOWN, buff=1.0) thumbnails[-3:].set_x(0) self.play(LaggedStartMap(FadeIn, thumbnails, shift=0.3 * UP, lag_ratio=0.35, run_time=4)) self.wait() # Rearrange left_x = -FRAME_WIDTH / 4 self.play( series_name.animate.set_x(left_x), thumbnails.animate.arrange_in_grid(n_cols=2, buff=0.25).set_height(6).set_x(left_x).to_edge(DOWN), ShowCreation(line, time_span=(1, 2)), run_time=2, ) self.wait() class EndScreen(PatreonEndScreen): title_text = "Where to dig deeper" thanks_words = """ Special thanks to these Patreon supporters """