2024-09-06 17:47:26 -05:00
from manim_imports_ext import *
2024-10-23 13:21:07 -05:00
from _2024 . transformers . generation import *
2024-09-06 17:47:26 -05:00
from _2024 . transformers . helpers import *
from _2024 . transformers . embedding import *
2024-10-23 13:21:07 -05:00
from _2024 . transformers . ml_basics import *
2024-09-06 17:47:26 -05:00
2024-11-15 12:10:17 -08:00
# Intro
class HoldUpThumbnail ( TeacherStudentsScene ) :
def construct ( self ) :
# Test
im = ImageMobject ( " /Users/grant/3Blue1Brown Dropbox/3Blue1Brown/videos/2024/transformers/Thumbnails/Chapter5_TN3.png " )
im_group = Group (
SurroundingRectangle ( im , buff = 0 ) . set_stroke ( WHITE , 3 ) ,
im
)
im_group . set_height ( 3 )
im_group . move_to ( self . hold_up_spot , DOWN )
morty = self . teacher
stds = self . students
self . play (
FadeIn ( im_group , UP ) ,
morty . change ( " raise_right_hand " , look_at = im_group ) ,
self . change_students ( " tease " , " happy " , " tease " , look_at = im_group ) ,
)
self . wait ( 4 )
class IsThisUsefulToShare ( TeacherStudentsScene ) :
def construct ( self ) :
# Test
morty = self . teacher
self . play (
morty . says ( " Do you find \n this useful? " ) ,
self . change_students ( " pondering " , " hesitant " , " well " , look_at = self . screen )
)
self . wait ( 3 )
self . play ( self . change_students ( " thinking " , " pondering " , " tease " ) )
self . wait ( 3 )
class AskAboutAttention ( TeacherStudentsScene ) :
def construct ( self ) :
# Test
stds = self . students
morty = self . teacher
self . play (
morty . change ( " tease " ) ,
stds [ 2 ] . says ( " Can you explain what \n Attention does? " , mode = " raise_left_hand " , bubble_direction = LEFT ) ,
stds [ 1 ] . change ( " pondering " , self . screen ) ,
stds [ 0 ] . change ( " pondering " , self . screen ) ,
)
self . wait ( 4 )
# Version 1
2024-09-06 17:47:26 -05:00
class PredictTheNextWord ( SimpleAutogregression ) :
text_corner = 3.5 * UP + 6.5 * LEFT
machine_name = " Large \n Language \n Model "
seed_text = " Paris is a city in "
model = " gpt3 "
n_shown_predictions = 12
random_seed = 2
def construct ( self ) :
# Setup machine
text_mob , next_word_line , machine = self . init_text_and_machine ( )
machine . move_to ( ORIGIN )
machine [ 1 ] . set_backstroke ( BLACK , 3 )
text_group = VGroup ( text_mob , next_word_line )
text_group . save_state ( )
text_group . scale ( 1.5 )
text_group . match_x ( machine [ 0 ] ) . to_edge ( UP )
# Introduce the machine
in_arrow = Arrow ( text_group , machine [ 0 ] . get_top ( ) , thickness = 5 )
frame = self . frame
self . set_floor_plane ( " xz " )
blocks = machine [ 0 ]
llm_text = machine [ 1 ]
block_outlines = blocks . copy ( )
block_outlines . set_fill ( opacity = 0 )
block_outlines . set_stroke ( GREY_B , 2 )
block_outlines . insert_n_curves ( 20 )
2024-10-23 13:21:07 -05:00
flat_dials , last_dials = self . get_machine_dials ( blocks )
2024-09-06 17:47:26 -05:00
self . clear ( )
frame . reorient ( - 31 , - 4 , - 5 , ( - 0.24 , - 0.26 , - 0.06 ) , 3 )
self . play (
FadeIn ( blocks , shift = 0.0 , lag_ratio = 0.01 ) ,
LaggedStartMap ( VShowPassingFlash , block_outlines . family_members_with_points ( ) , time_width = 2.0 , lag_ratio = 0.01 , remover = True ) ,
LaggedStartMap ( VFadeInThenOut , flat_dials , lag_ratio = 0.001 , remover = True ) ,
Write ( llm_text , time_span = ( 2 , 4 ) , stroke_color = WHITE ) ,
FadeIn ( last_dials , time_span = ( 4 , 5 ) ) ,
frame . animate . reorient ( 0 , 0 , 0 , ( - 0.17 , - 0.12 , 0.0 ) , 4.50 ) ,
run_time = 6 ,
)
blocks [ - 1 ] . add ( last_dials )
self . play (
frame . animate . to_default_state ( ) ,
FadeIn ( text_group , UP ) ,
GrowFromCenter ( in_arrow ) ,
run_time = 3
)
# Single word prediction
out_arrow = Vector ( 1.5 * RIGHT , thickness = 5 )
out_arrow . next_to ( machine [ 0 ] [ - 1 ] , RIGHT )
prediction = Text ( " France " , font_size = 72 )
prediction . next_to ( out_arrow , RIGHT )
self . animate_text_input (
text_mob , machine ,
position_text_over_machine = False ,
)
self . play (
LaggedStart (
( TransformFromCopy ( VectorizedPoint ( machine . get_right ( ) ) , letter )
for letter in prediction ) ,
lag_ratio = 0.05 ,
) ,
GrowArrow ( out_arrow )
)
self . wait ( )
machine . replace_submobject ( 2 , out_arrow )
# Probability distribution
self . play ( FadeOut ( prediction , DOWN ) )
bar_groups = self . animate_prediction_ouptut ( machine , self . cur_str )
self . wait ( )
# Show auto_regression
self . play (
Restore ( text_group ) ,
FadeOut ( in_arrow ) ,
)
seed_label = Text ( " Seed text " )
seed_label . set_color ( YELLOW )
seed_label . next_to ( text_mob , DOWN )
self . play (
FadeIn ( seed_label , rate_func = there_and_back_with_pause ) ,
FlashAround ( text_mob , time_width = 2 ) ,
frame . animate . reorient ( 0 , 0 , 0 , ( 0.7 , - 0.01 , 0.0 ) , 8.52 ) ,
run_time = 2 ,
)
self . animate_random_sample ( bar_groups )
new_text_mob = self . animate_word_addition (
bar_groups , text_mob , next_word_line ,
)
# More!
for n in range ( 20 ) :
text_mob = self . new_selection_cycle (
text_mob , next_word_line , machine ,
quick = True ,
skip_anims = ( n > 5 ) ,
)
self . wait ( 0.25 )
2024-10-23 13:21:07 -05:00
def get_machine_dials ( self , blocks ) :
dials = VGroup (
Dial ( ) . get_grid ( 8 , 12 ) . set_width ( 0.9 * block . get_width ( ) ) . move_to ( block )
for block in blocks
)
dials . set_stroke ( opacity = 0.5 )
for group in dials :
for dial in group :
dial . set_value ( dial . get_random_value ( ) )
flat_dials = VGroup ( * it . chain ( * dials ) )
last_dials = dials [ - 1 ] . copy ( )
last_dials . set_stroke ( opacity = 0.1 )
return flat_dials , last_dials
2024-09-06 17:47:26 -05:00
class LotsOfTextIntoTheMachine ( PredictTheNextWord ) :
run_time = 25
max_snippet_width = 3
def construct ( self ) :
# Add machine
text_mob , next_word_line , machine = self . init_text_and_machine ( )
machine . scale ( 1.5 )
self . clear ( )
self . add ( machine )
blocks , title = machine [ : 2 ]
dials = Dial ( ) . get_grid ( 8 , 12 ) . set_width ( 0.9 * blocks [ - 1 ] . get_width ( ) ) . move_to ( blocks [ - 1 ] )
dials . set_stroke ( opacity = 0.1 )
blocks [ - 1 ] . add ( dials )
machine . center ( )
machine [ 1 ] . set_stroke ( BLACK , 3 )
# Feed in lots of text
snippets = self . get_text_snippets ( )
text_mobs = VGroup ( get_paragraph ( snippet . split ( " " ) , line_len = 25 ) for snippet in snippets )
directions = compass_directions ( 12 , start_vect = UR )
for text_mob , vect in zip ( text_mobs , it . cycle ( directions ) ) :
text_mob . set_max_width ( self . max_snippet_width )
text_mob . move_to ( 5 * vect ) . shift_onto_screen ( buff = 0.25 )
self . play (
LaggedStart (
( Succession (
FadeIn ( text_mob ) ,
text_mob . animate . set_opacity ( 0 ) . move_to ( machine . get_center ( ) ) ,
)
for text_mob in text_mobs ) ,
lag_ratio = 0.05 ,
run_time = self . run_time
)
)
self . remove ( text_mobs )
self . wait ( )
def get_text_snippets ( self ) :
facts = Path ( DATA_DIR , " pile_of_text.txt " ) . read_text ( ) . split ( " \n " )
random . shuffle ( facts )
return facts
class EvenMoreTextIntoMachine ( LotsOfTextIntoTheMachine ) :
run_time = 40
max_snippet_width = 2.5
n_examples = 300
context_size = 25
def get_text_snippets ( self ) :
book = Path ( DATA_DIR , " tale_of_two_cities.txt " ) . read_text ( )
book = book . replace ( " \n " , " " )
words = list ( filter ( lambda m : m , book . split ( " " ) ) )
context_size = self . context_size
result = [ ]
for n in range ( self . n_examples ) :
index = random . randint ( 0 , len ( words ) - context_size - 1 )
result . append ( " " . join ( words [ index : index + context_size ] ) )
return result
class WriteTransformer ( InteractiveScene ) :
def construct ( self ) :
text = Text ( " Transformer " , font_size = 120 )
self . play ( Write ( text ) )
self . wait ( )
class LabelVector ( InteractiveScene ) :
def construct ( self ) :
brace = Brace ( Line ( UP , DOWN ) . set_height ( 4 ) , RIGHT )
name = Text ( " Vector " , font_size = 72 )
name . next_to ( brace , RIGHT )
name . set_backstroke ( BLACK , 5 )
self . play (
GrowFromCenter ( brace ) ,
Write ( name ) ,
)
self . wait ( )
class AdjustingTheMachine ( InteractiveScene ) :
def construct ( self ) :
# Add a machine and repeatedly tweak it
frame = self . frame
self . set_floor_plane ( " xz " )
frame . reorient ( - 28 , - 17 , 0 , ORIGIN , 8.91 )
self . camera . light_source . move_to ( [ - 10 , 10 , 10 ] )
machine = MachineWithDials ( n_rows = 10 , n_cols = 12 )
machine . set_height ( 6 )
blocks = VCube ( ) . replicate ( 10 )
blocks . set_shape ( machine . get_width ( ) , machine . get_height ( ) , 1.0 )
blocks . deactivate_depth_test ( )
cam_loc = self . frame . get_implied_camera_location ( )
for block in blocks :
block . sort ( lambda p : - get_norm ( p - cam_loc ) )
blocks . set_fill ( GREY_D , 1 )
blocks . set_shading ( 0.2 , 0.5 , 0.25 )
blocks . arrange ( OUT , buff = 0.5 )
blocks . move_to ( machine , OUT )
self . add ( blocks )
self . add ( machine )
frame . clear_updaters ( )
frame . add_updater ( lambda f : f . set_theta ( - 30 * DEGREES * math . cos ( 0.1 * self . time ) ) )
self . add ( frame )
for x in range ( 6 ) :
self . play ( machine . random_change_animation ( lag_factor = 0.1 ) )
class FirthQuote ( InteractiveScene ) :
def construct ( self ) :
# Show Quote
quote = TexText ( R " ``You shall know a word \\ by the company it keeps! ' ' " , font_size = 60 )
image = ImageMobject ( " JohnRFirth " ) # From https://www.cambridge.org/core/journals/bulletin-of-the-school-of-oriental-and-african-studies/article/john-rupert-firth/D926AFCBF99AD17D5C7A7A9C0558DFDC
image . set_height ( 6.5 )
image . to_corner ( UL , buff = 0.5 )
name = Text ( " John R. Firth " )
name . next_to ( image , DOWN )
quote . move_to ( midpoint ( image . get_right ( ) , RIGHT_SIDE ) )
quote . to_edge ( UP )
self . play (
FadeIn ( image , 0.25 * UP ) ,
FadeIn ( name , lag_ratio = 0.1 )
)
self . play ( Write ( quote ) )
self . wait ( )
# Show two sentences
phrases = VGroup (
Text ( " Down by the river bank " ) ,
Text ( " Deposit a check at the bank " ) ,
)
bank = Text ( " bank " , font_size = 90 )
bank . set_color ( TEAL )
bank . match_x ( quote ) . match_y ( image )
for phrase in phrases :
phrase [ " bank " ] . set_color ( TEAL )
phrases . arrange ( DOWN , buff = 1.0 , aligned_edge = LEFT )
phrases . next_to ( quote , DOWN , buff = 2.5 )
phrases [ 1 ] . set_opacity ( 0.15 )
banks = VGroup (
phrase [ " bank " ] [ 0 ]
for phrase in phrases
)
self . play (
FadeIn ( bank , scale = 2 , lag_ratio = 0.25 ) ,
quote . animate . scale ( 0.7 , about_edge = UP ) . set_opacity ( 0.75 )
)
self . wait ( )
self . remove ( bank )
self . play (
FadeIn ( phrases [ 0 ] [ : len ( " downbytheriver " ) ] , lag_ratio = 0.1 ) ,
FadeIn ( phrases [ 1 ] [ : len ( " depositacheckatthe " ) ] , lag_ratio = 0.1 ) ,
* ( TransformFromCopy ( bank , bank2 ) for bank2 in banks )
)
self . wait ( )
self . play (
phrases [ 0 ] . animate . set_opacity ( 0.5 ) ,
phrases [ 1 ] . animate . set_opacity ( 1 ) ,
)
self . wait ( )
# Isolate both phrases
self . play ( LaggedStart (
FadeOut ( image , LEFT , scale = 0.5 ) ,
FadeOut ( name , LEFT , scale = 0.5 ) ,
FadeOut ( quote , LEFT , scale = 0.5 ) ,
phrases . animate . set_opacity ( 1 ) . arrange ( DOWN , buff = 3.5 , aligned_edge = LEFT ) . move_to ( 0.5 * UP ) ,
) )
self . wait ( )
2024-10-23 13:21:07 -05:00
# Recreate
word = Text ( " bank " , font_size = 72 )
word . set_color ( TEAL )
self . clear ( )
self . add ( word )
self . wait ( )
self . remove ( word )
self . play (
* (
FadeIn ( phrase [ phrase . get_text ( ) . replace ( " bank " , " " ) ] )
for phrase in phrases
) ,
* (
TransformFromCopy ( word , phrase [ " bank " ] [ 0 ] )
for phrase in phrases
)
)
self . add ( phrases )
2024-09-06 17:47:26 -05:00
# Show influence
query_rects = VGroup (
SurroundingRectangle ( bank )
for bank in banks
)
query_rects . set_stroke ( TEAL , 2 )
query_rects . set_fill ( TEAL , 0.25 )
key_rects = VGroup (
SurroundingRectangle ( phrases [ 0 ] [ " river " ] ) ,
SurroundingRectangle ( phrases [ 1 ] [ " Deposit " ] ) ,
SurroundingRectangle ( phrases [ 1 ] [ " check " ] ) ,
)
key_rects . set_stroke ( BLUE , 2 )
key_rects . set_fill ( BLUE , 0.5 )
key_rects [ 2 ] . match_height ( key_rects [ 1 ] , about_edge = UP , stretch = True )
arrows = VGroup (
Arrow ( key_rects [ 0 ] . get_top ( ) , banks [ 0 ] . get_top ( ) , path_arc = - 180 * DEGREES , buff = 0.1 ) ,
Arrow ( key_rects [ 1 ] . get_top ( ) , banks [ 1 ] . get_top ( ) , path_arc = - 90 * DEGREES ) ,
Arrow ( key_rects [ 2 ] . get_top ( ) , banks [ 1 ] . get_top ( ) , path_arc = - 90 * DEGREES ) ,
)
arrows . set_color ( BLUE )
key_rects . save_state ( )
key_rects [ 0 ] . become ( query_rects [ 0 ] )
key_rects [ 1 ] . become ( query_rects [ 1 ] )
key_rects [ 2 ] . become ( query_rects [ 1 ] )
key_rects . set_opacity ( 0 )
self . add ( query_rects , phrases )
self . play ( FadeIn ( query_rects , lag_ratio = 0.25 ) )
self . wait ( )
self . add ( key_rects , phrases )
self . play ( Restore ( key_rects , lag_ratio = 0.1 , path_arc = PI / 4 , run_time = 2 ) )
self . play ( LaggedStartMap ( Write , arrows , stroke_width = 5 , run_time = 3 ) )
self . wait ( )
# Show images
images = Group (
ImageMobject ( " RiverBank " ) ,
ImageMobject ( " FederalReserve " ) ,
)
for image , bank in zip ( images , banks ) :
image . set_height ( 2.0 )
image . next_to ( bank , DOWN , MED_SMALL_BUFF , aligned_edge = LEFT )
self . play (
LaggedStart (
( FadeTransform ( Group ( word ) . copy ( ) , image )
for word , image in zip ( banks , images ) ) ,
lag_ratio = 0.5 ,
group_type = Group ,
)
)
self . wait ( 2 )
class DownByTheRiverHeader ( InteractiveScene ) :
def construct ( self ) :
words = Text ( " Down by the river bank ... " )
rect = SurroundingRectangle ( words [ " bank " ] )
rect . set_fill ( BLUE , 0.5 )
rect . set_stroke ( BLUE , 3 )
brace = Brace ( rect , DOWN , buff = SMALL_BUFF )
self . add ( rect , words , brace )
class RiverBankProbParts ( SimpleAutogregression ) :
seed_text = " Down by the river bank, "
model = " gpt3 "
def construct ( self ) :
# Test
text_mob , next_word_line , machine = self . init_text_and_machine ( )
machine . set_x ( 0 )
words = [
2024-10-23 13:21:07 -05:00
" water " ,
" river " ,
" lake " ,
" grass " ,
" waves " ,
" shallows " ,
" pool " ,
" depths " ,
" foam " ,
" mist " ,
2024-09-06 17:47:26 -05:00
]
probs = softmax ( [ 6 , 5 , 4 , 4 , 3.5 , 3.25 , 3 , 3 , 2.5 , 2 ] )
bar_groups = self . get_distribution ( words , probs , machine )
self . clear ( )
bar_groups . set_height ( 6 ) . center ( )
self . play (
LaggedStartMap ( FadeIn , bar_groups , shift = 0.25 * DOWN , run_time = 3 )
)
self . wait ( )
class FourStepsWithParameters ( InteractiveScene ) :
def construct ( self ) :
# Add rectangles and titles
self . add ( FullScreenRectangle ( fill_color = GREY_E ) )
rects = Square ( ) . replicate ( 4 )
rects . arrange ( RIGHT , buff = 0.25 * rects [ 0 ] . get_width ( ) )
rects . set_width ( FRAME_WIDTH - 1.0 )
rects . center ( ) . to_edge ( UP , buff = 0.5 )
rects . set_fill ( BLACK , 1 )
rects . set_stroke ( WHITE , 2 )
names = VGroup ( * map ( TexText , [
R " Text snippets \\ $ \ downarrow$ \\ Vectors " ,
R " Attention " ,
R " Feedforward " ,
R " Final prediction " ,
] ) )
for name , rect in zip ( names , rects ) :
name . scale ( 0.8 )
name . next_to ( rect , DOWN )
self . add ( rects )
self . play ( LaggedStartMap ( FadeIn , names , shift = 0.25 * DOWN , lag_ratio = 0.25 ) )
self . wait ( )
# Show many dials
machines = VGroup (
MachineWithDials (
width = rect . get_width ( ) ,
height = 3.0 ,
n_rows = 9 ,
n_cols = 6 ,
)
for rect in rects
)
for machine , rect in zip ( machines , rects ) :
machine . next_to ( rect , DOWN , buff = 0 )
machine [ 0 ] . set_opacity ( 0 )
machine . scale ( rect . get_width ( ) / machine . dials . get_width ( ) , about_edge = UP )
machine . dials . shift ( 0.25 * UP )
for dial in machine . dials :
dial . set_value ( 0 )
self . play (
LaggedStart ( (
LaggedStart (
( GrowFromPoint ( dial , machine . get_top ( ) )
for dial in machine . dials ) ,
lag_ratio = 0.025 ,
)
for machine in machines
) , lag_ratio = 0.25 ) ,
LaggedStartMap ( FadeOut , names )
)
for _ in range ( 2 ) :
self . play (
LaggedStart (
( machine . random_change_animation ( )
for machine in machines ) ,
lag_ratio = 0.2 ,
)
)
class ChatbotFeedback ( InteractiveScene ) :
random_seed = 404
def construct ( self ) :
# Test
self . frame . set_height ( 10 ) . move_to ( DOWN )
user_prompt = " User: How and when was the internet invented? "
prompt_mob = Text ( user_prompt )
prompt_mob . to_edge ( UP )
prompt_mob [ " User: " ] . set_color ( BLUE )
self . answer_mob = Text ( " AI Assistant: " )
self . answer_mob . next_to ( prompt_mob , DOWN , buff = 1.0 , aligned_edge = LEFT )
self . answer_mob . set_color ( YELLOW )
self . og_answer_mob = self . answer_mob
self . add ( prompt_mob , self . answer_mob )
# Show multiple answer
for n in range ( 8 ) :
self . give_answer ( prompt_mob )
mark = self . judge_answer ( )
self . add ( self . og_answer_mob )
self . play ( FadeOut ( self . answer_mob ) , FadeOut ( mark ) )
self . answer_mob = self . og_answer_mob
def display_answer ( self , text ) :
new_answer_mob = get_paragraph ( text . replace ( " \n " , " " ) . split ( " " ) )
new_answer_mob [ : len ( self . og_answer_mob ) ] . match_style ( self . og_answer_mob )
new_answer_mob . move_to ( self . og_answer_mob , UL )
self . remove ( self . answer_mob )
self . answer_mob = new_answer_mob
self . add ( self . answer_mob )
def give_answer ( self , prompt_mob , max_responses = 100 ) :
answer = self . og_answer_mob . get_text ( )
user_prompt = prompt_mob . get_text ( )
for n in range ( max_responses ) :
answer , stop = self . add_to_answer ( user_prompt , answer )
if stop :
break
self . display_answer ( answer )
self . wait ( 2 / 30 )
def judge_answer ( self ) :
mark = random . choice ( [
Checkmark ( ) . set_color ( GREEN ) ,
Exmark ( ) . set_color ( RED ) ,
] )
mark . scale ( 5 )
mark . next_to ( self . answer_mob , RIGHT , aligned_edge = UP )
rect = SurroundingRectangle ( self . answer_mob )
rect . match_color ( mark )
self . play ( FadeIn ( mark , scale = 2 ) , FadeIn ( rect , scale = 1.05 ) )
self . wait ( )
return VGroup ( mark , rect )
def add_to_answer ( self , user_prompt : str , answer : str ) :
try :
tokens , probs = gpt3_predict_next_token ( " \n \n " . join ( [ user_prompt , answer ] ) )
token = random . choices ( tokens , np . array ( probs ) / sum ( probs ) ) [ 0 ]
except IndexError :
return answer , True
stop = False
if token == ' <|endoftext|> ' :
stop = True
else :
answer + = token
return answer , stop
class ContrastWithEarlierFrame ( InteractiveScene ) :
def construct ( self ) :
# Test
vline = Line ( UP , DOWN )
vline . set_height ( FRAME_HEIGHT )
self . add ( vline )
titles = VGroup (
VGroup (
Text ( " Most earlier models " ) ,
# Vector(0.75 * DOWN, thickness=4),
# Text("One word at a time")
) ,
VGroup (
Text ( " Transformers " ) ,
# Vector(0.75 * DOWN, thickness=4),
# Text("All words in parallel")
) ,
)
for title , vect in zip ( titles , [ LEFT , RIGHT ] ) :
title . arrange ( DOWN , buff = 0.2 )
title . scale ( 1.5 )
title . move_to ( FRAME_WIDTH * vect / 4 )
title . to_edge ( UP )
self . add ( titles )
class SequentialProcessing ( InteractiveScene ) :
def construct ( self ) :
# Add text
text = Text ( " Down by the river bank, where I used to go fishing ... " )
text . move_to ( 1.0 * DOWN )
words = break_into_words ( text )
rects = get_piece_rectangles ( words )
blocks = VGroup ( VGroup ( rect , word ) for rect , word in zip ( rects , words ) )
blocks . save_state ( )
self . add ( blocks )
# Vector wandering over
vect = NumericEmbedding ( )
vect . set_width ( 1.0 )
vect . next_to ( rects [ 0 ] , UP )
for n in range ( len ( blocks ) - 1 ) :
blocks . target = blocks . saved_state . copy ( )
blocks . target [ : n ] . fade ( 0.75 )
blocks . target [ n + 1 : ] . fade ( 0.75 )
self . play (
vect . animate . next_to ( blocks [ n ] , UP ) ,
MoveToTarget ( blocks )
)
self . play (
LaggedStart (
( ContextAnimation ( elem , blocks [ n ] [ 1 ] , lag_ratio = 0.01 )
for elem in vect . get_entries ( ) ) ,
lag_ratio = 0.01 ,
) ,
RandomizeMatrixEntries ( vect ) ,
run_time = 2
)
2024-10-23 13:21:07 -05:00
# Version 2
class PartialScript ( SimpleAutogregression ) :
machine_name = " Magic next \n word predictor "
machine_phi = 5 * DEGREES
machine_theta = 6 * DEGREES
def construct ( self ) :
# Set frame
frame = self . frame
self . set_floor_plane ( " xz " )
# Unfurl script
curled_script_img = ImageMobject ( " HumanAIScript " )
curled_script_img . set_height ( 7 )
curves = VGroup ( SVGMobject ( " JaggedCurl1 " ) [ 0 ] , SVGMobject ( " JaggedCurl2 " ) [ 0 ] )
for curve in curves :
curve . make_smooth ( approx = False )
curve . insert_n_curves ( 100 )
curve . set_stroke ( WHITE , 3 )
curve . set_fill ( opacity = 0 )
curve . set_height ( 5 )
curves [ 1 ] . scale ( curves [ 0 ] . get_arc_length ( ) / curves [ 1 ] . get_arc_length ( ) )
resolution = ( 2 , 200 ) # Change
surface_kw = dict ( u_range = ( - 6 , 6 ) , v_range = ( 0.05 , 0.95 ) , resolution = resolution )
curled_script_templates = Group (
ParametricSurface (
lambda u , v : ( * curve . pfp ( v ) [ : 2 ] , u ) ,
* * surface_kw
)
for curve in curves
)
curled_script_templates [ 1 ] . rotate ( PI / 2 , UP )
curled_script_templates [ 0 ] . rotate ( - PI / 2 )
flat_script_template = ParametricSurface (
lambda u , v : ( u , v , 0 ) ,
* * surface_kw
)
curled_script0 = TexturedSurface ( curled_script_templates [ 0 ] , " HumanAIScript " )
curled_script1 = TexturedSurface ( curled_script_templates [ 1 ] , " HumanAIScript " )
curled_script1_torn = TexturedSurface ( curled_script_templates [ 1 ] , " HumanAIScriptTorn " )
flat_script = TexturedSurface ( flat_script_template , " HumanAIScriptTorn " )
flat_script . replace ( curled_script_img , stretch = True )
for script in [ curled_script0 , curled_script1 ] :
script . set_shading ( 0.25 , 0.25 , 0.35 )
curled_script1_torn . set_shading ( 0 , 0 , 0 )
flat_script . set_shading ( 0 , 0 , 0 )
frame . reorient ( 0 , - 1 , 0 , ( - 0.28 , 0.69 , 0.0 ) , 14.43 )
self . play (
TransformFromCopy ( curled_script0 , curled_script1 ) ,
frame . animate . reorient ( 56 , - 17 , 0 , ( - 0.2 , - 1.52 , - 2.39 ) , 20.05 ) ,
run_time = 3
)
self . play (
frame . animate . reorient ( - 6 , - 11 , 0 , ( 1.06 , - 1.22 , - 2.65 ) , 20.05 ) ,
run_time = 8 ,
)
self . play (
FadeOut ( curled_script1 , shift = 1e-2 * IN ) ,
FadeIn ( curled_script1_torn , shift = 1e-2 * IN ) ,
)
self . play (
ReplacementTransform ( curled_script1_torn , flat_script ) ,
frame . animate . to_default_state ( ) ,
run_time = 2
)
self . wait ( )
# Show the machine
machine = self . get_transformer_drawing ( )
machine [ 1 ] . set_height ( 0.7 ) . set_stroke ( width = 2 )
machine [ 1 ] . set_opacity ( 0 )
machine . remove ( machine [ - 1 ] )
machine . set_height ( 3 )
machine . to_edge ( RIGHT )
self . play (
flat_script . animate . set_height ( 5 ) . to_edge ( LEFT ) ,
FadeIn ( machine , lag_ratio = 0.01 )
)
self . add ( machine )
self . wait ( )
# Show example input and output
out_arrow = Vector ( DOWN , thickness = 6 )
out_arrow . next_to ( machine , DOWN )
in_arrow = out_arrow . copy ( ) . next_to ( machine , UP , SMALL_BUFF )
in_text = Text ( " To be or not to _ " )
in_text [ - 1 ] . stretch ( 3 , 0 , about_edge = LEFT )
in_text . next_to ( in_arrow , UP )
prediction = Text ( " be " , font_size = 72 )
prediction . next_to ( out_arrow , DOWN )
self . play ( FadeIn ( in_text ) , GrowArrow ( in_arrow ) )
self . animate_text_input ( in_text , machine , position_text_over_machine = False )
self . play (
GrowArrow ( out_arrow ) ,
FadeIn ( prediction , DOWN ) ,
)
self . wait ( )
# Clear the board
script_text = self . get_text ( )
script_text . set_width ( 0.89 * flat_script . get_width ( ) )
script_text . next_to ( flat_script . get_top ( ) , DOWN , buff = 0.33 )
font_size = 48 * ( script_text [ 0 ] . get_height ( ) / Text ( " H " ) . get_height ( ) )
completion = " A transistor is a semiconductor device used to amplify or switch electronic signals. It consists of three layers of semiconductor material, either p-type or n-type, forming a structure with terminals called the emitter, base, and collector. "
words = completion . split ( " " )
paragraph = get_paragraph ( completion . split ( " " ) , font_size = font_size )
paragraph . next_to ( script_text , DOWN , aligned_edge = LEFT )
paragraph . set_color ( YELLOW )
self . play (
FadeIn ( script_text ) ,
FadeOut ( flat_script ) ,
FadeOut ( VGroup ( in_text , in_arrow , prediction ) ) ,
)
# Repeatedly add predictions
machine . scale ( 1.25 , about_edge = RIGHT )
out_arrow . next_to ( machine , DOWN , buff = 0.5 )
blocks = machine [ 0 ]
dials = Dial ( ) . get_grid ( 11 , 16 )
dials . set_width ( blocks [ - 1 ] . get_width ( ) * 0.95 )
dials . rotate ( 5 * DEGREES , RIGHT ) . rotate ( 10 * DEGREES , UP )
dials . move_to ( blocks [ - 1 ] )
dials . set_stroke ( opacity = 0.5 )
for dial in dials :
dial . set_value ( dial . get_random_value ( ) )
dials . set_z_index ( 2 )
self . add ( dials )
curr_answer = VGroup ( )
curr_answer . next_to ( script_text , DOWN )
for n in range ( 6 ) :
word = words [ n ]
prediction = Text ( words [ n ] , font_size = 72 )
prediction . next_to ( out_arrow , DOWN )
word_in_answer = paragraph [ len ( curr_answer ) : len ( curr_answer ) + len ( word ) ]
word_in_answer . set_color ( YELLOW )
mover = VGroup ( script_text , curr_answer ) . copy ( )
if n > 2 :
self . skip_animations = True
self . play (
mover . animate . set_height ( 1.8 ) . next_to ( machine , UP , SMALL_BUFF ) . set_anim_args ( path_arc = - 30 * DEGREES ) ,
)
self . animate_text_input (
mover , machine ,
position_text_over_machine = False ,
lag_ratio = 1e-3
)
self . play ( FadeIn ( prediction , DOWN , rate_func = rush_from , run_time = 0.5 ) )
if n > 2 :
self . skip_animations = False
self . wait ( 0.5 )
self . skip_animations = True
self . play (
curr_answer . animate . set_color ( WHITE ) ,
Transform ( prediction , word_in_answer ) ,
FadeOut ( mover ) ,
)
curr_answer . add ( * word_in_answer )
self . add ( curr_answer )
self . remove ( prediction )
def get_text ( self ) :
script_text = Text ( """
Human :
Can you explain the history of
transistors and how they ' re relevant
to computers ? What is a transistor ,
and how exactly is it used to
perform computations ?
AI assistant :
""" , alignment= " LEFT " )
script_text [ " Human " ] . set_color ( BLUE )
script_text [ " AI assistant " ] . set_color ( TEAL )
script_text . set_height ( 4 ) . to_edge ( UP )
return script_text
def create_image ( self ) :
# Create image
script_text = self . get_text ( )
script_text . set_fill ( BLACK )
script_text [ " Human " ] . set_fill ( BLUE_D )
script_text [ " AI assistant " ] . set_fill ( TEAL_D )
self . add ( FullScreenRectangle ( fill_color = " #FCF5E5 " , fill_opacity = 1 ) )
self . add ( script_text )
# Add off test
tear_off = SVGMobject ( ' TearOff ' )
tear_off . set_stroke ( width = 0 )
tear_off . set_fill ( BLACK , 1 )
tear_off . set_width ( 7.5 )
tear_off . next_to ( script_text , DOWN , buff = - 0.2 )
self . add ( tear_off )
class ShowMachineWithDials ( PredictTheNextWord ) :
words = [ ' worst ' , ' age ' , ' worse ' , ' best ' , ' most ' , ' end ' , ' very ' , ' blur ' ]
logprobs = [ 4.0 , 2.15 , 1.89 , 1.4 , 0.1 , - 0.18 , - 0.23 , - 0.61 ]
def construct ( self ) :
# Show machine (same position as in PredictTheNextWord)
frame = self . frame
self . set_floor_plane ( " xz " )
blocks , llm_text , flat_dials , last_dials = self . get_blocks_and_dials ( )
self . clear ( )
self . add ( frame )
frame . reorient ( 0 , 0 , 0 , ( - 0.17 , - 0.12 , 0.0 ) , 4.50 )
self . add ( blocks , llm_text , last_dials )
# Prepare dial highlight
last_dials . target = last_dials . generate_target ( )
self . fix_dials ( last_dials . target )
small_rect = SurroundingRectangle ( last_dials [ 0 ] , buff = 0.025 )
small_rect . set_stroke ( BLUE , 2 )
big_rect = small_rect . copy ( ) . scale ( 4 )
big_rect . next_to ( blocks , UP , buff = SMALL_BUFF , aligned_edge = LEFT + OUT )
big_rect . shift ( 1.5 * RIGHT )
big_dial = last_dials [ 0 ] . copy ( ) . scale ( 4 ) . set_stroke ( opacity = 1 )
big_dial . move_to ( big_rect )
rect_lines = VGroup (
Line ( small_rect . get_corner ( UL ) , big_rect . get_corner ( DL ) ) ,
Line ( small_rect . get_corner ( UR ) , big_rect . get_corner ( DR ) ) ,
)
rect_lines . set_stroke ( WHITE , width = ( 1 , 3 ) )
highlighed_parameter_group = VGroup ( small_rect , rect_lines , big_rect , big_dial )
last_dials . set_stroke ( width = 1 , opacity = 1 )
self . play (
MoveToTarget ( last_dials ) ,
FadeOut ( llm_text ) ,
FadeIn ( small_rect ) ,
)
# Show an example input and output
example = self . get_example ( blocks )
in_text , in_arrow , out_arrow , bar_groups = example
logprobs = example . logprobs
true_probs = 100 * softmax ( logprobs )
bar_groups = self . get_output_distribution ( self . words , 0.1 * logprobs , out_arrow )
self . play (
LaggedStart (
ShowCreation ( rect_lines , lag_ratio = 0 ) ,
TransformFromCopy ( small_rect , big_rect ) ,
TransformFromCopy ( last_dials [ 0 ] , big_dial ) ,
FadeIn ( in_text ) ,
GrowArrow ( in_arrow ) ,
FadeIn ( bar_groups ) ,
GrowArrow ( out_arrow ) ,
) ,
frame . animate . reorient ( 0 , 0 , 0 , ( - 0.43 , 0.38 , 0.0 ) , 7.05 ) ,
run_time = 2
)
self . play (
last_dials [ 0 ] . animate_set_value ( 0.8 ) ,
big_dial . animate_set_value ( 0.8 ) ,
LaggedStart (
( dial . animate_set_value ( dial . get_random_value ( ) )
for dial in last_dials [ 1 : ] ) ,
lag_ratio = 1.0 / len ( last_dials ) ,
) ,
* (
self . bar_group_change_animation ( bg , value )
for bg , value in zip ( bar_groups [ : - 1 ] , true_probs )
) ,
run_time = 3
)
self . wait ( )
# Play around tweaking the parameters, and seeing the output change
self . play (
LaggedStart (
( dial . animate_set_value ( 0 )
for dial in last_dials [ : 12 ] ) ,
lag_ratio = 0.01 ,
) ,
big_dial . animate_set_value ( 0 ) ,
self . bar_group_change_animation ( bar_groups [ 0 ] , 50 ) ,
self . bar_group_change_animation ( bar_groups [ 1 ] , 34 ) ,
self . bar_group_change_animation ( bar_groups [ 2 ] , 5 ) ,
run_time = 4 ,
)
self . play (
LaggedStart (
( dial . animate_set_value ( 1 )
for dial in last_dials [ : 12 ] ) ,
lag_ratio = 0.01 ,
) ,
big_dial . animate_set_value ( 1 ) ,
self . bar_group_change_animation ( bar_groups [ 0 ] , 80 ) ,
self . bar_group_change_animation ( bar_groups [ 1 ] , 5 ) ,
self . bar_group_change_animation ( bar_groups [ 2 ] , 15 ) ,
run_time = 4 ,
)
self . wait ( )
# Mention randomness
random_words = Text ( " Initially random " )
random_words . next_to ( blocks , UP )
random_words . set_color ( RED )
out_dots = Tex ( R " ... " , font_size = 120 )
out_dots . next_to ( out_arrow , RIGHT )
self . play (
FadeOut ( big_rect ) ,
Uncreate ( rect_lines , lag_ratio = 0 ) ,
FadeOut ( small_rect ) ,
Transform ( big_dial , last_dials [ 0 ] )
)
self . play (
Write ( random_words ) ,
LaggedStart (
( dial . animate_set_value ( dial . get_random_value ( ) )
for dial in last_dials ) ,
lag_ratio = 0.5 / len ( last_dials ) ,
run_time = 2
) ,
FadeOut ( bar_groups ) ,
)
self . play ( Write ( out_dots ) )
self . wait ( )
self . play (
FadeOut ( dots ) ,
FadeOut ( random_words ) ,
FadeIn ( bar_groups ) ,
)
# Show many many parameters
example . save_state ( )
blocks . save_state ( )
last_dials . save_state ( )
all_dials = VGroup ( * flat_dials , * last_dials )
all_dials . generate_target ( )
all_dials . target . space_out_submobjects ( 3 )
new_dials = VGroup (
all_dials . target . copy ( ) . shift ( 3 * 2 * x * ( flat_dials . get_center ( ) - last_dials . get_center ( ) ) )
for x in range ( 1 , 9 )
)
self . play (
FadeOut ( example ) ,
FadeOut ( blocks ) ,
FadeIn ( flat_dials ) ,
FadeOut ( bar_groups ) ,
FadeOut ( out_arrow ) ,
)
self . play (
FadeOut ( highlighed_parameter_group ) ,
MoveToTarget ( all_dials ) ,
LaggedStart (
( TransformFromCopy ( all_dials . copy ( ) . set_opacity ( 0 ) , nd )
for nd in new_dials ) ,
lag_ratio = 0.05 ,
) ,
frame . animate . reorient ( - 9 , 0 , 0 , ( - 0.71 , - 0.07 , - 0.06 ) , 9.64 ) ,
run_time = 4
)
self . wait ( )
def get_blocks_and_dials ( self ) :
machine = self . get_transformer_drawing ( )
machine . move_to ( ORIGIN )
self . machine = machine
blocks = machine [ 0 ]
llm_text = machine [ 1 ]
llm_text . set_backstroke ( BLACK , 2 )
flat_dials , last_dials = self . get_machine_dials ( blocks )
return blocks , llm_text , flat_dials , last_dials
def get_example ( self , blocks ) :
in_text = Text ( " It was the best \n of times it was \n the _ " , alignment = " LEFT " )
in_text [ - 1 ] . stretch ( 4 , 0 , about_edge = LEFT )
in_text . next_to ( blocks , LEFT , LARGE_BUFF )
in_arrow = Arrow ( in_text , blocks )
out_arrow = Vector ( RIGHT )
out_arrow . next_to ( blocks [ - 1 ] , RIGHT , buff = 0.1 )
logprobs = np . array ( self . logprobs )
bar_groups = self . get_output_distribution ( self . words , logprobs , out_arrow )
example = VGroup ( in_text , in_arrow , out_arrow , bar_groups )
example . logprobs = logprobs
return example
def fix_dials ( self , dials ) :
for dial in dials :
dial . set_stroke ( width = 1 , opacity = 1 )
dial . needle . set_stroke ( width = ( 2 , 0 ) )
return dials
def bar_group_change_animation ( self , bar_group , new_value ) :
text , rect , value_mob = bar_group
buff = value_mob . get_left ( ) - rect . get_right ( )
factor = new_value / value_mob . get_value ( )
return AnimationGroup (
rect . animate . stretch ( factor , 0 , about_edge = LEFT ) ,
ChangeDecimalToValue ( value_mob , new_value ) ,
UpdateFromFunc ( text , lambda m : value_mob . move_to ( rect . get_right ( ) + buff , LEFT ) ) ,
)
def get_output_distribution ( self , words , logprobs , out_arrow ) :
probs = softmax ( logprobs )
bar_groups = self . get_distribution ( words , probs , self . machine , width_100p = 1.0 )
bar_groups . next_to ( out_arrow , RIGHT )
return bar_groups
class ShowSingleTrainingExample ( ShowMachineWithDials ) :
logprobs = [ 4.0 , 6.15 , 1.89 , 1.4 , 0.1 , - 0.18 , - 0.23 , - 0.61 ]
def construct ( self ) :
# Add state from before
frame = self . frame
self . set_floor_plane ( " xz " )
blocks , llm_text , flat_dials , last_dials = self . get_blocks_and_dials ( )
self . fix_dials ( last_dials )
example = self . get_example ( blocks )
in_text , in_arrow , out_arrow , bar_groups = example
self . add ( blocks , last_dials )
# Show example up top
parts = ( " It was the best of times it was the " , " worst " )
sentence = Text ( " " . join ( parts ) )
start = sentence [ parts [ 0 ] ] [ 0 ]
end = sentence [ parts [ 1 ] ] [ 0 ]
sentence . set_width ( 10 )
sentence . next_to ( blocks , UP , buff = 1.5 )
start_rect = SurroundingRectangle ( start )
start_rect . set_stroke ( BLUE , 2 )
start_rect . set_fill ( BLUE , 0.2 )
end_rect = SurroundingRectangle ( end )
end_rect . match_height ( start_rect , stretch = True ) . match_y ( start_rect )
end_rect . set_stroke ( YELLOW , 2 )
end_rect . set_fill ( YELLOW , 0.2 )
arrow = Arrow ( start_rect . get_top ( ) , end_rect . get_top ( ) , path_arc = - 90 * DEGREES , thickness = 5 )
arrow . set_fill ( border_width = 1 )
frame . reorient ( 0 , 0 , 0 , ( - 0.36 , 0.97 , 0.0 ) , 7.52 )
self . play ( FadeIn ( sentence , UP ) )
self . play (
LaggedStartMap ( DrawBorderThenFill , VGroup ( start_rect , end_rect ) ) ,
FadeIn ( arrow ) ,
)
self . remove ( last_dials )
self . play ( LaggedStart (
AnimationGroup (
TransformFromCopy ( start , in_text [ : - 1 ] ) ,
TransformFromCopy ( end_rect , in_text [ - 1 ] ) ,
FadeIn ( in_arrow )
) ,
LaggedStart (
(
block . animate . set_color (
block . get_color ( ) if block is blocks [ - 1 ] else TEAL
) . set_anim_args ( rate_func = there_and_back )
for block in blocks
) ,
group = blocks ,
lag_ratio = 0.1 ,
run_time = 1
) ,
Animation ( last_dials ) ,
GrowArrow ( out_arrow ) ,
LaggedStartMap ( GrowFromPoint , bar_groups , point = out_arrow . get_start ( ) ) ,
lag_ratio = 0.3
) )
self . wait ( )
# Flag bad prediction
out_rects = VGroup (
SurroundingRectangle ( bg )
for bg in bar_groups [ : 2 ]
)
out_rects . set_stroke ( RED , 3 )
annotations = VGroup (
Tex ( tex , font_size = 60 ) . next_to ( rect , LEFT , buff = SMALL_BUFF )
for rect , tex in zip ( out_rects , [ R " \ uparrow " , R " \ downarrow " ] )
)
annotations . set_color ( RED )
self . play (
FadeTransform ( end_rect . copy ( ) , out_rects [ 0 ] ) ,
Write ( annotations [ 0 ] ) ,
)
self . wait ( )
self . play (
FadeTransform ( * out_rects ) ,
FadeTransform ( * annotations ) ,
)
self . wait ( )
self . play (
FadeOut ( out_rects [ 1 ] ) ,
FadeOut ( annotations [ 1 ] ) ,
)
# Adjust
self . play (
LaggedStart (
( dial . animate_set_value ( dial . get_random_value ( ) )
for dial in last_dials ) ,
lag_ratio = 1.0 / len ( last_dials ) ,
) ,
LaggedStart (
( FlashAround ( dial , stroke_width = 2 , color = YELLOW , time_width = 1 , buff = 0.025 ) for dial in last_dials ) ,
lag_ratio = 1.0 / len ( last_dials ) ,
) ,
self . bar_group_change_animation ( bar_groups [ 0 ] , 70 ) ,
self . bar_group_change_animation ( bar_groups [ 1 ] , 20 ) ,
self . bar_group_change_animation ( bar_groups [ 2 ] , 8 ) ,
run_time = 6
)
class ParameterWeight ( InteractiveScene ) :
def construct ( self ) :
# Test
text = Text ( " Parameter / Weight " , font_size = 72 )
text . to_edge ( UP )
text . set_color ( YELLOW )
param = text [ " Parameter " ] [ 0 ]
param . save_state ( )
param . set_x ( 0 )
self . play ( Write ( param ) )
self . wait ( )
self . play ( LaggedStart (
Restore ( param ) ,
FadeIn ( text [ " / Weight " ] ) ,
) )
self . wait ( )
class LargeInLargeLanguageModel ( InteractiveScene ) :
def construct ( self ) :
# Test
text = Text ( " Large Language Model " , font_size = 72 )
text . to_edge ( UP )
large = text [ " Large " ] [ 0 ]
large . save_state ( )
large . set_x ( 0 )
self . add ( large )
self . play ( FlashUnder ( large ) , large . animate . set_color ( YELLOW ) )
self . play (
Restore ( large , path_arc = - 30 * DEGREES ) ,
Write ( text [ len ( large ) : ] , time_span = ( 0.5 , 1.5 ) )
)
self . wait ( )
class ThousandsOfWords ( InteractiveScene ) :
def construct ( self ) :
# Find passage
file = Path ( " /Users/grant/3Blue1Brown Dropbox/3Blue1Brown/videos/2024/transformers/data/tale_of_two_cities.txt " )
novel = file . read_text ( )
start_index = novel . index ( " It was the best of times " )
end_index = novel . index ( " There were a king with a large jaw " )
# Add text
passage = novel [ start_index : start_index + 5000 ] . replace ( " \n " , " " )
text = get_paragraph ( passage . split ( " " ) , line_len = 150 )
text . set_width ( 14 )
text . to_edge ( UP )
self . add ( text )
class EnormousAmountOfTrainingText ( PremiseOfMLWithText ) :
def construct ( self ) :
# Setup
self . init_data ( )
# n_rows = n_cols = 41
n_rows = n_cols = 9
screens = VGroup ( )
for row in range ( n_rows ) :
for col in range ( n_cols ) :
screen = self . get_screen ( )
screen . move_to ( FRAME_WIDTH * row * RIGHT + FRAME_HEIGHT * col * DOWN )
screens . add ( screen )
screens . center ( )
screens . submobjects . sort ( key = lambda sm : get_norm ( sm . machine . get_center ( ) ) )
self . add ( screens )
# Add frame growth
frame = self . frame
frame . clear_updaters ( )
frame . add_updater ( lambda m : m . set_height ( FRAME_HEIGHT * np . exp ( 0.2 * self . time ) ) )
# Show lots of new data
inner_screens = screens [ : 25 ]
n_examples = 20
for n in range ( n_examples ) :
self . play ( LaggedStart (
* ( self . change_example_animation ( screen , show_dial_change = True )
for screen in inner_screens ) ,
lag_ratio = 0.1 ,
run_time = 0.5 ,
) )
def change_example_animation ( self , screen , show_dial_change = True ) :
new_example = VGroup ( * self . new_input_output_example ( * screen . arrows ) )
time_span = ( 0 , 0.35 )
anims = [
FadeOut ( screen . training_example , time_span = time_span ) ,
FadeIn ( new_example , time_span = time_span ) ,
]
if show_dial_change :
anims . append ( screen . machine . random_change_animation ( run_time = 0.5 ) )
screen . training_example = new_example
return AnimationGroup ( * anims )
def get_screen ( self ) :
border = FullScreenRectangle ( )
border . set_fill ( opacity = 0 )
border . set_stroke ( WHITE , 2 )
machine = MachineWithDials ( width = 3.5 , height = 2.5 , n_rows = 5 , n_cols = 7 )
machine . move_to ( 1.0 * RIGHT )
in_arrow , out_arrow = arrows = Vector ( RIGHT ) . replicate ( 2 )
in_arrow . next_to ( machine , LEFT )
out_arrow . next_to ( machine , RIGHT )
in_data , out_data = training_example = VGroup (
* self . new_input_output_example ( in_arrow , out_arrow )
)
screen = VGroup (
border , machine ,
arrows , training_example
)
screen . border = border
screen . machine = machine
screen . arrows = arrows
screen . training_example = training_example
return screen
def new_input_output_example ( self , in_arrow , out_arrow ) :
in_data , out_data = super ( ) . new_input_output_example ( in_arrow , out_arrow )
in_data . scale ( 0.8 , about_edge = RIGHT )
out_data . scale ( 0.8 , about_edge = LEFT )
return in_data , out_data
class BadChatBot ( InteractiveScene ) :
def construct ( self ) :
# Add bot
bot = self . get_bot ( )
bot . set_height ( 3 )
lines = Line ( LEFT , RIGHT ) . get_grid ( 4 , 1 , buff = 0.25 )
lines . set_stroke ( WHITE , 1 )
lines [ - 1 ] . stretch ( 0.5 , 0 , about_edge = LEFT )
lines . set_width ( 3 )
bubble = SpeechBubble ( lines , buff = MED_LARGE_BUFF )
bubble . set_stroke ( width = 5 )
bubble . pin_to ( bot ) . shift ( DOWN )
self . add ( bot )
self . play ( Write ( bubble , run_time = 3 ) )
self . blink ( bot )
self . wait ( )
# Make lines bad
self . play (
LaggedStart (
( Transform ( line , self . get_scribble ( line ) )
for line in lines ) ,
lag_ratio = 0.1 ,
run_time = 2
)
)
for _ in range ( 2 ) :
self . blink ( bot )
self . wait ( 2 )
def get_scribble ( self , line ) :
freqs = np . random . random ( 5 )
graph = FunctionGraph (
lambda x : 0.05 * sum ( math . sin ( freq * TAU * x ) for freq in freqs ) ,
x_range = ( 0 , 5 , 0.1 )
)
graph . put_start_and_end_on ( * line . get_start_and_end ( ) )
graph . match_style ( line )
graph . set_stroke ( color = RED )
return graph
def get_bot ( self ) :
bot = SVGMobject ( " Bot " )
subpaths = bot [ 0 ] . get_subpaths ( )
bot [ 0 ] . set_points ( [ * subpaths [ 0 ] , subpaths [ 0 ] [ - 1 ] , * subpaths [ 1 ] ] )
eyes = VGroup ( Dot ( ) . replace ( VMobject ( ) . set_points ( subpath ) ) for subpath in subpaths [ 2 : ] )
bot . eyes = eyes
bot . add ( eyes )
bot . set_stroke ( width = 0 )
bot . set_height ( 4 )
bot . set_fill ( GREY_B )
bot . set_shading ( 0.5 , 0.5 , 1 )
return bot
def blink ( self , bot ) :
self . play (
bot . eyes . animate . stretch ( 0 , 1 ) . set_anim_args ( rate_func = squish_rate_func ( there_and_back ) )
)
class WriteRLHF ( InteractiveScene ) :
def construct ( self ) :
text = Text ( " Step 2: RLHF " )
full_text = Text ( " Reinforcement Learning \n with Human Feedback " )
full_text . next_to ( text , UP , LARGE_BUFF )
full_text . align_to ( text , RIGHT ) . shift ( RIGHT )
initials = VGroup ( full_text [ letter [ 0 ] ] [ 0 ] [ 0 ] for letter in " RLHF " )
full_text . remove ( * initials )
self . add ( text )
self . wait ( )
self . play (
TransformFromCopy ( text [ " RLHF " ] [ 0 ] , initials , lag_ratio = 0.25 ) ,
Write ( full_text , time_span = ( 1.5 , 3 ) ) ,
run_time = 3
)
self . wait ( )
class RLHFWorker ( InteractiveScene ) :
def construct ( self ) :
# Test
self . add ( FullScreenRectangle ( ) . set_fill ( GREY_E , 1 ) )
# worker = SVGMobject("computer_stall")
worker = SVGMobject ( " comp_worker " )
worker . set_height ( 4 )
worker . move_to ( 4 * LEFT )
worker . set_fill ( GREY_C , 1 )
rect = Rectangle ( 7 , 5 )
rect . to_edge ( RIGHT )
rect . set_stroke ( WHITE , 2 )
rect . set_fill ( BLACK , 1 )
self . add ( worker )
self . add ( rect )
class RLHFWorkers ( ShowMachineWithDials ) :
def construct ( self ) :
# Add workers
self . add ( FullScreenRectangle ( ) . set_fill ( GREY_E , 1 ) )
workers = SVGMobject ( " comp_worker " ) . get_grid ( 3 , 2 , buff = 0.5 )
workers . set_height ( 7 )
workers . to_edge ( LEFT )
workers . set_fill ( GREY_C , 1 )
self . add ( workers )
# Machine
blocks , llm_text , flat_dials , last_dials = self . get_blocks_and_dials ( )
machine = VGroup ( blocks , last_dials )
machine . set_height ( 4 )
machine . center ( ) . to_edge ( RIGHT , buff = LARGE_BUFF )
last_dials . set_stroke ( opacity = 1 )
self . add ( machine )
for _ in range ( 8 ) :
self . play ( LaggedStart (
( dial . animate_set_value ( dial . get_random_value ( ) )
for dial in last_dials ) ,
lag_ratio = 0.5 / len ( last_dials ) ,
run_time = 2
) )
self . wait ( )
class SerialProcessing ( InteractiveScene ) :
phrase = " It was the best of times it was the worst of times "
phrase_center = 2 * UP
def construct ( self ) :
# Set up words
words = self . get_words ( )
rects = get_piece_rectangles ( words )
self . add ( rects )
self . add ( words )
# Animate in the vectors
vectors = VGroup (
self . get_abstract_vector ( ) . next_to ( word , DOWN , LARGE_BUFF )
for word in words
)
last_vect = VGroup ( VectorizedPoint ( rects [ 0 ] . get_bottom ( ) ) )
for word , vect in zip ( words , vectors ) :
self . play (
FadeIn ( vect , run_time = 2 ) ,
LaggedStart (
( ContextAnimation (
square , VGroup ( * word , * last_vect ) ,
direction = DOWN ,
lag_ratio = 0.01 ,
path_arc = 30 * DEGREES
)
for square in vect ) ,
lag_ratio = 0.05 ,
run_time = 2
) ,
last_vect . animate . set_opacity ( 0.2 )
)
last_vect = vect
def get_words ( self ) :
result = break_into_words ( Text ( self . phrase ) )
result . move_to ( self . phrase_center )
return result
def get_abstract_vector ( self , values = None , default_length = 10 , elem_size = 0.2 ) :
if values is None :
values = np . random . uniform ( - 1 , 1 , default_length )
result = Square ( ) . get_grid ( len ( values ) , 1 , buff = 0 )
result . set_width ( elem_size )
result . set_stroke ( WHITE , 1 )
for square , value in zip ( result , values ) :
color = value_to_color ( value , min_value = 0 , max_value = 1 )
square . set_fill ( color , opacity = 1 )
return result
class ParallelProcessing ( SerialProcessing ) :
def construct ( self ) :
# Set up words
words = self . get_words ( )
rects = get_piece_rectangles ( words )
self . add ( rects )
self . add ( words )
# Animate in the vectors
vectors = VGroup (
self . get_abstract_vector ( ) . next_to ( word , DOWN , buff = 1.5 )
for word in words
)
lines = VGroup (
Line (
rect . get_bottom ( ) , vect . get_top ( ) ,
buff = 0.05 ,
stroke_color = WHITE ,
stroke_width = 2 * random . random ( ) * * 3
)
for rect in rects
for vect in vectors
)
lines . shuffle ( )
for vect , word in zip ( vectors , words ) :
vect . save_state ( )
for square in vect :
square . move_to ( word )
square . set_opacity ( 0 )
self . play (
LaggedStartMap ( ShowCreation , lines , lag_ratio = 0.01 ) ,
LaggedStartMap ( Restore , vectors , lag_ratio = 0 )
)
self . play ( lines . animate . set_stroke ( opacity = 0.25 ) )
self . wait ( )
class ManyComputationsPerUnitTimeV2 ( InteractiveScene ) :
def construct ( self ) :
# Add computations
box = Rectangle ( 5 , 5 )
label = Text ( " 1 Billion computations per Second " )
label . next_to ( box , UP )
self . add ( box )
self . add ( label )
comps = self . get_computations ( box )
self . add ( comps )
self . wait ( 3 )
# Place box into minute interval
width = FRAME_WIDTH - 1
number_lines = VGroup (
minute_line := NumberLine ( ( 0 , 60 , 1 ) , width = width , big_tick_spacing = 10 ) ,
hour_line := NumberLine ( ( 0 , 60 , 1 ) , width = width , big_tick_spacing = 10 ) ,
day_line := NumberLine ( ( 0 , 24 , 1 ) , width = width , big_tick_spacing = 6 ) ,
month_line := NumberLine ( ( 0 , 31 , 1 ) , width = width ) ,
year_line := NumberLine ( ( 0 , 12 , 1 ) , width = width ) ,
y100_line := NumberLine ( ( 0 , 100 , 1 ) , width = width ) ,
y10k_line := NumberLine ( ( 0 , 100 , 1 ) , width = width ) ,
y1M_line := NumberLine ( ( 0 , 100 , 1 ) , width = width ) ,
y100M_line := NumberLine ( ( 0 , 100 , 1 ) , width = width ) ,
)
number_lines . move_to ( DOWN )
first_ticks = minute_line . ticks [ : 2 ]
sec_brace = Brace ( first_ticks , DOWN , buff = 0 , tex_string = R " \ underbrace { \ qquad \ qquad} " )
sec_label = Text ( " Second " , font_size = 30 ) . next_to ( sec_brace , DOWN , SMALL_BUFF )
self . play (
ShowCreation ( minute_line , lag_ratio = 0.01 ) ,
box . animate . match_width ( first_ticks ) . move_to ( first_ticks . get_center ( ) , DOWN ) . set_stroke ( width = 1 ) ,
TransformFromCopy ( label [ " Second " ] [ 0 ] , sec_label ) ,
GrowFromCenter ( sec_brace ) ,
run_time = 2
)
# Add other boxes
minute_label = self . get_timeline_full_label ( number_lines [ 1 ] , " Minute " )
new_boxes = VGroup (
box . copy ( ) . move_to ( tick . get_center ( ) , DL )
for tick in minute_line . ticks [ 1 : - 1 ]
)
for new_box in new_boxes :
new_box . save_state ( )
new_box . move_to ( box )
computations = VGroup (
self . get_computations ( new_box , n_iterations = 1 )
for new_box in new_boxes
)
# computations = VGroup() # If needed
self . add ( computations )
self . play (
FadeIn ( minute_label , DOWN ) ,
LaggedStartMap ( Restore , new_boxes , lag_ratio = 0.1 ) ,
run_time = 2
)
self . wait ( 2 )
# Add labels
minute_line . add ( minute_label )
names = [ " Hour " , " Day " , " Month " , " Year " , " 100 Years " , " 10,000 Years " , " 1,000,000 Years " , " 100,000,000 Years " ]
for line , name in zip ( number_lines [ 1 : ] , names ) :
line . label = self . get_timeline_full_label ( line , name )
line . add ( line . label )
# Arrange all lines
number_lines [ 1 : ] . arrange ( DOWN , buff = 2.0 )
number_lines [ 1 : ] . next_to ( minute_line , DOWN , buff = 2.0 )
scale_lines = VGroup ( )
for nl1 , nl2 in zip ( number_lines , number_lines [ 1 : ] ) :
n = len ( nl2 . ticks ) / / 2
mini_line = Line ( nl2 . ticks [ n - 1 ] . get_center ( ) , nl2 . ticks [ n ] . get_center ( ) )
pair = VGroup (
DashedLine ( nl1 . get_start ( ) , mini_line . get_start ( ) ) ,
DashedLine ( nl1 . get_end ( ) , mini_line . get_end ( ) ) ,
)
pair . set_stroke ( WHITE , 2 )
nl1 . target = nl1 . copy ( )
nl1 . target . replace ( mini_line , dim_to_match = 0 )
nl1 . target . shift ( mini_line . pfp ( 0.5 ) - nl1 . target . pfp ( 0.5 ) )
scale_lines . add ( pair )
# Start panning down
lag_ratio = 1.5
self . play (
LaggedStart (
* ( AnimationGroup ( * ( ShowCreation ( sl ) for sl in pair ) ) for pair in scale_lines ) ,
lag_ratio = lag_ratio ,
) ,
LaggedStart (
* ( FadeIn ( nl ) for nl in number_lines [ 1 : ] ) ,
lag_ratio = lag_ratio ,
) ,
LaggedStart (
* ( TransformFromCopy ( nl , nl . target ) for nl in number_lines [ : - 1 ] ) ,
lag_ratio = lag_ratio ,
) ,
self . frame . animate . set_y ( number_lines [ - 1 ] . get_y ( ) + 2 ) . set_width ( 18 ) . set_anim_args (
rate_func = lambda t : interpolate ( smooth ( t ) , linear ( t ) , there_and_back_with_pause ( t , pause_ratio = 0.8 ) )
) ,
run_time = 30
)
self . play ( self . frame . animate . reorient ( 0 , 0 , 0 , ( - 0.03 , - 11.55 , 0.0 ) , 31.76 ) , run_time = 4 )
self . wait ( 4 )
def fade_in_bigger_interval ( self , new_interval , prev_interval , fader , scale_factor , added_anims = [ ] ) :
pivot = prev_interval . n2p ( 0 )
new_interval . save_state ( )
new_interval . scale ( scale_factor , about_point = pivot )
new_interval [ : - 1 ] . set_opacity ( 0 )
new_interval [ - 1 ] . set_fill ( BLACK )
self . play (
Restore ( new_interval ) ,
prev_interval . animate . scale ( 1.0 / scale_factor , about_point = pivot ) . set_fill ( border_width = 0 ) ,
fader . animate . scale ( 1.0 / scale_factor , about_point = pivot ) . set_opacity ( 0 ) ,
* added_anims ,
run_time = 4 ,
rate_func = rush_from
)
self . remove ( fader )
def get_timeline_full_label ( self , timeline , name ) :
brace = Brace ( Line ( ) . set_width ( 7 ) , UP , buff = MED_SMALL_BUFF )
brace . set_fill ( border_width = 5 )
brace . match_width ( timeline )
brace . next_to ( timeline , UP , buff = MED_SMALL_BUFF )
label = Text ( name , font_size = 72 )
label . next_to ( brace , UP , MED_SMALL_BUFF )
label . next_to ( timeline , DOWN )
return label
return VGroup ( brace , label )
def get_computations ( self , box , n_lines = 10 , n_iterations = 3 , n_digits = 4 , cycle_time = 0.5 ) :
# Try adding lines
lines = VGroup ( )
for iteration in range ( n_iterations ) :
cluster = VGroup ( )
for n in range ( n_lines ) :
x = random . uniform ( 0 , 10 * * ( n_digits ) )
y = random . uniform ( 0 , 10 * * ( n_digits ) )
if random . choice ( [ True , False ] ) :
comb = x * y
sym = Tex ( R " \ times " )
else :
comb = x + y
sym = Tex ( R " + " )
line = VGroup (
DecimalNumber ( x , num_decimal_places = 3 ) , sym ,
DecimalNumber ( y , num_decimal_places = 3 ) , Tex ( " = " ) ,
DecimalNumber ( comb , num_decimal_places = 3 )
)
line . arrange ( RIGHT , buff = SMALL_BUFF )
lines . add ( line )
cluster . add ( line )
cluster . arrange ( DOWN , buff = MED_LARGE_BUFF , aligned_edge = LEFT )
cluster . set_max_height ( 0.9 * box . get_height ( ) )
cluster . set_max_width ( 0.9 * box . get_width ( ) )
cluster . move_to ( box )
# Add updater
def update_lines ( lines ) :
sigma = 0.12
alpha = ( self . time / ( cycle_time * n_iterations ) ) % 1
step = 1.0 / len ( lines )
for n , line in enumerate ( lines ) :
x = min ( (
abs ( a - n * step )
for a in ( alpha - 1 , alpha , alpha + 1 )
) )
y = np . exp ( - x * * 2 / sigma * * 2 )
line . set_fill ( opacity = y )
lines . set_height ( 0.9 * box . get_height ( ) )
lines . move_to ( box )
lines . clear_updaters ( )
lines . add_updater ( update_lines )
return lines
def old ( self ) :
# Repeatedly scale down
to_fade = VGroup ( sec_brace , sec_label , box , comps , new_boxes , computations )
scale_factors = [ 60 , 24 , 365 , 1000 ]
for new_int , prev_int , scale_factor in zip ( number_lines [ 1 : ] , number_lines [ 0 : ] , scale_factors ) :
self . fade_in_bigger_interval (
new_int , prev_int , to_fade , scale_factor ,
added_anims = [ label . animate . set_opacity ( 0 ) ] ,
)
self . wait ( 2 )
to_fade = prev_int
# Multiply last line by 100
self . fade_in_bigger_interval (
y1M_line , millenium_line , year_line , 1000 ,
added_anims = [ self . frame . animate . reorient ( 0 , 0 , 0 , ( - 3.51 , - 5.18 , 0.0 ) , 12.93 ) ] ,
)
lines = Line ( LEFT , RIGHT ) . replicate ( 100 )
lines . match_width ( y1M_line )
lines . arrange_to_fit_height ( 10 )
lines . sort ( lambda p : - p [ 1 ] )
lines . set_stroke ( WHITE , 1 )
lines . move_to ( y1M_line [ 0 ] . get_center ( ) , UP )
side_brace , label100M = self . get_timeline_full_label ( y1M_line , " 100,000,000 Years " )
side_brace . rotate ( PI / 2 )
side_brace . match_height ( lines )
side_brace . next_to ( lines , LEFT )
label100M . next_to ( side_brace , LEFT )
self . play (
LaggedStart (
( TransformFromCopy ( lines [ 0 ] . copy ( ) . set_opacity ( 0 ) , line )
for line in lines ) ,
lag_ratio = 0.03 ,
run_time = 2
) ,
FadeIn ( side_brace , scale = 10 , shift = 2 * DOWN , time_span = ( 1 , 2 ) ) ,
FadeIn ( label100M , time_span = ( 1 , 2 ) ) ,
)
self . wait ( )
class VectorLabel ( InteractiveScene ) :
def construct ( self ) :
# Test
brace = Brace ( Line ( 4 * UP , ORIGIN ) , LEFT )
brace . center ( )
brace . set_stroke ( WHITE , 3 )
text = Text ( " Vector " , font_size = 90 )
text . next_to ( brace , LEFT , MED_SMALL_BUFF )
text . shift ( SMALL_BUFF * UP )
self . play (
GrowFromCenter ( brace ) ,
Write ( text )
)
self . play (
FlashUnder ( text , color = YELLOW )
)
self . wait ( )
class ParameterToVectorAnnotation ( InteractiveScene ) :
def construct ( self ) :
# Test
dials = VGroup ( Dial ( value_range = ( - 10 , 10 , 1 ) ) for _ in range ( 10 ) )
dials . arrange ( DOWN )
dials . set_height ( 5 )
values = [ 1 , 4.3 , 2 , 0.9 , - 1.5 , 2.9 , - 1.2 , 7.8 , 0 , - 2.3 ]
arrows = VGroup (
Vector ( 0.5 * RIGHT , thickness = 2 ) . next_to ( dial , RIGHT , buff = SMALL_BUFF )
for dial in dials
)
self . play (
Write ( dials , lag_ratio = 0.01 ) ,
LaggedStartMap ( GrowArrow , arrows ) ,
)
self . play ( LaggedStart (
( dial . animate_set_value ( value )
for dial , value in zip ( dials , values ) ) ,
lag_ratio = 0.05 ,
) )
self . wait ( )
class ThreeWordsToOne ( InteractiveScene ) :
def construct ( self ) :
# Test
image = ImageMobject ( " CHMTopText " )
image . set_height ( FRAME_HEIGHT )
# self.add(image)
phrase = Text ( " Computer History Museum " , font_size = 61 )
words = VGroup ( phrase [ word ] [ 0 ] for word in phrase . get_text ( ) . split ( " " ) )
words . move_to ( [ 0 , 2.627 , 0 ] )
og_words = words . copy ( )
og_words . shift ( DOWN )
words [ 0 ] . shift ( 0.13 * LEFT )
words [ 2 ] . shift ( 0.4 * RIGHT )
colors = [ " #63DCF7 " , " #90C9FA " , " #85D4FE " ]
for word , color in zip ( words , colors ) :
word . set_color ( color )
words . save_state ( )
self . add ( words )
self . wait ( )
# Back to unity
rect = SurroundingRectangle ( og_words )
rect . set_color ( RED )
chm_image = ImageMobject ( " /Users/grant/3Blue1Brown Dropbox/3Blue1Brown/videos/2024/transformers/chm/images/CHM_Exterior.jpeg " )
chm_image . match_width ( rect )
chm_image . next_to ( rect , DOWN )
self . play ( Transform ( words , og_words ) )
self . play (
ShowCreation ( rect ) ,
FadeIn ( chm_image , DOWN )
)
self . wait ( )
# Three pieces
rects = VGroup (
SurroundingRectangle ( word ) . set_fill ( color , 0.2 ) . set_stroke ( color , 2 )
for word , color in zip ( words . saved_state , colors )
)
words . set_z_index ( 1 )
icons = VGroup (
SVGMobject ( " GenericComputer.svg " ) ,
SVGMobject ( " History.svg " ) ,
SVGMobject ( " Museum.svg " ) ,
)
for word , icon in zip ( words . saved_state , icons ) :
icon . set_fill ( word . get_color ( ) , 1 , border_width = 1 )
icon . set_height ( 1 )
icon . next_to ( word , DOWN )
self . remove ( chm_image )
self . play (
ReplacementTransform ( VGroup ( rect ) , rects ) ,
Restore ( words ) ,
* (
FadeTransform ( chm_image . copy ( ) , icon )
for icon in icons
)
)
self . wait ( )
class ExamplePhraseHeader ( InteractiveScene ) :
def construct ( self ) :
# Test
phrase = Text ( " The Computer History Museum \n is located in ????? " )
phrase . to_edge ( UP )
rect = SurroundingRectangle ( phrase ) . set_stroke ( WHITE , 2 )
q_marks = phrase [ " ????? " ] [ 0 ]
q_marks [ : : 4 ] . set_fill ( opacity = 0 )
q_rect = SurroundingRectangle ( q_marks )
q_rect . set_fill ( YELLOW , 0.25 )
q_rect . set_stroke ( YELLOW , 2 )
self . add ( q_rect )
self . add ( phrase )
class TrainingDataCHM ( InteractiveScene ) :
def construct ( self ) :
# Test
passages = [
" The Computer History Museum (CHM) is a museum ... located in Mountain View... " ,
" Computer History Museum ... 1401 N. Shoreline Blvd. Mountain View... " ,
" Things to do in Mountain View ... the Computer History Museum ... " ,
" While I was in Mountain View ... stopped by the Computer History Museum ... " ,
]
items = VGroup (
get_paragraph ( passage . split ( " " ) , line_len = 35 , font_size = 30 )
for passage in passages
)
items . arrange ( DOWN , buff = LARGE_BUFF , aligned_edge = LEFT )
items . to_corner ( DL )
items . shift ( 0.5 * UP )
dots = Tex ( R " \ vdots " )
dots . next_to ( items , DOWN , MED_LARGE_BUFF )
dots . shift_onto_screen ( buff = MED_SMALL_BUFF )
items . add ( dots )
title = Text ( " Training Data " )
title . next_to ( items , UP , buff = LARGE_BUFF )
title . shift_onto_screen ( buff = MED_SMALL_BUFF )
underline = Underline ( title )
chm_phrases = VGroup ( item [ " Computer History Museum " ] for item in items )
mv_phrases = VGroup ( item [ " Mountain View " ] for item in items )
self . play (
FadeIn ( title ) ,
ShowCreation ( underline ) ,
LaggedStartMap ( FadeIn , items , shift = DOWN , lag_ratio = 0.15 )
)
self . wait ( )
self . play ( chm_phrases . animate . set_color ( RED ) . set_anim_args ( lag_ratio = 0.1 ) )
self . wait ( )
self . play ( mv_phrases . animate . set_color ( PINK ) . set_anim_args ( lag_ratio = 0.1 ) )
self . wait ( )
# Arrows to ffn
ffn_point = 3 * RIGHT + DOWN
arrows = VGroup (
Arrow (
item . get_right ( ) ,
interpolate ( item . get_right ( ) , ffn_point , 0.6 ) ,
path_arc = arc * DEGREES ,
)
for item , arc in zip ( items [ : - 1 ] , range ( - 40 , 40 , 20 ) )
)
arrows . set_fill ( border_width = 1 )
self . play ( Write ( arrows , lag_ratio = 0.1 ) , run_time = 3 )
self . play (
LaggedStart (
* (
FadeOutToPoint ( letter . copy ( ) , ffn_point )
for letter in VGroup ( chm_phrases , mv_phrases ) . family_members_with_points ( )
) ,
lag_ratio = 1e-2 ,
run_time = 3
)
)
self . wait ( )
class DivyUpParameters ( ShowMachineWithDials ) :
def construct ( self ) :
# Show machine
frame = self . frame
self . set_floor_plane ( " xz " )
machine = VGroup ( * self . get_blocks_and_dials ( ) )
blocks , llm_text , flat_dials , last_dials = machine
machine . set_height ( 3.0 )
machine . to_edge ( DOWN , buff = LARGE_BUFF )
block_outlines = blocks . copy ( )
block_outlines . set_fill ( opacity = 0 )
block_outlines . set_stroke ( WHITE , 2 )
block_outlines . insert_n_curves ( 20 )
# last_dials.set_submobjects(last_dials[:3]) # Remove
last_dials . set_stroke ( opacity = 1 )
for dial in last_dials :
dial [ 0 ] . set_stroke ( width = 1 )
dial [ 1 ] . set_stroke ( width = 1 )
dial [ 3 ] . set_stroke ( width = ( 3 , 0 ) )
frame . reorient ( - 23 , - 13 , 0 , ( - 0.41 , - 1.71 , - 0.06 ) , 4.95 )
self . play (
FadeIn ( blocks , shift = 0.0 , lag_ratio = 0.01 ) ,
LaggedStartMap ( VShowPassingFlash , block_outlines . family_members_with_points ( ) , time_width = 2.0 , lag_ratio = 0.01 , remover = True ) ,
LaggedStartMap ( VFadeInThenOut , flat_dials , lag_ratio = 0.001 , remover = True ) ,
FadeIn ( last_dials , time_span = ( 2 , 3 ) ) ,
self . frame . animate . reorient ( 10 , - 2 , 0 , ( - 0.25 , - 1.58 , - 0.02 ) , 4.61 ) ,
run_time = 3 ,
)
self . remove ( flat_dials )
# Show individual blocks
top_blocks = blocks [ : 3 ] . copy ( )
all_dials = VGroup ( * last_dials )
for block in top_blocks :
dials = last_dials . copy ( )
dials . rotate ( self . machine_phi , RIGHT )
dials . rotate ( self . machine_theta , UP )
dials . move_to ( block )
dials . set_stroke ( opacity = 1 )
block . add ( dials )
block . target = block . generate_target ( )
dials . set_opacity ( 0 )
all_dials . add ( * dials )
block_targets = Group ( block . target for block in top_blocks )
block_targets . rotate ( - self . machine_theta , UP )
block_targets . rotate ( - self . machine_phi , RIGHT )
block_targets . set_height ( 2 )
block_targets . arrange ( RIGHT , buff = 1.5 )
block_targets . to_edge ( UP )
block_targets . set_shading ( 0.1 , 0.1 , 0.1 )
labels = VGroup (
TexText ( R " Word $ \ to$ Vector " ) ,
Text ( " Attention " ) ,
Text ( " Feedforward " ) ,
)
for label , block in zip ( labels , block_targets ) :
label . next_to ( block , DOWN )
self . add (
blocks [ 0 ] , top_blocks [ 0 ] ,
blocks [ 1 ] , top_blocks [ 1 ] ,
blocks [ 2 ] , top_blocks [ 2 ] ,
blocks [ 3 : ] , last_dials
)
self . play (
MoveToTarget ( top_blocks [ 1 ] , time_span = ( 0 , 2 ) ) ,
MoveToTarget ( top_blocks [ 2 ] , time_span = ( 1 , 3 ) ) ,
MoveToTarget ( top_blocks [ 0 ] , time_span = ( 2 , 4 ) ) ,
Write ( labels [ 1 ] , time_span = ( 1.5 , 2 ) ) ,
Write ( labels [ 2 ] , time_span = ( 2.5 , 3 ) ) ,
Write ( labels [ 0 ] , time_span = ( 3.5 , 4 ) ) ,
frame . animate . to_default_state ( ) ,
run_time = 4
)
self . wait ( )
# Change all the parameters
self . play (
LaggedStart (
( dial . animate_set_value ( dial . get_random_value ( ) )
for dial in all_dials ) ,
lag_ratio = 1 / len ( all_dials ) ,
run_time = 6
) ,
LaggedStart (
( FlashAround ( dial , buff = 0 , color = YELLOW )
for dial in all_dials ) ,
lag_ratio = 1 / len ( all_dials ) ,
run_time = 6
) ,
)
self . wait ( )
2024-11-15 12:10:17 -08:00
# End clips
class ShowPreviousVideos ( InteractiveScene ) :
def construct ( self ) :
# Backdrop
background = FullScreenRectangle ( )
self . add ( background )
line = Line ( UP , DOWN ) . set_height ( FRAME_HEIGHT )
line . set_stroke ( WHITE , 2 )
series_name = Text ( " Deep Learning Series " , font_size = 68 )
series_name . to_edge ( UP , buff = 0.35 )
self . add ( series_name )
# Show thumbnails
thumbnails = Group (
Group (
Rectangle ( 16 , 9 ) . set_height ( 1 ) . set_stroke ( WHITE , 2 ) ,
ImageMobject ( f " https://img.youtube.com/vi/ { slug } /maxresdefault.jpg " , height = 1 )
)
for slug in [
" aircAruvnKk " ,
" IHZwWFHWa-w " ,
" Ilg3gGewQ5U " ,
" tIeHLnjs5U8 " ,
" wjZofJX0v4M " ,
" eMlx5fFNoYc " ,
" 9-Jl0dxWQs8 " ,
]
)
thumbnails . arrange_in_grid ( n_cols = 4 , buff = 0.2 )
thumbnails . set_width ( FRAME_WIDTH - 1 )
thumbnails . next_to ( series_name , DOWN , buff = 1.0 )
thumbnails [ - 3 : ] . set_x ( 0 )
self . play ( LaggedStartMap ( FadeIn , thumbnails , shift = 0.3 * UP , lag_ratio = 0.35 , run_time = 4 ) )
self . wait ( )
# Rearrange
left_x = - FRAME_WIDTH / 4
self . play (
series_name . animate . set_x ( left_x ) ,
thumbnails . animate . arrange_in_grid ( n_cols = 2 , buff = 0.25 ) . set_height ( 6 ) . set_x ( left_x ) . to_edge ( DOWN ) ,
ShowCreation ( line , time_span = ( 1 , 2 ) ) ,
run_time = 2 ,
)
self . wait ( )
class EndScreen ( PatreonEndScreen ) :
title_text = " Where to dig deeper "
thanks_words = """
Special thanks to these Patreon supporters
"""