3b1b-manim/manimlib/mobject/vector_field.py

331 lines
10 KiB
Python
Raw Normal View History

from __future__ import annotations
import itertools as it
import numpy as np
2022-04-12 19:19:59 +08:00
from manimlib.constants import FRAME_HEIGHT, FRAME_WIDTH
from manimlib.constants import WHITE
2021-02-25 08:47:29 -08:00
from manimlib.animation.indication import VShowPassingFlash
from manimlib.mobject.geometry import Arrow
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.bezier import interpolate
2022-04-12 19:19:59 +08:00
from manimlib.utils.bezier import inverse_interpolate
2021-02-25 08:47:29 -08:00
from manimlib.utils.color import get_colormap_list
2022-12-16 18:59:23 -08:00
from manimlib.utils.dict_ops import merge_dicts_recursively
from manimlib.utils.rate_functions import linear
from manimlib.utils.simple_functions import sigmoid
from manimlib.utils.space_ops import get_norm
from typing import TYPE_CHECKING
if TYPE_CHECKING:
2022-12-16 10:16:13 -08:00
from typing import Callable, Iterable, Sequence, TypeVar, Tuple
2022-04-12 19:19:59 +08:00
import numpy.typing as npt
2022-12-16 10:16:13 -08:00
from manimlib.constants import ManimColor, np_vector
2022-04-12 19:19:59 +08:00
from manimlib.mobject.coordinate_systems import CoordinateSystem
2022-04-12 19:19:59 +08:00
from manimlib.mobject.mobject import Mobject
T = TypeVar("T")
def get_vectorized_rgb_gradient_function(
min_value: T,
max_value: T,
color_map: str
2022-12-16 10:16:13 -08:00
) -> Callable[[npt.ArrayLike], np_vector]:
2021-02-25 08:47:29 -08:00
rgbs = np.array(get_colormap_list(color_map))
def func(values):
2019-03-19 17:30:37 -07:00
alphas = inverse_interpolate(
min_value, max_value, np.array(values)
)
alphas = np.clip(alphas, 0, 1)
scaled_alphas = alphas * (len(rgbs) - 1)
indices = scaled_alphas.astype(int)
next_indices = np.clip(indices + 1, 0, len(rgbs) - 1)
inter_alphas = scaled_alphas % 1
inter_alphas = inter_alphas.repeat(3).reshape((len(indices), 3))
result = interpolate(rgbs[indices], rgbs[next_indices], inter_alphas)
return result
return func
def get_rgb_gradient_function(
min_value: T,
max_value: T,
color_map: str
2022-12-16 10:16:13 -08:00
) -> Callable[[T], np_vector]:
2021-02-25 08:47:29 -08:00
vectorized_func = get_vectorized_rgb_gradient_function(min_value, max_value, color_map)
return lambda value: vectorized_func([value])[0]
def move_along_vector_field(
mobject: Mobject,
2022-12-16 10:16:13 -08:00
func: Callable[[np_vector], np_vector]
) -> Mobject:
mobject.add_updater(
lambda m, dt: m.shift(
func(m.get_center()) * dt
)
)
return mobject
def move_submobjects_along_vector_field(
mobject: Mobject,
2022-12-16 10:16:13 -08:00
func: Callable[[np_vector], np_vector]
) -> Mobject:
def apply_nudge(mob, dt):
for submob in mob:
x, y = submob.get_center()[:2]
if abs(x) < FRAME_WIDTH and abs(y) < FRAME_HEIGHT:
submob.shift(func(submob.get_center()) * dt)
mobject.add_updater(apply_nudge)
return mobject
def move_points_along_vector_field(
mobject: Mobject,
func: Callable[[float, float], Iterable[float]],
coordinate_system: CoordinateSystem
) -> Mobject:
2021-03-18 17:34:16 -07:00
cs = coordinate_system
origin = cs.get_origin()
def apply_nudge(self, dt):
2021-03-18 17:34:16 -07:00
mobject.apply_function(
lambda p: p + (cs.c2p(*func(*cs.p2c(p))) - origin) * dt
)
mobject.add_updater(apply_nudge)
return mobject
def get_sample_points_from_coordinate_system(
coordinate_system: CoordinateSystem,
step_multiple: float
2022-12-16 10:16:13 -08:00
) -> it.product[tuple[np_vector, ...]]:
2021-02-25 08:47:29 -08:00
ranges = []
for range_args in coordinate_system.get_all_ranges():
_min, _max, step = range_args
step *= step_multiple
ranges.append(np.arange(_min, _max + step, step))
return it.product(*ranges)
# Mobjects
class VectorField(VGroup):
def __init__(
self,
func: Callable[[float, float], Sequence[float]],
coordinate_system: CoordinateSystem,
2022-12-16 10:16:13 -08:00
step_multiple: float = 0.5,
magnitude_range: Tuple[float, float] = (0, 2),
color_map: str = "3b1b_colormap",
# Takes in actual norm, spits out displayed norm
length_func: Callable[[float], float] = lambda norm: 0.45 * sigmoid(norm),
opacity: float = 1.0,
vector_config: dict = dict(),
**kwargs
):
2020-06-29 11:04:48 -07:00
super().__init__(**kwargs)
self.func = func
2021-02-25 08:47:29 -08:00
self.coordinate_system = coordinate_system
2022-12-16 10:16:13 -08:00
self.step_multiple = step_multiple
self.magnitude_range = magnitude_range
self.color_map = color_map
self.length_func = length_func
self.opacity = opacity
self.vector_config = vector_config
2021-02-25 08:47:29 -08:00
self.value_to_rgb = get_rgb_gradient_function(
*self.magnitude_range, self.color_map,
)
2021-02-25 08:47:29 -08:00
samples = get_sample_points_from_coordinate_system(
coordinate_system, self.step_multiple
)
2021-02-25 08:47:29 -08:00
self.add(*(
self.get_vector(coords)
for coords in samples
))
def get_vector(self, coords: Iterable[float], **kwargs) -> Arrow:
2021-02-25 08:47:29 -08:00
vector_config = merge_dicts_recursively(
self.vector_config,
kwargs
)
2021-02-25 08:47:29 -08:00
output = np.array(self.func(*coords))
norm = get_norm(output)
2021-02-25 08:47:29 -08:00
if norm > 0:
output *= self.length_func(norm) / norm
2021-02-25 08:47:29 -08:00
origin = self.coordinate_system.get_origin()
_input = self.coordinate_system.c2p(*coords)
_output = self.coordinate_system.c2p(*output)
vect = Arrow(
origin, _output, buff=0,
**vector_config
)
2021-03-24 13:58:18 -07:00
vect.shift(_input - origin)
2021-02-25 08:47:29 -08:00
vect.set_rgba_array([[*self.value_to_rgb(norm), self.opacity]])
return vect
2019-03-19 17:30:37 -07:00
class StreamLines(VGroup):
def __init__(
self,
func: Callable[[float, float], Sequence[float]],
coordinate_system: CoordinateSystem,
2022-12-16 10:16:13 -08:00
step_multiple: float = 0.5,
n_repeats: int = 1,
noise_factor: float | None = None,
# Config for drawing lines
dt: float = 0.05,
arc_len: float = 3,
max_time_steps: int = 200,
n_samples_per_line: int = 10,
cutoff_norm: float = 15,
# Style info
stroke_width: float = 1.0,
stroke_color: ManimColor = WHITE,
stroke_opacity: float = 1,
color_by_magnitude: bool = True,
magnitude_range: Tuple[float, float] = (0, 2.0),
taper_stroke_width: bool = False,
color_map: str = "3b1b_colormap",
**kwargs
):
2021-02-25 08:47:29 -08:00
super().__init__(**kwargs)
2019-03-19 17:30:37 -07:00
self.func = func
2021-02-25 08:47:29 -08:00
self.coordinate_system = coordinate_system
2022-12-16 10:16:13 -08:00
self.step_multiple = step_multiple
self.n_repeats = n_repeats
self.noise_factor = noise_factor
self.dt = dt
self.arc_len = arc_len
self.max_time_steps = max_time_steps
self.n_samples_per_line = n_samples_per_line
self.cutoff_norm = cutoff_norm
self.stroke_width = stroke_width
self.stroke_color = stroke_color
self.stroke_opacity = stroke_opacity
self.color_by_magnitude = color_by_magnitude
self.magnitude_range = magnitude_range
self.taper_stroke_width = taper_stroke_width
self.color_map = color_map
2021-02-25 08:47:29 -08:00
self.draw_lines()
self.init_style()
2019-03-19 17:30:37 -07:00
2022-12-16 10:16:13 -08:00
def point_func(self, point: np_vector) -> np_vector:
2021-03-18 17:34:16 -07:00
in_coords = self.coordinate_system.p2c(point)
out_coords = self.func(*in_coords)
return self.coordinate_system.c2p(*out_coords)
2021-02-25 08:47:29 -08:00
def draw_lines(self) -> None:
2021-02-25 08:47:29 -08:00
lines = []
2021-03-18 17:34:16 -07:00
origin = self.coordinate_system.get_origin()
2021-02-25 08:47:29 -08:00
for point in self.get_start_points():
2019-03-19 17:30:37 -07:00
points = [point]
2021-02-25 08:47:29 -08:00
total_arc_len = 0
2021-03-18 17:34:16 -07:00
time = 0
2021-02-25 08:47:29 -08:00
for x in range(self.max_time_steps):
2021-03-18 17:34:16 -07:00
time += self.dt
2019-03-19 17:30:37 -07:00
last_point = points[-1]
2021-03-18 17:34:16 -07:00
new_point = last_point + self.dt * (self.point_func(last_point) - origin)
2021-02-25 08:47:29 -08:00
points.append(new_point)
total_arc_len += get_norm(new_point - last_point)
2019-03-19 17:30:37 -07:00
if get_norm(last_point) > self.cutoff_norm:
break
2021-02-25 08:47:29 -08:00
if total_arc_len > self.arc_len:
break
2019-03-19 17:30:37 -07:00
line = VMobject()
2021-03-18 17:34:16 -07:00
line.virtual_time = time
2021-02-25 08:47:29 -08:00
step = max(1, int(len(points) / self.n_samples_per_line))
2021-03-18 17:34:16 -07:00
line.set_points_as_corners(points[::step])
line.make_approximately_smooth()
2021-02-25 08:47:29 -08:00
lines.append(line)
self.set_submobjects(lines)
2019-03-19 17:30:37 -07:00
2022-12-16 10:16:13 -08:00
def get_start_points(self) -> np_vector:
2021-02-25 08:47:29 -08:00
cs = self.coordinate_system
sample_coords = get_sample_points_from_coordinate_system(
cs, self.step_multiple,
)
2019-03-19 17:30:37 -07:00
2021-02-25 08:47:29 -08:00
noise_factor = self.noise_factor
2019-03-19 17:30:37 -07:00
if noise_factor is None:
2021-02-25 08:47:29 -08:00
noise_factor = cs.x_range[2] * self.step_multiple * 0.5
2019-03-19 17:30:37 -07:00
return np.array([
2021-02-25 08:47:29 -08:00
cs.c2p(*coords) + noise_factor * np.random.random(3)
for n in range(self.n_repeats)
for coords in sample_coords
2019-03-19 17:30:37 -07:00
])
def init_style(self) -> None:
2021-02-25 08:47:29 -08:00
if self.color_by_magnitude:
values_to_rgbs = get_vectorized_rgb_gradient_function(
*self.magnitude_range, self.color_map,
)
2021-02-25 08:47:29 -08:00
cs = self.coordinate_system
for line in self.submobjects:
norms = [
get_norm(self.func(*cs.p2c(point)))
for point in line.get_points()
]
rgbs = values_to_rgbs(norms)
rgbas = np.zeros((len(rgbs), 4))
rgbas[:, :3] = rgbs
rgbas[:, 3] = self.stroke_opacity
line.set_rgba_array(rgbas, "stroke_rgba")
else:
2021-03-18 17:34:16 -07:00
self.set_stroke(self.stroke_color, opacity=self.stroke_opacity)
2021-02-25 08:47:29 -08:00
if self.taper_stroke_width:
width = [0, self.stroke_width, 0]
else:
width = self.stroke_width
self.set_stroke(width=width)
class AnimatedStreamLines(VGroup):
2022-12-16 10:16:13 -08:00
def __init__(
self,
stream_lines: StreamLines,
lag_range: float = 4,
line_anim_config: dict = dict(
rate_func=linear,
time_width=1.0,
),
**kwargs
):
2021-02-25 08:47:29 -08:00
super().__init__(**kwargs)
self.stream_lines = stream_lines
2022-12-16 10:16:13 -08:00
for line in stream_lines:
2022-12-16 10:16:13 -08:00
line.anim = VShowPassingFlash(
2021-03-18 17:34:16 -07:00
line,
run_time=line.virtual_time,
2022-12-16 10:16:13 -08:00
**line_anim_config,
2021-03-18 17:34:16 -07:00
)
2019-03-19 17:30:37 -07:00
line.anim.begin()
2022-12-16 10:16:13 -08:00
line.time = -lag_range * np.random.random()
self.add(line.anim.mobject)
self.add_updater(lambda m, dt: m.update(dt))
def update(self, dt: float) -> None:
stream_lines = self.stream_lines
for line in stream_lines:
line.time += dt
adjusted_time = max(line.time, 0) % line.anim.run_time
line.anim.update(adjusted_time / line.anim.run_time)