Further development on VectorField

This commit is contained in:
Grant Sanderson 2024-11-15 09:07:46 -08:00
parent 64ae1364ca
commit 0ad5a0e76e

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import itertools as it import itertools as it
import numpy as np import numpy as np
from scipy.integrate import solve_ivp
from manimlib.constants import FRAME_HEIGHT, FRAME_WIDTH from manimlib.constants import FRAME_HEIGHT, FRAME_WIDTH
from manimlib.constants import BLUE, WHITE from manimlib.constants import BLUE, WHITE
@ -26,7 +27,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable, Iterable, Sequence, TypeVar, Tuple from typing import Callable, Iterable, Sequence, TypeVar, Tuple
from manimlib.typing import ManimColor, Vect3, VectN, Vect3Array, Vect4Array from manimlib.typing import ManimColor, Vect3, VectN, Vect2Array, Vect3Array, Vect4Array
from manimlib.mobject.coordinate_systems import CoordinateSystem from manimlib.mobject.coordinate_systems import CoordinateSystem
from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Mobject
@ -68,6 +69,16 @@ def get_rgb_gradient_function(
#### ####
def ode_solution_points(function, state0, time, dt=0.01):
solution = solve_ivp(
lambda t, state: function(state),
t_span=(0, time),
y0=state0,
t_eval=np.arange(0, time, dt)
)
return solution.y.T
def move_along_vector_field( def move_along_vector_field(
mobject: Mobject, mobject: Mobject,
func: Callable[[Vect3], Vect3] func: Callable[[Vect3], Vect3]
@ -128,12 +139,12 @@ def get_sample_coords(
class VectorField(VMobject): class VectorField(VMobject):
def __init__( def __init__(
self, self,
func: Callable[Sequence[float], Sequence[float]], func: Callable[[VectArray], VectArray],
coordinate_system: CoordinateSystem, coordinate_system: CoordinateSystem,
step_multiple: float = 0.5, step_multiple: float = 0.5,
magnitude_range: Optional[Tuple[float, float]] = None, magnitude_range: Optional[Tuple[float, float]] = None,
color_map_name: Optional[str] = "3b1b_colormap", color_map_name: Optional[str] = "3b1b_colormap",
color_map: Optional[Callable[Sequence[float]], Vect4Array] = None, color_map: Optional[Callable[[Sequence[float]], Vect4Array]] = None,
stroke_color: ManimColor = BLUE, stroke_color: ManimColor = BLUE,
stroke_opacity: float = 1.0, stroke_opacity: float = 1.0,
stroke_width: float = 2, stroke_width: float = 2,
@ -281,9 +292,9 @@ class VectorField(VMobject):
if self.color_map is not None: if self.color_map is not None:
self.get_stroke_colors() # Ensures the array is updated to appropriate length self.get_stroke_colors() # Ensures the array is updated to appropriate length
low, high = self.magnitude_range low, high = self.magnitude_range
self.data['stroke_rgba'][:] = self.color_map( self.data['stroke_rgba'][:, :3] = self.color_map(
inverse_interpolate(low, high, np.repeat(output_norms, 8)[:-1]) inverse_interpolate(low, high, np.repeat(output_norms, 8)[:-1])
) )[:, :3]
if self.norm_to_opacity_func is not None: if self.norm_to_opacity_func is not None:
self.get_stroke_opacities()[:] = self.norm_to_opacity_func( self.get_stroke_opacities()[:] = self.norm_to_opacity_func(
@ -310,6 +321,165 @@ class TimeVaryingVectorField(VectorField):
self.time += dt self.time += dt
class StreamLines(VGroup):
def __init__(
self,
func: Callable[[VectArray], VectArray],
coordinate_system: CoordinateSystem,
step_multiple: float = 0.5,
n_repeats: int = 1,
noise_factor: float | None = None,
# Config for drawing lines
solution_time: float = 3,
dt: float = 0.05,
arc_len: float = 3,
max_time_steps: int = 200,
n_samples_per_line: int = 10,
cutoff_norm: float = 15,
# Style info
stroke_width: float = 1.0,
stroke_color: ManimColor = WHITE,
stroke_opacity: float = 1,
color_by_magnitude: bool = True,
magnitude_range: Tuple[float, float] = (0, 2.0),
taper_stroke_width: bool = False,
color_map: str = "3b1b_colormap",
**kwargs
):
super().__init__(**kwargs)
self.func = func
self.coordinate_system = coordinate_system
self.step_multiple = step_multiple
self.n_repeats = n_repeats
self.noise_factor = noise_factor
self.solution_time = solution_time
self.dt = dt
self.arc_len = arc_len
self.max_time_steps = max_time_steps
self.n_samples_per_line = n_samples_per_line
self.cutoff_norm = cutoff_norm
self.stroke_width = stroke_width
self.stroke_color = stroke_color
self.stroke_opacity = stroke_opacity
self.color_by_magnitude = color_by_magnitude
self.magnitude_range = magnitude_range
self.taper_stroke_width = taper_stroke_width
self.color_map = color_map
self.draw_lines()
self.init_style()
def point_func(self, points: Vect3Array) -> Vect3:
in_coords = np.array(self.coordinate_system.p2c(points)).T
out_coords = self.func(in_coords)
origin = self.coordinate_system.get_origin()
return self.coordinate_system.c2p(*out_coords.T) - origin
def draw_lines(self) -> None:
lines = []
origin = self.coordinate_system.get_origin()
# Todo, it feels like coordinate system should just have
# the ODE solver built into it, no?
lines = []
for coords in self.get_sample_coords():
solution_coords = ode_solution_points(self.func, coords, self.solution_time, self.dt)
line = VMobject()
line.set_points_smoothly(self.coordinate_system.c2p(*solution_coords.T))
# TODO, account for arc length somehow?
line.virtual_time = self.solution_time
lines.append(line)
self.set_submobjects(lines)
def get_sample_coords(self):
cs = self.coordinate_system
sample_coords = get_sample_coords(cs, self.step_multiple)
noise_factor = self.noise_factor
if noise_factor is None:
noise_factor = cs.get_x_unit_size() * self.step_multiple * 0.5
return np.array([
coords + noise_factor * np.random.random(coords.shape)
for n in range(self.n_repeats)
for coords in sample_coords
])
def get_start_points(self) -> Vect3Array:
cs = self.coordinate_system
sample_coords = get_sample_coords(cs, self.step_multiple)
noise_factor = self.noise_factor
if noise_factor is None:
noise_factor = cs.get_x_unit_size() * self.step_multiple * 0.5
return np.array([
cs.c2p(*coords) + noise_factor * np.random.random(3)
for n in range(self.n_repeats)
for coords in sample_coords
])
def init_style(self) -> None:
if self.color_by_magnitude:
values_to_rgbs = get_vectorized_rgb_gradient_function(
*self.magnitude_range, self.color_map,
)
cs = self.coordinate_system
for line in self.submobjects:
norms = [
get_norm(self.func(*cs.p2c(point)))
for point in line.get_points()
]
rgbs = values_to_rgbs(norms)
rgbas = np.zeros((len(rgbs), 4))
rgbas[:, :3] = rgbs
rgbas[:, 3] = self.stroke_opacity
line.set_rgba_array(rgbas, "stroke_rgba")
else:
self.set_stroke(self.stroke_color, opacity=self.stroke_opacity)
if self.taper_stroke_width:
width = [0, self.stroke_width, 0]
else:
width = self.stroke_width
self.set_stroke(width=width)
class AnimatedStreamLines(VGroup):
def __init__(
self,
stream_lines: StreamLines,
lag_range: float = 4,
rate_multiple: float = 1.0,
line_anim_config: dict = dict(
rate_func=linear,
time_width=1.0,
),
**kwargs
):
super().__init__(**kwargs)
self.stream_lines = stream_lines
for line in stream_lines:
line.anim = VShowPassingFlash(
line,
run_time=line.virtual_time / rate_multiple,
**line_anim_config,
)
line.anim.begin()
line.time = -lag_range * np.random.random()
self.add(line.anim.mobject)
self.add_updater(lambda m, dt: m.update(dt))
def update(self, dt: float) -> None:
stream_lines = self.stream_lines
for line in stream_lines:
line.time += dt
adjusted_time = max(line.time, 0) % line.anim.run_time
line.anim.update(adjusted_time / line.anim.run_time)
class OldVectorField(VGroup): class OldVectorField(VGroup):
def __init__( def __init__(
self, self,
@ -371,155 +541,3 @@ class OldVectorField(VGroup):
opacity=self.opacity, opacity=self.opacity,
) )
return vect return vect
class StreamLines(VGroup):
def __init__(
self,
func: Callable[[float, float], Sequence[float]],
coordinate_system: CoordinateSystem,
step_multiple: float = 0.5,
n_repeats: int = 1,
noise_factor: float | None = None,
# Config for drawing lines
dt: float = 0.05,
arc_len: float = 3,
max_time_steps: int = 200,
n_samples_per_line: int = 10,
cutoff_norm: float = 15,
# Style info
stroke_width: float = 1.0,
stroke_color: ManimColor = WHITE,
stroke_opacity: float = 1,
color_by_magnitude: bool = True,
magnitude_range: Tuple[float, float] = (0, 2.0),
taper_stroke_width: bool = False,
color_map: str = "3b1b_colormap",
**kwargs
):
super().__init__(**kwargs)
self.func = func
self.coordinate_system = coordinate_system
self.step_multiple = step_multiple
self.n_repeats = n_repeats
self.noise_factor = noise_factor
self.dt = dt
self.arc_len = arc_len
self.max_time_steps = max_time_steps
self.n_samples_per_line = n_samples_per_line
self.cutoff_norm = cutoff_norm
self.stroke_width = stroke_width
self.stroke_color = stroke_color
self.stroke_opacity = stroke_opacity
self.color_by_magnitude = color_by_magnitude
self.magnitude_range = magnitude_range
self.taper_stroke_width = taper_stroke_width
self.color_map = color_map
self.draw_lines()
self.init_style()
def point_func(self, point: Vect3) -> Vect3:
in_coords = self.coordinate_system.p2c(point)
out_coords = self.func(*in_coords)
return self.coordinate_system.c2p(*out_coords)
def draw_lines(self) -> None:
lines = []
origin = self.coordinate_system.get_origin()
for point in self.get_start_points():
points = [point]
total_arc_len = 0
time = 0
for x in range(self.max_time_steps):
time += self.dt
last_point = points[-1]
new_point = last_point + self.dt * (self.point_func(last_point) - origin)
points.append(new_point)
total_arc_len += get_norm(new_point - last_point)
if get_norm(last_point) > self.cutoff_norm:
break
if total_arc_len > self.arc_len:
break
line = VMobject()
line.virtual_time = time
step = max(1, int(len(points) / self.n_samples_per_line))
line.set_points_as_corners(points[::step])
line.make_smooth(approx=True)
lines.append(line)
self.set_submobjects(lines)
def get_start_points(self) -> Vect3Array:
cs = self.coordinate_system
sample_coords = get_sample_coords(
cs, self.step_multiple,
)
noise_factor = self.noise_factor
if noise_factor is None:
noise_factor = cs.x_range[2] * self.step_multiple * 0.5
return np.array([
cs.c2p(*coords) + noise_factor * np.random.random(3)
for n in range(self.n_repeats)
for coords in sample_coords
])
def init_style(self) -> None:
if self.color_by_magnitude:
values_to_rgbs = get_vectorized_rgb_gradient_function(
*self.magnitude_range, self.color_map,
)
cs = self.coordinate_system
for line in self.submobjects:
norms = [
get_norm(self.func(*cs.p2c(point)))
for point in line.get_points()
]
rgbs = values_to_rgbs(norms)
rgbas = np.zeros((len(rgbs), 4))
rgbas[:, :3] = rgbs
rgbas[:, 3] = self.stroke_opacity
line.set_rgba_array(rgbas, "stroke_rgba")
else:
self.set_stroke(self.stroke_color, opacity=self.stroke_opacity)
if self.taper_stroke_width:
width = [0, self.stroke_width, 0]
else:
width = self.stroke_width
self.set_stroke(width=width)
class AnimatedStreamLines(VGroup):
def __init__(
self,
stream_lines: StreamLines,
lag_range: float = 4,
line_anim_config: dict = dict(
rate_func=linear,
time_width=1.0,
),
**kwargs
):
super().__init__(**kwargs)
self.stream_lines = stream_lines
for line in stream_lines:
line.anim = VShowPassingFlash(
line,
run_time=line.virtual_time,
**line_anim_config,
)
line.anim.begin()
line.time = -lag_range * np.random.random()
self.add(line.anim.mobject)
self.add_updater(lambda m, dt: m.update(dt))
def update(self, dt: float) -> None:
stream_lines = self.stream_lines
for line in stream_lines:
line.time += dt
adjusted_time = max(line.time, 0) % line.anim.run_time
line.anim.update(adjusted_time / line.anim.run_time)