mirror of
https://github.com/3b1b/videos.git
synced 2025-08-31 21:58:59 +00:00
987 lines
No EOL
31 KiB
Python
987 lines
No EOL
31 KiB
Python
from transformers.models.videomae import image_processing_videomae
|
|
from manim_imports_ext import *
|
|
from _2024.transformers.helpers import *
|
|
|
|
from transformers import GPT2Tokenizer
|
|
from transformers import GPT2LMHeadModel
|
|
from transformers import PreTrainedModel
|
|
import torch
|
|
import openai
|
|
import tiktoken
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def get_gpt2_tokenizer(model_name='gpt2'):
|
|
return GPT2Tokenizer.from_pretrained(model_name)
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def get_gpt2_model(model_name='gpt2'):
|
|
return GPT2LMHeadModel.from_pretrained(model_name)
|
|
|
|
|
|
def gpt2_predict_next_token(text, n_shown=7):
|
|
tokenizer = get_gpt2_tokenizer()
|
|
model = get_gpt2_model()
|
|
# Encode the input text
|
|
indexed_tokens = tokenizer.encode(
|
|
text, add_special_tokens=False, return_tensors='pt'
|
|
)
|
|
|
|
# Predict all tokens
|
|
with torch.no_grad():
|
|
outputs = model(indexed_tokens)
|
|
# Pull out the first batch, and the last token prediction
|
|
predictions = outputs[0][0, -1, :]
|
|
|
|
# Get the predicted next token
|
|
indices = torch.argsort(predictions)
|
|
top_indices = reversed(indices[-n_shown:])
|
|
tokens = list(map(tokenizer.decode, top_indices))
|
|
probs = softmax(predictions)[top_indices]
|
|
|
|
return tokens, probs
|
|
|
|
|
|
def gpt3_predict_next_token(text, n_shown=10, random_seed=0):
|
|
openai.api_key = os.getenv('OPENAI_KEY')
|
|
response = openai.Completion.create(
|
|
# Or another model version, adjust as necessary
|
|
engine="gpt-3.5-turbo-instruct",
|
|
prompt=text,
|
|
max_tokens=1,
|
|
n=1,
|
|
temperature=1.0,
|
|
user=str(random_seed),
|
|
logprobs=50 # I think this is actually set to a max of 20?
|
|
)
|
|
top_logprob_dict = response.choices[0]["logprobs"]["top_logprobs"][0]
|
|
tokens, logprobs = zip(*top_logprob_dict.items())
|
|
probs = np.exp(logprobs)
|
|
indices = np.argsort(-probs)
|
|
shown_tokens = [tokens[i] for i in indices[:n_shown]]
|
|
return shown_tokens, probs[indices[:n_shown]]
|
|
|
|
|
|
def clean_text(text):
|
|
return " ".join(filter(lambda s: s.strip(), re.split(r"\s", text)))
|
|
|
|
|
|
def next_token_bar_chart(
|
|
words, probs,
|
|
reference_point=ORIGIN,
|
|
font_size=24,
|
|
width_100p=1.0,
|
|
prob_exp=0.75,
|
|
bar_height=0.25,
|
|
bar_space_factor=0.5,
|
|
buff=1.2,
|
|
show_ellipses=True,
|
|
use_percent=True,
|
|
):
|
|
labels = VGroup(Text(word, font_size=font_size) for word in words)
|
|
bars = VGroup(
|
|
Rectangle(prob**(prob_exp) * width_100p, bar_height)
|
|
for prob, label in zip(probs, labels)
|
|
)
|
|
bars.arrange(DOWN, aligned_edge=LEFT, buff=bar_space_factor * bar_height)
|
|
bars.set_fill(opacity=1)
|
|
bars.set_submobject_colors_by_gradient(TEAL, YELLOW)
|
|
bars.set_stroke(WHITE, 1)
|
|
|
|
bar_groups = VGroup()
|
|
for label, bar, prob in zip(labels, bars, probs):
|
|
if use_percent:
|
|
prob_label = Integer(int(100 * prob), unit="%", font_size=0.75 * font_size)
|
|
else:
|
|
prob_label = DecimalNumber(prob, font_size=0.75 * font_size)
|
|
prob_label.next_to(bar, RIGHT, buff=SMALL_BUFF)
|
|
label.next_to(bar, LEFT)
|
|
bar_groups.add(VGroup(label, bar, prob_label))
|
|
|
|
if show_ellipses:
|
|
ellipses = Tex(R"\vdots", font_size=font_size)
|
|
ellipses.next_to(bar_groups[-1][0], DOWN)
|
|
bar_groups.add(ellipses)
|
|
|
|
bar_groups.shift(reference_point - bars.get_left() + buff * RIGHT)
|
|
|
|
return bar_groups
|
|
|
|
|
|
class SimpleAutogregression(InteractiveScene):
|
|
text_corner = 3.5 * UP + 0.75 * RIGHT
|
|
line_len = 31
|
|
font_size = 35
|
|
n_shown_predictions = 12
|
|
seed_text = "Behold, a wild pi creature, foraging in its native"
|
|
seed_text_color = BLUE_B
|
|
machine_name = "Transformer"
|
|
machine_phi = 10 * DEGREES
|
|
machine_theta = 12 * DEGREES
|
|
n_predictions = 120
|
|
skip_through = False
|
|
random_seed = 0
|
|
model = "gpt2"
|
|
|
|
def construct(self):
|
|
# Repeatedly generate
|
|
text_mob, next_word_line, machine = self.init_text_and_machine()
|
|
for n in range(self.n_predictions):
|
|
text_mob = self.new_selection_cycle(
|
|
text_mob, next_word_line, machine,
|
|
quick=(n > 10),
|
|
skip_anims=self.skip_through,
|
|
)
|
|
|
|
def init_text_and_machine(self):
|
|
# Set up active text
|
|
self.cur_str = self.seed_text
|
|
text_mob = self.string_to_mob(self.cur_str)
|
|
text_mob.set_color(self.seed_text_color)
|
|
next_word_line = self.get_next_word_line(text_mob)
|
|
|
|
# Set up Transformer as some sort of machine
|
|
machine = self.get_transformer_drawing()
|
|
machine.set_y(0).to_edge(LEFT, buff=-0.6)
|
|
|
|
self.add(text_mob)
|
|
self.add(next_word_line)
|
|
self.add(machine)
|
|
|
|
return text_mob, next_word_line, machine
|
|
|
|
def string_to_mob(self, text):
|
|
text += " l" # Dumb hack for alignment
|
|
result = get_paragraph(
|
|
text.replace("\n", " ").split(" "),
|
|
self.line_len,
|
|
self.font_size
|
|
)
|
|
result.move_to(self.text_corner, UL)
|
|
result[-1].set_fill(BLACK, 0) # Continue dumb hack
|
|
result[-1].stretch(0, 0, about_edge=LEFT)
|
|
return result
|
|
|
|
def get_next_word_line(self, text_mob, char_len=7):
|
|
next_word_line = Underline(text_mob[:char_len])
|
|
next_word_line.set_stroke(TEAL, 2)
|
|
next_word_line.next_to(text_mob[-1], RIGHT, SMALL_BUFF, aligned_edge=DOWN)
|
|
if self.skip_through:
|
|
next_word_line.set_opacity(0)
|
|
return next_word_line
|
|
|
|
def get_transformer_drawing(self):
|
|
self.camera.light_source.move_to([-5, 5, 10])
|
|
self.frame.set_field_of_view(20 * DEGREES)
|
|
blocks = VGroup(
|
|
VPrism(3, 2, 0.2)
|
|
for n in range(10)
|
|
)
|
|
blocks.set_fill(GREY_D, 1)
|
|
blocks.set_stroke(width=0)
|
|
blocks.set_shading(0.25, 0.5, 0.2)
|
|
blocks.arrange(OUT)
|
|
blocks.move_to(ORIGIN, OUT)
|
|
blocks.rotate(self.machine_phi, RIGHT, about_edge=OUT)
|
|
blocks.rotate(self.machine_theta, UP, about_edge=OUT)
|
|
|
|
blocks.deactivate_depth_test()
|
|
for block in blocks:
|
|
block.sort(lambda p: p[2])
|
|
|
|
word = Text(self.machine_name, alignment="LEFT")
|
|
word.next_to(blocks[-1], UP)
|
|
word.shift(0.1 * UP + 0.4 * LEFT)
|
|
word.move_to(blocks[-1])
|
|
word.set_backstroke(BLACK, 5)
|
|
out_arrow = Vector(
|
|
0.5 * RIGHT, stroke_width=10,
|
|
max_tip_length_to_length_ratio=0.5,
|
|
max_width_to_length_ratio=12
|
|
)
|
|
out_arrow.next_to(blocks[-1], RIGHT, buff=SMALL_BUFF)
|
|
out_arrow.set_opacity(0)
|
|
|
|
result = VGroup(blocks, word, out_arrow)
|
|
return result
|
|
|
|
def get_distribution(
|
|
self, words, probs, machine,
|
|
font_size=24,
|
|
width_100p=1.8,
|
|
bar_height=0.25,
|
|
show_ellipses=True
|
|
):
|
|
labels = VGroup(Text(word, font_size=font_size) for word in words)
|
|
bars = VGroup(
|
|
Rectangle(prob * width_100p, bar_height)
|
|
for prob, label in zip(probs, labels)
|
|
)
|
|
bars.arrange(DOWN, aligned_edge=LEFT, buff=0.5 * bar_height)
|
|
bars.set_fill(opacity=1)
|
|
bars.set_submobject_colors_by_gradient(TEAL, YELLOW)
|
|
bars.set_stroke(WHITE, 1)
|
|
|
|
bar_groups = VGroup()
|
|
for label, bar, prob in zip(labels, bars, probs):
|
|
prob_label = Integer(int(100 * prob), unit="%", font_size=0.75 * font_size)
|
|
prob_label.next_to(bar, RIGHT, buff=SMALL_BUFF)
|
|
label.next_to(bar, LEFT)
|
|
bar_groups.add(VGroup(label, bar, prob_label))
|
|
|
|
if show_ellipses:
|
|
ellipses = Tex(R"\vdots", font_size=font_size)
|
|
ellipses.next_to(bar_groups[-1][0], DOWN)
|
|
bar_groups.add(ellipses)
|
|
|
|
arrow_point = machine[-1].get_right()
|
|
bar_groups.shift(arrow_point - bars.get_left() + 1.5 * RIGHT)
|
|
bar_groups.align_to(machine, UP)
|
|
|
|
return bar_groups
|
|
|
|
def animate_text_input(self, text_mob, machine, position_text_over_machine=True, added_anims=[], lag_ratio=0.02):
|
|
blocks = machine[0]
|
|
text_copy = text_mob.copy()
|
|
if position_text_over_machine:
|
|
text_copy.target = text_copy.generate_target()
|
|
text_copy.target.set_max_width(4)
|
|
text_copy.target.next_to(blocks[0], UP)
|
|
text_copy.target.shift_onto_screen()
|
|
self.play(MoveToTarget(text_copy, path_arc=-45 * DEGREES))
|
|
self.play(LaggedStart(
|
|
*added_anims,
|
|
Transform(
|
|
text_copy,
|
|
VGroup(VectorizedPoint(machine.get_top())),
|
|
lag_ratio=lag_ratio,
|
|
run_time=1,
|
|
path_arc=-45 * DEGREES,
|
|
remover=True,
|
|
),
|
|
LaggedStart(
|
|
(
|
|
block.animate.set_color(
|
|
block.get_color() if block is blocks[-1] else TEAL
|
|
).set_anim_args(rate_func=there_and_back)
|
|
for block in blocks
|
|
),
|
|
lag_ratio=0.1,
|
|
run_time=1
|
|
),
|
|
Animation(machine[1:]),
|
|
lag_ratio=0.5
|
|
))
|
|
|
|
def animate_prediction_ouptut(self, machine, cur_str):
|
|
words, probs = self.predict_next_token(cur_str)
|
|
bar_groups = self.get_distribution(words, probs, machine)
|
|
self.play(
|
|
LaggedStart(
|
|
(FadeInFromPoint(bar_group, machine[0][-1].get_right())
|
|
for bar_group in bar_groups),
|
|
lag_ratio=0.025,
|
|
group=bar_groups,
|
|
run_time=1
|
|
)
|
|
)
|
|
return bar_groups
|
|
|
|
def animate_random_sample(self, bar_groups):
|
|
widths = np.array([group[1].get_width() for group in bar_groups[:-1]])
|
|
dist = widths / widths.sum()
|
|
seed = random.randint(0, 1000)
|
|
buff = 0.025
|
|
highlight_rect = SurroundingRectangle(bar_groups[0], buff=buff)
|
|
highlight_rect.set_stroke(YELLOW, 2)
|
|
highlight_rect.set_fill(YELLOW, 0.25)
|
|
|
|
def highlight_randomly(rect, dist, alpha):
|
|
np.random.seed(seed + int(10 * alpha))
|
|
index = np.random.choice(np.arange(len(dist)), p=dist)
|
|
rect.surround(bar_groups[index], buff=buff)
|
|
rect.stretch(1.1, 0)
|
|
|
|
self.play(
|
|
UpdateFromAlphaFunc(highlight_rect, lambda rect, a: highlight_randomly(rect, dist, a)),
|
|
Animation(bar_groups)
|
|
)
|
|
|
|
bar_groups.add_to_back(highlight_rect)
|
|
|
|
def animate_word_addition(self, bar_groups, text_mob, next_word_line, force_unskip=False):
|
|
# Choose the highlighted_group
|
|
bar_group = None
|
|
if isinstance(bar_groups[0], Rectangle):
|
|
# Use the highlight rect to find the group element
|
|
bars = bar_groups[1:-1]
|
|
diffs = [abs(bg.get_y() - bar_groups[0].get_y()) for bg in bars]
|
|
bar_group = bar_groups[1:][np.argmin(diffs)]
|
|
if bar_group is None:
|
|
bar_group = bar_groups[0]
|
|
|
|
# Animate selection
|
|
word = bar_group[0].get_text()
|
|
new_str = self.cur_str + word
|
|
new_text_mob = self.string_to_mob(new_str)
|
|
new_text_mob[:len(self.seed_text.replace(" ", ""))].set_color(self.seed_text_color)
|
|
|
|
word_targets = new_text_mob[word.strip()]
|
|
if len(word_targets) > 0:
|
|
target = word_targets[-1]
|
|
else:
|
|
target = new_text_mob[-len(word) - 1:-1]
|
|
|
|
# target = new_text_mob[-len(word):]
|
|
|
|
self.add(bar_groups)
|
|
self.play(
|
|
FadeTransform(bar_group[0].copy(), target),
|
|
Transform(
|
|
next_word_line,
|
|
self.get_next_word_line(new_text_mob),
|
|
),
|
|
)
|
|
if force_unskip:
|
|
self.skip_animations = False
|
|
target.save_state()
|
|
target.set_fill(YELLOW)
|
|
self.wait(0.5)
|
|
target.restore()
|
|
self.skip_animations = True
|
|
self.play(
|
|
FadeOut(bar_groups),
|
|
)
|
|
|
|
self.remove(text_mob)
|
|
self.add(new_text_mob)
|
|
|
|
self.cur_str = new_str
|
|
|
|
return new_text_mob
|
|
|
|
def new_selection_cycle(self, text_mob, next_word_line, machine, quick=False, skip_anims=False):
|
|
if skip_anims:
|
|
self.skip_animations = True
|
|
|
|
if quick:
|
|
words, probs = self.predict_next_token(self.cur_str)
|
|
bar_groups = self.get_distribution(words, probs, machine)
|
|
self.add(bar_groups)
|
|
else:
|
|
self.animate_text_input(text_mob, machine)
|
|
bar_groups = self.animate_prediction_ouptut(machine, self.cur_str)
|
|
self.animate_random_sample(bar_groups)
|
|
new_text_mob = self.animate_word_addition(
|
|
bar_groups, text_mob, next_word_line,
|
|
force_unskip=skip_anims
|
|
)
|
|
return new_text_mob
|
|
|
|
#
|
|
|
|
def predict_next_token(self, text):
|
|
result = None
|
|
n_shown = self.n_shown_predictions
|
|
if self.model == "gpt3":
|
|
try:
|
|
result = gpt3_predict_next_token(
|
|
text, n_shown, random_seed=self.random_seed
|
|
)
|
|
except Exception as e:
|
|
pass
|
|
if result is None:
|
|
result = gpt2_predict_next_token(text, n_shown)
|
|
return result
|
|
|
|
|
|
class AltSimpleAutoRegression(SimpleAutogregression):
|
|
n_predictions = 1
|
|
line_len = 25
|
|
|
|
def reposition_transformer_drawing(self, machine):
|
|
machine.move_to(0.5 * RIGHT)
|
|
in_arrow = machine[-1].copy()
|
|
in_arrow.rotate(-45 * DEGREES)
|
|
in_arrow.next_to(machine, UL)
|
|
self.add(in_arrow)
|
|
return machine
|
|
|
|
|
|
class AnnotateNextWord(SimpleAutogregression):
|
|
def construct(self):
|
|
text_mob, next_word_line, machine = self.init_text_and_machine()
|
|
self.add(machine, *machine[1:])
|
|
words, probs = self.predict_next_token(self.cur_str)
|
|
bar_groups = self.get_distribution(words, probs, machine[-1].get_right())
|
|
|
|
self.add(bar_groups)
|
|
|
|
# Initial text
|
|
from manimlib.mobject.boolean_ops import Union
|
|
highlight = Union(
|
|
SurroundingRectangle(text_mob["Behold, a wild pi creature,"]),
|
|
SurroundingRectangle(text_mob["foraging in its native"]),
|
|
)
|
|
highlight.set_stroke(BLUE, 3)
|
|
arrow = Vector(LEFT, stroke_width=10)
|
|
arrow.next_to(highlight, RIGHT).match_y(text_mob[0])
|
|
|
|
dist_rect = SurroundingRectangle(bar_groups)
|
|
dist_rect.set_stroke(YELLOW, 2)
|
|
|
|
self.play(
|
|
ShowCreation(highlight),
|
|
GrowArrow(arrow)
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
arrow.animate.rotate(PI / 2).next_to(dist_rect, UP),
|
|
ReplacementTransform(highlight, dist_rect),
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
FadeOut(dist_rect),
|
|
FadeOut(arrow),
|
|
)
|
|
|
|
# Flash through
|
|
self.remove(bar_groups)
|
|
text_mob = self.new_selection_cycle(
|
|
text_mob, next_word_line, machine,
|
|
)
|
|
|
|
|
|
class QuickerRegression(SimpleAutogregression):
|
|
skip_through = True
|
|
|
|
|
|
class AutoregressionGPT3(SimpleAutogregression):
|
|
model = "gpt3"
|
|
|
|
|
|
class QuickRegressionGPT3(SimpleAutogregression):
|
|
skip_through = True
|
|
model = "gpt3"
|
|
|
|
|
|
class GPT3CleverestAutocomplete(QuickRegressionGPT3):
|
|
seed_text = "To date, the cleverest thinker of all time was"
|
|
n_predictions = 70
|
|
|
|
def construct(self):
|
|
# Test
|
|
text_mob, next_word_line, machine = self.init_text_and_machine()
|
|
for n in range(self.n_predictions):
|
|
text_mob = self.new_selection_cycle(
|
|
text_mob, next_word_line, machine,
|
|
skip_anims=(n > 2),
|
|
)
|
|
|
|
|
|
class GPT3OnLearningSimpler(QuickRegressionGPT3):
|
|
seed_text = "The most effective way to learn computer science is"
|
|
text_corner = 3.5 * UP + 3 * LEFT
|
|
line_len = 35
|
|
font_size = 35
|
|
n_predictions = 300
|
|
time_per_prediction = 0.2
|
|
random_seed = 313
|
|
model = "gpt3"
|
|
min_y = -3
|
|
up_shift = 5 * UP
|
|
show_dist = False
|
|
|
|
def construct(self):
|
|
# Test
|
|
cur_str = self.seed_text
|
|
text_mob = VGroup()
|
|
for n in range(self.n_predictions):
|
|
self.clear()
|
|
words, probs = self.predict_next_token(cur_str, n_shown=20)
|
|
index = np.random.choice(np.arange(len(words)), p=(probs / probs.sum()))
|
|
new_word = words[index]
|
|
cur_str += new_word
|
|
text_mob = self.string_to_mob(cur_str)
|
|
|
|
# Color seed
|
|
if self.color_seed:
|
|
text_mob[:len(self.seed_text.replace(" ", ""))].set_color(BLUE)
|
|
|
|
# Add to text, shift if necessary
|
|
text_mob[new_word.strip()][-1].set_color(YELLOW)
|
|
if text_mob.get_bottom()[1] < self.min_y:
|
|
text_mob.shift(self.up_shift)
|
|
self.text_corner += self.up_shift
|
|
self.add(text_mob)
|
|
|
|
# Add the distribution
|
|
if self.show_dist:
|
|
dist = self.get_distribution(
|
|
words[:self.n_shown_predictions],
|
|
probs[:self.n_shown_predictions],
|
|
buff=0
|
|
)
|
|
dist.set_height(4)
|
|
dist.to_edge(DOWN)
|
|
rect = SurroundingRectangle(dist[min(index, len(dist) - 1)])
|
|
self.add(dist, rect)
|
|
|
|
self.wait(self.time_per_prediction)
|
|
|
|
|
|
class GPT3OnLongPassages(GPT3OnLearningSimpler):
|
|
seed_text = "Writing long passages seems to involve more foresight and planning than what single-word prediction"
|
|
n_predictions = 100
|
|
color_seed = False
|
|
|
|
|
|
class GPT3CreaturePrediction(GPT3CleverestAutocomplete):
|
|
seed_text = "the fluffy blue creature"
|
|
n_predictions = 1
|
|
|
|
|
|
class GPT3CreaturePrediction2(GPT3CleverestAutocomplete):
|
|
seed_text = "the fluffy blue creature roamed the"
|
|
n_predictions = 1
|
|
|
|
|
|
class LowTempExample(GPT3OnLearningSimpler):
|
|
seed_text = "Once upon a time, there was a"
|
|
model = "gpt3"
|
|
min_y = 1
|
|
up_shift = 2 * UP
|
|
show_dist = True
|
|
temp = 0
|
|
n_predictions = 200
|
|
time_per_prediction = 0.25
|
|
|
|
def predict_next_token(self, text, n_shown=None):
|
|
words, probs = super().predict_next_token(text, n_shown)
|
|
if self.temp == 0:
|
|
probs = np.zeros_like(probs)
|
|
probs[0] = 1
|
|
else:
|
|
probs = probs**(1 / self.temp)
|
|
probs /= probs.sum()
|
|
return words, probs
|
|
|
|
|
|
class HighTempExample(LowTempExample):
|
|
temp = 5
|
|
model = "gpt3"
|
|
|
|
|
|
class MidTempExample(LowTempExample):
|
|
seed_text = "If you could see the underlying probability distributions a large language model uses when generating text, then"
|
|
temp = 1
|
|
model = "gpt3"
|
|
|
|
|
|
class ChatBotPrompt(SimpleAutogregression):
|
|
system_prompt = """
|
|
What follows is a conversation between a user and a helpful,
|
|
very knowledgeable AI assistant.
|
|
"""
|
|
user_prompt = "User: Give me some ideas for what to do when visiting Paris."
|
|
ai_seed = "AI Assistant: "
|
|
machine_name = "Large\nLanguage\nModel"
|
|
|
|
line_len = 28
|
|
font_size = 36
|
|
color_seed = False
|
|
|
|
n_predictions = 60
|
|
model = "gpt3"
|
|
random_seed = 12
|
|
|
|
def construct(self):
|
|
# Test
|
|
text_mob, next_word_line, machine = self.init_text_and_machine()
|
|
|
|
all_strs = list(map(clean_text, [self.system_prompt, self.user_prompt, self.ai_seed]))
|
|
|
|
system_prompt, user_prompt, ai_seed = all_text = VGroup(
|
|
get_paragraph(
|
|
s.split(" "),
|
|
font_size=self.font_size,
|
|
line_len=self.line_len
|
|
)
|
|
for s in all_strs
|
|
)
|
|
all_text.arrange(DOWN, aligned_edge=LEFT, buff=0.75)
|
|
all_text.move_to(self.text_corner, UL)
|
|
self.remove(text_mob)
|
|
self.add(all_text)
|
|
|
|
text_mob = ai_seed
|
|
self.text_corner = text_mob.get_corner(UL)
|
|
next_word_line.next_to(ai_seed, RIGHT, aligned_edge=DOWN)
|
|
|
|
self.cur_str = "\n\n".join(all_strs)
|
|
|
|
# Comment on system prompt
|
|
sys_rect = SurroundingRectangle(system_prompt)
|
|
sys_rect.set_stroke(GREEN, 2)
|
|
|
|
self.play(
|
|
ShowCreation(sys_rect),
|
|
system_prompt.animate.set_color(GREEN_B)
|
|
)
|
|
self.wait()
|
|
|
|
# Users prompt
|
|
from manimlib.mobject.boolean_ops import Union
|
|
|
|
top_line = user_prompt["Give me some ideas for what"]
|
|
low_line = user_prompt["to do when visiting Santiago."]
|
|
user_rect = Union(
|
|
SurroundingRectangle(low_line),
|
|
SurroundingRectangle(top_line),
|
|
)
|
|
user_rect.set_stroke(BLUE, 2)
|
|
|
|
sys_rect.insert_n_curves(100)
|
|
self.play(
|
|
ReplacementTransform(sys_rect, user_rect),
|
|
top_line.animate.set_color(BLUE_B),
|
|
low_line.animate.set_color(BLUE_B),
|
|
)
|
|
self.wait()
|
|
self.play(
|
|
FadeOut(user_rect),
|
|
)
|
|
|
|
# Run predictions
|
|
text_mob = all_text
|
|
self.add(all_text.copy())
|
|
for n in range(self.n_predictions):
|
|
text_mob = self.new_selection_cycle(
|
|
text_mob, next_word_line, machine,
|
|
skip_anims=(n > 0),
|
|
)
|
|
|
|
def string_to_mob(self, text):
|
|
seed = self.ai_seed.strip()
|
|
if seed in text:
|
|
text = text[text.index(seed):]
|
|
return super().string_to_mob(text)
|
|
|
|
|
|
class ChatBotPrompt2(ChatBotPrompt):
|
|
user_prompt = "User: Can you explain what temperature is, in the context of softmax?"
|
|
|
|
|
|
class ChatBotPrompt3(ChatBotPrompt):
|
|
user_prompt = "User: Can you give me some ideas for what to do while visiting Munich?"
|
|
|
|
|
|
class VoiceToTextExample(SimpleAutogregression):
|
|
model_name = "voice-to-text"
|
|
|
|
def construct(self):
|
|
# Add model
|
|
box = Rectangle(4, 3)
|
|
box.set_stroke(WHITE, 2)
|
|
name = Text(self.model_name, font_size=60)
|
|
name.set_max_width(box.get_width())
|
|
name.next_to(box, UP)
|
|
machine = self.get_transformer_drawing()
|
|
machine.center()
|
|
machine.set_max_width(0.75 * box.get_width())
|
|
machine.move_to(box)
|
|
arrows = Vector(0.75 * RIGHT, stroke_width=8).replicate(2)
|
|
arrows[0].next_to(box, LEFT, SMALL_BUFF)
|
|
arrows[1].next_to(box, RIGHT, SMALL_BUFF)
|
|
model = Group(box, name, arrows, machine)
|
|
|
|
self.add(*model)
|
|
self.add(Point())
|
|
|
|
# Process input
|
|
max_width = 3.75
|
|
in_mob = self.get_input().set_max_width(max_width)
|
|
out_mob = self.get_output().set_max_width(max_width)
|
|
in_mob.next_to(arrows, LEFT)
|
|
out_mob.next_to(arrows, RIGHT)
|
|
|
|
self.add(in_mob)
|
|
self.play(LaggedStart(
|
|
FadeOutToPoint(
|
|
in_mob.copy(), machine.get_left(),
|
|
path_arc=-45 * DEGREES,
|
|
lag_ratio=0.01,
|
|
),
|
|
LaggedStart(
|
|
(block.animate.set_color(TEAL).set_anim_args(rate_func=there_and_back)
|
|
for block in machine[0][:-1]),
|
|
lag_ratio=0.1,
|
|
run_time=1,
|
|
),
|
|
FadeInFromPoint(
|
|
out_mob.copy(), machine.get_right(),
|
|
path_arc=45 * DEGREES,
|
|
lag_ratio=0.02
|
|
),
|
|
lag_ratio=0.7
|
|
))
|
|
self.wait()
|
|
|
|
def get_input(self) -> Mobject:
|
|
result =ImageMobject("AudioSnippet").set_width(3.75)
|
|
result.set_height(3, stretch=True)
|
|
return result
|
|
|
|
def get_output(self) -> Mobject:
|
|
return Text("""
|
|
Some models take
|
|
in audio and
|
|
produce a transcript
|
|
""", alignment="LEFT")
|
|
|
|
|
|
class TextToVoiceExample(VoiceToTextExample):
|
|
model_name = "text-to-voice"
|
|
|
|
def get_input(self):
|
|
return Text("""
|
|
This sentence comes from
|
|
a model going the other
|
|
way around, producing
|
|
synthetic speech just
|
|
from text.
|
|
""", alignment="LEFT")
|
|
|
|
def get_output(self):
|
|
return super().get_input()
|
|
|
|
|
|
class TextToImage(VoiceToTextExample):
|
|
model_name = "text-to-image"
|
|
prompt = """
|
|
1960s photograph of a cute fluffy blue wild pi
|
|
creature, a creature whose body is shaped like
|
|
the symbol π, who is foraging in its native territory,
|
|
staring back at the camera with an exotic scene
|
|
in the background.
|
|
"""
|
|
image_name = "PiCreatureDalle3_5"
|
|
|
|
def get_clean_prompt(self):
|
|
return clean_text(self.prompt)
|
|
|
|
def get_input(self):
|
|
return get_paragraph(self.get_clean_prompt().split(" "), line_len=25)
|
|
|
|
def get_output(self):
|
|
return ImageMobject(self.image_name)
|
|
|
|
def generate_output(self):
|
|
# Test
|
|
self.prompt = """
|
|
1960s photograph of a cute fluffy blue wild pi
|
|
creature, a creature whose face bears a subtle resemblence
|
|
to the shape of the symbol π, who is foraging in its native
|
|
territory, staring back at the camera with an exotic scene
|
|
in the background.
|
|
"""
|
|
|
|
self.prompt = "abstract depiction of furry fluffiness"
|
|
|
|
openai.api_key = os.getenv('OPENAI_KEY')
|
|
prompt = self.get_clean_prompt()
|
|
|
|
response = openai.Image.create(
|
|
model="dall-e-3",
|
|
prompt=prompt,
|
|
size="1024x1024",
|
|
quality="standard",
|
|
n=1,
|
|
)
|
|
|
|
image_url = response.data[0].url
|
|
print(prompt)
|
|
print(image_url)
|
|
|
|
response = openai.Image.create_variation(
|
|
image=open("/Users/grant/3Blue1Brown Dropbox/3Blue1Brown/images/raster/PiCreatureDalle3_17.png", "rb"),
|
|
n=1,
|
|
size="1024x1024"
|
|
)
|
|
|
|
|
|
class TranslationExample(VoiceToTextExample):
|
|
model_name = "machine translation"
|
|
|
|
def get_input(self):
|
|
return Text("Attention is all\nyou need")
|
|
|
|
def get_output(self):
|
|
return Group(Point(), *Text("注意力就是你所需要的一切"))
|
|
|
|
|
|
class PredictionVsGeneration(SimpleAutogregression):
|
|
model = "gpt2"
|
|
|
|
def construct(self):
|
|
# Setup
|
|
self.add(FullScreenRectangle())
|
|
morty = Mortimer()
|
|
morty.to_edge(DOWN)
|
|
morty.body.insert_n_curves(100)
|
|
self.add(morty)
|
|
|
|
# Words
|
|
words = VGroup(Text("Prediction"), Text("Generation"))
|
|
words.scale(1.5)
|
|
for vect, word in zip([UL, UR], words):
|
|
word.next_to(morty, vect)
|
|
word.shift(0.5 * UP)
|
|
|
|
# Create prediction object
|
|
seed_text = "The goal of predicting the next"
|
|
self.n_shown_predictions = 8
|
|
tokens, probs = self.predict_next_token(seed_text)
|
|
dist = self.get_distribution(tokens, probs)
|
|
brace = Brace(dist, LEFT, SMALL_BUFF)
|
|
words = Text(seed_text, font_size=36).next_to(brace, LEFT)
|
|
prediction = VGroup(words, brace, dist)
|
|
prediction.set_width(FRAME_WIDTH / 2 - 1)
|
|
prediction.next_to(morty, UL)
|
|
prediction.shift(0.5 * UP).shift_onto_screen()
|
|
self.add(prediction)
|
|
|
|
# Animations
|
|
self.play(
|
|
morty.change("raise_right_hand", prediction),
|
|
FadeIn(prediction[0], UP),
|
|
GrowFromCenter(prediction[1]),
|
|
LaggedStart(
|
|
(FadeInFromPoint(bar, prediction[1].get_center())
|
|
for bar in prediction[2]),
|
|
lag_ratio=0.05,
|
|
)
|
|
)
|
|
self.play(Blink(morty))
|
|
self.play(
|
|
morty.change("raise_left_hand", 3 * UR),
|
|
)
|
|
self.wait()
|
|
self.play(Blink(morty))
|
|
self.wait()
|
|
|
|
|
|
class ManyParallelPredictions(SimpleAutogregression):
|
|
line_len = 200
|
|
n_shown_predictions = 8
|
|
model = "gpt3"
|
|
|
|
def construct(self):
|
|
# Setup
|
|
self.fake_machine = VectorizedPoint().replicate(3)
|
|
full_string = "Harry Potter was a highly unusual boy"
|
|
|
|
# Draw last layer vectors
|
|
last_layer = VGroup(
|
|
NumericEmbedding(length=10)
|
|
for n in range(12)
|
|
)
|
|
last_layer.arrange(RIGHT, buff=0.35 * last_layer[0].get_width())
|
|
last_layer.set_height(3)
|
|
last_layer.to_edge(DOWN)
|
|
# self.add(last_layer)
|
|
|
|
rects = VGroup(map(SurroundingRectangle, last_layer))
|
|
rects.set_stroke(YELLOW, 2)
|
|
arrows = VGroup(Vector(0.5 * UP).next_to(rect, UP, buff=0.1) for rect in rects)
|
|
arrows.set_stroke(YELLOW)
|
|
|
|
# Show prediction groups
|
|
words = full_string.split(" ")
|
|
substrings = [
|
|
" ".join(words[:n + 1])
|
|
for n in range(len(words))
|
|
]
|
|
|
|
predictions = VGroup(
|
|
self.get_prediction_group(substring)
|
|
for substring in substrings
|
|
)
|
|
predictions[0].to_edge(UP, buff=1.25).align_to(rects[1], LEFT)
|
|
for prediction, arrow, rect in zip(predictions, arrows, rects):
|
|
prediction.move_to(predictions[0], LEFT)
|
|
arrow.become(Arrow(
|
|
rect.get_top(),
|
|
prediction[1].get_left(),
|
|
))
|
|
arrow.set_stroke(YELLOW)
|
|
|
|
last_group = VGroup(
|
|
rects[0].copy().set_opacity(0),
|
|
arrows[0].copy().set_opacity(0),
|
|
predictions[0].copy().set_opacity(0),
|
|
)
|
|
for rect, arrow, prediction in zip(rects, arrows, predictions):
|
|
self.remove(last_group)
|
|
self.play(
|
|
TransformFromCopy(last_group[0], rect),
|
|
TransformFromCopy(last_group[1], arrow),
|
|
TransformMatchingStrings(last_group[2][0].copy(), prediction[0], run_time=1),
|
|
FadeTransform(last_group[2][1].copy(), prediction[1]),
|
|
FadeTransform(last_group[2][2].copy(), prediction[2]),
|
|
)
|
|
self.wait()
|
|
last_group = VGroup(rect, arrow, prediction)
|
|
|
|
def get_prediction_group(self, text):
|
|
words, probs = self.predict_next_token(text)
|
|
dist = self.get_distribution(
|
|
words, probs,
|
|
width_100p=2.0
|
|
)
|
|
dist.set_max_height(2.5)
|
|
brace = Brace(dist, LEFT)
|
|
prefix = Text(text, font_size=30)
|
|
prefix.next_to(brace, LEFT)
|
|
|
|
result = VGroup(prefix, brace, dist)
|
|
|
|
return result
|
|
|
|
|
|
class PeekUnderTheHood(SimpleAutogregression):
|
|
def construct(self):
|
|
# Add parts
|
|
text_mob, next_word_line, machine = self.init_text_and_machine()
|
|
blocks, label, arrow = machine
|
|
self.remove(text_mob, next_word_line)
|
|
|
|
# Zoom in
|
|
self.camera.light_source.move_to([-15, 5, 10])
|
|
self.set_floor_plane("xz")
|
|
|
|
blocks.rotate(-5 * DEGREES, UP, about_edge=OUT)
|
|
blocks.rotate(-10 * DEGREES, RIGHT, about_edge=OUT)
|
|
blocks.target = blocks.generate_target()
|
|
blocks.target.set_height(5)
|
|
blocks.target.center()
|
|
blocks.target[5:].set_opacity(0.3)
|
|
|
|
self.play(
|
|
self.frame.animate.reorient(-23, -12, 0, (1.79, -0.56, 1.27), 8.40).set_anim_args(run_time=3),
|
|
MoveToTarget(blocks, run_time=3),
|
|
FadeOut(arrow, RIGHT),
|
|
FadeOut(label, 2 * OUT),
|
|
)
|
|
self.wait()
|
|
|
|
blocks[5:].set_opacity(0.3)
|
|
|
|
# Add matrices
|
|
matrices = VGroup(WeightMatrix(shape=(8, 8)) for x in range(9))
|
|
matrices.arrange_in_grid(h_buff_ratio=0.25, v_buff_ratio=0.4)
|
|
matrices.match_width(blocks)
|
|
index = 6
|
|
matrices.move_to(blocks[index], OUT)
|
|
self.add(matrices, blocks[index:]) |