Merge pull request #1986 from 3b1b/video-work

A few small performance improvements
This commit is contained in:
Grant Sanderson 2023-02-02 14:49:49 -08:00 committed by GitHub
commit d263fa23fa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 98 additions and 30 deletions

View file

@ -26,7 +26,7 @@ class OpeningManimExample(Scene):
matrix = [[1, 1], [0, 1]] matrix = [[1, 1], [0, 1]]
linear_transform_words = VGroup( linear_transform_words = VGroup(
Text("This is what the matrix"), Text("This is what the matrix"),
IntegerMatrix(matrix, include_background_rectangle=True, h_buff=1.0), IntegerMatrix(matrix, include_background_rectangle=True),
Text("looks like") Text("looks like")
) )
linear_transform_words.arrange(RIGHT) linear_transform_words.arrange(RIGHT)

View file

@ -76,7 +76,7 @@ class Matrix(VMobject):
self, self,
matrix: Sequence[Sequence[str | float | VMobject]], matrix: Sequence[Sequence[str | float | VMobject]],
v_buff: float = 0.8, v_buff: float = 0.8,
h_buff: float = 1.3, h_buff: float = 1.0,
bracket_h_buff: float = 0.2, bracket_h_buff: float = 0.2,
bracket_v_buff: float = 0.25, bracket_v_buff: float = 0.25,
add_background_rectangles_to_entries: bool = False, add_background_rectangles_to_entries: bool = False,

View file

@ -376,7 +376,7 @@ class Mobject(object):
return [self] return [self]
def family_members_with_points(self) -> list[Self]: def family_members_with_points(self) -> list[Self]:
return [m for m in self.get_family() if m.has_points()] return [m for m in self.family if len(m.data) > 0]
def get_ancestors(self, extended: bool = False) -> list[Mobject]: def get_ancestors(self, extended: bool = False) -> list[Mobject]:
""" """
@ -449,7 +449,9 @@ class Mobject(object):
return self return self
def set_submobjects(self, submobject_list: list[Mobject]) -> Self: def set_submobjects(self, submobject_list: list[Mobject]) -> Self:
self.remove(*self.submobjects, reassemble=False) if self.submobjects == submobject_list:
return self
self.clear()
self.add(*submobject_list) self.add(*submobject_list)
return self return self

View file

