3b1b-videos/_2024/transformers/ml_basics.py
2024-10-13 21:05:24 -05:00

2430 lines
81 KiB
Python

from manim_imports_ext import *
from _2024.transformers.helpers import *
from _2024.transformers.embedding import *
from _2024.transformers.generation import *
class DialTest(InteractiveScene):
def construct(self):
# Test
dial = Dial(radius=0.5)
self.add(dial)
self.play(dial.animate_set_value(0.5, run_time=1))
# Test
machine = MachineWithDials()
self.add(machine)
self.play(machine.random_change_animation())
class MLWithinDeepL(InteractiveScene):
def construct(self):
# Organize boxes
kw = dict(font_size=36, opacity=0.25)
model_boxes = VGroup(
self.get_titled_box("Multilayer Perceptrons", BLUE_D, **kw),
self.get_titled_box("Convolutional Neural Networks", BLUE_D, **kw),
self.get_titled_box("Transformers", BLUE, **kw),
)
for box in model_boxes:
box.box.set_width(model_boxes.get_width(), stretch=True)
dots = Tex(R"\vdots", font_size=72)
model_boxes.add(dots)
model_boxes.arrange(DOWN, buff=0.1)
dots.shift(0.2 * DOWN)
transformer_box = model_boxes[2]
dl_box = self.get_titled_box(
"Deep Learning", TEAL,
font_size=60,
y_space=model_boxes.get_height() + 1.0,
x_space=2.75,
opacity=0.05
)
model_boxes.next_to(dl_box.title, DOWN)
# Animate in word
transformer_box.save_state()
transformer_box.box.set_opacity(0)
transformer_box.set_height(1)
transformer_box.move_to(np.array([-1.58, -2.01, 0]))
self.add(transformer_box)
self.wait()
self.add(dl_box, transformer_box)
self.play(LaggedStart(
FadeIn(dl_box, scale=1.2),
Restore(transformer_box),
*(FadeIn(model_boxes[i]) for i in [0, 1, 3]),
), lag_ratio=0.75, run_time=2)
self.wait()
dl_box.add(model_boxes)
self.add(dl_box)
# Place within ML box
ml_box = self.get_titled_box(
"Machine Learning",
GREEN,
opacity=0.1,
font_size=72,
x_space=6.0,
y_space=5.0
)
dl_box.target = dl_box.generate_target()
blank_boxes = dl_box.box.replicate(2)
inner_boxes = VGroup(*blank_boxes, dl_box.target)
reg_drawing = self.get_regression_drawing()
bayes_net = self.get_bayes_net_drawing()
for drawing, box in zip([reg_drawing, bayes_net], blank_boxes):
drawing.set_height(0.8 * box.get_height())
drawing.move_to(box)
box.add(drawing)
inner_boxes.set_height(3.5)
inner_boxes.arrange(RIGHT)
inner_boxes.set_max_width(ml_box.get_width() - 0.5)
inner_boxes.next_to(ml_box.title, DOWN, buff=1.0)
self.add(ml_box, dl_box, blank_boxes)
self.play(
FadeIn(ml_box),
MoveToTarget(dl_box),
LaggedStartMap(FadeIn, blank_boxes, scale=2.0, lag_ratio=0.5)
)
self.wait()
ml_box.add(dl_box, blank_boxes)
# Learn from data
words = Text("Learn from data", font_size=72)
words.to_edge(UP, buff=MED_SMALL_BUFF)
learn = words["Learn"][0]
learn.save_state()
learn.set_x(0)
words["data"].set_color(YELLOW)
ml_box.target = ml_box.generate_target()
ml_box.target.scale(0.75)
ml_box.target.to_edge(DOWN)
arrow = Arrow(ml_box.target, words)
self.play(
MoveToTarget(ml_box),
GrowFromCenter(arrow),
TransformFromCopy(ml_box.title["Learn"][0], learn),
)
self.play(
Restore(learn),
FadeIn(words["from data"][0], lag_ratio=0.1, shift=0.2 * RIGHT),
)
self.wait()
self.play(
FadeOut(ml_box),
FadeOut(arrow),
)
self.wait()
# Go back to the box
self.clear()
ml_box.center()
self.add(ml_box)
# Pop out
ml_box.remove(dl_box)
ml_box.add(dl_box.copy())
ml_box.target = ml_box.generate_target()
ml_box.target.scale(0.25).to_edge(LEFT)
dl_box.target = dl_box.generate_target()
dl_box.target.scale(2.0)
dl_box.target.next_to(ml_box.target, RIGHT, buff=0.75),
lines = VGroup(*(
Line(
ml_box.target[-1].get_corner(RIGHT + v),
dl_box.target.get_corner(LEFT + v)
)
for v in [UP, DOWN]
))
lines.set_stroke(TEAL, 2)
self.play(
MoveToTarget(ml_box),
MoveToTarget(dl_box),
GrowFromPoint(lines[0], dl_box.get_corner(UR)),
GrowFromPoint(lines[1], dl_box.get_corner(DR)),
run_time=1.5,
)
self.wait()
# Show a neural network
network = NeuralNetwork([5, 10, 5])
network.next_to(dl_box, RIGHT, buff=1.0)
self.play(
FadeIn(network.layers[0]),
ShowCreation(network.lines[0], lag_ratio=0.01),
FadeIn(network.layers[1], lag_ratio=0.5),
run_time=2
)
self.play(
ShowCreation(network.lines[1], lag_ratio=0.01),
FadeIn(network.layers[2], lag_ratio=0.5),
run_time=2
)
# Ambiently change the network
for _ in range(6):
self.play(
network.animate.randomize_line_style().randomize_layer_values(),
run_time=3,
lag_ratio=1e-4
)
# Pile of matrices
pile_words = Text("Pile of matrices")
pile_words.next_to(network, UP)
path_arc = -60 * DEGREES
arrow = Arrow(dl_box.get_top(), pile_words.get_corner(UL), path_arc=path_arc)
matrices = VGroup(*(
WeightMatrix(shape=(8, 6), ellipses_row=None, ellipses_col=None)
for x in range(10)
))
matrices.match_width(network)
matrices.move_to(network, UP)
matrices.shift(0.5 * DOWN)
matrix_shift = 0.5 * (IN + RIGHT)
matrices.arrange(OUT, buff=0.25)
matrices.move_to(network)
for matrix in matrices[:-1]:
matrix.target = matrix.generate_target()
for entry in matrix.target.get_entries():
dot = Dot(radius=0.05)
dot.set_fill(entry.get_fill_color(), opacity=0.25)
dot.move_to(entry)
entry.become(dot)
matrix.target[-1].set_opacity(0.25)
matrices[-1].get_entries().set_backstroke(BLACK, 8)
self.play(
FadeOut(network, 2 * DOWN),
ShowCreation(arrow),
FadeInFromPoint(pile_words, dl_box.title.get_center(), path_arc=path_arc),
FadeOut(network, DOWN)
)
mat_shift = 0.5 * IN + 0.25 * DOWN
self.play(
LaggedStart(*(
Succession(
FadeIn(matrix, shift=mat_shift),
MoveToTarget(matrix)
)
for matrix in matrices[:-1]
), lag_ratio=0.25, run_time=5),
Animation(Point()),
FadeIn(matrices[-1], shift=mat_shift, time_span=(3.75, 4.75))
)
self.wait()
def get_titled_box(self, text, color, font_size=48, y_space=0.5, x_space=0.5, opacity=0.1):
title = Text(text, font_size=font_size)
box = Rectangle(
title.get_width() + x_space,
title.get_height() + y_space
)
box.set_fill(interpolate_color(BLACK, color, opacity), 1)
box.set_stroke(color, 2)
title.next_to(box.get_top(), DOWN, buff=MED_SMALL_BUFF)
result = VGroup(box, title)
result.box = box
result.title = title
return result
def get_regression_drawing(self):
axes = Axes((-1, 10), (-1, 10))
m = 0.5
y0 = 2
line = axes.get_graph(lambda x: y0 + m * x)
line.set_stroke(YELLOW, 2)
dots = VGroup(
Dot(axes.c2p(x, y0 + m * x + np.random.normal()))
for x in np.random.uniform(0, 10, 15)
)
reg_drawing = VGroup(axes, dots, line)
return reg_drawing
def get_bayes_net_drawing(self):
radius = MED_SMALL_BUFF
node = Circle(radius=radius)
node.set_stroke(GREY_B, 2)
node.shift(2 * DOWN)
nodes = VGroup(
node.copy().shift(x * RIGHT + y * UP)
for x, y in [
(-1, 0),
(1, 0),
(-2, 2),
(0, 2),
(2, 2),
(-2, 4),
(0, 4),
]
)
edge_index_pairs = [
(2, 0),
(3, 0),
(3, 1),
(4, 1),
(5, 2),
(6, 3),
]
edges = VGroup()
for i1, i2 in edge_index_pairs:
n1, n2 = nodes[i1], nodes[i2]
edge = Arrow(
n1.get_center(),
n2.get_center(),
buff=radius,
color=WHITE,
stroke_width=3
)
edges.add(edge)
network = VGroup(nodes, edges)
return network
class ShowCross(InteractiveScene):
def construct(self):
# Test
cross = Cross(Square(side_length=5))
cross.set_stroke(width=[0, 30, 0])
self.play(ShowCreation(cross))
self.wait()
class FlashThroughImageData(InteractiveScene):
time_per_example = 0.1
def construct(self):
# Images
image_data = load_image_net_data()
arrow = Vector(RIGHT)
for path, text in ProgressDisplay(image_data):
image = ImageMobject(str(path))
label = Text(text.split(",")[0])
label.use_winding_fill(False)
image.next_to(arrow, LEFT)
label.next_to(arrow, RIGHT)
self.add(image, arrow, label)
self.wait(self.time_per_example)
self.remove(image, label)
if hasattr(image, "shader_wrapper"):
for tid in image.shader_wrapper.texture_names_to_ids.values():
release_texture(tid)
class FlashThroughTextData2(InteractiveScene):
n_examples = 200
time_per_example = 0.1
window_size = 50
line_len = 35
ul_point = 5 * LEFT + 3 * UP
def construct(self):
# Test
totc = read_in_book(name="tale_of_two_cities")
words = re.split(r"\s", totc)
words = list(filter(lambda s: s, words))
for n in range(self.n_examples):
index = random.randint(0, len(words) - self.window_size)
window = words[index:index + self.window_size]
phrase = get_paragraph(window, line_len=self.line_len)
phrase.move_to(self.ul_point, UL)
word = phrase[window[-1]][-1]
rect = SurroundingRectangle(word, buff=0.1)
rect.set_stroke(YELLOW, 2)
rect.set_fill(YELLOW, 0.5)
self.add(phrase)
self.wait(self.time_per_example)
self.remove(phrase)
class TweakedMachine(InteractiveScene):
n_tweaks = 200
time_per_example = 0.1
def construct(self):
# Test
machine = MachineWithDials(
dial_config=dict(
value_to_color_config=dict(
low_negative_color=BLUE_E,
high_negative_color=BLUE_B,
)
)
)
machine.move_to(2 * DOWN)
machine.set_width(4)
arrow = Vector(DOWN, stroke_width=10)
arrow.next_to(machine, UP)
self.add(machine, arrow)
values = np.array([d.get_random_value() for d in machine.dials])
for n in range(self.n_tweaks):
nudges = np.random.uniform(-1, 1, values.shape)
values += 0.1 * nudges
values[values > 1.0] = 0.9
values[values < 0.0] = 0.1
for dial, value in zip(machine.dials, values):
dial.set_value(value)
self.wait(self.time_per_example)
class PremiseOfML(InteractiveScene):
box_center = RIGHT
n_examples = 50
random_seed = 316
show_matrices = False
def construct(self):
self.init_data()
# Set up input and output
machine = self.get_machine()
machine.set_width(4)
machine.move_to(self.box_center)
model_label = Text("Model", font_size=72)
model_label.move_to(machine.box)
in_arrow = Vector(RIGHT).next_to(machine, LEFT)
out_arrow = Vector(RIGHT).next_to(machine, RIGHT)
self.add(machine.box)
self.add(in_arrow, out_arrow)
self.add(model_label)
# Show initial input and output
in_data, out_data = self.new_input_output_example(in_arrow, out_arrow)
in_word, out_word = [
Text(word).next_to(machine, UP).match_x(mob).shift_onto_screen()
for word, mob in [("Input", in_data), ("Output", out_data)]
]
self.play(
FadeIn(in_data, lag_ratio=0.001),
FadeIn(in_word, 0.5 * UP),
)
self.play(FadeOutToPoint(in_data.copy(), machine.get_left(), lag_ratio=0.005, path_arc=-60 * DEGREES))
self.play(
FadeInFromPoint(out_data, machine.get_right(), lag_ratio=0.1, path_arc=60 * DEGREES),
FadeIn(out_word, 0.5 * UP)
)
self.wait()
# Show code
model_label.target = model_label.generate_target()
model_label.target.scale(in_word[0].get_height() / model_label[0].get_height())
model_label.target.align_to(in_word, UP)
code = self.get_code()
code.set_height(machine.get_height() - MED_SMALL_BUFF)
code.set_max_width(machine.get_width() - MED_SMALL_BUFF)
code.move_to(machine, UP).shift(SMALL_BUFF * DOWN)
self.play(
MoveToTarget(model_label),
ShowIncreasingSubsets(code, run_time=3),
)
self.wait()
# Show tunable parameters
param_label = Text("Tunable parameters")
param_label.next_to(machine, UP)
param_label.set_color(BLUE)
self.play(
FadeOut(code, 0.25 * DOWN, lag_ratio=0.01),
Write(machine.dials, lag_ratio=0.001),
FadeOut(model_label, 0.5 * UP),
FadeIn(param_label, 0.5 * UP),
)
self.play(machine.rotate_all_dials())
self.wait()
# Show lots of new data
for n in range(self.n_examples):
new_in_data, new_out_data = self.new_input_output_example(in_arrow, out_arrow)
self.add(in_data, out_data)
time_span = (0, 0.35)
self.play(
machine.random_change_animation(run_time=0.5),
FadeOut(in_data, time_span=time_span),
FadeOut(out_data, time_span=time_span),
FadeIn(new_in_data, time_span=time_span),
FadeIn(new_out_data, time_span=time_span),
)
in_data, out_data = new_in_data, new_out_data
if not self.show_matrices:
return
# Make room
up_shift = 1.5 * UP
down_shift = 1.75 * DOWN
down_group = Group(in_arrow, machine, param_label, out_arrow, out_data, out_word)
self.play(
in_data.animate.scale(0.75).shift(up_shift + 0.5 * UP),
UpdateFromFunc(out_data, lambda m: m.match_y(in_data)),
in_word.animate.shift(up_shift),
down_group.animate.shift(down_shift),
)
# Create pixels
image = in_data
pixels = create_pixels(in_data)
# Show input array
in_array = NumericEmbedding(shape=(10, 10), ellipses_col=-2)
in_array.match_height(machine)
in_array.next_to(in_arrow, LEFT)
image.set_opacity(0.8)
self.play(
TransformFromCopy(
pixels,
VGroup(*(in_array.get_entries().family_members_with_points())),
run_time=2,
lag_ratio=1e-3
),
FadeInFromPoint(in_array.get_brackets(), image.get_bottom()),
Write(in_array.get_ellipses(), time_span=(1, 2))
)
self.play(image.animate.set_opacity(1))
self.wait()
# Show one dimensional array
vector = NumericEmbedding(length=10)
vector.replace(in_array, dim_to_match=1)
vector.move_to(in_array, RIGHT)
self.remove(in_array)
self.play(
TransformFromCopy(in_array.get_brackets(), vector.get_brackets()),
TransformFromCopy(in_array.get_columns()[5], vector.get_columns()[0]),
*map(FadeOut, in_array.get_columns()),
)
self.wait()
self.remove(vector)
self.play(LaggedStart(
TransformFromCopy(vector.get_brackets(), in_array.get_brackets()),
TransformFromCopy(vector.get_columns()[0], in_array.get_columns()[5]),
*(
FadeIn(col, shift=col.get_center() - vector.get_center())
for col in in_array.get_columns()
)
))
self.wait()
# Show 3d tensor
self.frame.set_field_of_view(30 * DEGREES)
dot_array = in_array.copy()
for entry in (*dot_array.get_entries(), *dot_array.get_ellipses()):
dot = Dot(entry.get_center(), radius=0.06)
entry.set_submobjects([dot])
tensor = VGroup(*(
dot_array.copy()
for n in range(5)
))
for layer in tensor:
for dot in (*layer.get_entries(), *layer.get_ellipses()):
dot.set_fill(
interpolate_color(GREY_C, GREY_B, random.random()),
opacity=0.5,
)
dot.set_backstroke(BLACK, 2)
tensor.arrange(OUT, buff=0.25)
tensor.move_to(in_array, RIGHT)
tensor.rotate(5 * DEGREES, RIGHT)
tensor.rotate(5 * DEGREES, UP)
self.remove(in_array)
self.play(TransformFromCopy(VGroup(in_array), tensor))
self.play(Rotate(tensor, 20 * DEGREES, axis=UP, run_time=4))
self.play(Transform(tensor, VGroup(in_array), remover=True))
self.add(in_array)
# Express output as an array of numbers
values = np.random.uniform(0, 1, (10, 1))
values[5] = 9.7
out_array = DecimalMatrix(values, ellipses_row=-2)
out_array.match_height(machine)
out_array.match_y(out_arrow)
out_array.match_x(out_word)
self.play(
FadeInFromPoint(out_array, machine.get_right(), lag_ratio=1e-3),
out_data.animate.scale(0.75).fade(0.5).rotate(-PI / 2).next_to(out_array, RIGHT, buff=0.25),
)
self.wait()
# Describe parameters as weights
weights_label = Text("Weights")
weights_label.next_to(machine, UP, buff=0.5)
weights_label.match_color(param_label)
equiv = Tex(R"\Updownarrow")
equiv.next_to(weights_label, UP)
top_dials = machine.dials[:8]
dial_rects = VGroup(*map(SurroundingRectangle, top_dials))
dial_rects.set_stroke(TEAL, 2)
dial_arrows = VGroup(*(
Arrow(weights_label.get_bottom(), rect.get_top(), buff=0.05)
for rect in dial_rects
))
dial_arrows.set_stroke(TEAL)
self.play(
FadeIn(weights_label, scale=2),
param_label.animate.next_to(equiv, UP),
Write(equiv),
)
self.play(
LaggedStart(*(
VFadeInThenOut(VGroup(arrow, rect))
for arrow, rect in zip(dial_arrows, dial_rects)
), lag_ratio=0.25, run_time=3)
)
self.wait()
# Show weighted sum
machine.dials.save_state()
weights_label.set_backstroke(BLACK, 5)
weights_label.target = weights_label.generate_target()
weights_label.target.next_to(top_dials, DOWN, buff=0.25)
weighted_sum = Tex(
R"w_1 x_1 + w_2 x_2 + w_3 x_3 + \cdots + w_n x_n",
font_size=42,
)
weighted_sum.next_to(machine, UP, buff=1.0)
weight_parts = weighted_sum[re.compile(r"w_\d|w_n")]
weight_parts.set_color(BLUE)
data_parts = weighted_sum[re.compile(r"x_\d|x_n")]
data_parts.set_color(GREY_A)
indices = [0, 1, 2, -1]
dial_lines = VGroup(*(
Line(top_dials[n].get_top(), weight_parts[n].get_bottom(), buff=0.1)
for n in indices
))
ellipses = weighted_sum[R"\cdots"]
dial_lines.set_stroke(BLUE_B, 1)
column = in_array.get_columns()[-1]
col_rect = SurroundingRectangle(column)
col_rect.set_stroke(YELLOW, 2)
self.play(ShowCreation(col_rect))
self.play(
FadeOut(VGroup(param_label, equiv), UP),
MoveToTarget(weights_label),
machine.dials[8:].animate.fade(0.75),
LaggedStart(*(
TransformFromCopy(column[n], data_parts[n])
for n in indices
)),
Group(in_data, in_word).animate.to_edge(LEFT, buff=0.25)
)
self.play(
Write(weighted_sum["+"]),
Write(weighted_sum[R"\cdots"]),
LaggedStart(*(
FadeTransform(top_dials[n].copy(), weight_parts[n])
for n in indices
)),
LaggedStartMap(ShowCreation, dial_lines),
run_time=1
)
self.wait()
for x in range(3):
self.play(*(
dial.animate_set_value(dial.get_random_value())
for dial in top_dials
))
# Wrap a function around it
func_wrapper = Tex(R"f()")
func_wrapper[:2].next_to(weighted_sum, LEFT, buff=SMALL_BUFF)
func_wrapper[2].next_to(weighted_sum, RIGHT, buff=SMALL_BUFF)
func_wrapper.set_color(PINK)
nl_words = Text("Simple nonlinear\nfunction", font_size=42, alignment="LEFT")
nl_words.next_to(func_wrapper, UP, buff=1.5, aligned_edge=LEFT)
nl_words.match_color(func_wrapper)
nl_arrow = Arrow(nl_words, func_wrapper[0].get_top())
nl_arrow.match_color(nl_words)
self.play(
FadeIn(func_wrapper),
FadeIn(nl_words, lag_ratio=0.1),
ShowCreation(nl_arrow),
)
self.wait()
# Show next layer
weights_label.target = weights_label.generate_target()
weights_label.target.next_to(weighted_sum, UP, buff=1.0)
dial_lines.target = VGroup(*(
Line(
weights_label.target, weight_parts[index].get_top(),
buff=SMALL_BUFF
)
for index in indices
))
dial_lines.target.match_style(dial_lines)
layer1 = NumericEmbedding(shape=(10, 5), ellipses_col=-2)
layer1.match_height(in_array)
layer1.next_to(in_arrow, RIGHT)
mid_arrow = in_arrow.copy()
mid_arrow.next_to(layer1, RIGHT)
dots = Tex(R"\dots").next_to(mid_arrow, RIGHT)
expr_rect = SurroundingRectangle(func_wrapper)
expr_rect.set_stroke(PINK, 2)
x01_rect = SurroundingRectangle(layer1.elements[0])
x01_rect.match_style(expr_rect)
rect_lines = VGroup(*(
Line(expr_rect.get_corner(DOWN + v), x01_rect.get_corner(UP + v))
for v in [LEFT, RIGHT]
))
rect_lines.match_style(expr_rect)
self.play(LaggedStart(
FadeOut(weights_label),
FadeOut(dial_lines),
FadeOut(nl_words),
FadeOut(nl_arrow),
FadeOut(col_rect),
FadeOut(machine),
FadeIn(expr_rect),
))
self.play(
TransformFromCopy(in_array.get_brackets(), layer1.get_brackets()),
TransformFromCopy(in_arrow, mid_arrow),
out_arrow.animate.next_to(dots, RIGHT),
Write(dots),
)
self.play(
TransformFromCopy(expr_rect, x01_rect),
ShowCreation(rect_lines, lag_ratio=0),
FadeInFromPoint(layer1.elements[0], expr_rect.get_center()),
)
self.play(ShowIncreasingSubsets(layer1[1:-1]))
self.add(layer1)
self.wait()
# Highlight a subset of the data
in_subset = VGroup(*(
elem
for row in in_array.get_rows()[:3]
for elem in row[:3]
))
in_subset_rects = VGroup(*map(SurroundingRectangle, in_subset))
data_part_rects = VGroup(*map(SurroundingRectangle, data_parts))
self.play(
LaggedStartMap(ShowCreationThenFadeOut, in_subset_rects, lag_ratio=0.02),
LaggedStartMap(ShowCreationThenFadeOut, data_part_rects, lag_ratio=0.04),
run_time=3
)
self.wait()
# Show added layers
to_fade = VGroup(
func_wrapper, expr_rect, rect_lines, x01_rect,
weighted_sum
)
self.play(
LaggedStartMap(FadeOut, to_fade, run_time=1),
in_arrow.animate.scale(0.5, about_edge=LEFT),
layer1.animate.rotate(70 * DEGREES, UP).next_to(in_arrow, RIGHT, buff=-0.25),
mid_arrow.animate.scale(0.5).next_to(in_arrow, RIGHT, buff=0.75),
)
layer1_group = VGroup(layer1, mid_arrow)
layer2_group, layer3_group = layer1_group.replicate(2)
layer2_group.next_to(layer1_group, RIGHT, buff=SMALL_BUFF)
layer3_group.next_to(layer2_group, RIGHT, buff=SMALL_BUFF)
self.play(TransformFromCopy(layer1_group, layer2_group))
self.play(
TransformFromCopy(layer2_group, layer3_group),
VGroup(dots, out_arrow).animate.next_to(layer3_group, RIGHT),
)
self.play(
LaggedStart(*(
dot.animate.shift(0.1 * UP).set_anim_args(rate_func=there_and_back)
for dot in dots
), lag_ratio=0.25)
)
self.wait()
# Bring back machine
layers = VGroup(layer1_group, layer2_group, layer3_group, dots)
self.play(
FadeIn(machine, scale=0.8),
FadeIn(weights_label, shift=DOWN),
ShowCreation(dial_lines, lag_ratio=0.1),
FadeIn(weighted_sum, shift=UP),
FadeOut(layers, scale=0.8),
)
self.wait()
self.play(
machine.random_change_animation()
)
self.wait()
# Show a matrix
frame = self.frame
matrix, vector, equals, rhs = get_full_matrix_vector_product()
mat_prod_group = VGroup(matrix, vector, equals, rhs)
mat_prod_group.next_to(machine, UP, buff=2.0)
mat_prod_group.shift(0.5 * LEFT)
p0 = machine.get_corner(UL)
p1 = matrix.get_corner(DL)
p2 = machine.get_corner(UR)
p3 = rhs.get_corner(DR)
brace = VGroup(
CubicBezier(p0, p0 + 2 * UP, p1 + 2 * DOWN, p1 + 0.1 * DOWN),
CubicBezier(p2, p2 + 2 * UP, p3 + 2 * DOWN, p3 + 0.1 * DOWN),
)
brace.set_stroke(WHITE, 5)
self.play(LaggedStart(
TransformFromCopy(data_parts, vector.get_columns()[0]),
TransformFromCopy(weight_parts, matrix.get_rows()[0]),
FadeTransform(weighted_sum, rhs.get_rows()[0]),
frame.animate.set_height(10, about_edge=DOWN),
FadeOut(in_data, DOWN),
FadeOut(out_data, DOWN),
in_word.animate.next_to(in_array, UP),
FadeIn(matrix, lag_ratio=0.1),
ShowCreation(brace, lag_ratio=0),
weights_label.animate.set_height(0.5).next_to(matrix, UP, buff=MED_SMALL_BUFF),
Uncreate(dial_lines, lag_ratio=0.1),
FadeOut(col_rect),
machine.dials.animate.restore(),
FadeIn(vector.get_brackets()),
FadeIn(rhs.get_brackets()),
FadeIn(equals),
run_time=3,
lag_ratio=0.1,
))
self.wait()
# Animate matrix vector product
ghost_row = rhs.get_rows()[0].copy()
ghost_row.set_opacity(0.25)
self.add(ghost_row)
show_symbolic_matrix_vector_product(
self, matrix, vector, rhs,
run_time_per_row=1.5
)
self.remove(ghost_row)
self.wait()
# Associate weights with dials
w_elems = matrix.get_entries()
moving_dials = machine.dials[:len(w_elems)].copy()
moving_dials.target = moving_dials.generate_target()
for dial, w_elem in zip(moving_dials.target, w_elems):
dial.move_to(w_elem)
dial.scale(2)
self.play(
w_elems.animate.set_opacity(0.25),
MoveToTarget(moving_dials, run_time=2),
)
self.play(
LaggedStart(*(
dial.animate_set_value(dial.get_random_value())
for dial in moving_dials
), lag_ratio=0.02, run_time=3)
)
self.wait()
self.play(
FadeOut(moving_dials),
w_elems.animate.set_opacity(1),
)
# Vector an data slice
v_rect = SurroundingRectangle(vector.get_entries())
self.play(
ShowCreation(v_rect),
ShowCreation(col_rect),
)
self.wait()
self.play(
FadeOut(v_rect),
FadeOut(col_rect),
)
self.wait()
# Show many matrices
lhs = VGroup(matrix, vector)
small_mat_product = Tex(R"W_{10} v_{11}")
small_mat_product[R"W_{10}"].set_color(BLUE)
w_index = small_mat_product.make_number_changeable("10")
v_index = small_mat_product.make_number_changeable("11")
small_mat_products = VGroup()
n_rows, n_cols = 16, 8
for n in range(n_rows * n_cols):
w_index.set_value(n + 1)
v_index.set_value(n + 1)
new_prod = small_mat_product.copy()
new_prod.arrange(RIGHT, buff=SMALL_BUFF, aligned_edge=DOWN)
small_mat_products.add(new_prod)
small_mat_products.arrange_in_grid(n_rows, n_cols, v_buff_ratio=2.0)
small_mat_products.replace(machine.dials)
mv_label = Text("matrix-vector products")
mv_label.next_to(machine, UP, buff=1.0)
mv_label[-1].set_opacity(0)
mv_top_label = Text("Many, many")
mv_top_label.next_to(mv_label, UP)
mv_arrows = VGroup(*(
Arrow(mv_label.get_bottom(), smp.get_top(), buff=0.1)
for smp in small_mat_products
))
self.play(
FadeTransform(mat_prod_group, small_mat_products[0]),
Uncreate(brace, lag_ratio=0),
FadeOut(machine.dials, run_time=0.5),
FadeTransform(weights_label, mv_label),
GrowFromPoint(mv_arrows[0], weights_label.get_bottom()),
frame.animate.set_height(FRAME_HEIGHT).move_to(DOWN).set_anim_args(time_span=(1, 2)),
run_time=2,
)
self.wait()
self.remove(mv_arrows)
self.play(
FadeIn(mv_top_label, UP),
mv_label[-1].animate.set_opacity(1),
ShowIncreasingSubsets(small_mat_products, rate_func=linear, run_time=12, int_func=np.ceil),
ShowSubmobjectsOneByOne(mv_arrows, rate_func=linear, run_time=12, int_func=np.ceil),
)
self.remove(mv_arrows)
self.play(FadeOut(mv_arrows[-1]))
self.wait()
def init_data(self):
self.image_data = load_image_net_data()
def new_input_output_example(self, in_arrow, out_arrow) -> tuple[Mobject, Mobject]:
path, label_text = random.choice(self.image_data)
image = ImageMobject(str(path))
image.set_width(4)
image.next_to(in_arrow, LEFT)
label = Text(label_text.split(",")[0])
label.set_max_width(2.5)
label.next_to(out_arrow, RIGHT)
return image, label
def get_machine(self):
return MachineWithDials()
def get_code(self):
# Test
src = """
#include <opencv2/opencv.hpp>
#include <iostream>
using namespace cv;
using namespace std;
int main(int argc, char** argv) {
Mat image = imread(argv[1], IMREAD_GRAYSCALE);
if (image.empty()) {
cout << "Could not open image" << endl;
return -1;
}
// Blur the image to reduce noise
Mat blurredImage;
GaussianBlur(image, blurredImage, Size(5, 5), 0);
// Detect edges with Canny
Mat edges;
Canny(blurredImage, edges, 100, 200);
"""
return Code(src, language="C++", alignment="LEFT")
class PremiseOfMLWithText(PremiseOfML):
random_seed = 316
def init_data(self):
totc = read_in_book(name="tale_of_two_cities")
words = re.split(r"\s", totc)
words = list(filter(lambda s: s, words))
self.all_words = words
def new_input_output_example(self, in_arrow, out_arrow):
words = self.all_words
window_size = 25
index = random.randint(0, len(words) - window_size)
window = words[index:index + window_size]
in_text = get_paragraph(window[:-1], line_len=25)
in_text.set_max_width(4)
in_text.next_to(in_arrow, LEFT)
out_text = Text(window[-1])
out_text.next_to(out_arrow, RIGHT)
return in_text, out_text
def get_machine(self):
machine = super().get_machine()
machine.add(VectorizedPoint().next_to(machine, DOWN, buff=0.5))
return machine
def get_code(self):
# Test
src = """
using namespace std;
vector<string> findCapitalizedWords(const string& text) {
vector<string> capitalizedWords;
stringstream ss(text);
string word;
while (ss >> word) {
// Check for uppercase
if (!word.empty() && isupper(word[0])) {
capitalizedWords.push_back(word);
}
}
return capitalizedWords;
}
int main() {
string text;
cout << "Enter text: ";
getline(cin, text); // Using getline to read spaces
"""
return Code(src, language="C++", alignment="LEFT")
class PremiseOfMLWithMatrices(PremiseOfML):
# Skip to animation 9
show_matrices = True
n_examples = 0
random_seed = 6
class LinearRegression(InteractiveScene):
radom_seed = 1
def construct(self):
# Set up axes
x_min, x_max = (-1, 12)
y_min, y_max = (-1, 10)
axes = Axes((x_min, x_max), (y_min, y_max), width=12, height=6)
axes.to_edge(DOWN)
self.add(axes)
# Add data
n_data_points = 30
m = 0.75
y0 = 1
data = np.array([
(x, y0 + m * x + 0.75 * np.random.normal(0, 1))
for x in np.random.uniform(2, x_max, n_data_points)
])
points = axes.c2p(data[:, 0], data[:, 1])
dots = DotCloud(points)
dots.set_color(YELLOW)
dots.set_glow_factor(1)
dots.set_radius(0.075)
self.add(dots)
# Make title
title = Text("Linear Regression", font_size=72)
title.to_edge(UP)
# Show line
m_tracker = ValueTracker(m)
y0_tracker = ValueTracker(y0)
line = Line()
line.set_stroke(TEAL, 2)
def update_line(line):
curr_y0 = y0_tracker.get_value()
curr_m = m_tracker.get_value()
line.put_start_and_end_on(
axes.c2p(0, curr_y0),
axes.c2p(x_max, curr_y0 + curr_m * x_max),
)
line.add_updater(update_line)
self.play(
FadeIn(title, UP),
ShowCreation(line),
)
self.wait()
# Label inputs and outputs
in_labels = VGroup(Text("Input"), Text("Square footage"))
out_labels = VGroup(Text("Output"), Text("Price"))
for in_label in in_labels:
in_label.next_to(axes.x_axis, DOWN, buff=0.1, aligned_edge=RIGHT)
for out_label in out_labels:
out_label.rotate(90 * DEGREES)
out_label.next_to(axes.y_axis, LEFT, aligned_edge=UP)
self.play(LaggedStart(
FadeIn(in_labels[0], lag_ratio=0.1),
FadeIn(out_labels[0], lag_ratio=0.1),
lag_ratio=0.5,
))
self.wait()
self.play(LaggedStart(
FadeTransform(*in_labels),
FadeTransform(*out_labels),
lag_ratio=0.8,
))
self.wait()
# Emphasize line
self.play(
VShowPassingFlash(
line.copy().set_stroke(BLUE, 8).scale(1.1).insert_n_curves(100),
time_width=1.5,
run_time=2
),
)
self.wait()
# Add line parameter updaters
words = ["slope", "y-intercept"]
value_ranges = [(0, 2, 0.2), (-2, 3, 0.5)]
m_label, y0_label = labels = VGroup(
VGroup(
Dial(value_range=value_range),
Text(f"{text} = "),
DecimalNumber(),
)
for text, value_range in zip(words, value_ranges)
)
for label, tracker in zip(labels, [m_tracker, y0_tracker]):
label[0].set_height(2 * label[2].get_height())
label.arrange(RIGHT)
label[0].f_always.set_value(tracker.get_value)
label[2].f_always.set_value(tracker.get_value)
labels.arrange(DOWN, aligned_edge=LEFT)
labels.next_to(axes.y_axis, RIGHT, buff=1.0)
labels.to_edge(UP)
self.play(
FadeOut(title, UP),
FadeIn(m_label, UP),
)
self.play(
m_tracker.animate.set_value(1.5),
run_time=2,
)
self.play(FadeIn(y0_label, UP))
self.play(
y0_tracker.animate.set_value(-2),
run_time=2
)
self.wait()
# Tweak line parameters
for n in range(10):
alpha = random.random()
if alpha > 0.5:
alpha += 1
new_m = interpolate(m_tracker.get_value(), m, alpha)
new_y0 = interpolate(y0_tracker.get_value(), y0, alpha)
self.play(LaggedStart(
m_tracker.animate.set_value(new_m),
y0_tracker.animate.set_value(new_y0),
run_time=1.5,
lag_ratio=0.25,
))
self.wait(0.5)
class ShowGPT3Numbers(InteractiveScene):
def construct(self):
# Title
gpt3_label = Text("GPT-3", font="Consolas", font_size=72)
openai_logo = SVGMobject("OpenAI.svg")
openai_logo.set_fill(WHITE)
openai_logo.set_height(2.0 * gpt3_label.get_height())
title = VGroup(openai_logo, gpt3_label)
title.arrange(RIGHT)
title.to_edge(UP)
self.add(title)
# 175b weights
n_param = 175_181_291_520
weights_count = Integer(n_param, color=BLUE)
weights_text = VGroup(Text("Total parameters:"), weights_count)
weights_text.arrange(RIGHT, buff=MED_SMALL_BUFF)
weights_text.next_to(title, DOWN, buff=1.0)
weights_arrow = Arrow(weights_count, gpt3_label, stroke_width=6, buff=0.2)
param_shape = (8, 24)
pre_dials = Dial().get_grid(*param_shape)
dial_matrix = MobjectMatrix(
pre_dials, *param_shape,
ellipses_row=-2,
ellipses_col=-2,
)
dial_matrix.set_width(FRAME_WIDTH)
dial_matrix.next_to(weights_text, DOWN, buff=MED_SMALL_BUFF)
dials = dial_matrix.get_entries()
dots = dial_matrix.get_ellipses()
self.play(
FadeIn(weights_text[:-1], time_span=(0, 3)),
CountInFrom(weights_count, 0),
GrowArrow(weights_arrow, time_span=(0, 3)),
LaggedStartMap(FadeIn, pre_dials, scale=3, lag_ratio=0.1),
run_time=10,
)
self.play(
LaggedStart(
(dial.animate_set_value(dial.get_random_value())
for dial in dials),
lag_ratio=1.0 / len(dials),
run_time=5
)
)
self.wait()
# Change name to weights
new_name = Text("Total weights: ")
new_name.move_to(weights_text[0], RIGHT)
self.play(
Transform(weights_text[0]["Total"][0], new_name["Total"][0]),
Transform(weights_text[0]["parameters:"][0], new_name["weights:"][0]),
)
self.wait()
# Organize dials into matrices
mat_text = Text("Organized into 27,938 matrices")
mat_text["27,938"].set_color(TEAL)
mat_text.next_to(weights_text, DOWN, buff=MED_SMALL_BUFF)
mat_text.shift((weights_count.get_x(LEFT) - mat_text["27,938"].get_x(LEFT)) * RIGHT)
mat_grid_shape = n, m = (3, 7)
matrices = VGroup(
WeightMatrix(shape=(5, 5))
for n in range(np.product(mat_grid_shape))
)
matrices.arrange_in_grid(
*mat_grid_shape,
v_buff_ratio=0.3,
h_buff_ratio=0.2,
)
matrices.set_width(FRAME_WIDTH - 1)
mat_dots = VGroup(
*(
Tex(R"\dots").next_to(mat, RIGHT)
for mat in matrices[m - 1::m]
),
*(
Tex(R"\vdots").next_to(mat, DOWN)
for mat in matrices[-m:]
)
)
matrices_group = VGroup(matrices, mat_dots)
matrices_group.set_width(FRAME_WIDTH - 1)
matrices_group.next_to(mat_text, DOWN, buff=0.5)
matrices_group.set_x(0)
all_entries = VGroup(
entry
for mat in matrices
for row in mat.get_rows()
for entry in row
)
pre_entries = []
height = all_entries[0].get_height()
for n, entry in enumerate(all_entries):
index = n * len(dials) // len(all_entries)
dial = dials[min(index, len(dials) - 1)].copy()
dial.target = dial.generate_target()
dial.target.set_height(height)
dial.target.move_to(entry)
pre_entries.append(dial)
pre_entries = VGroup(*pre_entries)
self.remove(dial_matrix)
lag_ratio = 1 / len(all_entries)
self.play(
Write(mat_text),
LaggedStartMap(MoveToTarget, pre_entries, lag_ratio=lag_ratio),
TransformFromCopy(dots, mat_dots),
*(FadeIn(mat.get_brackets()) for mat in matrices)
)
self.play(
FadeOut(pre_entries, lag_ratio=0.2 * lag_ratio),
FadeIn(all_entries, lag_ratio=0.2 * lag_ratio),
run_time=2
)
self.add(matrices)
self.wait()
# Show 8 different categories
count_text = VGroup(weights_text, mat_text)
title_scale_factor = 0.75
count_text.target = count_text.generate_target()
count_text.target.scale(title_scale_factor)
count_text.target.to_edge(UP, MED_SMALL_BUFF).to_edge(LEFT)
h_line = Line(LEFT, RIGHT)
h_line.set_width(FRAME_WIDTH)
h_line.next_to(count_text.target, DOWN).set_x(0)
h_line.insert_n_curves(10)
h_line.set_stroke(width=[0, 3, 3, 3, 0])
category_names = VGroup(*map(TexText, [
"Embedding",
"Key",
"Query",
# "Value", # Dumb alignment hack
# "Output",
R"Value$_\downarrow$",
R"Value$_\uparrow$",
"Up-projection",
"Down-projection",
"Unembedding",
]))
# category_names[3][-1].set_fill(BLACK) # Dumb alignment hack
category_names.arrange(DOWN, buff=MED_LARGE_BUFF, aligned_edge=LEFT)
category_names.set_height(5.5)
category_names.next_to(h_line, DOWN, buff=MED_LARGE_BUFF)
category_names.to_edge(LEFT, buff=0.5)
category_names.set_fill(border_width=0.2)
mat_index = 0
counts = [1, * 6 * [3], 1]
mat_groups = VGroup()
for name, count, dots in zip(category_names, counts, mat_dots):
new_mat_index = mat_index + count
mat_group = matrices[mat_index:new_mat_index]
mat_index = new_mat_index
mat_group.target = mat_group.generate_target()
if len(mat_group) > 1:
mat_group.target.add(*mat_group.copy())
mat_group.target.arrange(RIGHT, buff=LARGE_BUFF)
mat_group.target.set_height(0.25)
mat_group.target.next_to(category_names, RIGHT)
mat_group.target.match_y(name)
dots.target = dots.generate_target()
if dots.get_width() < dots.get_height():
dots.target.rotate(90 * DEGREES)
dots.target.next_to(mat_group.target, RIGHT)
mat_groups.add(mat_group)
mat_dots[0].target.set_opacity(0)
mat_dots[7].target.set_opacity(0)
n_groups = len(category_names)
self.play(LaggedStart(
MoveToTarget(count_text),
title.animate.scale(title_scale_factor).next_to(count_text.target, RIGHT, LARGE_BUFF),
FadeOut(weights_arrow),
GrowFromCenter(h_line),
FadeIn(category_names),
LaggedStart(map(MoveToTarget, mat_groups), lag_ratio=0.05),
LaggedStart(map(MoveToTarget, mat_dots[:n_groups]), lag_ratio=0.05),
LaggedStart(map(FadeOut, mat_dots[n_groups:]), lag_ratio=0.05),
FadeOut(matrices[sum(counts):]),
))
# Add lines
h_lines = Line(LEFT, RIGHT).set_width(13).replicate(n_groups)
h_lines.set_stroke(WHITE, 1, 0.5)
for name, line in zip(category_names, h_lines):
line.next_to(name, DOWN, buff=0.1, aligned_edge=LEFT)
name.line = line
v_line = Line(
mat_groups.get_corner(DL) + 0.5 * DOWN,
mat_groups.get_corner(UL) + 0.25 * UP,
)
v_line.shift(SMALL_BUFF * LEFT)
v_line.match_style(h_lines)
self.play(
Write(h_lines),
Write(v_line),
)
self.wait()
# Prepare expressions for parameter counts
const_to_value = {
"n_vocab": 50_257,
"d_embed": 12_288,
"d_query": 128,
"d_value": 128,
"n_heads": 96,
"n_layers": 96,
"n_neurons": 4 * 12_288,
}
const_lists = [
["d_embed", "n_vocab"],
["d_query", "d_embed", "n_heads", "n_layers",],
["d_query", "d_embed", "n_heads", "n_layers",],
["d_value", "d_embed", "n_heads", "n_layers",],
["d_embed", "d_value", "n_heads", "n_layers"],
["n_neurons", "d_embed", "n_layers"],
["d_embed", "n_neurons", "n_layers"],
["n_vocab", "d_embed"],
]
def get_product_expression(category, consts, font_size=30, suffix=None):
values = [const_to_value[const] for const in consts]
result = np.product(values)
result_str = "{:,}".format(result)
expr = VGroup()
expr = Text(
" * ".join(consts) + " = " + result_str,
font_size=font_size,
)
expr.next_to(v_line, RIGHT)
expr.align_to(category.line, DOWN)
expr.shift(0.25 * expr.get_height() * UP)
expr.rhs = expr[result_str]
expr.rhs.set_color(BLUE)
counts = VGroup(
Integer(
const_to_value[const],
font_size=0.8 * font_size,
)
for const in consts
)
counts.next_to(expr, UP, buff=0.05)
for count, const in zip(counts, consts):
count.match_x(expr[const])
counts.set_fill(GREY_B)
result = VGroup(expr, counts)
if suffix is not None:
label = Text(suffix)
label.match_height(expr)
label.next_to(expr, RIGHT, buff=MED_SMALL_BUFF)
result.add(label)
return result
product_expressions = VGroup(
get_product_expression(category, consts)
for category, consts in zip(category_names, const_lists)
)
exprs = [pe[0] for pe in product_expressions]
counts = [pe[1] for pe in product_expressions]
# Embedding
def highlight_category(*indices):
category_names.target = category_names.generate_target()
category_names.target.set_fill(opacity=0.15, border_width=0)
for index in indices:
category_names.target[index].set_fill(opacity=1, border_width=0.5)
return MoveToTarget(category_names)
self.play(
FadeOut(mat_groups),
FadeOut(mat_dots[1:7]),
highlight_category(0)
)
self.play(
FadeIn(exprs[0]),
FadeIn(counts[0], 0.25 * UP),
)
self.wait()
# Unembedding
total = Integer(2 * 12_288 * 50_257)
total.to_edge(RIGHT, buff=1.0)
total.set_color(BLUE)
total_box = SurroundingRectangle(total, buff=0.25)
total_box.set_fill(BLACK, 1)
total_box.set_stroke(WHITE, 2)
lines = VGroup(*(Line(exprs[i].get_right(), total_box) for i in [0, 7]))
lines.set_stroke(BLUE, 2)
self.play(
highlight_category(0, 7),
TransformMatchingStrings(exprs[0].copy(), exprs[7]),
TransformFromCopy(counts[0][0].copy(), counts[7][1]),
TransformFromCopy(counts[0][1].copy(), counts[7][0]),
run_time=2
)
self.wait()
self.play(
ShowCreation(lines, lag_ratio=0),
FadeIn(total_box),
FadeTransform(exprs[0][-11:].copy(), total),
FadeTransform(exprs[7][-11:].copy(), total),
)
self.wait()
self.play(FlashAround(weights_count, time_width=1.5, run_time=2))
self.wait()
self.play(
FadeOut(lines),
FadeOut(total_box),
FadeOut(total),
)
self.wait()
# Attention matrices
covered_categories = [0, 7]
att_categories = [1, 2, 3, 4]
per_head_factors = [
["d_query", "d_embed"],
["d_query", "d_embed"],
["d_value", "d_embed"],
["d_embed", "d_value"],
]
per_head_exprs = VGroup(
get_product_expression(name, factors, suffix="per head")
for name, factors in zip(category_names[1:5], per_head_factors)
)
per_layer_exprs = VGroup(
get_product_expression(name, factors + ["n_heads"], suffix="per layer")
for name, factors in zip(category_names[1:5], per_head_factors)
)
full_att_exprs = product_expressions[1:5]
for group in [per_head_exprs, per_layer_exprs, full_att_exprs]:
sum_box = SurroundingRectangle(
VGroup(expr[0].rhs for expr in group)
)
sum_box.set_stroke(BLUE, 2)
sum_label = Integer(sum(
np.product(list(count.get_value() for count in expr[1]))
for expr in group
))
sum_label.set_color(BLUE)
sum_label.next_to(sum_box, DOWN)
sum_box.add(sum_label)
group.sum_box = sum_box
self.play(
*(
product_expressions[i].animate.set_fill(opacity=0.25, border_width=0)
for i in covered_categories
),
highlight_category(att_categories[0]),
FadeIn(per_head_exprs[0], shift=0.5 * RIGHT)
)
self.wait()
self.play(
LaggedStartMap(FadeIn, per_head_exprs[1:], shift=0.5 *DOWN, lag_ratio=0.5),
highlight_category(*att_categories),
)
self.wait()
self.play(FadeIn(per_head_exprs.sum_box, run_time=3, rate_func=there_and_back_with_pause))
self.wait()
self.play(
FadeOut(per_head_exprs),
FadeIn(per_layer_exprs),
)
self.wait()
self.play(FadeIn(per_layer_exprs.sum_box, run_time=3, rate_func=there_and_back_with_pause))
self.wait()
self.play(
FadeOut(per_layer_exprs),
FadeIn(full_att_exprs),
)
self.wait()
self.play(FadeIn(full_att_exprs.sum_box))
self.wait()
# Compare with total weights
total_weights_rect = SurroundingRectangle(weights_count)
total_weights_rect.set_stroke(BLUE_B, 2)
box = full_att_exprs.sum_box.copy()
box.remove(box.submobjects[0])
self.play(Transform(box, total_weights_rect))
self.wait()
self.play(
FadeOut(box),
FadeOut(full_att_exprs.sum_box),
)
self.wait()
# MLP matrices
mlp_categories = [5, 6]
mlp_exprs = product_expressions[5:7]
per_layer_exprs = VGroup(
get_product_expression(category_names[i], const_lists[i][:2], suffix="per layer")
for i in mlp_categories
)
self.play(
full_att_exprs.animate.set_fill(opacity=0.25, border_width=0),
highlight_category(*mlp_categories),
)
self.wait()
self.play(FadeIn(per_layer_exprs[0]))
self.wait()
self.play(
TransformMatchingStrings(per_layer_exprs[0][0].copy(), per_layer_exprs[1][0]),
TransformFromCopy(per_layer_exprs[0][1][0], per_layer_exprs[1][1][1]),
TransformFromCopy(per_layer_exprs[0][1][1], per_layer_exprs[1][1][0]),
TransformFromCopy(per_layer_exprs[0][2], per_layer_exprs[1][2]),
run_time=1
)
self.wait()
self.play(
FadeOut(per_layer_exprs),
FadeIn(mlp_exprs),
)
self.wait()
# Sum up MLP right hand sides
rhs_rect = SurroundingRectangle(VGroup(expr[0].rhs for expr in mlp_exprs))
rhs_rect.set_stroke(BLUE, 2)
rhs_rect.stretch(1.2, 1, about_edge=DOWN)
c2v = const_to_value
mlp_total = Integer(2 * c2v["n_neurons"] * c2v["d_embed"] * c2v["n_layers"])
mlp_total.next_to(rhs_rect)
mlp_total.set_color(BLUE)
mlp_total_rect = BackgroundRectangle(mlp_total)
mlp_total_rect.set_fill(BLACK, 1)
self.play(
FadeIn(rhs_rect),
FadeIn(mlp_total_rect),
FadeTransform(mlp_exprs[0][0].rhs.copy(), mlp_total),
FadeTransform(mlp_exprs[1][0].rhs.copy(), mlp_total),
)
self.wait()
# Align all right hand sides
self.play(
category_names.animate.set_fill(opacity=1, border_width=0.5),
product_expressions.animate.set_fill(opacity=1, border_width=0.5),
)
all_rhss = VGroup(
VGroup(expr[0]["="][0], expr[0].rhs)
for expr in product_expressions
)
all_rhss.target = all_rhss.generate_target()
for mob in all_rhss.target:
mob.align_to(product_expressions, RIGHT)
mob.shift(0.5 * RIGHT)
all_rhss_rect = SurroundingRectangle(all_rhss.target)
all_rhss_rect.match_style(rhs_rect)
self.play(
FadeOut(mlp_total_rect, RIGHT),
FadeOut(mlp_total, RIGHT),
ReplacementTransform(rhs_rect, all_rhss_rect),
MoveToTarget(all_rhss)
)
self.wait()
# Move weights count
self.play(LaggedStart(
h_line.animate.scale(0.5, about_edge=LEFT),
weights_text.animate.arrange(DOWN).scale(1.5).next_to(all_rhss_rect, UP),
FadeOut(mat_text, LEFT),
title.animate.to_edge(LEFT, buff=2.5),
lag_ratio=0.2,
run_time=2
))
self.wait()
class DistinguishWeightsAndData(InteractiveScene):
def construct(self):
# Set up titles
weights_title, data_title = titles = VGroup(
Text(word, font_size=60)
for word in ["Weights", "Data"]
)
weights_title.set_color(BLUE)
data_title.set_color(GREY_B)
for title, sign in zip(titles, [-1, 1]):
title.set_x(sign * FRAME_WIDTH / 4)
title.to_edge(UP, buff=0.25)
underline = Underline(title, stretch_factor=1.5)
underline.match_color(title)
underline.set_y(title[0].get_y(DOWN) - 0.1)
title.add(underline)
v_line = Line(UP, DOWN).set_height(4.5)
v_line.to_edge(UP, buff=0)
v_line.set_stroke(GREY_A, 2)
# Set up matrices
matrices = VGroup(
WeightMatrix(
shape=(6, 8),
ellipses_row=None,
ellipses_col=None,
)
for n in range(4)
)
matrices.arrange_in_grid(v_buff=1, h_buff=1)
vectors = VGroup(
NumericEmbedding(length=8, ellipses_row=None)
for n in range(8)
)
vectors.arrange(RIGHT)
tensors = VGroup(matrices, vectors)
for group, title in zip(tensors, titles):
group.set_height(2.5)
group.next_to(title, DOWN, buff=0.5)
# Mix up all the numbers
mat_nums = VGroup(
elem
for matrix in matrices
for elem in matrix.get_entries()
)
mat_braces = VGroup(
brace
for matrix in matrices
for brace in matrix.get_brackets()
)
vec_nums = VGroup(
elem
for vector in vectors
for elem in vector.get_entries()
)
vec_braces = VGroup(
brace
for vector in vectors
for brace in vector.get_brackets()
)
def random_point(x_min, x_max, y_min, y_max):
return np.array([
random.uniform(x_min, x_max),
random.uniform(y_min, y_max),
0
])
all_nums = VGroup(*mat_nums, *vec_nums)
all_nums.shuffle()
for num in all_nums:
states = num.replicate(4)
for state in states[1:]:
state.set_height(0.15)
sign = 1 if num in vec_nums else -1
states[1].move_to(random_point(6.5 * sign, 1 * sign, 0, 3.5))
states[2].move_to(random_point(-8, 8, -4, 4))
states[3].move_to(random_point(-8, 8, -4, 4))
states[3].set_opacity(0)
num.states = states
num.become(states[3])
self.add(all_nums)
# Animations
lag_ratio = 1 / len(all_nums)
self.play(
LaggedStart(
(Transform(num, num.states[2], path_arc=PI)
for num in all_nums),
lag_ratio=lag_ratio,
run_time=3
),
)
self.wait()
self.play(
LaggedStart(
(LaggedStart(
(Transform(num, num.states[1])
for num in group),
lag_ratio=lag_ratio,
run_time=2
)
for group in [mat_nums, vec_nums]),
lag_ratio=0.5
),
ShowCreation(v_line),
)
self.play(
Write(weights_title),
LaggedStart(
(Transform(num, num.states[0])
for num in mat_nums),
lag_ratio=lag_ratio,
run_time=2
),
FadeIn(mat_braces, lag_ratio=0.1, time_span=(1, 2)),
)
self.play(
Write(data_title),
LaggedStart(
(Transform(num, num.states[0])
for num in vec_nums),
lag_ratio=lag_ratio,
run_time=2
),
FadeIn(vec_braces, lag_ratio=0.1, time_span=(1, 2)),
)
self.wait()
# Add subtitles
subtitles = VGroup(
Text("What defines the model", font_size=40),
Text("What the model processes", font_size=40),
)
for subtitle, title, group in zip(subtitles, titles, tensors):
subtitle.next_to(title, DOWN)
self.play(
FadeIn(subtitle, lag_ratio=0.1),
group.animate.next_to(subtitle, DOWN, buff=0.5),
)
self.wait()
class SoftmaxBreakdown(InteractiveScene):
def construct(self):
# Show example probability distribution
word_strs = ['Dumbledore', 'Flitwick', 'Mcgonagall', 'Quirrell', 'Snape', 'Sprout', 'Trelawney']
words = VGroup(*(Text(word_str, font_size=30) for word_str in word_strs))
values = np.array([-0.8, -5.0, 0.5, 1.5, 3.4, -2.3, 2.5])
prob_values = softmax(values)
chart = BarChart(prob_values, width=10)
chart.bars.set_stroke(width=1)
probs = VGroup(*(DecimalNumber(pv) for pv in prob_values))
probs.arrange(DOWN, buff=0.25)
probs.generate_target()
for prob, bar in zip(probs.target, chart.bars):
prob.scale(0.5)
prob.next_to(bar, UP)
for word, bar in zip(words, chart.bars):
word.scale(0.75)
height = word.get_height()
word.move_to(bar.get_bottom(), LEFT)
word.rotate(-45 * DEGREES, about_point=bar.get_bottom())
word.shift(height * DOWN)
chart.save_state()
for bar in chart.bars:
bar.stretch(0, 1, about_edge=DOWN)
chart.set_opacity(0)
seq_title = Text("Sequence of numbers", font_size=60)
seq_title.next_to(probs, LEFT, buff=0.75)
seq_title.set_color(YELLOW)
prob_title = Text("Probability distribution", font_size=60)
prob_title.set_color(chart.bars[3].get_color())
prob_title.center().to_edge(UP)
self.play(
LaggedStartMap(FadeIn, probs, shift=0.25 * DOWN, lag_ratio=0.3),
FadeIn(seq_title),
run_time=1
)
self.wait()
self.play(
Restore(chart, lag_ratio=0.1),
MoveToTarget(probs),
FadeTransform(seq_title, prob_title),
)
self.wait()
self.play(
LaggedStartMap(FadeIn, words),
)
self.wait()
# Show constraint between 0 and 1
index = 3
bar = chart.bars[index]
bar.save_state()
prob = probs[index]
prob.bar = bar
max_height = chart.y_axis.get_y(UP) - chart.x_axis.get_y()
prob.f_always.set_value(lambda: prob.bar.get_height() / max_height)
prob.always.match_height(probs[1])
prob.always.next_to(prob.bar, UP)
one_line = DashedLine(*chart.x_axis.get_start_and_end())
one_line.set_stroke(RED, 2)
one_line.align_to(chart.y_axis, UP)
low_line = one_line.copy()
low_line.set_stroke(PINK, 5)
low_line.match_y(chart.x_axis)
self.play(FadeIn(low_line), FadeIn(one_line), FadeOut(prob_title))
self.play(low_line.animate.match_y(one_line))
self.play(FadeOut(low_line))
self.wait()
self.play(
FadeIn(one_line, time_span=(0, 1)),
bar.animate.set_height(max_height, about_edge=DOWN, stretch=True),
run_time=2,
)
self.play(
bar.animate.set_height(1e-4, about_edge=DOWN, stretch=True),
run_time=2,
)
self.play(Restore(bar))
self.wait()
prob.clear_updaters()
# Show sum
prob_copies = probs.copy()
prob_copies.scale(1.5)
prob_copies.arrange(RIGHT, buff=1.0)
prob_copies.to_edge(UP)
prob_copies.shift(LEFT)
plusses = VGroup(*(
Tex("+").move_to(VGroup(p1, p2))
for p1, p2 in zip(prob_copies, prob_copies[1:])
))
equals = Tex("=").next_to(prob_copies, RIGHT)
rhs = DecimalNumber(1.00)
rhs.next_to(equals, RIGHT)
self.play(
TransformFromCopy(probs, prob_copies),
Write(plusses),
Write(equals),
FadeOut(one_line),
)
self.play(
LaggedStart(*(
FadeTransform(pc.copy(), rhs)
for pc in prob_copies
), lag_ratio=0.07)
)
self.wait()
sum_group = VGroup(*prob_copies, *plusses, equals, rhs)
chart_group = VGroup(chart, probs, words)
# Show example matrix vector output
n = len(words)
vector = NumericEmbedding(length=n, ellipses_row=None)
in_values = np.array([e.get_value() for e in vector.elements])
rows = []
for value in values:
row = np.random.uniform(-1, 1, len(in_values))
row *= value / np.dot(row, in_values)
rows.append(row)
matrix_values = np.array(rows)
matrix = WeightMatrix(
values=matrix_values,
ellipses_row=None,
ellipses_col=None,
num_decimal_places=2,
)
for mob in matrix, vector:
mob.set_height(4)
vector.to_edge(UP).set_x(2.5)
matrix.next_to(vector, LEFT)
self.play(LaggedStart(
chart_group.animate.scale(0.35).to_corner(DL),
FadeOut(sum_group, UP),
FadeIn(matrix, UP),
FadeIn(vector, UP),
))
eq, rhs = show_matrix_vector_product(self, matrix, vector, x_max=9)
self.wait()
# Comment on output
rhs_rect = SurroundingRectangle(rhs)
rhs_words = Text("Not at all a\nprobability distribution!")
rhs_words.next_to(rhs_rect, DOWN)
neg_rects = VGroup(*(
SurroundingRectangle(entry)
for entry in rhs.get_entries()
if entry.get_value() < 0
))
gt1_rects = VGroup(*(
SurroundingRectangle(entry)
for entry in rhs.get_entries()
if entry.get_value() > 1
))
VGroup(rhs_rect, neg_rects).set_stroke(RED, 4)
gt1_rects.set_stroke(BLUE, 4)
for rect in (*neg_rects, *gt1_rects):
neg = rect in neg_rects
rect.word = Text("Negative" if neg else "> 1", font_size=36)
rect.word.match_color(rect)
rect.word.next_to(rhs, RIGHT)
rect.word.match_y(rect)
neg_words = VGroup(*(r.word for r in neg_rects))
gt1_words = VGroup(*(r.word for r in gt1_rects))
sum_arrow = Vector(DOWN).next_to(rhs, DOWN)
sum_sym = Tex(R"\sum", font_size=36).next_to(sum_arrow, LEFT)
sum_num = DecimalNumber(sum(e.get_value() for e in rhs.get_entries()))
sum_num.next_to(sum_arrow, DOWN)
self.play(
ShowCreation(rhs_rect),
FadeIn(rhs_words),
)
self.wait()
self.play(
ReplacementTransform(VGroup(rhs_rect), neg_rects),
LaggedStart(*(FadeIn(rect.word, 0.5 * RIGHT) for rect in neg_rects)),
)
self.wait()
self.play(
ReplacementTransform(neg_rects, gt1_rects),
FadeTransformPieces(neg_words, gt1_words),
)
self.wait()
self.play(
LaggedStart(
FadeOut(rhs_words),
FadeOut(gt1_rects),
FadeOut(gt1_words),
),
GrowArrow(sum_arrow),
FadeIn(sum_num, DOWN),
FadeIn(sum_sym),
)
self.wait()
self.play(*map(FadeOut, [sum_arrow, sum_sym, sum_num]))
# Preview softmax application
rhs.generate_target()
rhs.target.to_edge(LEFT, buff=1.5)
rhs.target.set_y(0)
softmax_box = Rectangle(width=5, height=6.5)
softmax_box.set_stroke(BLUE, 2)
softmax_box.set_fill(BLUE_E, 0.5)
in_arrow, out_arrow = Vector(RIGHT).replicate(2)
in_arrow.next_to(rhs.target, RIGHT)
softmax_box.next_to(in_arrow, RIGHT)
out_arrow.next_to(softmax_box, RIGHT)
softmax_label = Text("softmax", font_size=60)
softmax_label.move_to(softmax_box)
rhs_values = np.array([e.get_value() for e in rhs.get_entries()])
dist = softmax(rhs_values)
output = DecimalMatrix(dist.reshape((dist.shape[0], 1)))
output.match_height(rhs)
output.next_to(out_arrow, RIGHT)
bars = chart.bars.copy()
for bar, entry in zip(bars, output.get_entries()):
bar.rotate(-PI / 2)
bar.stretch(2, 0)
bar.next_to(output)
bar.match_y(entry)
self.play(LaggedStart(
FadeOut(matrix, 2 * LEFT),
FadeOut(vector, 3 * LEFT),
FadeOut(eq, 3.5 * LEFT),
FadeOut(chart_group, DL),
GrowArrow(in_arrow),
FadeIn(softmax_box, RIGHT),
FadeIn(softmax_label, RIGHT),
MoveToTarget(rhs),
GrowArrow(out_arrow),
FadeIn(output, RIGHT),
TransformFromCopy(chart.bars, bars),
), lag_ratio=0.2, run_time=2)
self.wait()
# Highlight larger and smaller parts
rhs_entries = rhs.get_entries()
changer = VGroup(rhs_entries, output.get_entries(), bars)
changer.save_state()
for index in range(4, 0, -1):
changer.target = changer.saved_state.copy()
changer.target.set_fill(border_width=0)
for group in changer.target:
for j, elem in enumerate(group):
if j != index:
elem.fade(0.8)
self.play(MoveToTarget(changer))
self.wait()
self.play(Restore(changer))
self.remove(changer)
self.add(rhs, output, bars)
self.wait()
# Swap out for variables
variables = VGroup(*(
Tex(f"x_{{{n}}}", font_size=48).move_to(elem)
for n, elem in enumerate(rhs_entries, start=1)
))
self.remove(rhs_entries)
self.play(
LaggedStart(*(
TransformFromCopy(entry, variable, path_arc=PI / 2)
for entry, variable in zip(rhs_entries, variables)
), lag_ratio=0.1, run_time=1.0)
)
self.wait()
# Exponentiate each part
exp_parts = VGroup(*(
Tex(f"e^{{{var.get_tex()}}}", font_size=48).move_to(var)
for var in variables
))
exp_parts.align_to(softmax_box, LEFT)
exp_parts.shift(0.75 * RIGHT)
exp_parts.space_out_submobjects(1.5)
gt0s = VGroup(
Tex(R"> 0").next_to(exp_part, aligned_edge=DOWN)
for exp_part in exp_parts
)
self.play(
softmax_label.animate.next_to(softmax_box, UP, buff=0.15),
LaggedStart(*(
TransformMatchingStrings(var.copy(), exp_part)
for var, exp_part in zip(variables, exp_parts)
), run_time=1, lag_ratio=0.01)
)
self.play(LaggedStartMap(FadeIn, gt0s, shift=0.5 * RIGHT, lag_ratio=0.25, run_time=1))
self.wait()
self.play(FadeOut(gt0s))
# Compute the sum
exp_sum = Tex(R"\sum_{n=0}^{N-1} e^{x_{n}}", font_size=42)
exp_sum[R"e^{x_{n}}"].scale(1.5, about_edge=LEFT)
exp_sum.next_to(softmax_box.get_right(), LEFT, buff=0.75)
lines = VGroup(*(Line(exp_part.get_right(), exp_sum.get_left(), buff=0.1) for exp_part in exp_parts))
lines.set_stroke(TEAL, 2)
self.play(
LaggedStart(*(
FadeTransform(exp_part.copy(), exp_sum)
for exp_part in exp_parts
), lag_ratio=0.01),
LaggedStartMap(ShowCreation, lines, lag_ratio=0.01),
run_time=1
)
self.wait()
self.play(FadeOut(lines))
# Divide each part by the sum
lil_denoms = VGroup()
for exp_part in exp_parts:
slash = Tex("/").match_height(exp_sum)
slash.next_to(exp_sum, LEFT, buff=0)
denom = VGroup(slash, exp_sum).copy()
denom.set_height(exp_part.get_height() * 1.5)
denom.next_to(exp_part, RIGHT, buff=0)
lil_denoms.add(denom)
lil_denoms.align_to(softmax_box.get_center(), LEFT)
lines = VGroup(*(Line(exp_sum.get_left(), denom.get_center()) for denom in lil_denoms))
lines.set_stroke(TEAL, 1)
self.remove(exp_sum)
self.play(
exp_parts.animate.next_to(lil_denoms, LEFT, buff=0),
LaggedStart(*(
FadeTransform(exp_sum.copy(), denom)
for denom in lil_denoms
), lag_ratio=0.01),
)
self.wait()
# Resize box
sm_terms = VGroup(*(
VGroup(exp_part, denom)
for exp_part, denom in zip(exp_parts, lil_denoms)
))
sm_terms.generate_target()
target_height = 5.0
full_output = Group(output, bars)
full_output.generate_target()
full_output.target.set_height(target_height, about_edge=RIGHT)
full_output.target.shift(1.5 * LEFT)
equals = Tex("=")
equals.next_to(full_output.target, LEFT)
softmax_box.generate_target()
softmax_box.target.set_width(3.0, stretch=True)
VGroup(softmax_box.target, sm_terms.target).set_height(target_height + 0.5).next_to(equals, LEFT)
rhs.generate_target()
rhs_entries.become(variables)
self.remove(variables)
rhs.target.set_height(target_height)
rhs.target.next_to(softmax_box.target, LEFT, buff=1.5)
self.play(
softmax_label.animate.next_to(softmax_box.target, UP),
MoveToTarget(softmax_box),
MoveToTarget(sm_terms),
MoveToTarget(full_output),
MoveToTarget(rhs),
FadeTransform(out_arrow, equals),
in_arrow.animate.become(
Arrow(rhs.target, softmax_box.target).match_style(in_arrow)
),
)
self.wait()
# Set up updaters
output_entries = output.get_entries()
bar_width_ratio = bars.get_width() / max(o.get_value() for o in output_entries)
temp_tracker = ValueTracker(1)
def update_outs(output_entries):
inputs = [entry.get_value() for entry in rhs_entries]
outputs = softmax(inputs, temp_tracker.get_value())
for entry, output in zip(output_entries, outputs):
entry.set_value(output)
def update_bars(bars):
for bar, entry in zip(bars, output_entries):
width = max(bar_width_ratio * entry.get_value(), 1e-3)
bar.set_width(width, about_edge=LEFT, stretch=True)
output_entries.clear_updaters().save_state()
bars.clear_updaters().save_state()
output_entries.add_updater(update_outs)
bars.add_updater(update_bars)
self.add(bars, output_entries)
# Tweak values
index_value_pairs = [
(6, 4.0),
(4, 4.2),
(2, 4.0),
(0, 6.0),
(4, 9.9)
]
# index_value_pairs = [ # For emphasizing a max
# (3, 8.5),
# (6, 8.0),
# (2, 8.1),
# (0, 9.0),
# ]
for index, value in index_value_pairs:
entry = rhs_entries[index]
rect = SurroundingRectangle(entry)
rect.set_stroke(BLUE if value > entry.get_value() else RED, 3)
self.play(
ChangeDecimalToValue(entry, value),
FadeIn(rect, time_span=(0, 1)),
run_time=4
)
self.play(FadeOut(rect))
# Add temperature
frame = self.frame
temp_color = RED
new_title = Text("softmax with temperature")
new_title["temperature"].set_color(temp_color)
get_t = temp_tracker.get_value
t_line = NumberLine(
(0, 10, 0.2),
tick_size=0.025,
big_tick_spacing=1,
longer_tick_multiple=2.0,
width=4
)
t_line.set_stroke(width=1.5)
t_line.next_to(softmax_box, UP)
t_tri = ArrowTip(angle=-90 * DEGREES)
t_tri.set_color(temp_color)
t_tri.set_height(0.2)
t_label = Tex("T = 0.00", font_size=36)
t_label.rhs = t_label.make_number_changeable("0.00")
t_label["T"].set_color(temp_color)
t_tri.add_updater(lambda m: m.move_to(t_line.n2p(get_t()), DOWN))
t_label.add_updater(lambda m: m.rhs.set_value(get_t()))
t_label.add_updater(lambda m: m.next_to(t_tri, UP, buff=0.1, aligned_edge=LEFT))
t_label.update()
new_title.next_to(t_label, UP, buff=0.5).match_x(softmax_box)
self.play(
frame.animate.move_to(0.75 * UP),
TransformMatchingStrings(softmax_label, new_title),
FadeIn(t_line),
FadeIn(t_tri),
FadeIn(t_label),
run_time=1
)
# Change formula
template = Tex(R"e^{x_{0} / T} / \sum_{n=0}^{N - 1} e^{x_n / T}")
template["T"].set_color(temp_color)
template["/"][1].scale(1.9, about_edge=LEFT)
template[R"\sum_{n=0}^{N - 1}"][0].scale(0.7, about_edge=RIGHT)
index_part = template.make_number_changeable("0")
new_sm_terms = VGroup()
all_Ts = VGroup()
for n, term in enumerate(sm_terms, start=1):
template.replace(term, dim_to_match=1)
index_part.set_value(n)
new_term = template.copy()
all_Ts.add(*new_term["T"])
new_sm_terms.add(new_term)
self.play(
LaggedStart(*(
FadeTransform(old_term, new_term)
for old_term, new_term in zip(sm_terms, new_sm_terms)
)),
LaggedStart(*(
TransformFromCopy(t_label[0], t_mob[0])
for t_mob in all_Ts
)),
)
self.wait()
# Oscilate between values
for value in [4, 10, 2]:
self.play(temp_tracker.animate.set_value(value), run_time=8)
self.wait()
self.play(temp_tracker.animate.set_value(0), run_time=3)
max_rects = VGroup(
SurroundingRectangle(rhs.get_entries()[4]),
SurroundingRectangle(VGroup(output.get_entries()[4], bars[4])),
)
self.play(LaggedStartMap(ShowCreationThenFadeOut, max_rects))
self.wait()
for value in [5, 1, 7]:
self.play(temp_tracker.animate.set_value(value), run_time=4)
self.wait()
# Describe logits
prob_arrows, logit_arrows = (
VGroup(*(
Vector(-vect).next_to(entry, vect, buff=0.25)
for entry in matrix.get_entries()
))
for matrix, vect in [(output, RIGHT), (rhs, LEFT)]
)
prob_arrows.next_to(bars, RIGHT)
prob_rects = VGroup(*map(SurroundingRectangle, output.get_entries()))
logit_rects = VGroup(*map(SurroundingRectangle, rhs.get_entries()))
VGroup(prob_rects, logit_rects).set_stroke(width=1)
prob_words = Text("Probabilities")
prob_words.next_to(output, UP, buff=0.25)
logit_words = Text("Logits")
logit_words.next_to(rhs, UP, buff=0.25)
logit_group = VGroup(logit_arrows, logit_words, logit_rects)
logit_group.set_color(TEAL)
prob_group = VGroup(prob_arrows, prob_words, prob_rects)
prob_group.set_color(YELLOW)
for arrows, word, rects in [prob_group, logit_group]:
self.play(
t_line.animate.set_y(3.35),
Write(word),
Write(rects, stroke_width=5, stroke_color=rects[0].get_stroke_color(), lag_ratio=0.3, run_time=3),
)
self.wait()
class CostFunction(InteractiveScene):
def construct(self):
# Add graph
axes = Axes((0, 1, 0.1), (0, 5, 1), width=10, height=6)
axes.center().to_edge(LEFT)
axes.x_axis.add_numbers(num_decimal_places=1)
axes.y_axis.add_numbers(num_decimal_places=0, direction=LEFT)
x_label = Tex("p")
x_label.next_to(axes.x_axis.get_right(), UR)
axes.add(x_label)
graph = axes.get_graph(lambda x: -np.log(x), x_range=(0.001, 10, 0.01))
graph.set_color(RED)
expr = Tex(R"\text{Cost} = -\log(p)", font_size=60)
expr.next_to(axes.i2gp(0.1, graph), UR, buff=0.1)
self.add(axes, graph, expr)
# Add sample phrase
phrase = Text("Watching 3Blue1Brown makes you smarter")
phrase.scale(0.75)
phrase.to_edge(UP)
phrase.align_to(axes.c2p(0.1, 0), LEFT)
pieces = break_into_tokens(phrase)
pieces[-1].set_opacity(0.0)
rects = get_piece_rectangles(pieces, leading_spaces=True, h_buff=0)
self.add(rects, pieces)
# Add predictions
arrow = Vector(0.5 * DOWN)
arrow.next_to(rects[-1], DOWN, SMALL_BUFF)
index = 0
tokens, probs = gpt3_predict_next_token(phrase.get_text()[:-len(" smarter")])
bar_chart = next_token_bar_chart(
tokens[:8], probs[:8],
width_100p=7.0,
bar_space_factor=1.0,
use_percent=False,
)
bar_chart.next_to(arrow, DOWN)
bar_chart.shift(1.25 * RIGHT)
bar_chart.set_opacity(0.5)
bar_chart[index].set_opacity(1.0)
rect = SurroundingRectangle(bar_chart[index])
self.add(arrow, bar_chart, rect)
# Animate in graph
self.play(
ShowCreation(graph, run_time=3),
Write(expr, run_time=2),
)
self.wait()
# Show point on the graph
line = axes.get_line_from_axis_to_point(0, axes.i2gp(probs[index], graph), line_func=Line)
line.set_stroke(YELLOW)
self.play(FadeTransform(rect.copy(), line))
self.wait()