Merge branch 'master' of github.com:3b1b/manim

This commit is contained in:
Grant Sanderson 2024-08-01 07:55:11 -05:00
commit d644e3b184
31 changed files with 929 additions and 647 deletions

View file

@ -26,7 +26,7 @@ class OpeningManimExample(Scene):
matrix = [[1, 1], [0, 1]] matrix = [[1, 1], [0, 1]]
linear_transform_words = VGroup( linear_transform_words = VGroup(
Text("This is what the matrix"), Text("This is what the matrix"),
IntegerMatrix(matrix, include_background_rectangle=True), IntegerMatrix(matrix),
Text("looks like") Text("looks like")
) )
linear_transform_words.arrange(RIGHT) linear_transform_words.arrange(RIGHT)
@ -251,7 +251,7 @@ class TexIndexing(Scene):
self.play(FlashAround(part)) self.play(FlashAround(part))
self.wait() self.wait()
self.play(FadeOut(equation)) self.play(FadeOut(equation))
# Indexing by substrings like this may not work when # Indexing by substrings like this may not work when
# the order in which Latex draws symbols does not match # the order in which Latex draws symbols does not match
# the order in which they show up in the string. # the order in which they show up in the string.
@ -289,11 +289,11 @@ class UpdatersExample(Scene):
brace = always_redraw(Brace, square, UP) brace = always_redraw(Brace, square, UP)
label = TexText("Width = 0.00") label = TexText("Width = 0.00")
number = label.make_number_changable("0.00") number = label.make_number_changeable("0.00")
# This ensures that the method deicmal.next_to(square) # This ensures that the method deicmal.next_to(square)
# is called on every frame # is called on every frame
always(label.next_to, brace, UP) label.always.next_to(brace, UP)
# You could also write the following equivalent line # You could also write the following equivalent line
# label.add_updater(lambda m: m.next_to(brace, UP)) # label.add_updater(lambda m: m.next_to(brace, UP))
@ -302,7 +302,7 @@ class UpdatersExample(Scene):
# should be functions returning arguments to that method. # should be functions returning arguments to that method.
# The following line ensures thst decimal.set_value(square.get_y()) # The following line ensures thst decimal.set_value(square.get_y())
# is called every frame # is called every frame
f_always(number.set_value, square.get_width) number.f_always.set_value(square.get_width)
# You could also write the following equivalent line # You could also write the following equivalent line
# number.add_updater(lambda m: m.set_value(square.get_width())) # number.add_updater(lambda m: m.set_value(square.get_width()))
@ -359,7 +359,7 @@ class CoordinateSystemExample(Scene):
# Alternatively, you can specify configuration for just one # Alternatively, you can specify configuration for just one
# of them, like this. # of them, like this.
y_axis_config=dict( y_axis_config=dict(
numbers_with_elongated_ticks=[-2, 2], big_tick_numbers=[-2, 2],
) )
) )
# Keyword arguments of add_coordinate_labels can be used to # Keyword arguments of add_coordinate_labels can be used to
@ -515,7 +515,7 @@ class TexAndNumbersExample(Scene):
# on them. # on them.
tex = Tex("x^2 + y^2 = 4.00") tex = Tex("x^2 + y^2 = 4.00")
tex.next_to(axes, UP, buff=0.5) tex.next_to(axes, UP, buff=0.5)
value = tex.make_number_changable("4.00") value = tex.make_number_changeable("4.00")
# This will tie the right hand side of our equation to # This will tie the right hand side of our equation to
@ -537,10 +537,10 @@ class TexAndNumbersExample(Scene):
rate_func=there_and_back, rate_func=there_and_back,
) )
# By default, tex.make_number_changable replaces the first occurance # By default, tex.make_number_changeable replaces the first occurance
# of the number,but by passing replace_all=True it replaces all and # of the number,but by passing replace_all=True it replaces all and
# returns a group of the results # returns a group of the results
exponents = tex.make_number_changable("2", replace_all=True) exponents = tex.make_number_changeable("2", replace_all=True)
self.play( self.play(
LaggedStartMap( LaggedStartMap(
FlashAround, exponents, FlashAround, exponents,

View file

@ -43,7 +43,6 @@ from manimlib.mobject.probability import *
from manimlib.mobject.shape_matchers import * from manimlib.mobject.shape_matchers import *
from manimlib.mobject.svg.brace import * from manimlib.mobject.svg.brace import *
from manimlib.mobject.svg.drawings import * from manimlib.mobject.svg.drawings import *
from manimlib.mobject.svg.tex_mobject import *
from manimlib.mobject.svg.string_mobject import * from manimlib.mobject.svg.string_mobject import *
from manimlib.mobject.svg.svg_mobject import * from manimlib.mobject.svg.svg_mobject import *
from manimlib.mobject.svg.special_tex import * from manimlib.mobject.svg.special_tex import *

View file

@ -13,7 +13,8 @@ from manimlib.utils.bezier import interpolate
from manimlib.utils.iterables import remove_list_redundancies from manimlib.utils.iterables import remove_list_redundancies
from manimlib.utils.simple_functions import clip from manimlib.utils.simple_functions import clip
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Union, Iterable
AnimationType = Union[Animation, _AnimationBuilder]
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable, Optional from typing import Callable, Optional
@ -26,14 +27,16 @@ DEFAULT_LAGGED_START_LAG_RATIO = 0.05
class AnimationGroup(Animation): class AnimationGroup(Animation):
def __init__(self, def __init__(
*animations: Animation | _AnimationBuilder, self,
*args: AnimationType | Iterable[AnimationType],
run_time: float = -1, # If negative, default to sum of inputed animation runtimes run_time: float = -1, # If negative, default to sum of inputed animation runtimes
lag_ratio: float = 0.0, lag_ratio: float = 0.0,
group: Optional[Mobject] = None, group: Optional[Mobject] = None,
group_type: Optional[type] = None, group_type: Optional[type] = None,
**kwargs **kwargs
): ):
animations = args[0] if isinstance(args[0], Iterable) else args
self.animations = [prepare_animation(anim) for anim in animations] self.animations = [prepare_animation(anim) for anim in animations]
self.build_animations_with_timings(lag_ratio) self.build_animations_with_timings(lag_ratio)
self.max_end_time = max((awt[2] for awt in self.anims_with_timings), default=0) self.max_end_time = max((awt[2] for awt in self.anims_with_timings), default=0)
@ -177,4 +180,5 @@ class LaggedStartMap(LaggedStart):
*(anim_func(submob, **anim_kwargs) for submob in group), *(anim_func(submob, **anim_kwargs) for submob in group),
run_time=run_time, run_time=run_time,
lag_ratio=lag_ratio, lag_ratio=lag_ratio,
group=group
) )

View file

@ -118,6 +118,7 @@ class FadeTransform(Transform):
def ghost_to(self, source: Mobject, target: Mobject) -> None: def ghost_to(self, source: Mobject, target: Mobject) -> None:
source.replace(target, stretch=self.stretch, dim_to_match=self.dim_to_match) source.replace(target, stretch=self.stretch, dim_to_match=self.dim_to_match)
source.set_uniform(**target.get_uniforms())
source.set_opacity(0) source.set_opacity(0)
def get_all_mobjects(self) -> list[Mobject]: def get_all_mobjects(self) -> list[Mobject]:
@ -134,7 +135,8 @@ class FadeTransform(Transform):
Animation.clean_up_from_scene(self, scene) Animation.clean_up_from_scene(self, scene)
scene.remove(self.mobject) scene.remove(self.mobject)
self.mobject[0].restore() self.mobject[0].restore()
scene.add(self.to_add_on_completion) if not self.remover:
scene.add(self.to_add_on_completion)
class FadeTransformPieces(FadeTransform): class FadeTransformPieces(FadeTransform):

View file

@ -3,6 +3,7 @@ from __future__ import annotations
from manimlib.animation.animation import Animation from manimlib.animation.animation import Animation
from manimlib.mobject.numbers import DecimalNumber from manimlib.mobject.numbers import DecimalNumber
from manimlib.utils.bezier import interpolate from manimlib.utils.bezier import interpolate
from manimlib.utils.simple_functions import clip
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@ -55,9 +56,9 @@ class CountInFrom(ChangingDecimal):
source_number: float | complex = 0, source_number: float | complex = 0,
**kwargs **kwargs
): ):
start_number = decimal_mob.number start_number = decimal_mob.get_value()
super().__init__( super().__init__(
decimal_mob, decimal_mob,
lambda a: interpolate(source_number, start_number, a), lambda a: interpolate(source_number, start_number, clip(a, 0, 1)),
**kwargs **kwargs
) )

View file

@ -65,7 +65,7 @@ class Transform(Animation):
self.target_copy = self.target_mobject.copy() self.target_copy = self.target_mobject.copy()
self.mobject.align_data_and_family(self.target_copy) self.mobject.align_data_and_family(self.target_copy)
super().begin() super().begin()
if not self.mobject.has_updaters: if not self.mobject.has_updaters():
self.mobject.lock_matching_data( self.mobject.lock_matching_data(
self.starting_mobject, self.starting_mobject,
self.target_copy, self.target_copy,

View file

@ -25,6 +25,7 @@ class CameraFrame(Mobject):
center_point: Vect3 = ORIGIN, center_point: Vect3 = ORIGIN,
# Field of view in the y direction # Field of view in the y direction
fovy: float = 45 * DEGREES, fovy: float = 45 * DEGREES,
euler_axes: str = "zxz",
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -35,6 +36,7 @@ class CameraFrame(Mobject):
self.default_orientation = Rotation.identity() self.default_orientation = Rotation.identity()
self.view_matrix = np.identity(4) self.view_matrix = np.identity(4)
self.camera_location = OUT # This will be updated by set_points self.camera_location = OUT # This will be updated by set_points
self.euler_axes = euler_axes
self.set_points(np.array([ORIGIN, LEFT, RIGHT, DOWN, UP])) self.set_points(np.array([ORIGIN, LEFT, RIGHT, DOWN, UP]))
self.set_width(frame_shape[0], stretch=True) self.set_width(frame_shape[0], stretch=True)
@ -62,7 +64,7 @@ class CameraFrame(Mobject):
orientation = self.get_orientation() orientation = self.get_orientation()
if all(orientation.as_quat() == [0, 0, 0, 1]): if all(orientation.as_quat() == [0, 0, 0, 1]):
return np.zeros(3) return np.zeros(3)
return orientation.as_euler("zxz")[::-1] return orientation.as_euler(self.euler_axes)[::-1]
def get_theta(self): def get_theta(self):
return self.get_euler_angles()[0] return self.get_euler_angles()[0]
@ -126,21 +128,44 @@ class CameraFrame(Mobject):
if all(eulers == 0): if all(eulers == 0):
rot = Rotation.identity() rot = Rotation.identity()
else: else:
rot = Rotation.from_euler("zxz", eulers[::-1]) rot = Rotation.from_euler(self.euler_axes, eulers[::-1])
self.set_orientation(rot) self.set_orientation(rot)
return self return self
def increment_euler_angles(
self,
dtheta: float | None = None,
dphi: float | None = None,
dgamma: float | None = None,
units: float = RADIANS
):
angles = self.get_euler_angles()
for i, value in enumerate([dtheta, dphi, dgamma]):
if value is not None:
angles[i] += value * units
self.set_euler_angles(*angles)
return self
def set_euler_axes(self, seq: str):
self.euler_axes = seq
def reorient( def reorient(
self, self,
theta_degrees: float | None = None, theta_degrees: float | None = None,
phi_degrees: float | None = None, phi_degrees: float | None = None,
gamma_degrees: float | None = None, gamma_degrees: float | None = None,
center: Vect3 | tuple[float, float, float] | None = None,
height: float | None = None
): ):
""" """
Shortcut for set_euler_angles, defaulting to taking Shortcut for set_euler_angles, defaulting to taking
in angles in degrees in angles in degrees
""" """
self.set_euler_angles(theta_degrees, phi_degrees, gamma_degrees, units=DEGREES) self.set_euler_angles(theta_degrees, phi_degrees, gamma_degrees, units=DEGREES)
if center is not None:
self.move_to(np.array(center))
if height is not None:
self.set_height(height)
return self return self
def set_theta(self, theta: float): def set_theta(self, theta: float):
@ -152,16 +177,20 @@ class CameraFrame(Mobject):
def set_gamma(self, gamma: float): def set_gamma(self, gamma: float):
return self.set_euler_angles(gamma=gamma) return self.set_euler_angles(gamma=gamma)
def increment_theta(self, dtheta: float): def increment_theta(self, dtheta: float, units=RADIANS):
self.rotate(dtheta, OUT) self.increment_euler_angles(dtheta=dtheta, units=units)
return self return self
def increment_phi(self, dphi: float): def increment_phi(self, dphi: float, units=RADIANS):
self.rotate(dphi, self.get_inverse_camera_rotation_matrix()[0]) self.increment_euler_angles(dphi=dphi, units=units)
return self return self
def increment_gamma(self, dgamma: float): def increment_gamma(self, dgamma: float, units=RADIANS):
self.rotate(dgamma, self.get_inverse_camera_rotation_matrix()[2]) self.increment_euler_angles(dgamma=dgamma, units=units)
return self
def add_ambient_rotation(self, angular_speed=1 * DEGREES):
self.add_updater(lambda m, dt: m.increment_theta(angular_speed * dt))
return self return self
@Mobject.affects_data @Mobject.affects_data

View file

@ -6,7 +6,7 @@ import colour
import importlib import importlib
import inspect import inspect
import os import os
from screeninfo import get_monitors import screeninfo
import sys import sys
import yaml import yaml
@ -433,7 +433,10 @@ def get_file_writer_config(args: Namespace, custom_config: dict) -> dict:
def get_window_config(args: Namespace, custom_config: dict, camera_config: dict) -> dict: def get_window_config(args: Namespace, custom_config: dict, camera_config: dict) -> dict:
# Default to making window half the screen size # Default to making window half the screen size
# but make it full screen if -f is passed in # but make it full screen if -f is passed in
monitors = get_monitors() try:
monitors = screeninfo.get_monitors()
except screeninfo.ScreenInfoError:
pass
mon_index = custom_config["window_monitor"] mon_index = custom_config["window_monitor"]
monitor = monitors[min(mon_index, len(monitors) - 1)] monitor = monitors[min(mon_index, len(monitors) - 1)]
aspect_ratio = camera_config["pixel_width"] / camera_config["pixel_height"] aspect_ratio = camera_config["pixel_width"] / camera_config["pixel_height"]

View file

@ -167,3 +167,4 @@ class TracingTail(TracedPath):
stroke_color=stroke_color, stroke_color=stroke_color,
**kwargs **kwargs
) )
self.add_updater(lambda m: m.set_stroke(width=stroke_width, opacity=stroke_opacity))