@ -111,6 +111,8 @@ class VMobject(Mobject):
self.anti_alias_width = anti_alias_width self.anti_alias_width = anti_alias_width
self.fill_border_width = fill_border_width self.fill_border_width = fill_border_width
self._use_winding_fill = use_winding_fill self._use_winding_fill = use_winding_fill
self._has_fill = False
self._has_stroke = False
self.needs_new_triangulation = True self.needs_new_triangulation = True
self.triangulation = np.zeros(0, dtype='i4') self.triangulation = np.zeros(0, dtype='i4')
@ -134,6 +136,16 @@ class VMobject(Mobject):
return super().add(*vmobjects) return super().add(*vmobjects)
# Colors # Colors
def note_changed_fill(self) -> Self:
for submob in self.get_family():
submob._has_fill = submob.has_fill()
return self
def note_changed_stroke(self) -> Self:
for submob in self.get_family():
submob._has_stroke = submob.has_stroke()
return self
def init_colors(self): def init_colors(self):
self.set_fill( self.set_fill(
color=self.fill_color, color=self.fill_color,
@ -164,6 +176,10 @@ class VMobject(Mobject):
for name in names: for name in names:
super().set_rgba_array(rgba_array, name, recurse) super().set_rgba_array(rgba_array, name, recurse)
if name == "fill_rgba":
self.note_changed_fill()
elif name == "stroke_rgba":
self.note_changed_stroke()
return self return self
def set_fill( def set_fill(
@ -177,6 +193,7 @@ class VMobject(Mobject):
if border_width is not None: if border_width is not None:
for mob in self.get_family(recurse): for mob in self.get_family(recurse):
mob.data["fill_border_width"] = border_width mob.data["fill_border_width"] = border_width
self.note_changed_fill()
return self return self
def set_stroke( def set_stroke(
@ -202,6 +219,8 @@ class VMobject(Mobject):
if background is not None: if background is not None:
for mob in self.get_family(recurse): for mob in self.get_family(recurse):
mob.stroke_behind = background mob.stroke_behind = background
self.note_changed_stroke()
return self return self
def set_backstroke( def set_backstroke(
@ -255,6 +274,8 @@ class VMobject(Mobject):
if shading is not None: if shading is not None:
mob.set_shading(*shading, recurse=False) mob.set_shading(*shading, recurse=False)
self.note_changed_fill()
self.note_changed_stroke()
return self return self
def get_style(self) -> dict[str, Any]: def get_style(self) -> dict[str, Any]:
@ -378,10 +399,12 @@ class VMobject(Mobject):
return self.uniforms["anti_alias_width"] return self.uniforms["anti_alias_width"]
def has_stroke(self) -> bool: def has_stroke(self) -> bool:
return any(self.data['stroke_width']) and any(self.data['stroke_rgba'][:, 3]) data = self.data if len(self.data) > 0 else self._data_defaults
return any(data['stroke_width']) and any(data['stroke_rgba'][:, 3])
def has_fill(self) -> bool: def has_fill(self) -> bool:
return any(self.data['fill_rgba'][:, 3]) data = self.data if len(self.data) > 0 else self._data_defaults
return any(data['fill_rgba'][:, 3])
def get_opacity(self) -> float: def get_opacity(self) -> float:
if self.has_fill(): if self.has_fill():
@ -971,6 +994,10 @@ class VMobject(Mobject):
*args, **kwargs *args, **kwargs
) -> Self: ) -> Self:
super().interpolate(mobject1, mobject2, alpha, *args, **kwargs) super().interpolate(mobject1, mobject2, alpha, *args, **kwargs)
self._has_stroke = mobject1._has_stroke or mobject2._has_stroke
self._has_fill = mobject1._has_fill or mobject2._has_fill
if self.has_fill() and not self._use_winding_fill: if self.has_fill() and not self._use_winding_fill:
tri1 = mobject1.get_triangulation() tri1 = mobject1.get_triangulation()
tri2 = mobject2.get_triangulation() tri2 = mobject2.get_triangulation()
@ -1026,13 +1053,22 @@ class VMobject(Mobject):
vmob.pointwise_become_partial(self, a, b) vmob.pointwise_become_partial(self, a, b)
return vmob return vmob
def resize_points(
self,
new_length: int,
resize_func: Callable[[np.ndarray, int], np.ndarray] = resize_array
) -> Self:
super().resize_points(new_length, resize_func)
n_curves = self.get_num_curves()
# Creates the pattern (0, 1, 2, 2, 3, 4, 4, 5, 6, ...)
self.outer_vert_indices = (np.arange(1, 3 * n_curves + 1) * 2) // 3
return self
def get_outer_vert_indices(self) -> np.ndarray: def get_outer_vert_indices(self) -> np.ndarray:
""" """
Returns the pattern (0, 1, 2, 2, 3, 4, 4, 5, 6, ...) Returns the pattern (0, 1, 2, 2, 3, 4, 4, 5, 6, ...)
""" """
n_curves = self.get_num_curves()
if len(self.outer_vert_indices) != 3 * n_curves:
self.outer_vert_indices = (np.arange(1, 3 * n_curves + 1) * 2) // 3
return self.outer_vert_indices return self.outer_vert_indices
# Data for shaders that may need refreshing # Data for shaders that may need refreshing
@ -1202,6 +1238,8 @@ class VMobject(Mobject):
@triggers_refreshed_triangulation @triggers_refreshed_triangulation
def set_data(self, data: np.ndarray) -> Self: def set_data(self, data: np.ndarray) -> Self:
super().set_data(data) super().set_data(data)
self.note_changed_fill()
self.note_changed_stroke()
return self return self
# TODO, how to be smart about tangents here? # TODO, how to be smart about tangents here?
@ -1291,19 +1329,19 @@ class VMobject(Mobject):
for submob in family: for submob in family:
submob.get_joint_products() submob.get_joint_products()
indices = submob.get_outer_vert_indices() indices = submob.get_outer_vert_indices()
has_fill = submob.has_fill() has_fill = submob._has_fill
has_stroke = submob.has_stroke() has_stroke = submob._has_stroke
back_stroke = has_stroke and submob.stroke_behind back_stroke = has_stroke and submob.stroke_behind
front_stroke = has_stroke and not submob.stroke_behind front_stroke = has_stroke and not submob.stroke_behind
if back_stroke: if back_stroke:
back_stroke_datas.append(submob.data[stroke_names][indices]) back_stroke_datas.append(submob.data[stroke_names][indices])
if front_stroke: if front_stroke:
stroke_datas.append(submob.data[stroke_names][indices]) stroke_datas.append(submob.data[stroke_names][indices])
if has_fill and self._use_winding_fill: if has_fill and submob._use_winding_fill:
data = submob.data[fill_names] data = submob.data[fill_names]
data["base_point"][:] = data["point"][0] data["base_point"][:] = data["point"][0]
fill_datas.append(data[indices]) fill_datas.append(data[indices])
if has_fill and not self._use_winding_fill: if has_fill and not submob._use_winding_fill:
fill_datas.append(submob.data[fill_names]) fill_datas.append(submob.data[fill_names])
fill_indices.append(submob.get_triangulation()) fill_indices.append(submob.get_triangulation())
if has_fill and not front_stroke: if has_fill and not front_stroke:

View file

@ -7,6 +7,7 @@ import platform
import pyperclip import pyperclip
import random import random
import time import time
from functools import wraps
from IPython.terminal import pt_inputhooks from IPython.terminal import pt_inputhooks
from IPython.terminal.embed import InteractiveShellEmbed from IPython.terminal.embed import InteractiveShellEmbed
@ -38,6 +39,7 @@ from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.scene.scene_file_writer import SceneFileWriter from manimlib.scene.scene_file_writer import SceneFileWriter
from manimlib.utils.family_ops import extract_mobject_family_members from manimlib.utils.family_ops import extract_mobject_family_members
from manimlib.utils.family_ops import recursive_mobject_remove from manimlib.utils.family_ops import recursive_mobject_remove
from manimlib.utils.iterables import batch_by_property
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@ -119,6 +121,7 @@ class Scene(object):
self.file_writer = SceneFileWriter(self, **self.file_writer_config) self.file_writer = SceneFileWriter(self, **self.file_writer_config)
self.mobjects: list[Mobject] = [self.camera.frame] self.mobjects: list[Mobject] = [self.camera.frame]
self.render_groups: list[Mobject] = []
self.id_to_mobject_map: dict[int, Mobject] = dict() self.id_to_mobject_map: dict[int, Mobject] = dict()
self.num_plays: int = 0 self.num_plays: int = 0
self.time: float = 0 self.time: float = 0
@ -298,7 +301,7 @@ class Scene(object):
def get_image(self) -> Image: def get_image(self) -> Image:
if self.window is not None: if self.window is not None:
self.camera.use_window_fbo(False) self.camera.use_window_fbo(False)
self.camera.capture(*self.mobjects) self.camera.capture(*self.render_groups)
image = self.camera.get_image() image = self.camera.get_image()
if self.window is not None: if self.window is not None:
self.camera.use_window_fbo(True) self.camera.use_window_fbo(True)
@ -319,7 +322,7 @@ class Scene(object):
if self.window: if self.window:
self.window.clear() self.window.clear()
self.camera.capture(*self.mobjects) self.camera.capture(*self.render_groups)
if self.window: if self.window:
self.window.swap_buffers() self.window.swap_buffers()
@ -378,6 +381,30 @@ class Scene(object):
def get_mobject_family_members(self) -> list[Mobject]: def get_mobject_family_members(self) -> list[Mobject]:
return extract_mobject_family_members(self.mobjects) return extract_mobject_family_members(self.mobjects)
def assemble_render_groups(self):
"""
Rendering can be more efficient when mobjects of the
same type are grouped together, so this function creates
Groups of all clusters of adjacent Mobjects in the scene
"""
batches = batch_by_property(self.mobjects, lambda m: str(type(m)))
for group in self.render_groups:
group.clear()
self.render_groups = [
batch[0].get_group_class()(*batch)
for batch, key in batches
]
def affects_mobject_list(func: Callable):
@wraps(func)
def wrapper(self, *args, **kwargs):
func(self, *args, **kwargs)
self.assemble_render_groups()
return self
return wrapper
@affects_mobject_list
def add(self, *new_mobjects: Mobject): def add(self, *new_mobjects: Mobject):
""" """
Mobjects will be displayed, from background to Mobjects will be displayed, from background to
@ -404,6 +431,7 @@ class Scene(object):
)) ))
return self return self
@affects_mobject_list
def replace(self, mobject: Mobject, *replacements: Mobject): def replace(self, mobject: Mobject, *replacements: Mobject):
if mobject in self.mobjects: if mobject in self.mobjects:
index = self.mobjects.index(mobject) index = self.mobjects.index(mobject)
@ -414,6 +442,7 @@ class Scene(object):
] ]
return self return self
@affects_mobject_list
def remove(self, *mobjects_to_remove: Mobject): def remove(self, *mobjects_to_remove: Mobject):
""" """
Removes anything in mobjects from scenes mobject list, but in the event that one Removes anything in mobjects from scenes mobject list, but in the event that one
@ -431,11 +460,13 @@ class Scene(object):
self.add(*mobjects) self.add(*mobjects)
return self return self
@affects_mobject_list
def bring_to_back(self, *mobjects: Mobject): def bring_to_back(self, *mobjects: Mobject):
self.remove(*mobjects) self.remove(*mobjects)
self.mobjects = list(mobjects) + self.mobjects self.mobjects = list(mobjects) + self.mobjects
return self return self
@affects_mobject_list
def clear(self): def clear(self):
self.mobjects = [] self.mobjects = []
return self return self
@ -608,6 +639,7 @@ class Scene(object):
else: else:
self.update_mobjects(0) self.update_mobjects(0)
@affects_mobject_list
def play( def play(
self, self,
*proto_animations: Animation | _AnimationBuilder, *proto_animations: Animation | _AnimationBuilder,

View file

@ -6,16 +6,12 @@ in vec4 color;
in float fill_all; in float fill_all;
in float orientation; in float orientation;
in vec2 uv_coords; in vec2 uv_coords;
in vec3 point;
in vec3 unit_normal;
out vec4 frag_color; out vec4 frag_color;
#INSERT finalize_color.glsl
void main() { void main() {
if (color.a == 0) discard; if (color.a == 0) discard;
frag_color = finalize_color(color, point, unit_normal); frag_color = color;
/* /*
We want negatively oriented triangles to be canceled with positively We want negatively oriented triangles to be canceled with positively
oriented ones. The easiest way to do this is to give them negative alpha, oriented ones. The easiest way to do this is to give them negative alpha,
@ -33,9 +29,11 @@ void main() {
cap is to make sure the original fragment color can be recovered even after cap is to make sure the original fragment color can be recovered even after
blending with an (alpha = 1) color. blending with an (alpha = 1) color.
*/ */
float a = 0.95 * frag_color.a; if(winding){
if(winding && orientation < 0) a = -a / (1 - a); float a = 0.95 * frag_color.a;
frag_color.a = a; if(orientation < 0) a = -a / (1 - a);
frag_color.a = a;
}
if (bool(fill_all)) return; if (bool(fill_all)) return;

View file

@ -14,8 +14,6 @@ in vec3 v_unit_normal[3];
out vec4 color; out vec4 color;
out float fill_all; out float fill_all;
out float orientation; out float orientation;
out vec3 point;
out vec3 unit_normal;
// uv space is where the curve coincides with y = x^2 // uv space is where the curve coincides with y = x^2
out vec2 uv_coords; out vec2 uv_coords;
@ -28,9 +26,12 @@ const vec2 SIMPLE_QUADRATIC[3] = vec2[3](
// Analog of import for manim only // Analog of import for manim only
#INSERT emit_gl_Position.glsl #INSERT emit_gl_Position.glsl
#INSERT finalize_color.glsl
void emit_triangle(vec3 points[3], vec4 v_color[3]){ void emit_triangle(vec3 points[3], vec4 v_color[3]){
vec3 unit_normal = v_unit_normal[1];
orientation = sign(determinant(mat3( orientation = sign(determinant(mat3(
unit_normal, unit_normal,
points[1] - points[0], points[1] - points[0],
@ -39,8 +40,7 @@ void emit_triangle(vec3 points[3], vec4 v_color[3]){
for(int i = 0; i < 3; i++){ for(int i = 0; i < 3; i++){
uv_coords = SIMPLE_QUADRATIC[i]; uv_coords = SIMPLE_QUADRATIC[i];
color = v_color[i]; color = finalize_color(v_color[i], points[i], unit_normal);
point = points[i];
emit_gl_Position(points[i]); emit_gl_Position(points[i]);
EmitVertex(); EmitVertex();
} }
@ -61,8 +61,6 @@ void main(){
// the first anchor is set equal to that anchor // the first anchor is set equal to that anchor
if (verts[0] == verts[1]) return; if (verts[0] == verts[1]) return;
unit_normal = v_unit_normal[1];
if(winding){ if(winding){
// Emit main triangle // Emit main triangle
fill_all = 1.0; fill_all = 1.0;

View file

@ -16,7 +16,7 @@ def num_tex_symbols(tex: str) -> int:
# \begin{array}{cc}, etc. # \begin{array}{cc}, etc.
pattern = "|".join( pattern = "|".join(
rf"(\\{s})" + r"(\{\w+\})?(\{\w+\})?(\[\w+\])?" rf"(\\{s})" + r"(\{\w+\})?(\{\w+\})?(\[\w+\])?"
for s in ["begin", "end", "phantom"] for s in ["begin", "end", "phantom", "text"]
) )
tex = re.sub(pattern, "", tex) tex = re.sub(pattern, "", tex)