Update the Vector Field interface

This commit is contained in:
Grant Sanderson 2024-11-12 11:21:19 -08:00
parent b84376d6fd
commit 64ae1364ca
3 changed files with 108 additions and 62 deletions

View file

@ -45,6 +45,12 @@ DEFAULT_X_RANGE = (-8.0, 8.0, 1.0)
DEFAULT_Y_RANGE = (-4.0, 4.0, 1.0)
def full_range_specifier(range_args):
if len(range_args) == 2:
return (*range_args, 1)
return range_args
class CoordinateSystem(ABC):
"""
Abstract class for Axes and NumberPlane
@ -57,8 +63,8 @@ class CoordinateSystem(ABC):
y_range: RangeSpecifier = DEFAULT_Y_RANGE,
num_sampled_graph_points_per_tick: int = 5,
):
self.x_range = x_range
self.y_range = y_range
self.x_range = full_range_specifier(x_range)
self.y_range = full_range_specifier(y_range)
self.num_sampled_graph_points_per_tick = num_sampled_graph_points_per_tick
@abstractmethod
@ -536,7 +542,7 @@ class ThreeDAxes(Axes):
):
Axes.__init__(self, x_range, y_range, **kwargs)
self.z_range = z_range
self.z_range = full_range_specifier(z_range)
self.z_axis = self.create_axis(
self.z_range,
axis_config=merge_dicts_recursively(

View file

@ -15,6 +15,7 @@ from manimlib.utils.bezier import interpolate
from manimlib.utils.bezier import inverse_interpolate
from manimlib.utils.color import get_colormap_list
from manimlib.utils.color import rgb_to_color
from manimlib.utils.color import get_color_map
from manimlib.utils.dict_ops import merge_dicts_recursively
from manimlib.utils.iterables import cartesian_product
from manimlib.utils.rate_functions import linear
@ -25,7 +26,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Callable, Iterable, Sequence, TypeVar, Tuple
from manimlib.typing import ManimColor, Vect3, VectN, Vect3Array
from manimlib.typing import ManimColor, Vect3, VectN, Vect3Array, Vect4Array
from manimlib.mobject.coordinate_systems import CoordinateSystem
from manimlib.mobject.mobject import Mobject
@ -33,6 +34,7 @@ if TYPE_CHECKING:
T = TypeVar("T")
#### Delete these two ###
def get_vectorized_rgb_gradient_function(
min_value: T,
max_value: T,
@ -52,6 +54,7 @@ def get_vectorized_rgb_gradient_function(
inter_alphas = inter_alphas.repeat(3).reshape((len(indices), 3))
result = interpolate(rgbs[indices], rgbs[next_indices], inter_alphas)
return result
return func
@ -62,6 +65,7 @@ def get_rgb_gradient_function(
) -> Callable[[float], Vect3]:
vectorized_func = get_vectorized_rgb_gradient_function(min_value, max_value, color_map)
return lambda value: vectorized_func(np.array([value]))[0]
####
def move_along_vector_field(
@ -106,16 +110,16 @@ def move_points_along_vector_field(
return mobject
def get_sample_points_from_coordinate_system(
def get_sample_coords(
coordinate_system: CoordinateSystem,
step_multiple: float
step_multiple: float = 1.0
) -> it.product[tuple[Vect3, ...]]:
ranges = []
for range_args in coordinate_system.get_all_ranges():
_min, _max, step = range_args
step *= step_multiple
ranges.append(np.arange(_min, _max + step, step))
return it.product(*ranges)
return np.array(list(it.product(*ranges)))
# Mobjects
@ -124,51 +128,53 @@ def get_sample_points_from_coordinate_system(
class VectorField(VMobject):
def __init__(
self,
func,
func: Callable[Sequence[float], Sequence[float]],
coordinate_system: CoordinateSystem,
step_multiple: float = 0.5,
magnitude_range: Optional[Tuple[float, float]] = None,
color_map_name: Optional[str] = "3b1b_colormap",
color_map: Optional[Callable[Sequence[float]], Vect4Array] = None,
stroke_color: ManimColor = BLUE,
stroke_opacity: float = 1.0,
center: Vect3 = ORIGIN,
sample_points: Optional[Vect3Array] = None,
x_density: float = 2.0,
y_density: float = 2.0,
z_density: float = 2.0,
width: float = 14.0,
height: float = 8.0,
depth: float = 0.0,
stroke_width: float = 2,
tip_width_ratio: float = 4,
tip_len_to_width: float = 0.01,
max_vect_len: float | None = None,
min_drawn_norm: float = 1e-2,
flat_stroke: bool = False,
norm_to_opacity_func=None,
norm_to_rgb_func=None,
norm_to_opacity_func=None, # TODO, check on this
**kwargs
):
self.func = func
self.coordinate_system = coordinate_system
self.stroke_width = stroke_width
self.tip_width_ratio = tip_width_ratio
self.tip_len_to_width = tip_len_to_width
self.min_drawn_norm = min_drawn_norm
self.norm_to_opacity_func = norm_to_opacity_func
self.norm_to_rgb_func = norm_to_rgb_func
if max_vect_len is not None:
self.max_vect_len = max_vect_len
# Search for sample_points
self.sample_coords = get_sample_coords(coordinate_system, step_multiple)
self.update_sample_points()
if max_vect_len is None:
self.max_displayed_vect_len = get_norm(self.sample_points[1] - self.sample_points[0])
else:
densities = np.array([x_density, y_density, z_density])
dims = np.array([width, height, depth])
self.max_vect_len = 1.0 / densities[dims > 0].mean()
self.max_displayed_vect_len = max_vect_len * coordinate_system.get_x_unit_size()
if sample_points is None:
self.sample_points = self.get_sample_points(
center, width, height, depth,
x_density, y_density, z_density
)
# Prepare the color map
if magnitude_range is None:
max_value = max(map(get_norm, func(self.sample_coords)))
magnitude_range = (0, max_value)
self.magnitude_range = magnitude_range
if color_map is not None:
self.color_map = color_map
elif color_map_name is not None:
self.color_map = get_color_map(color_map_name)
else:
self.sample_points = sample_points
self.color_map = None
self.init_base_stroke_width_array(len(self.sample_points))
self.init_base_stroke_width_array(len(self.sample_coords))
super().__init__(
stroke_color=stroke_color,
@ -177,7 +183,7 @@ class VectorField(VMobject):
**kwargs
)
n_samples = len(self.sample_points)
n_samples = len(self.sample_coords)
self.set_points(np.zeros((8 * n_samples - 1, 3)))
self.set_stroke(width=stroke_width)
self.set_joint_type('no_joint')
@ -211,8 +217,8 @@ class VectorField(VMobject):
arr[7::8] = 0
self.base_stroke_width_array = arr
def set_sample_points(self, sample_points: Vect3Array):
self.sample_points = sample_points
def set_sample_coords(self, sample_points: VectArray):
self.sample_coords = sample_coords
return self
def set_stroke(self, color=None, width=None, opacity=None, behind=None, flat=None, recurse=True):
@ -227,35 +233,40 @@ class VectorField(VMobject):
self.stroke_width = width
return self
def update_sample_points(self):
self.sample_points = self.coordinate_system.c2p(*self.sample_coords.T)
def update_vectors(self):
tip_width = self.tip_width_ratio * self.stroke_width
tip_len = self.tip_len_to_width * tip_width
samples = self.sample_points
# Get raw outputs and lengths
outputs = self.func(samples)
norms = np.linalg.norm(outputs, axis=1)[:, np.newaxis]
# Outputs in the coordinate system
outputs = self.func(self.sample_coords)
output_norms = np.linalg.norm(outputs, axis=1)[:, np.newaxis]
# How long should the arrows be drawn?
max_len = self.max_vect_len
# Corresponding vector values in global coordinates
out_vects = self.coordinate_system.c2p(*outputs.T) - self.coordinate_system.get_origin()
out_vect_norms = np.linalg.norm(out_vects, axis=1)[:, np.newaxis]
unit_outputs = np.zeros_like(out_vects)
np.true_divide(out_vects, out_vect_norms, out=unit_outputs, where=(out_vect_norms > 0))
# How long should the arrows be drawn, in global coordinates
max_len = self.max_displayed_vect_len
if max_len < np.inf:
drawn_norms = max_len * np.tanh(norms / max_len)
drawn_norms = max_len * np.tanh(out_vect_norms / max_len)
else:
drawn_norms = norms
drawn_norms = out_vect_norms
# What's the distance from the base of an arrow to
# the base of its head?
dist_to_head_base = np.clip(drawn_norms - tip_len, 0, np.inf)
dist_to_head_base = np.clip(drawn_norms - tip_len, 0, np.inf) # Mixing units!
# Set all points
unit_outputs = np.zeros_like(outputs)
np.true_divide(outputs, norms, out=unit_outputs, where=(norms > self.min_drawn_norm))
points = self.get_points()
points[0::8] = samples
points[2::8] = samples + dist_to_head_base * unit_outputs
points[0::8] = self.sample_points
points[2::8] = self.sample_points + dist_to_head_base * unit_outputs
points[4::8] = points[2::8]
points[6::8] = samples + drawn_norms * unit_outputs
points[6::8] = self.sample_points + drawn_norms * unit_outputs
for i in (1, 3, 5):
points[i::8] = 0.5 * (points[i - 1::8] + points[i + 1::8])
points[7::8] = points[6:-1:8]
@ -267,14 +278,16 @@ class VectorField(VMobject):
self.get_stroke_widths()[:] = width_scalars * width_arr
# Potentially adjust opacity and color
if self.color_map is not None:
self.get_stroke_colors() # Ensures the array is updated to appropriate length
low, high = self.magnitude_range
self.data['stroke_rgba'][:] = self.color_map(
inverse_interpolate(low, high, np.repeat(output_norms, 8)[:-1])
)
if self.norm_to_opacity_func is not None:
self.get_stroke_opacities()[:] = self.norm_to_opacity_func(
np.repeat(norms, 8)[:-1]
)
if self.norm_to_rgb_func is not None:
self.get_stroke_colors()
self.data['stroke_rgba'][:, :3] = self.norm_to_rgb_func(
np.repeat(norms, 8)[:-1]
np.repeat(output_norms, 8)[:-1]
)
self.note_changed_data()
@ -321,11 +334,11 @@ class OldVectorField(VGroup):
self.opacity = opacity
self.vector_config = dict(vector_config)
self.value_to_rgb = get_rgb_gradient_function(
self.value_to_rgb = get_vectorized_rgb_gradient_function(
*self.magnitude_range, self.color_map,
)
samples = get_sample_points_from_coordinate_system(
samples = get_sample_coords(
coordinate_system, self.step_multiple
)
self.add(*(
@ -438,7 +451,7 @@ class StreamLines(VGroup):
def get_start_points(self) -> Vect3Array:
cs = self.coordinate_system
sample_coords = get_sample_points_from_coordinate_system(
sample_coords = get_sample_coords(
cs, self.step_multiple,
)

View file

@ -14,8 +14,8 @@ from manimlib.utils.iterables import resize_with_interpolation
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Iterable, Sequence
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array
from typing import Iterable, Sequence, Callable
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, Vect4Array, NDArray
def color_to_rgb(color: ManimColor) -> Vect3:
@ -134,6 +134,33 @@ def random_bright_color(
))
def get_colormap_from_colors(colors: Iterable[ManimColor]) -> Callable[[Sequence[float]], Vect4Array]:
"""
Returns a funciton which takes in values between 0 and 1, and returns
a corresponding list of rgba values
"""
rgbas = np.array([color_to_rgba(color) for color in colors])
def func(values):
alphas = np.clip(values, 0, 1)
scaled_alphas = alphas * (len(rgbas) - 1)
indices = scaled_alphas.astype(int)
next_indices = np.clip(indices + 1, 0, len(rgbas) - 1)
inter_alphas = scaled_alphas % 1
inter_alphas = inter_alphas.repeat(4).reshape((len(indices), 4))
result = interpolate(rgbas[indices], rgbas[next_indices], inter_alphas)
return result
return func
def get_color_map(map_name: str) -> Callable[[Sequence[float]], Vect4Array]:
if map_name == "3b1b_colormap":
return get_colormap_from_colors(COLORMAP_3B1B)
return plt.get_cmap(map_name)
# Delete this?
def get_colormap_list(
map_name: str = "viridis",
n_colors: int = 9