mirror of
https://github.com/3b1b/videos.git
synced 2025-08-31 21:58:59 +00:00
819 lines
25 KiB
Python
819 lines
25 KiB
Python
from __future__ import annotations
|
|
|
|
from manim_imports_ext import *
|
|
|
|
from typing import TYPE_CHECKING
|
|
import warnings
|
|
# import datasets
|
|
|
|
DATA_DIR = Path(get_output_dir(), "2024/transformers/data/")
|
|
WORD_FILE = Path(DATA_DIR, "OWL3_Dictionary.txt")
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from typing import Optional
|
|
from manimlib.typing import Vect3, ManimColor
|
|
|
|
|
|
def get_paragraph(words, line_len=40, font_size=48):
|
|
"""
|
|
Handle word wrapping
|
|
"""
|
|
words = list(map(str.strip, words))
|
|
word_lens = list(map(len, words))
|
|
lines = []
|
|
lh, rh = 0, 0
|
|
while rh < len(words):
|
|
rh += 1
|
|
if sum(word_lens[lh:rh]) > line_len:
|
|
rh -= 1
|
|
lines.append(words[lh:rh])
|
|
lh = rh
|
|
lines.append(words[lh:])
|
|
text = "\n".join([" ".join(line).strip() for line in lines])
|
|
return Text(text, alignment="LEFT", font_size=font_size)
|
|
|
|
|
|
def softmax(logits, temperature=1.0):
|
|
logits = np.array(logits)
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings('ignore') # Ignore all warnings within this block
|
|
logits = logits - np.max(logits) # For numerical stability
|
|
exps = np.exp(np.divide(logits, temperature, where=temperature != 0))
|
|
|
|
if np.isinf(exps).any() or np.isnan(exps).any() or temperature == 0:
|
|
result = np.zeros_like(logits)
|
|
result[np.argmax(logits)] = 1
|
|
return result
|
|
return exps / np.sum(exps)
|
|
|
|
|
|
def value_to_color(
|
|
value,
|
|
low_positive_color=BLUE_E,
|
|
high_positive_color=BLUE_B,
|
|
low_negative_color=RED_E,
|
|
high_negative_color=RED_B,
|
|
min_value=0.0,
|
|
max_value=10.0
|
|
):
|
|
alpha = clip(float(inverse_interpolate(min_value, max_value, abs(value))), 0, 1)
|
|
if value >= 0:
|
|
colors = (low_positive_color, high_positive_color)
|
|
else:
|
|
colors = (low_negative_color, high_negative_color)
|
|
return interpolate_color_by_hsl(*colors, alpha)
|
|
|
|
|
|
def read_in_book(name="tale_of_two_cities"):
|
|
return Path(DATA_DIR, name).with_suffix(".txt").read_text()
|
|
|
|
|
|
def load_image_net_data(dataset_name="image_net_1k"):
|
|
data_path = Path(Path.home(), "Documents", dataset_name)
|
|
image_dir = Path(data_path, "images")
|
|
label_category_path = Path(DATA_DIR, "image_categories.txt")
|
|
image_label_path = Path(data_path, "image_labels.txt")
|
|
|
|
if not os.path.exists(image_dir):
|
|
os.makedirs(image_dir)
|
|
image_data = datasets.load_from_disk(str(data_path))
|
|
indices = range(len(image_data))
|
|
categories = label_category_path.read_text().split("\n")
|
|
labels = [categories[image_data[index]['label']] for index in indices]
|
|
image_label_path.write_text("\n".join(labels))
|
|
for index in ProgressDisplay(indices):
|
|
image = image_data[index]['image']
|
|
image.save(str(Path(image_dir, f"{index}.jpeg")))
|
|
|
|
|
|
labels = image_label_path.read_text().split("\n")
|
|
return [
|
|
(Path(image_dir, f"{index}.jpeg"), label)
|
|
for index, label in enumerate(labels)
|
|
]
|
|
|
|
|
|
def show_matrix_vector_product(scene, matrix, vector, buff=0.25, x_max=999, fix_in_frame=False):
|
|
# Show product
|
|
eq = Tex("=")
|
|
eq.set_width(0.5 * vector.get_width())
|
|
shape = (matrix.shape[0], 1)
|
|
rhs = NumericEmbedding(
|
|
values=x_max * np.ones(shape),
|
|
value_range=(-x_max, x_max),
|
|
decimal_config=dict(include_sign=True, edge_to_fix=ORIGIN),
|
|
ellipses_row=matrix.ellipses_row,
|
|
)
|
|
rhs.scale(vector.elements[0].get_height() / rhs.elements[0].get_height())
|
|
eq.next_to(vector, RIGHT, buff=buff)
|
|
rhs.next_to(eq, RIGHT, buff=buff)
|
|
if fix_in_frame:
|
|
eq.fix_in_frame()
|
|
rhs.fix_in_frame()
|
|
|
|
scene.play(FadeIn(eq), FadeIn(rhs.get_brackets()))
|
|
|
|
last_rects = VGroup()
|
|
n_rows = len(matrix.rows)
|
|
for n, row, entry in zip(it.count(), matrix.get_rows(), rhs[:-2]):
|
|
if matrix.ellipses_row is not None and n == (matrix.ellipses_row % n_rows):
|
|
scene.add(entry)
|
|
else:
|
|
last_rects = matrix_row_vector_product(
|
|
scene, row, vector, entry, last_rects,
|
|
fix_in_frame=fix_in_frame
|
|
)
|
|
scene.play(FadeOut(last_rects))
|
|
|
|
return eq, rhs
|
|
|
|
|
|
def matrix_row_vector_product(scene, row, vector, entry, to_fade, fix_in_frame=False):
|
|
def get_rect(elem):
|
|
return SurroundingRectangle(elem, buff=0.1, is_fixed_in_frame=fix_in_frame).set_stroke(YELLOW, 2)
|
|
|
|
row_rects = VGroup(*map(get_rect, row))
|
|
vect_rects = VGroup(*map(get_rect, vector[:-2]))
|
|
partial_values = [0]
|
|
for e1, e2 in zip(row, vector[:-2]):
|
|
if not isinstance(e1, DecimalNumber) and isinstance(e2, DecimalNumber):
|
|
increment = 0
|
|
else:
|
|
val1 = round(e1.get_value(), e1.num_decimal_places)
|
|
val2 = round(e2.get_value(), e2.num_decimal_places)
|
|
increment = val1 * val2
|
|
partial_values.append(partial_values[-1] + increment)
|
|
n_values = len(partial_values)
|
|
|
|
scene.play(
|
|
ShowIncreasingSubsets(row_rects),
|
|
ShowIncreasingSubsets(vect_rects),
|
|
UpdateFromAlphaFunc(entry, lambda m, a: m.set_value(
|
|
partial_values[min(int(np.round(a * n_values)), n_values - 1)]
|
|
)),
|
|
FadeOut(to_fade),
|
|
rate_func=linear,
|
|
)
|
|
|
|
return VGroup(row_rects, vect_rects)
|
|
|
|
|
|
def get_full_matrix_vector_product(
|
|
mat_sym="w",
|
|
vect_sym="x",
|
|
n_rows=5,
|
|
n_cols=5,
|
|
mat_sym_color=BLUE,
|
|
height=3.0,
|
|
ellipses_row=-2,
|
|
ellipses_col=-2,
|
|
):
|
|
m_indices = list(map(str, [*range(1, n_cols), "m"]))
|
|
n_indices = list(map(str, [*range(1, n_rows), "n"]))
|
|
matrix = TexMatrix(
|
|
[
|
|
[Rf"{mat_sym}_{{{m}, {n}}}" for n in n_indices]
|
|
for m in m_indices
|
|
],
|
|
ellipses_row=ellipses_row,
|
|
ellipses_col=ellipses_col,
|
|
)
|
|
matrix.set_height(height)
|
|
matrix.get_entries().set_color(mat_sym_color)
|
|
vector = TexMatrix(
|
|
[[Rf"x_{{{n}}}"] for n in n_indices],
|
|
ellipses_row=ellipses_row,
|
|
)
|
|
vector.match_height(matrix)
|
|
vector.next_to(matrix, RIGHT)
|
|
equals = Tex("=", font_size=72)
|
|
equals.next_to(vector, RIGHT)
|
|
|
|
result_terms = [
|
|
[Rf"w_{{{m}, {n}}} x_{n}" for n in n_indices]
|
|
for m in m_indices
|
|
]
|
|
rhs = TexMatrix(
|
|
result_terms,
|
|
ellipses_row=ellipses_row,
|
|
ellipses_col=ellipses_col,
|
|
)
|
|
rhs.match_height(matrix)
|
|
rhs.next_to(equals, RIGHT)
|
|
for m, row in enumerate(rhs.get_rows()):
|
|
if m == (ellipses_row % len(m_indices)):
|
|
continue
|
|
for n, entry in enumerate(row):
|
|
if n != (ellipses_col % len(n_indices)):
|
|
entry[:4].set_color(mat_sym_color)
|
|
for e1, e2 in zip(row, row[1:]):
|
|
plus = Tex("+")
|
|
plus.match_height(e1)
|
|
points = [e1.get_right(), e2.get_left()]
|
|
plus.move_to(midpoint(*points))
|
|
plus.align_to(e1, UP)
|
|
e2.add(plus)
|
|
|
|
return matrix, vector, equals, rhs
|
|
|
|
|
|
def show_symbolic_matrix_vector_product(scene, matrix, vector, rhs, run_time_per_row=0.75):
|
|
last_rects = VGroup()
|
|
for mat_row, rhs_row in zip(matrix.get_rows(), rhs.get_rows()):
|
|
mat_rects = VGroup(*map(SurroundingRectangle, mat_row))
|
|
vect_rects = VGroup(*map(SurroundingRectangle, vector.get_columns()[0]))
|
|
rect_group = VGroup(mat_rects, vect_rects)
|
|
rect_group.set_stroke(YELLOW, 2)
|
|
scene.play(
|
|
FadeOut(last_rects),
|
|
*(
|
|
ShowIncreasingSubsets(group, rate_func=linear)
|
|
for group in [mat_rects, vect_rects, rhs_row]
|
|
),
|
|
run_time=run_time_per_row,
|
|
)
|
|
last_rects = rect_group
|
|
scene.play(FadeOut(last_rects))
|
|
|
|
|
|
def data_flying_animation(
|
|
point,
|
|
vect=2 * DOWN + RIGHT,
|
|
color=GREY_C,
|
|
max_opacity=0.75,
|
|
font_size=48,
|
|
fix_in_frame=False
|
|
):
|
|
word = Text("Data", color=color, font_size=font_size)
|
|
if fix_in_frame:
|
|
word.fix_in_frame()
|
|
return UpdateFromAlphaFunc(
|
|
word, lambda m, a: m.move_to(
|
|
interpolate(point, point + vect, a)
|
|
).set_opacity(there_and_back(a) * max_opacity)
|
|
)
|
|
|
|
|
|
def get_data_modifying_matrix_anims(
|
|
matrix,
|
|
word_shape=(5, 10),
|
|
alpha_maxes=(0.7, 0.9),
|
|
shift_vect=2 * DOWN + RIGHT,
|
|
run_time=3,
|
|
fix_in_frame=False,
|
|
font_size=48,
|
|
):
|
|
x_min, x_max = [matrix.get_x(LEFT), matrix.get_x(RIGHT)]
|
|
y_min, y_max = [matrix.get_y(UP), matrix.get_y(DOWN)]
|
|
z = matrix.get_z()
|
|
points = np.array([
|
|
[
|
|
interpolate(x_min, x_max, a1),
|
|
interpolate(y_min, y_max, a2),
|
|
z,
|
|
]
|
|
for a1 in np.linspace(0, alpha_maxes[1], word_shape[1])
|
|
for a2 in np.linspace(0, alpha_maxes[0], word_shape[0])
|
|
])
|
|
return [
|
|
LaggedStart(
|
|
(data_flying_animation(p, vect=shift_vect, fix_in_frame=fix_in_frame, font_size=font_size)
|
|
for p in points),
|
|
lag_ratio=1 / len(points),
|
|
run_time=run_time
|
|
),
|
|
RandomizeMatrixEntries(matrix, run_time=run_time),
|
|
]
|
|
|
|
|
|
def data_modifying_matrix(scene, matrix, *args, **kwargs):
|
|
anims = get_data_modifying_matrix_anims(matrix, *args, **kwargs)
|
|
scene.play(*anims)
|
|
|
|
|
|
def create_pixels(image_mob, pixel_width=0.1):
|
|
x0, y0, z0 = image_mob.get_corner(UL)
|
|
x1, y1, z1 = image_mob.get_corner(DR)
|
|
points = np.array([
|
|
[x, y, 0]
|
|
for y in np.arange(y0, y1, -pixel_width)
|
|
for x in np.arange(x0, x1, pixel_width)
|
|
])
|
|
square = Square(pixel_width).set_fill(WHITE, 1).set_stroke(width=0)
|
|
pixels = VGroup(
|
|
square.copy().move_to(point, UL).set_color(
|
|
Color(rgb=image_mob.point_to_rgb(point))
|
|
)
|
|
for point in points
|
|
)
|
|
return pixels
|
|
|
|
|
|
def get_network_connections(layer1, layer2, max_width=2.0, opacity_exp=1.0):
|
|
radius = layer1[0].get_width() / 2
|
|
return VGroup(
|
|
Line(n1.get_center(), n2.get_center(), buff=radius).set_stroke(
|
|
color=value_to_color(random.uniform(-10, 10)),
|
|
width=max_width * random.random(),
|
|
opacity=random.random()**opacity_exp,
|
|
)
|
|
for n1 in layer1
|
|
for n2 in layer2
|
|
)
|
|
|
|
|
|
def get_vector_pair(angle_in_degrees=90, length=1.0, colors=(BLUE, BLUE)):
|
|
angle = angle_in_degrees * DEGREES
|
|
v1 = Vector(length * RIGHT)
|
|
v2 = v1.copy().rotate(angle, about_point=ORIGIN)
|
|
v1.set_color(colors[0])
|
|
v2.set_color(colors[1])
|
|
arc = Arc(radius=0.2, angle=angle)
|
|
arc.set_stroke(WHITE, 2)
|
|
label = Tex(Rf"180^\circ", font_size=24)
|
|
num = label.make_number_changeable("180")
|
|
num.set_value(angle_in_degrees)
|
|
label.next_to(arc.pfp(0.5), normalize(arc.pfp(0.5)), buff=SMALL_BUFF)
|
|
|
|
return VGroup(v1, v2, arc, label)
|
|
|
|
|
|
class NeuralNetwork(VGroup):
|
|
def __init__(
|
|
self,
|
|
layer_sizes=[6, 12, 6],
|
|
neuron_radius=0.1,
|
|
v_buff_ratio=1.0,
|
|
h_buff_ratio=7.0,
|
|
max_stroke_width=2.0,
|
|
stroke_decay=2.0,
|
|
):
|
|
self.max_stroke_width = max_stroke_width
|
|
self.stroke_decay = stroke_decay
|
|
layers = VGroup(*(
|
|
Dot(radius=neuron_radius).get_grid(n, 1, v_buff_ratio=v_buff_ratio)
|
|
for n in layer_sizes
|
|
))
|
|
layers.arrange(RIGHT, buff=h_buff_ratio * layers[0].get_width())
|
|
|
|
lines = VGroup(*(
|
|
VGroup(*(
|
|
Line(
|
|
n1.get_center(),
|
|
n2.get_center(),
|
|
buff=n1.get_width() / 2,
|
|
)
|
|
for n1, n2 in it.product(l1, l2)
|
|
))
|
|
for l1, l2 in zip(layers, layers[1:])
|
|
))
|
|
|
|
super().__init__(layers, lines)
|
|
self.layers = layers
|
|
self.lines = lines
|
|
|
|
self.randomize_layer_values()
|
|
self.randomize_line_style()
|
|
|
|
def randomize_layer_values(self):
|
|
for group in self.lines:
|
|
for line in group:
|
|
line.set_stroke(
|
|
value_to_color(random.uniform(-10, 10)),
|
|
self.max_stroke_width * random.random()**self.stroke_decay,
|
|
)
|
|
return self
|
|
|
|
def randomize_line_style(self):
|
|
for layer in self.layers:
|
|
for dot in layer:
|
|
dot.set_stroke(WHITE, 1)
|
|
dot.set_fill(WHITE, random.random())
|
|
return self
|
|
|
|
|
|
class ContextAnimation(LaggedStart):
|
|
def __init__(
|
|
self,
|
|
target,
|
|
sources,
|
|
direction=UP,
|
|
hue_range=(0.1, 0.3),
|
|
time_width=2,
|
|
min_stroke_width=0,
|
|
max_stroke_width=5,
|
|
lag_ratio=None,
|
|
strengths=None,
|
|
run_time=3,
|
|
fix_in_frame=False,
|
|
path_arc=PI / 2,
|
|
**kwargs,
|
|
):
|
|
arcs = VGroup()
|
|
if strengths is None:
|
|
strengths = np.random.random(len(sources))**2
|
|
for source, strength in zip(sources, strengths):
|
|
sign = direction[1] * (-1)**int(source.get_x() < target.get_x())
|
|
arcs.add(Line(
|
|
source.get_edge_center(direction),
|
|
target.get_edge_center(direction),
|
|
path_arc=sign * path_arc,
|
|
stroke_color=random_bright_color(hue_range=hue_range),
|
|
stroke_width=interpolate(
|
|
min_stroke_width,
|
|
max_stroke_width,
|
|
strength,
|
|
)
|
|
))
|
|
if fix_in_frame:
|
|
arcs.fix_in_frame()
|
|
arcs.shuffle()
|
|
lag_ratio = 0.5 / len(arcs) if lag_ratio is None else lag_ratio
|
|
|
|
super().__init__(
|
|
*(
|
|
VShowPassingFlash(arc, time_width=time_width)
|
|
for arc in arcs
|
|
),
|
|
lag_ratio=lag_ratio,
|
|
run_time=run_time,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
class LabeledArrow(Arrow):
|
|
def __init__(
|
|
self,
|
|
*args,
|
|
label_text: Optional[str] = None,
|
|
font_size: float = 24,
|
|
label_buff: float = 0.1,
|
|
direction: Optional[Vect3] = None,
|
|
label_rotation: float = PI / 2,
|
|
**kwargs
|
|
):
|
|
super().__init__(*args, **kwargs)
|
|
if label_text is not None:
|
|
start, end = self.get_start_and_end()
|
|
label = Text(label_text, font_size=font_size)
|
|
label.set_fill(self.get_color())
|
|
label.set_backstroke()
|
|
label.rotate(label_rotation, RIGHT)
|
|
if direction is None:
|
|
direction = normalize(end - start)
|
|
label.next_to(end, direction, buff=label_buff)
|
|
self.label = label
|
|
else:
|
|
self.label = None
|
|
|
|
|
|
class WeightMatrix(DecimalMatrix):
|
|
def __init__(
|
|
self,
|
|
values: Optional[np.ndarray] = None,
|
|
shape: tuple[int, int] = (6, 8),
|
|
value_range: tuple[float, float] = (-9.9, 9.9),
|
|
ellipses_row: Optional[int] = -2,
|
|
ellipses_col: Optional[int] = -2,
|
|
num_decimal_places: int = 1,
|
|
bracket_h_buff: float = 0.1,
|
|
decimal_config=dict(include_sign=True),
|
|
low_positive_color: ManimColor = BLUE_E,
|
|
high_positive_color: ManimColor = BLUE_B,
|
|
low_negative_color: ManimColor = RED_E,
|
|
high_negative_color: ManimColor = RED_B,
|
|
):
|
|
if values is not None:
|
|
shape = values.shape
|
|
self.shape = shape
|
|
self.value_range = value_range
|
|
self.low_positive_color = low_positive_color
|
|
self.high_positive_color = high_positive_color
|
|
self.low_negative_color = low_negative_color
|
|
self.high_negative_color = high_negative_color
|
|
self.ellipses_row = ellipses_row
|
|
self.ellipses_col = ellipses_col
|
|
|
|
if values is None:
|
|
values = np.random.uniform(*self.value_range, size=shape)
|
|
|
|
super().__init__(
|
|
values,
|
|
num_decimal_places=num_decimal_places,
|
|
bracket_h_buff=bracket_h_buff,
|
|
decimal_config=decimal_config,
|
|
ellipses_row=ellipses_row,
|
|
ellipses_col=ellipses_col,
|
|
)
|
|
self.reset_entry_colors()
|
|
|
|
def reset_entry_colors(self):
|
|
for entry in self.get_entries():
|
|
entry.set_fill(color=value_to_color(
|
|
entry.get_value(),
|
|
self.low_positive_color,
|
|
self.high_positive_color,
|
|
self.low_negative_color,
|
|
self.high_negative_color,
|
|
0, max(self.value_range),
|
|
))
|
|
return self
|
|
|
|
|
|
class NumericEmbedding(WeightMatrix):
|
|
def __init__(
|
|
self,
|
|
values: Optional[np.ndarray] = None,
|
|
shape: Optional[Tuple[int, int]] = None,
|
|
length: int = 7,
|
|
num_decimal_places: int = 1,
|
|
ellipses_row: int = -2,
|
|
ellipses_col: int = -2,
|
|
value_range: tuple[float, float] = (-9.9, 9.9),
|
|
bracket_h_buff: float = 0.1,
|
|
decimal_config=dict(include_sign=True),
|
|
dark_color: ManimColor = GREY_C,
|
|
light_color: ManimColor = WHITE,
|
|
**kwargs,
|
|
):
|
|
if values is not None:
|
|
if len(values.shape) == 1:
|
|
values = values.reshape((values.shape[0], 1))
|
|
shape = values.shape
|
|
if shape is None:
|
|
shape = (length, 1)
|
|
super().__init__(
|
|
values,
|
|
shape=shape,
|
|
value_range=value_range,
|
|
num_decimal_places=num_decimal_places,
|
|
bracket_h_buff=bracket_h_buff,
|
|
decimal_config=decimal_config,
|
|
low_positive_color=dark_color,
|
|
high_positive_color=light_color,
|
|
low_negative_color=dark_color,
|
|
high_negative_color=light_color,
|
|
ellipses_row=ellipses_row,
|
|
ellipses_col=ellipses_col,
|
|
**kwargs,
|
|
)
|
|
|
|
# No sign on zeros
|
|
for entry in self.get_entries():
|
|
if entry.get_value() == 0:
|
|
entry[0].set_opacity(0)
|
|
|
|
|
|
class EmbeddingArray(VGroup):
|
|
def __init__(
|
|
self,
|
|
shape=(10, 9),
|
|
height=4,
|
|
dots_index=-4,
|
|
buff_ratio=0.4,
|
|
bracket_color=GREY_B,
|
|
backstroke_width=3,
|
|
add_background_rectangle=False,
|
|
):
|
|
super().__init__()
|
|
|
|
# Embeddings
|
|
embeddings = VGroup(
|
|
NumericEmbedding(length=shape[0])
|
|
for n in range(shape[1])
|
|
)
|
|
embeddings.set_height(height)
|
|
buff = buff_ratio * embeddings[0].get_width()
|
|
embeddings.arrange(RIGHT, buff=buff)
|
|
|
|
# Background rectangle
|
|
if add_background_rectangle:
|
|
for embedding in embeddings:
|
|
embedding.add_background_rectangle()
|
|
|
|
# Add brackets
|
|
brackets = Tex("".join((
|
|
R"\left[\begin{array}{c}",
|
|
*(shape[1] // 3) * [R"\quad \\"],
|
|
R"\end{array}\right]",
|
|
)))
|
|
brackets.set_height(1.1 * embeddings.get_height())
|
|
lb = brackets[:len(brackets) // 2]
|
|
rb = brackets[len(brackets) // 2:]
|
|
lb.next_to(embeddings, LEFT, buff=0)
|
|
rb.next_to(embeddings, RIGHT, buff=0)
|
|
brackets.set_fill(bracket_color)
|
|
|
|
# Assemble result
|
|
dots = VGroup()
|
|
self.add(embeddings, dots, brackets)
|
|
self.embeddings = embeddings
|
|
self.dots = dots
|
|
self.brackets = brackets
|
|
self.set_backstroke(BLACK, backstroke_width)
|
|
|
|
if dots_index is not None:
|
|
self.swap_embedding_for_dots(dots_index)
|
|
|
|
|
|
def swap_embedding_for_dots(self, dots_index=-4):
|
|
to_replace = self.embeddings[dots_index]
|
|
dots = Tex(R"\dots", font_size=60)
|
|
dots.set_width(0.75 * to_replace.get_width())
|
|
dots.move_to(to_replace)
|
|
self.embeddings.remove(to_replace)
|
|
self.dots.add(dots)
|
|
return self
|
|
|
|
|
|
class RandomizeMatrixEntries(Animation):
|
|
def __init__(self, matrix, **kwargs):
|
|
self.matrix = matrix
|
|
self.entries = matrix.get_entries()
|
|
self.start_values = [entry.get_value() for entry in self.entries]
|
|
self.target_values = np.random.uniform(
|
|
matrix.value_range[0],
|
|
matrix.value_range[1],
|
|
len(self.entries)
|
|
)
|
|
super().__init__(matrix, **kwargs)
|
|
|
|
def interpolate_mobject(self, alpha: float) -> None:
|
|
for index, entry in enumerate(self.entries):
|
|
start = self.start_values[index]
|
|
target = self.target_values[index]
|
|
sub_alpha = self.get_sub_alpha(alpha, index, len(self.entries))
|
|
entry.set_value(interpolate(start, target, sub_alpha))
|
|
self.matrix.reset_entry_colors()
|
|
|
|
|
|
class AbstractEmbeddingSequence(MobjectMatrix):
|
|
pass
|
|
|
|
|
|
class Dial(VGroup):
|
|
def __init__(
|
|
self,
|
|
radius=0.5,
|
|
relative_tick_size=0.2,
|
|
value_range=(0, 1, 0.1),
|
|
initial_value=0,
|
|
arc_angle=270 * DEGREES,
|
|
stroke_width=2,
|
|
stroke_color=WHITE,
|
|
needle_color=BLUE,
|
|
needle_stroke_width=5.0,
|
|
value_to_color_config=dict(),
|
|
set_anim_streak_color=TEAL,
|
|
set_anim_streak_width=4,
|
|
set_value_anim_streak_density=6,
|
|
**kwargs
|
|
):
|
|
super().__init__(**kwargs)
|
|
self.value_range = value_range
|
|
self.value_to_color_config = value_to_color_config
|
|
self.set_anim_streak_color = set_anim_streak_color
|
|
self.set_anim_streak_width = set_anim_streak_width
|
|
self.set_value_anim_streak_density = set_value_anim_streak_density
|
|
|
|
# Main dial
|
|
self.arc = Arc(arc_angle / 2, -arc_angle, radius=radius)
|
|
self.arc.rotate(90 * DEGREES, about_point=ORIGIN)
|
|
|
|
low, high, step = value_range
|
|
n_values = int(1 + (high - low) / step)
|
|
tick_points = map(self.arc.pfp, np.linspace(0, 1, n_values))
|
|
self.ticks = VGroup(*(
|
|
Line((1.0 - relative_tick_size) * point, point)
|
|
for point in tick_points
|
|
))
|
|
self.bottom_point = VectorizedPoint(radius * DOWN)
|
|
for mob in self.arc, self.ticks:
|
|
mob.set_stroke(stroke_color, stroke_width)
|
|
|
|
self.add(self.arc, self.ticks, self.bottom_point)
|
|
|
|
# Needle
|
|
self.needle = Line()
|
|
self.needle.set_stroke(
|
|
color=needle_color,
|
|
width=[needle_stroke_width, 0]
|
|
)
|
|
self.add(self.needle)
|
|
|
|
# Initialize
|
|
self.set_value(initial_value)
|
|
|
|
def value_to_point(self, value):
|
|
low, high, step = self.value_range
|
|
alpha = inverse_interpolate(low, high, value)
|
|
return self.arc.pfp(alpha)
|
|
|
|
def set_value(self, value):
|
|
self.needle.put_start_and_end_on(
|
|
self.get_center(),
|
|
self.value_to_point(value)
|
|
)
|
|
self.needle.set_color(value_to_color(
|
|
value,
|
|
min_value=self.value_range[0],
|
|
max_value=self.value_range[1],
|
|
**self.value_to_color_config
|
|
))
|
|
|
|
def animate_set_value(self, value, **kwargs):
|
|
kwargs.pop("path_arc", None)
|
|
center = self.get_center()
|
|
points = [self.needle.get_end(), self.value_to_point(value)]
|
|
vects = [point - center for point in points]
|
|
angle1, angle2 = [
|
|
(angle_of_vector(vect) + TAU / 4) % TAU - TAU / 4
|
|
for vect in vects
|
|
]
|
|
path_arc = angle2 - angle1
|
|
|
|
density = self.set_value_anim_streak_density
|
|
radii = np.linspace(0, 0.5 * self.get_width(), density + 1)[1:]
|
|
diff_arcs = VGroup(*(
|
|
Arc(
|
|
angle1, angle2 - angle1,
|
|
radius=radius,
|
|
arc_center=center,
|
|
)
|
|
for radius in radii
|
|
))
|
|
diff_arcs.set_stroke(self.set_anim_streak_color, self.set_anim_streak_width)
|
|
|
|
return AnimationGroup(
|
|
self.animate.set_value(value).set_anim_args(path_arc=path_arc, **kwargs),
|
|
*(
|
|
VShowPassingFlash(diff_arc, time_width=1.5, **kwargs)
|
|
for diff_arc in diff_arcs
|
|
)
|
|
)
|
|
|
|
def get_random_value(self):
|
|
low, high, step = self.value_range
|
|
return interpolate(low, high, random.random())
|
|
|
|
|
|
class MachineWithDials(VGroup):
|
|
default_dial_config = dict(
|
|
stroke_width=1.0,
|
|
needle_stroke_width=5.0,
|
|
relative_tick_size=0.25,
|
|
set_anim_streak_width=2,
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
width=5.0,
|
|
height=4.0,
|
|
n_rows=6,
|
|
n_cols=8,
|
|
dial_buff_ratio=0.5,
|
|
stroke_color=WHITE,
|
|
stroke_width=1,
|
|
fill_color=GREY_D,
|
|
fill_opacity=1.0,
|
|
dial_config=dict(),
|
|
):
|
|
super().__init__()
|
|
box = Rectangle(width, height)
|
|
box.set_stroke(stroke_color, stroke_width)
|
|
box.set_fill(fill_color, fill_opacity)
|
|
self.box = box
|
|
|
|
dial_config = dict(**self.default_dial_config, **dial_config)
|
|
dials = Dial(**dial_config).get_grid(n_rows, n_cols, buff_ratio=dial_buff_ratio)
|
|
buff = dials[0].get_width() * dial_buff_ratio
|
|
dials.set_width(box.get_width() - buff)
|
|
dials.set_max_height(box.get_width() - buff)
|
|
dials.move_to(box)
|
|
for dial in dials:
|
|
dial.set_value(dial.get_random_value())
|
|
self.dials = dials
|
|
|
|
self.add(box, dials)
|
|
|
|
def random_change_animation(self, lag_factor=0.5, run_time=3.0, **kwargs):
|
|
return LaggedStart(
|
|
*(
|
|
dial.animate_set_value(dial.get_random_value())
|
|
for dial in self.dials
|
|
), lag_ratio=lag_factor / len(self.dials),
|
|
run_time=run_time,
|
|
**kwargs
|
|
)
|
|
|
|
def rotate_all_dials(self, run_time=2, lag_factor=1.0):
|
|
shuffled_dials = list(self.dials)
|
|
random.shuffle(shuffled_dials)
|
|
return LaggedStart(
|
|
*(
|
|
Rotate(dial.needle, TAU, about_point=dial.get_center())
|
|
for dial in shuffled_dials
|
|
),
|
|
lag_ratio=lag_factor / len(self.dials)
|
|
)
|