3b1b-videos/_2024/transformers/objects.py
2024-02-22 11:46:00 -08:00

325 lines
No EOL
10 KiB
Python

from __future__ import annotations
from manim_imports_ext import *
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional
from manimlib.typing import Vect3, ManimColor
def softmax(logits, temperature=1.0):
if temperature == 0:
result = np.zeros(len(logits))
result[np.argmax(logits)] = 1
return result
exps = np.exp(np.array(logits) / temperature)
return exps / exps.sum()
def value_to_color(
value,
low_positive_color=BLUE_E,
high_positive_color=BLUE_B,
low_negative_color=RED_E,
high_negative_color=RED_B,
min_value=0.0,
max_value=10.0
):
alpha = float(inverse_interpolate(min_value, max_value, abs(value)))
if value > 0:
colors = (low_positive_color, high_positive_color)
else:
colors = (low_negative_color, high_negative_color)
return interpolate_color_by_hsl(*colors, alpha)
def show_matrix_vector_product(scene, matrix, vector, buff=0.25, x_max=999):
# Show product
eq = Tex("=")
eq.set_width(0.5 * vector.get_width())
shape = (matrix.shape[0], 1)
rhs = NumericEmbedding(
values=x_max * np.ones(shape),
value_range=(-x_max, x_max),
decimal_config=dict(include_sign=True),
ellipses_index=matrix.ellipses_row,
)
rhs.scale(vector.elements[0].get_height() / rhs.elements[0].get_height())
eq.next_to(vector, RIGHT, buff=buff)
rhs.next_to(eq, RIGHT, buff=buff)
scene.play(FadeIn(eq), FadeIn(rhs.get_brackets()))
last_rects = VGroup()
for n, row, entry in zip(it.count(), matrix.get_rows(), rhs[:-2]):
if n == matrix.ellipses_row:
scene.add(entry)
else:
last_rects = matrix_row_vector_product(
scene, row, vector, entry, last_rects
)
scene.play(FadeOut(last_rects))
return eq, rhs
def matrix_row_vector_product(scene, row, vector, entry, to_fade):
def get_rect(elem):
return SurroundingRectangle(elem, buff=0.1).set_stroke(YELLOW, 2)
row_rects = VGroup(*map(get_rect, row))
vect_rects = VGroup(*map(get_rect, vector[:-2]))
partial_values = [0]
for e1, e2 in zip(row, vector[:-2]):
if not isinstance(e1, DecimalNumber) and isinstance(e2, DecimalNumber):
increment = 0
else:
increment = e1.get_value() * e2.get_value()
partial_values.append(partial_values[-1] + increment)
n_values = len(partial_values)
scene.play(
ShowIncreasingSubsets(row_rects),
ShowIncreasingSubsets(vect_rects),
UpdateFromAlphaFunc(entry, lambda m, a: m.set_value(
partial_values[min(int(np.round(a * n_values)), n_values - 1)]
)),
FadeOut(to_fade),
rate_func=linear,
)
return VGroup(row_rects, vect_rects)
def data_flying_animation(point, vect=2 * DOWN + RIGHT, color=GREY_C, max_opacity=0.75):
word = Text("Data", color=color)
return UpdateFromAlphaFunc(
word, lambda m, a: m.move_to(
interpolate(point, point + vect, a)
).set_opacity(there_and_back(a) * max_opacity)
)
def data_modifying_matrix(scene, matrix, word_shape=(5, 10), alpha_maxes=(0.7, 0.9)):
x_min, x_max = [matrix.get_x(LEFT), matrix.get_x(RIGHT)]
y_min, y_max = [matrix.get_y(UP), matrix.get_y(DOWN)]
z = matrix.get_z()
points = np.array([
[
interpolate(x_min, x_max, a1),
interpolate(y_min, y_max, a2),
z,
]
for a1 in np.linspace(0, alpha_maxes[1], word_shape[1])
for a2 in np.linspace(0, alpha_maxes[0], word_shape[0])
])
scene.play(
LaggedStart(*map(data_flying_animation, points), lag_ratio=1 / len(points), run_time=3),
RandomizeMatrixEntries(matrix, run_time=3),
)
class ContextAnimation(LaggedStart):
def __init__(
self,
target,
sources,
direction=UP,
hue_range=(0.1, 0.3),
time_width=2,
min_stroke_width=0,
max_stroke_width=5,
lag_ratio=None,
run_time=3,
fix_in_frame=False,
path_arc=PI / 2,
**kwargs,
):
arcs = VGroup()
for source in sources:
sign = direction[1] * (-1)**int(source.get_x() < target.get_x())
arcs.add(Line(
source.get_edge_center(direction),
target.get_edge_center(direction),
path_arc=sign * path_arc,
stroke_color=random_bright_color(hue_range=hue_range),
stroke_width=interpolate(
min_stroke_width,
max_stroke_width,
random.random()**2,
)
))
if fix_in_frame:
arcs.fix_in_frame()
arcs.shuffle()
lag_ratio = 0.5 / len(arcs) if lag_ratio is None else lag_ratio
super().__init__(
*(
VShowPassingFlash(arc, time_width=time_width)
for arc in arcs
),
lag_ratio=lag_ratio,
run_time=run_time,
**kwargs,
)
class LabeledArrow(Arrow):
def __init__(
self,
*args,
label_text: Optional[str] = None,
font_size: float = 24,
label_buff: float = 0.1,
direction: Optional[Vect3] = None,
label_rotation: float = PI / 2,
**kwargs
):
super().__init__(*args, **kwargs)
if label_text is not None:
start, end = self.get_start_and_end()
label = Text(label_text, font_size=font_size)
label.set_fill(self.get_color())
label.set_backstroke()
label.rotate(label_rotation, RIGHT)
if direction is None:
direction = normalize(end - start)
label.next_to(end, direction, buff=label_buff)
self.label = label
else:
self.label = None
class WeightMatrix(DecimalMatrix):
def __init__(
self,
values: Optional[np.ndarray] = None,
shape: tuple[int, int] = (6, 8),
value_range: tuple[float, float] = (-9.9, 9.9),
ellipses_row: Optional[int] = -2,
ellipses_col: Optional[int] = -2,
num_decimal_places: int = 1,
bracket_h_buff: float = 0.1,
decimal_config=dict(include_sign=True),
low_positive_color: ManimColor = BLUE_E,
high_positive_color: ManimColor = BLUE_B,
low_negative_color: ManimColor = RED_E,
high_negative_color: ManimColor = RED_B,
):
if values is not None:
shape = values.shape
self.shape = shape
self.value_range = value_range
self.low_positive_color = low_positive_color
self.high_positive_color = high_positive_color
self.low_negative_color = low_negative_color
self.high_negative_color = high_negative_color
self.ellipses_row = ellipses_row
self.ellipses_col = ellipses_col
if values is None:
values = np.random.uniform(*self.value_range, size=shape)
super().__init__(
values,
num_decimal_places=num_decimal_places,
bracket_h_buff=bracket_h_buff,
decimal_config=decimal_config,
ellipses_row=ellipses_row,
ellipses_col=ellipses_col,
)
self.reset_entry_colors()
def reset_entry_colors(self):
for entry in self.get_entries():
entry.set_fill(color=value_to_color(
entry.get_value(),
self.low_positive_color,
self.high_positive_color,
self.low_negative_color,
self.high_negative_color,
0, max(self.value_range),
))
return self
class NumericEmbedding(WeightMatrix):
def __init__(
self,
values: Optional[np.ndarray] = None,
shape: Optional[int] = None,
length: int = 7,
num_decimal_places: int = 1,
ellipses_index: int = -2,
value_range: tuple[float, float] = (0, 9.9),
bracket_h_buff: float = 0.1,
decimal_config=dict(),
dark_color: ManimColor = GREY_C,
light_color: ManimColor = WHITE,
):
if values is not None:
if len(values.shape) == 1:
values = values.reshape((values.shape[0], 1))
shape = values.shape
if shape is None:
shape = (length, 1)
super().__init__(
values,
shape=shape,
value_range=value_range,
num_decimal_places=num_decimal_places,
bracket_h_buff=bracket_h_buff,
decimal_config=decimal_config,
low_positive_color=dark_color,
high_positive_color=light_color,
low_negative_color=dark_color,
high_negative_color=light_color,
ellipses_row=ellipses_index,
ellipses_col=None,
)
class RandomizeMatrixEntries(Animation):
def __init__(self, matrix, **kwargs):
self.matrix = matrix
self.entries = matrix.get_entries()
self.start_values = [entry.get_value() for entry in self.entries]
self.target_values = np.random.uniform(
matrix.value_range[0],
matrix.value_range[1],
len(self.entries)
)
super().__init__(matrix, **kwargs)
def interpolate_mobject(self, alpha: float) -> None:
for index, entry in enumerate(self.entries):
start = self.start_values[index]
target = self.target_values[index]
sub_alpha = self.get_sub_alpha(alpha, index, len(self.entries))
entry.set_value(interpolate(start, target, sub_alpha))
self.matrix.reset_entry_colors()
class Test(InteractiveScene):
def construct(self):
# Test
matrix = WeightMatrix(
shape=(10, 10),
ellipses_row=-2,
ellipses_col=-2,
)
matrix.set_width(12)
self.add(matrix)
self.play(RandomizeMatrixEntries(matrix, lag_ratio=0.02))
class EmbeddingSequence(MobjectMatrix):
pass
class AbstractEmbeddingSequence(MobjectMatrix):
pass