View file

@ -603,9 +603,6 @@ class ThreeDAxes(Axes):
**kwargs **kwargs
) -> ParametricSurface: ) -> ParametricSurface:
surface = ParametricSurface(func, color=color, opacity=opacity, **kwargs) surface = ParametricSurface(func, color=color, opacity=opacity, **kwargs)
xu = self.x_axis.get_unit_size()
yu = self.y_axis.get_unit_size()
zu = self.z_axis.get_unit_size()
axes = [self.x_axis, self.y_axis, self.z_axis] axes = [self.x_axis, self.y_axis, self.z_axis]
for dim, axis in zip(range(3), axes): for dim, axis in zip(range(3), axes):
surface.stretch(axis.get_unit_size(), dim, about_point=ORIGIN) surface.stretch(axis.get_unit_size(), dim, about_point=ORIGIN)

View file

@ -4,85 +4,38 @@ import itertools as it
import numpy as np import numpy as np
from manimlib.constants import DEFAULT_MOBJECT_TO_MOBJECT_BUFFER from manimlib.constants import DOWN, LEFT, RIGHT, ORIGIN
from manimlib.constants import DOWN, LEFT, RIGHT, UP from manimlib.constants import DEGREES
from manimlib.constants import WHITE
from manimlib.mobject.numbers import DecimalNumber from manimlib.mobject.numbers import DecimalNumber
from manimlib.mobject.numbers import Integer
from manimlib.mobject.shape_matchers import BackgroundRectangle
from manimlib.mobject.svg.tex_mobject import Tex from manimlib.mobject.svg.tex_mobject import Tex
from manimlib.mobject.svg.tex_mobject import TexText
from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.mobject.types.vectorized_mobject import VMobject
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Sequence from typing import Sequence, Union, Tuple, Optional
import numpy.typing as npt from manimlib.typing import ManimColor, Vect3, VectNArray, Self
from manimlib.mobject.mobject import Mobject
from manimlib.typing import ManimColor, Vect3, Self
StringMatrixType = Union[Sequence[Sequence[str]], np.ndarray[int, np.dtype[np.str_]]]
VECTOR_LABEL_SCALE_FACTOR = 0.8 FloatMatrixType = Union[Sequence[Sequence[float]], VectNArray]
VMobjectMatrixType = Sequence[Sequence[VMobject]]
GenericMatrixType = Union[FloatMatrixType, StringMatrixType, VMobjectMatrixType]
def matrix_to_tex_string(matrix: npt.ArrayLike) -> str:
matrix = np.array(matrix).astype("str")
if matrix.ndim == 1:
matrix = matrix.reshape((matrix.size, 1))
n_rows, n_cols = matrix.shape
prefix = R"\left[ \begin{array}{%s}" % ("c" * n_cols)
suffix = R"\end{array} \right]"
rows = [
" & ".join(row)
for row in matrix
]
return prefix + R" \\ ".join(rows) + suffix
def matrix_to_mobject(matrix: npt.ArrayLike) -> Tex:
return Tex(matrix_to_tex_string(matrix))
def vector_coordinate_label(
vector_mob: VMobject,
integer_labels: bool = True,
n_dim: int = 2,
color: ManimColor = WHITE
) -> Matrix:
vect = np.array(vector_mob.get_end())
if integer_labels:
vect = np.round(vect).astype(int)
vect = vect[:n_dim]
vect = vect.reshape((n_dim, 1))
label = Matrix(vect, add_background_rectangles_to_entries=True)
label.scale(VECTOR_LABEL_SCALE_FACTOR)
shift_dir = np.array(vector_mob.get_end())
if shift_dir[0] >= 0: # Pointing right
shift_dir -= label.get_left() + DEFAULT_MOBJECT_TO_MOBJECT_BUFFER * LEFT
else: # Pointing left
shift_dir -= label.get_right() + DEFAULT_MOBJECT_TO_MOBJECT_BUFFER * RIGHT
label.shift(shift_dir)
label.set_color(color)
label.rect = BackgroundRectangle(label)
label.add_to_back(label.rect)
return label
class Matrix(VMobject): class Matrix(VMobject):
def __init__( def __init__(
self, self,
matrix: Sequence[Sequence[str | float | VMobject]], matrix: GenericMatrixType,
v_buff: float = 0.8, v_buff: float = 0.5,
h_buff: float = 1.0, h_buff: float = 0.5,
bracket_h_buff: float = 0.2, bracket_h_buff: float = 0.2,
bracket_v_buff: float = 0.25, bracket_v_buff: float = 0.25,
add_background_rectangles_to_entries: bool = False, height: float | None = None,
include_background_rectangle: bool = False, element_config: dict = dict(),
element_alignment_corner: Vect3 = DOWN, element_alignment_corner: Vect3 = DOWN,
**kwargs ellipses_row: Optional[int] = None,
ellipses_col: Optional[int] = None,
): ):
""" """
Matrix can either include numbers, tex_strings, Matrix can either include numbers, tex_strings,
@ -90,83 +43,109 @@ class Matrix(VMobject):
""" """
super().__init__() super().__init__()
mob_matrix = self.matrix_to_mob_matrix(matrix, **kwargs) self.mob_matrix = self.create_mobject_matrix(
self.mob_matrix = mob_matrix matrix, v_buff, h_buff, element_alignment_corner,
**element_config
)
self.organize_mob_matrix(mob_matrix, v_buff, h_buff, element_alignment_corner) # Create helpful groups for the elements
self.elements = VGroup(*it.chain(*mob_matrix)) n_cols = len(self.mob_matrix[0])
self.add(self.elements) self.elements = [elem for row in self.mob_matrix for elem in row]
self.add_brackets(bracket_v_buff, bracket_h_buff) self.columns = VGroup(*(
VGroup(*(row[i] for row in self.mob_matrix))
for i in range(n_cols)
))
self.rows = VGroup(*(VGroup(*row) for row in self.mob_matrix))
if height is not None:
self.rows.set_height(height - 2 * bracket_v_buff)
self.brackets = self.create_brackets(self.rows, bracket_v_buff, bracket_h_buff)
self.ellipses = []
# Add elements and brackets
self.add(*self.elements)
self.add(*self.brackets)
self.center() self.center()
if add_background_rectangles_to_entries:
for mob in self.elements:
mob.add_background_rectangle()
if include_background_rectangle:
self.add_background_rectangle()
# Potentially add ellipses
self.swap_entries_for_ellipses(
ellipses_row,
ellipses_col,
)
def element_to_mobject(self, element: str | float | VMobject, **config) -> VMobject: def copy(self, deep: bool = False):
if isinstance(element, VMobject): result = super().copy(deep)
return element self_family = self.get_family()
return Tex(str(element), **config) copy_family = result.get_family()
for attr in ["elements", "ellipses"]:
setattr(result, attr, [
copy_family[self_family.index(mob)]
for mob in getattr(self, attr)
])
return result
def matrix_to_mob_matrix( def create_mobject_matrix(
self, self,
matrix: Sequence[Sequence[str | float | VMobject]], matrix: GenericMatrixType,
**config
) -> list[list[VMobject]]:
return [
[
self.element_to_mobject(item, **config)
for item in row
]
for row in matrix
]
def organize_mob_matrix(
self,
matrix: list[list[VMobject]],
v_buff: float, v_buff: float,
h_buff: float, h_buff: float,
aligned_corner: Vect3, aligned_corner: Vect3,
) -> Self: **element_config
for i, row in enumerate(matrix): ) -> VMobjectMatrixType:
"""
Creates and organizes the matrix of mobjects
"""
mob_matrix = [
[
self.element_to_mobject(element, **element_config)
for element in row
]
for row in matrix
]
max_width = max(elem.get_width() for row in mob_matrix for elem in row)
max_height = max(elem.get_height() for row in mob_matrix for elem in row)
x_step = (max_width + h_buff) * RIGHT
y_step = (max_height + v_buff) * DOWN
for i, row in enumerate(mob_matrix):
for j, elem in enumerate(row): for j, elem in enumerate(row):
mob = matrix[i][j] elem.move_to(i * y_step + j * x_step, aligned_corner)
mob.move_to( return mob_matrix
i * v_buff * DOWN + j * h_buff * RIGHT,
aligned_corner
)
return self
def add_brackets(self, v_buff: float, h_buff: float) -> Self: def element_to_mobject(self, element, **config) -> VMobject:
height = len(self.mob_matrix) if isinstance(element, VMobject):
return element
elif isinstance(element, float | complex):
return DecimalNumber(element, **config)
else:
return Tex(str(element), **config)
def create_brackets(self, rows, v_buff: float, h_buff: float) -> VGroup:
brackets = Tex("".join(( brackets = Tex("".join((
R"\left[\begin{array}{c}", R"\left[\begin{array}{c}",
*height * [R"\quad \\"], *len(rows) * [R"\quad \\"],
R"\end{array}\right]", R"\end{array}\right]",
))) )))
brackets.set_height(self.get_height() + v_buff) brackets.set_height(rows.get_height() + v_buff)
l_bracket = brackets[:len(brackets) // 2] l_bracket = brackets[:len(brackets) // 2]
r_bracket = brackets[len(brackets) // 2:] r_bracket = brackets[len(brackets) // 2:]
l_bracket.next_to(self, LEFT, h_buff) l_bracket.next_to(rows, LEFT, h_buff)
r_bracket.next_to(self, RIGHT, h_buff) r_bracket.next_to(rows, RIGHT, h_buff)
brackets.set_submobjects([l_bracket, r_bracket]) return VGroup(l_bracket, r_bracket)
self.brackets = brackets
self.add(*brackets) def get_column(self, index: int):
return self if not 0 <= index < len(self.columns):
raise IndexError(f"Index {index} out of bound for matrix with {len(self.columns)} columns")
return self.columns[index]
def get_row(self, index: int):
if not 0 <= index < len(self.rows):
raise IndexError(f"Index {index} out of bound for matrix with {len(self.rows)} rows")
return self.rows[index]
def get_columns(self) -> VGroup: def get_columns(self) -> VGroup:
return VGroup(*[ return self.columns
VGroup(*[row[i] for row in self.mob_matrix])
for i in range(len(self.mob_matrix[0]))
])
def get_rows(self) -> VGroup: def get_rows(self) -> VGroup:
return VGroup(*[ return self.rows
VGroup(*row)
for row in self.mob_matrix
])
def set_column_colors(self, *colors: ManimColor) -> Self: def set_column_colors(self, *colors: ManimColor) -> Self:
columns = self.get_columns() columns = self.get_columns()
@ -179,61 +158,138 @@ class Matrix(VMobject):
mob.add_background_rectangle() mob.add_background_rectangle()
return self return self
def get_mob_matrix(self) -> list[list[Mobject]]: def swap_entry_for_dots(self, entry, dots):
dots.move_to(entry)
entry.become(dots)
if entry in self.elements:
self.elements.remove(entry)
if entry not in self.ellipses:
self.ellipses.append(entry)
def swap_entries_for_ellipses(
self,
row_index: Optional[int] = None,
col_index: Optional[int] = None,
height_ratio: float = 0.65,
width_ratio: float = 0.4
):
rows = self.get_rows()
cols = self.get_columns()
avg_row_height = rows.get_height() / len(rows)
vdots_height = height_ratio * avg_row_height
avg_col_width = cols.get_width() / len(cols)
hdots_width = width_ratio * avg_col_width
use_vdots = row_index is not None and -len(rows) <= row_index < len(rows)
use_hdots = col_index is not None and -len(cols) <= col_index < len(cols)
if use_vdots:
for column in cols:
# Add vdots
dots = Tex(R"\vdots")
dots.set_height(vdots_height)
self.swap_entry_for_dots(column[row_index], dots)
if use_hdots:
for row in rows:
# Add hdots
dots = Tex(R"\hdots")
dots.set_width(hdots_width)
self.swap_entry_for_dots(row[col_index], dots)
if use_vdots and use_hdots:
rows[row_index][col_index].rotate(-45 * DEGREES)
return self
def get_mob_matrix(self) -> VMobjectMatrixType:
return self.mob_matrix return self.mob_matrix
def get_entries(self) -> VGroup: def get_entries(self) -> VGroup:
return self.elements return VGroup(*self.elements)
def get_brackets(self) -> VGroup: def get_brackets(self) -> VGroup:
return self.brackets return VGroup(*self.brackets)
def get_ellipses(self) -> VGroup:
return VGroup(*self.ellipses)
class DecimalMatrix(Matrix): class DecimalMatrix(Matrix):
def element_to_mobject(self, element: float, num_decimal_places: int = 1, **config) -> DecimalNumber:
return DecimalNumber(element, num_decimal_places=num_decimal_places, **config)
class IntegerMatrix(Matrix):
def __init__( def __init__(
self, self,
matrix: npt.ArrayLike, matrix: FloatMatrixType,
element_alignment_corner: Vect3 = UP, num_decimal_places: int = 2,
**kwargs decimal_config: dict = dict(),
**config
): ):
super().__init__(matrix, element_alignment_corner=element_alignment_corner, **kwargs) self.float_matrix = matrix
super().__init__(
matrix,
element_config=dict(
num_decimal_places=num_decimal_places,
**decimal_config
),
**config
)
def element_to_mobject(self, element: int, **config) -> Integer: def element_to_mobject(self, element, **decimal_config) -> DecimalNumber:
return Integer(element, **config) return DecimalNumber(element, **decimal_config)
class IntegerMatrix(DecimalMatrix):
def __init__(
self,
matrix: FloatMatrixType,
num_decimal_places: int = 0,
decimal_config: dict = dict(),
**config
):
super().__init__(matrix, num_decimal_places, decimal_config, **config)
class TexMatrix(Matrix):
def __init__(
self,
matrix: StringMatrixType,
tex_config: dict = dict(),
**config,
):
super().__init__(
matrix,
element_config=tex_config,
**config
)
class MobjectMatrix(Matrix): class MobjectMatrix(Matrix):
def __init__(
self,
group: VGroup,
n_rows: int | None = None,
n_cols: int | None = None,
height: float = 4.0,
element_alignment_corner=ORIGIN,
**config,
):
# Have fallback defaults of n_rows and n_cols
n_mobs = len(group)
if n_rows is None:
n_rows = int(np.sqrt(n_mobs)) if n_cols is None else n_mobs // n_cols
if n_cols is None:
n_cols = n_mobs // n_rows
if len(group) < n_rows * n_cols:
raise Exception("Input to MobjectMatrix must have at least n_rows * n_cols entries")
mob_matrix = [
[group[n * n_cols + k] for k in range(n_cols)]
for n in range(n_rows)
]
config.update(
height=height,
element_alignment_corner=element_alignment_corner,
)
super().__init__(mob_matrix, **config)
def element_to_mobject(self, element: VMobject, **config) -> VMobject: def element_to_mobject(self, element: VMobject, **config) -> VMobject:
return element return element
def get_det_text(
matrix: Matrix,
determinant: int | str | None = None,
background_rect: bool = False,
initial_scale_factor: int = 2
) -> VGroup:
parens = Tex("()")
parens.scale(initial_scale_factor)
parens.stretch_to_fit_height(matrix.get_height())
l_paren, r_paren = parens.split()
l_paren.next_to(matrix, LEFT, buff=0.1)
r_paren.next_to(matrix, RIGHT, buff=0.1)
det = TexText("det")
det.scale(initial_scale_factor)
det.next_to(l_paren, LEFT, buff=0.1)
if background_rect:
det.add_background_rectangle()
det_text = VGroup(det, l_paren, r_paren)
if determinant is not None:
eq = Tex("=")
eq.next_to(r_paren, RIGHT, buff=0.1)
result = Tex(str(determinant))
result.next_to(eq, RIGHT, buff=0.2)
det_text.add(eq, result)
return det_text

View file

@ -39,20 +39,23 @@ from manimlib.utils.iterables import resize_with_interpolation
from manimlib.utils.bezier import integer_interpolate from manimlib.utils.bezier import integer_interpolate
from manimlib.utils.bezier import interpolate from manimlib.utils.bezier import interpolate
from manimlib.utils.paths import straight_path from manimlib.utils.paths import straight_path
from manimlib.utils.simple_functions import get_parameters
from manimlib.utils.shaders import get_colormap_code from manimlib.utils.shaders import get_colormap_code
from manimlib.utils.space_ops import angle_of_vector from manimlib.utils.space_ops import angle_of_vector
from manimlib.utils.space_ops import get_norm from manimlib.utils.space_ops import get_norm
from manimlib.utils.space_ops import rotation_matrix_transpose from manimlib.utils.space_ops import rotation_matrix_transpose
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import TypeVar, Generic, Iterable
SubmobjectType = TypeVar('SubmobjectType', bound='Mobject')
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable, Iterable, Iterator, Union, Tuple, Optional from typing import Callable, Iterator, Union, Tuple, Optional, Any
import numpy.typing as npt import numpy.typing as npt
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, UniformDict, Self from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, UniformDict, Self
from moderngl.context import Context from moderngl.context import Context
T = TypeVar('T')
TimeBasedUpdater = Callable[["Mobject", float], "Mobject" | None] TimeBasedUpdater = Callable[["Mobject", float], "Mobject" | None]
NonTimeUpdater = Callable[["Mobject"], "Mobject" | None] NonTimeUpdater = Callable[["Mobject"], "Mobject" | None]
Updater = Union[TimeBasedUpdater, NonTimeUpdater] Updater = Union[TimeBasedUpdater, NonTimeUpdater]
@ -88,21 +91,20 @@ class Mobject(object):
self.opacity = opacity self.opacity = opacity
self.shading = shading self.shading = shading
self.texture_paths = texture_paths self.texture_paths = texture_paths
self._is_fixed_in_frame = is_fixed_in_frame
self.depth_test = depth_test self.depth_test = depth_test
# Internal state # Internal state
self.submobjects: list[Mobject] = [] self.submobjects: list[Mobject] = []
self.parents: list[Mobject] = [] self.parents: list[Mobject] = []
self.family: list[Mobject] = [self] self.family: list[Mobject] | None = [self]
self.locked_data_keys: set[str] = set() self.locked_data_keys: set[str] = set()
self.const_data_keys: set[str] = set() self.const_data_keys: set[str] = set()
self.locked_uniform_keys: set[str] = set() self.locked_uniform_keys: set[str] = set()
self.needs_new_bounding_box: bool = True
self._is_animating: bool = False
self.saved_state = None self.saved_state = None
self.target = None self.target = None
self.bounding_box: Vect3Array = np.zeros((3, 3)) self.bounding_box: Vect3Array = np.zeros((3, 3))
self._is_animating: bool = False
self._needs_new_bounding_box: bool = True
self._shaders_initialized: bool = False self._shaders_initialized: bool = False
self._data_has_changed: bool = True self._data_has_changed: bool = True
self.shader_code_replacements: dict[str, str] = dict() self.shader_code_replacements: dict[str, str] = dict()
@ -117,6 +119,8 @@ class Mobject(object):
if self.depth_test: if self.depth_test:
self.apply_depth_test() self.apply_depth_test()
if is_fixed_in_frame:
self.fix_in_frame()
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
@ -134,7 +138,7 @@ class Mobject(object):
def init_uniforms(self): def init_uniforms(self):
self.uniforms: UniformDict = { self.uniforms: UniformDict = {
"is_fixed_in_frame": float(self._is_fixed_in_frame), "is_fixed_in_frame": 0.0,
"shading": np.array(self.shading, dtype=float), "shading": np.array(self.shading, dtype=float),
} }
@ -154,9 +158,47 @@ class Mobject(object):
@property @property
def animate(self) -> _AnimationBuilder: def animate(self) -> _AnimationBuilder:
# Borrowed from https://github.com/ManimCommunity/manim/ """
Methods called with Mobject.animate.method() can be passed
into a Scene.play call, as if you were calling
ApplyMethod(mobject.method)
Borrowed from https://github.com/ManimCommunity/manim/
"""
return _AnimationBuilder(self) return _AnimationBuilder(self)
@property
def always(self) -> _UpdaterBuilder:
"""
Methods called with mobject.always.method(*args, **kwargs)
will result in the call mobject.method(*args, **kwargs)
on every frame
"""
return _UpdaterBuilder(self)
@property
def f_always(self) -> _FunctionalUpdaterBuilder:
"""
Similar to Mobject.always, but with the intent that arguments
are functions returning the corresponding type fit for the method
Methods called with
mobject.f_always.method(
func1, func2, ...,
kwarg1=kw_func1,
kwarg2=kw_func2,
...
)
will result in the call
mobject.method(
func1(), func2(), ...,
kwarg1=kw_func1(),
kwarg2=kw_func2(),
...
)
on every frame
"""
return _FunctionalUpdaterBuilder(self)
def note_changed_data(self, recurse_up: bool = True) -> Self: def note_changed_data(self, recurse_up: bool = True) -> Self:
self._data_has_changed = True self._data_has_changed = True
if recurse_up: if recurse_up:
@ -164,20 +206,23 @@ class Mobject(object):
mob.note_changed_data() mob.note_changed_data()
return self return self
def affects_data(func: Callable): @staticmethod
def affects_data(func: Callable[..., T]) -> Callable[..., T]:
@wraps(func) @wraps(func)
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
func(self, *args, **kwargs) result = func(self, *args, **kwargs)
self.note_changed_data() self.note_changed_data()
return result
return wrapper return wrapper
def affects_family_data(func: Callable): @staticmethod
def affects_family_data(func: Callable[..., T]) -> Callable[..., T]:
@wraps(func) @wraps(func)
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
func(self, *args, **kwargs) result = func(self, *args, **kwargs)
for mob in self.family_members_with_points(): for mob in self.family_members_with_points():
mob.note_changed_data() mob.note_changed_data()
return self return result
return wrapper return wrapper
# Only these methods should directly affect points # Only these methods should directly affect points
@ -285,9 +330,9 @@ class Mobject(object):
return len(self.get_points()) > 0 return len(self.get_points()) > 0
def get_bounding_box(self) -> Vect3Array: def get_bounding_box(self) -> Vect3Array:
if self.needs_new_bounding_box: if self._needs_new_bounding_box:
self.bounding_box[:] = self.compute_bounding_box() self.bounding_box[:] = self.compute_bounding_box()
self.needs_new_bounding_box = False self._needs_new_bounding_box = False
return self.bounding_box return self.bounding_box
def compute_bounding_box(self) -> Vect3Array: def compute_bounding_box(self) -> Vect3Array:
@ -314,7 +359,7 @@ class Mobject(object):
recurse_up: bool = True recurse_up: bool = True
) -> Self: ) -> Self:
for mob in self.get_family(recurse_down): for mob in self.get_family(recurse_down):
mob.needs_new_bounding_box = True mob._needs_new_bounding_box = True
if recurse_up: if recurse_up:
for parent in self.parents: for parent in self.parents:
parent.refresh_bounding_box() parent.refresh_bounding_box()
@ -347,7 +392,7 @@ class Mobject(object):
# Family matters # Family matters
def __getitem__(self, value: int | slice) -> Self: def __getitem__(self, value: int | slice) -> Mobject:
if isinstance(value, slice): if isinstance(value, slice):
GroupClass = self.get_group_class() GroupClass = self.get_group_class()
return GroupClass(*self.split().__getitem__(value)) return GroupClass(*self.split().__getitem__(value))
@ -363,23 +408,26 @@ class Mobject(object):
return self.submobjects return self.submobjects
@affects_data @affects_data
def assemble_family(self) -> Self: def note_changed_family(self, only_changed_order=False) -> Self:
sub_families = (sm.get_family() for sm in self.submobjects) self.family = None
self.family = [self, *it.chain(*sub_families)] if not only_changed_order:
self.refresh_has_updater_status() self.refresh_has_updater_status()
self.refresh_bounding_box() self.refresh_bounding_box()
for parent in self.parents: for parent in self.parents:
parent.assemble_family() parent.note_changed_family()
return self return self
def get_family(self, recurse: bool = True) -> list[Self]: def get_family(self, recurse: bool = True) -> list[Mobject]:
if recurse: if not recurse:
return self.family
else:
return [self] return [self]
if self.family is None:
# Reconstruct and save
sub_families = (sm.get_family() for sm in self.submobjects)
self.family = [self, *it.chain(*sub_families)]
return self.family
def family_members_with_points(self) -> list[Self]: def family_members_with_points(self) -> list[Mobject]:
return [m for m in self.family if len(m.data) > 0] return [m for m in self.get_family() if len(m.data) > 0]
def get_ancestors(self, extended: bool = False) -> list[Mobject]: def get_ancestors(self, extended: bool = False) -> list[Mobject]:
""" """
@ -410,7 +458,7 @@ class Mobject(object):
self.submobjects.append(mobject) self.submobjects.append(mobject)
if self not in mobject.parents: if self not in mobject.parents:
mobject.parents.append(self) mobject.parents.append(self)
self.assemble_family() self.note_changed_family()
return self return self
def remove( def remove(
@ -426,7 +474,7 @@ class Mobject(object):
if parent in child.parents: if parent in child.parents:
child.parents.remove(parent) child.parents.remove(parent)
if reassemble: if reassemble:
parent.assemble_family() parent.note_changed_family()
return self return self
def clear(self) -> Self: def clear(self) -> Self:
@ -443,12 +491,12 @@ class Mobject(object):
old_submob.parents.remove(self) old_submob.parents.remove(self)
self.submobjects[index] = new_submob self.submobjects[index] = new_submob
new_submob.parents.append(self) new_submob.parents.append(self)
self.assemble_family() self.note_changed_family()
return self return self
def insert_submobject(self, index: int, new_submob: Mobject) -> Self: def insert_submobject(self, index: int, new_submob: Mobject) -> Self:
self.submobjects.insert(index, new_submob) self.submobjects.insert(index, new_submob)
self.assemble_family() self.note_changed_family()
return self return self
def set_submobjects(self, submobject_list: list[Mobject]) -> Self: def set_submobjects(self, submobject_list: list[Mobject]) -> Self:
@ -495,12 +543,11 @@ class Mobject(object):
fill_rows_first: bool = True fill_rows_first: bool = True
) -> Self: ) -> Self:
submobs = self.submobjects submobs = self.submobjects
if n_rows is None and n_cols is None: n_submobs = len(submobs)
n_rows = int(np.sqrt(len(submobs)))
if n_rows is None: if n_rows is None:
n_rows = len(submobs) // n_cols n_rows = int(np.sqrt(n_submobs)) if n_cols is None else n_submobs // n_cols
if n_cols is None: if n_cols is None:
n_cols = len(submobs) // n_rows n_cols = n_submobs // n_rows
if buff is not None: if buff is not None:
h_buff = buff h_buff = buff
@ -561,7 +608,7 @@ class Mobject(object):
self.submobjects.sort(key=submob_func) self.submobjects.sort(key=submob_func)
else: else:
self.submobjects.sort(key=lambda m: point_to_num_func(m.get_center())) self.submobjects.sort(key=lambda m: point_to_num_func(m.get_center()))
self.assemble_family() self.note_changed_family(only_changed_order=True)
return self return self
def shuffle(self, recurse: bool = False) -> Self: def shuffle(self, recurse: bool = False) -> Self:
@ -569,17 +616,18 @@ class Mobject(object):
for submob in self.submobjects: for submob in self.submobjects:
submob.shuffle(recurse=True) submob.shuffle(recurse=True)
random.shuffle(self.submobjects) random.shuffle(self.submobjects)
self.assemble_family() self.note_changed_family(only_changed_order=True)
return self return self
def reverse_submobjects(self) -> Self: def reverse_submobjects(self) -> Self:
self.submobjects.reverse() self.submobjects.reverse()
self.assemble_family() self.note_changed_family(only_changed_order=True)
return self return self
# Copying and serialization # Copying and serialization
def stash_mobject_pointers(func: Callable): @staticmethod
def stash_mobject_pointers(func: Callable[..., T]) -> Callable[..., T]:
@wraps(func) @wraps(func)
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
uncopied_attrs = ["parents", "target", "saved_state"] uncopied_attrs = ["parents", "target", "saved_state"]
@ -637,8 +685,7 @@ class Mobject(object):
# Similarly, instead of calling match_updaters, since we know the status # Similarly, instead of calling match_updaters, since we know the status
# won't have changed, just directly match. # won't have changed, just directly match.
result.non_time_updaters = list(self.non_time_updaters) result.updaters = list(self.updaters)
result.time_based_updaters = list(self.time_based_updaters)
result._data_has_changed = True result._data_has_changed = True
result._shaders_initialized = False result._shaders_initialized = False
@ -646,7 +693,7 @@ class Mobject(object):
for attr, value in self.__dict__.items(): for attr, value in self.__dict__.items():
if isinstance(value, Mobject) and value is not self: if isinstance(value, Mobject) and value is not self:
if value in family: if value in family:
setattr(result, attr, result.family[self.family.index(value)]) setattr(result, attr, result.family[family.index(value)])
elif isinstance(value, np.ndarray): elif isinstance(value, np.ndarray):
setattr(result, attr, value.copy()) setattr(result, attr, value.copy())
return result return result
@ -698,7 +745,7 @@ class Mobject(object):
sm1.texture_paths = sm2.texture_paths sm1.texture_paths = sm2.texture_paths
sm1.depth_test = sm2.depth_test sm1.depth_test = sm2.depth_test
sm1.render_primitive = sm2.render_primitive sm1.render_primitive = sm2.render_primitive
sm1.needs_new_bounding_box = sm2.needs_new_bounding_box sm1._needs_new_bounding_box = sm2._needs_new_bounding_box
# Make sure named family members carry over # Make sure named family members carry over
for attr, value in list(mobject.__dict__.items()): for attr, value in list(mobject.__dict__.items()):
if isinstance(value, Mobject) and value in family2: if isinstance(value, Mobject) and value in family2:
@ -782,78 +829,57 @@ class Mobject(object):
# Updating # Updating
def init_updaters(self): def init_updaters(self):
self.time_based_updaters: list[TimeBasedUpdater] = [] self.updaters: list[Updater] = list()
self.non_time_updaters: list[NonTimeUpdater] = [] self._has_updaters_in_family: Optional[bool] = False
self.has_updaters: bool = False
self.updating_suspended: bool = False self.updating_suspended: bool = False
def update(self, dt: float = 0, recurse: bool = True) -> Self: def update(self, dt: float = 0, recurse: bool = True) -> Self:
if not self.has_updaters or self.updating_suspended: if not self.has_updaters() or self.updating_suspended:
return self return self
if recurse: if recurse:
for submob in self.submobjects: for submob in self.submobjects:
submob.update(dt, recurse) submob.update(dt, recurse)
for updater in self.time_based_updaters: for updater in self.updaters:
updater(self, dt) # This is hacky, but if an updater takes dt as an arg,
for updater in self.non_time_updaters: # it will be passed the change in time from here
updater(self) if "dt" in updater.__code__.co_varnames:
updater(self, dt=dt)
else:
updater(self)
return self return self
def get_time_based_updaters(self) -> list[TimeBasedUpdater]:
return self.time_based_updaters
def has_time_based_updater(self) -> bool:
return len(self.time_based_updaters) > 0
def get_updaters(self) -> list[Updater]: def get_updaters(self) -> list[Updater]:
return self.time_based_updaters + self.non_time_updaters return self.updaters
def get_family_updaters(self) -> list[Updater]: def add_updater(self, update_func: Updater, call: bool = True) -> Self:
return list(it.chain(*[sm.get_updaters() for sm in self.get_family()])) self.updaters.append(update_func)
if call:
def add_updater(
self,
update_function: Updater,
index: int | None = None,
call_updater: bool = True
) -> Self:
if "dt" in get_parameters(update_function):
updater_list = self.time_based_updaters
else:
updater_list = self.non_time_updaters
if index is None:
updater_list.append(update_function)
else:
updater_list.insert(index, update_function)
self.refresh_has_updater_status()
for parent in self.parents:
parent.has_updaters = True
if call_updater:
self.update(dt=0) self.update(dt=0)
self.refresh_has_updater_status()
return self return self
def remove_updater(self, update_function: Updater) -> Self: def insert_updater(self, update_func: Updater, index=0):
for updater_list in [self.time_based_updaters, self.non_time_updaters]: self.updaters.insert(index, update_func)
while update_function in updater_list: self.refresh_has_updater_status()
updater_list.remove(update_function) return self
def remove_updater(self, update_func: Updater) -> Self:
while update_func in self.updaters:
self.updaters.remove(update_func)
self.refresh_has_updater_status() self.refresh_has_updater_status()
return self return self
def clear_updaters(self, recurse: bool = True) -> Self: def clear_updaters(self, recurse: bool = True) -> Self:
self.time_based_updaters = [] for mob in self.get_family(recurse):
self.non_time_updaters = [] mob.updaters = []
if recurse: mob._has_updaters_in_family = False
for submob in self.submobjects: for parent in self.get_ancestors():
submob.clear_updaters() parent._has_updaters_in_family = False
self.refresh_has_updater_status()
return self return self
def match_updaters(self, mobject: Mobject) -> Self: def match_updaters(self, mobject: Mobject) -> Self:
self.clear_updaters() self.updaters = list(mobject.updaters)
for updater in mobject.get_updaters(): self.refresh_has_updater_status()
self.add_updater(updater)
return self return self
def suspend_updating(self, recurse: bool = True) -> Self: def suspend_updating(self, recurse: bool = True) -> Self:
@ -874,14 +900,24 @@ class Mobject(object):
self.update(dt=0, recurse=recurse) self.update(dt=0, recurse=recurse)
return self return self
def has_updaters(self) -> bool:
if self._has_updaters_in_family is None:
# Recompute and save
self._has_updaters_in_family = bool(self.updaters) or any(
sm.has_updaters() for sm in self.submobjects
)
return self._has_updaters_in_family
def refresh_has_updater_status(self) -> Self: def refresh_has_updater_status(self) -> Self:
self.has_updaters = any(mob.get_updaters() for mob in self.get_family()) self._has_updaters_in_family = None
for parent in self.parents:
parent.refresh_has_updater_status()
return self return self
# Check if mark as static or not for camera # Check if mark as static or not for camera
def is_changing(self) -> bool: def is_changing(self) -> bool:
return self._is_animating or self.has_updaters return self._is_animating or self.has_updaters()
def set_animating_status(self, is_animating: bool, recurse: bool = True) -> Self: def set_animating_status(self, is_animating: bool, recurse: bool = True) -> Self:
for mob in (*self.get_family(recurse), *self.get_ancestors()): for mob in (*self.get_family(recurse), *self.get_ancestors()):
@ -1368,7 +1404,7 @@ class Mobject(object):
return rgb_to_hex(self.data["rgba"][0, :3]) return rgb_to_hex(self.data["rgba"][0, :3])
def get_opacity(self) -> float: def get_opacity(self) -> float:
return self.data["rgba"][0, 3] return float(self.data["rgba"][0, 3])
def set_color_by_gradient(self, *colors: ManimColor) -> Self: def set_color_by_gradient(self, *colors: ManimColor) -> Self:
if self.has_points(): if self.has_points():
@ -1816,13 +1852,13 @@ class Mobject(object):
interpolate can skip this, and so that it's not interpolate can skip this, and so that it's not
read into the shader_wrapper objects needlessly read into the shader_wrapper objects needlessly
""" """
if self.has_updaters: if self.has_updaters():
return self return self
self.locked_data_keys = set(keys) self.locked_data_keys = set(keys)
return self return self
def lock_uniforms(self, keys: Iterable[str]) -> Self: def lock_uniforms(self, keys: Iterable[str]) -> Self:
if self.has_updaters: if self.has_updaters():
return self return self
self.locked_uniform_keys = set(keys) self.locked_uniform_keys = set(keys)
return self return self
@ -1864,7 +1900,8 @@ class Mobject(object):
# Operations touching shader uniforms # Operations touching shader uniforms
def affects_shader_info_id(func: Callable): @staticmethod
def affects_shader_info_id(func: Callable[..., T]) -> Callable[..., T]:
@wraps(func) @wraps(func)
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
result = func(self, *args, **kwargs) result = func(self, *args, **kwargs)
@ -2126,19 +2163,29 @@ class Mobject(object):
raise Exception(message.format(caller_name)) raise Exception(message.format(caller_name))
class Group(Mobject): class Group(Mobject, Generic[SubmobjectType]):
def __init__(self, *mobjects: Mobject, **kwargs): def __init__(self, *mobjects: SubmobjectType | Iterable[SubmobjectType], **kwargs):
if not all([isinstance(m, Mobject) for m in mobjects]): super().__init__(**kwargs)
raise Exception("All submobjects must be of type Mobject") self._ingest_args(*mobjects)
Mobject.__init__(self, **kwargs)
self.add(*mobjects) def _ingest_args(self, *args: Mobject | Iterable[Mobject]):
if any(m.is_fixed_in_frame() for m in mobjects): if len(args) == 0:
self.fix_in_frame() return
if all(isinstance(mob, Mobject) for mob in args):
self.add(*args)
elif isinstance(args[0], Iterable):
self.add(*args[0])
else:
raise Exception(f"Invalid argument to Group of type {type(args[0])}")
def __add__(self, other: Mobject | Group) -> Self: def __add__(self, other: Mobject | Group) -> Self:
assert(isinstance(other, Mobject)) assert isinstance(other, Mobject)
return self.add(other) return self.add(other)
# This is just here to make linters happy with references to things like Group(...)[0]
def __getitem__(self, index) -> SubmobjectType:
return super().__getitem__(index)
class Point(Mobject): class Point(Mobject):
def __init__( def __init__(
@ -2245,3 +2292,35 @@ def override_animate(method):
return animation_method return animation_method
return decorator return decorator
class _UpdaterBuilder:
def __init__(self, mobject: Mobject):
self.mobject = mobject
def __getattr__(self, method_name: str):
def add_updater(*method_args, **method_kwargs):
self.mobject.add_updater(
lambda m: getattr(m, method_name)(*method_args, **method_kwargs)
)
return self
return add_updater
class _FunctionalUpdaterBuilder:
def __init__(self, mobject: Mobject):
self.mobject = mobject
def __getattr__(self, method_name: str):
def add_updater(*method_args, **method_kwargs):
self.mobject.add_updater(
lambda m: getattr(m, method_name)(
*(arg() for arg in method_args),
**{
key: value()
for key, value in method_kwargs.items()
}
)
)
return self
return add_updater

View file

@ -16,7 +16,7 @@ from manimlib.utils.simple_functions import fdiv
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Iterable from typing import Iterable, Optional
from manimlib.typing import ManimColor, Vect3, Vect3Array, VectN, RangeSpecifier from manimlib.typing import ManimColor, Vect3, Vect3Array, VectN, RangeSpecifier
@ -28,13 +28,14 @@ class NumberLine(Line):
stroke_width: float = 2.0, stroke_width: float = 2.0,
# How big is one one unit of this number line in terms of absolute spacial distance # How big is one one unit of this number line in terms of absolute spacial distance
unit_size: float = 1.0, unit_size: float = 1.0,
width: float | None = None, width: Optional[float] = None,
include_ticks: bool = True, include_ticks: bool = True,
tick_size: float = 0.1, tick_size: float = 0.1,
longer_tick_multiple: float = 1.5, longer_tick_multiple: float = 1.5,
tick_offset: float = 0.0, tick_offset: float = 0.0,
# Change name # Change name
numbers_with_elongated_ticks: list[float] = [], big_tick_spacing: Optional[float] = None,
big_tick_numbers: list[float] = [],
include_numbers: bool = False, include_numbers: bool = False,
line_to_number_direction: Vect3 = DOWN, line_to_number_direction: Vect3 = DOWN,
line_to_number_buff: float = MED_SMALL_BUFF, line_to_number_buff: float = MED_SMALL_BUFF,
@ -54,7 +55,14 @@ class NumberLine(Line):
self.tick_size = tick_size self.tick_size = tick_size
self.longer_tick_multiple = longer_tick_multiple self.longer_tick_multiple = longer_tick_multiple
self.tick_offset = tick_offset self.tick_offset = tick_offset
self.numbers_with_elongated_ticks = list(numbers_with_elongated_ticks) if big_tick_spacing is not None:
self.big_tick_numbers = np.arange(
x_range[0],
x_range[1] + big_tick_spacing,
big_tick_spacing,
)
else:
self.big_tick_numbers = list(big_tick_numbers)
self.line_to_number_direction = line_to_number_direction self.line_to_number_direction = line_to_number_direction
self.line_to_number_buff = line_to_number_buff self.line_to_number_buff = line_to_number_buff
self.include_tip = include_tip self.include_tip = include_tip
@ -101,7 +109,7 @@ class NumberLine(Line):
ticks = VGroup() ticks = VGroup()
for x in self.get_tick_range(): for x in self.get_tick_range():
size = self.tick_size size = self.tick_size
if np.isclose(self.numbers_with_elongated_ticks, x).any(): if np.isclose(self.big_tick_numbers, x).any():
size *= self.longer_tick_multiple size *= self.longer_tick_multiple
ticks.add(self.get_tick(x, size)) ticks.add(self.get_tick(x, size))
self.add(ticks) self.add(ticks)
@ -210,7 +218,7 @@ class UnitInterval(NumberLine):
self, self,
x_range: RangeSpecifier = (0, 1, 0.1), x_range: RangeSpecifier = (0, 1, 0.1),
unit_size: float = 10, unit_size: float = 10,
numbers_with_elongated_ticks: list[float] = [0, 1], big_tick_numbers: list[float] = [0, 1],
decimal_number_config: dict = dict( decimal_number_config: dict = dict(
num_decimal_places=1, num_decimal_places=1,
) )
@ -218,6 +226,6 @@ class UnitInterval(NumberLine):
super().__init__( super().__init__(
x_range=x_range, x_range=x_range,
unit_size=unit_size, unit_size=unit_size,
numbers_with_elongated_ticks=numbers_with_elongated_ticks, big_tick_numbers=big_tick_numbers,
decimal_number_config=decimal_number_config, decimal_number_config=decimal_number_config,
) )

View file

@ -1,4 +1,5 @@
from __future__ import annotations from __future__ import annotations
from functools import lru_cache
import numpy as np import numpy as np
@ -17,6 +18,11 @@ if TYPE_CHECKING:
T = TypeVar("T", bound=VMobject) T = TypeVar("T", bound=VMobject)
@lru_cache()
def char_to_cahced_mob(char: str, **text_config):
return Text(char, **text_config)
class DecimalNumber(VMobject): class DecimalNumber(VMobject):
def __init__( def __init__(
self, self,
@ -46,7 +52,6 @@ class DecimalNumber(VMobject):
self.edge_to_fix = edge_to_fix self.edge_to_fix = edge_to_fix
self.font_size = font_size self.font_size = font_size
self.text_config = dict(text_config) self.text_config = dict(text_config)
self.char_to_mob_map = dict()
super().__init__( super().__init__(
color=color, color=color,
@ -59,36 +64,44 @@ class DecimalNumber(VMobject):
self.init_colors() self.init_colors()
def set_submobjects_from_number(self, number: float | complex) -> None: def set_submobjects_from_number(self, number: float | complex) -> None:
# Create the submobject list
self.number = number self.number = number
self.set_submobjects([]) self.num_string = self.get_num_string(number)
self.text_config["font_size"] = self.get_font_size()
num_string = self.num_string = self.get_num_string(number)
self.add(*map(self.char_to_mob, num_string))
# Add non-numerical bits # Submob_templates will be a list of cached Tex and Text mobjects,
# with the intent of calling .copy or .become on them
submob_templates = list(map(self.char_to_mob, self.num_string))
if self.show_ellipsis: if self.show_ellipsis:
dots = self.char_to_mob("...") dots = self.char_to_mob("...")
dots.arrange(RIGHT, buff=2 * dots[0].get_width()) dots.arrange(RIGHT, buff=2 * dots[0].get_width())
self.add(dots) submob_templates.append(dots)
if self.unit is not None: if self.unit is not None:
self.unit_sign = Tex(self.unit, font_size=self.get_font_size()) submob_templates.append(self.char_to_mob(self.unit))
self.add(self.unit_sign)
self.arrange( # Set internals
buff=self.digit_buff_per_font_unit * self.get_font_size(), font_size = self.get_font_size()
aligned_edge=DOWN if len(submob_templates) == len(self.submobjects):
) for sm, smt in zip(self.submobjects, submob_templates):
sm.become(smt)
sm.scale(font_size / smt.font_size)
else:
self.set_submobjects([
smt.copy().scale(font_size / smt.font_size)
for smt in submob_templates
])
# Handle alignment of parts that should be aligned digit_buff = self.digit_buff_per_font_unit * font_size
# to the bottom self.arrange(RIGHT, buff=digit_buff, aligned_edge=DOWN)
for i, c in enumerate(num_string):
if c == "" and len(num_string) > i + 1: # Handle alignment of special characters
for i, c in enumerate(self.num_string):
if c == "" and len(self.num_string) > i + 1:
self[i].align_to(self[i + 1], UP) self[i].align_to(self[i + 1], UP)
self[i].shift(self[i + 1].get_height() * DOWN / 2) self[i].shift(self[i + 1].get_height() * DOWN / 2)
elif c == ",": elif c == ",":
self[i].shift(self[i].get_height() * DOWN / 2) self[i].shift(self[i].get_height() * DOWN / 2)
if self.unit and self.unit.startswith("^"): if self.unit and self.unit.startswith("^"):
self.unit_sign.align_to(self, UP) self[-1].align_to(self, UP)
if self.include_background_rectangle: if self.include_background_rectangle:
self.add_background_rectangle() self.add_background_rectangle()
@ -111,12 +124,8 @@ class DecimalNumber(VMobject):
num_string = num_string.replace("-", "") num_string = num_string.replace("-", "")
return num_string return num_string
def char_to_mob(self, char: str) -> Tex | Text: def char_to_mob(self, char: str) -> Text:
if char not in self.char_to_mob_map: return char_to_cahced_mob(char, **self.text_config)
self.char_to_mob_map[char] = Text(char, **self.text_config)
result = self.char_to_mob_map[char].copy()
result.scale(self.get_font_size() / result.font_size)
return result
def init_uniforms(self) -> None: def init_uniforms(self) -> None:
super().init_uniforms() super().init_uniforms()
@ -171,7 +180,8 @@ class DecimalNumber(VMobject):
self.set_submobjects_from_number(number) self.set_submobjects_from_number(number)
self.move_to(move_to_point, self.edge_to_fix) self.move_to(move_to_point, self.edge_to_fix)
self.set_style(**style) self.set_style(**style)
self.fix_in_frame(self._is_fixed_in_frame) for submob in self.get_family():
submob.uniforms.update(self.uniforms)
return self return self
def _handle_scale_side_effects(self, scale_factor: float) -> Self: def _handle_scale_side_effects(self, scale_factor: float) -> Self:

View file

@ -30,6 +30,8 @@ class SurroundingRectangle(Rectangle):
super().__init__(color=color, **kwargs) super().__init__(color=color, **kwargs)
self.buff = buff self.buff = buff
self.surround(mobject) self.surround(mobject)
if mobject.is_fixed_in_frame():
self.fix_in_frame()
def surround(self, mobject, buff=None) -> Self: def surround(self, mobject, buff=None) -> Self:
self.mobject = mobject self.mobject = mobject
@ -120,13 +122,9 @@ class Underline(Line):
stretch_factor=1.2, stretch_factor=1.2,
**kwargs **kwargs
): ):
super().__init__( super().__init__(LEFT, RIGHT, **kwargs)
LEFT, RIGHT, if not isinstance(stroke_width, (float, int)):
stroke_color=stroke_color, self.insert_n_curves(len(stroke_width) - 2)
stroke_width=stroke_width,
**kwargs
)
self.insert_n_curves(30)
self.set_stroke(stroke_color, stroke_width) self.set_stroke(stroke_color, stroke_width)
self.set_width(mobject.get_width() * stretch_factor) self.set_width(mobject.get_width() * stretch_factor)
self.next_to(mobject, DOWN, buff=buff) self.next_to(mobject, DOWN, buff=buff)

View file

@ -91,8 +91,9 @@ class Brace(Tex):
return text_mob return text_mob
def get_tex(self, *tex: str, **kwargs) -> Tex: def get_tex(self, *tex: str, **kwargs) -> Tex:
tex_mob = Tex(*tex) buff = kwargs.pop("buff", SMALL_BUFF)
self.put_at_tip(tex_mob, **kwargs) tex_mob = Tex(*tex, **kwargs)
self.put_at_tip(tex_mob, buff=buff)
return tex_mob return tex_mob
def get_tip(self) -> np.ndarray: def get_tip(self) -> np.ndarray:

View file

@ -2,6 +2,7 @@ from __future__ import annotations
import numpy as np import numpy as np
import itertools as it import itertools as it
import random
from manimlib.animation.composition import AnimationGroup from manimlib.animation.composition import AnimationGroup
from manimlib.animation.rotation import Rotating from manimlib.animation.rotation import Rotating
@ -24,6 +25,7 @@ from manimlib.constants import LEFT
from manimlib.constants import LEFT from manimlib.constants import LEFT
from manimlib.constants import MED_LARGE_BUFF from manimlib.constants import MED_LARGE_BUFF
from manimlib.constants import MED_SMALL_BUFF from manimlib.constants import MED_SMALL_BUFF
from manimlib.constants import LARGE_BUFF
from manimlib.constants import ORIGIN from manimlib.constants import ORIGIN
from manimlib.constants import OUT from manimlib.constants import OUT
from manimlib.constants import PI from manimlib.constants import PI
@ -41,6 +43,7 @@ from manimlib.constants import WHITE
from manimlib.constants import YELLOW from manimlib.constants import YELLOW
from manimlib.constants import TAU from manimlib.constants import TAU
from manimlib.mobject.boolean_ops import Difference from manimlib.mobject.boolean_ops import Difference
from manimlib.mobject.boolean_ops import Union
from manimlib.mobject.geometry import Arc from manimlib.mobject.geometry import Arc
from manimlib.mobject.geometry import Circle from manimlib.mobject.geometry import Circle
from manimlib.mobject.geometry import Dot from manimlib.mobject.geometry import Dot
@ -51,6 +54,7 @@ from manimlib.mobject.geometry import Square
from manimlib.mobject.geometry import AnnularSector from manimlib.mobject.geometry import AnnularSector
from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Mobject
from manimlib.mobject.numbers import Integer from manimlib.mobject.numbers import Integer
from manimlib.mobject.shape_matchers import SurroundingRectangle
from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.svg.svg_mobject import SVGMobject
from manimlib.mobject.svg.tex_mobject import Tex from manimlib.mobject.svg.tex_mobject import Tex
from manimlib.mobject.svg.tex_mobject import TexText from manimlib.mobject.svg.tex_mobject import TexText
@ -59,9 +63,13 @@ from manimlib.mobject.three_dimensions import Prismify
from manimlib.mobject.three_dimensions import VCube from manimlib.mobject.three_dimensions import VCube
from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.mobject.svg.text_mobject import Text
from manimlib.utils.bezier import interpolate
from manimlib.utils.iterables import adjacent_pairs
from manimlib.utils.rate_functions import linear from manimlib.utils.rate_functions import linear
from manimlib.utils.space_ops import angle_of_vector from manimlib.utils.space_ops import angle_of_vector
from manimlib.utils.space_ops import compass_directions from manimlib.utils.space_ops import compass_directions
from manimlib.utils.space_ops import get_norm
from manimlib.utils.space_ops import midpoint from manimlib.utils.space_ops import midpoint
from manimlib.utils.space_ops import rotate_vector from manimlib.utils.space_ops import rotate_vector
@ -344,66 +352,76 @@ class ClockPassesTime(AnimationGroup):
) )
class Bubble(SVGMobject): class Bubble(VGroup):
file_name: str = "Bubbles_speech.svg" file_name: str = "Bubbles_speech.svg"
bubble_center_adjustment_factor = 0.125
def __init__( def __init__(
self, self,
content: str | VMobject | None = None,
buff: float = 1.0,
filler_shape: Tuple[float, float] = (3.0, 2.0),
pin_point: Vect3 | None = None,
direction: Vect3 = LEFT, direction: Vect3 = LEFT,
center_point: Vect3 = ORIGIN, add_content: bool = True,
content_scale_factor: float = 0.7,
height: float = 4.0,
width: float = 8.0,
max_height: float | None = None,
max_width: float | None = None,
bubble_center_adjustment_factor: float = 0.125,
fill_color: ManimColor = BLACK, fill_color: ManimColor = BLACK,
fill_opacity: float = 0.8, fill_opacity: float = 0.8,
stroke_color: ManimColor = WHITE, stroke_color: ManimColor = WHITE,
stroke_width: float = 3.0, stroke_width: float = 3.0,
**kwargs **kwargs
): ):
self.direction = LEFT # Possibly updated below by self.flip() super().__init__(**kwargs)
self.bubble_center_adjustment_factor = bubble_center_adjustment_factor self.direction = direction
self.content_scale_factor = content_scale_factor
super().__init__( if content is None:
fill_color=fill_color, content = Rectangle(*filler_shape)
fill_opacity=fill_opacity, content.set_fill(opacity=0)
stroke_color=stroke_color, content.set_stroke(width=0)
stroke_width=stroke_width, elif isinstance(content, str):
**kwargs content = Text(content)
) self.content = content
self.center() self.body = self.get_body(content, direction, buff)
self.set_height(height, stretch=True) self.body.set_fill(fill_color, fill_opacity)
self.set_width(width, stretch=True) self.body.set_stroke(stroke_color, stroke_width)
if max_height: self.add(self.body)
self.set_max_height(max_height)
if max_width: if add_content:
self.set_max_width(max_width) self.add(self.content)
if pin_point is not None:
self.pin_to(pin_point)
def get_body(self, content: VMobject, direction: Vect3, buff: float) -> VMobject:
body = SVGMobject(self.file_name)
if direction[0] > 0: if direction[0] > 0:
self.flip() body.flip()
# Resize
self.content = VMobject() width = content.get_width()
height = content.get_height()
target_width = width + min(buff, height)
target_height = 1.35 * (height + buff) # Magic number?
body.set_shape(target_width, target_height)
body.move_to(content)
body.shift(self.bubble_center_adjustment_factor * body.get_height() * DOWN)
return body
def get_tip(self): def get_tip(self):
# TODO, find a better way return self.get_corner(DOWN + self.direction)
return self.get_corner(DOWN + self.direction) - 0.6 * self.direction
def get_bubble_center(self): def get_bubble_center(self):
factor = self.bubble_center_adjustment_factor factor = self.bubble_center_adjustment_factor
return self.get_center() + factor * self.get_height() * UP return self.get_center() + factor * self.get_height() * UP
def move_tip_to(self, point): def move_tip_to(self, point):
mover = VGroup(self) self.shift(point - self.get_tip())
if self.content is not None:
mover.add(self.content)
mover.shift(point - self.get_tip())
return self return self
def flip(self, axis=UP): def flip(self, axis=UP, only_body=True, **kwargs):
super().flip(axis=axis) super().flip(axis=axis, **kwargs)
if only_body:
# Flip in place, don't use kwargs
self.content.flip(axis=axis)
if abs(axis[1]) > 0: if abs(axis[1]) > 0:
self.direction = -np.array(self.direction) self.direction = -np.array(self.direction)
return self return self
@ -418,9 +436,9 @@ class Bubble(SVGMobject):
self.move_tip_to(mob_center + vector_from_center) self.move_tip_to(mob_center + vector_from_center)
return self return self
def position_mobject_inside(self, mobject): def position_mobject_inside(self, mobject, buff=MED_LARGE_BUFF):
mobject.set_max_width(self.content_scale_factor * self.get_width()) mobject.set_max_width(self.body.get_width() - 2 * buff)
mobject.set_max_height(self.content_scale_factor * self.get_height() / 1.5) mobject.set_max_height(self.body.get_height() / 1.5 - 2 * buff)
mobject.shift(self.get_bubble_center() - mobject.get_center()) mobject.shift(self.get_bubble_center() - mobject.get_center())
return mobject return mobject
@ -429,26 +447,110 @@ class Bubble(SVGMobject):
self.content = mobject self.content = mobject
return self.content return self.content
def write(self, *text): def write(self, text):
self.add_content(TexText(*text)) self.add_content(Text(text))
return self return self
def resize_to_content(self, buff=0.75): def resize_to_content(self, buff=1.0): # TODO
width = self.content.get_width() self.body.match_points(self.get_body(
height = self.content.get_height() self.content, self.direction, buff
target_width = width + min(buff, height) ))
target_height = 1.35 * (self.content.get_height() + buff)
tip_point = self.get_tip()
self.stretch_to_fit_width(target_width, about_point=tip_point)
self.stretch_to_fit_height(target_height, about_point=tip_point)
self.position_mobject_inside(self.content)
def clear(self): def clear(self):
self.add_content(VMobject()) self.remove(self.content)
return self return self
class SpeechBubble(Bubble): class SpeechBubble(Bubble):
def __init__(
self,
content: str | VMobject | None = None,
buff: float = MED_SMALL_BUFF,
filler_shape: Tuple[float, float] = (2.0, 1.0),
stem_height_to_bubble_height: float = 0.5,
stem_top_x_props: Tuple[float, float] = (0.2, 0.3),
**kwargs
):
self.stem_height_to_bubble_height = stem_height_to_bubble_height
self.stem_top_x_props = stem_top_x_props
super().__init__(content, buff, filler_shape, **kwargs)
def get_body(self, content: VMobject, direction: Vect3, buff: float) -> VMobject:
rect = SurroundingRectangle(content, buff=buff)
rect.round_corners()
lp = rect.get_corner(DL)
rp = rect.get_corner(DR)
stem_height = self.stem_height_to_bubble_height * rect.get_height()
low_prop, high_prop = self.stem_top_x_props
triangle = Polygon(
interpolate(lp, rp, low_prop),
interpolate(lp, rp, high_prop),
lp + stem_height * DOWN,
)
result = Union(rect, triangle)
result.insert_n_curves(20)
if direction[0] > 0:
result.flip()
return result
class ThoughtBubble(Bubble):
def __init__(
self,
content: str | VMobject | None = None,
buff: float = SMALL_BUFF,
filler_shape: Tuple[float, float] = (2.0, 1.0),
bulge_radius: float = 0.35,
bulge_overlap: float = 0.25,
noise_factor: float = 0.1,
circle_radii: list[float] = [0.1, 0.15, 0.2],
**kwargs
):
self.bulge_radius = bulge_radius
self.bulge_overlap = bulge_overlap
self.noise_factor = noise_factor
self.circle_radii = circle_radii
super().__init__(content, buff, filler_shape, **kwargs)
def get_body(self, content: VMobject, direction: Vect3, buff: float) -> VMobject:
rect = SurroundingRectangle(content, buff)
perimeter = rect.get_arc_length()
radius = self.bulge_radius
step = (1 - self.bulge_overlap) * (2 * radius)
nf = self.noise_factor
corners = [rect.get_corner(v) for v in [DL, UL, UR, DR]]
points = []
for c1, c2 in adjacent_pairs(corners):
n_alphas = int(get_norm(c1 - c2) / step) + 1
for alpha in np.linspace(0, 1, n_alphas):
points.append(interpolate(
c1, c2, alpha + nf * (step / n_alphas) * (random.random() - 0.5)
))
cloud = Union(rect, *(
# Add bulges
Circle(radius=radius * (1 + nf * random.random())).move_to(point)
for point in points
))
cloud.set_stroke(WHITE, 2)
circles = VGroup(Circle(radius=radius) for radius in self.circle_radii)
circ_buff = 0.25 * self.circle_radii[0]
circles.arrange(UR, buff=circ_buff)
circles[1].shift(circ_buff * DR)
circles.next_to(cloud, DOWN, 4 * circ_buff, aligned_edge=LEFT)
circles.set_stroke(WHITE, 2)
result = VGroup(*circles, cloud)
if direction[0] > 0:
result.flip()
return result
class OldSpeechBubble(Bubble):
file_name: str = "Bubbles_speech.svg" file_name: str = "Bubbles_speech.svg"
@ -456,17 +558,16 @@ class DoubleSpeechBubble(Bubble):
file_name: str = "Bubbles_double_speech.svg" file_name: str = "Bubbles_double_speech.svg"
class ThoughtBubble(Bubble): class OldThoughtBubble(Bubble):
file_name: str = "Bubbles_thought.svg" file_name: str = "Bubbles_thought.svg"
def __init__(self, **kwargs): def get_body(self, content: VMobject, direction: Vect3, buff: float) -> VMobject:
Bubble.__init__(self, **kwargs) body = super().get_body(content, direction, buff)
self.submobjects.sort( body.sort(lambda p: p[1])
key=lambda m: m.get_bottom()[1] return body
)
def make_green_screen(self): def make_green_screen(self):
self.submobjects[-1].set_fill(GREEN_SCREEN, opacity=1) self.body[-1].set_fill(GREEN_SCREEN, opacity=1)
return self return self

View file

@ -231,7 +231,7 @@ class Tex(StringMobject):
)) ))
return re.findall(pattern, self.string) return re.findall(pattern, self.string)
def make_number_changable( def make_number_changeable(
self, self,
value: float | int | str, value: float | int | str,
index: int = 0, index: int = 0,
@ -241,7 +241,7 @@ class Tex(StringMobject):
substr = str(value) substr = str(value)
parts = self.select_parts(substr) parts = self.select_parts(substr)
if len(parts) == 0: if len(parts) == 0:
log.warning(f"{value} not found in Tex.make_number_changable call") log.warning(f"{value} not found in Tex.make_number_changeable call")
return VMobject() return VMobject()
if index > len(parts) - 1: if index > len(parts) - 1:
log.warning(f"Requested {index}th occurance of {value}, but only {len(parts)} exist") log.warning(f"Requested {index}th occurance of {value}, but only {len(parts)} exist")

View file

@ -71,5 +71,5 @@ class ImageMobject(Mobject):
rgb = self.image.getpixel(( rgb = self.image.getpixel((
int((pw - 1) * x_alpha), int((pw - 1) * x_alpha),
int((ph - 1) * y_alpha), int((ph - 1) * y_alpha),
)) ))[:3]
return np.array(rgb) / 255 return np.array(rgb) / 255

View file

@ -133,20 +133,28 @@ class Surface(Mobject):
if len(indices) == 0: if len(indices) == 0:
return np.zeros((3, 0)) return np.zeros((3, 0))
left = indices - 1 # For each point, find two adjacent points at indices
right = indices + 1 # step1 and step2, such that crossing points[step1] - points
up = indices - nv # with points[step1] - points gives a normal vector
down = indices + nv step1 = indices + 1
step2 = indices + nu
left[0] = indices[0] # Right edge
right[-1] = indices[-1] step1[nu - 1::nu] = indices[nu - 1::nu] + nu
up[:nv] = indices[:nv] step2[nu - 1::nu] = indices[nu - 1::nu] - 1
down[-nv:] = indices[-nv:]
# Bottom edge
step1[-nu:] = indices[-nu:] - nu
step2[-nu:] = indices[-nu:] + 1
# Lower right point
step1[-1] = indices[-1] - 1
step2[-1] = indices[-1] - nu
points = self.get_points() points = self.get_points()
crosses = cross( crosses = cross(
points[right] - points[left], points[step1] - points,
points[up] - points[down], points[step2] - points,
) )
self.data["normal"] = normalize_along_axis(crosses, 1) self.data["normal"] = normalize_along_axis(crosses, 1)
return self.data["normal"] return self.data["normal"]

View file

@ -13,6 +13,7 @@ from manimlib.constants import JOINT_TYPE_MAP
from manimlib.constants import ORIGIN, OUT from manimlib.constants import ORIGIN, OUT
from manimlib.constants import TAU from manimlib.constants import TAU
from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Mobject
from manimlib.mobject.mobject import Group
from manimlib.mobject.mobject import Point from manimlib.mobject.mobject import Point
from manimlib.utils.bezier import bezier from manimlib.utils.bezier import bezier
from manimlib.utils.bezier import get_quadratic_approximation_of_cubic from manimlib.utils.bezier import get_quadratic_approximation_of_cubic
@ -40,7 +41,6 @@ from manimlib.utils.space_ops import get_norm
from manimlib.utils.space_ops import get_unit_normal from manimlib.utils.space_ops import get_unit_normal
from manimlib.utils.space_ops import line_intersects_path from manimlib.utils.space_ops import line_intersects_path
from manimlib.utils.space_ops import midpoint from manimlib.utils.space_ops import midpoint
from manimlib.utils.space_ops import normalize_along_axis
from manimlib.utils.space_ops import rotation_between_vectors from manimlib.utils.space_ops import rotation_between_vectors
from manimlib.utils.space_ops import poly_line_length from manimlib.utils.space_ops import poly_line_length
from manimlib.utils.space_ops import z_to_vector from manimlib.utils.space_ops import z_to_vector
@ -48,15 +48,18 @@ from manimlib.shader_wrapper import ShaderWrapper
from manimlib.shader_wrapper import FillShaderWrapper from manimlib.shader_wrapper import FillShaderWrapper
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Generic, TypeVar, Iterable
SubVmobjectType = TypeVar('SubVmobjectType', bound='VMobject')
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable, Iterable, Tuple, Any from typing import Callable, Tuple, Any
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, Vect4Array, Self from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, Vect4Array, Self
from moderngl.context import Context from moderngl.context import Context
DEFAULT_STROKE_COLOR = GREY_A DEFAULT_STROKE_COLOR = GREY_A
DEFAULT_FILL_COLOR = GREY_C DEFAULT_FILL_COLOR = GREY_C
class VMobject(Mobject): class VMobject(Mobject):
fill_shader_folder: str = "quadratic_bezier_fill" fill_shader_folder: str = "quadratic_bezier_fill"
stroke_shader_folder: str = "quadratic_bezier_stroke" stroke_shader_folder: str = "quadratic_bezier_stroke"
@ -97,7 +100,7 @@ class VMobject(Mobject):
flat_stroke: bool = True, flat_stroke: bool = True,
use_simple_quadratic_approx: bool = False, use_simple_quadratic_approx: bool = False,
# Measured in pixel widths # Measured in pixel widths
anti_alias_width: float = 1.0, anti_alias_width: float = 1.5,
fill_border_width: float = 0.5, fill_border_width: float = 0.5,
use_winding_fill: bool = True, use_winding_fill: bool = True,
**kwargs **kwargs
@ -187,9 +190,10 @@ class VMobject(Mobject):
recurse: bool = True recurse: bool = True
) -> Self: ) -> Self:
self.set_rgba_array_by_color(color, opacity, 'fill_rgba', recurse) self.set_rgba_array_by_color(color, opacity, 'fill_rgba', recurse)
if border_width is not None: if border_width is None:
for mob in self.get_family(recurse): border_width = 0 if self.get_fill_opacity() < 1 else 0.5
mob.data["fill_border_width"] = border_width for mob in self.get_family(recurse):
mob.data["fill_border_width"] = border_width
self.note_changed_fill() self.note_changed_fill()
return self return self
@ -1415,17 +1419,23 @@ class VMobject(Mobject):
return [sw for sw in shader_wrappers if len(sw.vert_data) > 0] return [sw for sw in shader_wrappers if len(sw.vert_data) > 0]
class VGroup(VMobject): class VGroup(Group, VMobject, Generic[SubVmobjectType]):
def __init__(self, *vmobjects: VMobject, **kwargs): def __init__(self, *vmobjects: SubVmobjectType | Iterable[SubVmobjectType], **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.add(*vmobjects) if any(isinstance(vmob, Mobject) and not isinstance(vmob, VMobject) for vmob in vmobjects):
if vmobjects: raise Exception("Only VMobjects can be passed into VGroup")
self.uniforms.update(vmobjects[0].uniforms) self._ingest_args(*vmobjects)
if self.submobjects:
self.uniforms.update(self.submobjects[0].uniforms)
def __add__(self, other: VMobject) -> Self: def __add__(self, other: VMobject) -> Self:
assert(isinstance(other, VMobject)) assert isinstance(other, VMobject)
return self.add(other) return self.add(other)
# This is just here to make linters happy with references to things like VGroup(...)[0]
def __getitem__(self, index) -> SubVmobjectType:
return super().__getitem__(index)
class VectorizedPoint(Point, VMobject): class VectorizedPoint(Point, VMobject):
def __init__( def __init__(

View file

@ -48,6 +48,7 @@ RESIZE_KEY = 't'
COLOR_KEY = 'c' COLOR_KEY = 'c'
INFORMATION_KEY = 'i' INFORMATION_KEY = 'i'
CURSOR_KEY = 'k' CURSOR_KEY = 'k'
COPY_FRAME_POSITION_KEY = 'p'
# Note, a lot of the functionality here is still buggy and very much a work in progress. # Note, a lot of the functionality here is still buggy and very much a work in progress.
@ -504,8 +505,8 @@ class InteractiveScene(Scene):
self.toggle_selection_mode() self.toggle_selection_mode()
elif char == "s" and modifiers == COMMAND_MODIFIER: elif char == "s" and modifiers == COMMAND_MODIFIER:
self.save_selection_to_file() self.save_selection_to_file()
elif char == PAN_3D_KEY and modifiers == COMMAND_MODIFIER: elif char == "d" and modifiers == SHIFT_MODIFIER:
self.copy_frame_anim_call() self.copy_frame_positioning()
elif symbol in ARROW_SYMBOLS: elif symbol in ARROW_SYMBOLS:
self.nudge_selection( self.nudge_selection(
vect=[LEFT, UP, RIGHT, DOWN][ARROW_SYMBOLS.index(symbol)], vect=[LEFT, UP, RIGHT, DOWN][ARROW_SYMBOLS.index(symbol)],
@ -615,16 +616,18 @@ class InteractiveScene(Scene):
self.clear_selection() self.clear_selection()
# Copying code to recreate state # Copying code to recreate state
def copy_frame_anim_call(self): def copy_frame_positioning(self):
frame = self.frame frame = self.frame
center = frame.get_center() center = frame.get_center()
height = frame.get_height() height = frame.get_height()
angles = frame.get_euler_angles() angles = frame.get_euler_angles()
call = f"self.frame.animate.reorient" call = f"reorient("
call += str(tuple((angles / DEGREES).astype(int))) theta, phi, gamma = (angles / DEGREES).astype(int)
call += f"{theta}, {phi}, {gamma}"
if any(center != 0): if any(center != 0):
call += f".move_to({list(np.round(center, 2))})" call += f", {tuple(np.round(center, 2))}"
if height != FRAME_HEIGHT: if height != FRAME_HEIGHT:
call += ".set_height({:.2f})".format(height) call += ", {:.2f}".format(height)
call += ")"
pyperclip.copy(call) pyperclip.copy(call)

View file

@ -7,6 +7,7 @@ import platform
import pyperclip import pyperclip
import random import random
import time import time
import re
from functools import wraps from functools import wraps
from IPython.terminal import pt_inputhooks from IPython.terminal import pt_inputhooks
@ -44,9 +45,11 @@ from manimlib.utils.iterables import batch_by_property
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable, Iterable from typing import Callable, Iterable, TypeVar
from manimlib.typing import Vect3 from manimlib.typing import Vect3
T = TypeVar('T')
from PIL.Image import Image from PIL.Image import Image
from manimlib.animation.animation import Animation from manimlib.animation.animation import Animation
@ -210,7 +213,8 @@ class Scene(object):
show_animation_progress: bool = False, show_animation_progress: bool = False,
) -> None: ) -> None:
if not self.preview: if not self.preview:
return # Embed is only relevant with a preview # Embed is only relevant with a preview
return
self.stop_skipping() self.stop_skipping()
self.update_frame() self.update_frame()
self.save_state() self.save_state()
@ -236,6 +240,8 @@ class Scene(object):
i2g=self.i2g, i2g=self.i2g,
i2m=self.i2m, i2m=self.i2m,
checkpoint_paste=self.checkpoint_paste, checkpoint_paste=self.checkpoint_paste,
touch=lambda: shell.enable_gui("manim"),
notouch=lambda: shell.enable_gui(None),
) )
# Enables gui interactions during the embed # Enables gui interactions during the embed
@ -257,20 +263,19 @@ class Scene(object):
# namespace, since this is just a shell session anyway. # namespace, since this is just a shell session anyway.
shell.events.register( shell.events.register(
"pre_run_cell", "pre_run_cell",
lambda: shell.user_global_ns.update(shell.user_ns) lambda *args, **kwargs: shell.user_global_ns.update(shell.user_ns)
) )
# Operation to run after each ipython command # Operation to run after each ipython command
def post_cell_func(): def post_cell_func(*args, **kwargs):
if not self.is_window_closing(): if not self.is_window_closing():
self.update_frame(dt=0, ignore_skipping=True) self.update_frame(dt=0, ignore_skipping=True)
self.save_state()
shell.events.register("post_run_cell", post_cell_func) shell.events.register("post_run_cell", post_cell_func)
# Flash border, and potentially play sound, on exceptions # Flash border, and potentially play sound, on exceptions
def custom_exc(shell, etype, evalue, tb, tb_offset=None): def custom_exc(shell, etype, evalue, tb, tb_offset=None):
# still show the error don't just swallow it # Show the error don't just swallow it
shell.showtraceback((etype, evalue, tb), tb_offset=tb_offset) shell.showtraceback((etype, evalue, tb), tb_offset=tb_offset)
if self.embed_error_sound: if self.embed_error_sound:
os.system("printf '\a'") os.system("printf '\a'")
@ -342,17 +347,9 @@ class Scene(object):
mobject.update(dt) mobject.update(dt)
def should_update_mobjects(self) -> bool: def should_update_mobjects(self) -> bool:
return self.always_update_mobjects or any([ return self.always_update_mobjects or any(
len(mob.get_family_updaters()) > 0 mob.has_updaters() for mob in self.mobjects
for mob in self.mobjects )
])
def has_time_based_updaters(self) -> bool:
return any([
sm.has_time_based_updater()
for mob in self.mobjects()
for sm in mob.get_family()
])
# Related to time # Related to time
@ -399,7 +396,8 @@ class Scene(object):
for batch, key in batches for batch, key in batches
] ]
def affects_mobject_list(func: Callable): @staticmethod
def affects_mobject_list(func: Callable[..., T]) -> Callable[..., T]:
@wraps(func) @wraps(func)
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
func(self, *args, **kwargs) func(self, *args, **kwargs)
@ -774,13 +772,31 @@ class Scene(object):
) )
pasted = pyperclip.paste() pasted = pyperclip.paste()
line0 = pasted.lstrip().split("\n")[0] lines = pasted.split("\n")
if line0.startswith("#"):
if line0 not in self.checkpoint_states:
self.checkpoint(line0)
else:
self.revert_to_checkpoint(line0)
# Commented lines trigger saved checkpoints
if lines[0].lstrip().startswith("#"):
if lines[0] not in self.checkpoint_states:
self.checkpoint(lines[0])
else:
self.revert_to_checkpoint(lines[0])
# Copied methods of a scene are handled specially
# A bit hacky, yes, but convenient
method_pattern = r"^def\s+([a-zA-Z_]\w*)\s*\(self.*\):"
method_names = re.findall(method_pattern ,lines[0].strip())
if method_names:
method_name = method_names[0]
indent = " " * lines[0].index(lines[0].strip())
pasted = "\n".join([
# Remove self from function signature
re.sub(r"self(,\s*)?", "", lines[0]),
*lines[1:],
# Attach to scene via self.func_name = func_name
f"{indent}self.{method_name} = {method_name}"
])
# Keep track of skipping and progress bar status
prev_skipping = self.skip_animations prev_skipping = self.skip_animations
self.skip_animations = skip self.skip_animations = skip
@ -836,6 +852,13 @@ class Scene(object):
return self.window and (self.window.is_closing or self.quit_interaction) return self.window and (self.window.is_closing or self.quit_interaction)
# Event handling # Event handling
def set_floor_plane(self, plane: str = "xy"):
if plane == "xy":
self.frame.set_euler_axes("zxz")
elif plane == "xz":
self.frame.set_euler_axes("zxy")
else:
raise Exception("Only `xz` and `xy` are valid floor planes")
def on_mouse_motion( def on_mouse_motion(
self, self,
@ -1023,7 +1046,7 @@ class ThreeDScene(Scene):
default_frame_orientation = (-30, 70) default_frame_orientation = (-30, 70)
always_depth_test = True always_depth_test = True
def add(self, *mobjects, set_depth_test: bool = True): def add(self, *mobjects: Mobject, set_depth_test: bool = True):
for mob in mobjects: for mob in mobjects:
if set_depth_test and not mob.is_fixed_in_frame() and self.always_depth_test: if set_depth_test and not mob.is_fixed_in_frame() and self.always_depth_test:
mob.apply_depth_test() mob.apply_depth_test()

View file

@ -9,9 +9,8 @@ const float Y_SCALE = 2.0 / DEFAULT_FRAME_HEIGHT;
void emit_gl_Position(vec3 point){ void emit_gl_Position(vec3 point){
vec4 result = vec4(point, 1.0); vec4 result = vec4(point, 1.0);
if(!bool(is_fixed_in_frame)){ // This allow for smooth transitions between objects fixed and unfixed from frame
result = view * result; result = mix(view * result, result, is_fixed_in_frame);
}
// Essentially a projection matrix // Essentially a projection matrix
result.x *= X_SCALE; result.x *= X_SCALE;
result.y *= Y_SCALE; result.y *= Y_SCALE;

View file

@ -1,23 +0,0 @@
uniform float is_fixed_in_frame;
uniform mat4 view;
uniform float focal_distance;
const float DEFAULT_FRAME_HEIGHT = 8.0;
const float ASPECT_RATIO = 16.0 / 9.0;
const float X_SCALE = 2.0 / DEFAULT_FRAME_HEIGHT / ASPECT_RATIO;
const float Y_SCALE = 2.0 / DEFAULT_FRAME_HEIGHT;
void emit_gl_Position(vec3 point){
vec4 result = vec4(point, 1.0);
if(!bool(is_fixed_in_frame)){
result = view * result;
}
// Essentially a projection matrix
result.x *= X_SCALE;
result.y *= Y_SCALE;
result.z /= focal_distance;
result.w = 1.0 - result.z;
// Flip and scale to prevent premature clipping
result.z *= -0.1;
gl_Position = result;
}

View file

@ -1,66 +1,20 @@
#version 330 #version 330
in vec2 uv_coords; // Value between -1 and 1
in float scaled_signed_dist_to_curve;
in float uv_stroke_width; in float scaled_anti_alias_width;
in float uv_anti_alias_width;
in vec4 color; in vec4 color;
in float is_linear;
out vec4 frag_color; out vec4 frag_color;
const float QUICK_DIST_WIDTH = 0.2;
float dist_to_curve(){
// In the linear case, the curve will have
// been set to equal the x axis
if(bool(is_linear)) return abs(uv_coords.y);
// Otherwise, find the distance from uv_coords to the curve y = x^2
float x0 = uv_coords.x;
float y0 = uv_coords.y;
// This is a quick approximation for computing
// the distance to the curve.
// Evaluate F(x, y) = y - x^2
// divide by its gradient's magnitude
float Fxy = y0 - x0 * x0;
float approx_dist = abs(Fxy) * inversesqrt(1.0 + 4 * x0 * x0);
if(approx_dist < QUICK_DIST_WIDTH) return approx_dist;
// Otherwise, solve for the minimal distance.
// The distance squared between (x0, y0) and a point (x, x^2) looks like
//
// (x0 - x)^2 + (y0 - x^2)^2 = x^4 + (1 - 2y0)x^2 - 2x0 * x + (x0^2 + y0^2)
//
// Setting the derivative equal to zero (and rescaling) looks like
//
// x^3 + (0.5 - y0) * x - 0.5 * x0 = 0
//
// Adapted from https://www.shadertoy.com/view/ws3GD7
x0 = abs(x0);
float p = (0.5 - y0) / 3.0; // p / 3 in usual Cardano's formula notation
float q = 0.25 * x0; // -q / 2 in usual Cardano's formula notation
float disc = q*q + p*p*p;
float r = sqrt(abs(disc));
float x = (disc > 0.0) ?
// 1 root
pow(q + r, 1.0 / 3.0) + pow(abs(q - r), 1.0 / 3.0) * sign(-p) :
// 3 roots
2.0 * cos(atan(r, q) / 3.0) * sqrt(-p);
return length(vec2(x0 - x, y0 - x * x));
}
void main() { void main() {
if (uv_stroke_width == 0) discard; if(scaled_anti_alias_width < 0) discard;
frag_color = color; frag_color = color;
// sdf for the region around the curve we wish to color. // sdf for the region around the curve we wish to color.
float signed_dist = dist_to_curve() - 0.5 * uv_stroke_width; float signed_dist_to_region = abs(scaled_signed_dist_to_curve) - 1.0;
frag_color.a *= smoothstep(
frag_color.a *= smoothstep(0.5, -0.5, signed_dist / uv_anti_alias_width); 0, -scaled_anti_alias_width,
signed_dist_to_region
);
} }

View file

@ -1,12 +1,13 @@
#version 330 #version 330
layout (triangles) in; layout (triangles) in;
layout (triangle_strip, max_vertices = 6) out; layout (triangle_strip, max_vertices = 64) out; // Related to MAX_STEPS below
uniform float anti_alias_width; uniform float anti_alias_width;
uniform float flat_stroke; uniform float flat_stroke;
uniform float pixel_size; uniform float pixel_size;
uniform float joint_type; uniform float joint_type;
uniform float frame_scale;
in vec3 verts[3]; in vec3 verts[3];
@ -15,12 +16,8 @@ in float v_stroke_width[3];
in vec4 v_color[3]; in vec4 v_color[3];
out vec4 color; out vec4 color;
out float uv_stroke_width; out float scaled_anti_alias_width;
out float uv_anti_alias_width; out float scaled_signed_dist_to_curve;
out float is_linear;
out vec2 uv_coords;
// Codes for joint types // Codes for joint types
const int NO_JOINT = 0; const int NO_JOINT = 0;
@ -32,11 +29,13 @@ const int MITER_JOINT = 3;
// two vectors is larger than this, we // two vectors is larger than this, we
// consider them aligned // consider them aligned
const float COS_THRESHOLD = 0.99; const float COS_THRESHOLD = 0.99;
// Used to determine how many lines to break the curve into
const float POLYLINE_FACTOR = 30;
const int MAX_STEPS = 32;
vec3 unit_normal = vec3(0.0, 0.0, 1.0); vec3 unit_normal = vec3(0.0, 0.0, 1.0);
#INSERT emit_gl_Position.glsl #INSERT emit_gl_Position.glsl
#INSERT get_xyz_to_uv.glsl
#INSERT finalize_color.glsl #INSERT finalize_color.glsl
@ -54,6 +53,30 @@ vec4 normalized_joint_product(vec4 joint_product){
} }
vec3 point_on_curve(float t){
return verts[0] + 2 * (verts[1] - verts[0]) * t + (verts[0] - 2 * verts[1] + verts[2]) * t * t;
}
vec3 tangent_on_curve(float t){
return 2 * (verts[1] - verts[0]) + 2 * (verts[0] - 2 * verts[1] + verts[2]) * t;
}
void compute_subdivision(out int n_steps, out float subdivision[MAX_STEPS]){
// Crude estimate for the number of polyline segments to use, based
// on the area spanned by the control points
float area = 0.5 * length(v_joint_product[1].xzy);
int count = 2 + int(round(POLYLINE_FACTOR * sqrt(area) / frame_scale));
n_steps = min(count, MAX_STEPS);
for(int i = 0; i < MAX_STEPS; i++){
if (i >= n_steps) break;
subdivision[i] = float(i) / (n_steps - 1);
}
}
void create_joint( void create_joint(
vec4 joint_product, vec4 joint_product,
vec3 unit_tan, vec3 unit_tan,
@ -83,73 +106,58 @@ void create_joint(
changing_c1 = static_c1 + shift * unit_tan; changing_c1 = static_c1 + shift * unit_tan;
} }
vec3 get_perp(int index, vec4 joint_product, vec3 point, vec3 tangent, float aaw){
vec3 left_step(vec3 point, vec3 tangent, vec4 joint_product){
/* /*
Perpendicular vectors to the left of the curve Perpendicular vectors to the left of the curve
*/ */
float buff = 0.5 * v_stroke_width[index] + aaw;
// Add correction for sharp angles to prevent weird bevel effects // Add correction for sharp angles to prevent weird bevel effects
if(joint_product.w < -0.75) buff *= 4 * (joint_product.w + 1.0); float mult = 1.0;
if(joint_product.w < -0.75) mult *= 4 * (joint_product.w + 1.0);
vec3 normal = get_joint_unit_normal(joint_product); vec3 normal = get_joint_unit_normal(joint_product);
// Set global unit normal // Set global unit normal
unit_normal = normal; unit_normal = normal;
// Choose the "outward" normal direction // Choose the "outward" normal direction
if(normal.z < 0) normal *= -1; if(normal.z < 0) normal *= -1;
if(bool(flat_stroke)){ if(bool(flat_stroke)){
return buff * normalize(cross(normal, tangent)); return mult * normalize(cross(normal, tangent));
}else{ }else{
return buff * normalize(cross(camera_position - point, tangent)); return mult * normalize(cross(camera_position - point, tangent));
} }
} }
// This function is responsible for finding the corners of
// a bounding region around the bezier curve, which can be void emit_point_with_width(
// emitted as a triangle fan, with vertices vaguely close vec3 point,
// to control points so that the passage of vert data to vec3 tangent,
// frag shaders is most natural. vec4 joint_product,
void get_corners( float width,
// Control points for a bezier curve vec4 joint_color
vec3 p0,
vec3 p1,
vec3 p2,
// Unit tangent vectors at p0 and p2
vec3 v01,
vec3 v12,
// Anti-alias width
float aaw,
out vec3 corners[6]
){ ){
bool linear = bool(is_linear); vec3 unit_tan = normalize(tangent);
vec4 jp0 = normalized_joint_product(v_joint_product[0]); vec4 normed_join_product = normalized_joint_product(joint_product);
vec4 jp2 = normalized_joint_product(v_joint_product[2]); vec3 perp = 0.5 * width * left_step(point, unit_tan, normed_join_product);
vec3 p0_perp = get_perp(0, jp0, p0, v01, aaw);
vec3 p2_perp = get_perp(2, jp2, p2, v12, aaw);
vec3 p1_perp = 0.5 * (p0_perp + p2_perp);
if(linear){
p1_perp *= (0.5 * v_stroke_width[1] + aaw) / length(p1_perp);
}
// The order of corners should be for a triangle_strip. vec3 corners[2] = vec3[2](point + perp, point - perp);
vec3 c0 = p0 + p0_perp; create_joint(
vec3 c1 = p0 - p0_perp; normed_join_product, unit_tan, length(perp),
vec3 c2 = p1 + p1_perp; corners[0], corners[0],
vec3 c3 = p1 - p1_perp; corners[1], corners[1]
vec3 c4 = p2 + p2_perp; );
vec3 c5 = p2 - p2_perp;
// Move the inner middle control point to make
// room for the curve
// float orientation = dot(unit_normal, v_joint_product[1].xyz);
float orientation = v_joint_product[1].z;
if(!linear && orientation >= 0.0) c2 = 0.5 * (c0 + c4);
else if(!linear && orientation < 0.0) c3 = 0.5 * (c1 + c5);
// Account for previous and next control points color = finalize_color(joint_color, point, unit_normal);
if(bool(flat_stroke)){ if (width == 0) scaled_anti_alias_width = -1.0; // Signal to discard in frag
create_joint(jp0, v01, length(p0_perp), c1, c1, c0, c0); else scaled_anti_alias_width = 2.0 * anti_alias_width * pixel_size / width;
create_joint(jp2, -v12, length(p2_perp), c5, c5, c4, c4);
}
corners = vec3[6](c0, c1, c2, c3, c4, c5); // Emit two corners
// The frag shader will receive a value from -1 to 1,
// reflecting where in the stroke that point is
scaled_signed_dist_to_curve = -1.0;
emit_gl_Position(corners[0]);
EmitVertex();
scaled_signed_dist_to_curve = +1.0;
emit_gl_Position(corners[1]);
EmitVertex();
} }
void main() { void main() {
@ -157,53 +165,40 @@ void main() {
// the first anchor is set equal to that anchor // the first anchor is set equal to that anchor
if (verts[0] == verts[1]) return; if (verts[0] == verts[1]) return;
vec3 p0 = verts[0]; // Compute subdivision
vec3 p1 = verts[1]; int n_steps;
vec3 p2 = verts[2]; float subdivision[MAX_STEPS];
vec3 v01 = normalize(p1 - p0); compute_subdivision(n_steps, subdivision);
vec3 v12 = normalize(p2 - p1); vec3 points[MAX_STEPS];
for (int i = 0; i < MAX_STEPS; i++){
if (i >= n_steps) break;
vec4 jp1 = normalized_joint_product(v_joint_product[1]); points[i] = point_on_curve(subdivision[i]);
is_linear = float(jp1.w > COS_THRESHOLD);
// We want to change the coordinates to a space where the curve
// coincides with y = x^2, between some values x0 and x2. Or, in
// the case of a linear curve just put it on the x-axis
mat4 xyz_to_uv;
float uv_scale_factor;
if(!bool(is_linear)){
bool too_steep;
xyz_to_uv = get_xyz_to_uv(p0, p1, p2, 2.0, too_steep);
is_linear = float(too_steep);
uv_scale_factor = length(xyz_to_uv[0].xyz);
} }
float scaled_aaw = anti_alias_width * pixel_size; // Compute joint products
vec3 corners[6]; vec4 joint_products[MAX_STEPS];
get_corners(p0, p1, p2, v01, v12, scaled_aaw, corners); joint_products[0] = v_joint_product[0];
joint_products[0].xyz *= -1;
joint_products[n_steps - 1] = v_joint_product[2];
for (int i = 1; i < MAX_STEPS; i++){
if (i >= n_steps - 1) break;
vec3 v1 = points[i] - points[i - 1];
vec3 v2 = points[i + 1] - points[i];
joint_products[i].xyz = cross(v1, v2);
joint_products[i].w = dot(v1, v2);
}
// Emit each corner // Emit vertex pairs aroudn subdivided points
float max_sw = max(v_stroke_width[0], v_stroke_width[2]); for (int i = 0; i < MAX_STEPS; i++){
for(int i = 0; i < 6; i++){ if (i >= n_steps) break;
float stroke_width = v_stroke_width[i / 2]; float t = subdivision[i];
emit_point_with_width(
if(bool(is_linear)){ points[i],
float sign = vec2(-1, 1)[i % 2]; tangent_on_curve(t),
// In this case, we only really care about joint_products[i],
// the v coordinate mix(v_stroke_width[0], v_stroke_width[2], t),
uv_coords = vec2(0, sign * (0.5 * stroke_width + scaled_aaw)); mix(v_color[0], v_color[2], t)
uv_anti_alias_width = scaled_aaw; );
uv_stroke_width = stroke_width;
}else{
uv_coords = (xyz_to_uv * vec4(corners[i], 1.0)).xy;
uv_stroke_width = uv_scale_factor * stroke_width;
uv_anti_alias_width = uv_scale_factor * scaled_aaw;
}
color = finalize_color(v_color[i / 2], corners[i], unit_normal);
emit_gl_Position(corners[i]);
EmitVertex();
} }
EndPrimitive(); EndPrimitive();
} }

View file

@ -20,8 +20,7 @@ const float STROKE_WIDTH_CONVERSION = 0.01;
void main(){ void main(){
verts = point; verts = point;
v_stroke_width = STROKE_WIDTH_CONVERSION * stroke_width; v_stroke_width = STROKE_WIDTH_CONVERSION * stroke_width * mix(frame_scale, 1, is_fixed_in_frame);
v_stroke_width *= mix(frame_scale, 1, is_fixed_in_frame);
v_joint_product = joint_product; v_joint_product = joint_product;
v_color = stroke_rgba; v_color = stroke_rgba;
} }

View file

@ -4,6 +4,7 @@ from colour import Color
from colour import hex2rgb from colour import hex2rgb
from colour import rgb2hex from colour import rgb2hex
import numpy as np import numpy as np
import random
from manimlib.constants import COLORMAP_3B1B from manimlib.constants import COLORMAP_3B1B
from manimlib.constants import WHITE from manimlib.constants import WHITE
@ -102,6 +103,16 @@ def interpolate_color(
return rgb_to_color(rgb) return rgb_to_color(rgb)
def interpolate_color_by_hsl(
color1: ManimColor,
color2: ManimColor,
alpha: float
) -> Color:
hsl1 = np.array(Color(color1).get_hsl())
hsl2 = np.array(Color(color2).get_hsl())
return Color(hsl=interpolate(hsl1, hsl2, alpha))
def average_color(*colors: ManimColor) -> Color: def average_color(*colors: ManimColor) -> Color:
rgbs = np.array(list(map(color_to_rgb, colors))) rgbs = np.array(list(map(color_to_rgb, colors)))
return rgb_to_color(np.sqrt((rgbs**2).mean(0))) return rgb_to_color(np.sqrt((rgbs**2).mean(0)))
@ -111,9 +122,16 @@ def random_color() -> Color:
return Color(rgb=tuple(np.random.random(3))) return Color(rgb=tuple(np.random.random(3)))
def random_bright_color() -> Color: def random_bright_color(
color = random_color() hue_range: tuple[float, float] = (0.0, 1.0),
return average_color(color, Color(WHITE)) saturation_range: tuple[float, float] = (0.5, 0.8),
luminance_range: tuple[float, float] = (0.5, 1.0),
) -> Color:
return Color(hsl=(
interpolate(*hue_range, random.random()),
interpolate(*saturation_range, random.random()),
interpolate(*luminance_range, random.random()),
))
def get_colormap_list( def get_colormap_list(

View file

@ -3,6 +3,7 @@ from __future__ import annotations
from colour import Color from colour import Color
import numpy as np import numpy as np
import random
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@ -83,6 +84,12 @@ def listify(obj: object) -> list:
return [obj] return [obj]
def shuffled(iterable: Iterable) -> list:
as_list = list(iterable)
random.shuffle(as_list)
return as_list
def resize_array(nparray: np.ndarray, length: int) -> np.ndarray: def resize_array(nparray: np.ndarray, length: int) -> np.ndarray:
if len(nparray) == length: if len(nparray) == length:
return nparray return nparray

View file

@ -9,7 +9,7 @@ import numpy as np
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable, TypeVar from typing import Callable, TypeVar, Iterable
from manimlib.typing import FloatArray from manimlib.typing import FloatArray
Scalable = TypeVar("Scalable", float, FloatArray) Scalable = TypeVar("Scalable", float, FloatArray)
@ -30,11 +30,11 @@ def gen_choose(n: int, r: int) -> int:
def get_num_args(function: Callable) -> int: def get_num_args(function: Callable) -> int:
return len(get_parameters(function)) return function.__code__.co_argcount
def get_parameters(function: Callable) -> list: def get_parameters(function: Callable) -> Iterable[str]:
return list(inspect.signature(function).parameters.keys()) return inspect.signature(function).parameters.keys()
# Just to have a less heavyweight name for this extremely common operation # Just to have a less heavyweight name for this extremely common operation
# #