diff --git a/manimlib/mobject/vector_field.py b/manimlib/mobject/vector_field.py index 5355cd74..876db25a 100644 --- a/manimlib/mobject/vector_field.py +++ b/manimlib/mobject/vector_field.py @@ -150,16 +150,17 @@ class VectorField(VMobject): func: Callable[[VectArray], VectArray], # Typically a set of Axes or NumberPlane coordinate_system: CoordinateSystem, - density: float = 2.0, # Describe as a density instead? + density: float = 2.0, magnitude_range: Optional[Tuple[float, float]] = None, + color: Optional[ManimColor] = None, color_map_name: Optional[str] = "3b1b_colormap", color_map: Optional[Callable[[Sequence[float]], Vect4Array]] = None, - stroke_color: ManimColor = BLUE, stroke_opacity: float = 1.0, - stroke_width: float = 2, + stroke_width: float = 3, tip_width_ratio: float = 4, tip_len_to_width: float = 0.01, max_vect_len: float | None = None, + max_vect_len_to_step_size: float = 0.8, flat_stroke: bool = False, norm_to_opacity_func=None, # TODO, check on this **kwargs @@ -176,7 +177,8 @@ class VectorField(VMobject): self.update_sample_points() if max_vect_len is None: - self.max_displayed_vect_len = get_norm(self.sample_points[1] - self.sample_points[0]) + step_size = get_norm(self.sample_points[1] - self.sample_points[0]) + self.max_displayed_vect_len = max_vect_len_to_step_size * step_size else: self.max_displayed_vect_len = max_vect_len * coordinate_system.get_x_unit_size() @@ -187,27 +189,25 @@ class VectorField(VMobject): self.magnitude_range = magnitude_range - if color_map is not None: - self.color_map = color_map - elif color_map_name is not None: - self.color_map = get_color_map(color_map_name) - else: + if color is not None: self.color_map = None + else: + self.color_map = color_map or get_color_map(color_map_name) self.init_base_stroke_width_array(len(self.sample_coords)) super().__init__( - stroke_color=stroke_color, stroke_opacity=stroke_opacity, flat_stroke=flat_stroke, **kwargs ) + self.set_stroke(color, stroke_width) + self.update_vectors() + def init_points(self): n_samples = len(self.sample_coords) self.set_points(np.zeros((8 * n_samples - 1, 3))) - self.set_stroke(width=stroke_width) self.set_joint_type('no_joint') - self.update_vectors() def get_sample_points( self,