mirror of
https://github.com/3b1b/videos.git
synced 2025-08-31 21:58:59 +00:00
51 lines
1.5 KiB
Python
51 lines
1.5 KiB
Python
import torch
|
|
import matplotlib.pyplot as plt
|
|
from tqdm import tqdm
|
|
|
|
# List of vectors in some dimension, with many
|
|
# more vectors than there are dimensions
|
|
num_vectors = 10000
|
|
vector_len = 100
|
|
big_matrix = torch.randn(num_vectors, vector_len)
|
|
big_matrix /= big_matrix.norm(p=2, dim=1, keepdim=True) # Normalize
|
|
big_matrix.requires_grad_(True)
|
|
|
|
# Set up an Optimization loop to create nearly-perpendicular vectors
|
|
optimizer = torch.optim.Adam([big_matrix], lr=0.01)
|
|
num_steps = 250
|
|
|
|
losses = []
|
|
|
|
dot_diff_cutoff = 0.01
|
|
big_id = torch.eye(num_vectors, num_vectors)
|
|
|
|
for step_num in tqdm(range(num_steps)):
|
|
optimizer.zero_grad()
|
|
|
|
dot_products = big_matrix @ big_matrix.T
|
|
# Punish deviation from orthogonal
|
|
diff = dot_products - big_id
|
|
loss = (diff.abs() - dot_diff_cutoff).relu().sum()
|
|
|
|
# Extra incentive to keep rows normalized
|
|
loss += num_vectors * diff.diag().pow(2).sum()
|
|
|
|
loss.backward()
|
|
optimizer.step()
|
|
losses.append(loss.item())
|
|
|
|
# Loss curve
|
|
plt.plot(losses)
|
|
plt.grid(1)
|
|
plt.show()
|
|
|
|
# Angle distribution
|
|
dot_products = big_matrix @ big_matrix.T
|
|
norms = torch.sqrt(torch.diag(dot_products))
|
|
normed_dot_products = dot_products / torch.outer(norms, norms)
|
|
angles_degrees = torch.rad2deg(torch.acos(normed_dot_products.detach()))
|
|
# Use this to ignore self-orthogonality.
|
|
self_orthogonality_mask = ~(torch.eye(num_vectors, num_vectors).bool())
|
|
plt.hist(angles_degrees[self_orthogonality_mask].numpy().ravel(), bins=1000, range=(80, 100))
|
|
plt.grid(1)
|
|
plt.show()
|