diff --git a/manimlib/mobject/coordinate_systems.py b/manimlib/mobject/coordinate_systems.py index 38bab433..6863fee5 100644 --- a/manimlib/mobject/coordinate_systems.py +++ b/manimlib/mobject/coordinate_systems.py @@ -45,6 +45,12 @@ DEFAULT_X_RANGE = (-8.0, 8.0, 1.0) DEFAULT_Y_RANGE = (-4.0, 4.0, 1.0) +def full_range_specifier(range_args): + if len(range_args) == 2: + return (*range_args, 1) + return range_args + + class CoordinateSystem(ABC): """ Abstract class for Axes and NumberPlane @@ -57,8 +63,8 @@ class CoordinateSystem(ABC): y_range: RangeSpecifier = DEFAULT_Y_RANGE, num_sampled_graph_points_per_tick: int = 5, ): - self.x_range = x_range - self.y_range = y_range + self.x_range = full_range_specifier(x_range) + self.y_range = full_range_specifier(y_range) self.num_sampled_graph_points_per_tick = num_sampled_graph_points_per_tick @abstractmethod @@ -536,7 +542,7 @@ class ThreeDAxes(Axes): ): Axes.__init__(self, x_range, y_range, **kwargs) - self.z_range = z_range + self.z_range = full_range_specifier(z_range) self.z_axis = self.create_axis( self.z_range, axis_config=merge_dicts_recursively( diff --git a/manimlib/mobject/vector_field.py b/manimlib/mobject/vector_field.py index df784190..0a1eb94a 100644 --- a/manimlib/mobject/vector_field.py +++ b/manimlib/mobject/vector_field.py @@ -15,6 +15,7 @@ from manimlib.utils.bezier import interpolate from manimlib.utils.bezier import inverse_interpolate from manimlib.utils.color import get_colormap_list from manimlib.utils.color import rgb_to_color +from manimlib.utils.color import get_color_map from manimlib.utils.dict_ops import merge_dicts_recursively from manimlib.utils.iterables import cartesian_product from manimlib.utils.rate_functions import linear @@ -25,7 +26,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Callable, Iterable, Sequence, TypeVar, Tuple - from manimlib.typing import ManimColor, Vect3, VectN, Vect3Array + from manimlib.typing import ManimColor, Vect3, VectN, Vect3Array, Vect4Array from manimlib.mobject.coordinate_systems import CoordinateSystem from manimlib.mobject.mobject import Mobject @@ -33,6 +34,7 @@ if TYPE_CHECKING: T = TypeVar("T") +#### Delete these two ### def get_vectorized_rgb_gradient_function( min_value: T, max_value: T, @@ -52,6 +54,7 @@ def get_vectorized_rgb_gradient_function( inter_alphas = inter_alphas.repeat(3).reshape((len(indices), 3)) result = interpolate(rgbs[indices], rgbs[next_indices], inter_alphas) return result + return func @@ -62,6 +65,7 @@ def get_rgb_gradient_function( ) -> Callable[[float], Vect3]: vectorized_func = get_vectorized_rgb_gradient_function(min_value, max_value, color_map) return lambda value: vectorized_func(np.array([value]))[0] +#### def move_along_vector_field( @@ -106,16 +110,16 @@ def move_points_along_vector_field( return mobject -def get_sample_points_from_coordinate_system( +def get_sample_coords( coordinate_system: CoordinateSystem, - step_multiple: float + step_multiple: float = 1.0 ) -> it.product[tuple[Vect3, ...]]: ranges = [] for range_args in coordinate_system.get_all_ranges(): _min, _max, step = range_args step *= step_multiple ranges.append(np.arange(_min, _max + step, step)) - return it.product(*ranges) + return np.array(list(it.product(*ranges))) # Mobjects @@ -124,51 +128,53 @@ def get_sample_points_from_coordinate_system( class VectorField(VMobject): def __init__( self, - func, + func: Callable[Sequence[float], Sequence[float]], + coordinate_system: CoordinateSystem, + step_multiple: float = 0.5, + magnitude_range: Optional[Tuple[float, float]] = None, + color_map_name: Optional[str] = "3b1b_colormap", + color_map: Optional[Callable[Sequence[float]], Vect4Array] = None, stroke_color: ManimColor = BLUE, stroke_opacity: float = 1.0, - center: Vect3 = ORIGIN, - sample_points: Optional[Vect3Array] = None, - x_density: float = 2.0, - y_density: float = 2.0, - z_density: float = 2.0, - width: float = 14.0, - height: float = 8.0, - depth: float = 0.0, stroke_width: float = 2, tip_width_ratio: float = 4, tip_len_to_width: float = 0.01, max_vect_len: float | None = None, - min_drawn_norm: float = 1e-2, flat_stroke: bool = False, - norm_to_opacity_func=None, - norm_to_rgb_func=None, + norm_to_opacity_func=None, # TODO, check on this **kwargs ): self.func = func + self.coordinate_system = coordinate_system self.stroke_width = stroke_width self.tip_width_ratio = tip_width_ratio self.tip_len_to_width = tip_len_to_width - self.min_drawn_norm = min_drawn_norm self.norm_to_opacity_func = norm_to_opacity_func - self.norm_to_rgb_func = norm_to_rgb_func - if max_vect_len is not None: - self.max_vect_len = max_vect_len + # Search for sample_points + self.sample_coords = get_sample_coords(coordinate_system, step_multiple) + self.update_sample_points() + + if max_vect_len is None: + self.max_displayed_vect_len = get_norm(self.sample_points[1] - self.sample_points[0]) else: - densities = np.array([x_density, y_density, z_density]) - dims = np.array([width, height, depth]) - self.max_vect_len = 1.0 / densities[dims > 0].mean() + self.max_displayed_vect_len = max_vect_len * coordinate_system.get_x_unit_size() - if sample_points is None: - self.sample_points = self.get_sample_points( - center, width, height, depth, - x_density, y_density, z_density - ) + # Prepare the color map + if magnitude_range is None: + max_value = max(map(get_norm, func(self.sample_coords))) + magnitude_range = (0, max_value) + + self.magnitude_range = magnitude_range + + if color_map is not None: + self.color_map = color_map + elif color_map_name is not None: + self.color_map = get_color_map(color_map_name) else: - self.sample_points = sample_points + self.color_map = None - self.init_base_stroke_width_array(len(self.sample_points)) + self.init_base_stroke_width_array(len(self.sample_coords)) super().__init__( stroke_color=stroke_color, @@ -177,7 +183,7 @@ class VectorField(VMobject): **kwargs ) - n_samples = len(self.sample_points) + n_samples = len(self.sample_coords) self.set_points(np.zeros((8 * n_samples - 1, 3))) self.set_stroke(width=stroke_width) self.set_joint_type('no_joint') @@ -211,8 +217,8 @@ class VectorField(VMobject): arr[7::8] = 0 self.base_stroke_width_array = arr - def set_sample_points(self, sample_points: Vect3Array): - self.sample_points = sample_points + def set_sample_coords(self, sample_points: VectArray): + self.sample_coords = sample_coords return self def set_stroke(self, color=None, width=None, opacity=None, behind=None, flat=None, recurse=True): @@ -227,35 +233,40 @@ class VectorField(VMobject): self.stroke_width = width return self + def update_sample_points(self): + self.sample_points = self.coordinate_system.c2p(*self.sample_coords.T) + def update_vectors(self): tip_width = self.tip_width_ratio * self.stroke_width tip_len = self.tip_len_to_width * tip_width - samples = self.sample_points - # Get raw outputs and lengths - outputs = self.func(samples) - norms = np.linalg.norm(outputs, axis=1)[:, np.newaxis] + # Outputs in the coordinate system + outputs = self.func(self.sample_coords) + output_norms = np.linalg.norm(outputs, axis=1)[:, np.newaxis] - # How long should the arrows be drawn? - max_len = self.max_vect_len + # Corresponding vector values in global coordinates + out_vects = self.coordinate_system.c2p(*outputs.T) - self.coordinate_system.get_origin() + out_vect_norms = np.linalg.norm(out_vects, axis=1)[:, np.newaxis] + unit_outputs = np.zeros_like(out_vects) + np.true_divide(out_vects, out_vect_norms, out=unit_outputs, where=(out_vect_norms > 0)) + + # How long should the arrows be drawn, in global coordinates + max_len = self.max_displayed_vect_len if max_len < np.inf: - drawn_norms = max_len * np.tanh(norms / max_len) + drawn_norms = max_len * np.tanh(out_vect_norms / max_len) else: - drawn_norms = norms + drawn_norms = out_vect_norms # What's the distance from the base of an arrow to # the base of its head? - dist_to_head_base = np.clip(drawn_norms - tip_len, 0, np.inf) + dist_to_head_base = np.clip(drawn_norms - tip_len, 0, np.inf) # Mixing units! # Set all points - unit_outputs = np.zeros_like(outputs) - np.true_divide(outputs, norms, out=unit_outputs, where=(norms > self.min_drawn_norm)) - points = self.get_points() - points[0::8] = samples - points[2::8] = samples + dist_to_head_base * unit_outputs + points[0::8] = self.sample_points + points[2::8] = self.sample_points + dist_to_head_base * unit_outputs points[4::8] = points[2::8] - points[6::8] = samples + drawn_norms * unit_outputs + points[6::8] = self.sample_points + drawn_norms * unit_outputs for i in (1, 3, 5): points[i::8] = 0.5 * (points[i - 1::8] + points[i + 1::8]) points[7::8] = points[6:-1:8] @@ -267,14 +278,16 @@ class VectorField(VMobject): self.get_stroke_widths()[:] = width_scalars * width_arr # Potentially adjust opacity and color + if self.color_map is not None: + self.get_stroke_colors() # Ensures the array is updated to appropriate length + low, high = self.magnitude_range + self.data['stroke_rgba'][:] = self.color_map( + inverse_interpolate(low, high, np.repeat(output_norms, 8)[:-1]) + ) + if self.norm_to_opacity_func is not None: self.get_stroke_opacities()[:] = self.norm_to_opacity_func( - np.repeat(norms, 8)[:-1] - ) - if self.norm_to_rgb_func is not None: - self.get_stroke_colors() - self.data['stroke_rgba'][:, :3] = self.norm_to_rgb_func( - np.repeat(norms, 8)[:-1] + np.repeat(output_norms, 8)[:-1] ) self.note_changed_data() @@ -321,11 +334,11 @@ class OldVectorField(VGroup): self.opacity = opacity self.vector_config = dict(vector_config) - self.value_to_rgb = get_rgb_gradient_function( + self.value_to_rgb = get_vectorized_rgb_gradient_function( *self.magnitude_range, self.color_map, ) - samples = get_sample_points_from_coordinate_system( + samples = get_sample_coords( coordinate_system, self.step_multiple ) self.add(*( @@ -438,7 +451,7 @@ class StreamLines(VGroup): def get_start_points(self) -> Vect3Array: cs = self.coordinate_system - sample_coords = get_sample_points_from_coordinate_system( + sample_coords = get_sample_coords( cs, self.step_multiple, ) diff --git a/manimlib/utils/color.py b/manimlib/utils/color.py index f9ceb21a..a511375e 100644 --- a/manimlib/utils/color.py +++ b/manimlib/utils/color.py @@ -14,8 +14,8 @@ from manimlib.utils.iterables import resize_with_interpolation from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Iterable, Sequence - from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array + from typing import Iterable, Sequence, Callable + from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, Vect4Array, NDArray def color_to_rgb(color: ManimColor) -> Vect3: @@ -134,6 +134,33 @@ def random_bright_color( )) +def get_colormap_from_colors(colors: Iterable[ManimColor]) -> Callable[[Sequence[float]], Vect4Array]: + """ + Returns a funciton which takes in values between 0 and 1, and returns + a corresponding list of rgba values + """ + rgbas = np.array([color_to_rgba(color) for color in colors]) + + def func(values): + alphas = np.clip(values, 0, 1) + scaled_alphas = alphas * (len(rgbas) - 1) + indices = scaled_alphas.astype(int) + next_indices = np.clip(indices + 1, 0, len(rgbas) - 1) + inter_alphas = scaled_alphas % 1 + inter_alphas = inter_alphas.repeat(4).reshape((len(indices), 4)) + result = interpolate(rgbas[indices], rgbas[next_indices], inter_alphas) + return result + + return func + + +def get_color_map(map_name: str) -> Callable[[Sequence[float]], Vect4Array]: + if map_name == "3b1b_colormap": + return get_colormap_from_colors(COLORMAP_3B1B) + return plt.get_cmap(map_name) + + +# Delete this? def get_colormap_list( map_name: str = "viridis", n_colors: int = 9