Merge pull request #1973 from 3b1b/video-work

Refactor render out of camera (plus winding fill blending fix)
This commit is contained in:
Grant Sanderson 2023-01-26 16:54:44 -08:00 committed by GitHub
commit adfef48418
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 656 additions and 558 deletions

View file

@ -65,7 +65,6 @@ class Animation(object):
self.rate_func = squish_rate_func( self.rate_func = squish_rate_func(
self.rate_func, start / self.run_time, end / self.run_time, self.rate_func, start / self.run_time, end / self.run_time,
) )
self.mobject.refresh_shader_data()
self.mobject.set_animating_status(True) self.mobject.set_animating_status(True)
self.starting_mobject = self.create_starting_mobject() self.starting_mobject = self.create_starting_mobject()
if self.suspend_mobject_updating: if self.suspend_mobject_updating:

View file

@ -30,15 +30,6 @@ class ShowPartial(Animation, ABC):
self.should_match_start = should_match_start self.should_match_start = should_match_start
super().__init__(mobject, **kwargs) super().__init__(mobject, **kwargs)
def begin(self) -> None:
super().begin()
if not self.should_match_start:
self.mobject.lock_matching_data(self.mobject, self.starting_mobject)
def finish(self) -> None:
super().finish()
self.mobject.unlock_data()
def interpolate_submobject( def interpolate_submobject(
self, self,
submob: VMobject, submob: VMobject,
@ -114,11 +105,9 @@ class DrawBorderThenFill(Animation):
self.outline = self.get_outline() self.outline = self.get_outline()
super().begin() super().begin()
self.mobject.match_style(self.outline) self.mobject.match_style(self.outline)
self.mobject.lock_matching_data(self.mobject, self.outline)
def finish(self) -> None: def finish(self) -> None:
super().finish() super().finish()
self.mobject.unlock_data()
self.mobject.refresh_joint_products() self.mobject.refresh_joint_products()
def get_outline(self) -> VMobject: def get_outline(self) -> VMobject:
@ -146,9 +135,6 @@ class DrawBorderThenFill(Animation):
if index == 1 and self.sm_to_index[hash(submob)] == 0: if index == 1 and self.sm_to_index[hash(submob)] == 0:
# First time crossing over # First time crossing over
submob.set_data(outline.data) submob.set_data(outline.data)
submob.unlock_data()
if not self.mobject.has_updaters:
submob.lock_matching_data(submob, start)
submob.needs_new_triangulation = False submob.needs_new_triangulation = False
self.sm_to_index[hash(submob)] = 1 self.sm_to_index[hash(submob)] = 1

View file

@ -70,6 +70,8 @@ class Transform(Animation):
def finish(self) -> None: def finish(self) -> None:
super().finish() super().finish()
self.mobject.unlock_data() self.mobject.unlock_data()
if self.target_mobject is not None:
self.mobject.become(self.target_mobject)
def create_target(self) -> Mobject: def create_target(self) -> Mobject:
# Has no meaningful effect here, but may be useful # Has no meaningful effect here, but may be useful

View file

@ -1,188 +1,24 @@
from __future__ import annotations from __future__ import annotations
import itertools as it
import math
import moderngl import moderngl
import numpy as np import numpy as np
import OpenGL.GL as gl import OpenGL.GL as gl
from PIL import Image from PIL import Image
from scipy.spatial.transform import Rotation
from manimlib.camera.camera_frame import CameraFrame
from manimlib.constants import BLACK from manimlib.constants import BLACK
from manimlib.constants import DEGREES, RADIANS
from manimlib.constants import DEFAULT_FPS from manimlib.constants import DEFAULT_FPS
from manimlib.constants import DEFAULT_PIXEL_HEIGHT, DEFAULT_PIXEL_WIDTH from manimlib.constants import DEFAULT_PIXEL_HEIGHT, DEFAULT_PIXEL_WIDTH
from manimlib.constants import FRAME_HEIGHT, FRAME_WIDTH from manimlib.constants import FRAME_WIDTH
from manimlib.constants import DOWN, LEFT, ORIGIN, OUT, RIGHT, UP
from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Mobject
from manimlib.mobject.mobject import Point from manimlib.mobject.mobject import Point
from manimlib.utils.color import color_to_rgba from manimlib.utils.color import color_to_rgba
from manimlib.utils.simple_functions import fdiv
from manimlib.utils.space_ops import normalize
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from manimlib.shader_wrapper import ShaderWrapper
from manimlib.typing import ManimColor, Vect3 from manimlib.typing import ManimColor, Vect3
from manimlib.window import Window from manimlib.window import Window
from typing import Any, Iterable
class CameraFrame(Mobject):
def __init__(
self,
frame_shape: tuple[float, float] = (FRAME_WIDTH, FRAME_HEIGHT),
center_point: Vect3 = ORIGIN,
focal_dist_to_height: float = 2.0,
**kwargs,
):
self.frame_shape = frame_shape
self.center_point = center_point
self.focal_dist_to_height = focal_dist_to_height
self.view_matrix = np.identity(4)
super().__init__(**kwargs)
def init_uniforms(self) -> None:
super().init_uniforms()
# As a quaternion
self.uniforms["orientation"] = Rotation.identity().as_quat()
self.uniforms["focal_dist_to_height"] = self.focal_dist_to_height
def init_points(self) -> None:
self.set_points([ORIGIN, LEFT, RIGHT, DOWN, UP])
self.set_width(self.frame_shape[0], stretch=True)
self.set_height(self.frame_shape[1], stretch=True)
self.move_to(self.center_point)
def set_orientation(self, rotation: Rotation):
self.uniforms["orientation"][:] = rotation.as_quat()
return self
def get_orientation(self):
return Rotation.from_quat(self.uniforms["orientation"])
def to_default_state(self):
self.center()
self.set_height(FRAME_HEIGHT)
self.set_width(FRAME_WIDTH)
self.set_orientation(Rotation.identity())
return self
def get_euler_angles(self):
return self.get_orientation().as_euler("zxz")[::-1]
def get_theta(self):
return self.get_euler_angles()[0]
def get_phi(self):
return self.get_euler_angles()[1]
def get_gamma(self):
return self.get_euler_angles()[2]
def get_inverse_camera_rotation_matrix(self):
return self.get_orientation().as_matrix().T
def get_view_matrix(self):
"""
Returns a 4x4 for the affine transformation mapping a point
into the camera's internal coordinate system
"""
result = self.view_matrix
result[:] = np.identity(4)
result[:3, 3] = -self.get_center()
rotation = np.identity(4)
rotation[:3, :3] = self.get_inverse_camera_rotation_matrix()
result[:] = np.dot(rotation, result)
return result
def rotate(self, angle: float, axis: np.ndarray = OUT, **kwargs):
rot = Rotation.from_rotvec(angle * normalize(axis))
self.set_orientation(rot * self.get_orientation())
return self
def set_euler_angles(
self,
theta: float | None = None,
phi: float | None = None,
gamma: float | None = None,
units: float = RADIANS
):
eulers = self.get_euler_angles() # theta, phi, gamma
for i, var in enumerate([theta, phi, gamma]):
if var is not None:
eulers[i] = var * units
self.set_orientation(Rotation.from_euler("zxz", eulers[::-1]))
return self
def reorient(
self,
theta_degrees: float | None = None,
phi_degrees: float | None = None,
gamma_degrees: float | None = None,
):
"""
Shortcut for set_euler_angles, defaulting to taking
in angles in degrees
"""
self.set_euler_angles(theta_degrees, phi_degrees, gamma_degrees, units=DEGREES)
return self
def set_theta(self, theta: float):
return self.set_euler_angles(theta=theta)
def set_phi(self, phi: float):
return self.set_euler_angles(phi=phi)
def set_gamma(self, gamma: float):
return self.set_euler_angles(gamma=gamma)
def increment_theta(self, dtheta: float):
self.rotate(dtheta, OUT)
return self
def increment_phi(self, dphi: float):
self.rotate(dphi, self.get_inverse_camera_rotation_matrix()[0])
return self
def increment_gamma(self, dgamma: float):
self.rotate(dgamma, self.get_inverse_camera_rotation_matrix()[2])
return self
def set_focal_distance(self, focal_distance: float):
self.uniforms["focal_dist_to_height"] = focal_distance / self.get_height()
return self
def set_field_of_view(self, field_of_view: float):
self.uniforms["focal_dist_to_height"] = 2 * math.tan(field_of_view / 2)
return self
def get_shape(self):
return (self.get_width(), self.get_height())
def get_center(self) -> np.ndarray:
# Assumes first point is at the center
return self.get_points()[0]
def get_width(self) -> float:
points = self.get_points()
return points[2, 0] - points[1, 0]
def get_height(self) -> float:
points = self.get_points()
return points[4, 1] - points[3, 1]
def get_focal_distance(self) -> float:
return self.uniforms["focal_dist_to_height"] * self.get_height()
def get_field_of_view(self) -> float:
return 2 * math.atan(self.uniforms["focal_dist_to_height"] / 2)
def get_implied_camera_location(self) -> np.ndarray:
to_camera = self.get_inverse_camera_rotation_matrix()[2]
dist = self.get_focal_distance()
return self.get_center() + dist * to_camera
class Camera(object): class Camera(object):
@ -224,109 +60,43 @@ class Camera(object):
self.background_rgba: list[float] = list(color_to_rgba( self.background_rgba: list[float] = list(color_to_rgba(
background_color, background_opacity background_color, background_opacity
)) ))
self.perspective_uniforms = dict() self.uniforms = dict()
self.init_frame(**frame_config) self.init_frame(**frame_config)
self.init_context(window) self.init_context(window)
self.init_shaders()
self.init_textures()
self.init_light_source() self.init_light_source()
self.refresh_perspective_uniforms()
self.init_fill_fbo(self.ctx) # Experimental
# A cached map from mobjects to their associated list of render groups
# so that these render groups are not regenerated unnecessarily for static
# mobjects
self.mob_to_render_groups = {}
def init_frame(self, **config) -> None: def init_frame(self, **config) -> None:
self.frame = CameraFrame(**config) self.frame = CameraFrame(**config)
def init_context(self, window: Window | None = None) -> None: def init_context(self, window: Window | None = None) -> None:
self.window = window
if window is None: if window is None:
self.ctx = moderngl.create_standalone_context() self.ctx = moderngl.create_standalone_context()
self.fbo = self.get_fbo(self.samples) self.fbo = self.get_fbo(self.samples)
else: else:
self.ctx = window.ctx self.ctx = window.ctx
self.fbo = self.ctx.detect_framebuffer() self.window_fbo = self.ctx.detect_framebuffer()
self.fbo_for_files = self.get_fbo(self.samples)
self.fbo = self.window_fbo
self.fbo.use() self.fbo.use()
self.set_ctx_blending()
self.ctx.enable(moderngl.PROGRAM_POINT_SIZE) self.ctx.enable(moderngl.PROGRAM_POINT_SIZE)
self.ctx.enable(moderngl.BLEND)
# This is the frame buffer we'll draw into when emitting frames # This is the frame buffer we'll draw into when emitting frames
self.draw_fbo = self.get_fbo(samples=0) self.draw_fbo = self.get_fbo(samples=0)
def init_fill_fbo(self, ctx: moderngl.context.Context):
# Experimental
size = self.get_pixel_shape()
self.fill_texture = ctx.texture(
size=size,
components=4,
# Important to make sure floating point (not fixed point) is
# used so that alpha values are not clipped
dtype='f2',
)
# TODO, depth buffer is not really used yet
fill_depth = ctx.depth_renderbuffer(size)
self.fill_fbo = ctx.framebuffer(self.fill_texture, fill_depth)
self.fill_prog = ctx.program(
vertex_shader='''
#version 330
in vec2 texcoord;
out vec2 v_textcoord;
void main() {
gl_Position = vec4((2.0 * texcoord - 1.0), 0.0, 1.0);
v_textcoord = texcoord;
}
''',
fragment_shader='''
#version 330
uniform sampler2D Texture;
in vec2 v_textcoord;
out vec4 frag_color;
void main() {
frag_color = texture(Texture, v_textcoord);
frag_color = abs(frag_color);
if(frag_color.a == 0) discard;
//TODO, set gl_FragDepth;
}
''',
)
tid = self.n_textures
self.fill_texture.use(tid)
self.fill_prog['Texture'].value = tid
self.n_textures += 1
verts = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
self.fill_texture_vao = ctx.simple_vertex_array(
self.fill_prog,
ctx.buffer(verts.astype('f4').tobytes()),
'texcoord',
)
def set_ctx_blending(self, enable: bool = True) -> None:
if enable:
self.ctx.enable(moderngl.BLEND)
else:
self.ctx.disable(moderngl.BLEND)
def set_ctx_depth_test(self, enable: bool = True) -> None:
if enable:
self.ctx.enable(moderngl.DEPTH_TEST)
else:
self.ctx.disable(moderngl.DEPTH_TEST)
def set_ctx_clip_plane(self, enable: bool = True) -> None:
if enable:
gl.glEnable(gl.GL_CLIP_DISTANCE0)
def init_light_source(self) -> None: def init_light_source(self) -> None:
self.light_source = Point(self.light_source_position) self.light_source = Point(self.light_source_position)
def use_window_fbo(self, use: bool = True):
assert(self.window is not None)
if use:
self.fbo = self.window_fbo
else:
self.fbo = self.fbo_for_files
# Methods associated with the frame buffer # Methods associated with the frame buffer
def get_fbo( def get_fbo(
self, self,
@ -351,9 +121,6 @@ class Camera(object):
# Copy blocks from fbo into draw_fbo using Blit # Copy blocks from fbo into draw_fbo using Blit
gl.glBindFramebuffer(gl.GL_READ_FRAMEBUFFER, self.fbo.glo) gl.glBindFramebuffer(gl.GL_READ_FRAMEBUFFER, self.fbo.glo)
gl.glBindFramebuffer(gl.GL_DRAW_FRAMEBUFFER, self.draw_fbo.glo) gl.glBindFramebuffer(gl.GL_DRAW_FRAMEBUFFER, self.draw_fbo.glo)
if self.window is not None:
src_viewport = self.window.viewport
else:
src_viewport = self.fbo.viewport src_viewport = self.fbo.viewport
gl.glBlitFramebuffer( gl.glBlitFramebuffer(
*src_viewport, *src_viewport,
@ -443,153 +210,18 @@ class Camera(object):
# Rendering # Rendering
def capture(self, *mobjects: Mobject) -> None: def capture(self, *mobjects: Mobject) -> None:
self.refresh_perspective_uniforms() self.refresh_uniforms()
for mobject in mobjects:
for render_group in self.get_render_group_list(mobject):
self.render(render_group)
def render(self, render_group: dict[str, Any]) -> None:
shader_wrapper = render_group["shader_wrapper"]
shader_program = render_group["prog"]
primitive = int(shader_wrapper.render_primitive)
self.set_shader_uniforms(shader_program, shader_wrapper)
self.set_ctx_depth_test(shader_wrapper.depth_test)
self.set_ctx_clip_plane(shader_wrapper.use_clip_plane)
if shader_wrapper.is_fill:
self.render_fill(render_group["vao"], primitive, shader_wrapper.vert_indices)
else:
render_group["vao"].render(primitive)
if render_group["single_use"]:
self.release_render_group(render_group)
def render_fill(self, vao, render_primitive: int, indices: np.ndarray):
"""
VMobject fill is handled in a special way, where emited triangles
must be blended with moderngl.FUNC_SUBTRACT so as to effectively compute
a winding number around each pixel. This is rendered to a separate texture,
then that texture is overlayed onto the current fbo
"""
winding = (len(indices) == 0)
vao.program['winding'].value = winding
if not winding:
vao.render(moderngl.TRIANGLES)
return
self.fill_fbo.clear()
self.fill_fbo.use()
self.ctx.blend_func = (moderngl.ONE, moderngl.ONE)
vao.render(render_primitive)
self.ctx.blend_func = moderngl.DEFAULT_BLENDING
self.fbo.use() self.fbo.use()
self.fill_texture_vao.render(moderngl.TRIANGLE_STRIP) for mobject in mobjects:
mobject.render(self.ctx, self.uniforms)
def get_render_group_list(self, mobject: Mobject) -> Iterable[dict[str, Any]]: def refresh_uniforms(self) -> None:
if mobject.is_changing():
return self.generate_render_group_list(mobject)
# Otherwise, cache result for later use
key = id(mobject)
if key not in self.mob_to_render_groups:
self.mob_to_render_groups[key] = list(self.generate_render_group_list(mobject))
return self.mob_to_render_groups[key]
def generate_render_group_list(self, mobject: Mobject) -> Iterable[dict[str, Any]]:
return (
self.get_render_group(sw, single_use=mobject.is_changing())
for sw in mobject.get_shader_wrapper_list()
)
def get_render_group(
self,
shader_wrapper: ShaderWrapper,
single_use: bool = True
) -> dict[str, Any]:
# Data buffer
vert_data = shader_wrapper.vert_data
indices = shader_wrapper.vert_indices
if len(indices) == 0:
ibo = None
elif single_use:
ibo = self.ctx.buffer(indices.astype(np.uint32))
else:
ibo = self.ctx.buffer(indices.astype(np.uint32))
# # The vao.render call is strangely longer
# # when an index buffer is used, so if the
# # mobject is not changing, meaning only its
# # uniforms are being updated, just create
# # a larger data array based on the indices
# # and don't bother with the ibo
# vert_data = vert_data[indices]
# ibo = None
vbo = self.ctx.buffer(vert_data)
# Program and vertex array
shader_program, vert_format = self.get_shader_program(shader_wrapper)
attributes = shader_wrapper.vert_attributes
vao = self.ctx.vertex_array(
program=shader_program,
content=[(vbo, vert_format, *attributes)],
index_buffer=ibo,
)
return {
"vbo": vbo,
"ibo": ibo,
"vao": vao,
"prog": shader_program,
"shader_wrapper": shader_wrapper,
"single_use": single_use,
}
def release_render_group(self, render_group: dict[str, Any]) -> None:
for key in ["vbo", "ibo", "vao"]:
if render_group[key] is not None:
render_group[key].release()
def refresh_static_mobjects(self) -> None:
for render_group in it.chain(*self.mob_to_render_groups.values()):
self.release_render_group(render_group)
self.mob_to_render_groups = {}
# Shaders
def init_shaders(self) -> None:
# Initialize with the null id going to None
self.id_to_shader_program: dict[int, tuple[moderngl.Program, str] | None] = {hash(""): None}
def get_shader_program(
self,
shader_wrapper: ShaderWrapper
) -> tuple[moderngl.Program, str] | None:
sid = shader_wrapper.get_program_id()
if sid not in self.id_to_shader_program:
# Create shader program for the first time, then cache
# in the id_to_shader_program dictionary
program = self.ctx.program(**shader_wrapper.get_program_code())
vert_format = moderngl.detect_format(program, shader_wrapper.vert_attributes)
self.id_to_shader_program[sid] = (program, vert_format)
return self.id_to_shader_program[sid]
def set_shader_uniforms(
self,
shader: moderngl.Program,
shader_wrapper: ShaderWrapper
) -> None:
for name, path in shader_wrapper.texture_paths.items():
tid = self.get_texture_id(path)
shader[name].value = tid
for name, value in it.chain(self.perspective_uniforms.items(), shader_wrapper.uniforms.items()):
if name in shader:
if isinstance(value, np.ndarray) and value.ndim > 0:
value = tuple(value)
shader[name].value = value
def refresh_perspective_uniforms(self) -> None:
frame = self.frame frame = self.frame
view_matrix = frame.get_view_matrix() view_matrix = frame.get_view_matrix()
light_pos = self.light_source.get_location() light_pos = self.light_source.get_location()
cam_pos = self.frame.get_implied_camera_location() cam_pos = self.frame.get_implied_camera_location()
self.perspective_uniforms.update( self.uniforms.update(
frame_shape=frame.get_shape(), frame_shape=frame.get_shape(),
pixel_size=self.get_pixel_size(), pixel_size=self.get_pixel_size(),
view=tuple(view_matrix.T.flatten()), view=tuple(view_matrix.T.flatten()),
@ -598,32 +230,6 @@ class Camera(object):
focal_distance=frame.get_focal_distance(), focal_distance=frame.get_focal_distance(),
) )
def init_textures(self) -> None:
self.n_textures: int = 0
self.path_to_texture: dict[
str, tuple[int, moderngl.Texture]
] = {}
def get_texture_id(self, path: str) -> int:
if path not in self.path_to_texture:
tid = self.n_textures
self.n_textures += 1
im = Image.open(path).convert("RGBA")
texture = self.ctx.texture(
size=im.size,
components=len(im.getbands()),
data=im.tobytes(),
)
texture.use(location=tid)
self.path_to_texture[path] = (tid, texture)
return self.path_to_texture[path][0]
def release_texture(self, path: str):
tid_and_texture = self.path_to_texture.pop(path, None)
if tid_and_texture:
tid_and_texture[1].release()
return self
# Mostly just defined so old scenes don't break # Mostly just defined so old scenes don't break
class ThreeDCamera(Camera): class ThreeDCamera(Camera):

View file

@ -0,0 +1,173 @@
from __future__ import annotations
import math
import numpy as np
from scipy.spatial.transform import Rotation
from manimlib.constants import DEGREES, RADIANS
from manimlib.constants import FRAME_HEIGHT, FRAME_WIDTH
from manimlib.constants import DOWN, LEFT, ORIGIN, OUT, RIGHT, UP
from manimlib.mobject.mobject import Mobject
from manimlib.utils.space_ops import normalize
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.typing import Vect3
class CameraFrame(Mobject):
def __init__(
self,
frame_shape: tuple[float, float] = (FRAME_WIDTH, FRAME_HEIGHT),
center_point: Vect3 = ORIGIN,
focal_dist_to_height: float = 2.0,
**kwargs,
):
self.frame_shape = frame_shape
self.center_point = center_point
self.focal_dist_to_height = focal_dist_to_height
self.view_matrix = np.identity(4)
super().__init__(**kwargs)
def init_uniforms(self) -> None:
super().init_uniforms()
# As a quaternion
self.uniforms["orientation"] = Rotation.identity().as_quat()
self.uniforms["focal_dist_to_height"] = self.focal_dist_to_height
def init_points(self) -> None:
self.set_points([ORIGIN, LEFT, RIGHT, DOWN, UP])
self.set_width(self.frame_shape[0], stretch=True)
self.set_height(self.frame_shape[1], stretch=True)
self.move_to(self.center_point)
def set_orientation(self, rotation: Rotation):
self.uniforms["orientation"][:] = rotation.as_quat()
return self
def get_orientation(self):
return Rotation.from_quat(self.uniforms["orientation"])
def to_default_state(self):
self.center()
self.set_height(FRAME_HEIGHT)
self.set_width(FRAME_WIDTH)
self.set_orientation(Rotation.identity())
return self
def get_euler_angles(self):
return self.get_orientation().as_euler("zxz")[::-1]
def get_theta(self):
return self.get_euler_angles()[0]
def get_phi(self):
return self.get_euler_angles()[1]
def get_gamma(self):
return self.get_euler_angles()[2]
def get_inverse_camera_rotation_matrix(self):
return self.get_orientation().as_matrix().T
def get_view_matrix(self):
"""
Returns a 4x4 for the affine transformation mapping a point
into the camera's internal coordinate system
"""
result = self.view_matrix
result[:] = np.identity(4)
result[:3, 3] = -self.get_center()
rotation = np.identity(4)
rotation[:3, :3] = self.get_inverse_camera_rotation_matrix()
result[:] = np.dot(rotation, result)
return result
def rotate(self, angle: float, axis: np.ndarray = OUT, **kwargs):
rot = Rotation.from_rotvec(angle * normalize(axis))
self.set_orientation(rot * self.get_orientation())
return self
def set_euler_angles(
self,
theta: float | None = None,
phi: float | None = None,
gamma: float | None = None,
units: float = RADIANS
):
eulers = self.get_euler_angles() # theta, phi, gamma
for i, var in enumerate([theta, phi, gamma]):
if var is not None:
eulers[i] = var * units
self.set_orientation(Rotation.from_euler("zxz", eulers[::-1]))
return self
def reorient(
self,
theta_degrees: float | None = None,
phi_degrees: float | None = None,
gamma_degrees: float | None = None,
):
"""
Shortcut for set_euler_angles, defaulting to taking
in angles in degrees
"""
self.set_euler_angles(theta_degrees, phi_degrees, gamma_degrees, units=DEGREES)
return self
def set_theta(self, theta: float):
return self.set_euler_angles(theta=theta)
def set_phi(self, phi: float):
return self.set_euler_angles(phi=phi)
def set_gamma(self, gamma: float):
return self.set_euler_angles(gamma=gamma)
def increment_theta(self, dtheta: float):
self.rotate(dtheta, OUT)
return self
def increment_phi(self, dphi: float):
self.rotate(dphi, self.get_inverse_camera_rotation_matrix()[0])
return self
def increment_gamma(self, dgamma: float):
self.rotate(dgamma, self.get_inverse_camera_rotation_matrix()[2])
return self
def set_focal_distance(self, focal_distance: float):
self.uniforms["focal_dist_to_height"] = focal_distance / self.get_height()
return self
def set_field_of_view(self, field_of_view: float):
self.uniforms["focal_dist_to_height"] = 2 * math.tan(field_of_view / 2)
return self
def get_shape(self):
return (self.get_width(), self.get_height())
def get_center(self) -> np.ndarray:
# Assumes first point is at the center
return self.get_points()[0]
def get_width(self) -> float:
points = self.get_points()
return points[2, 0] - points[1, 0]
def get_height(self) -> float:
points = self.get_points()
return points[4, 1] - points[3, 1]
def get_focal_distance(self) -> float:
return self.uniforms["focal_dist_to_height"] * self.get_height()
def get_field_of_view(self) -> float:
return 2 * math.atan(self.uniforms["focal_dist_to_height"] / 2)
def get_implied_camera_location(self) -> np.ndarray:
to_camera = self.get_inverse_camera_rotation_matrix()[2]
dist = self.get_focal_distance()
return self.get_center() + dist * to_camera

View file

@ -414,10 +414,7 @@ def get_window_config(args: Namespace, custom_config: dict, camera_config: dict)
if not (args.full_screen or custom_config["full_screen"]): if not (args.full_screen or custom_config["full_screen"]):
window_width //= 2 window_width //= 2
window_height = int(window_width / aspect_ratio) window_height = int(window_width / aspect_ratio)
return dict( return dict(size=(window_width, window_height))
full_size=(camera_config["pixel_width"], camera_config["pixel_height"]),
size=(window_width, window_height),
)
def get_camera_config(args: Namespace, custom_config: dict) -> dict: def get_camera_config(args: Namespace, custom_config: dict) -> dict:

View file

@ -51,6 +51,7 @@ if TYPE_CHECKING:
from typing import Callable, Iterable, Union, Tuple from typing import Callable, Iterable, Union, Tuple
import numpy.typing as npt import numpy.typing as npt
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array
from moderngl.context import Context
TimeBasedUpdater = Callable[["Mobject", float], "Mobject" | None] TimeBasedUpdater = Callable[["Mobject", float], "Mobject" | None]
NonTimeUpdater = Callable[["Mobject"], "Mobject" | None] NonTimeUpdater = Callable[["Mobject"], "Mobject" | None]
@ -101,6 +102,8 @@ class Mobject(object):
self.saved_state = None self.saved_state = None
self.target = None self.target = None
self.bounding_box: Vect3Array = np.zeros((3, 3)) self.bounding_box: Vect3Array = np.zeros((3, 3))
self._shaders_initialized: bool = False
self._data_has_changed: bool = True
self.init_data() self.init_data()
self._data_defaults = np.ones(1, dtype=self.data.dtype) self._data_defaults = np.ones(1, dtype=self.data.dtype)
@ -109,7 +112,6 @@ class Mobject(object):
self.init_event_listners() self.init_event_listners()
self.init_points() self.init_points()
self.init_colors() self.init_colors()
self.init_shader_data()
if self.depth_test: if self.depth_test:
self.apply_depth_test() self.apply_depth_test()
@ -141,11 +143,6 @@ class Mobject(object):
# Typically implemented in subclass, unlpess purposefully left blank # Typically implemented in subclass, unlpess purposefully left blank
pass pass
def set_data(self, data: np.ndarray):
assert(data.dtype == self.data.dtype)
self.data = data
return self
def set_uniforms(self, uniforms: dict): def set_uniforms(self, uniforms: dict):
for key, value in uniforms.items(): for key, value in uniforms.items():
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
@ -158,8 +155,36 @@ class Mobject(object):
# Borrowed from https://github.com/ManimCommunity/manim/ # Borrowed from https://github.com/ManimCommunity/manim/
return _AnimationBuilder(self) return _AnimationBuilder(self)
# Only these methods should directly affect points def note_changed_data(self, recurse_up: bool = True):
self._data_has_changed = True
if recurse_up:
for mob in self.parents:
mob.note_changed_data()
def affects_data(func: Callable):
@wraps(func)
def wrapper(self, *args, **kwargs):
func(self, *args, **kwargs)
self.note_changed_data()
return wrapper
def affects_family_data(func: Callable):
@wraps(func)
def wrapper(self, *args, **kwargs):
func(self, *args, **kwargs)
for mob in self.family_members_with_points():
mob.note_changed_data()
return self
return wrapper
# Only these methods should directly affect points
@affects_data
def set_data(self, data: np.ndarray):
assert(data.dtype == self.data.dtype)
self.data = data.copy()
return self
@affects_data
def resize_points( def resize_points(
self, self,
new_length: int, new_length: int,
@ -175,11 +200,13 @@ class Mobject(object):
self.refresh_bounding_box() self.refresh_bounding_box()
return self return self
@affects_data
def set_points(self, points: Vect3Array): def set_points(self, points: Vect3Array):
self.resize_points(len(points), resize_func=resize_preserving_order) self.resize_points(len(points), resize_func=resize_preserving_order)
self.data["point"][:] = points self.data["point"][:] = points
return self return self
@affects_data
def append_points(self, new_points: Vect3Array): def append_points(self, new_points: Vect3Array):
n = self.get_num_points() n = self.get_num_points()
self.resize_points(n + len(new_points)) self.resize_points(n + len(new_points))
@ -190,11 +217,13 @@ class Mobject(object):
self.refresh_bounding_box() self.refresh_bounding_box()
return self return self
@affects_family_data
def reverse_points(self): def reverse_points(self):
for mob in self.get_family(): for mob in self.get_family():
mob.data = mob.data[::-1] mob.data = mob.data[::-1]
return self return self
@affects_family_data
def apply_points_function( def apply_points_function(
self, self,
func: Callable[[np.ndarray], np.ndarray], func: Callable[[np.ndarray], np.ndarray],
@ -328,6 +357,7 @@ class Mobject(object):
def split(self) -> list[Mobject]: def split(self) -> list[Mobject]:
return self.submobjects return self.submobjects
@affects_data
def assemble_family(self): def assemble_family(self):
sub_families = (sm.get_family() for sm in self.submobjects) sub_families = (sm.get_family() for sm in self.submobjects)
self.family = [self, *it.chain(*sub_families)] self.family = [self, *it.chain(*sub_families)]
@ -557,11 +587,9 @@ class Mobject(object):
return self return self
def deepcopy(self): def deepcopy(self):
try: result = copy.deepcopy(self)
# Often faster than deepcopy result._shaders_initialized = False
return pickle.loads(pickle.dumps(self)) result._data_has_changed = True
except AttributeError:
return copy.deepcopy(self)
@stash_mobject_pointers @stash_mobject_pointers
def copy(self, deep: bool = False): def copy(self, deep: bool = False):
@ -591,6 +619,7 @@ class Mobject(object):
# won't have changed, just directly match. # won't have changed, just directly match.
result.non_time_updaters = list(self.non_time_updaters) result.non_time_updaters = list(self.non_time_updaters)
result.time_based_updaters = list(self.time_based_updaters) result.time_based_updaters = list(self.time_based_updaters)
result._data_has_changed = True
family = self.get_family() family = self.get_family()
for attr, value in list(self.__dict__.items()): for attr, value in list(self.__dict__.items()):
@ -654,7 +683,6 @@ class Mobject(object):
for attr, value in list(mobject.__dict__.items()): for attr, value in list(mobject.__dict__.items()):
if isinstance(value, Mobject) and value in family2: if isinstance(value, Mobject) and value in family2:
setattr(self, attr, family1[family2.index(value)]) setattr(self, attr, family1[family2.index(value)])
self.refresh_bounding_box(recurse_down=True)
if match_updaters: if match_updaters:
self.match_updaters(mobject) self.match_updaters(mobject)
return self return self
@ -1214,6 +1242,7 @@ class Mobject(object):
# Color functions # Color functions
@affects_family_data
def set_rgba_array( def set_rgba_array(
self, self,
rgba_array: npt.ArrayLike, rgba_array: npt.ArrayLike,
@ -1252,6 +1281,7 @@ class Mobject(object):
mob.set_rgba_array(rgba_array) mob.set_rgba_array(rgba_array)
return self return self
@affects_family_data
def set_rgba_array_by_color( def set_rgba_array_by_color(
self, self,
color: ManimColor | Iterable[ManimColor] | None = None, color: ManimColor | Iterable[ManimColor] | None = None,
@ -1618,9 +1648,6 @@ class Mobject(object):
def align_data(self, mobject: Mobject) -> None: def align_data(self, mobject: Mobject) -> None:
for mob1, mob2 in zip(self.get_family(), mobject.get_family()): for mob1, mob2 in zip(self.get_family(), mobject.get_family()):
# In case any data arrays get resized when aligned to shader data
mob1.refresh_shader_data()
mob2.refresh_shader_data()
mob1.align_points(mob2) mob1.align_points(mob2)
def align_points(self, mobject: Mobject): def align_points(self, mobject: Mobject):
@ -1690,6 +1717,8 @@ class Mobject(object):
path_func: Callable[[np.ndarray, np.ndarray, float], np.ndarray] = straight_path path_func: Callable[[np.ndarray, np.ndarray, float], np.ndarray] = straight_path
): ):
keys = [k for k in self.data.dtype.names if k not in self.locked_data_keys] keys = [k for k in self.data.dtype.names if k not in self.locked_data_keys]
if keys:
self.note_changed_data()
for key in keys: for key in keys:
func = path_func if key in self.pointlike_data_keys else interpolate func = path_func if key in self.pointlike_data_keys else interpolate
md1 = mobject1.data[key] md1 = mobject1.data[key]
@ -1700,6 +1729,8 @@ class Mobject(object):
self.data[key] = func(md1, md2, alpha) self.data[key] = func(md1, md2, alpha)
for key in self.uniforms: for key in self.uniforms:
if key not in mobject1.uniforms or key not in mobject2.uniforms:
continue
self.uniforms[key] = interpolate( self.uniforms[key] = interpolate(
mobject1.uniforms[key], mobject1.uniforms[key],
mobject2.uniforms[key], mobject2.uniforms[key],
@ -1731,8 +1762,6 @@ class Mobject(object):
""" """
if self.has_updaters: if self.has_updaters:
return return
# Be sure shader data has most up to date information
self.refresh_shader_data()
self.locked_data_keys = set(keys) self.locked_data_keys = set(keys)
def lock_matching_data(self, mobject1: Mobject, mobject2: Mobject): def lock_matching_data(self, mobject1: Mobject, mobject2: Mobject):
@ -1842,10 +1871,10 @@ class Mobject(object):
# For shader data # For shader data
def init_shader_data(self): def init_shader_data(self, ctx: Context):
# TODO, only call this when needed?
self.shader_indices = np.zeros(0) self.shader_indices = np.zeros(0)
self.shader_wrapper = ShaderWrapper( self.shader_wrapper = ShaderWrapper(
ctx=ctx,
vert_data=self.data, vert_data=self.data,
shader_folder=self.shader_folder, shader_folder=self.shader_folder,
texture_paths=self.texture_paths, texture_paths=self.texture_paths,
@ -1854,20 +1883,25 @@ class Mobject(object):
) )
def refresh_shader_wrapper_id(self): def refresh_shader_wrapper_id(self):
if self._shaders_initialized:
self.shader_wrapper.refresh_id() self.shader_wrapper.refresh_id()
return self return self
def get_shader_wrapper(self) -> ShaderWrapper: def get_shader_wrapper(self, ctx: Context) -> ShaderWrapper:
if not self._shaders_initialized:
self.init_shader_data(ctx)
self._shaders_initialized = True
self.shader_wrapper.vert_data = self.get_shader_data() self.shader_wrapper.vert_data = self.get_shader_data()
self.shader_wrapper.vert_indices = self.get_shader_vert_indices() self.shader_wrapper.vert_indices = self.get_shader_vert_indices()
self.shader_wrapper.uniforms = self.get_uniforms() self.shader_wrapper.uniforms.update(self.get_uniforms())
self.shader_wrapper.depth_test = self.depth_test self.shader_wrapper.depth_test = self.depth_test
return self.shader_wrapper return self.shader_wrapper
def get_shader_wrapper_list(self) -> list[ShaderWrapper]: def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]:
shader_wrappers = it.chain( shader_wrappers = it.chain(
[self.get_shader_wrapper()], [self.get_shader_wrapper(ctx)],
*[sm.get_shader_wrapper_list() for sm in self.submobjects] *[sm.get_shader_wrapper_list(ctx) for sm in self.submobjects]
) )
batches = batch_by_property(shader_wrappers, lambda sw: sw.get_id()) batches = batch_by_property(shader_wrappers, lambda sw: sw.get_id())
@ -1884,15 +1918,24 @@ class Mobject(object):
def get_shader_data(self): def get_shader_data(self):
return self.data return self.data
def refresh_shader_data(self):
pass
def get_uniforms(self): def get_uniforms(self):
return self.uniforms return self.uniforms
def get_shader_vert_indices(self): def get_shader_vert_indices(self):
return self.shader_indices return self.shader_indices
def render(self, ctx: Context, camera_uniforms: dict):
if self._data_has_changed:
self.shader_wrappers = self.get_shader_wrapper_list(ctx)
for shader_wrapper in self.shader_wrappers:
shader_wrapper.generate_vao()
self._data_has_changed = False
for shader_wrapper in self.shader_wrappers:
shader_wrapper.uniforms.update(self.get_uniforms())
shader_wrapper.uniforms.update(camera_uniforms)
shader_wrapper.pre_render()
shader_wrapper.render()
# Event Handlers # Event Handlers
""" """
Event handling follows the Event Bubbling model of DOM in javascript. Event handling follows the Event Bubbling model of DOM in javascript.
@ -2005,6 +2048,8 @@ class Group(Mobject):
raise Exception("All submobjects must be of type Mobject") raise Exception("All submobjects must be of type Mobject")
Mobject.__init__(self, **kwargs) Mobject.__init__(self, **kwargs)
self.add(*mobjects) self.add(*mobjects)
if any(m.is_fixed_in_frame for m in mobjects):
self.fix_in_frame()
def __add__(self, other: Mobject | Group): def __add__(self, other: Mobject | Group):
assert(isinstance(other, Mobject)) assert(isinstance(other, Mobject))

View file

@ -320,8 +320,9 @@ class VMobjectFromSVGPath(VMobject):
self.set_points(self.get_points_without_null_curves()) self.set_points(self.get_points_without_null_curves())
# So triangulation doesn't get messed up # So triangulation doesn't get messed up
self.subdivide_intersections() self.subdivide_intersections()
# Always default to orienting outward # Always default to orienting outward, account
if self.get_unit_normal()[2] < 0: # for the fact that this will get flipped in SVG.__init__
if self.get_unit_normal()[2] > 0:
self.reverse_points() self.reverse_points()
# Save for future use # Save for future use
PATH_TO_POINTS[path_string] = self.get_points().copy() PATH_TO_POINTS[path_string] = self.get_points().copy()

View file

@ -5,6 +5,7 @@ import numpy as np
from manimlib.constants import GREY_C, YELLOW from manimlib.constants import GREY_C, YELLOW
from manimlib.constants import ORIGIN, NULL_POINTS from manimlib.constants import ORIGIN, NULL_POINTS
from manimlib.mobject.mobject import Mobject
from manimlib.mobject.types.point_cloud_mobject import PMobject from manimlib.mobject.types.point_cloud_mobject import PMobject
from manimlib.utils.iterables import resize_with_interpolation from manimlib.utils.iterables import resize_with_interpolation
@ -94,6 +95,7 @@ class DotCloud(PMobject):
self.center() self.center()
return self return self
@Mobject.affects_data
def set_radii(self, radii: npt.ArrayLike): def set_radii(self, radii: npt.ArrayLike):
n_points = self.get_num_points() n_points = self.get_num_points()
radii = np.array(radii).reshape((len(radii), 1)) radii = np.array(radii).reshape((len(radii), 1))
@ -104,6 +106,7 @@ class DotCloud(PMobject):
def get_radii(self) -> np.ndarray: def get_radii(self) -> np.ndarray:
return self.data["radius"] return self.data["radius"]
@Mobject.affects_data
def set_radius(self, radius: float): def set_radius(self, radius: float):
data = self.data if self.get_num_points() > 0 else self._data_defaults data = self.data if self.get_num_points() > 0 else self._data_defaults
data["radius"][:] = radius data["radius"][:] = radius

View file

@ -47,6 +47,7 @@ class ImageMobject(Mobject):
self.set_width(2 * size[0] / size[1], stretch=True) self.set_width(2 * size[0] / size[1], stretch=True)
self.set_height(self.height) self.set_height(self.height)
@Mobject.affects_data
def set_opacity(self, opacity: float, recurse: bool = True): def set_opacity(self, opacity: float, recurse: bool = True):
self.data["opacity"][:, 0] = resize_with_interpolation( self.data["opacity"][:, 0] = resize_with_interpolation(
np.array(listify(opacity)), np.array(listify(opacity)),

View file

@ -5,7 +5,6 @@ import numpy as np
from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Mobject
from manimlib.utils.color import color_gradient from manimlib.utils.color import color_gradient
from manimlib.utils.color import color_to_rgba from manimlib.utils.color import color_to_rgba
from manimlib.utils.iterables import resize_array
from manimlib.utils.iterables import resize_with_interpolation from manimlib.utils.iterables import resize_with_interpolation
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@ -52,6 +51,7 @@ class PMobject(Mobject):
self.add_points([point], rgbas, color, opacity) self.add_points([point], rgbas, color, opacity)
return self return self
@Mobject.affects_data
def set_color_by_gradient(self, *colors: ManimColor): def set_color_by_gradient(self, *colors: ManimColor):
self.data["rgba"][:] = np.array(list(map( self.data["rgba"][:] = np.array(list(map(
color_to_rgba, color_to_rgba,
@ -59,17 +59,20 @@ class PMobject(Mobject):
))) )))
return self return self
@Mobject.affects_data
def match_colors(self, pmobject: PMobject): def match_colors(self, pmobject: PMobject):
self.data["rgba"][:] = resize_with_interpolation( self.data["rgba"][:] = resize_with_interpolation(
pmobject.data["rgba"], self.get_num_points() pmobject.data["rgba"], self.get_num_points()
) )
return self return self
@Mobject.affects_data
def filter_out(self, condition: Callable[[np.ndarray], bool]): def filter_out(self, condition: Callable[[np.ndarray], bool]):
for mob in self.family_members_with_points(): for mob in self.family_members_with_points():
mob.data = mob.data[~np.apply_along_axis(condition, 1, mob.get_points())] mob.data = mob.data[~np.apply_along_axis(condition, 1, mob.get_points())]
return self return self
@Mobject.affects_data
def sort_points(self, function: Callable[[Vect3], None] = lambda p: p[0]): def sort_points(self, function: Callable[[Vect3], None] = lambda p: p[0]):
""" """
function is any map from R^3 to R function is any map from R^3 to R
@ -81,6 +84,7 @@ class PMobject(Mobject):
mob.data[:] = mob.data[indices] mob.data[:] = mob.data[indices]
return self return self
@Mobject.affects_data
def ingest_submobjects(self): def ingest_submobjects(self):
self.data = np.vstack([ self.data = np.vstack([
sm.data for sm in self.get_family() sm.data for sm in self.get_family()
@ -91,6 +95,7 @@ class PMobject(Mobject):
index = alpha * (self.get_num_points() - 1) index = alpha * (self.get_num_points() - 1)
return self.get_points()[int(index)] return self.get_points()[int(index)]
@Mobject.affects_data
def pointwise_become_partial(self, pmobject: PMobject, a: float, b: float): def pointwise_become_partial(self, pmobject: PMobject, a: float, b: float):
lower_index = int(a * pmobject.get_num_points()) lower_index = int(a * pmobject.get_num_points())
upper_index = int(b * pmobject.get_num_points()) upper_index = int(b * pmobject.get_num_points())

View file

@ -72,6 +72,7 @@ class Surface(Mobject):
# To be implemented in subclasses # To be implemented in subclasses
return (u, v, 0.0) return (u, v, 0.0)
@Mobject.affects_data
def init_points(self): def init_points(self):
dim = self.dim dim = self.dim
nu, nv = self.resolution nu, nv = self.resolution
@ -130,6 +131,7 @@ class Surface(Mobject):
) )
return normalize_along_axis(normals, 1) return normalize_along_axis(normals, 1)
@Mobject.affects_data
def pointwise_become_partial( def pointwise_become_partial(
self, self,
smobject: "Surface", smobject: "Surface",
@ -218,12 +220,10 @@ class Surface(Mobject):
self.uniforms["clip_plane"][:3] = vect self.uniforms["clip_plane"][:3] = vect
if threshold is not None: if threshold is not None:
self.uniforms["clip_plane"][3] = threshold self.uniforms["clip_plane"][3] = threshold
self.shader_wrapper.use_clip_plane = True
return self return self
def deactivate_clip_plane(self): def deactivate_clip_plane(self):
self.uniforms["clip_plane"][:] = 0 self.uniforms["clip_plane"][:] = 0
self.shader_wrapper.use_clip_plane = False
return self return self
def get_shader_vert_indices(self) -> np.ndarray: def get_shader_vert_indices(self) -> np.ndarray:
@ -300,6 +300,7 @@ class TexturedSurface(Surface):
**kwargs **kwargs
) )
@Mobject.affects_data
def init_points(self): def init_points(self):
surf = self.uv_surface surf = self.uv_surface
nu, nv = surf.resolution nu, nv = surf.resolution
@ -317,6 +318,7 @@ class TexturedSurface(Surface):
super().init_uniforms() super().init_uniforms()
self.uniforms["num_textures"] = self.num_textures self.uniforms["num_textures"] = self.num_textures
@Mobject.affects_data
def set_opacity(self, opacity: float | Iterable[float]): def set_opacity(self, opacity: float | Iterable[float]):
op_arr = np.array(listify(opacity)) op_arr = np.array(listify(opacity))
self.data["opacity"][:, 0] = resize_with_interpolation(op_arr, len(self.data)) self.data["opacity"][:, 0] = resize_with_interpolation(op_arr, len(self.data))

View file

@ -40,12 +40,14 @@ from manimlib.utils.space_ops import midpoint
from manimlib.utils.space_ops import normalize_along_axis from manimlib.utils.space_ops import normalize_along_axis
from manimlib.utils.space_ops import z_to_vector from manimlib.utils.space_ops import z_to_vector
from manimlib.shader_wrapper import ShaderWrapper from manimlib.shader_wrapper import ShaderWrapper
from manimlib.shader_wrapper import FillShaderWrapper
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable, Iterable, Tuple from typing import Callable, Iterable, Tuple
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, Vect4Array from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, Vect4Array
from moderngl.context import Context
DEFAULT_STROKE_COLOR = GREY_A DEFAULT_STROKE_COLOR = GREY_A
DEFAULT_FILL_COLOR = GREY_C DEFAULT_FILL_COLOR = GREY_C
@ -231,6 +233,7 @@ class VMobject(Mobject):
self.set_stroke(color, width, background=background) self.set_stroke(color, width, background=background)
return self return self
@Mobject.affects_family_data
def set_style( def set_style(
self, self,
fill_color: ManimColor | Iterable[ManimColor] | None = None, fill_color: ManimColor | Iterable[ManimColor] | None = None,
@ -414,6 +417,7 @@ class VMobject(Mobject):
def get_joint_type(self) -> float: def get_joint_type(self) -> float:
return self.uniforms["joint_type"] return self.uniforms["joint_type"]
@Mobject.affects_family_data
def use_winding_fill(self, value: bool = True, recurse: bool = True): def use_winding_fill(self, value: bool = True, recurse: bool = True):
for submob in self.get_family(recurse): for submob in self.get_family(recurse):
submob._use_winding_fill = value submob._use_winding_fill = value
@ -654,7 +658,7 @@ class VMobject(Mobject):
return self return self
def add_subpath(self, points: Vect3Array): def add_subpath(self, points: Vect3Array):
assert(len(points) % 2 == 1) assert(len(points) % 2 == 1 or len(points) == 0)
if not self.has_points(): if not self.has_points():
self.set_points(points) self.set_points(points)
return self return self
@ -832,7 +836,7 @@ class VMobject(Mobject):
# If both have fill, and they have the same shape, just # If both have fill, and they have the same shape, just
# give them the same triangulation so that it's not recalculated # give them the same triangulation so that it's not recalculated
# needlessly throughout an animation # needlessly throughout an animation
if self._use_winding_fill and self.has_fill() \ if not self._use_winding_fill and self.has_fill() \
and vmobject.has_fill() and self.has_same_shape_as(vmobject): and vmobject.has_fill() and self.has_same_shape_as(vmobject):
vmobject.triangulation = self.triangulation vmobject.triangulation = self.triangulation
return self return self
@ -881,6 +885,8 @@ class VMobject(Mobject):
def invisible_copy(self): def invisible_copy(self):
result = self.copy() result = self.copy()
if not result.has_fill() or result.get_num_points() == 0:
return result
result.append_vectorized_mobject(self.copy().reverse_points()) result.append_vectorized_mobject(self.copy().reverse_points())
result.set_opacity(0) result.set_opacity(0)
return result return result
@ -937,8 +943,9 @@ class VMobject(Mobject):
def pointwise_become_partial(self, vmobject: VMobject, a: float, b: float): def pointwise_become_partial(self, vmobject: VMobject, a: float, b: float):
assert(isinstance(vmobject, VMobject)) assert(isinstance(vmobject, VMobject))
vm_points = vmobject.get_points() vm_points = vmobject.get_points()
self.data["joint_product"] = vmobject.data["joint_product"]
if a <= 0 and b >= 1: if a <= 0 and b >= 1:
self.set_points(vm_points) self.set_points(vm_points, refresh_joints=False)
return self return self
num_curves = vmobject.get_num_curves() num_curves = vmobject.get_num_curves()
@ -971,7 +978,9 @@ class VMobject(Mobject):
# Keep new_points i2:i3 as they are # Keep new_points i2:i3 as they are
new_points[i3:i4] = high_tup new_points[i3:i4] = high_tup
new_points[i4:] = high_tup[2] new_points[i4:] = high_tup[2]
self.set_points(new_points) self.data["joint_product"][:i1] = [0, 0, 0, 1]
self.data["joint_product"][i4:] = [0, 0, 0, 1]
self.set_points(new_points, refresh_joints=False)
return self return self
def get_subcurve(self, a: float, b: float) -> VMobject: def get_subcurve(self, a: float, b: float) -> VMobject:
@ -1046,8 +1055,10 @@ class VMobject(Mobject):
null2 = (iti[0::3] - 1 == iti[1::3]) & (iti[0::3] - 2 == iti[2::3]) null2 = (iti[0::3] - 1 == iti[1::3]) & (iti[0::3] - 2 == iti[2::3])
inner_tri_indices = iti[~(null1 | null2).repeat(3)] inner_tri_indices = iti[~(null1 | null2).repeat(3)]
outer_tri_indices = self.get_outer_vert_indices() ovi = self.get_outer_vert_indices()
tri_indices = np.hstack([outer_tri_indices, inner_tri_indices]) # Flip outer triangles with negative orientation
ovi[0::3][concave_parts], ovi[2::3][concave_parts] = ovi[2::3][concave_parts], ovi[0::3][concave_parts]
tri_indices = np.hstack([ovi, inner_tri_indices])
self.triangulation = tri_indices self.triangulation = tri_indices
self.needs_new_triangulation = False self.needs_new_triangulation = False
return tri_indices return tri_indices
@ -1069,6 +1080,7 @@ class VMobject(Mobject):
return self.data["joint_product"] return self.data["joint_product"]
self.needs_new_joint_products = False self.needs_new_joint_products = False
self._data_has_changed = True
points = self.get_points() points = self.get_points()
@ -1107,6 +1119,11 @@ class VMobject(Mobject):
self.data["joint_product"][:, 3] = (vect_to_vert * vect_from_vert).sum(1) self.data["joint_product"][:, 3] = (vect_to_vert * vect_from_vert).sum(1)
return self.data["joint_product"] return self.data["joint_product"]
def lock_matching_data(self, vmobject1: VMobject, vmobject2: VMobject):
for mob in [self, vmobject1, vmobject2]:
mob.get_joint_products()
super().lock_matching_data(vmobject1, vmobject2)
def triggers_refreshed_triangulation(func: Callable): def triggers_refreshed_triangulation(func: Callable):
@wraps(func) @wraps(func)
def wrapper(self, *args, refresh=True, **kwargs): def wrapper(self, *args, refresh=True, **kwargs):
@ -1117,10 +1134,12 @@ class VMobject(Mobject):
return self return self
return wrapper return wrapper
@triggers_refreshed_triangulation def set_points(self, points: Vect3Array, refresh_joints: bool = True):
def set_points(self, points: Vect3Array):
assert(len(points) == 0 or len(points) % 2 == 1) assert(len(points) == 0 or len(points) % 2 == 1)
super().set_points(points) super().set_points(points)
self.refresh_triangulation()
if refresh_joints:
self.get_joint_products(refresh=True)
return self return self
@triggers_refreshed_triangulation @triggers_refreshed_triangulation
@ -1164,7 +1183,7 @@ class VMobject(Mobject):
self.refresh_joint_products() self.refresh_joint_products()
# For shaders # For shaders
def init_shader_data(self): def init_shader_data(self, ctx: Context):
dtype = self.shader_dtype dtype = self.shader_dtype
fill_dtype, stroke_dtype = ( fill_dtype, stroke_dtype = (
np.dtype([ np.dtype([
@ -1175,27 +1194,39 @@ class VMobject(Mobject):
) )
fill_data = np.zeros(0, dtype=fill_dtype) fill_data = np.zeros(0, dtype=fill_dtype)
stroke_data = np.zeros(0, dtype=stroke_dtype) stroke_data = np.zeros(0, dtype=stroke_dtype)
self.fill_shader_wrapper = ShaderWrapper( self.fill_shader_wrapper = FillShaderWrapper(
ctx=ctx,
vert_data=fill_data, vert_data=fill_data,
uniforms=self.uniforms, uniforms=self.uniforms,
shader_folder=self.fill_shader_folder, shader_folder=self.fill_shader_folder,
render_primitive=self.fill_render_primitive, render_primitive=self.fill_render_primitive,
is_fill=True,
) )
self.stroke_shader_wrapper = ShaderWrapper( self.stroke_shader_wrapper = ShaderWrapper(
ctx=ctx,
vert_data=stroke_data, vert_data=stroke_data,
uniforms=self.uniforms, uniforms=self.uniforms,
shader_folder=self.stroke_shader_folder, shader_folder=self.stroke_shader_folder,
render_primitive=self.stroke_render_primitive, render_primitive=self.stroke_render_primitive,
) )
self.back_stroke_shader_wrapper = self.stroke_shader_wrapper.copy() self.back_stroke_shader_wrapper = self.stroke_shader_wrapper.copy()
self.shader_wrappers = [
self.back_stroke_shader_wrapper,
self.fill_shader_wrapper,
self.stroke_shader_wrapper,
]
def refresh_shader_wrapper_id(self): def refresh_shader_wrapper_id(self):
for wrapper in self.get_shader_wrapper_list(): if not self._shaders_initialized:
return self
for wrapper in self.shader_wrappers:
wrapper.refresh_id() wrapper.refresh_id()
return self return self
def get_shader_wrapper_list(self) -> list[ShaderWrapper]: def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]:
if not self._shaders_initialized:
self.init_shader_data(ctx)
self._shaders_initialized = True
family = self.family_members_with_points() family = self.family_members_with_points()
if not family: if not family:
return [] return []
@ -1236,13 +1267,10 @@ class VMobject(Mobject):
for sw in shader_wrappers: for sw in shader_wrappers:
# Assume uniforms of the first family member # Assume uniforms of the first family member
sw.uniforms = family[0].get_uniforms() sw.uniforms.update(family[0].get_uniforms())
sw.depth_test = family[0].depth_test sw.depth_test = family[0].depth_test
return [sw for sw in shader_wrappers if len(sw.vert_data) > 0] return [sw for sw in shader_wrappers if len(sw.vert_data) > 0]
def refresh_shader_data(self):
self.get_shader_wrapper_list()
class VGroup(VMobject): class VGroup(VMobject):
def __init__(self, *vmobjects: VMobject, **kwargs): def __init__(self, *vmobjects: VMobject, **kwargs):

View file

@ -317,14 +317,12 @@ class InteractiveScene(Scene):
mob.refresh_bounding_box() mob.refresh_bounding_box()
else: else:
self.add_to_selection(mob) self.add_to_selection(mob)
self.refresh_static_mobjects()
def clear_selection(self): def clear_selection(self):
for mob in self.selection: for mob in self.selection:
mob.set_animating_status(False) mob.set_animating_status(False)
mob.refresh_bounding_box() mob.refresh_bounding_box()
self.selection.set_submobjects([]) self.selection.set_submobjects([])
self.refresh_static_mobjects()
def disable_interaction(self, *mobjects: Mobject): def disable_interaction(self, *mobjects: Mobject):
for mob in mobjects: for mob in mobjects:

View file

@ -189,14 +189,13 @@ class Scene(object):
"Press `command + q` or `esc` to quit" "Press `command + q` or `esc` to quit"
) )
self.skip_animations = False self.skip_animations = False
self.refresh_static_mobjects()
while not self.is_window_closing(): while not self.is_window_closing():
self.update_frame(1 / self.camera.fps) self.update_frame(1 / self.camera.fps)
def embed( def embed(
self, self,
close_scene_on_exit: bool = True, close_scene_on_exit: bool = True,
show_animation_progress: bool = True, show_animation_progress: bool = False,
) -> None: ) -> None:
if not self.preview: if not self.preview:
return # Embed is only relevant with a preview return # Embed is only relevant with a preview
@ -251,7 +250,6 @@ class Scene(object):
# Operation to run after each ipython command # Operation to run after each ipython command
def post_cell_func(): def post_cell_func():
self.refresh_static_mobjects()
if not self.is_window_closing(): if not self.is_window_closing():
self.update_frame(dt=0, ignore_skipping=True) self.update_frame(dt=0, ignore_skipping=True)
self.save_state() self.save_state()
@ -562,8 +560,6 @@ class Scene(object):
self.real_animation_start_time = time.time() self.real_animation_start_time = time.time()
self.virtual_animation_start_time = self.time self.virtual_animation_start_time = self.time
self.refresh_static_mobjects()
def post_play(self): def post_play(self):
if not self.skip_animations: if not self.skip_animations:
self.file_writer.end_animation() self.file_writer.end_animation()
@ -574,9 +570,6 @@ class Scene(object):
self.num_plays += 1 self.num_plays += 1
def refresh_static_mobjects(self) -> None:
self.camera.refresh_static_mobjects()
def begin_animations(self, animations: Iterable[Animation]) -> None: def begin_animations(self, animations: Iterable[Animation]) -> None:
for animation in animations: for animation in animations:
animation.begin() animation.begin()
@ -651,7 +644,6 @@ class Scene(object):
self.emit_frame() self.emit_frame()
if stop_condition is not None and stop_condition(): if stop_condition is not None and stop_condition():
break break
self.refresh_static_mobjects()
self.post_play() self.post_play()
def hold_loop(self): def hold_loop(self):
@ -711,15 +703,18 @@ class Scene(object):
if self.undo_stack: if self.undo_stack:
self.redo_stack.append(self.get_state()) self.redo_stack.append(self.get_state())
self.restore_state(self.undo_stack.pop()) self.restore_state(self.undo_stack.pop())
self.refresh_static_mobjects()
def redo(self): def redo(self):
if self.redo_stack: if self.redo_stack:
self.undo_stack.append(self.get_state()) self.undo_stack.append(self.get_state())
self.restore_state(self.redo_stack.pop()) self.restore_state(self.redo_stack.pop())
self.refresh_static_mobjects()
def checkpoint_paste(self, skip: bool = False, record: bool = False): def checkpoint_paste(
self,
skip: bool = False,
record: bool = False,
progress_bar: bool = True
):
""" """
Used during interactive development to run (or re-run) Used during interactive development to run (or re-run)
a block of scene code. a block of scene code.
@ -746,21 +741,21 @@ class Scene(object):
prev_skipping = self.skip_animations prev_skipping = self.skip_animations
self.skip_animations = skip self.skip_animations = skip
prev_progress = self.show_animation_progress
self.show_animation_progress = progress_bar
if record: if record:
# Resize window so rendering happens at the appropriate size self.camera.use_window_fbo(False)
self.window.size = self.camera.get_pixel_shape()
self.window.swap_buffers()
self.update_frame()
self.file_writer.begin_insert() self.file_writer.begin_insert()
shell.run_cell(pasted) shell.run_cell(pasted)
if record: if record:
self.file_writer.end_insert() self.file_writer.end_insert()
# Put window back to how it started self.camera.use_window_fbo(True)
self.window.to_default_position()
self.skip_animations = prev_skipping self.skip_animations = prev_skipping
self.show_animation_progress = prev_progress
def checkpoint(self, key: str): def checkpoint(self, key: str):
self.checkpoint_states[key] = self.get_state() self.checkpoint_states[key] = self.get_state()

View file

@ -293,13 +293,10 @@ class SceneFileWriter(object):
self.write_to_movie = True self.write_to_movie = True
self.init_output_directories() self.init_output_directories()
movie_path = self.get_movie_file_path() movie_path = self.get_movie_file_path()
folder, file = os.path.split(movie_path) count = 0
scene_name, ext = file.split(".") while os.path.exists(name := movie_path.replace(".", f"_insert_{count}.")):
n_inserts = len(list(filter( count += 1
lambda f: f.startswith(scene_name + "_insert"), self.inserted_file_path = name
os.listdir(folder)
)))
self.inserted_file_path = movie_path.replace(".", f"_insert_{n_inserts}.")
self.open_movie_pipe(self.inserted_file_path) self.open_movie_pipe(self.inserted_file_path)
def end_insert(self): def end_insert(self):

View file

@ -4,16 +4,23 @@ import copy
import os import os
import re import re
import OpenGL.GL as gl
import moderngl import moderngl
import numpy as np import numpy as np
from functools import lru_cache
from manimlib.utils.iterables import resize_array from manimlib.utils.iterables import resize_array
from manimlib.utils.shaders import get_shader_code_from_file from manimlib.utils.shaders import get_shader_code_from_file
from manimlib.utils.shaders import get_shader_program
from manimlib.utils.shaders import image_path_to_texture
from manimlib.utils.shaders import get_texture_id
from manimlib.utils.shaders import get_fill_palette
from manimlib.utils.shaders import release_texture
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import List from typing import List, Optional
# Mobjects that should be rendered with # Mobjects that should be rendered with
@ -26,29 +33,60 @@ if TYPE_CHECKING:
class ShaderWrapper(object): class ShaderWrapper(object):
def __init__( def __init__(
self, self,
ctx: moderngl.context.Context,
vert_data: np.ndarray, vert_data: np.ndarray,
vert_indices: np.ndarray | None = None, vert_indices: Optional[np.ndarray] = None,
shader_folder: str | None = None, shader_folder: Optional[str] = None,
uniforms: dict[str, float | np.ndarray] | None = None, # A dictionary mapping names of uniform variables uniforms: Optional[dict[str, float | np.ndarray]] = None, # A dictionary mapping names of uniform variables
texture_paths: dict[str, str] | None = None, # A dictionary mapping names to filepaths for textures. texture_paths: Optional[dict[str, str]] = None, # A dictionary mapping names to filepaths for textures.
depth_test: bool = False, depth_test: bool = False,
use_clip_plane: bool = False,
render_primitive: int = moderngl.TRIANGLE_STRIP, render_primitive: int = moderngl.TRIANGLE_STRIP,
is_fill: bool = False,
): ):
self.ctx = ctx
self.vert_data = vert_data self.vert_data = vert_data
self.vert_indices = (vert_indices or np.zeros(0)).astype(int) self.vert_indices = (vert_indices or np.zeros(0)).astype(int)
self.vert_attributes = vert_data.dtype.names self.vert_attributes = vert_data.dtype.names
self.shader_folder = shader_folder self.shader_folder = shader_folder
self.uniforms = uniforms or dict() self.uniforms = dict(uniforms or {})
self.texture_paths = texture_paths or dict()
self.depth_test = depth_test self.depth_test = depth_test
self.use_clip_plane = use_clip_plane self.render_primitive = render_primitive
self.render_primitive = str(render_primitive)
self.is_fill = is_fill
self.init_program_code() self.init_program_code()
self.init_program()
if texture_paths is not None:
self.init_textures(texture_paths)
self.refresh_id() self.refresh_id()
self.vbo = None
self.ibo = None
self.vao = None
def init_program_code(self) -> None:
def get_code(name: str) -> str | None:
return get_shader_code_from_file(
os.path.join(self.shader_folder, f"{name}.glsl")
)
self.program_code: dict[str, str | None] = {
"vertex_shader": get_code("vert"),
"geometry_shader": get_code("geom"),
"fragment_shader": get_code("frag"),
}
def init_program(self):
if not self.shader_folder:
self.program = None
self.vert_format = None
return
self.program = get_shader_program(self.ctx, **self.program_code)
self.vert_format = moderngl.detect_format(self.program, self.vert_attributes)
def init_textures(self, texture_paths: dict[str, str]):
for name, path in texture_paths.items():
texture = image_path_to_texture(path, self.ctx)
tid = get_texture_id(texture)
self.uniforms[name] = tid
def __eq__(self, shader_wrapper: ShaderWrapper): def __eq__(self, shader_wrapper: ShaderWrapper):
return all(( return all((
np.all(self.vert_data == shader_wrapper.vert_data), np.all(self.vert_data == shader_wrapper.vert_data),
@ -58,22 +96,20 @@ class ShaderWrapper(object):
np.all(self.uniforms[key] == shader_wrapper.uniforms[key]) np.all(self.uniforms[key] == shader_wrapper.uniforms[key])
for key in self.uniforms for key in self.uniforms
), ),
all(
self.texture_paths[key] == shader_wrapper.texture_paths[key]
for key in self.texture_paths
),
self.depth_test == shader_wrapper.depth_test, self.depth_test == shader_wrapper.depth_test,
self.render_primitive == shader_wrapper.render_primitive, self.render_primitive == shader_wrapper.render_primitive,
)) ))
def copy(self): def copy(self):
result = copy.copy(self) result = copy.copy(self)
result.ctx = self.ctx
result.vert_data = self.vert_data.copy() result.vert_data = self.vert_data.copy()
result.vert_indices = self.vert_indices.copy() result.vert_indices = self.vert_indices.copy()
if self.uniforms: if self.uniforms:
result.uniforms = {key: np.array(value) for key, value in self.uniforms.items()} result.uniforms = {key: np.array(value) for key, value in self.uniforms.items()}
if self.texture_paths: result.vao = None
result.texture_paths = dict(self.texture_paths) result.vbo = None
result.ibo = None
return result return result
def is_valid(self) -> bool: def is_valid(self) -> bool:
@ -94,7 +130,6 @@ class ShaderWrapper(object):
return "|".join(map(str, [ return "|".join(map(str, [
self.program_id, self.program_id,
self.uniforms, self.uniforms,
self.texture_paths,
self.depth_test, self.depth_test,
self.render_primitive, self.render_primitive,
])) ]))
@ -109,29 +144,33 @@ class ShaderWrapper(object):
for name in ("vertex", "geometry", "fragment") for name in ("vertex", "geometry", "fragment")
))) )))
def init_program_code(self) -> None:
def get_code(name: str) -> str | None:
return get_shader_code_from_file(
os.path.join(self.shader_folder, f"{name}.glsl")
)
self.program_code: dict[str, str | None] = {
"vertex_shader": get_code("vert"),
"geometry_shader": get_code("geom"),
"fragment_shader": get_code("frag"),
}
def get_program_code(self) -> dict[str, str | None]:
return self.program_code
def replace_code(self, old: str, new: str) -> None: def replace_code(self, old: str, new: str) -> None:
code_map = self.program_code code_map = self.program_code
for (name, code) in code_map.items(): for (name, code) in code_map.items():
if code_map[name] is None: if code_map[name] is None:
continue continue
code_map[name] = re.sub(old, new, code_map[name]) code_map[name] = re.sub(old, new, code_map[name])
self.init_program()
self.refresh_id() self.refresh_id()
# Changing context
def use_clip_plane(self):
if "clip_plane" not in self.uniforms:
return False
return any(self.uniforms["clip_plane"])
def set_ctx_depth_test(self, enable: bool = True) -> None:
if enable:
self.ctx.enable(moderngl.DEPTH_TEST)
else:
self.ctx.disable(moderngl.DEPTH_TEST)
def set_ctx_clip_plane(self, enable: bool = True) -> None:
if enable:
gl.glEnable(gl.GL_CLIP_DISTANCE0)
# Adding data
def combine_with(self, *shader_wrappers: ShaderWrapper) -> ShaderWrapper: def combine_with(self, *shader_wrappers: ShaderWrapper) -> ShaderWrapper:
if len(shader_wrappers) > 0: if len(shader_wrappers) > 0:
data_list = [self.vert_data, *(sw.vert_data for sw in shader_wrappers)] data_list = [self.vert_data, *(sw.vert_data for sw in shader_wrappers)]
@ -173,3 +212,86 @@ class ShaderWrapper(object):
n_verts = new_n_verts n_verts = new_n_verts
n_points += len(data) n_points += len(data)
return self return self
# Related to data and rendering
def pre_render(self):
self.set_ctx_depth_test(self.depth_test)
self.set_ctx_clip_plane(self.use_clip_plane())
self.update_program_uniforms()
def render(self):
assert(self.vao is not None)
self.vao.render()
def update_program_uniforms(self):
if self.program is None:
return
for name, value in self.uniforms.items():
if name in self.program:
if isinstance(value, np.ndarray) and value.ndim > 0:
value = tuple(value)
self.program[name].value = value
def get_vertex_buffer_object(self, refresh: bool = True):
if refresh:
self.vbo = self.ctx.buffer(self.vert_data)
return self.vbo
def get_index_buffer_object(self, refresh: bool = True):
if refresh and len(self.vert_indices) > 0:
self.ibo = self.ctx.buffer(self.vert_indices.astype(np.uint32))
return self.ibo
def generate_vao(self, refresh: bool = True):
self.release()
# Data buffer
vbo = self.get_vertex_buffer_object(refresh)
ibo = self.get_index_buffer_object(refresh)
# Vertex array object
self.vao = self.ctx.vertex_array(
program=self.program,
content=[(vbo, self.vert_format, *self.vert_attributes)],
index_buffer=ibo,
mode=self.render_primitive,
)
return self.vao
def release(self):
for obj in (self.vbo, self.ibo, self.vao):
if obj is not None:
obj.release()
self.vbo = None
self.ibo = None
self.vao = None
class FillShaderWrapper(ShaderWrapper):
def __init__(
self,
ctx: moderngl.context.Context,
*args,
**kwargs
):
super().__init__(ctx, *args, **kwargs)
def render(self):
vao = self.vao
assert(vao is not None)
winding = (len(self.vert_indices) == 0)
vao.program['winding'].value = winding
if not winding:
vao.render(moderngl.TRIANGLES)
return
original_fbo = self.ctx.fbo
texture_fbo, texture_vao = get_fill_palette(self.ctx)
texture_fbo.clear()
texture_fbo.use()
vao.render()
original_fbo.use()
self.ctx.blend_func = (moderngl.ONE, moderngl.ONE_MINUS_SRC_ALPHA)
texture_vao.render(moderngl.TRIANGLE_STRIP)
self.ctx.blend_func = (moderngl.DEFAULT_BLENDING)

View file

@ -6,20 +6,41 @@ in vec4 color;
in float fill_all; in float fill_all;
in float orientation; in float orientation;
in vec2 uv_coords; in vec2 uv_coords;
in vec3 point;
out vec4 frag_color; out vec4 frag_color;
#INSERT finalize_color.glsl
void main() { void main() {
if (color.a == 0) discard; if (color.a == 0) discard;
frag_color = color; frag_color = finalize_color(color, point, vec3(0.0, 0.0, 1.0));
/*
We want negatively oriented triangles to be canceled with positively
oriented ones. The easiest way to do this is to give them negative alpha,
and change the blend function to just add them. However, this messes with
usual blending, so instead the following line is meant to let this canceling
work even for the normal blending equation:
if(winding && orientation > 0) frag_color *= -1; (1 - alpha) * dst + alpha * src
We want the effect of blending with a positively oriented triangle followed
by a negatively oriented one to return to whatever the original frag value
was. You can work out this will work if the alpha for negative orientations
is changed to -alpha / (1 - alpha). This has a singularity at alpha = 1,
so we cap it at a value very close to 1. Effectively, the purpose of this
cap is to make sure the original fragment color can be recovered even after
blending with an alpha = 1 color.
*/
float a = 0.999 * frag_color.a;
if(winding && orientation < 0) a = -a / (1 - a);
frag_color.a = a;
if (bool(fill_all)) return; if (bool(fill_all)) return;
float x = uv_coords.x; float x = uv_coords.x;
float y = uv_coords.y; float y = uv_coords.y;
float Fxy = (y - x * x); float Fxy = (y - x * x);
if(!winding && orientation > 0) Fxy *= -1; if(!winding && orientation < 0) Fxy *= -1;
if(Fxy < 0) discard; if(Fxy < 0) discard;
} }

View file

@ -13,6 +13,7 @@ in float v_vert_index[3];
out vec4 color; out vec4 color;
out float fill_all; out float fill_all;
out float orientation; out float orientation;
out vec3 point;
// uv space is where the curve coincides with y = x^2 // uv space is where the curve coincides with y = x^2
out vec2 uv_coords; out vec2 uv_coords;
@ -26,16 +27,18 @@ const vec2 SIMPLE_QUADRATIC[3] = vec2[3](
// Analog of import for manim only // Analog of import for manim only
#INSERT get_gl_Position.glsl #INSERT get_gl_Position.glsl
#INSERT get_unit_normal.glsl #INSERT get_unit_normal.glsl
#INSERT finalize_color.glsl
void emit_triangle(vec3 points[3], vec4 v_color[3]){ void emit_triangle(vec3 points[3], vec4 v_color[3]){
vec3 unit_normal = get_unit_normal(points[0], points[1], points[2]); vec3 unit_normal = get_unit_normal(points[0], points[1], points[2]);
orientation = sign(unit_normal.z); orientation = winding ? sign(unit_normal.z) : 1.0;
for(int i = 0; i < 3; i++){ for(int i = 0; i < 3; i++){
uv_coords = SIMPLE_QUADRATIC[i]; uv_coords = SIMPLE_QUADRATIC[i];
color = finalize_color(v_color[i], points[i], unit_normal); color = v_color[i];
point = points[i];
// Pure black will be used to discard fragments later
if(winding && color.rgb == vec3(0.0)) color.rgb += vec3(0.01);
gl_Position = get_gl_Position(points[i]); gl_Position = get_gl_Position(points[i]);
EmitVertex(); EmitVertex();
} }

View file

@ -3,17 +3,66 @@ from __future__ import annotations
import os import os
import re import re
from functools import lru_cache from functools import lru_cache
import moderngl
from PIL import Image
import numpy as np
from manimlib.constants import DEFAULT_PIXEL_HEIGHT
from manimlib.constants import DEFAULT_PIXEL_WIDTH
from manimlib.utils.directories import get_shader_dir from manimlib.utils.directories import get_shader_dir
from manimlib.utils.file_ops import find_file from manimlib.utils.file_ops import find_file
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Sequence from typing import Sequence, Optional, Tuple
from moderngl.vertex_array import VertexArray
from moderngl.framebuffer import Framebuffer
@lru_cache(maxsize=12) ID_TO_TEXTURE: dict[int, moderngl.Texture] = dict()
@lru_cache()
def image_path_to_texture(path: str, ctx: moderngl.Context) -> moderngl.Texture:
im = Image.open(path).convert("RGBA")
return ctx.texture(
size=im.size,
components=len(im.getbands()),
data=im.tobytes(),
)
def get_texture_id(texture: moderngl.Texture) -> int:
tid = 0
while tid in ID_TO_TEXTURE:
tid += 1
ID_TO_TEXTURE[tid] = texture
texture.use(location=tid)
return tid
def release_texture(texture_id: int):
texture = ID_TO_TEXTURE.pop(texture_id, None)
if texture is not None:
texture.release()
@lru_cache()
def get_shader_program(
ctx: moderngl.context.Context,
vertex_shader: str,
fragment_shader: Optional[str] = None,
geometry_shader: Optional[str] = None,
) -> moderngl.Program:
return ctx.program(
vertex_shader=vertex_shader,
fragment_shader=fragment_shader,
geometry_shader=geometry_shader,
)
@lru_cache()
def get_shader_code_from_file(filename: str) -> str | None: def get_shader_code_from_file(filename: str) -> str | None:
if not filename: if not filename:
return None return None
@ -49,3 +98,70 @@ def get_colormap_code(rgb_list: Sequence[float]) -> str:
for rgb in rgb_list for rgb in rgb_list
) )
return f"vec3[{len(rgb_list)}]({data})" return f"vec3[{len(rgb_list)}]({data})"
@lru_cache()
def get_fill_palette(ctx) -> Tuple[Framebuffer, VertexArray]:
"""
Creates a texture, loaded into a frame buffer, and a vao
which can display that texture as a simple quad onto a screen.
"""
size = (2 * DEFAULT_PIXEL_WIDTH, 2 * DEFAULT_PIXEL_HEIGHT)
# Important to make sure dtype is floating point (not fixed point)
# so that alpha values can be negative and are not clipped
texture = ctx.texture(size=size, components=4, dtype='f4')
depth_buffer = ctx.depth_renderbuffer(size) # TODO, currently not used
texture_fbo = ctx.framebuffer(texture, depth_buffer)
simple_program = ctx.program(
vertex_shader='''
#version 330
in vec2 texcoord;
out vec2 v_textcoord;
void main() {
gl_Position = vec4((2.0 * texcoord - 1.0), 0.0, 1.0);
v_textcoord = texcoord;
}
''',
fragment_shader='''
#version 330
uniform sampler2D Texture;
uniform float v_nudge;
uniform float h_nudge;
in vec2 v_textcoord;
out vec4 frag_color;
void main() {
// Apply poor man's anti-aliasing
vec2 tc0 = v_textcoord + vec2(0, 0);
vec2 tc1 = v_textcoord + vec2(0, h_nudge);
vec2 tc2 = v_textcoord + vec2(v_nudge, 0);
vec2 tc3 = v_textcoord + vec2(v_nudge, h_nudge);
frag_color =
0.25 * texture(Texture, tc0) +
0.25 * texture(Texture, tc1) +
0.25 * texture(Texture, tc2) +
0.25 * texture(Texture, tc3);
if(distance(frag_color.rgb, vec3(0.0)) < 1e-3) discard;
//TODO, set gl_FragDepth;
}
''',
)
simple_program['Texture'].value = get_texture_id(texture)
# Half pixel width/height
simple_program['h_nudge'].value = 0.5 / size[0]
simple_program['v_nudge'].value = 0.5 / size[1]
verts = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
fill_texture_vao = ctx.simple_vertex_array(
simple_program,
ctx.buffer(verts.astype('f4').tobytes()),
'texcoord',
)
return (texture_fbo, fill_texture_vao)

View file

@ -26,12 +26,10 @@ class Window(PygletWindow):
self, self,
scene: Scene, scene: Scene,
size: tuple[int, int] = (1280, 720), size: tuple[int, int] = (1280, 720),
full_size: tuple[int, int] = (1920, 1080),
samples = 0 samples = 0
): ):
super().__init__(size=full_size, samples=samples) super().__init__(size=size, samples=samples)
self.full_size = full_size
self.default_size = size self.default_size = size
self.default_position = self.find_initial_position(size) self.default_position = self.find_initial_position(size)
self.scene = scene self.scene = scene