Merge pull request #2250 from 3b1b/video-work

Video work
This commit is contained in:
Grant Sanderson 2024-11-25 13:44:08 -06:00 committed by GitHub
commit 530cb4f104
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 247 additions and 195 deletions

View file

@ -45,6 +45,12 @@ DEFAULT_X_RANGE = (-8.0, 8.0, 1.0)
DEFAULT_Y_RANGE = (-4.0, 4.0, 1.0)
def full_range_specifier(range_args):
if len(range_args) == 2:
return (*range_args, 1)
return range_args
class CoordinateSystem(ABC):
"""
Abstract class for Axes and NumberPlane
@ -57,8 +63,8 @@ class CoordinateSystem(ABC):
y_range: RangeSpecifier = DEFAULT_Y_RANGE,
num_sampled_graph_points_per_tick: int = 5,
):
self.x_range = x_range
self.y_range = y_range
self.x_range = full_range_specifier(x_range)
self.y_range = full_range_specifier(y_range)
self.num_sampled_graph_points_per_tick = num_sampled_graph_points_per_tick
@abstractmethod
@ -536,7 +542,7 @@ class ThreeDAxes(Axes):
):
Axes.__init__(self, x_range, y_range, **kwargs)
self.z_range = z_range
self.z_range = full_range_specifier(z_range)
self.z_axis = self.create_axis(
self.z_range,
axis_config=merge_dicts_recursively(

View file

@ -1950,12 +1950,14 @@ class Mobject(object):
def set_clip_plane(
self,
vect: Vect3 | None = None,
threshold: float | None = None
threshold: float | None = None,
recurse=True
) -> Self:
if vect is not None:
self.uniforms["clip_plane"][:3] = vect
if threshold is not None:
self.uniforms["clip_plane"][3] = threshold
for submob in self.get_family(recurse):
if vect is not None:
submob.uniforms["clip_plane"][:3] = vect
if threshold is not None:
submob.uniforms["clip_plane"][3] = threshold
return self
def deactivate_clip_plane(self) -> Self:

View file

@ -22,7 +22,12 @@ if TYPE_CHECKING:
@lru_cache()
def char_to_cahced_mob(char: str, **text_config):
return Text(char, **text_config)
if "\\" in char:
# This is for when the "character" is a LaTeX command
# like ^\circ or \dots
return Tex(char, **text_config)
else:
return Text(char, **text_config)
class DecimalNumber(VMobject):

View file

@ -48,6 +48,7 @@ class Tex(StringMobject):
if not tex_string.strip():
tex_string = R"\\"
self.font_size = font_size
self.tex_string = tex_string
self.alignment = alignment
self.template = template
@ -86,6 +87,10 @@ class Tex(StringMobject):
content, self.template, self.additional_preamble, self.tex_string
)
def _handle_scale_side_effects(self, scale_factor: float) -> Self:
self.font_size *= scale_factor
return self
# Parsing
@staticmethod

View file

@ -94,7 +94,7 @@ class Sphere(Surface):
def __init__(
self,
u_range: Tuple[float, float] = (0, TAU),
v_range: Tuple[float, float] = (1e-5, PI - 1e-5),
v_range: Tuple[float, float] = (0, PI),
resolution: Tuple[int, int] = (101, 51),
radius: float = 1.0,
**kwargs,
@ -158,7 +158,6 @@ class Cylinder(Surface):
**kwargs
)
def init_points(self):
super().init_points()
self.scale(self.radius)
@ -169,6 +168,20 @@ class Cylinder(Surface):
return np.array([np.cos(u), np.sin(u), v])
class Cone(Cylinder):
def __init__(
self,
u_range: Tuple[float, float] = (0, TAU),
v_range: Tuple[float, float] = (0, 1),
*args,
**kwargs,
):
super().__init__(u_range=u_range, v_range=v_range, *args, **kwargs)
def uv_func(self, u: float, v: float) -> np.ndarray:
return np.array([(1 - v) * np.cos(u), (1 - v) * np.sin(u), v])
class Line3D(Cylinder):
def __init__(
self,

View file

@ -94,6 +94,7 @@ class VMobject(Mobject):
# Could also be "no_joint", "bevel", "miter"
joint_type: str = "auto",
flat_stroke: bool = False,
scale_stroke_with_zoom: bool = False,
use_simple_quadratic_approx: bool = False,
# Measured in pixel widths
anti_alias_width: float = 1.5,
@ -110,6 +111,7 @@ class VMobject(Mobject):
self.long_lines = long_lines
self.joint_type = joint_type
self.flat_stroke = flat_stroke
self.scale_stroke_with_zoom = scale_stroke_with_zoom
self.use_simple_quadratic_approx = use_simple_quadratic_approx
self.anti_alias_width = anti_alias_width
self.fill_border_width = fill_border_width
@ -126,9 +128,12 @@ class VMobject(Mobject):
def init_uniforms(self):
super().init_uniforms()
self.uniforms["anti_alias_width"] = self.anti_alias_width
self.uniforms["joint_type"] = JOINT_TYPE_MAP[self.joint_type]
self.uniforms["flat_stroke"] = float(self.flat_stroke)
self.uniforms.update(
anti_alias_width=self.anti_alias_width,
joint_type=JOINT_TYPE_MAP[self.joint_type],
flat_stroke=float(self.flat_stroke),
scale_stroke_with_zoom=float(self.scale_stroke_with_zoom)
)
def add(self, *vmobjects: VMobject) -> Self:
if not all((isinstance(m, VMobject) for m in vmobjects)):
@ -399,6 +404,13 @@ class VMobject(Mobject):
def get_flat_stroke(self) -> bool:
return self.uniforms["flat_stroke"] == 1.0
def set_scale_stroke_with_zoom(self, scale_stroke_with_zoom: bool = True, recurse: bool = True) -> Self:
self.set_uniform(recurse, scale_stroke_with_zoom=float(scale_stroke_with_zoom))
pass
def get_scale_stroke_with_zoom(self) -> bool:
return self.uniforms["flat_stroke"] == 1.0
def set_joint_type(self, joint_type: str, recurse: bool = True) -> Self:
for mob in self.get_family(recurse):
mob.uniforms["joint_type"] = JOINT_TYPE_MAP[joint_type]

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import itertools as it
import numpy as np
from scipy.integrate import solve_ivp
from manimlib.constants import FRAME_HEIGHT, FRAME_WIDTH
from manimlib.constants import BLUE, WHITE
@ -15,6 +16,7 @@ from manimlib.utils.bezier import interpolate
from manimlib.utils.bezier import inverse_interpolate
from manimlib.utils.color import get_colormap_list
from manimlib.utils.color import rgb_to_color
from manimlib.utils.color import get_color_map
from manimlib.utils.dict_ops import merge_dicts_recursively
from manimlib.utils.iterables import cartesian_product
from manimlib.utils.rate_functions import linear
@ -25,7 +27,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Callable, Iterable, Sequence, TypeVar, Tuple
from manimlib.typing import ManimColor, Vect3, VectN, Vect3Array
from manimlib.typing import ManimColor, Vect3, VectN, Vect2Array, Vect3Array, Vect4Array
from manimlib.mobject.coordinate_systems import CoordinateSystem
from manimlib.mobject.mobject import Mobject
@ -33,6 +35,7 @@ if TYPE_CHECKING:
T = TypeVar("T")
#### Delete these two ###
def get_vectorized_rgb_gradient_function(
min_value: T,
max_value: T,
@ -52,6 +55,7 @@ def get_vectorized_rgb_gradient_function(
inter_alphas = inter_alphas.repeat(3).reshape((len(indices), 3))
result = interpolate(rgbs[indices], rgbs[next_indices], inter_alphas)
return result
return func
@ -62,6 +66,17 @@ def get_rgb_gradient_function(
) -> Callable[[float], Vect3]:
vectorized_func = get_vectorized_rgb_gradient_function(min_value, max_value, color_map)
return lambda value: vectorized_func(np.array([value]))[0]
####
def ode_solution_points(function, state0, time, dt=0.01):
solution = solve_ivp(
lambda t, state: function(state),
t_span=(0, time),
y0=state0,
t_eval=np.arange(0, time, dt)
)
return solution.y.T
def move_along_vector_field(
@ -106,16 +121,23 @@ def move_points_along_vector_field(
return mobject
def get_sample_points_from_coordinate_system(
def get_sample_coords(
coordinate_system: CoordinateSystem,
step_multiple: float
density: float = 1.0
) -> it.product[tuple[Vect3, ...]]:
ranges = []
for range_args in coordinate_system.get_all_ranges():
_min, _max, step = range_args
step *= step_multiple
step /= density
ranges.append(np.arange(_min, _max + step, step))
return it.product(*ranges)
return np.array(list(it.product(*ranges)))
def vectorize(pointwise_function: Callable[Tuple, Tuple]):
def v_func(coords_array: VectArray) -> VectArray:
return np.array([pointwise_function(*coords) for coords in coords_array])
return v_func
# Mobjects
@ -124,65 +146,69 @@ def get_sample_points_from_coordinate_system(
class VectorField(VMobject):
def __init__(
self,
func,
stroke_color: ManimColor = BLUE,
# Vectorized function: Takes in an array of coordinates, returns an array of outputs.
func: Callable[[VectArray], VectArray],
# Typically a set of Axes or NumberPlane
coordinate_system: CoordinateSystem,
density: float = 2.0,
magnitude_range: Optional[Tuple[float, float]] = None,
color: Optional[ManimColor] = None,
color_map_name: Optional[str] = "3b1b_colormap",
color_map: Optional[Callable[[Sequence[float]], Vect4Array]] = None,
stroke_opacity: float = 1.0,
center: Vect3 = ORIGIN,
sample_points: Optional[Vect3Array] = None,
x_density: float = 2.0,
y_density: float = 2.0,
z_density: float = 2.0,
width: float = 14.0,
height: float = 8.0,
depth: float = 0.0,
stroke_width: float = 2,
stroke_width: float = 3,
tip_width_ratio: float = 4,
tip_len_to_width: float = 0.01,
max_vect_len: float | None = None,
min_drawn_norm: float = 1e-2,
max_vect_len_to_step_size: float = 0.8,
flat_stroke: bool = False,
norm_to_opacity_func=None,
norm_to_rgb_func=None,
norm_to_opacity_func=None, # TODO, check on this
**kwargs
):
self.func = func
self.coordinate_system = coordinate_system
self.stroke_width = stroke_width
self.tip_width_ratio = tip_width_ratio
self.tip_len_to_width = tip_len_to_width
self.min_drawn_norm = min_drawn_norm
self.norm_to_opacity_func = norm_to_opacity_func
self.norm_to_rgb_func = norm_to_rgb_func
if max_vect_len is not None:
self.max_vect_len = max_vect_len
# Search for sample_points
self.sample_coords = get_sample_coords(coordinate_system, density)
self.update_sample_points()
if max_vect_len is None:
step_size = get_norm(self.sample_points[1] - self.sample_points[0])
self.max_displayed_vect_len = max_vect_len_to_step_size * step_size
else:
densities = np.array([x_density, y_density, z_density])
dims = np.array([width, height, depth])
self.max_vect_len = 1.0 / densities[dims > 0].mean()
self.max_displayed_vect_len = max_vect_len * coordinate_system.get_x_unit_size()
if sample_points is None:
self.sample_points = self.get_sample_points(
center, width, height, depth,
x_density, y_density, z_density
)
# Prepare the color map
if magnitude_range is None:
max_value = max(map(get_norm, func(self.sample_coords)))
magnitude_range = (0, max_value)
self.magnitude_range = magnitude_range
if color is not None:
self.color_map = None
else:
self.sample_points = sample_points
self.color_map = color_map or get_color_map(color_map_name)
self.init_base_stroke_width_array(len(self.sample_points))
self.init_base_stroke_width_array(len(self.sample_coords))
super().__init__(
stroke_color=stroke_color,
stroke_opacity=stroke_opacity,
flat_stroke=flat_stroke,
**kwargs
)
n_samples = len(self.sample_points)
self.set_points(np.zeros((8 * n_samples - 1, 3)))
self.set_stroke(width=stroke_width)
self.set_joint_type('no_joint')
self.set_stroke(color, stroke_width)
self.update_vectors()
def init_points(self):
n_samples = len(self.sample_coords)
self.set_points(np.zeros((8 * n_samples - 1, 3)))
self.set_joint_type('no_joint')
def get_sample_points(
self,
center: np.ndarray,
@ -211,8 +237,8 @@ class VectorField(VMobject):
arr[7::8] = 0
self.base_stroke_width_array = arr
def set_sample_points(self, sample_points: Vect3Array):
self.sample_points = sample_points
def set_sample_coords(self, sample_points: VectArray):
self.sample_coords = sample_coords
return self
def set_stroke(self, color=None, width=None, opacity=None, behind=None, flat=None, recurse=True):
@ -227,35 +253,40 @@ class VectorField(VMobject):
self.stroke_width = width
return self
def update_sample_points(self):
self.sample_points = self.coordinate_system.c2p(*self.sample_coords.T)
def update_vectors(self):
tip_width = self.tip_width_ratio * self.stroke_width
tip_len = self.tip_len_to_width * tip_width
samples = self.sample_points
# Get raw outputs and lengths
outputs = self.func(samples)
norms = np.linalg.norm(outputs, axis=1)[:, np.newaxis]
# Outputs in the coordinate system
outputs = self.func(self.sample_coords)
output_norms = np.linalg.norm(outputs, axis=1)[:, np.newaxis]
# How long should the arrows be drawn?
max_len = self.max_vect_len
# Corresponding vector values in global coordinates
out_vects = self.coordinate_system.c2p(*outputs.T) - self.coordinate_system.get_origin()
out_vect_norms = np.linalg.norm(out_vects, axis=1)[:, np.newaxis]
unit_outputs = np.zeros_like(out_vects)
np.true_divide(out_vects, out_vect_norms, out=unit_outputs, where=(out_vect_norms > 0))
# How long should the arrows be drawn, in global coordinates
max_len = self.max_displayed_vect_len
if max_len < np.inf:
drawn_norms = max_len * np.tanh(norms / max_len)
drawn_norms = max_len * np.tanh(out_vect_norms / max_len)
else:
drawn_norms = norms
drawn_norms = out_vect_norms
# What's the distance from the base of an arrow to
# the base of its head?
dist_to_head_base = np.clip(drawn_norms - tip_len, 0, np.inf)
dist_to_head_base = np.clip(drawn_norms - tip_len, 0, np.inf) # Mixing units!
# Set all points
unit_outputs = np.zeros_like(outputs)
np.true_divide(outputs, norms, out=unit_outputs, where=(norms > self.min_drawn_norm))
points = self.get_points()
points[0::8] = samples
points[2::8] = samples + dist_to_head_base * unit_outputs
points[0::8] = self.sample_points
points[2::8] = self.sample_points + dist_to_head_base * unit_outputs
points[4::8] = points[2::8]
points[6::8] = samples + drawn_norms * unit_outputs
points[6::8] = self.sample_points + drawn_norms * unit_outputs
for i in (1, 3, 5):
points[i::8] = 0.5 * (points[i - 1::8] + points[i + 1::8])
points[7::8] = points[6:-1:8]
@ -267,14 +298,16 @@ class VectorField(VMobject):
self.get_stroke_widths()[:] = width_scalars * width_arr
# Potentially adjust opacity and color
if self.color_map is not None:
self.get_stroke_colors() # Ensures the array is updated to appropriate length
low, high = self.magnitude_range
self.data['stroke_rgba'][:, :3] = self.color_map(
inverse_interpolate(low, high, np.repeat(output_norms, 8)[:-1])
)[:, :3]
if self.norm_to_opacity_func is not None:
self.get_stroke_opacities()[:] = self.norm_to_opacity_func(
np.repeat(norms, 8)[:-1]
)
if self.norm_to_rgb_func is not None:
self.get_stroke_colors()
self.data['stroke_rgba'][:, :3] = self.norm_to_rgb_func(
np.repeat(norms, 8)[:-1]
np.repeat(output_norms, 8)[:-1]
)
self.note_changed_data()
@ -285,90 +318,33 @@ class TimeVaryingVectorField(VectorField):
def __init__(
self,
# Takes in an array of points and a float for time
time_func,
time_func: Callable[[VectArray, float], VectArray],
coordinate_system: CoordinateSystem,
**kwargs
):
self.time = 0
super().__init__(func=lambda p: time_func(p, self.time), **kwargs)
def func(coords):
return time_func(coords, self.time)
super().__init__(func, coordinate_system, **kwargs)
self.add_updater(lambda m, dt: m.increment_time(dt))
always(self.update_vectors)
self.always.update_vectors()
def increment_time(self, dt):
self.time += dt
class OldVectorField(VGroup):
def __init__(
self,
func: Callable[[float, float], Sequence[float]],
coordinate_system: CoordinateSystem,
step_multiple: float = 0.5,
magnitude_range: Tuple[float, float] = (0, 2),
color_map: str = "3b1b_colormap",
# Takes in actual norm, spits out displayed norm
length_func: Callable[[float], float] = lambda norm: 0.45 * sigmoid(norm),
opacity: float = 1.0,
vector_config: dict = dict(),
**kwargs
):
super().__init__(**kwargs)
self.func = func
self.coordinate_system = coordinate_system
self.step_multiple = step_multiple
self.magnitude_range = magnitude_range
self.color_map = color_map
self.length_func = length_func
self.opacity = opacity
self.vector_config = dict(vector_config)
self.value_to_rgb = get_rgb_gradient_function(
*self.magnitude_range, self.color_map,
)
samples = get_sample_points_from_coordinate_system(
coordinate_system, self.step_multiple
)
self.add(*(
self.get_vector(coords)
for coords in samples
))
def get_vector(self, coords: Iterable[float], **kwargs) -> Arrow:
vector_config = merge_dicts_recursively(
self.vector_config,
kwargs
)
output = np.array(self.func(*coords))
norm = get_norm(output)
if norm > 0:
output *= self.length_func(norm) / norm
origin = self.coordinate_system.get_origin()
_input = self.coordinate_system.c2p(*coords)
_output = self.coordinate_system.c2p(*output)
vect = Arrow(
origin, _output, buff=0,
**vector_config
)
vect.shift(_input - origin)
vect.set_color(
rgb_to_color(self.value_to_rgb(norm)),
opacity=self.opacity,
)
return vect
class StreamLines(VGroup):
def __init__(
self,
func: Callable[[float, float], Sequence[float]],
func: Callable[[VectArray], VectArray],
coordinate_system: CoordinateSystem,
step_multiple: float = 0.5,
density: float = 1.0,
n_repeats: int = 1,
noise_factor: float | None = None,
# Config for drawing lines
solution_time: float = 3,
dt: float = 0.05,
arc_len: float = 3,
max_time_steps: int = 200,
@ -387,9 +363,10 @@ class StreamLines(VGroup):
super().__init__(**kwargs)
self.func = func
self.coordinate_system = coordinate_system
self.step_multiple = step_multiple
self.density = density
self.n_repeats = n_repeats
self.noise_factor = noise_factor
self.solution_time = solution_time
self.dt = dt
self.arc_len = arc_len
self.max_time_steps = max_time_steps
@ -406,48 +383,38 @@ class StreamLines(VGroup):
self.draw_lines()
self.init_style()
def point_func(self, point: Vect3) -> Vect3:
in_coords = self.coordinate_system.p2c(point)
out_coords = self.func(*in_coords)
return self.coordinate_system.c2p(*out_coords)
def point_func(self, points: Vect3Array) -> Vect3:
in_coords = np.array(self.coordinate_system.p2c(points)).T
out_coords = self.func(in_coords)
origin = self.coordinate_system.get_origin()
return self.coordinate_system.c2p(*out_coords.T) - origin
def draw_lines(self) -> None:
lines = []
origin = self.coordinate_system.get_origin()
for point in self.get_start_points():
points = [point]
total_arc_len = 0
time = 0
for x in range(self.max_time_steps):
time += self.dt
last_point = points[-1]
new_point = last_point + self.dt * (self.point_func(last_point) - origin)
points.append(new_point)
total_arc_len += get_norm(new_point - last_point)
if get_norm(last_point) > self.cutoff_norm:
break
if total_arc_len > self.arc_len:
break
# Todo, it feels like coordinate system should just have
# the ODE solver built into it, no?
lines = []
for coords in self.get_sample_coords():
solution_coords = ode_solution_points(self.func, coords, self.solution_time, self.dt)
line = VMobject()
line.virtual_time = time
step = max(1, int(len(points) / self.n_samples_per_line))
line.set_points_as_corners(points[::step])
line.make_smooth(approx=True)
line.set_points_smoothly(self.coordinate_system.c2p(*solution_coords.T))
# TODO, account for arc length somehow?
line.virtual_time = self.solution_time
lines.append(line)
self.set_submobjects(lines)
def get_start_points(self) -> Vect3Array:
def get_sample_coords(self):
cs = self.coordinate_system
sample_coords = get_sample_points_from_coordinate_system(
cs, self.step_multiple,
)
sample_coords = get_sample_coords(cs, self.density)
noise_factor = self.noise_factor
if noise_factor is None:
noise_factor = cs.x_range[2] * self.step_multiple * 0.5
noise_factor = (cs.get_x_unit_size() / self.density) * 0.5
return np.array([
cs.c2p(*coords) + noise_factor * np.random.random(3)
coords + noise_factor * np.random.random(coords.shape)
for n in range(self.n_repeats)
for coords in sample_coords
])
@ -483,6 +450,7 @@ class AnimatedStreamLines(VGroup):
self,
stream_lines: StreamLines,
lag_range: float = 4,
rate_multiple: float = 1.0,
line_anim_config: dict = dict(
rate_func=linear,
time_width=1.0,
@ -495,7 +463,7 @@ class AnimatedStreamLines(VGroup):
for line in stream_lines:
line.anim = VShowPassingFlash(
line,
run_time=line.virtual_time,
run_time=line.virtual_time / rate_multiple,
**line_anim_config,
)
line.anim.begin()

View file

@ -4,7 +4,7 @@ layout (triangles) in;
layout (triangle_strip, max_vertices = 64) out; // Related to MAX_STEPS below
uniform float anti_alias_width;
uniform float flat_stroke_float;
uniform float flat_stroke;
uniform float pixel_size;
uniform float joint_type;
uniform float frame_scale;
@ -62,13 +62,13 @@ vec3 rotate_vector(vec3 vect, vec3 unit_normal, float angle){
}
vec3 step_to_corner(vec3 point, vec3 tangent, vec3 unit_normal, float joint_angle, bool inside_curve, bool flat_stroke){
vec3 step_to_corner(vec3 point, vec3 tangent, vec3 unit_normal, float joint_angle, bool inside_curve, bool draw_flat){
/*
Step the the left of a curve.
First a perpendicular direction is calculated, then it is adjusted
so as to make a joint.
*/
vec3 unit_tan = normalize(flat_stroke ? tangent : project(tangent, unit_normal));
vec3 unit_tan = normalize(draw_flat ? tangent : project(tangent, unit_normal));
// Step to stroke width bound should be perpendicular
// both to the tangent and the normal direction
@ -78,11 +78,13 @@ vec3 step_to_corner(vec3 point, vec3 tangent, vec3 unit_normal, float joint_angl
// lines up very closely with the direction to the camera, treated here
// as the unit normal. To avoid those, this smoothly transitions to a step
// direction perpendicular to the true curve normal.
float alignment = abs(dot(normalize(tangent), unit_normal));
float alignment_threshold = 0.97; // This could maybe be chosen in a more principled way based on stroke width
if (alignment > alignment_threshold) {
vec3 perp = normalize(cross(v_unit_normal[1], tangent));
step = mix(step, project(step, perp), smoothstep(alignment_threshold, 1.0, alignment));
if(joint_angle != 0){
float alignment = abs(dot(normalize(tangent), unit_normal));
float alignment_threshold = 0.97; // This could maybe be chosen in a more principled way based on stroke width
if (alignment > alignment_threshold) {
vec3 perp = normalize(cross(v_unit_normal[1], tangent));
step = mix(step, project(step, perp), smoothstep(alignment_threshold, 1.0, alignment));
}
}
if (inside_curve || int(joint_type) == NO_JOINT) return step;
@ -93,7 +95,7 @@ vec3 step_to_corner(vec3 point, vec3 tangent, vec3 unit_normal, float joint_angl
if (abs(cos_angle) > COS_THRESHOLD) return step;
// Below here, figure out the adjustment to bevel or miter a joint
if (!flat_stroke){
if (!draw_flat){
// Figure out what joint product would be for everything projected onto
// the plane perpendicular to the normal direction (which here would be to_camera)
step = normalize(cross(unit_normal, unit_tan)); // Back to original step
@ -128,17 +130,17 @@ void emit_point_with_width(
float width,
vec4 joint_color,
bool inside_curve,
bool flat_stroke
bool draw_flat
){
// Find unit normal
vec3 unit_normal = flat_stroke ? v_unit_normal[1] : normalize(camera_position - point);
vec3 unit_normal = draw_flat ? v_unit_normal[1] : normalize(camera_position - point);
// Set styling
color = finalize_color(joint_color, point, unit_normal);
// Figure out the step from the point to the corners of the
// triangle strip around the polyline
vec3 step = step_to_corner(point, tangent, unit_normal, joint_angle, inside_curve, flat_stroke);
vec3 step = step_to_corner(point, tangent, unit_normal, joint_angle, inside_curve, draw_flat);
float aaw = max(anti_alias_width * pixel_size, 1e-8);
// Emit two corners
@ -163,7 +165,7 @@ void main() {
if (vec3(v_stroke_width[0], v_stroke_width[1], v_stroke_width[2]) == vec3(0.0, 0.0, 0.0)) return;
if (vec3(v_color[0].a, v_color[1].a, v_color[2].a) == vec3(0.0, 0.0, 0.0)) return;
bool flat_stroke = bool(flat_stroke_float) || bool(is_fixed_in_frame);
bool draw_flat = bool(flat_stroke) || bool(is_fixed_in_frame);
// Coefficients such that the quadratic bezier is c0 + c1 * t + c2 * t^2
vec3 c0 = verts[0];
@ -207,7 +209,7 @@ void main() {
emit_point_with_width(
point, tangent, joint_angle,
stroke_width, color,
inside_curve, flat_stroke
inside_curve, draw_flat
);
}
EndPrimitive();

View file

@ -2,6 +2,7 @@
uniform float frame_scale;
uniform float is_fixed_in_frame;
uniform float scale_stroke_with_zoom;
in vec3 point;
in vec4 stroke_rgba;
@ -22,7 +23,7 @@ const float STROKE_WIDTH_CONVERSION = 0.01;
void main(){
verts = point;
v_color = stroke_rgba;
v_stroke_width = STROKE_WIDTH_CONVERSION * stroke_width * mix(frame_scale, 1, is_fixed_in_frame);
v_stroke_width = STROKE_WIDTH_CONVERSION * stroke_width * mix(frame_scale, 1, scale_stroke_with_zoom);
v_joint_angle = joint_angle;
v_unit_normal = unit_normal;
}

View file

@ -11,8 +11,14 @@ out vec4 v_color;
#INSERT get_unit_normal.glsl
#INSERT finalize_color.glsl
const float EPSILON = 1e-10;
void main(){
emit_gl_Position(point);
vec3 normal = cross(normalize(du_point - point), normalize(dv_point - point));
v_color = finalize_color(rgba, point, normalize(normal));
vec3 du = (du_point - point);
vec3 dv = (dv_point - point);
vec3 normal = cross(du, dv);
float mag = length(normal);
vec3 unit_normal = (mag < EPSILON) ? vec3(0, 0, sign(point.z)) : normal / mag;
v_color = finalize_color(rgba, point, unit_normal);
}

View file

@ -5,6 +5,7 @@ from colour import hex2rgb
from colour import rgb2hex
import numpy as np
import random
from matplotlib import pyplot
from manimlib.constants import COLORMAP_3B1B
from manimlib.constants import WHITE
@ -14,8 +15,8 @@ from manimlib.utils.iterables import resize_with_interpolation
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Iterable, Sequence
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array
from typing import Iterable, Sequence, Callable
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, Vect4Array, NDArray
def color_to_rgb(color: ManimColor) -> Vect3:
@ -134,6 +135,33 @@ def random_bright_color(
))
def get_colormap_from_colors(colors: Iterable[ManimColor]) -> Callable[[Sequence[float]], Vect4Array]:
"""
Returns a funciton which takes in values between 0 and 1, and returns
a corresponding list of rgba values
"""
rgbas = np.array([color_to_rgba(color) for color in colors])
def func(values):
alphas = np.clip(values, 0, 1)
scaled_alphas = alphas * (len(rgbas) - 1)
indices = scaled_alphas.astype(int)
next_indices = np.clip(indices + 1, 0, len(rgbas) - 1)
inter_alphas = scaled_alphas % 1
inter_alphas = inter_alphas.repeat(4).reshape((len(indices), 4))
result = interpolate(rgbas[indices], rgbas[next_indices], inter_alphas)
return result
return func
def get_color_map(map_name: str) -> Callable[[Sequence[float]], Vect4Array]:
if map_name == "3b1b_colormap":
return get_colormap_from_colors(COLORMAP_3B1B)
return pyplot.get_cmap(map_name)
# Delete this?
def get_colormap_list(
map_name: str = "viridis",
n_colors: int = 9

View file

@ -1,6 +1,8 @@
from __future__ import annotations
import os
from pathlib import Path
import hashlib
import numpy as np
import validators
@ -35,9 +37,11 @@ def find_file(
if validators.url(file_name):
import urllib.request
from manimlib.utils.directories import get_downloads_dir
stem, name = os.path.split(file_name)
suffix = Path(file_name).suffix
file_hash = hashlib.sha256(file_name.encode('utf-8')).hexdigest()[:32]
folder = get_downloads_dir()
path = os.path.join(folder, name)
path = Path(folder, file_hash).with_suffix(suffix)
urllib.request.urlretrieve(file_name, path)
return path