mirror of
https://github.com/3b1b/videos.git
synced 2025-08-31 21:58:59 +00:00
325 lines
No EOL
10 KiB
Python
325 lines
No EOL
10 KiB
Python
from __future__ import annotations
|
|
|
|
from manim_imports_ext import *
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
if TYPE_CHECKING:
|
|
from typing import Optional
|
|
from manimlib.typing import Vect3, ManimColor
|
|
|
|
|
|
def softmax(logits, temperature=1.0):
|
|
if temperature == 0:
|
|
result = np.zeros(len(logits))
|
|
result[np.argmax(logits)] = 1
|
|
return result
|
|
exps = np.exp(np.array(logits) / temperature)
|
|
return exps / exps.sum()
|
|
|
|
|
|
def value_to_color(
|
|
value,
|
|
low_positive_color=BLUE_E,
|
|
high_positive_color=BLUE_B,
|
|
low_negative_color=RED_E,
|
|
high_negative_color=RED_B,
|
|
min_value=0.0,
|
|
max_value=10.0
|
|
):
|
|
alpha = float(inverse_interpolate(min_value, max_value, abs(value)))
|
|
if value > 0:
|
|
colors = (low_positive_color, high_positive_color)
|
|
else:
|
|
colors = (low_negative_color, high_negative_color)
|
|
return interpolate_color_by_hsl(*colors, alpha)
|
|
|
|
|
|
def show_matrix_vector_product(scene, matrix, vector, buff=0.25, x_max=999):
|
|
# Show product
|
|
eq = Tex("=")
|
|
eq.set_width(0.5 * vector.get_width())
|
|
shape = (matrix.shape[0], 1)
|
|
rhs = NumericEmbedding(
|
|
values=x_max * np.ones(shape),
|
|
value_range=(-x_max, x_max),
|
|
decimal_config=dict(include_sign=True),
|
|
ellipses_index=matrix.ellipses_row,
|
|
)
|
|
rhs.scale(vector.elements[0].get_height() / rhs.elements[0].get_height())
|
|
eq.next_to(vector, RIGHT, buff=buff)
|
|
rhs.next_to(eq, RIGHT, buff=buff)
|
|
|
|
scene.play(FadeIn(eq), FadeIn(rhs.get_brackets()))
|
|
|
|
last_rects = VGroup()
|
|
for n, row, entry in zip(it.count(), matrix.get_rows(), rhs[:-2]):
|
|
if n == matrix.ellipses_row:
|
|
scene.add(entry)
|
|
else:
|
|
last_rects = matrix_row_vector_product(
|
|
scene, row, vector, entry, last_rects
|
|
)
|
|
scene.play(FadeOut(last_rects))
|
|
|
|
return eq, rhs
|
|
|
|
|
|
def matrix_row_vector_product(scene, row, vector, entry, to_fade):
|
|
def get_rect(elem):
|
|
return SurroundingRectangle(elem, buff=0.1).set_stroke(YELLOW, 2)
|
|
|
|
row_rects = VGroup(*map(get_rect, row))
|
|
vect_rects = VGroup(*map(get_rect, vector[:-2]))
|
|
partial_values = [0]
|
|
for e1, e2 in zip(row, vector[:-2]):
|
|
if not isinstance(e1, DecimalNumber) and isinstance(e2, DecimalNumber):
|
|
increment = 0
|
|
else:
|
|
increment = e1.get_value() * e2.get_value()
|
|
partial_values.append(partial_values[-1] + increment)
|
|
n_values = len(partial_values)
|
|
|
|
scene.play(
|
|
ShowIncreasingSubsets(row_rects),
|
|
ShowIncreasingSubsets(vect_rects),
|
|
UpdateFromAlphaFunc(entry, lambda m, a: m.set_value(
|
|
partial_values[min(int(np.round(a * n_values)), n_values - 1)]
|
|
)),
|
|
FadeOut(to_fade),
|
|
rate_func=linear,
|
|
)
|
|
|
|
return VGroup(row_rects, vect_rects)
|
|
|
|
|
|
def data_flying_animation(point, vect=2 * DOWN + RIGHT, color=GREY_C, max_opacity=0.75):
|
|
word = Text("Data", color=color)
|
|
return UpdateFromAlphaFunc(
|
|
word, lambda m, a: m.move_to(
|
|
interpolate(point, point + vect, a)
|
|
).set_opacity(there_and_back(a) * max_opacity)
|
|
)
|
|
|
|
|
|
def data_modifying_matrix(scene, matrix, word_shape=(5, 10), alpha_maxes=(0.7, 0.9)):
|
|
x_min, x_max = [matrix.get_x(LEFT), matrix.get_x(RIGHT)]
|
|
y_min, y_max = [matrix.get_y(UP), matrix.get_y(DOWN)]
|
|
z = matrix.get_z()
|
|
points = np.array([
|
|
[
|
|
interpolate(x_min, x_max, a1),
|
|
interpolate(y_min, y_max, a2),
|
|
z,
|
|
]
|
|
for a1 in np.linspace(0, alpha_maxes[1], word_shape[1])
|
|
for a2 in np.linspace(0, alpha_maxes[0], word_shape[0])
|
|
])
|
|
scene.play(
|
|
LaggedStart(*map(data_flying_animation, points), lag_ratio=1 / len(points), run_time=3),
|
|
RandomizeMatrixEntries(matrix, run_time=3),
|
|
)
|
|
|
|
|
|
class ContextAnimation(LaggedStart):
|
|
def __init__(
|
|
self,
|
|
target,
|
|
sources,
|
|
direction=UP,
|
|
hue_range=(0.1, 0.3),
|
|
time_width=2,
|
|
min_stroke_width=0,
|
|
max_stroke_width=5,
|
|
lag_ratio=None,
|
|
run_time=3,
|
|
fix_in_frame=False,
|
|
path_arc=PI / 2,
|
|
**kwargs,
|
|
):
|
|
arcs = VGroup()
|
|
for source in sources:
|
|
sign = direction[1] * (-1)**int(source.get_x() < target.get_x())
|
|
arcs.add(Line(
|
|
source.get_edge_center(direction),
|
|
target.get_edge_center(direction),
|
|
path_arc=sign * path_arc,
|
|
stroke_color=random_bright_color(hue_range=hue_range),
|
|
stroke_width=interpolate(
|
|
min_stroke_width,
|
|
max_stroke_width,
|
|
random.random()**2,
|
|
)
|
|
))
|
|
if fix_in_frame:
|
|
arcs.fix_in_frame()
|
|
arcs.shuffle()
|
|
lag_ratio = 0.5 / len(arcs) if lag_ratio is None else lag_ratio
|
|
|
|
super().__init__(
|
|
*(
|
|
VShowPassingFlash(arc, time_width=time_width)
|
|
for arc in arcs
|
|
),
|
|
lag_ratio=lag_ratio,
|
|
run_time=run_time,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
class LabeledArrow(Arrow):
|
|
def __init__(
|
|
self,
|
|
*args,
|
|
label_text: Optional[str] = None,
|
|
font_size: float = 24,
|
|
label_buff: float = 0.1,
|
|
direction: Optional[Vect3] = None,
|
|
label_rotation: float = PI / 2,
|
|
**kwargs
|
|
):
|
|
super().__init__(*args, **kwargs)
|
|
if label_text is not None:
|
|
start, end = self.get_start_and_end()
|
|
label = Text(label_text, font_size=font_size)
|
|
label.set_fill(self.get_color())
|
|
label.set_backstroke()
|
|
label.rotate(label_rotation, RIGHT)
|
|
if direction is None:
|
|
direction = normalize(end - start)
|
|
label.next_to(end, direction, buff=label_buff)
|
|
self.label = label
|
|
else:
|
|
self.label = None
|
|
|
|
|
|
class WeightMatrix(DecimalMatrix):
|
|
def __init__(
|
|
self,
|
|
values: Optional[np.ndarray] = None,
|
|
shape: tuple[int, int] = (6, 8),
|
|
value_range: tuple[float, float] = (-9.9, 9.9),
|
|
ellipses_row: Optional[int] = -2,
|
|
ellipses_col: Optional[int] = -2,
|
|
num_decimal_places: int = 1,
|
|
bracket_h_buff: float = 0.1,
|
|
decimal_config=dict(include_sign=True),
|
|
low_positive_color: ManimColor = BLUE_E,
|
|
high_positive_color: ManimColor = BLUE_B,
|
|
low_negative_color: ManimColor = RED_E,
|
|
high_negative_color: ManimColor = RED_B,
|
|
):
|
|
if values is not None:
|
|
shape = values.shape
|
|
self.shape = shape
|
|
self.value_range = value_range
|
|
self.low_positive_color = low_positive_color
|
|
self.high_positive_color = high_positive_color
|
|
self.low_negative_color = low_negative_color
|
|
self.high_negative_color = high_negative_color
|
|
self.ellipses_row = ellipses_row
|
|
self.ellipses_col = ellipses_col
|
|
|
|
if values is None:
|
|
values = np.random.uniform(*self.value_range, size=shape)
|
|
|
|
super().__init__(
|
|
values,
|
|
num_decimal_places=num_decimal_places,
|
|
bracket_h_buff=bracket_h_buff,
|
|
decimal_config=decimal_config,
|
|
ellipses_row=ellipses_row,
|
|
ellipses_col=ellipses_col,
|
|
)
|
|
self.reset_entry_colors()
|
|
|
|
def reset_entry_colors(self):
|
|
for entry in self.get_entries():
|
|
entry.set_fill(color=value_to_color(
|
|
entry.get_value(),
|
|
self.low_positive_color,
|
|
self.high_positive_color,
|
|
self.low_negative_color,
|
|
self.high_negative_color,
|
|
0, max(self.value_range),
|
|
))
|
|
return self
|
|
|
|
|
|
class NumericEmbedding(WeightMatrix):
|
|
def __init__(
|
|
self,
|
|
values: Optional[np.ndarray] = None,
|
|
shape: Optional[int] = None,
|
|
length: int = 7,
|
|
num_decimal_places: int = 1,
|
|
ellipses_index: int = -2,
|
|
value_range: tuple[float, float] = (0, 9.9),
|
|
bracket_h_buff: float = 0.1,
|
|
decimal_config=dict(),
|
|
dark_color: ManimColor = GREY_C,
|
|
light_color: ManimColor = WHITE,
|
|
):
|
|
if values is not None:
|
|
if len(values.shape) == 1:
|
|
values = values.reshape((values.shape[0], 1))
|
|
shape = values.shape
|
|
if shape is None:
|
|
shape = (length, 1)
|
|
super().__init__(
|
|
values,
|
|
shape=shape,
|
|
value_range=value_range,
|
|
num_decimal_places=num_decimal_places,
|
|
bracket_h_buff=bracket_h_buff,
|
|
decimal_config=decimal_config,
|
|
low_positive_color=dark_color,
|
|
high_positive_color=light_color,
|
|
low_negative_color=dark_color,
|
|
high_negative_color=light_color,
|
|
ellipses_row=ellipses_index,
|
|
ellipses_col=None,
|
|
)
|
|
|
|
|
|
class RandomizeMatrixEntries(Animation):
|
|
def __init__(self, matrix, **kwargs):
|
|
self.matrix = matrix
|
|
self.entries = matrix.get_entries()
|
|
self.start_values = [entry.get_value() for entry in self.entries]
|
|
self.target_values = np.random.uniform(
|
|
matrix.value_range[0],
|
|
matrix.value_range[1],
|
|
len(self.entries)
|
|
)
|
|
super().__init__(matrix, **kwargs)
|
|
|
|
def interpolate_mobject(self, alpha: float) -> None:
|
|
for index, entry in enumerate(self.entries):
|
|
start = self.start_values[index]
|
|
target = self.target_values[index]
|
|
sub_alpha = self.get_sub_alpha(alpha, index, len(self.entries))
|
|
entry.set_value(interpolate(start, target, sub_alpha))
|
|
self.matrix.reset_entry_colors()
|
|
|
|
|
|
class Test(InteractiveScene):
|
|
def construct(self):
|
|
# Test
|
|
matrix = WeightMatrix(
|
|
shape=(10, 10),
|
|
ellipses_row=-2,
|
|
ellipses_col=-2,
|
|
)
|
|
matrix.set_width(12)
|
|
self.add(matrix)
|
|
|
|
self.play(RandomizeMatrixEntries(matrix, lag_ratio=0.02))
|
|
|
|
|
|
class EmbeddingSequence(MobjectMatrix):
|
|
pass
|
|
|
|
|
|
class AbstractEmbeddingSequence(MobjectMatrix):
|
|
pass |