3b1b-videos/_2024/transformers/generation.py
2024-10-23 13:21:07 -05:00

987 lines
No EOL
31 KiB
Python

from transformers.models.videomae import image_processing_videomae
from manim_imports_ext import *
from _2024.transformers.helpers import *
from transformers import GPT2Tokenizer
from transformers import GPT2LMHeadModel
from transformers import PreTrainedModel
import torch
import openai
import tiktoken
@lru_cache(maxsize=1)
def get_gpt2_tokenizer(model_name='gpt2'):
return GPT2Tokenizer.from_pretrained(model_name)
@lru_cache(maxsize=1)
def get_gpt2_model(model_name='gpt2'):
return GPT2LMHeadModel.from_pretrained(model_name)
def gpt2_predict_next_token(text, n_shown=7):
tokenizer = get_gpt2_tokenizer()
model = get_gpt2_model()
# Encode the input text
indexed_tokens = tokenizer.encode(
text, add_special_tokens=False, return_tensors='pt'
)
# Predict all tokens
with torch.no_grad():
outputs = model(indexed_tokens)
# Pull out the first batch, and the last token prediction
predictions = outputs[0][0, -1, :]
# Get the predicted next token
indices = torch.argsort(predictions)
top_indices = reversed(indices[-n_shown:])
tokens = list(map(tokenizer.decode, top_indices))
probs = softmax(predictions)[top_indices]
return tokens, probs
def gpt3_predict_next_token(text, n_shown=10, random_seed=0):
openai.api_key = os.getenv('OPENAI_KEY')
response = openai.Completion.create(
# Or another model version, adjust as necessary
engine="gpt-3.5-turbo-instruct",
prompt=text,
max_tokens=1,
n=1,
temperature=1.0,
user=str(random_seed),
logprobs=50 # I think this is actually set to a max of 20?
)
top_logprob_dict = response.choices[0]["logprobs"]["top_logprobs"][0]
tokens, logprobs = zip(*top_logprob_dict.items())
probs = np.exp(logprobs)
indices = np.argsort(-probs)
shown_tokens = [tokens[i] for i in indices[:n_shown]]
return shown_tokens, probs[indices[:n_shown]]
def clean_text(text):
return " ".join(filter(lambda s: s.strip(), re.split(r"\s", text)))
def next_token_bar_chart(
words, probs,
reference_point=ORIGIN,
font_size=24,
width_100p=1.0,
prob_exp=0.75,
bar_height=0.25,
bar_space_factor=0.5,
buff=1.2,
show_ellipses=True,
use_percent=True,
):
labels = VGroup(Text(word, font_size=font_size) for word in words)
bars = VGroup(
Rectangle(prob**(prob_exp) * width_100p, bar_height)
for prob, label in zip(probs, labels)
)
bars.arrange(DOWN, aligned_edge=LEFT, buff=bar_space_factor * bar_height)
bars.set_fill(opacity=1)
bars.set_submobject_colors_by_gradient(TEAL, YELLOW)
bars.set_stroke(WHITE, 1)
bar_groups = VGroup()
for label, bar, prob in zip(labels, bars, probs):
if use_percent:
prob_label = Integer(int(100 * prob), unit="%", font_size=0.75 * font_size)
else:
prob_label = DecimalNumber(prob, font_size=0.75 * font_size)
prob_label.next_to(bar, RIGHT, buff=SMALL_BUFF)
label.next_to(bar, LEFT)
bar_groups.add(VGroup(label, bar, prob_label))
if show_ellipses:
ellipses = Tex(R"\vdots", font_size=font_size)
ellipses.next_to(bar_groups[-1][0], DOWN)
bar_groups.add(ellipses)
bar_groups.shift(reference_point - bars.get_left() + buff * RIGHT)
return bar_groups
class SimpleAutogregression(InteractiveScene):
text_corner = 3.5 * UP + 0.75 * RIGHT
line_len = 31
font_size = 35
n_shown_predictions = 12
seed_text = "Behold, a wild pi creature, foraging in its native"
seed_text_color = BLUE_B
machine_name = "Transformer"
machine_phi = 10 * DEGREES
machine_theta = 12 * DEGREES
n_predictions = 120
skip_through = False
random_seed = 0
model = "gpt2"
def construct(self):
# Repeatedly generate
text_mob, next_word_line, machine = self.init_text_and_machine()
for n in range(self.n_predictions):
text_mob = self.new_selection_cycle(
text_mob, next_word_line, machine,
quick=(n > 10),
skip_anims=self.skip_through,
)
def init_text_and_machine(self):
# Set up active text
self.cur_str = self.seed_text
text_mob = self.string_to_mob(self.cur_str)
text_mob.set_color(self.seed_text_color)
next_word_line = self.get_next_word_line(text_mob)
# Set up Transformer as some sort of machine
machine = self.get_transformer_drawing()
machine.set_y(0).to_edge(LEFT, buff=-0.6)
self.add(text_mob)
self.add(next_word_line)
self.add(machine)
return text_mob, next_word_line, machine
def string_to_mob(self, text):
text += " l" # Dumb hack for alignment
result = get_paragraph(
text.replace("\n", " ").split(" "),
self.line_len,
self.font_size
)
result.move_to(self.text_corner, UL)
result[-1].set_fill(BLACK, 0) # Continue dumb hack
result[-1].stretch(0, 0, about_edge=LEFT)
return result
def get_next_word_line(self, text_mob, char_len=7):
next_word_line = Underline(text_mob[:char_len])
next_word_line.set_stroke(TEAL, 2)
next_word_line.next_to(text_mob[-1], RIGHT, SMALL_BUFF, aligned_edge=DOWN)
if self.skip_through:
next_word_line.set_opacity(0)
return next_word_line
def get_transformer_drawing(self):
self.camera.light_source.move_to([-5, 5, 10])
self.frame.set_field_of_view(20 * DEGREES)
blocks = VGroup(
VPrism(3, 2, 0.2)
for n in range(10)
)
blocks.set_fill(GREY_D, 1)
blocks.set_stroke(width=0)
blocks.set_shading(0.25, 0.5, 0.2)
blocks.arrange(OUT)
blocks.move_to(ORIGIN, OUT)
blocks.rotate(self.machine_phi, RIGHT, about_edge=OUT)
blocks.rotate(self.machine_theta, UP, about_edge=OUT)
blocks.deactivate_depth_test()
for block in blocks:
block.sort(lambda p: p[2])
word = Text(self.machine_name, alignment="LEFT")
word.next_to(blocks[-1], UP)
word.shift(0.1 * UP + 0.4 * LEFT)
word.move_to(blocks[-1])
word.set_backstroke(BLACK, 5)
out_arrow = Vector(
0.5 * RIGHT, stroke_width=10,
max_tip_length_to_length_ratio=0.5,
max_width_to_length_ratio=12
)
out_arrow.next_to(blocks[-1], RIGHT, buff=SMALL_BUFF)
out_arrow.set_opacity(0)
result = VGroup(blocks, word, out_arrow)
return result
def get_distribution(
self, words, probs, machine,
font_size=24,
width_100p=1.8,
bar_height=0.25,
show_ellipses=True
):
labels = VGroup(Text(word, font_size=font_size) for word in words)
bars = VGroup(
Rectangle(prob * width_100p, bar_height)
for prob, label in zip(probs, labels)
)
bars.arrange(DOWN, aligned_edge=LEFT, buff=0.5 * bar_height)
bars.set_fill(opacity=1)
bars.set_submobject_colors_by_gradient(TEAL, YELLOW)
bars.set_stroke(WHITE, 1)
bar_groups = VGroup()
for label, bar, prob in zip(labels, bars, probs):
prob_label = Integer(int(100 * prob), unit="%", font_size=0.75 * font_size)
prob_label.next_to(bar, RIGHT, buff=SMALL_BUFF)
label.next_to(bar, LEFT)
bar_groups.add(VGroup(label, bar, prob_label))
if show_ellipses:
ellipses = Tex(R"\vdots", font_size=font_size)
ellipses.next_to(bar_groups[-1][0], DOWN)
bar_groups.add(ellipses)
arrow_point = machine[-1].get_right()
bar_groups.shift(arrow_point - bars.get_left() + 1.5 * RIGHT)
bar_groups.align_to(machine, UP)
return bar_groups
def animate_text_input(self, text_mob, machine, position_text_over_machine=True, added_anims=[], lag_ratio=0.02):
blocks = machine[0]
text_copy = text_mob.copy()
if position_text_over_machine:
text_copy.target = text_copy.generate_target()
text_copy.target.set_max_width(4)
text_copy.target.next_to(blocks[0], UP)
text_copy.target.shift_onto_screen()
self.play(MoveToTarget(text_copy, path_arc=-45 * DEGREES))
self.play(LaggedStart(
*added_anims,
Transform(
text_copy,
VGroup(VectorizedPoint(machine.get_top())),
lag_ratio=lag_ratio,
run_time=1,
path_arc=-45 * DEGREES,
remover=True,
),
LaggedStart(
(
block.animate.set_color(
block.get_color() if block is blocks[-1] else TEAL
).set_anim_args(rate_func=there_and_back)
for block in blocks
),
lag_ratio=0.1,
run_time=1
),
Animation(machine[1:]),
lag_ratio=0.5
))
def animate_prediction_ouptut(self, machine, cur_str):
words, probs = self.predict_next_token(cur_str)
bar_groups = self.get_distribution(words, probs, machine)
self.play(
LaggedStart(
(FadeInFromPoint(bar_group, machine[0][-1].get_right())
for bar_group in bar_groups),
lag_ratio=0.025,
group=bar_groups,
run_time=1
)
)
return bar_groups
def animate_random_sample(self, bar_groups):
widths = np.array([group[1].get_width() for group in bar_groups[:-1]])
dist = widths / widths.sum()
seed = random.randint(0, 1000)
buff = 0.025
highlight_rect = SurroundingRectangle(bar_groups[0], buff=buff)
highlight_rect.set_stroke(YELLOW, 2)
highlight_rect.set_fill(YELLOW, 0.25)
def highlight_randomly(rect, dist, alpha):
np.random.seed(seed + int(10 * alpha))
index = np.random.choice(np.arange(len(dist)), p=dist)
rect.surround(bar_groups[index], buff=buff)
rect.stretch(1.1, 0)
self.play(
UpdateFromAlphaFunc(highlight_rect, lambda rect, a: highlight_randomly(rect, dist, a)),
Animation(bar_groups)
)
bar_groups.add_to_back(highlight_rect)
def animate_word_addition(self, bar_groups, text_mob, next_word_line, force_unskip=False):
# Choose the highlighted_group
bar_group = None
if isinstance(bar_groups[0], Rectangle):
# Use the highlight rect to find the group element
bars = bar_groups[1:-1]
diffs = [abs(bg.get_y() - bar_groups[0].get_y()) for bg in bars]
bar_group = bar_groups[1:][np.argmin(diffs)]
if bar_group is None:
bar_group = bar_groups[0]
# Animate selection
word = bar_group[0].get_text()
new_str = self.cur_str + word
new_text_mob = self.string_to_mob(new_str)
new_text_mob[:len(self.seed_text.replace(" ", ""))].set_color(self.seed_text_color)
word_targets = new_text_mob[word.strip()]
if len(word_targets) > 0:
target = word_targets[-1]
else:
target = new_text_mob[-len(word) - 1:-1]
# target = new_text_mob[-len(word):]
self.add(bar_groups)
self.play(
FadeTransform(bar_group[0].copy(), target),
Transform(
next_word_line,
self.get_next_word_line(new_text_mob),
),
)
if force_unskip:
self.skip_animations = False
target.save_state()
target.set_fill(YELLOW)
self.wait(0.5)
target.restore()
self.skip_animations = True
self.play(
FadeOut(bar_groups),
)
self.remove(text_mob)
self.add(new_text_mob)
self.cur_str = new_str
return new_text_mob
def new_selection_cycle(self, text_mob, next_word_line, machine, quick=False, skip_anims=False):
if skip_anims:
self.skip_animations = True
if quick:
words, probs = self.predict_next_token(self.cur_str)
bar_groups = self.get_distribution(words, probs, machine)
self.add(bar_groups)
else:
self.animate_text_input(text_mob, machine)
bar_groups = self.animate_prediction_ouptut(machine, self.cur_str)
self.animate_random_sample(bar_groups)
new_text_mob = self.animate_word_addition(
bar_groups, text_mob, next_word_line,
force_unskip=skip_anims
)
return new_text_mob
#
def predict_next_token(self, text):
result = None
n_shown = self.n_shown_predictions
if self.model == "gpt3":
try:
result = gpt3_predict_next_token(
text, n_shown, random_seed=self.random_seed
)
except Exception as e:
pass
if result is None:
result = gpt2_predict_next_token(text, n_shown)
return result
class AltSimpleAutoRegression(SimpleAutogregression):
n_predictions = 1
line_len = 25
def reposition_transformer_drawing(self, machine):
machine.move_to(0.5 * RIGHT)
in_arrow = machine[-1].copy()
in_arrow.rotate(-45 * DEGREES)
in_arrow.next_to(machine, UL)
self.add(in_arrow)
return machine
class AnnotateNextWord(SimpleAutogregression):
def construct(self):
text_mob, next_word_line, machine = self.init_text_and_machine()
self.add(machine, *machine[1:])
words, probs = self.predict_next_token(self.cur_str)
bar_groups = self.get_distribution(words, probs, machine[-1].get_right())
self.add(bar_groups)
# Initial text
from manimlib.mobject.boolean_ops import Union
highlight = Union(
SurroundingRectangle(text_mob["Behold, a wild pi creature,"]),
SurroundingRectangle(text_mob["foraging in its native"]),
)
highlight.set_stroke(BLUE, 3)
arrow = Vector(LEFT, stroke_width=10)
arrow.next_to(highlight, RIGHT).match_y(text_mob[0])
dist_rect = SurroundingRectangle(bar_groups)
dist_rect.set_stroke(YELLOW, 2)
self.play(
ShowCreation(highlight),
GrowArrow(arrow)
)
self.wait()
self.play(
arrow.animate.rotate(PI / 2).next_to(dist_rect, UP),
ReplacementTransform(highlight, dist_rect),
)
self.wait()
self.play(
FadeOut(dist_rect),
FadeOut(arrow),
)
# Flash through
self.remove(bar_groups)
text_mob = self.new_selection_cycle(
text_mob, next_word_line, machine,
)
class QuickerRegression(SimpleAutogregression):
skip_through = True
class AutoregressionGPT3(SimpleAutogregression):
model = "gpt3"
class QuickRegressionGPT3(SimpleAutogregression):
skip_through = True
model = "gpt3"
class GPT3CleverestAutocomplete(QuickRegressionGPT3):
seed_text = "To date, the cleverest thinker of all time was"
n_predictions = 70
def construct(self):
# Test
text_mob, next_word_line, machine = self.init_text_and_machine()
for n in range(self.n_predictions):
text_mob = self.new_selection_cycle(
text_mob, next_word_line, machine,
skip_anims=(n > 2),
)
class GPT3OnLearningSimpler(QuickRegressionGPT3):
seed_text = "The most effective way to learn computer science is"
text_corner = 3.5 * UP + 3 * LEFT
line_len = 35
font_size = 35
n_predictions = 300
time_per_prediction = 0.2
random_seed = 313
model = "gpt3"
min_y = -3
up_shift = 5 * UP
show_dist = False
def construct(self):
# Test
cur_str = self.seed_text
text_mob = VGroup()
for n in range(self.n_predictions):
self.clear()
words, probs = self.predict_next_token(cur_str, n_shown=20)
index = np.random.choice(np.arange(len(words)), p=(probs / probs.sum()))
new_word = words[index]
cur_str += new_word
text_mob = self.string_to_mob(cur_str)
# Color seed
if self.color_seed:
text_mob[:len(self.seed_text.replace(" ", ""))].set_color(BLUE)
# Add to text, shift if necessary
text_mob[new_word.strip()][-1].set_color(YELLOW)
if text_mob.get_bottom()[1] < self.min_y:
text_mob.shift(self.up_shift)
self.text_corner += self.up_shift
self.add(text_mob)
# Add the distribution
if self.show_dist:
dist = self.get_distribution(
words[:self.n_shown_predictions],
probs[:self.n_shown_predictions],
buff=0
)
dist.set_height(4)
dist.to_edge(DOWN)
rect = SurroundingRectangle(dist[min(index, len(dist) - 1)])
self.add(dist, rect)
self.wait(self.time_per_prediction)
class GPT3OnLongPassages(GPT3OnLearningSimpler):
seed_text = "Writing long passages seems to involve more foresight and planning than what single-word prediction"
n_predictions = 100
color_seed = False
class GPT3CreaturePrediction(GPT3CleverestAutocomplete):
seed_text = "the fluffy blue creature"
n_predictions = 1
class GPT3CreaturePrediction2(GPT3CleverestAutocomplete):
seed_text = "the fluffy blue creature roamed the"
n_predictions = 1
class LowTempExample(GPT3OnLearningSimpler):
seed_text = "Once upon a time, there was a"
model = "gpt3"
min_y = 1
up_shift = 2 * UP
show_dist = True
temp = 0
n_predictions = 200
time_per_prediction = 0.25
def predict_next_token(self, text, n_shown=None):
words, probs = super().predict_next_token(text, n_shown)
if self.temp == 0:
probs = np.zeros_like(probs)
probs[0] = 1
else:
probs = probs**(1 / self.temp)
probs /= probs.sum()
return words, probs
class HighTempExample(LowTempExample):
temp = 5
model = "gpt3"
class MidTempExample(LowTempExample):
seed_text = "If you could see the underlying probability distributions a large language model uses when generating text, then"
temp = 1
model = "gpt3"
class ChatBotPrompt(SimpleAutogregression):
system_prompt = """
What follows is a conversation between a user and a helpful,
very knowledgeable AI assistant.
"""
user_prompt = "User: Give me some ideas for what to do when visiting Paris."
ai_seed = "AI Assistant: "
machine_name = "Large\nLanguage\nModel"
line_len = 28
font_size = 36
color_seed = False
n_predictions = 60
model = "gpt3"
random_seed = 12
def construct(self):
# Test
text_mob, next_word_line, machine = self.init_text_and_machine()
all_strs = list(map(clean_text, [self.system_prompt, self.user_prompt, self.ai_seed]))
system_prompt, user_prompt, ai_seed = all_text = VGroup(
get_paragraph(
s.split(" "),
font_size=self.font_size,
line_len=self.line_len
)
for s in all_strs
)
all_text.arrange(DOWN, aligned_edge=LEFT, buff=0.75)
all_text.move_to(self.text_corner, UL)
self.remove(text_mob)
self.add(all_text)
text_mob = ai_seed
self.text_corner = text_mob.get_corner(UL)
next_word_line.next_to(ai_seed, RIGHT, aligned_edge=DOWN)
self.cur_str = "\n\n".join(all_strs)
# Comment on system prompt
sys_rect = SurroundingRectangle(system_prompt)
sys_rect.set_stroke(GREEN, 2)
self.play(
ShowCreation(sys_rect),
system_prompt.animate.set_color(GREEN_B)
)
self.wait()
# Users prompt
from manimlib.mobject.boolean_ops import Union
top_line = user_prompt["Give me some ideas for what"]
low_line = user_prompt["to do when visiting Santiago."]
user_rect = Union(
SurroundingRectangle(low_line),
SurroundingRectangle(top_line),
)
user_rect.set_stroke(BLUE, 2)
sys_rect.insert_n_curves(100)
self.play(
ReplacementTransform(sys_rect, user_rect),
top_line.animate.set_color(BLUE_B),
low_line.animate.set_color(BLUE_B),
)
self.wait()
self.play(
FadeOut(user_rect),
)
# Run predictions
text_mob = all_text
self.add(all_text.copy())
for n in range(self.n_predictions):
text_mob = self.new_selection_cycle(
text_mob, next_word_line, machine,
skip_anims=(n > 0),
)
def string_to_mob(self, text):
seed = self.ai_seed.strip()
if seed in text:
text = text[text.index(seed):]
return super().string_to_mob(text)
class ChatBotPrompt2(ChatBotPrompt):
user_prompt = "User: Can you explain what temperature is, in the context of softmax?"
class ChatBotPrompt3(ChatBotPrompt):
user_prompt = "User: Can you give me some ideas for what to do while visiting Munich?"
class VoiceToTextExample(SimpleAutogregression):
model_name = "voice-to-text"
def construct(self):
# Add model
box = Rectangle(4, 3)
box.set_stroke(WHITE, 2)
name = Text(self.model_name, font_size=60)
name.set_max_width(box.get_width())
name.next_to(box, UP)
machine = self.get_transformer_drawing()
machine.center()
machine.set_max_width(0.75 * box.get_width())
machine.move_to(box)
arrows = Vector(0.75 * RIGHT, stroke_width=8).replicate(2)
arrows[0].next_to(box, LEFT, SMALL_BUFF)
arrows[1].next_to(box, RIGHT, SMALL_BUFF)
model = Group(box, name, arrows, machine)
self.add(*model)
self.add(Point())
# Process input
max_width = 3.75
in_mob = self.get_input().set_max_width(max_width)
out_mob = self.get_output().set_max_width(max_width)
in_mob.next_to(arrows, LEFT)
out_mob.next_to(arrows, RIGHT)
self.add(in_mob)
self.play(LaggedStart(
FadeOutToPoint(
in_mob.copy(), machine.get_left(),
path_arc=-45 * DEGREES,
lag_ratio=0.01,
),
LaggedStart(
(block.animate.set_color(TEAL).set_anim_args(rate_func=there_and_back)
for block in machine[0][:-1]),
lag_ratio=0.1,
run_time=1,
),
FadeInFromPoint(
out_mob.copy(), machine.get_right(),
path_arc=45 * DEGREES,
lag_ratio=0.02
),
lag_ratio=0.7
))
self.wait()
def get_input(self) -> Mobject:
result =ImageMobject("AudioSnippet").set_width(3.75)
result.set_height(3, stretch=True)
return result
def get_output(self) -> Mobject:
return Text("""
Some models take
in audio and
produce a transcript
""", alignment="LEFT")
class TextToVoiceExample(VoiceToTextExample):
model_name = "text-to-voice"
def get_input(self):
return Text("""
This sentence comes from
a model going the other
way around, producing
synthetic speech just
from text.
""", alignment="LEFT")
def get_output(self):
return super().get_input()
class TextToImage(VoiceToTextExample):
model_name = "text-to-image"
prompt = """
1960s photograph of a cute fluffy blue wild pi
creature, a creature whose body is shaped like
the symbol π, who is foraging in its native territory,
staring back at the camera with an exotic scene
in the background.
"""
image_name = "PiCreatureDalle3_5"
def get_clean_prompt(self):
return clean_text(self.prompt)
def get_input(self):
return get_paragraph(self.get_clean_prompt().split(" "), line_len=25)
def get_output(self):
return ImageMobject(self.image_name)
def generate_output(self):
# Test
self.prompt = """
1960s photograph of a cute fluffy blue wild pi
creature, a creature whose face bears a subtle resemblence
to the shape of the symbol π, who is foraging in its native
territory, staring back at the camera with an exotic scene
in the background.
"""
self.prompt = "abstract depiction of furry fluffiness"
openai.api_key = os.getenv('OPENAI_KEY')
prompt = self.get_clean_prompt()
response = openai.Image.create(
model="dall-e-3",
prompt=prompt,
size="1024x1024",
quality="standard",
n=1,
)
image_url = response.data[0].url
print(prompt)
print(image_url)
response = openai.Image.create_variation(
image=open("/Users/grant/3Blue1Brown Dropbox/3Blue1Brown/images/raster/PiCreatureDalle3_17.png", "rb"),
n=1,
size="1024x1024"
)
class TranslationExample(VoiceToTextExample):
model_name = "machine translation"
def get_input(self):
return Text("Attention is all\nyou need")
def get_output(self):
return Group(Point(), *Text("注意力就是你所需要的一切"))
class PredictionVsGeneration(SimpleAutogregression):
model = "gpt2"
def construct(self):
# Setup
self.add(FullScreenRectangle())
morty = Mortimer()
morty.to_edge(DOWN)
morty.body.insert_n_curves(100)
self.add(morty)
# Words
words = VGroup(Text("Prediction"), Text("Generation"))
words.scale(1.5)
for vect, word in zip([UL, UR], words):
word.next_to(morty, vect)
word.shift(0.5 * UP)
# Create prediction object
seed_text = "The goal of predicting the next"
self.n_shown_predictions = 8
tokens, probs = self.predict_next_token(seed_text)
dist = self.get_distribution(tokens, probs)
brace = Brace(dist, LEFT, SMALL_BUFF)
words = Text(seed_text, font_size=36).next_to(brace, LEFT)
prediction = VGroup(words, brace, dist)
prediction.set_width(FRAME_WIDTH / 2 - 1)
prediction.next_to(morty, UL)
prediction.shift(0.5 * UP).shift_onto_screen()
self.add(prediction)
# Animations
self.play(
morty.change("raise_right_hand", prediction),
FadeIn(prediction[0], UP),
GrowFromCenter(prediction[1]),
LaggedStart(
(FadeInFromPoint(bar, prediction[1].get_center())
for bar in prediction[2]),
lag_ratio=0.05,
)
)
self.play(Blink(morty))
self.play(
morty.change("raise_left_hand", 3 * UR),
)
self.wait()
self.play(Blink(morty))
self.wait()
class ManyParallelPredictions(SimpleAutogregression):
line_len = 200
n_shown_predictions = 8
model = "gpt3"
def construct(self):
# Setup
self.fake_machine = VectorizedPoint().replicate(3)
full_string = "Harry Potter was a highly unusual boy"
# Draw last layer vectors
last_layer = VGroup(
NumericEmbedding(length=10)
for n in range(12)
)
last_layer.arrange(RIGHT, buff=0.35 * last_layer[0].get_width())
last_layer.set_height(3)
last_layer.to_edge(DOWN)
# self.add(last_layer)
rects = VGroup(map(SurroundingRectangle, last_layer))
rects.set_stroke(YELLOW, 2)
arrows = VGroup(Vector(0.5 * UP).next_to(rect, UP, buff=0.1) for rect in rects)
arrows.set_stroke(YELLOW)
# Show prediction groups
words = full_string.split(" ")
substrings = [
" ".join(words[:n + 1])
for n in range(len(words))
]
predictions = VGroup(
self.get_prediction_group(substring)
for substring in substrings
)
predictions[0].to_edge(UP, buff=1.25).align_to(rects[1], LEFT)
for prediction, arrow, rect in zip(predictions, arrows, rects):
prediction.move_to(predictions[0], LEFT)
arrow.become(Arrow(
rect.get_top(),
prediction[1].get_left(),
))
arrow.set_stroke(YELLOW)
last_group = VGroup(
rects[0].copy().set_opacity(0),
arrows[0].copy().set_opacity(0),
predictions[0].copy().set_opacity(0),
)
for rect, arrow, prediction in zip(rects, arrows, predictions):
self.remove(last_group)
self.play(
TransformFromCopy(last_group[0], rect),
TransformFromCopy(last_group[1], arrow),
TransformMatchingStrings(last_group[2][0].copy(), prediction[0], run_time=1),
FadeTransform(last_group[2][1].copy(), prediction[1]),
FadeTransform(last_group[2][2].copy(), prediction[2]),
)
self.wait()
last_group = VGroup(rect, arrow, prediction)
def get_prediction_group(self, text):
words, probs = self.predict_next_token(text)
dist = self.get_distribution(
words, probs,
width_100p=2.0
)
dist.set_max_height(2.5)
brace = Brace(dist, LEFT)
prefix = Text(text, font_size=30)
prefix.next_to(brace, LEFT)
result = VGroup(prefix, brace, dist)
return result
class PeekUnderTheHood(SimpleAutogregression):
def construct(self):
# Add parts
text_mob, next_word_line, machine = self.init_text_and_machine()
blocks, label, arrow = machine
self.remove(text_mob, next_word_line)
# Zoom in
self.camera.light_source.move_to([-15, 5, 10])
self.set_floor_plane("xz")
blocks.rotate(-5 * DEGREES, UP, about_edge=OUT)
blocks.rotate(-10 * DEGREES, RIGHT, about_edge=OUT)
blocks.target = blocks.generate_target()
blocks.target.set_height(5)
blocks.target.center()
blocks.target[5:].set_opacity(0.3)
self.play(
self.frame.animate.reorient(-23, -12, 0, (1.79, -0.56, 1.27), 8.40).set_anim_args(run_time=3),
MoveToTarget(blocks, run_time=3),
FadeOut(arrow, RIGHT),
FadeOut(label, 2 * OUT),
)
self.wait()
blocks[5:].set_opacity(0.3)
# Add matrices
matrices = VGroup(WeightMatrix(shape=(8, 8)) for x in range(9))
matrices.arrange_in_grid(h_buff_ratio=0.25, v_buff_ratio=0.4)
matrices.match_width(blocks)
index = 6
matrices.move_to(blocks[index], OUT)
self.add(matrices, blocks[index:])