mirror of
https://github.com/3b1b/videos.git
synced 2025-09-18 21:38:53 +00:00
2160 lines
71 KiB
Python
2160 lines
71 KiB
Python
from manim_imports_ext import *
|
|
from _2024.transformers.generation import *
|
|
from _2024.transformers.helpers import *
|
|
from _2024.transformers.embedding import *
|
|
from _2024.transformers.ml_basics import *
|
|
|
|
|
|
# Intro
|
|
|
|
class HoldUpThumbnail(TeacherStudentsScene):
|
|
def construct(self):
|
|
# Test
|
|
im = ImageMobject("/Users/grant/3Blue1Brown Dropbox/3Blue1Brown/videos/2024/transformers/Thumbnails/Chapter5_TN3.png")
|
|
im_group = Group(
|
|
SurroundingRectangle(im, buff=0).set_stroke(WHITE, 3),
|
|
im
|
|
)
|
|
im_group.set_height(3)
|
|
im_group.move_to(self.hold_up_spot, DOWN)
|
|
|
|
morty = self.teacher
|
|
stds = self.students
|
|
|
|
self.play(
|
|
FadeIn(im_group, UP),
|
|
morty.change("raise_right_hand", look_at=im_group),
|
|
self.change_students("tease", "happy", "tease", look_at=im_group),
|
|
)
|
|
self.wait(4)
|
|
|
|
|
|
class IsThisUsefulToShare(TeacherStudentsScene):
|
|
def construct(self):
|
|
# Test
|
|
morty = self.teacher
|
|
self.play(
|
|
morty.says("Do you find\nthis useful?"),
|
|
self.change_students("pondering", "hesitant", "well", look_at=self.screen)
|
|
)
|
|
self.wait(3)
|
|
self.play(self.change_students("thinking", "pondering", "tease"))
|
|
self.wait(3)
|
|
|
|
|
|
class AskAboutAttention(TeacherStudentsScene):
|
|
def construct(self):
|
|
# Test
|
|
stds = self.students
|
|
morty = self.teacher
|
|
self.play(
|
|
morty.change("tease"),
|
|
stds[2].says("Can you explain what\nAttention does?", mode="raise_left_hand", bubble_direction=LEFT),
|
|
stds[1].change("pondering", self.screen),
|
|
stds[0].change("pondering", self.screen),
|
|
)
|
|
self.wait(4)
|
|
|
|
|
|
# Version 1
|
|
|
|
class PredictTheNextWord(SimpleAutogregression):
|
|
text_corner = 3.5 * UP + 6.5 * LEFT
|
|
machine_name = "Large\nLanguage\nModel"
|
|
seed_text = "Paris is a city in"
|
|
model = "gpt3"
|
|
n_shown_predictions = 12
|
|
random_seed = 2
|
|
|
|
def construct(self):
|
|
# Setup machine
|
|
text_mob, next_word_line, machine = self.init_text_and_machine()
|
|
machine.move_to(ORIGIN)
|
|
machine[1].set_backstroke(BLACK, 3)
|
|
|
|
text_group = VGroup(text_mob, next_word_line)
|
|
text_group.save_state()
|
|
text_group.scale(1.5)
|
|
text_group.match_x(machine[0]).to_edge(UP)
|
|
|
|
# Introduce the machine
|
|
in_arrow = Arrow(text_group, machine[0].get_top(), thickness=5)
|
|
frame = self.frame
|
|
self.set_floor_plane("xz")
|
|
blocks = machine[0]
|
|
llm_text = machine[1]
|
|
block_outlines = blocks.copy()
|
|
block_outlines.set_fill(opacity=0)
|
|
block_outlines.set_stroke(GREY_B, 2)
|
|
block_outlines.insert_n_curves(20)
|
|
|
|
flat_dials, last_dials = self.get_machine_dials(blocks)
|
|
|
|
self.clear()
|
|
frame.reorient(-31, -4, -5, (-0.24, -0.26, -0.06), 3)
|
|
self.play(
|
|
FadeIn(blocks, shift=0.0, lag_ratio=0.01),
|
|
LaggedStartMap(VShowPassingFlash, block_outlines.family_members_with_points(), time_width=2.0, lag_ratio=0.01, remover=True),
|
|
LaggedStartMap(VFadeInThenOut, flat_dials, lag_ratio=0.001, remover=True),
|
|
Write(llm_text, time_span=(2, 4), stroke_color=WHITE),
|
|
FadeIn(last_dials, time_span=(4, 5)),
|
|
frame.animate.reorient(0, 0, 0, (-0.17, -0.12, 0.0), 4.50),
|
|
run_time=6,
|
|
)
|
|
blocks[-1].add(last_dials)
|
|
self.play(
|
|
frame.animate.to_default_state(),
|
|
FadeIn(text_group, UP),
|
|
GrowFromCenter(in_arrow),
|
|
run_time=3
|
|
)
|
|
|
|
# Single word prediction
|
|
out_arrow = Vector(1.5 * RIGHT, thickness=5)
|
|
out_arrow.next_to(machine[0][-1], RIGHT)
|
|
prediction = Text("France", font_size=72)
|
|
prediction.next_to(out_arrow, RIGHT)
|
|
|
|
self.animate_text_input(
|
|
text_mob, machine,
|
|
position_text_over_machine=False,
|
|
)
|
|
self.play(
|
|
LaggedStart(
|
|
(TransformFromCopy(VectorizedPoint(machine.get_right()), letter)
|
|
for letter in prediction),
|
|
lag_ratio=0.05,
|
|
),
|
|
GrowArrow(out_arrow)
|
|
)
|
|
self.wait()
|
|
machine.replace_submobject(2, out_arrow)
|
|
|
|
# Probability distribution
|
|
self.play(FadeOut(prediction, DOWN))
|
|
bar_groups = self.animate_prediction_ouptut(machine, self.cur_str)
|
|
self.wait()
|
|
|
|
# Show auto_regression
|
|
self.play(
|
|
Restore(text_group),
|
|
FadeOut(in_arrow),
|
|
)
|
|
|
|
seed_label = Text("Seed text")
|
|
seed_label.set_color(YELLOW)
|
|
seed_label.next_to(text_mob, DOWN)
|
|
|
|
self.play(
|
|
FadeIn(seed_label, rate_func=there_and_back_with_pause),
|
|
FlashAround(text_mob, time_width=2),
|
|
frame.animate.reorient(0, 0, 0, (0.7, -0.01, 0.0), 8.52),
|
|
run_time=2,
|
|
)
|
|
|
|
self.animate_random_sample(bar_groups)
|
|
new_text_mob = self.animate_word_addition(
|
|
bar_groups, text_mob, next_word_line,
|
|
)
|
|
|
|
# More!
|
|
for n in range(20):
|
|
text_mob = self.new_selection_cycle(
|
|
text_mob, next_word_line, machine,
|
|
quick=True,
|
|
skip_anims=(n > 5),
|
|
)
|
|
self.wait(0.25)
|
|
|
|
def get_machine_dials(self, blocks):
|
|
dials = VGroup(
|
|
Dial().get_grid(8, 12).set_width(0.9 * block.get_width()).move_to(block)
|
|
for block in blocks
|
|
)
|
|
dials.set_stroke(opacity=0.5)
|
|
for group in dials:
|
|
for dial in group:
|
|
dial.set_value(dial.get_random_value())
|
|
flat_dials = VGroup(*it.chain(*dials))
|
|
last_dials = dials[-1].copy()
|
|
last_dials.set_stroke(opacity=0.1)
|
|
|
|
return flat_dials, last_dials
|
|
|
|
|
|
class LotsOfTextIntoTheMachine(PredictTheNextWord):
|
|
run_time = 25
|
|
max_snippet_width = 3
|
|
|
|
def construct(self):
|
|
# Add machine
|
|
text_mob, next_word_line, machine = self.init_text_and_machine()
|
|
machine.scale(1.5)
|
|
self.clear()
|
|
self.add(machine)
|
|
|
|
blocks, title = machine[:2]
|
|
dials = Dial().get_grid(8, 12).set_width(0.9 * blocks[-1].get_width()).move_to(blocks[-1])
|
|
dials.set_stroke(opacity=0.1)
|
|
blocks[-1].add(dials)
|
|
|
|
machine.center()
|
|
machine[1].set_stroke(BLACK, 3)
|
|
|
|
# Feed in lots of text
|
|
snippets = self.get_text_snippets()
|
|
text_mobs = VGroup(get_paragraph(snippet.split(" "), line_len=25) for snippet in snippets)
|
|
directions = compass_directions(12, start_vect=UR)
|
|
for text_mob, vect in zip(text_mobs, it.cycle(directions)):
|
|
text_mob.set_max_width(self.max_snippet_width)
|
|
text_mob.move_to(5 * vect).shift_onto_screen(buff=0.25)
|
|
|
|
self.play(
|
|
LaggedStart(
|
|
(Succession(
|
|
FadeIn(text_mob),
|
|
text_mob.animate.set_opacity(0).move_to(machine.get_center()),
|
|
)
|
|
for text_mob in text_mobs),
|
|
lag_ratio=0.05,
|
|
run_time=self.run_time
|
|
)
|
|
)
|
|
self.remove(text_mobs)
|
|
self.wait()
|
|
|
|
def get_text_snippets(self):
|
|
facts = Path(DATA_DIR, "pile_of_text.txt").read_text().split("\n")
|
|
random.shuffle(facts)
|
|
return facts
|
|
|
|
|
|
class EvenMoreTextIntoMachine(LotsOfTextIntoTheMachine):
|
|
run_time = 40
|
|
max_snippet_width = 2.5
|
|
n_examples = 300
|
|
context_size = 25
|
|
|
|
def get_text_snippets(self):
|
|
book = Path(DATA_DIR, "tale_of_two_cities.txt").read_text()
|
|
book = book.replace("\n", " ")
|
|
words = list(filter(lambda m: m, book.split(" ")))
|
|
context_size = self.context_size
|
|
result = []
|
|
for n in range(self.n_examples):
|
|
index = random.randint(0, len(words) - context_size - 1)
|
|
result.append(" ".join(words[index:index + context_size]))
|
|
|
|
return result
|
|
|
|
|
|
class WriteTransformer(InteractiveScene):
|
|
def construct(self):
|
|
text = Text("Transformer", font_size=120)
|
|
self.play(Write(text))
|
|
self.wait()
|
|
|
|
|
|
class LabelVector(InteractiveScene):
|
|
def construct(self):
|
|
brace = Brace(Line(UP, DOWN).set_height(4), RIGHT)
|
|
name = Text("Vector", font_size=72)
|
|
name.next_to(brace, RIGHT)
|
|
name.set_backstroke(BLACK, 5)
|
|
|
|
self.play(
|
|
GrowFromCenter(brace),
|
|
Write(name),
|
|
)
|
|
self.wait()
|
|
|
|
|
|
class AdjustingTheMachine(InteractiveScene):
|
|
def construct(self):
|
|
# Add a machine and repeatedly tweak it
|
|
frame = self.frame
|
|
self.set_floor_plane("xz")
|
|
frame.reorient(-28, -17, 0, ORIGIN, 8.91)
|
|
self.camera.light_source.move_to([-10, 10, 10])
|
|
|
|
machine = MachineWithDials(n_rows=10, n_cols=12)
|
|
machine.set_height(6)
|
|
blocks = VCube().replicate(10)
|
|
blocks.set_shape(machine.get_width(), machine.get_height(), 1.0)
|
|
blocks.deactivate_depth_test()
|
|
cam_loc = self.frame.get_implied_camera_location()
|
|
for block in blocks:
|
|
block.sort(lambda p: -get_norm(p - cam_loc))
|
|
blocks.set_fill(GREY_D, 1)
|
|
blocks.set_shading(0.2, 0.5, 0.25)
|
|
blocks.arrange(OUT, buff=0.5)
|
|
blocks.move_to(machine, OUT)
|
|
|
|
self.add(blocks)
|
|
self.add(machine)
|
|
|
|
frame.clear_updaters()
|
|
frame.add_updater(lambda f: f.set_theta(-30 * DEGREES * math.cos(0.1 * self.time)))
|
|
self.add(frame)
|
|
for x in range(6):
|
|
self.play(machine.random_change_animation(lag_factor=0.1))
|
|
|
|
|
|
class FirthQuote(InteractiveScene):
|
|
def construct(self):
|
|
# Show Quote
|
|
quote = TexText(R"``You shall know a word\\by the company it keeps!''", font_size=60)
|
|
image = ImageMobject("JohnRFirth") # From https://www.cambridge.org/core/journals/bulletin-of-the-school-of-oriental-and-african-studies/article/john-rupert-firth/D926AFCBF99AD17D5C7A7A9C0558DFDC
|
|
image.set_height(6.5)
|
|
image.to_corner(UL, buff=0.5)
|
|
name = Text("John R. Firth")
|
|
name.next_to(image, DOWN)
|
|
quote.move_to(midpoint(image.get_right(), RIGHT_SIDE))
|
|
quote.to_edge(UP)
|
|
|
|
self.play(
|
|
FadeIn(image, 0.25 * UP),
|
|
FadeIn(name, lag_ratio=0.1)
|
|
)
|
|
self.play(Write(quote))
|
|
self.wait()
|
|
|
|
# Show two sentences
|
|
phrases = VGroup(
|
|
Text("Down by the river bank"),
|
|
Text("Deposit a check at the bank"),
|
|
)
|
|
bank = Text("bank", font_size=90)
|
|
bank.set_color(TEAL)
|
|
bank.match_x(quote).match_y(image)
|
|
for phrase in phrases:
|
|
phrase["bank"].set_color(TEAL)
|
|
|
|
phrases.arrange(DOWN, buff=1.0, aligned_edge=LEFT)
|
|
phrases.next_to(quote, DOWN, buff=2.5)
|
|
phrases[1].set_opacity(0.15)
|
|
banks = VGroup(
|
|
phrase["bank"][0]
|
|
for phrase in phrases
|
|
)
|
|
|
|
self.play(
|
|
FadeIn(bank, scale=2, lag_ratio=0.25),
|
|
quote.animate.scale(0.7, about_edge=UP).set_opacity(0.75)
|
|
)
|
|
self.wait()
|
|
self.remove(bank)
|
|
self.play(
|
|
FadeIn(phrases[0][:len("downbytheriver")], lag_ratio=0.1),
|
|
FadeIn(phrases[1][:len("depositacheckatthe")], lag_ratio=0.1),
|
|
*(TransformFromCopy(bank, bank2) for bank2 in banks)
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
phrases[0].animate.set_opacity(0.5),
|
|
phrases[1].animate.set_opacity(1),
|
|
)
|
|
self.wait()
|
|
|
|
# Isolate both phrases
|
|
self.play(LaggedStart(
|
|
FadeOut(image, LEFT, scale=0.5),
|
|
FadeOut(name, LEFT, scale=0.5),
|
|
FadeOut(quote, LEFT, scale=0.5),
|
|
phrases.animate.set_opacity(1).arrange(DOWN, buff=3.5, aligned_edge=LEFT).move_to(0.5 * UP),
|
|
))
|
|
self.wait()
|
|
|
|
# Recreate
|
|
word = Text("bank", font_size=72)
|
|
word.set_color(TEAL)
|
|
self.clear()
|
|
|
|
self.add(word)
|
|
self.wait()
|
|
self.remove(word)
|
|
self.play(
|
|
*(
|
|
FadeIn(phrase[phrase.get_text().replace("bank", "")])
|
|
for phrase in phrases
|
|
),
|
|
*(
|
|
TransformFromCopy(word, phrase["bank"][0])
|
|
for phrase in phrases
|
|
)
|
|
)
|
|
self.add(phrases)
|
|
|
|
# Show influence
|
|
query_rects = VGroup(
|
|
SurroundingRectangle(bank)
|
|
for bank in banks
|
|
)
|
|
query_rects.set_stroke(TEAL, 2)
|
|
query_rects.set_fill(TEAL, 0.25)
|
|
key_rects = VGroup(
|
|
SurroundingRectangle(phrases[0]["river"]),
|
|
SurroundingRectangle(phrases[1]["Deposit"]),
|
|
SurroundingRectangle(phrases[1]["check"]),
|
|
)
|
|
key_rects.set_stroke(BLUE, 2)
|
|
key_rects.set_fill(BLUE, 0.5)
|
|
key_rects[2].match_height(key_rects[1], about_edge=UP, stretch=True)
|
|
arrows = VGroup(
|
|
Arrow(key_rects[0].get_top(), banks[0].get_top(), path_arc=-180 * DEGREES, buff=0.1),
|
|
Arrow(key_rects[1].get_top(), banks[1].get_top(), path_arc=-90 * DEGREES),
|
|
Arrow(key_rects[2].get_top(), banks[1].get_top(), path_arc=-90 * DEGREES),
|
|
)
|
|
arrows.set_color(BLUE)
|
|
|
|
key_rects.save_state()
|
|
key_rects[0].become(query_rects[0])
|
|
key_rects[1].become(query_rects[1])
|
|
key_rects[2].become(query_rects[1])
|
|
key_rects.set_opacity(0)
|
|
|
|
self.add(query_rects, phrases)
|
|
self.play(FadeIn(query_rects, lag_ratio=0.25))
|
|
self.wait()
|
|
|
|
self.add(key_rects, phrases)
|
|
self.play(Restore(key_rects, lag_ratio=0.1, path_arc=PI / 4, run_time=2))
|
|
self.play(LaggedStartMap(Write, arrows, stroke_width=5, run_time=3))
|
|
self.wait()
|
|
|
|
# Show images
|
|
images = Group(
|
|
ImageMobject("RiverBank"),
|
|
ImageMobject("FederalReserve"),
|
|
)
|
|
for image, bank in zip(images, banks):
|
|
image.set_height(2.0)
|
|
image.next_to(bank, DOWN, MED_SMALL_BUFF, aligned_edge=LEFT)
|
|
|
|
self.play(
|
|
LaggedStart(
|
|
(FadeTransform(Group(word).copy(), image)
|
|
for word, image in zip(banks, images)),
|
|
lag_ratio=0.5,
|
|
group_type=Group,
|
|
)
|
|
)
|
|
self.wait(2)
|
|
|
|
|
|
class DownByTheRiverHeader(InteractiveScene):
|
|
def construct(self):
|
|
words = Text("Down by the river bank ...")
|
|
rect = SurroundingRectangle(words["bank"])
|
|
rect.set_fill(BLUE, 0.5)
|
|
rect.set_stroke(BLUE, 3)
|
|
brace = Brace(rect, DOWN, buff=SMALL_BUFF)
|
|
self.add(rect, words, brace)
|
|
|
|
|
|
class RiverBankProbParts(SimpleAutogregression):
|
|
seed_text = "Down by the river bank, "
|
|
model = "gpt3"
|
|
|
|
def construct(self):
|
|
# Test
|
|
text_mob, next_word_line, machine = self.init_text_and_machine()
|
|
machine.set_x(0)
|
|
words = [
|
|
"water",
|
|
"river",
|
|
"lake",
|
|
"grass",
|
|
"waves",
|
|
"shallows",
|
|
"pool",
|
|
"depths",
|
|
"foam",
|
|
"mist",
|
|
]
|
|
probs = softmax([6, 5, 4, 4, 3.5, 3.25, 3, 3, 2.5, 2])
|
|
bar_groups = self.get_distribution(words, probs, machine)
|
|
|
|
self.clear()
|
|
bar_groups.set_height(6).center()
|
|
self.play(
|
|
LaggedStartMap(FadeIn, bar_groups, shift=0.25 * DOWN, run_time=3)
|
|
)
|
|
self.wait()
|
|
|
|
|
|
class FourStepsWithParameters(InteractiveScene):
|
|
def construct(self):
|
|
# Add rectangles and titles
|
|
self.add(FullScreenRectangle(fill_color=GREY_E))
|
|
rects = Square().replicate(4)
|
|
rects.arrange(RIGHT, buff=0.25 * rects[0].get_width())
|
|
rects.set_width(FRAME_WIDTH - 1.0)
|
|
rects.center().to_edge(UP, buff=0.5)
|
|
rects.set_fill(BLACK, 1)
|
|
rects.set_stroke(WHITE, 2)
|
|
names = VGroup(*map(TexText, [
|
|
R"Text snippets\\$\downarrow$\\Vectors",
|
|
R"Attention",
|
|
R"Feedforward",
|
|
R"Final prediction",
|
|
]))
|
|
for name, rect in zip(names, rects):
|
|
name.scale(0.8)
|
|
name.next_to(rect, DOWN)
|
|
|
|
self.add(rects)
|
|
self.play(LaggedStartMap(FadeIn, names, shift=0.25 * DOWN, lag_ratio=0.25))
|
|
self.wait()
|
|
|
|
# Show many dials
|
|
machines = VGroup(
|
|
MachineWithDials(
|
|
width=rect.get_width(),
|
|
height=3.0,
|
|
n_rows=9,
|
|
n_cols=6,
|
|
)
|
|
for rect in rects
|
|
)
|
|
for machine, rect in zip(machines, rects):
|
|
machine.next_to(rect, DOWN, buff=0)
|
|
machine[0].set_opacity(0)
|
|
machine.scale(rect.get_width() / machine.dials.get_width(), about_edge=UP)
|
|
machine.dials.shift(0.25 * UP)
|
|
for dial in machine.dials:
|
|
dial.set_value(0)
|
|
|
|
self.play(
|
|
LaggedStart((
|
|
LaggedStart(
|
|
(GrowFromPoint(dial, machine.get_top())
|
|
for dial in machine.dials),
|
|
lag_ratio=0.025,
|
|
)
|
|
for machine in machines
|
|
), lag_ratio=0.25),
|
|
LaggedStartMap(FadeOut, names)
|
|
)
|
|
for _ in range(2):
|
|
self.play(
|
|
LaggedStart(
|
|
(machine.random_change_animation()
|
|
for machine in machines),
|
|
lag_ratio=0.2,
|
|
)
|
|
)
|
|
|
|
|
|
class ChatbotFeedback(InteractiveScene):
|
|
random_seed = 404
|
|
|
|
def construct(self):
|
|
# Test
|
|
self.frame.set_height(10).move_to(DOWN)
|
|
user_prompt = "User: How and when was the internet invented?"
|
|
|
|
prompt_mob = Text(user_prompt)
|
|
prompt_mob.to_edge(UP)
|
|
prompt_mob["User:"].set_color(BLUE)
|
|
|
|
self.answer_mob = Text("AI Assistant:")
|
|
self.answer_mob.next_to(prompt_mob, DOWN, buff=1.0, aligned_edge=LEFT)
|
|
self.answer_mob.set_color(YELLOW)
|
|
self.og_answer_mob = self.answer_mob
|
|
|
|
self.add(prompt_mob, self.answer_mob)
|
|
|
|
# Show multiple answer
|
|
for n in range(8):
|
|
self.give_answer(prompt_mob)
|
|
mark = self.judge_answer()
|
|
self.add(self.og_answer_mob)
|
|
self.play(FadeOut(self.answer_mob), FadeOut(mark))
|
|
self.answer_mob = self.og_answer_mob
|
|
|
|
def display_answer(self, text):
|
|
new_answer_mob = get_paragraph(text.replace("\n", " ").split(" "))
|
|
new_answer_mob[:len(self.og_answer_mob)].match_style(self.og_answer_mob)
|
|
new_answer_mob.move_to(self.og_answer_mob, UL)
|
|
self.remove(self.answer_mob)
|
|
self.answer_mob = new_answer_mob
|
|
self.add(self.answer_mob)
|
|
|
|
def give_answer(self, prompt_mob, max_responses=100):
|
|
answer = self.og_answer_mob.get_text()
|
|
user_prompt = prompt_mob.get_text()
|
|
for n in range(max_responses):
|
|
answer, stop = self.add_to_answer(user_prompt, answer)
|
|
if stop:
|
|
break
|
|
self.display_answer(answer)
|
|
self.wait(2 / 30)
|
|
|
|
def judge_answer(self):
|
|
mark = random.choice([
|
|
Checkmark().set_color(GREEN),
|
|
Exmark().set_color(RED),
|
|
])
|
|
mark.scale(5)
|
|
mark.next_to(self.answer_mob, RIGHT, aligned_edge=UP)
|
|
rect = SurroundingRectangle(self.answer_mob)
|
|
rect.match_color(mark)
|
|
self.play(FadeIn(mark, scale=2), FadeIn(rect, scale=1.05))
|
|
self.wait()
|
|
return VGroup(mark, rect)
|
|
|
|
def add_to_answer(self, user_prompt: str, answer: str):
|
|
try:
|
|
tokens, probs = gpt3_predict_next_token("\n\n".join([user_prompt, answer]))
|
|
token = random.choices(tokens, np.array(probs) / sum(probs))[0]
|
|
except IndexError:
|
|
return answer, True
|
|
|
|
stop = False
|
|
if token == '<|endoftext|>':
|
|
stop = True
|
|
else:
|
|
answer += token
|
|
return answer, stop
|
|
|
|
|
|
class ContrastWithEarlierFrame(InteractiveScene):
|
|
def construct(self):
|
|
# Test
|
|
vline = Line(UP, DOWN)
|
|
vline.set_height(FRAME_HEIGHT)
|
|
self.add(vline)
|
|
|
|
titles = VGroup(
|
|
VGroup(
|
|
Text("Most earlier models"),
|
|
# Vector(0.75 * DOWN, thickness=4),
|
|
# Text("One word at a time")
|
|
),
|
|
VGroup(
|
|
Text("Transformers"),
|
|
# Vector(0.75 * DOWN, thickness=4),
|
|
# Text("All words in parallel")
|
|
),
|
|
)
|
|
for title, vect in zip(titles, [LEFT, RIGHT]):
|
|
title.arrange(DOWN, buff=0.2)
|
|
title.scale(1.5)
|
|
title.move_to(FRAME_WIDTH * vect / 4)
|
|
title.to_edge(UP)
|
|
|
|
self.add(titles)
|
|
|
|
|
|
class SequentialProcessing(InteractiveScene):
|
|
def construct(self):
|
|
# Add text
|
|
text = Text("Down by the river bank, where I used to go fishing ...")
|
|
text.move_to(1.0 * DOWN)
|
|
words = break_into_words(text)
|
|
rects = get_piece_rectangles(words)
|
|
blocks = VGroup(VGroup(rect, word) for rect, word in zip(rects, words))
|
|
blocks.save_state()
|
|
self.add(blocks)
|
|
|
|
# Vector wandering over
|
|
vect = NumericEmbedding()
|
|
vect.set_width(1.0)
|
|
vect.next_to(rects[0], UP)
|
|
|
|
for n in range(len(blocks) - 1):
|
|
blocks.target = blocks.saved_state.copy()
|
|
blocks.target[:n].fade(0.75)
|
|
blocks.target[n + 1:].fade(0.75)
|
|
self.play(
|
|
vect.animate.next_to(blocks[n], UP),
|
|
MoveToTarget(blocks)
|
|
)
|
|
self.play(
|
|
LaggedStart(
|
|
(ContextAnimation(elem, blocks[n][1], lag_ratio=0.01)
|
|
for elem in vect.get_entries()),
|
|
lag_ratio=0.01,
|
|
),
|
|
RandomizeMatrixEntries(vect),
|
|
run_time=2
|
|
)
|
|
|
|
|
|
# Version 2
|
|
|
|
|
|
class PartialScript(SimpleAutogregression):
|
|
machine_name = "Magic next\nword predictor"
|
|
machine_phi = 5 * DEGREES
|
|
machine_theta = 6 * DEGREES
|
|
|
|
def construct(self):
|
|
# Set frame
|
|
frame = self.frame
|
|
self.set_floor_plane("xz")
|
|
|
|
# Unfurl script
|
|
curled_script_img = ImageMobject("HumanAIScript")
|
|
curled_script_img.set_height(7)
|
|
|
|
curves = VGroup(SVGMobject("JaggedCurl1")[0], SVGMobject("JaggedCurl2")[0])
|
|
for curve in curves:
|
|
curve.make_smooth(approx=False)
|
|
curve.insert_n_curves(100)
|
|
curve.set_stroke(WHITE, 3)
|
|
curve.set_fill(opacity=0)
|
|
curve.set_height(5)
|
|
curves[1].scale(curves[0].get_arc_length() / curves[1].get_arc_length())
|
|
|
|
resolution = (2, 200) # Change
|
|
surface_kw = dict(u_range=(-6, 6), v_range=(0.05, 0.95), resolution=resolution)
|
|
curled_script_templates = Group(
|
|
ParametricSurface(
|
|
lambda u, v: (*curve.pfp(v)[:2], u),
|
|
**surface_kw
|
|
)
|
|
for curve in curves
|
|
)
|
|
curled_script_templates[1].rotate(PI / 2, UP)
|
|
curled_script_templates[0].rotate(-PI / 2)
|
|
flat_script_template = ParametricSurface(
|
|
lambda u, v: (u, v, 0),
|
|
**surface_kw
|
|
)
|
|
curled_script0 = TexturedSurface(curled_script_templates[0], "HumanAIScript")
|
|
curled_script1 = TexturedSurface(curled_script_templates[1], "HumanAIScript")
|
|
curled_script1_torn = TexturedSurface(curled_script_templates[1], "HumanAIScriptTorn")
|
|
flat_script = TexturedSurface(flat_script_template, "HumanAIScriptTorn")
|
|
flat_script.replace(curled_script_img, stretch=True)
|
|
|
|
for script in [curled_script0, curled_script1]:
|
|
script.set_shading(0.25, 0.25, 0.35)
|
|
curled_script1_torn.set_shading(0, 0, 0)
|
|
flat_script.set_shading(0, 0, 0)
|
|
|
|
frame.reorient(0, -1, 0, (-0.28, 0.69, 0.0), 14.43)
|
|
self.play(
|
|
TransformFromCopy(curled_script0, curled_script1),
|
|
frame.animate.reorient(56, -17, 0, (-0.2, -1.52, -2.39), 20.05),
|
|
run_time=3
|
|
)
|
|
self.play(
|
|
frame.animate.reorient(-6, -11, 0, (1.06, -1.22, -2.65), 20.05),
|
|
run_time=8,
|
|
)
|
|
self.play(
|
|
FadeOut(curled_script1, shift=1e-2 * IN),
|
|
FadeIn(curled_script1_torn, shift=1e-2 * IN),
|
|
)
|
|
self.play(
|
|
ReplacementTransform(curled_script1_torn, flat_script),
|
|
frame.animate.to_default_state(),
|
|
run_time=2
|
|
)
|
|
self.wait()
|
|
|
|
# Show the machine
|
|
machine = self.get_transformer_drawing()
|
|
machine[1].set_height(0.7).set_stroke(width=2)
|
|
machine[1].set_opacity(0)
|
|
machine.remove(machine[-1])
|
|
machine.set_height(3)
|
|
machine.to_edge(RIGHT)
|
|
|
|
self.play(
|
|
flat_script.animate.set_height(5).to_edge(LEFT),
|
|
FadeIn(machine, lag_ratio=0.01)
|
|
)
|
|
self.add(machine)
|
|
self.wait()
|
|
|
|
# Show example input and output
|
|
out_arrow = Vector(DOWN, thickness=6)
|
|
out_arrow.next_to(machine, DOWN)
|
|
in_arrow = out_arrow.copy().next_to(machine, UP, SMALL_BUFF)
|
|
in_text = Text("To be or not to _")
|
|
in_text[-1].stretch(3, 0, about_edge=LEFT)
|
|
in_text.next_to(in_arrow, UP)
|
|
prediction = Text("be", font_size=72)
|
|
prediction.next_to(out_arrow, DOWN)
|
|
|
|
self.play(FadeIn(in_text), GrowArrow(in_arrow))
|
|
self.animate_text_input(in_text, machine, position_text_over_machine=False)
|
|
self.play(
|
|
GrowArrow(out_arrow),
|
|
FadeIn(prediction, DOWN),
|
|
)
|
|
self.wait()
|
|
|
|
# Clear the board
|
|
script_text = self.get_text()
|
|
script_text.set_width(0.89 * flat_script.get_width())
|
|
script_text.next_to(flat_script.get_top(), DOWN, buff=0.33)
|
|
|
|
font_size = 48 * (script_text[0].get_height() / Text("H").get_height())
|
|
completion = "A transistor is a semiconductor device used to amplify or switch electronic signals. It consists of three layers of semiconductor material, either p-type or n-type, forming a structure with terminals called the emitter, base, and collector."
|
|
words = completion.split(" ")
|
|
paragraph = get_paragraph(completion.split(" "), font_size=font_size)
|
|
paragraph.next_to(script_text, DOWN, aligned_edge=LEFT)
|
|
paragraph.set_color(YELLOW)
|
|
|
|
self.play(
|
|
FadeIn(script_text),
|
|
FadeOut(flat_script),
|
|
FadeOut(VGroup(in_text, in_arrow, prediction)),
|
|
)
|
|
|
|
# Repeatedly add predictions
|
|
machine.scale(1.25, about_edge=RIGHT)
|
|
out_arrow.next_to(machine, DOWN, buff=0.5)
|
|
|
|
blocks = machine[0]
|
|
dials = Dial().get_grid(11, 16)
|
|
dials.set_width(blocks[-1].get_width() * 0.95)
|
|
dials.rotate(5 * DEGREES, RIGHT).rotate(10 * DEGREES, UP)
|
|
dials.move_to(blocks[-1])
|
|
dials.set_stroke(opacity=0.5)
|
|
for dial in dials:
|
|
dial.set_value(dial.get_random_value())
|
|
dials.set_z_index(2)
|
|
self.add(dials)
|
|
|
|
curr_answer = VGroup()
|
|
curr_answer.next_to(script_text, DOWN)
|
|
for n in range(6):
|
|
word = words[n]
|
|
prediction = Text(words[n], font_size=72)
|
|
prediction.next_to(out_arrow, DOWN)
|
|
word_in_answer = paragraph[len(curr_answer):len(curr_answer) + len(word)]
|
|
word_in_answer.set_color(YELLOW)
|
|
mover = VGroup(script_text, curr_answer).copy()
|
|
|
|
if n > 2:
|
|
self.skip_animations = True
|
|
|
|
self.play(
|
|
mover.animate.set_height(1.8).next_to(machine, UP, SMALL_BUFF).set_anim_args(path_arc=-30 * DEGREES),
|
|
)
|
|
self.animate_text_input(
|
|
mover, machine,
|
|
position_text_over_machine=False,
|
|
lag_ratio=1e-3
|
|
)
|
|
self.play(FadeIn(prediction, DOWN, rate_func=rush_from, run_time=0.5))
|
|
|
|
if n > 2:
|
|
self.skip_animations = False
|
|
self.wait(0.5)
|
|
self.skip_animations = True
|
|
|
|
self.play(
|
|
curr_answer.animate.set_color(WHITE),
|
|
Transform(prediction, word_in_answer),
|
|
FadeOut(mover),
|
|
)
|
|
curr_answer.add(*word_in_answer)
|
|
self.add(curr_answer)
|
|
self.remove(prediction)
|
|
|
|
def get_text(self):
|
|
script_text = Text("""
|
|
Human:
|
|
Can you explain the history of
|
|
transistors and how they're relevant
|
|
to computers? What is a transistor,
|
|
and how exactly is it used to
|
|
perform computations?
|
|
|
|
AI assistant:
|
|
""", alignment="LEFT")
|
|
script_text["Human"].set_color(BLUE)
|
|
script_text["AI assistant"].set_color(TEAL)
|
|
|
|
script_text.set_height(4).to_edge(UP)
|
|
return script_text
|
|
|
|
def create_image(self):
|
|
# Create image
|
|
script_text = self.get_text()
|
|
script_text.set_fill(BLACK)
|
|
script_text["Human"].set_fill(BLUE_D)
|
|
script_text["AI assistant"].set_fill(TEAL_D)
|
|
self.add(FullScreenRectangle(fill_color="#FCF5E5", fill_opacity=1))
|
|
self.add(script_text)
|
|
|
|
# Add off test
|
|
tear_off = SVGMobject('TearOff')
|
|
tear_off.set_stroke(width=0)
|
|
tear_off.set_fill(BLACK, 1)
|
|
tear_off.set_width(7.5)
|
|
tear_off.next_to(script_text, DOWN, buff=-0.2)
|
|
self.add(tear_off)
|
|
|
|
|
|
class ShowMachineWithDials(PredictTheNextWord):
|
|
words = ['worst', 'age', 'worse', 'best', 'most', 'end', 'very', 'blur']
|
|
logprobs = [4.0, 2.15, 1.89, 1.4, 0.1, -0.18, -0.23, -0.61]
|
|
|
|
def construct(self):
|
|
# Show machine (same position as in PredictTheNextWord)
|
|
frame = self.frame
|
|
self.set_floor_plane("xz")
|
|
blocks, llm_text, flat_dials, last_dials = self.get_blocks_and_dials()
|
|
|
|
self.clear()
|
|
self.add(frame)
|
|
frame.reorient(0, 0, 0, (-0.17, -0.12, 0.0), 4.50)
|
|
self.add(blocks, llm_text, last_dials)
|
|
|
|
# Prepare dial highlight
|
|
last_dials.target = last_dials.generate_target()
|
|
self.fix_dials(last_dials.target)
|
|
|
|
small_rect = SurroundingRectangle(last_dials[0], buff=0.025)
|
|
small_rect.set_stroke(BLUE, 2)
|
|
big_rect = small_rect.copy().scale(4)
|
|
big_rect.next_to(blocks, UP, buff=SMALL_BUFF, aligned_edge=LEFT + OUT)
|
|
big_rect.shift(1.5 * RIGHT)
|
|
big_dial = last_dials[0].copy().scale(4).set_stroke(opacity=1)
|
|
big_dial.move_to(big_rect)
|
|
rect_lines = VGroup(
|
|
Line(small_rect.get_corner(UL), big_rect.get_corner(DL)),
|
|
Line(small_rect.get_corner(UR), big_rect.get_corner(DR)),
|
|
)
|
|
rect_lines.set_stroke(WHITE, width=(1, 3))
|
|
highlighed_parameter_group = VGroup(small_rect, rect_lines, big_rect, big_dial)
|
|
|
|
last_dials.set_stroke(width=1, opacity=1)
|
|
self.play(
|
|
MoveToTarget(last_dials),
|
|
FadeOut(llm_text),
|
|
FadeIn(small_rect),
|
|
)
|
|
|
|
# Show an example input and output
|
|
example = self.get_example(blocks)
|
|
in_text, in_arrow, out_arrow, bar_groups = example
|
|
logprobs = example.logprobs
|
|
true_probs = 100 * softmax(logprobs)
|
|
bar_groups = self.get_output_distribution(self.words, 0.1 * logprobs, out_arrow)
|
|
|
|
self.play(
|
|
LaggedStart(
|
|
ShowCreation(rect_lines, lag_ratio=0),
|
|
TransformFromCopy(small_rect, big_rect),
|
|
TransformFromCopy(last_dials[0], big_dial),
|
|
FadeIn(in_text),
|
|
GrowArrow(in_arrow),
|
|
FadeIn(bar_groups),
|
|
GrowArrow(out_arrow),
|
|
),
|
|
frame.animate.reorient(0, 0, 0, (-0.43, 0.38, 0.0), 7.05),
|
|
run_time=2
|
|
)
|
|
self.play(
|
|
last_dials[0].animate_set_value(0.8),
|
|
big_dial.animate_set_value(0.8),
|
|
LaggedStart(
|
|
(dial.animate_set_value(dial.get_random_value())
|
|
for dial in last_dials[1:]),
|
|
lag_ratio=1.0 / len(last_dials),
|
|
),
|
|
*(
|
|
self.bar_group_change_animation(bg, value)
|
|
for bg, value in zip(bar_groups[:-1], true_probs)
|
|
),
|
|
run_time=3
|
|
)
|
|
self.wait()
|
|
|
|
# Play around tweaking the parameters, and seeing the output change
|
|
self.play(
|
|
LaggedStart(
|
|
(dial.animate_set_value(0)
|
|
for dial in last_dials[:12]),
|
|
lag_ratio=0.01,
|
|
),
|
|
big_dial.animate_set_value(0),
|
|
self.bar_group_change_animation(bar_groups[0], 50),
|
|
self.bar_group_change_animation(bar_groups[1], 34),
|
|
self.bar_group_change_animation(bar_groups[2], 5),
|
|
run_time=4,
|
|
)
|
|
self.play(
|
|
LaggedStart(
|
|
(dial.animate_set_value(1)
|
|
for dial in last_dials[:12]),
|
|
lag_ratio=0.01,
|
|
),
|
|
big_dial.animate_set_value(1),
|
|
self.bar_group_change_animation(bar_groups[0], 80),
|
|
self.bar_group_change_animation(bar_groups[1], 5),
|
|
self.bar_group_change_animation(bar_groups[2], 15),
|
|
run_time=4,
|
|
)
|
|
self.wait()
|
|
|
|
# Mention randomness
|
|
random_words = Text("Initially random")
|
|
random_words.next_to(blocks, UP)
|
|
random_words.set_color(RED)
|
|
out_dots = Tex(R"...", font_size=120)
|
|
out_dots.next_to(out_arrow, RIGHT)
|
|
|
|
self.play(
|
|
FadeOut(big_rect),
|
|
Uncreate(rect_lines, lag_ratio=0),
|
|
FadeOut(small_rect),
|
|
Transform(big_dial, last_dials[0])
|
|
)
|
|
self.play(
|
|
Write(random_words),
|
|
LaggedStart(
|
|
(dial.animate_set_value(dial.get_random_value())
|
|
for dial in last_dials),
|
|
lag_ratio=0.5 / len(last_dials),
|
|
run_time=2
|
|
),
|
|
FadeOut(bar_groups),
|
|
)
|
|
self.play(Write(out_dots))
|
|
self.wait()
|
|
self.play(
|
|
FadeOut(dots),
|
|
FadeOut(random_words),
|
|
FadeIn(bar_groups),
|
|
)
|
|
|
|
# Show many many parameters
|
|
example.save_state()
|
|
blocks.save_state()
|
|
last_dials.save_state()
|
|
all_dials = VGroup(*flat_dials, *last_dials)
|
|
all_dials.generate_target()
|
|
all_dials.target.space_out_submobjects(3)
|
|
new_dials = VGroup(
|
|
all_dials.target.copy().shift(3 * 2 * x * (flat_dials.get_center() - last_dials.get_center()))
|
|
for x in range(1, 9)
|
|
)
|
|
|
|
self.play(
|
|
FadeOut(example),
|
|
FadeOut(blocks),
|
|
FadeIn(flat_dials),
|
|
FadeOut(bar_groups),
|
|
FadeOut(out_arrow),
|
|
)
|
|
self.play(
|
|
FadeOut(highlighed_parameter_group),
|
|
MoveToTarget(all_dials),
|
|
LaggedStart(
|
|
(TransformFromCopy(all_dials.copy().set_opacity(0), nd)
|
|
for nd in new_dials),
|
|
lag_ratio=0.05,
|
|
),
|
|
frame.animate.reorient(-9, 0, 0, (-0.71, -0.07, -0.06), 9.64),
|
|
run_time=4
|
|
)
|
|
self.wait()
|
|
|
|
def get_blocks_and_dials(self):
|
|
machine = self.get_transformer_drawing()
|
|
machine.move_to(ORIGIN)
|
|
self.machine = machine
|
|
|
|
blocks = machine[0]
|
|
llm_text = machine[1]
|
|
llm_text.set_backstroke(BLACK, 2)
|
|
flat_dials, last_dials = self.get_machine_dials(blocks)
|
|
return blocks, llm_text, flat_dials, last_dials
|
|
|
|
def get_example(self, blocks):
|
|
in_text = Text("It was the best\nof times it was\nthe _", alignment="LEFT")
|
|
in_text[-1].stretch(4, 0, about_edge=LEFT)
|
|
in_text.next_to(blocks, LEFT, LARGE_BUFF)
|
|
in_arrow = Arrow(in_text, blocks)
|
|
|
|
out_arrow = Vector(RIGHT)
|
|
out_arrow.next_to(blocks[-1], RIGHT, buff=0.1)
|
|
logprobs = np.array(self.logprobs)
|
|
bar_groups = self.get_output_distribution(self.words, logprobs, out_arrow)
|
|
example = VGroup(in_text, in_arrow, out_arrow, bar_groups)
|
|
example.logprobs = logprobs
|
|
return example
|
|
|
|
def fix_dials(self, dials):
|
|
for dial in dials:
|
|
dial.set_stroke(width=1, opacity=1)
|
|
dial.needle.set_stroke(width=(2, 0))
|
|
return dials
|
|
|
|
def bar_group_change_animation(self, bar_group, new_value):
|
|
text, rect, value_mob = bar_group
|
|
buff = value_mob.get_left() - rect.get_right()
|
|
factor = new_value / value_mob.get_value()
|
|
|
|
return AnimationGroup(
|
|
rect.animate.stretch(factor, 0, about_edge=LEFT),
|
|
ChangeDecimalToValue(value_mob, new_value),
|
|
UpdateFromFunc(text, lambda m: value_mob.move_to(rect.get_right() + buff, LEFT)),
|
|
)
|
|
|
|
def get_output_distribution(self, words, logprobs, out_arrow):
|
|
probs = softmax(logprobs)
|
|
bar_groups = self.get_distribution(words, probs, self.machine, width_100p=1.0)
|
|
bar_groups.next_to(out_arrow, RIGHT)
|
|
return bar_groups
|
|
|
|
|
|
class ShowSingleTrainingExample(ShowMachineWithDials):
|
|
logprobs = [4.0, 6.15, 1.89, 1.4, 0.1, -0.18, -0.23, -0.61]
|
|
|
|
def construct(self):
|
|
# Add state from before
|
|
frame = self.frame
|
|
self.set_floor_plane("xz")
|
|
|
|
blocks, llm_text, flat_dials, last_dials = self.get_blocks_and_dials()
|
|
self.fix_dials(last_dials)
|
|
example = self.get_example(blocks)
|
|
in_text, in_arrow, out_arrow, bar_groups = example
|
|
|
|
self.add(blocks, last_dials)
|
|
|
|
# Show example up top
|
|
parts = ("It was the best of times it was the", "worst")
|
|
sentence = Text(" ".join(parts))
|
|
start = sentence[parts[0]][0]
|
|
end = sentence[parts[1]][0]
|
|
sentence.set_width(10)
|
|
sentence.next_to(blocks, UP, buff=1.5)
|
|
|
|
start_rect = SurroundingRectangle(start)
|
|
start_rect.set_stroke(BLUE, 2)
|
|
start_rect.set_fill(BLUE, 0.2)
|
|
end_rect = SurroundingRectangle(end)
|
|
end_rect.match_height(start_rect, stretch=True).match_y(start_rect)
|
|
end_rect.set_stroke(YELLOW, 2)
|
|
end_rect.set_fill(YELLOW, 0.2)
|
|
arrow = Arrow(start_rect.get_top(), end_rect.get_top(), path_arc=-90 * DEGREES, thickness=5)
|
|
arrow.set_fill(border_width=1)
|
|
|
|
frame.reorient(0, 0, 0, (-0.36, 0.97, 0.0), 7.52)
|
|
self.play(FadeIn(sentence, UP))
|
|
self.play(
|
|
LaggedStartMap(DrawBorderThenFill, VGroup(start_rect, end_rect)),
|
|
FadeIn(arrow),
|
|
)
|
|
self.remove(last_dials)
|
|
self.play(LaggedStart(
|
|
AnimationGroup(
|
|
TransformFromCopy(start, in_text[:-1]),
|
|
TransformFromCopy(end_rect, in_text[-1]),
|
|
FadeIn(in_arrow)
|
|
),
|
|
LaggedStart(
|
|
(
|
|
block.animate.set_color(
|
|
block.get_color() if block is blocks[-1] else TEAL
|
|
).set_anim_args(rate_func=there_and_back)
|
|
for block in blocks
|
|
),
|
|
group=blocks,
|
|
lag_ratio=0.1,
|
|
run_time=1
|
|
),
|
|
Animation(last_dials),
|
|
GrowArrow(out_arrow),
|
|
LaggedStartMap(GrowFromPoint, bar_groups, point=out_arrow.get_start()),
|
|
lag_ratio=0.3
|
|
))
|
|
self.wait()
|
|
|
|
# Flag bad prediction
|
|
out_rects = VGroup(
|
|
SurroundingRectangle(bg)
|
|
for bg in bar_groups[:2]
|
|
)
|
|
out_rects.set_stroke(RED, 3)
|
|
annotations = VGroup(
|
|
Tex(tex, font_size=60).next_to(rect, LEFT, buff=SMALL_BUFF)
|
|
for rect, tex in zip(out_rects, [R"\uparrow", R"\downarrow"])
|
|
)
|
|
annotations.set_color(RED)
|
|
|
|
self.play(
|
|
FadeTransform(end_rect.copy(), out_rects[0]),
|
|
Write(annotations[0]),
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
FadeTransform(*out_rects),
|
|
FadeTransform(*annotations),
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
FadeOut(out_rects[1]),
|
|
FadeOut(annotations[1]),
|
|
)
|
|
|
|
# Adjust
|
|
self.play(
|
|
LaggedStart(
|
|
(dial.animate_set_value(dial.get_random_value())
|
|
for dial in last_dials),
|
|
lag_ratio=1.0 / len(last_dials),
|
|
),
|
|
LaggedStart(
|
|
(FlashAround(dial, stroke_width=2, color=YELLOW, time_width=1, buff=0.025) for dial in last_dials),
|
|
lag_ratio=1.0 / len(last_dials),
|
|
),
|
|
self.bar_group_change_animation(bar_groups[0], 70),
|
|
self.bar_group_change_animation(bar_groups[1], 20),
|
|
self.bar_group_change_animation(bar_groups[2], 8),
|
|
run_time=6
|
|
)
|
|
|
|
|
|
class ParameterWeight(InteractiveScene):
|
|
def construct(self):
|
|
# Test
|
|
text = Text("Parameter / Weight", font_size=72)
|
|
text.to_edge(UP)
|
|
text.set_color(YELLOW)
|
|
param = text["Parameter"][0]
|
|
param.save_state()
|
|
param.set_x(0)
|
|
|
|
self.play(Write(param))
|
|
self.wait()
|
|
self.play(LaggedStart(
|
|
Restore(param),
|
|
FadeIn(text["/ Weight"]),
|
|
))
|
|
self.wait()
|
|
|
|
|
|
class LargeInLargeLanguageModel(InteractiveScene):
|
|
def construct(self):
|
|
# Test
|
|
text = Text("Large Language Model", font_size=72)
|
|
text.to_edge(UP)
|
|
large = text["Large"][0]
|
|
large.save_state()
|
|
large.set_x(0)
|
|
|
|
self.add(large)
|
|
self.play(FlashUnder(large), large.animate.set_color(YELLOW))
|
|
self.play(
|
|
Restore(large, path_arc=-30 * DEGREES),
|
|
Write(text[len(large):], time_span=(0.5, 1.5))
|
|
)
|
|
self.wait()
|
|
|
|
|
|
class ThousandsOfWords(InteractiveScene):
|
|
def construct(self):
|
|
# Find passage
|
|
file = Path("/Users/grant/3Blue1Brown Dropbox/3Blue1Brown/videos/2024/transformers/data/tale_of_two_cities.txt")
|
|
novel = file.read_text()
|
|
start_index = novel.index("It was the best of times")
|
|
end_index = novel.index("There were a king with a large jaw")
|
|
|
|
# Add text
|
|
passage = novel[start_index:start_index + 5000].replace("\n", " ")
|
|
text = get_paragraph(passage.split(" "), line_len=150)
|
|
text.set_width(14)
|
|
text.to_edge(UP)
|
|
self.add(text)
|
|
|
|
|
|
class EnormousAmountOfTrainingText(PremiseOfMLWithText):
|
|
def construct(self):
|
|
# Setup
|
|
self.init_data()
|
|
# n_rows = n_cols = 41
|
|
n_rows = n_cols = 9
|
|
screens = VGroup()
|
|
for row in range(n_rows):
|
|
for col in range(n_cols):
|
|
screen = self.get_screen()
|
|
screen.move_to(FRAME_WIDTH * row * RIGHT + FRAME_HEIGHT * col * DOWN)
|
|
screens.add(screen)
|
|
screens.center()
|
|
screens.submobjects.sort(key=lambda sm: get_norm(sm.machine.get_center()))
|
|
|
|
self.add(screens)
|
|
|
|
# Add frame growth
|
|
frame = self.frame
|
|
frame.clear_updaters()
|
|
frame.add_updater(lambda m: m.set_height(FRAME_HEIGHT * np.exp(0.2 * self.time)))
|
|
|
|
# Show lots of new data
|
|
inner_screens = screens[:25]
|
|
n_examples = 20
|
|
for n in range(n_examples):
|
|
self.play(LaggedStart(
|
|
*(self.change_example_animation(screen, show_dial_change=True)
|
|
for screen in inner_screens),
|
|
lag_ratio=0.1,
|
|
run_time=0.5,
|
|
))
|
|
|
|
def change_example_animation(self, screen, show_dial_change=True):
|
|
new_example = VGroup(*self.new_input_output_example(*screen.arrows))
|
|
time_span = (0, 0.35)
|
|
anims = [
|
|
FadeOut(screen.training_example, time_span=time_span),
|
|
FadeIn(new_example, time_span=time_span),
|
|
]
|
|
if show_dial_change:
|
|
anims.append(screen.machine.random_change_animation(run_time=0.5))
|
|
screen.training_example = new_example
|
|
return AnimationGroup(*anims)
|
|
|
|
def get_screen(self):
|
|
border = FullScreenRectangle()
|
|
border.set_fill(opacity=0)
|
|
border.set_stroke(WHITE, 2)
|
|
|
|
machine = MachineWithDials(width=3.5, height=2.5, n_rows=5, n_cols=7)
|
|
machine.move_to(1.0 * RIGHT)
|
|
in_arrow, out_arrow = arrows = Vector(RIGHT).replicate(2)
|
|
in_arrow.next_to(machine, LEFT)
|
|
out_arrow.next_to(machine, RIGHT)
|
|
in_data, out_data = training_example = VGroup(
|
|
*self.new_input_output_example(in_arrow, out_arrow)
|
|
)
|
|
|
|
screen = VGroup(
|
|
border, machine,
|
|
arrows, training_example
|
|
)
|
|
screen.border = border
|
|
screen.machine = machine
|
|
screen.arrows = arrows
|
|
screen.training_example = training_example
|
|
|
|
return screen
|
|
|
|
def new_input_output_example(self, in_arrow, out_arrow):
|
|
in_data, out_data = super().new_input_output_example(in_arrow, out_arrow)
|
|
in_data.scale(0.8, about_edge=RIGHT)
|
|
out_data.scale(0.8, about_edge=LEFT)
|
|
return in_data, out_data
|
|
|
|
|
|
class BadChatBot(InteractiveScene):
|
|
def construct(self):
|
|
# Add bot
|
|
bot = self.get_bot()
|
|
bot.set_height(3)
|
|
|
|
lines = Line(LEFT, RIGHT).get_grid(4, 1, buff=0.25)
|
|
lines.set_stroke(WHITE, 1)
|
|
lines[-1].stretch(0.5, 0, about_edge=LEFT)
|
|
lines.set_width(3)
|
|
bubble = SpeechBubble(lines, buff=MED_LARGE_BUFF)
|
|
bubble.set_stroke(width=5)
|
|
bubble.pin_to(bot).shift(DOWN)
|
|
|
|
self.add(bot)
|
|
self.play(Write(bubble, run_time=3))
|
|
self.blink(bot)
|
|
self.wait()
|
|
|
|
# Make lines bad
|
|
self.play(
|
|
LaggedStart(
|
|
(Transform(line, self.get_scribble(line))
|
|
for line in lines),
|
|
lag_ratio=0.1,
|
|
run_time=2
|
|
)
|
|
)
|
|
for _ in range(2):
|
|
self.blink(bot)
|
|
self.wait(2)
|
|
|
|
def get_scribble(self, line):
|
|
freqs = np.random.random(5)
|
|
graph = FunctionGraph(
|
|
lambda x: 0.05 * sum(math.sin(freq * TAU * x) for freq in freqs),
|
|
x_range=(0, 5, 0.1)
|
|
)
|
|
graph.put_start_and_end_on(*line.get_start_and_end())
|
|
graph.match_style(line)
|
|
graph.set_stroke(color=RED)
|
|
return graph
|
|
|
|
def get_bot(self):
|
|
bot = SVGMobject("Bot")
|
|
subpaths = bot[0].get_subpaths()
|
|
bot[0].set_points([*subpaths[0], subpaths[0][-1], *subpaths[1]])
|
|
eyes = VGroup(Dot().replace(VMobject().set_points(subpath)) for subpath in subpaths[2:])
|
|
bot.eyes = eyes
|
|
bot.add(eyes)
|
|
bot.set_stroke(width=0)
|
|
|
|
bot.set_height(4)
|
|
bot.set_fill(GREY_B)
|
|
bot.set_shading(0.5, 0.5, 1)
|
|
|
|
return bot
|
|
|
|
def blink(self, bot):
|
|
self.play(
|
|
bot.eyes.animate.stretch(0, 1).set_anim_args(rate_func=squish_rate_func(there_and_back))
|
|
)
|
|
|
|
|
|
class WriteRLHF(InteractiveScene):
|
|
def construct(self):
|
|
text = Text("Step 2: RLHF")
|
|
full_text = Text("Reinforcement Learning\nwith Human Feedback")
|
|
full_text.next_to(text, UP, LARGE_BUFF)
|
|
full_text.align_to(text, RIGHT).shift(RIGHT)
|
|
initials = VGroup(full_text[letter[0]][0][0] for letter in "RLHF")
|
|
full_text.remove(*initials)
|
|
|
|
self.add(text)
|
|
self.wait()
|
|
self.play(
|
|
TransformFromCopy(text["RLHF"][0], initials, lag_ratio=0.25),
|
|
Write(full_text, time_span=(1.5, 3)),
|
|
run_time=3
|
|
)
|
|
self.wait()
|
|
|
|
|
|
class RLHFWorker(InteractiveScene):
|
|
def construct(self):
|
|
# Test
|
|
self.add(FullScreenRectangle().set_fill(GREY_E, 1))
|
|
# worker = SVGMobject("computer_stall")
|
|
worker = SVGMobject("comp_worker")
|
|
worker.set_height(4)
|
|
worker.move_to(4 * LEFT)
|
|
worker.set_fill(GREY_C, 1)
|
|
|
|
rect = Rectangle(7, 5)
|
|
rect.to_edge(RIGHT)
|
|
rect.set_stroke(WHITE, 2)
|
|
rect.set_fill(BLACK, 1)
|
|
|
|
self.add(worker)
|
|
self.add(rect)
|
|
|
|
|
|
class RLHFWorkers(ShowMachineWithDials):
|
|
def construct(self):
|
|
# Add workers
|
|
self.add(FullScreenRectangle().set_fill(GREY_E, 1))
|
|
workers = SVGMobject("comp_worker").get_grid(3, 2, buff=0.5)
|
|
workers.set_height(7)
|
|
workers.to_edge(LEFT)
|
|
workers.set_fill(GREY_C, 1)
|
|
|
|
self.add(workers)
|
|
|
|
# Machine
|
|
blocks, llm_text, flat_dials, last_dials = self.get_blocks_and_dials()
|
|
machine = VGroup(blocks, last_dials)
|
|
machine.set_height(4)
|
|
machine.center().to_edge(RIGHT, buff=LARGE_BUFF)
|
|
last_dials.set_stroke(opacity=1)
|
|
|
|
self.add(machine)
|
|
|
|
for _ in range(8):
|
|
self.play(LaggedStart(
|
|
(dial.animate_set_value(dial.get_random_value())
|
|
for dial in last_dials),
|
|
lag_ratio=0.5 / len(last_dials),
|
|
run_time=2
|
|
))
|
|
self.wait()
|
|
|
|
|
|
class SerialProcessing(InteractiveScene):
|
|
phrase = "It was the best of times it was the worst of times"
|
|
phrase_center = 2 * UP
|
|
|
|
def construct(self):
|
|
# Set up words
|
|
words = self.get_words()
|
|
rects = get_piece_rectangles(words)
|
|
|
|
self.add(rects)
|
|
self.add(words)
|
|
|
|
# Animate in the vectors
|
|
vectors = VGroup(
|
|
self.get_abstract_vector().next_to(word, DOWN, LARGE_BUFF)
|
|
for word in words
|
|
)
|
|
last_vect = VGroup(VectorizedPoint(rects[0].get_bottom()))
|
|
|
|
for word, vect in zip(words, vectors):
|
|
self.play(
|
|
FadeIn(vect, run_time=2),
|
|
LaggedStart(
|
|
(ContextAnimation(
|
|
square, VGroup(*word, *last_vect),
|
|
direction=DOWN,
|
|
lag_ratio=0.01,
|
|
path_arc=30 * DEGREES
|
|
)
|
|
for square in vect),
|
|
lag_ratio=0.05,
|
|
run_time=2
|
|
),
|
|
last_vect.animate.set_opacity(0.2)
|
|
)
|
|
last_vect = vect
|
|
|
|
def get_words(self):
|
|
result = break_into_words(Text(self.phrase))
|
|
result.move_to(self.phrase_center)
|
|
return result
|
|
|
|
def get_abstract_vector(self, values=None, default_length=10, elem_size=0.2):
|
|
if values is None:
|
|
values = np.random.uniform(-1, 1, default_length)
|
|
result = Square().get_grid(len(values), 1, buff=0)
|
|
result.set_width(elem_size)
|
|
result.set_stroke(WHITE, 1)
|
|
for square, value in zip(result, values):
|
|
color = value_to_color(value, min_value=0, max_value=1)
|
|
square.set_fill(color, opacity=1)
|
|
return result
|
|
|
|
|
|
class ParallelProcessing(SerialProcessing):
|
|
def construct(self):
|
|
# Set up words
|
|
words = self.get_words()
|
|
rects = get_piece_rectangles(words)
|
|
|
|
self.add(rects)
|
|
self.add(words)
|
|
|
|
# Animate in the vectors
|
|
vectors = VGroup(
|
|
self.get_abstract_vector().next_to(word, DOWN, buff=1.5)
|
|
for word in words
|
|
)
|
|
|
|
lines = VGroup(
|
|
Line(
|
|
rect.get_bottom(), vect.get_top(),
|
|
buff=0.05,
|
|
stroke_color=WHITE,
|
|
stroke_width=2 * random.random()**3
|
|
)
|
|
for rect in rects
|
|
for vect in vectors
|
|
)
|
|
lines.shuffle()
|
|
|
|
for vect, word in zip(vectors, words):
|
|
vect.save_state()
|
|
for square in vect:
|
|
square.move_to(word)
|
|
square.set_opacity(0)
|
|
|
|
self.play(
|
|
LaggedStartMap(ShowCreation, lines, lag_ratio=0.01),
|
|
LaggedStartMap(Restore, vectors, lag_ratio=0)
|
|
)
|
|
self.play(lines.animate.set_stroke(opacity=0.25))
|
|
self.wait()
|
|
|
|
|
|
class ManyComputationsPerUnitTimeV2(InteractiveScene):
|
|
def construct(self):
|
|
# Add computations
|
|
box = Rectangle(5, 5)
|
|
label = Text("1 Billion computations per Second")
|
|
label.next_to(box, UP)
|
|
self.add(box)
|
|
self.add(label)
|
|
|
|
comps = self.get_computations(box)
|
|
self.add(comps)
|
|
self.wait(3)
|
|
|
|
# Place box into minute interval
|
|
width = FRAME_WIDTH - 1
|
|
number_lines = VGroup(
|
|
minute_line := NumberLine((0, 60, 1), width=width, big_tick_spacing=10),
|
|
hour_line := NumberLine((0, 60, 1), width=width, big_tick_spacing=10),
|
|
day_line := NumberLine((0, 24, 1), width=width, big_tick_spacing=6),
|
|
month_line := NumberLine((0, 31, 1), width=width),
|
|
year_line := NumberLine((0, 12, 1), width=width),
|
|
y100_line := NumberLine((0, 100, 1), width=width),
|
|
y10k_line := NumberLine((0, 100, 1), width=width),
|
|
y1M_line := NumberLine((0, 100, 1), width=width),
|
|
y100M_line := NumberLine((0, 100, 1), width=width),
|
|
)
|
|
number_lines.move_to(DOWN)
|
|
|
|
first_ticks = minute_line.ticks[:2]
|
|
sec_brace = Brace(first_ticks, DOWN, buff=0, tex_string=R"\underbrace{\qquad\qquad}")
|
|
sec_label = Text("Second", font_size=30).next_to(sec_brace, DOWN, SMALL_BUFF)
|
|
|
|
self.play(
|
|
ShowCreation(minute_line, lag_ratio=0.01),
|
|
box.animate.match_width(first_ticks).move_to(first_ticks.get_center(), DOWN).set_stroke(width=1),
|
|
TransformFromCopy(label["Second"][0], sec_label),
|
|
GrowFromCenter(sec_brace),
|
|
run_time=2
|
|
)
|
|
|
|
# Add other boxes
|
|
minute_label = self.get_timeline_full_label(number_lines[1], "Minute")
|
|
new_boxes = VGroup(
|
|
box.copy().move_to(tick.get_center(), DL)
|
|
for tick in minute_line.ticks[1:-1]
|
|
)
|
|
for new_box in new_boxes:
|
|
new_box.save_state()
|
|
new_box.move_to(box)
|
|
computations = VGroup(
|
|
self.get_computations(new_box, n_iterations=1)
|
|
for new_box in new_boxes
|
|
)
|
|
# computations = VGroup() # If needed
|
|
|
|
self.add(computations)
|
|
self.play(
|
|
FadeIn(minute_label, DOWN),
|
|
LaggedStartMap(Restore, new_boxes, lag_ratio=0.1),
|
|
run_time=2
|
|
)
|
|
self.wait(2)
|
|
|
|
# Add labels
|
|
minute_line.add(minute_label)
|
|
names = ["Hour", "Day", "Month", "Year", "100 Years", "10,000 Years", "1,000,000 Years", "100,000,000 Years"]
|
|
for line, name in zip(number_lines[1:], names):
|
|
line.label = self.get_timeline_full_label(line, name)
|
|
line.add(line.label)
|
|
|
|
# Arrange all lines
|
|
number_lines[1:].arrange(DOWN, buff=2.0)
|
|
number_lines[1:].next_to(minute_line, DOWN, buff=2.0)
|
|
|
|
scale_lines = VGroup()
|
|
for nl1, nl2 in zip(number_lines, number_lines[1:]):
|
|
n = len(nl2.ticks) // 2
|
|
mini_line = Line(nl2.ticks[n - 1].get_center(), nl2.ticks[n].get_center())
|
|
pair = VGroup(
|
|
DashedLine(nl1.get_start(), mini_line.get_start()),
|
|
DashedLine(nl1.get_end(), mini_line.get_end()),
|
|
)
|
|
pair.set_stroke(WHITE, 2)
|
|
nl1.target = nl1.copy()
|
|
nl1.target.replace(mini_line, dim_to_match=0)
|
|
nl1.target.shift(mini_line.pfp(0.5) - nl1.target.pfp(0.5))
|
|
scale_lines.add(pair)
|
|
|
|
# Start panning down
|
|
lag_ratio = 1.5
|
|
self.play(
|
|
LaggedStart(
|
|
*(AnimationGroup(*(ShowCreation(sl) for sl in pair)) for pair in scale_lines),
|
|
lag_ratio=lag_ratio,
|
|
),
|
|
LaggedStart(
|
|
*(FadeIn(nl) for nl in number_lines[1:]),
|
|
lag_ratio=lag_ratio,
|
|
),
|
|
LaggedStart(
|
|
*(TransformFromCopy(nl, nl.target) for nl in number_lines[:-1]),
|
|
lag_ratio=lag_ratio,
|
|
),
|
|
self.frame.animate.set_y(number_lines[-1].get_y() + 2).set_width(18).set_anim_args(
|
|
rate_func=lambda t: interpolate(smooth(t), linear(t), there_and_back_with_pause(t, pause_ratio=0.8))
|
|
),
|
|
run_time=30
|
|
)
|
|
self.play(self.frame.animate.reorient(0, 0, 0, (-0.03, -11.55, 0.0), 31.76), run_time=4)
|
|
self.wait(4)
|
|
|
|
def fade_in_bigger_interval(self, new_interval, prev_interval, fader, scale_factor, added_anims=[]):
|
|
pivot = prev_interval.n2p(0)
|
|
new_interval.save_state()
|
|
new_interval.scale(scale_factor, about_point=pivot)
|
|
new_interval[:-1].set_opacity(0)
|
|
new_interval[-1].set_fill(BLACK)
|
|
|
|
self.play(
|
|
Restore(new_interval),
|
|
prev_interval.animate.scale(1.0 / scale_factor, about_point=pivot).set_fill(border_width=0),
|
|
fader.animate.scale(1.0 / scale_factor, about_point=pivot).set_opacity(0),
|
|
*added_anims,
|
|
run_time=4,
|
|
rate_func=rush_from
|
|
)
|
|
self.remove(fader)
|
|
|
|
def get_timeline_full_label(self, timeline, name):
|
|
brace = Brace(Line().set_width(7), UP, buff=MED_SMALL_BUFF)
|
|
brace.set_fill(border_width=5)
|
|
brace.match_width(timeline)
|
|
brace.next_to(timeline, UP, buff=MED_SMALL_BUFF)
|
|
label = Text(name, font_size=72)
|
|
label.next_to(brace, UP, MED_SMALL_BUFF)
|
|
|
|
label.next_to(timeline, DOWN)
|
|
return label
|
|
|
|
return VGroup(brace, label)
|
|
|
|
def get_computations(self, box, n_lines=10, n_iterations=3, n_digits=4, cycle_time=0.5):
|
|
# Try adding lines
|
|
lines = VGroup()
|
|
for iteration in range(n_iterations):
|
|
cluster = VGroup()
|
|
for n in range(n_lines):
|
|
x = random.uniform(0, 10**(n_digits))
|
|
y = random.uniform(0, 10**(n_digits))
|
|
if random.choice([True, False]):
|
|
comb = x * y
|
|
sym = Tex(R"\times")
|
|
else:
|
|
comb = x + y
|
|
sym = Tex(R"+")
|
|
line = VGroup(
|
|
DecimalNumber(x, num_decimal_places=3), sym,
|
|
DecimalNumber(y, num_decimal_places=3), Tex("="),
|
|
DecimalNumber(comb, num_decimal_places=3)
|
|
)
|
|
line.arrange(RIGHT, buff=SMALL_BUFF)
|
|
lines.add(line)
|
|
cluster.add(line)
|
|
cluster.arrange(DOWN, buff=MED_LARGE_BUFF, aligned_edge=LEFT)
|
|
cluster.set_max_height(0.9 * box.get_height())
|
|
cluster.set_max_width(0.9 * box.get_width())
|
|
cluster.move_to(box)
|
|
|
|
# Add updater
|
|
def update_lines(lines):
|
|
sigma = 0.12
|
|
alpha = (self.time / (cycle_time * n_iterations)) % 1
|
|
step = 1.0 / len(lines)
|
|
for n, line in enumerate(lines):
|
|
x = min((
|
|
abs(a - n * step)
|
|
for a in (alpha - 1, alpha, alpha + 1)
|
|
))
|
|
y = np.exp(-x**2 / sigma**2)
|
|
line.set_fill(opacity=y)
|
|
|
|
lines.set_height(0.9 * box.get_height())
|
|
lines.move_to(box)
|
|
|
|
lines.clear_updaters()
|
|
lines.add_updater(update_lines)
|
|
|
|
return lines
|
|
|
|
def old(self):
|
|
# Repeatedly scale down
|
|
to_fade = VGroup(sec_brace, sec_label, box, comps, new_boxes, computations)
|
|
scale_factors = [60, 24, 365, 1000]
|
|
for new_int, prev_int, scale_factor in zip(number_lines[1:], number_lines[0:], scale_factors):
|
|
self.fade_in_bigger_interval(
|
|
new_int, prev_int, to_fade, scale_factor,
|
|
added_anims=[label.animate.set_opacity(0)],
|
|
)
|
|
self.wait(2)
|
|
to_fade = prev_int
|
|
|
|
# Multiply last line by 100
|
|
self.fade_in_bigger_interval(
|
|
y1M_line, millenium_line, year_line, 1000,
|
|
added_anims=[self.frame.animate.reorient(0, 0, 0, (-3.51, -5.18, 0.0), 12.93)],
|
|
)
|
|
|
|
lines = Line(LEFT, RIGHT).replicate(100)
|
|
lines.match_width(y1M_line)
|
|
lines.arrange_to_fit_height(10)
|
|
lines.sort(lambda p: -p[1])
|
|
lines.set_stroke(WHITE, 1)
|
|
lines.move_to(y1M_line[0].get_center(), UP)
|
|
|
|
side_brace, label100M = self.get_timeline_full_label(y1M_line, "100,000,000 Years")
|
|
side_brace.rotate(PI / 2)
|
|
side_brace.match_height(lines)
|
|
side_brace.next_to(lines, LEFT)
|
|
label100M.next_to(side_brace, LEFT)
|
|
|
|
self.play(
|
|
LaggedStart(
|
|
(TransformFromCopy(lines[0].copy().set_opacity(0), line)
|
|
for line in lines),
|
|
lag_ratio=0.03,
|
|
run_time=2
|
|
),
|
|
FadeIn(side_brace, scale=10, shift=2 * DOWN, time_span=(1, 2)),
|
|
FadeIn(label100M, time_span=(1, 2)),
|
|
)
|
|
self.wait()
|
|
|
|
|
|
class VectorLabel(InteractiveScene):
|
|
def construct(self):
|
|
# Test
|
|
brace = Brace(Line(4 * UP, ORIGIN), LEFT)
|
|
brace.center()
|
|
brace.set_stroke(WHITE, 3)
|
|
text = Text("Vector", font_size=90)
|
|
text.next_to(brace, LEFT, MED_SMALL_BUFF)
|
|
text.shift(SMALL_BUFF * UP)
|
|
|
|
self.play(
|
|
GrowFromCenter(brace),
|
|
Write(text)
|
|
)
|
|
self.play(
|
|
FlashUnder(text, color=YELLOW)
|
|
)
|
|
self.wait()
|
|
|
|
|
|
class ParameterToVectorAnnotation(InteractiveScene):
|
|
def construct(self):
|
|
# Test
|
|
dials = VGroup(Dial(value_range=(-10, 10, 1)) for _ in range(10))
|
|
dials.arrange(DOWN)
|
|
dials.set_height(5)
|
|
|
|
values = [1, 4.3, 2, 0.9, -1.5, 2.9, -1.2, 7.8, 0, -2.3]
|
|
arrows = VGroup(
|
|
Vector(0.5 * RIGHT, thickness=2).next_to(dial, RIGHT, buff=SMALL_BUFF)
|
|
for dial in dials
|
|
)
|
|
|
|
self.play(
|
|
Write(dials, lag_ratio=0.01),
|
|
LaggedStartMap(GrowArrow, arrows),
|
|
)
|
|
self.play(LaggedStart(
|
|
(dial.animate_set_value(value)
|
|
for dial, value in zip(dials, values)),
|
|
lag_ratio=0.05,
|
|
))
|
|
self.wait()
|
|
|
|
|
|
class ThreeWordsToOne(InteractiveScene):
|
|
def construct(self):
|
|
# Test
|
|
image = ImageMobject("CHMTopText")
|
|
image.set_height(FRAME_HEIGHT)
|
|
# self.add(image)
|
|
|
|
phrase = Text("Computer History Museum", font_size=61)
|
|
words = VGroup(phrase[word][0] for word in phrase.get_text().split(" "))
|
|
words.move_to([0, 2.627, 0])
|
|
og_words = words.copy()
|
|
og_words.shift(DOWN)
|
|
words[0].shift(0.13 * LEFT)
|
|
words[2].shift(0.4 * RIGHT)
|
|
colors = ["#63DCF7", "#90C9FA", "#85D4FE"]
|
|
for word, color in zip(words, colors):
|
|
word.set_color(color)
|
|
|
|
words.save_state()
|
|
|
|
self.add(words)
|
|
self.wait()
|
|
|
|
# Back to unity
|
|
rect = SurroundingRectangle(og_words)
|
|
rect.set_color(RED)
|
|
chm_image = ImageMobject("/Users/grant/3Blue1Brown Dropbox/3Blue1Brown/videos/2024/transformers/chm/images/CHM_Exterior.jpeg")
|
|
chm_image.match_width(rect)
|
|
chm_image.next_to(rect, DOWN)
|
|
|
|
self.play(Transform(words, og_words))
|
|
self.play(
|
|
ShowCreation(rect),
|
|
FadeIn(chm_image, DOWN)
|
|
)
|
|
self.wait()
|
|
|
|
# Three pieces
|
|
rects = VGroup(
|
|
SurroundingRectangle(word).set_fill(color, 0.2).set_stroke(color, 2)
|
|
for word, color in zip(words.saved_state, colors)
|
|
)
|
|
words.set_z_index(1)
|
|
|
|
icons = VGroup(
|
|
SVGMobject("GenericComputer.svg"),
|
|
SVGMobject("History.svg"),
|
|
SVGMobject("Museum.svg"),
|
|
)
|
|
for word, icon in zip(words.saved_state, icons):
|
|
icon.set_fill(word.get_color(), 1, border_width=1)
|
|
icon.set_height(1)
|
|
icon.next_to(word, DOWN)
|
|
|
|
self.remove(chm_image)
|
|
self.play(
|
|
ReplacementTransform(VGroup(rect), rects),
|
|
Restore(words),
|
|
*(
|
|
FadeTransform(chm_image.copy(), icon)
|
|
for icon in icons
|
|
)
|
|
)
|
|
self.wait()
|
|
|
|
|
|
class ExamplePhraseHeader(InteractiveScene):
|
|
def construct(self):
|
|
# Test
|
|
phrase = Text("The Computer History Museum\nis located in ?????")
|
|
phrase.to_edge(UP)
|
|
rect = SurroundingRectangle(phrase).set_stroke(WHITE, 2)
|
|
|
|
q_marks = phrase["?????"][0]
|
|
q_marks[::4].set_fill(opacity=0)
|
|
q_rect = SurroundingRectangle(q_marks)
|
|
q_rect.set_fill(YELLOW, 0.25)
|
|
q_rect.set_stroke(YELLOW, 2)
|
|
|
|
self.add(q_rect)
|
|
self.add(phrase)
|
|
|
|
|
|
class TrainingDataCHM(InteractiveScene):
|
|
def construct(self):
|
|
# Test
|
|
passages = [
|
|
"The Computer History Museum (CHM) is a museum ... located in Mountain View...",
|
|
"Computer History Museum ... 1401 N. Shoreline Blvd. Mountain View...",
|
|
"Things to do in Mountain View ... the Computer History Museum ...",
|
|
"While I was in Mountain View ... stopped by the Computer History Museum ...",
|
|
]
|
|
items = VGroup(
|
|
get_paragraph(passage.split(" "), line_len=35, font_size=30)
|
|
for passage in passages
|
|
)
|
|
|
|
items.arrange(DOWN, buff=LARGE_BUFF, aligned_edge=LEFT)
|
|
items.to_corner(DL)
|
|
items.shift(0.5 * UP)
|
|
dots = Tex(R"\vdots")
|
|
dots.next_to(items, DOWN, MED_LARGE_BUFF)
|
|
dots.shift_onto_screen(buff=MED_SMALL_BUFF)
|
|
items.add(dots)
|
|
|
|
title = Text("Training Data")
|
|
title.next_to(items, UP, buff=LARGE_BUFF)
|
|
title.shift_onto_screen(buff=MED_SMALL_BUFF)
|
|
underline = Underline(title)
|
|
|
|
chm_phrases = VGroup(item["Computer History Museum"] for item in items)
|
|
mv_phrases = VGroup(item["Mountain View"] for item in items)
|
|
|
|
self.play(
|
|
FadeIn(title),
|
|
ShowCreation(underline),
|
|
LaggedStartMap(FadeIn, items, shift=DOWN, lag_ratio=0.15)
|
|
)
|
|
self.wait()
|
|
self.play(chm_phrases.animate.set_color(RED).set_anim_args(lag_ratio=0.1))
|
|
self.wait()
|
|
self.play(mv_phrases.animate.set_color(PINK).set_anim_args(lag_ratio=0.1))
|
|
self.wait()
|
|
|
|
# Arrows to ffn
|
|
ffn_point = 3 * RIGHT + DOWN
|
|
arrows = VGroup(
|
|
Arrow(
|
|
item.get_right(),
|
|
interpolate(item.get_right(), ffn_point, 0.6),
|
|
path_arc=arc * DEGREES,
|
|
)
|
|
for item, arc in zip(items[:-1], range(-40, 40, 20))
|
|
)
|
|
arrows.set_fill(border_width=1)
|
|
self.play(Write(arrows, lag_ratio=0.1), run_time=3)
|
|
self.play(
|
|
LaggedStart(
|
|
*(
|
|
FadeOutToPoint(letter.copy(), ffn_point)
|
|
for letter in VGroup(chm_phrases, mv_phrases).family_members_with_points()
|
|
),
|
|
lag_ratio=1e-2,
|
|
run_time=3
|
|
)
|
|
)
|
|
self.wait()
|
|
|
|
|
|
class DivyUpParameters(ShowMachineWithDials):
|
|
def construct(self):
|
|
# Show machine
|
|
frame = self.frame
|
|
self.set_floor_plane("xz")
|
|
|
|
machine = VGroup(*self.get_blocks_and_dials())
|
|
blocks, llm_text, flat_dials, last_dials = machine
|
|
machine.set_height(3.0)
|
|
machine.to_edge(DOWN, buff=LARGE_BUFF)
|
|
|
|
block_outlines = blocks.copy()
|
|
block_outlines.set_fill(opacity=0)
|
|
block_outlines.set_stroke(WHITE, 2)
|
|
block_outlines.insert_n_curves(20)
|
|
|
|
# last_dials.set_submobjects(last_dials[:3]) # Remove
|
|
last_dials.set_stroke(opacity=1)
|
|
for dial in last_dials:
|
|
dial[0].set_stroke(width=1)
|
|
dial[1].set_stroke(width=1)
|
|
dial[3].set_stroke(width=(3, 0))
|
|
|
|
frame.reorient(-23, -13, 0, (-0.41, -1.71, -0.06), 4.95)
|
|
self.play(
|
|
FadeIn(blocks, shift=0.0, lag_ratio=0.01),
|
|
LaggedStartMap(VShowPassingFlash, block_outlines.family_members_with_points(), time_width=2.0, lag_ratio=0.01, remover=True),
|
|
LaggedStartMap(VFadeInThenOut, flat_dials, lag_ratio=0.001, remover=True),
|
|
FadeIn(last_dials, time_span=(2, 3)),
|
|
self.frame.animate.reorient(10, -2, 0, (-0.25, -1.58, -0.02), 4.61),
|
|
run_time=3,
|
|
)
|
|
self.remove(flat_dials)
|
|
|
|
# Show individual blocks
|
|
top_blocks = blocks[:3].copy()
|
|
all_dials = VGroup(*last_dials)
|
|
for block in top_blocks:
|
|
dials = last_dials.copy()
|
|
dials.rotate(self.machine_phi, RIGHT)
|
|
dials.rotate(self.machine_theta, UP)
|
|
dials.move_to(block)
|
|
dials.set_stroke(opacity=1)
|
|
block.add(dials)
|
|
block.target = block.generate_target()
|
|
dials.set_opacity(0)
|
|
all_dials.add(*dials)
|
|
|
|
block_targets = Group(block.target for block in top_blocks)
|
|
block_targets.rotate(-self.machine_theta, UP)
|
|
block_targets.rotate(-self.machine_phi, RIGHT)
|
|
block_targets.set_height(2)
|
|
block_targets.arrange(RIGHT, buff=1.5)
|
|
block_targets.to_edge(UP)
|
|
block_targets.set_shading(0.1, 0.1, 0.1)
|
|
|
|
labels = VGroup(
|
|
TexText(R"Word $\to$ Vector"),
|
|
Text("Attention"),
|
|
Text("Feedforward"),
|
|
)
|
|
for label, block in zip(labels, block_targets):
|
|
label.next_to(block, DOWN)
|
|
|
|
self.add(
|
|
blocks[0], top_blocks[0],
|
|
blocks[1], top_blocks[1],
|
|
blocks[2], top_blocks[2],
|
|
blocks[3:], last_dials
|
|
)
|
|
self.play(
|
|
MoveToTarget(top_blocks[1], time_span=(0, 2)),
|
|
MoveToTarget(top_blocks[2], time_span=(1, 3)),
|
|
MoveToTarget(top_blocks[0], time_span=(2, 4)),
|
|
Write(labels[1], time_span=(1.5, 2)),
|
|
Write(labels[2], time_span=(2.5, 3)),
|
|
Write(labels[0], time_span=(3.5, 4)),
|
|
frame.animate.to_default_state(),
|
|
run_time=4
|
|
)
|
|
self.wait()
|
|
|
|
# Change all the parameters
|
|
self.play(
|
|
LaggedStart(
|
|
(dial.animate_set_value(dial.get_random_value())
|
|
for dial in all_dials),
|
|
lag_ratio=1 / len(all_dials),
|
|
run_time=6
|
|
),
|
|
LaggedStart(
|
|
(FlashAround(dial, buff=0, color=YELLOW)
|
|
for dial in all_dials),
|
|
lag_ratio=1 / len(all_dials),
|
|
run_time=6
|
|
),
|
|
)
|
|
self.wait()
|
|
|
|
|
|
# End clips
|
|
|
|
|
|
class ShowPreviousVideos(InteractiveScene):
|
|
def construct(self):
|
|
# Backdrop
|
|
background = FullScreenRectangle()
|
|
self.add(background)
|
|
|
|
line = Line(UP, DOWN).set_height(FRAME_HEIGHT)
|
|
line.set_stroke(WHITE, 2)
|
|
|
|
series_name = Text("Deep Learning Series", font_size=68)
|
|
series_name.to_edge(UP, buff=0.35)
|
|
self.add(series_name)
|
|
|
|
# Show thumbnails
|
|
thumbnails = Group(
|
|
Group(
|
|
Rectangle(16, 9).set_height(1).set_stroke(WHITE, 2),
|
|
ImageMobject(f"https://img.youtube.com/vi/{slug}/maxresdefault.jpg", height=1)
|
|
)
|
|
for slug in [
|
|
"aircAruvnKk",
|
|
"IHZwWFHWa-w",
|
|
"Ilg3gGewQ5U",
|
|
"tIeHLnjs5U8",
|
|
"wjZofJX0v4M",
|
|
"eMlx5fFNoYc",
|
|
"9-Jl0dxWQs8",
|
|
]
|
|
)
|
|
|
|
thumbnails.arrange_in_grid(n_cols=4, buff=0.2)
|
|
thumbnails.set_width(FRAME_WIDTH - 1)
|
|
thumbnails.next_to(series_name, DOWN, buff=1.0)
|
|
thumbnails[-3:].set_x(0)
|
|
|
|
self.play(LaggedStartMap(FadeIn, thumbnails, shift=0.3 * UP, lag_ratio=0.35, run_time=4))
|
|
self.wait()
|
|
|
|
# Rearrange
|
|
left_x = -FRAME_WIDTH / 4
|
|
self.play(
|
|
series_name.animate.set_x(left_x),
|
|
thumbnails.animate.arrange_in_grid(n_cols=2, buff=0.25).set_height(6).set_x(left_x).to_edge(DOWN),
|
|
ShowCreation(line, time_span=(1, 2)),
|
|
run_time=2,
|
|
)
|
|
self.wait()
|
|
|
|
|
|
class EndScreen(PatreonEndScreen):
|
|
title_text = "Where to dig deeper"
|
|
thanks_words = """
|
|
Special thanks to these Patreon supporters
|
|
"""
|