Merge branch 'master' of github.com:3b1b/manim

This commit is contained in:
Grant Sanderson 2024-08-01 07:55:11 -05:00
commit d644e3b184
31 changed files with 929 additions and 647 deletions

View file

@ -26,7 +26,7 @@ class OpeningManimExample(Scene):
matrix = [[1, 1], [0, 1]]
linear_transform_words = VGroup(
Text("This is what the matrix"),
IntegerMatrix(matrix, include_background_rectangle=True),
IntegerMatrix(matrix),
Text("looks like")
)
linear_transform_words.arrange(RIGHT)
@ -251,7 +251,7 @@ class TexIndexing(Scene):
self.play(FlashAround(part))
self.wait()
self.play(FadeOut(equation))
# Indexing by substrings like this may not work when
# the order in which Latex draws symbols does not match
# the order in which they show up in the string.
@ -289,11 +289,11 @@ class UpdatersExample(Scene):
brace = always_redraw(Brace, square, UP)
label = TexText("Width = 0.00")
number = label.make_number_changable("0.00")
number = label.make_number_changeable("0.00")
# This ensures that the method deicmal.next_to(square)
# is called on every frame
always(label.next_to, brace, UP)
label.always.next_to(brace, UP)
# You could also write the following equivalent line
# label.add_updater(lambda m: m.next_to(brace, UP))
@ -302,7 +302,7 @@ class UpdatersExample(Scene):
# should be functions returning arguments to that method.
# The following line ensures thst decimal.set_value(square.get_y())
# is called every frame
f_always(number.set_value, square.get_width)
number.f_always.set_value(square.get_width)
# You could also write the following equivalent line
# number.add_updater(lambda m: m.set_value(square.get_width()))
@ -359,7 +359,7 @@ class CoordinateSystemExample(Scene):
# Alternatively, you can specify configuration for just one
# of them, like this.
y_axis_config=dict(
numbers_with_elongated_ticks=[-2, 2],
big_tick_numbers=[-2, 2],
)
)
# Keyword arguments of add_coordinate_labels can be used to
@ -515,7 +515,7 @@ class TexAndNumbersExample(Scene):
# on them.
tex = Tex("x^2 + y^2 = 4.00")
tex.next_to(axes, UP, buff=0.5)
value = tex.make_number_changable("4.00")
value = tex.make_number_changeable("4.00")
# This will tie the right hand side of our equation to
@ -537,10 +537,10 @@ class TexAndNumbersExample(Scene):
rate_func=there_and_back,
)
# By default, tex.make_number_changable replaces the first occurance
# By default, tex.make_number_changeable replaces the first occurance
# of the number,but by passing replace_all=True it replaces all and
# returns a group of the results
exponents = tex.make_number_changable("2", replace_all=True)
exponents = tex.make_number_changeable("2", replace_all=True)
self.play(
LaggedStartMap(
FlashAround, exponents,

View file

@ -43,7 +43,6 @@ from manimlib.mobject.probability import *
from manimlib.mobject.shape_matchers import *
from manimlib.mobject.svg.brace import *
from manimlib.mobject.svg.drawings import *
from manimlib.mobject.svg.tex_mobject import *
from manimlib.mobject.svg.string_mobject import *
from manimlib.mobject.svg.svg_mobject import *
from manimlib.mobject.svg.special_tex import *

View file

@ -13,7 +13,8 @@ from manimlib.utils.bezier import interpolate
from manimlib.utils.iterables import remove_list_redundancies
from manimlib.utils.simple_functions import clip
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union, Iterable
AnimationType = Union[Animation, _AnimationBuilder]
if TYPE_CHECKING:
from typing import Callable, Optional
@ -26,14 +27,16 @@ DEFAULT_LAGGED_START_LAG_RATIO = 0.05
class AnimationGroup(Animation):
def __init__(self,
*animations: Animation | _AnimationBuilder,
def __init__(
self,
*args: AnimationType | Iterable[AnimationType],
run_time: float = -1, # If negative, default to sum of inputed animation runtimes
lag_ratio: float = 0.0,
group: Optional[Mobject] = None,
group_type: Optional[type] = None,
**kwargs
):
animations = args[0] if isinstance(args[0], Iterable) else args
self.animations = [prepare_animation(anim) for anim in animations]
self.build_animations_with_timings(lag_ratio)
self.max_end_time = max((awt[2] for awt in self.anims_with_timings), default=0)
@ -177,4 +180,5 @@ class LaggedStartMap(LaggedStart):
*(anim_func(submob, **anim_kwargs) for submob in group),
run_time=run_time,
lag_ratio=lag_ratio,
group=group
)

View file

@ -118,6 +118,7 @@ class FadeTransform(Transform):
def ghost_to(self, source: Mobject, target: Mobject) -> None:
source.replace(target, stretch=self.stretch, dim_to_match=self.dim_to_match)
source.set_uniform(**target.get_uniforms())
source.set_opacity(0)
def get_all_mobjects(self) -> list[Mobject]:
@ -134,7 +135,8 @@ class FadeTransform(Transform):
Animation.clean_up_from_scene(self, scene)
scene.remove(self.mobject)
self.mobject[0].restore()
scene.add(self.to_add_on_completion)
if not self.remover:
scene.add(self.to_add_on_completion)
class FadeTransformPieces(FadeTransform):

View file

@ -3,6 +3,7 @@ from __future__ import annotations
from manimlib.animation.animation import Animation
from manimlib.mobject.numbers import DecimalNumber
from manimlib.utils.bezier import interpolate
from manimlib.utils.simple_functions import clip
from typing import TYPE_CHECKING
@ -55,9 +56,9 @@ class CountInFrom(ChangingDecimal):
source_number: float | complex = 0,
**kwargs
):
start_number = decimal_mob.number
start_number = decimal_mob.get_value()
super().__init__(
decimal_mob,
lambda a: interpolate(source_number, start_number, a),
lambda a: interpolate(source_number, start_number, clip(a, 0, 1)),
**kwargs
)

View file

@ -65,7 +65,7 @@ class Transform(Animation):
self.target_copy = self.target_mobject.copy()
self.mobject.align_data_and_family(self.target_copy)
super().begin()
if not self.mobject.has_updaters:
if not self.mobject.has_updaters():
self.mobject.lock_matching_data(
self.starting_mobject,
self.target_copy,

View file

@ -25,6 +25,7 @@ class CameraFrame(Mobject):
center_point: Vect3 = ORIGIN,
# Field of view in the y direction
fovy: float = 45 * DEGREES,
euler_axes: str = "zxz",
**kwargs,
):
super().__init__(**kwargs)
@ -35,6 +36,7 @@ class CameraFrame(Mobject):
self.default_orientation = Rotation.identity()
self.view_matrix = np.identity(4)
self.camera_location = OUT # This will be updated by set_points
self.euler_axes = euler_axes
self.set_points(np.array([ORIGIN, LEFT, RIGHT, DOWN, UP]))
self.set_width(frame_shape[0], stretch=True)
@ -62,7 +64,7 @@ class CameraFrame(Mobject):
orientation = self.get_orientation()
if all(orientation.as_quat() == [0, 0, 0, 1]):
return np.zeros(3)
return orientation.as_euler("zxz")[::-1]
return orientation.as_euler(self.euler_axes)[::-1]
def get_theta(self):
return self.get_euler_angles()[0]
@ -126,21 +128,44 @@ class CameraFrame(Mobject):
if all(eulers == 0):
rot = Rotation.identity()
else:
rot = Rotation.from_euler("zxz", eulers[::-1])
rot = Rotation.from_euler(self.euler_axes, eulers[::-1])
self.set_orientation(rot)
return self
def increment_euler_angles(
self,
dtheta: float | None = None,
dphi: float | None = None,
dgamma: float | None = None,
units: float = RADIANS
):
angles = self.get_euler_angles()
for i, value in enumerate([dtheta, dphi, dgamma]):
if value is not None:
angles[i] += value * units
self.set_euler_angles(*angles)
return self
def set_euler_axes(self, seq: str):
self.euler_axes = seq
def reorient(
self,
theta_degrees: float | None = None,
phi_degrees: float | None = None,
gamma_degrees: float | None = None,
center: Vect3 | tuple[float, float, float] | None = None,
height: float | None = None
):
"""
Shortcut for set_euler_angles, defaulting to taking
in angles in degrees
"""
self.set_euler_angles(theta_degrees, phi_degrees, gamma_degrees, units=DEGREES)
if center is not None:
self.move_to(np.array(center))
if height is not None:
self.set_height(height)
return self
def set_theta(self, theta: float):
@ -152,16 +177,20 @@ class CameraFrame(Mobject):
def set_gamma(self, gamma: float):
return self.set_euler_angles(gamma=gamma)
def increment_theta(self, dtheta: float):
self.rotate(dtheta, OUT)
def increment_theta(self, dtheta: float, units=RADIANS):
self.increment_euler_angles(dtheta=dtheta, units=units)
return self
def increment_phi(self, dphi: float):
self.rotate(dphi, self.get_inverse_camera_rotation_matrix()[0])
def increment_phi(self, dphi: float, units=RADIANS):
self.increment_euler_angles(dphi=dphi, units=units)
return self
def increment_gamma(self, dgamma: float):
self.rotate(dgamma, self.get_inverse_camera_rotation_matrix()[2])
def increment_gamma(self, dgamma: float, units=RADIANS):
self.increment_euler_angles(dgamma=dgamma, units=units)
return self
def add_ambient_rotation(self, angular_speed=1 * DEGREES):
self.add_updater(lambda m, dt: m.increment_theta(angular_speed * dt))
return self
@Mobject.affects_data

View file

@ -6,7 +6,7 @@ import colour
import importlib
import inspect
import os
from screeninfo import get_monitors
import screeninfo
import sys
import yaml
@ -433,7 +433,10 @@ def get_file_writer_config(args: Namespace, custom_config: dict) -> dict:
def get_window_config(args: Namespace, custom_config: dict, camera_config: dict) -> dict:
# Default to making window half the screen size
# but make it full screen if -f is passed in
monitors = get_monitors()
try:
monitors = screeninfo.get_monitors()
except screeninfo.ScreenInfoError:
pass
mon_index = custom_config["window_monitor"]
monitor = monitors[min(mon_index, len(monitors) - 1)]
aspect_ratio = camera_config["pixel_width"] / camera_config["pixel_height"]

View file

@ -167,3 +167,4 @@ class TracingTail(TracedPath):
stroke_color=stroke_color,
**kwargs
)
self.add_updater(lambda m: m.set_stroke(width=stroke_width, opacity=stroke_opacity))

View file

@ -603,9 +603,6 @@ class ThreeDAxes(Axes):
**kwargs
) -> ParametricSurface:
surface = ParametricSurface(func, color=color, opacity=opacity, **kwargs)
xu = self.x_axis.get_unit_size()
yu = self.y_axis.get_unit_size()
zu = self.z_axis.get_unit_size()
axes = [self.x_axis, self.y_axis, self.z_axis]
for dim, axis in zip(range(3), axes):
surface.stretch(axis.get_unit_size(), dim, about_point=ORIGIN)

View file

@ -4,85 +4,38 @@ import itertools as it
import numpy as np
from manimlib.constants import DEFAULT_MOBJECT_TO_MOBJECT_BUFFER
from manimlib.constants import DOWN, LEFT, RIGHT, UP
from manimlib.constants import WHITE
from manimlib.constants import DOWN, LEFT, RIGHT, ORIGIN
from manimlib.constants import DEGREES
from manimlib.mobject.numbers import DecimalNumber
from manimlib.mobject.numbers import Integer
from manimlib.mobject.shape_matchers import BackgroundRectangle
from manimlib.mobject.svg.tex_mobject import Tex
from manimlib.mobject.svg.tex_mobject import TexText
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Sequence
import numpy.typing as npt
from manimlib.mobject.mobject import Mobject
from manimlib.typing import ManimColor, Vect3, Self
from typing import Sequence, Union, Tuple, Optional
from manimlib.typing import ManimColor, Vect3, VectNArray, Self
VECTOR_LABEL_SCALE_FACTOR = 0.8
def matrix_to_tex_string(matrix: npt.ArrayLike) -> str:
matrix = np.array(matrix).astype("str")
if matrix.ndim == 1:
matrix = matrix.reshape((matrix.size, 1))
n_rows, n_cols = matrix.shape
prefix = R"\left[ \begin{array}{%s}" % ("c" * n_cols)
suffix = R"\end{array} \right]"
rows = [
" & ".join(row)
for row in matrix
]
return prefix + R" \\ ".join(rows) + suffix
def matrix_to_mobject(matrix: npt.ArrayLike) -> Tex:
return Tex(matrix_to_tex_string(matrix))
def vector_coordinate_label(
vector_mob: VMobject,
integer_labels: bool = True,
n_dim: int = 2,
color: ManimColor = WHITE
) -> Matrix:
vect = np.array(vector_mob.get_end())
if integer_labels:
vect = np.round(vect).astype(int)
vect = vect[:n_dim]
vect = vect.reshape((n_dim, 1))
label = Matrix(vect, add_background_rectangles_to_entries=True)
label.scale(VECTOR_LABEL_SCALE_FACTOR)
shift_dir = np.array(vector_mob.get_end())
if shift_dir[0] >= 0: # Pointing right
shift_dir -= label.get_left() + DEFAULT_MOBJECT_TO_MOBJECT_BUFFER * LEFT
else: # Pointing left
shift_dir -= label.get_right() + DEFAULT_MOBJECT_TO_MOBJECT_BUFFER * RIGHT
label.shift(shift_dir)
label.set_color(color)
label.rect = BackgroundRectangle(label)
label.add_to_back(label.rect)
return label
StringMatrixType = Union[Sequence[Sequence[str]], np.ndarray[int, np.dtype[np.str_]]]
FloatMatrixType = Union[Sequence[Sequence[float]], VectNArray]
VMobjectMatrixType = Sequence[Sequence[VMobject]]
GenericMatrixType = Union[FloatMatrixType, StringMatrixType, VMobjectMatrixType]
class Matrix(VMobject):
def __init__(
self,
matrix: Sequence[Sequence[str | float | VMobject]],
v_buff: float = 0.8,
h_buff: float = 1.0,
matrix: GenericMatrixType,
v_buff: float = 0.5,
h_buff: float = 0.5,
bracket_h_buff: float = 0.2,
bracket_v_buff: float = 0.25,
add_background_rectangles_to_entries: bool = False,
include_background_rectangle: bool = False,
height: float | None = None,
element_config: dict = dict(),
element_alignment_corner: Vect3 = DOWN,
**kwargs
ellipses_row: Optional[int] = None,
ellipses_col: Optional[int] = None,
):
"""
Matrix can either include numbers, tex_strings,
@ -90,83 +43,109 @@ class Matrix(VMobject):
"""
super().__init__()
mob_matrix = self.matrix_to_mob_matrix(matrix, **kwargs)
self.mob_matrix = mob_matrix
self.mob_matrix = self.create_mobject_matrix(
matrix, v_buff, h_buff, element_alignment_corner,
**element_config
)
self.organize_mob_matrix(mob_matrix, v_buff, h_buff, element_alignment_corner)
self.elements = VGroup(*it.chain(*mob_matrix))
self.add(self.elements)
self.add_brackets(bracket_v_buff, bracket_h_buff)
# Create helpful groups for the elements
n_cols = len(self.mob_matrix[0])
self.elements = [elem for row in self.mob_matrix for elem in row]
self.columns = VGroup(*(
VGroup(*(row[i] for row in self.mob_matrix))
for i in range(n_cols)
))
self.rows = VGroup(*(VGroup(*row) for row in self.mob_matrix))
if height is not None:
self.rows.set_height(height - 2 * bracket_v_buff)
self.brackets = self.create_brackets(self.rows, bracket_v_buff, bracket_h_buff)
self.ellipses = []
# Add elements and brackets
self.add(*self.elements)
self.add(*self.brackets)
self.center()
if add_background_rectangles_to_entries:
for mob in self.elements:
mob.add_background_rectangle()
if include_background_rectangle:
self.add_background_rectangle()
# Potentially add ellipses
self.swap_entries_for_ellipses(
ellipses_row,
ellipses_col,
)
def element_to_mobject(self, element: str | float | VMobject, **config) -> VMobject:
if isinstance(element, VMobject):
return element
return Tex(str(element), **config)
def copy(self, deep: bool = False):
result = super().copy(deep)
self_family = self.get_family()
copy_family = result.get_family()
for attr in ["elements", "ellipses"]:
setattr(result, attr, [
copy_family[self_family.index(mob)]
for mob in getattr(self, attr)
])
return result
def matrix_to_mob_matrix(
def create_mobject_matrix(
self,
matrix: Sequence[Sequence[str | float | VMobject]],
**config
) -> list[list[VMobject]]:
return [
[
self.element_to_mobject(item, **config)
for item in row
]
for row in matrix
]
def organize_mob_matrix(
self,
matrix: list[list[VMobject]],
matrix: GenericMatrixType,
v_buff: float,
h_buff: float,
aligned_corner: Vect3,
) -> Self:
for i, row in enumerate(matrix):
**element_config
) -> VMobjectMatrixType:
"""
Creates and organizes the matrix of mobjects
"""
mob_matrix = [
[
self.element_to_mobject(element, **element_config)
for element in row
]
for row in matrix
]
max_width = max(elem.get_width() for row in mob_matrix for elem in row)
max_height = max(elem.get_height() for row in mob_matrix for elem in row)
x_step = (max_width + h_buff) * RIGHT
y_step = (max_height + v_buff) * DOWN
for i, row in enumerate(mob_matrix):
for j, elem in enumerate(row):
mob = matrix[i][j]
mob.move_to(
i * v_buff * DOWN + j * h_buff * RIGHT,
aligned_corner
)
return self
elem.move_to(i * y_step + j * x_step, aligned_corner)
return mob_matrix
def add_brackets(self, v_buff: float, h_buff: float) -> Self:
height = len(self.mob_matrix)
def element_to_mobject(self, element, **config) -> VMobject:
if isinstance(element, VMobject):
return element
elif isinstance(element, float | complex):
return DecimalNumber(element, **config)
else:
return Tex(str(element), **config)
def create_brackets(self, rows, v_buff: float, h_buff: float) -> VGroup:
brackets = Tex("".join((
R"\left[\begin{array}{c}",
*height * [R"\quad \\"],
*len(rows) * [R"\quad \\"],
R"\end{array}\right]",
)))
brackets.set_height(self.get_height() + v_buff)
brackets.set_height(rows.get_height() + v_buff)
l_bracket = brackets[:len(brackets) // 2]
r_bracket = brackets[len(brackets) // 2:]
l_bracket.next_to(self, LEFT, h_buff)
r_bracket.next_to(self, RIGHT, h_buff)
brackets.set_submobjects([l_bracket, r_bracket])
self.brackets = brackets
self.add(*brackets)
return self
l_bracket.next_to(rows, LEFT, h_buff)
r_bracket.next_to(rows, RIGHT, h_buff)
return VGroup(l_bracket, r_bracket)
def get_column(self, index: int):
if not 0 <= index < len(self.columns):
raise IndexError(f"Index {index} out of bound for matrix with {len(self.columns)} columns")
return self.columns[index]
def get_row(self, index: int):
if not 0 <= index < len(self.rows):
raise IndexError(f"Index {index} out of bound for matrix with {len(self.rows)} rows")
return self.rows[index]
def get_columns(self) -> VGroup:
return VGroup(*[
VGroup(*[row[i] for row in self.mob_matrix])
for i in range(len(self.mob_matrix[0]))
])
return self.columns
def get_rows(self) -> VGroup:
return VGroup(*[
VGroup(*row)
for row in self.mob_matrix
])
return self.rows
def set_column_colors(self, *colors: ManimColor) -> Self:
columns = self.get_columns()
@ -179,61 +158,138 @@ class Matrix(VMobject):
mob.add_background_rectangle()
return self
def get_mob_matrix(self) -> list[list[Mobject]]:
def swap_entry_for_dots(self, entry, dots):
dots.move_to(entry)
entry.become(dots)
if entry in self.elements:
self.elements.remove(entry)
if entry not in self.ellipses:
self.ellipses.append(entry)
def swap_entries_for_ellipses(
self,
row_index: Optional[int] = None,
col_index: Optional[int] = None,
height_ratio: float = 0.65,
width_ratio: float = 0.4
):
rows = self.get_rows()
cols = self.get_columns()
avg_row_height = rows.get_height() / len(rows)
vdots_height = height_ratio * avg_row_height
avg_col_width = cols.get_width() / len(cols)
hdots_width = width_ratio * avg_col_width
use_vdots = row_index is not None and -len(rows) <= row_index < len(rows)
use_hdots = col_index is not None and -len(cols) <= col_index < len(cols)
if use_vdots:
for column in cols:
# Add vdots
dots = Tex(R"\vdots")
dots.set_height(vdots_height)
self.swap_entry_for_dots(column[row_index], dots)
if use_hdots:
for row in rows:
# Add hdots
dots = Tex(R"\hdots")
dots.set_width(hdots_width)
self.swap_entry_for_dots(row[col_index], dots)
if use_vdots and use_hdots:
rows[row_index][col_index].rotate(-45 * DEGREES)
return self
def get_mob_matrix(self) -> VMobjectMatrixType:
return self.mob_matrix
def get_entries(self) -> VGroup:
return self.elements
return VGroup(*self.elements)
def get_brackets(self) -> VGroup:
return self.brackets
return VGroup(*self.brackets)
def get_ellipses(self) -> VGroup:
return VGroup(*self.ellipses)
class DecimalMatrix(Matrix):
def element_to_mobject(self, element: float, num_decimal_places: int = 1, **config) -> DecimalNumber:
return DecimalNumber(element, num_decimal_places=num_decimal_places, **config)
class IntegerMatrix(Matrix):
def __init__(
self,
matrix: npt.ArrayLike,
element_alignment_corner: Vect3 = UP,
**kwargs
matrix: FloatMatrixType,
num_decimal_places: int = 2,
decimal_config: dict = dict(),
**config
):
super().__init__(matrix, element_alignment_corner=element_alignment_corner, **kwargs)
self.float_matrix = matrix
super().__init__(
matrix,
element_config=dict(
num_decimal_places=num_decimal_places,
**decimal_config
),
**config
)
def element_to_mobject(self, element: int, **config) -> Integer:
return Integer(element, **config)
def element_to_mobject(self, element, **decimal_config) -> DecimalNumber:
return DecimalNumber(element, **decimal_config)
class IntegerMatrix(DecimalMatrix):
def __init__(
self,
matrix: FloatMatrixType,
num_decimal_places: int = 0,
decimal_config: dict = dict(),
**config
):
super().__init__(matrix, num_decimal_places, decimal_config, **config)
class TexMatrix(Matrix):
def __init__(
self,
matrix: StringMatrixType,
tex_config: dict = dict(),
**config,
):
super().__init__(
matrix,
element_config=tex_config,
**config
)
class MobjectMatrix(Matrix):
def __init__(
self,
group: VGroup,
n_rows: int | None = None,
n_cols: int | None = None,
height: float = 4.0,
element_alignment_corner=ORIGIN,
**config,
):
# Have fallback defaults of n_rows and n_cols
n_mobs = len(group)
if n_rows is None:
n_rows = int(np.sqrt(n_mobs)) if n_cols is None else n_mobs // n_cols
if n_cols is None:
n_cols = n_mobs // n_rows
if len(group) < n_rows * n_cols:
raise Exception("Input to MobjectMatrix must have at least n_rows * n_cols entries")
mob_matrix = [
[group[n * n_cols + k] for k in range(n_cols)]
for n in range(n_rows)
]
config.update(
height=height,
element_alignment_corner=element_alignment_corner,
)
super().__init__(mob_matrix, **config)
def element_to_mobject(self, element: VMobject, **config) -> VMobject:
return element
def get_det_text(
matrix: Matrix,
determinant: int | str | None = None,
background_rect: bool = False,
initial_scale_factor: int = 2
) -> VGroup:
parens = Tex("()")
parens.scale(initial_scale_factor)
parens.stretch_to_fit_height(matrix.get_height())
l_paren, r_paren = parens.split()
l_paren.next_to(matrix, LEFT, buff=0.1)
r_paren.next_to(matrix, RIGHT, buff=0.1)
det = TexText("det")
det.scale(initial_scale_factor)
det.next_to(l_paren, LEFT, buff=0.1)
if background_rect:
det.add_background_rectangle()
det_text = VGroup(det, l_paren, r_paren)
if determinant is not None:
eq = Tex("=")
eq.next_to(r_paren, RIGHT, buff=0.1)
result = Tex(str(determinant))
result.next_to(eq, RIGHT, buff=0.2)
det_text.add(eq, result)
return det_text

View file

@ -39,20 +39,23 @@ from manimlib.utils.iterables import resize_with_interpolation
from manimlib.utils.bezier import integer_interpolate
from manimlib.utils.bezier import interpolate
from manimlib.utils.paths import straight_path
from manimlib.utils.simple_functions import get_parameters
from manimlib.utils.shaders import get_colormap_code
from manimlib.utils.space_ops import angle_of_vector
from manimlib.utils.space_ops import get_norm
from manimlib.utils.space_ops import rotation_matrix_transpose
from typing import TYPE_CHECKING
from typing import TypeVar, Generic, Iterable
SubmobjectType = TypeVar('SubmobjectType', bound='Mobject')
if TYPE_CHECKING:
from typing import Callable, Iterable, Iterator, Union, Tuple, Optional
from typing import Callable, Iterator, Union, Tuple, Optional, Any
import numpy.typing as npt
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, UniformDict, Self
from moderngl.context import Context
T = TypeVar('T')
TimeBasedUpdater = Callable[["Mobject", float], "Mobject" | None]
NonTimeUpdater = Callable[["Mobject"], "Mobject" | None]
Updater = Union[TimeBasedUpdater, NonTimeUpdater]
@ -88,21 +91,20 @@ class Mobject(object):
self.opacity = opacity
self.shading = shading
self.texture_paths = texture_paths
self._is_fixed_in_frame = is_fixed_in_frame
self.depth_test = depth_test
# Internal state
self.submobjects: list[Mobject] = []
self.parents: list[Mobject] = []
self.family: list[Mobject] = [self]
self.family: list[Mobject] | None = [self]
self.locked_data_keys: set[str] = set()
self.const_data_keys: set[str] = set()
self.locked_uniform_keys: set[str] = set()
self.needs_new_bounding_box: bool = True
self._is_animating: bool = False
self.saved_state = None
self.target = None
self.bounding_box: Vect3Array = np.zeros((3, 3))
self._is_animating: bool = False
self._needs_new_bounding_box: bool = True
self._shaders_initialized: bool = False
self._data_has_changed: bool = True
self.shader_code_replacements: dict[str, str] = dict()
@ -117,6 +119,8 @@ class Mobject(object):
if self.depth_test:
self.apply_depth_test()
if is_fixed_in_frame:
self.fix_in_frame()
def __str__(self):
return self.__class__.__name__
@ -134,7 +138,7 @@ class Mobject(object):
def init_uniforms(self):
self.uniforms: UniformDict = {
"is_fixed_in_frame": float(self._is_fixed_in_frame),
"is_fixed_in_frame": 0.0,
"shading": np.array(self.shading, dtype=float),
}
@ -154,9 +158,47 @@ class Mobject(object):
@property
def animate(self) -> _AnimationBuilder:
# Borrowed from https://github.com/ManimCommunity/manim/
"""
Methods called with Mobject.animate.method() can be passed
into a Scene.play call, as if you were calling
ApplyMethod(mobject.method)
Borrowed from https://github.com/ManimCommunity/manim/
"""
return _AnimationBuilder(self)
@property
def always(self) -> _UpdaterBuilder:
"""
Methods called with mobject.always.method(*args, **kwargs)
will result in the call mobject.method(*args, **kwargs)
on every frame
"""
return _UpdaterBuilder(self)
@property
def f_always(self) -> _FunctionalUpdaterBuilder:
"""
Similar to Mobject.always, but with the intent that arguments
are functions returning the corresponding type fit for the method
Methods called with
mobject.f_always.method(
func1, func2, ...,
kwarg1=kw_func1,
kwarg2=kw_func2,
...
)
will result in the call
mobject.method(
func1(), func2(), ...,
kwarg1=kw_func1(),
kwarg2=kw_func2(),
...
)
on every frame
"""
return _FunctionalUpdaterBuilder(self)
def note_changed_data(self, recurse_up: bool = True) -> Self:
self._data_has_changed = True
if recurse_up:
@ -164,20 +206,23 @@ class Mobject(object):
mob.note_changed_data()
return self
def affects_data(func: Callable):
@staticmethod
def affects_data(func: Callable[..., T]) -> Callable[..., T]:
@wraps(func)
def wrapper(self, *args, **kwargs):
func(self, *args, **kwargs)
result = func(self, *args, **kwargs)
self.note_changed_data()
return result
return wrapper
def affects_family_data(func: Callable):
@staticmethod
def affects_family_data(func: Callable[..., T]) -> Callable[..., T]:
@wraps(func)
def wrapper(self, *args, **kwargs):
func(self, *args, **kwargs)
result = func(self, *args, **kwargs)
for mob in self.family_members_with_points():
mob.note_changed_data()
return self
return result
return wrapper
# Only these methods should directly affect points
@ -285,9 +330,9 @@ class Mobject(object):
return len(self.get_points()) > 0
def get_bounding_box(self) -> Vect3Array:
if self.needs_new_bounding_box:
if self._needs_new_bounding_box:
self.bounding_box[:] = self.compute_bounding_box()
self.needs_new_bounding_box = False
self._needs_new_bounding_box = False
return self.bounding_box
def compute_bounding_box(self) -> Vect3Array:
@ -314,7 +359,7 @@ class Mobject(object):
recurse_up: bool = True
) -> Self:
for mob in self.get_family(recurse_down):
mob.needs_new_bounding_box = True
mob._needs_new_bounding_box = True
if recurse_up:
for parent in self.parents:
parent.refresh_bounding_box()
@ -347,7 +392,7 @@ class Mobject(object):
# Family matters
def __getitem__(self, value: int | slice) -> Self:
def __getitem__(self, value: int | slice) -> Mobject:
if isinstance(value, slice):
GroupClass = self.get_group_class()
return GroupClass(*self.split().__getitem__(value))
@ -363,23 +408,26 @@ class Mobject(object):
return self.submobjects
@affects_data
def assemble_family(self) -> Self:
sub_families = (sm.get_family() for sm in self.submobjects)
self.family = [self, *it.chain(*sub_families)]
self.refresh_has_updater_status()
self.refresh_bounding_box()
def note_changed_family(self, only_changed_order=False) -> Self:
self.family = None
if not only_changed_order:
self.refresh_has_updater_status()
self.refresh_bounding_box()
for parent in self.parents:
parent.assemble_family()
parent.note_changed_family()
return self
def get_family(self, recurse: bool = True) -> list[Self]:
if recurse:
return self.family
else:
def get_family(self, recurse: bool = True) -> list[Mobject]:
if not recurse:
return [self]
if self.family is None:
# Reconstruct and save
sub_families = (sm.get_family() for sm in self.submobjects)
self.family = [self, *it.chain(*sub_families)]
return self.family
def family_members_with_points(self) -> list[Self]:
return [m for m in self.family if len(m.data) > 0]
def family_members_with_points(self) -> list[Mobject]:
return [m for m in self.get_family() if len(m.data) > 0]
def get_ancestors(self, extended: bool = False) -> list[Mobject]:
"""
@ -410,7 +458,7 @@ class Mobject(object):
self.submobjects.append(mobject)
if self not in mobject.parents:
mobject.parents.append(self)
self.assemble_family()
self.note_changed_family()
return self
def remove(
@ -426,7 +474,7 @@ class Mobject(object):
if parent in child.parents:
child.parents.remove(parent)
if reassemble:
parent.assemble_family()
parent.note_changed_family()
return self
def clear(self) -> Self:
@ -443,12 +491,12 @@ class Mobject(object):
old_submob.parents.remove(self)
self.submobjects[index] = new_submob
new_submob.parents.append(self)
self.assemble_family()
self.note_changed_family()
return self
def insert_submobject(self, index: int, new_submob: Mobject) -> Self:
self.submobjects.insert(index, new_submob)
self.assemble_family()
self.note_changed_family()
return self
def set_submobjects(self, submobject_list: list[Mobject]) -> Self:
@ -495,12 +543,11 @@ class Mobject(object):
fill_rows_first: bool = True
) -> Self:
submobs = self.submobjects
if n_rows is None and n_cols is None:
n_rows = int(np.sqrt(len(submobs)))
n_submobs = len(submobs)
if n_rows is None:
n_rows = len(submobs) // n_cols
n_rows = int(np.sqrt(n_submobs)) if n_cols is None else n_submobs // n_cols
if n_cols is None:
n_cols = len(submobs) // n_rows
n_cols = n_submobs // n_rows
if buff is not None:
h_buff = buff
@ -561,7 +608,7 @@ class Mobject(object):
self.submobjects.sort(key=submob_func)
else:
self.submobjects.sort(key=lambda m: point_to_num_func(m.get_center()))
self.assemble_family()
self.note_changed_family(only_changed_order=True)
return self
def shuffle(self, recurse: bool = False) -> Self:
@ -569,17 +616,18 @@ class Mobject(object):
for submob in self.submobjects:
submob.shuffle(recurse=True)
random.shuffle(self.submobjects)
self.assemble_family()
self.note_changed_family(only_changed_order=True)
return self
def reverse_submobjects(self) -> Self:
self.submobjects.reverse()
self.assemble_family()
self.note_changed_family(only_changed_order=True)
return self
# Copying and serialization
def stash_mobject_pointers(func: Callable):
@staticmethod
def stash_mobject_pointers(func: Callable[..., T]) -> Callable[..., T]:
@wraps(func)
def wrapper(self, *args, **kwargs):
uncopied_attrs = ["parents", "target", "saved_state"]
@ -637,8 +685,7 @@ class Mobject(object):
# Similarly, instead of calling match_updaters, since we know the status
# won't have changed, just directly match.
result.non_time_updaters = list(self.non_time_updaters)
result.time_based_updaters = list(self.time_based_updaters)
result.updaters = list(self.updaters)
result._data_has_changed = True
result._shaders_initialized = False
@ -646,7 +693,7 @@ class Mobject(object):
for attr, value in self.__dict__.items():
if isinstance(value, Mobject) and value is not self:
if value in family:
setattr(result, attr, result.family[self.family.index(value)])
setattr(result, attr, result.family[family.index(value)])
elif isinstance(value, np.ndarray):
setattr(result, attr, value.copy())
return result
@ -698,7 +745,7 @@ class Mobject(object):
sm1.texture_paths = sm2.texture_paths
sm1.depth_test = sm2.depth_test
sm1.render_primitive = sm2.render_primitive
sm1.needs_new_bounding_box = sm2.needs_new_bounding_box
sm1._needs_new_bounding_box = sm2._needs_new_bounding_box
# Make sure named family members carry over
for attr, value in list(mobject.__dict__.items()):
if isinstance(value, Mobject) and value in family2:
@ -782,78 +829,57 @@ class Mobject(object):
# Updating
def init_updaters(self):
self.time_based_updaters: list[TimeBasedUpdater] = []
self.non_time_updaters: list[NonTimeUpdater] = []
self.has_updaters: bool = False
self.updaters: list[Updater] = list()
self._has_updaters_in_family: Optional[bool] = False
self.updating_suspended: bool = False
def update(self, dt: float = 0, recurse: bool = True) -> Self:
if not self.has_updaters or self.updating_suspended:
if not self.has_updaters() or self.updating_suspended:
return self
if recurse:
for submob in self.submobjects:
submob.update(dt, recurse)
for updater in self.time_based_updaters:
updater(self, dt)
for updater in self.non_time_updaters:
updater(self)
for updater in self.updaters:
# This is hacky, but if an updater takes dt as an arg,
# it will be passed the change in time from here
if "dt" in updater.__code__.co_varnames:
updater(self, dt=dt)
else:
updater(self)
return self
def get_time_based_updaters(self) -> list[TimeBasedUpdater]:
return self.time_based_updaters
def has_time_based_updater(self) -> bool:
return len(self.time_based_updaters) > 0
def get_updaters(self) -> list[Updater]:
return self.time_based_updaters + self.non_time_updaters
return self.updaters
def get_family_updaters(self) -> list[Updater]:
return list(it.chain(*[sm.get_updaters() for sm in self.get_family()]))
def add_updater(
self,
update_function: Updater,
index: int | None = None,
call_updater: bool = True
) -> Self:
if "dt" in get_parameters(update_function):
updater_list = self.time_based_updaters
else:
updater_list = self.non_time_updaters
if index is None:
updater_list.append(update_function)
else:
updater_list.insert(index, update_function)
self.refresh_has_updater_status()
for parent in self.parents:
parent.has_updaters = True
if call_updater:
def add_updater(self, update_func: Updater, call: bool = True) -> Self:
self.updaters.append(update_func)
if call:
self.update(dt=0)
self.refresh_has_updater_status()
return self
def remove_updater(self, update_function: Updater) -> Self:
for updater_list in [self.time_based_updaters, self.non_time_updaters]:
while update_function in updater_list:
updater_list.remove(update_function)
def insert_updater(self, update_func: Updater, index=0):
self.updaters.insert(index, update_func)
self.refresh_has_updater_status()
return self
def remove_updater(self, update_func: Updater) -> Self:
while update_func in self.updaters:
self.updaters.remove(update_func)
self.refresh_has_updater_status()
return self
def clear_updaters(self, recurse: bool = True) -> Self:
self.time_based_updaters = []
self.non_time_updaters = []
if recurse:
for submob in self.submobjects:
submob.clear_updaters()
self.refresh_has_updater_status()
for mob in self.get_family(recurse):
mob.updaters = []
mob._has_updaters_in_family = False
for parent in self.get_ancestors():
parent._has_updaters_in_family = False
return self
def match_updaters(self, mobject: Mobject) -> Self:
self.clear_updaters()
for updater in mobject.get_updaters():
self.add_updater(updater)
self.updaters = list(mobject.updaters)
self.refresh_has_updater_status()
return self
def suspend_updating(self, recurse: bool = True) -> Self:
@ -874,14 +900,24 @@ class Mobject(object):
self.update(dt=0, recurse=recurse)
return self
def has_updaters(self) -> bool:
if self._has_updaters_in_family is None:
# Recompute and save
self._has_updaters_in_family = bool(self.updaters) or any(
sm.has_updaters() for sm in self.submobjects
)
return self._has_updaters_in_family
def refresh_has_updater_status(self) -> Self:
self.has_updaters = any(mob.get_updaters() for mob in self.get_family())
self._has_updaters_in_family = None
for parent in self.parents:
parent.refresh_has_updater_status()
return self
# Check if mark as static or not for camera
def is_changing(self) -> bool:
return self._is_animating or self.has_updaters
return self._is_animating or self.has_updaters()
def set_animating_status(self, is_animating: bool, recurse: bool = True) -> Self:
for mob in (*self.get_family(recurse), *self.get_ancestors()):
@ -1368,7 +1404,7 @@ class Mobject(object):
return rgb_to_hex(self.data["rgba"][0, :3])
def get_opacity(self) -> float:
return self.data["rgba"][0, 3]
return float(self.data["rgba"][0, 3])
def set_color_by_gradient(self, *colors: ManimColor) -> Self:
if self.has_points():
@ -1816,13 +1852,13 @@ class Mobject(object):
interpolate can skip this, and so that it's not
read into the shader_wrapper objects needlessly
"""
if self.has_updaters:
if self.has_updaters():
return self
self.locked_data_keys = set(keys)
return self
def lock_uniforms(self, keys: Iterable[str]) -> Self:
if self.has_updaters:
if self.has_updaters():
return self
self.locked_uniform_keys = set(keys)
return self
@ -1864,7 +1900,8 @@ class Mobject(object):
# Operations touching shader uniforms
def affects_shader_info_id(func: Callable):
@staticmethod
def affects_shader_info_id(func: Callable[..., T]) -> Callable[..., T]:
@wraps(func)
def wrapper(self, *args, **kwargs):
result = func(self, *args, **kwargs)
@ -2126,19 +2163,29 @@ class Mobject(object):
raise Exception(message.format(caller_name))
class Group(Mobject):
def __init__(self, *mobjects: Mobject, **kwargs):
if not all([isinstance(m, Mobject) for m in mobjects]):
raise Exception("All submobjects must be of type Mobject")
Mobject.__init__(self, **kwargs)
self.add(*mobjects)
if any(m.is_fixed_in_frame() for m in mobjects):
self.fix_in_frame()
class Group(Mobject, Generic[SubmobjectType]):
def __init__(self, *mobjects: SubmobjectType | Iterable[SubmobjectType], **kwargs):
super().__init__(**kwargs)
self._ingest_args(*mobjects)
def _ingest_args(self, *args: Mobject | Iterable[Mobject]):
if len(args) == 0:
return
if all(isinstance(mob, Mobject) for mob in args):
self.add(*args)
elif isinstance(args[0], Iterable):
self.add(*args[0])
else:
raise Exception(f"Invalid argument to Group of type {type(args[0])}")
def __add__(self, other: Mobject | Group) -> Self:
assert(isinstance(other, Mobject))
assert isinstance(other, Mobject)
return self.add(other)
# This is just here to make linters happy with references to things like Group(...)[0]
def __getitem__(self, index) -> SubmobjectType:
return super().__getitem__(index)
class Point(Mobject):
def __init__(
@ -2245,3 +2292,35 @@ def override_animate(method):
return animation_method
return decorator
class _UpdaterBuilder:
def __init__(self, mobject: Mobject):
self.mobject = mobject
def __getattr__(self, method_name: str):
def add_updater(*method_args, **method_kwargs):
self.mobject.add_updater(
lambda m: getattr(m, method_name)(*method_args, **method_kwargs)
)
return self
return add_updater
class _FunctionalUpdaterBuilder:
def __init__(self, mobject: Mobject):
self.mobject = mobject
def __getattr__(self, method_name: str):
def add_updater(*method_args, **method_kwargs):
self.mobject.add_updater(
lambda m: getattr(m, method_name)(
*(arg() for arg in method_args),
**{
key: value()
for key, value in method_kwargs.items()
}
)
)
return self
return add_updater

View file

@ -16,7 +16,7 @@ from manimlib.utils.simple_functions import fdiv
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Iterable
from typing import Iterable, Optional
from manimlib.typing import ManimColor, Vect3, Vect3Array, VectN, RangeSpecifier
@ -28,13 +28,14 @@ class NumberLine(Line):
stroke_width: float = 2.0,
# How big is one one unit of this number line in terms of absolute spacial distance
unit_size: float = 1.0,
width: float | None = None,
width: Optional[float] = None,
include_ticks: bool = True,
tick_size: float = 0.1,
longer_tick_multiple: float = 1.5,
tick_offset: float = 0.0,
# Change name
numbers_with_elongated_ticks: list[float] = [],
big_tick_spacing: Optional[float] = None,
big_tick_numbers: list[float] = [],
include_numbers: bool = False,
line_to_number_direction: Vect3 = DOWN,
line_to_number_buff: float = MED_SMALL_BUFF,
@ -54,7 +55,14 @@ class NumberLine(Line):
self.tick_size = tick_size
self.longer_tick_multiple = longer_tick_multiple
self.tick_offset = tick_offset
self.numbers_with_elongated_ticks = list(numbers_with_elongated_ticks)
if big_tick_spacing is not None:
self.big_tick_numbers = np.arange(
x_range[0],
x_range[1] + big_tick_spacing,
big_tick_spacing,
)
else:
self.big_tick_numbers = list(big_tick_numbers)
self.line_to_number_direction = line_to_number_direction
self.line_to_number_buff = line_to_number_buff
self.include_tip = include_tip
@ -101,7 +109,7 @@ class NumberLine(Line):
ticks = VGroup()
for x in self.get_tick_range():
size = self.tick_size
if np.isclose(self.numbers_with_elongated_ticks, x).any():
if np.isclose(self.big_tick_numbers, x).any():
size *= self.longer_tick_multiple
ticks.add(self.get_tick(x, size))
self.add(ticks)
@ -210,7 +218,7 @@ class UnitInterval(NumberLine):
self,
x_range: RangeSpecifier = (0, 1, 0.1),
unit_size: float = 10,
numbers_with_elongated_ticks: list[float] = [0, 1],
big_tick_numbers: list[float] = [0, 1],
decimal_number_config: dict = dict(
num_decimal_places=1,
)
@ -218,6 +226,6 @@ class UnitInterval(NumberLine):
super().__init__(
x_range=x_range,
unit_size=unit_size,
numbers_with_elongated_ticks=numbers_with_elongated_ticks,
big_tick_numbers=big_tick_numbers,
decimal_number_config=decimal_number_config,
)

View file

@ -1,4 +1,5 @@
from __future__ import annotations
from functools import lru_cache
import numpy as np
@ -17,6 +18,11 @@ if TYPE_CHECKING:
T = TypeVar("T", bound=VMobject)
@lru_cache()
def char_to_cahced_mob(char: str, **text_config):
return Text(char, **text_config)
class DecimalNumber(VMobject):
def __init__(
self,
@ -46,7 +52,6 @@ class DecimalNumber(VMobject):
self.edge_to_fix = edge_to_fix
self.font_size = font_size
self.text_config = dict(text_config)
self.char_to_mob_map = dict()
super().__init__(
color=color,
@ -59,36 +64,44 @@ class DecimalNumber(VMobject):
self.init_colors()
def set_submobjects_from_number(self, number: float | complex) -> None:
# Create the submobject list
self.number = number
self.set_submobjects([])
self.text_config["font_size"] = self.get_font_size()
num_string = self.num_string = self.get_num_string(number)
self.add(*map(self.char_to_mob, num_string))
self.num_string = self.get_num_string(number)
# Add non-numerical bits
# Submob_templates will be a list of cached Tex and Text mobjects,
# with the intent of calling .copy or .become on them
submob_templates = list(map(self.char_to_mob, self.num_string))
if self.show_ellipsis:
dots = self.char_to_mob("...")
dots.arrange(RIGHT, buff=2 * dots[0].get_width())
self.add(dots)
submob_templates.append(dots)
if self.unit is not None:
self.unit_sign = Tex(self.unit, font_size=self.get_font_size())
self.add(self.unit_sign)
submob_templates.append(self.char_to_mob(self.unit))
self.arrange(
buff=self.digit_buff_per_font_unit * self.get_font_size(),
aligned_edge=DOWN
)
# Set internals
font_size = self.get_font_size()
if len(submob_templates) == len(self.submobjects):
for sm, smt in zip(self.submobjects, submob_templates):
sm.become(smt)
sm.scale(font_size / smt.font_size)
else:
self.set_submobjects([
smt.copy().scale(font_size / smt.font_size)
for smt in submob_templates
])
# Handle alignment of parts that should be aligned
# to the bottom
for i, c in enumerate(num_string):
if c == "" and len(num_string) > i + 1:
digit_buff = self.digit_buff_per_font_unit * font_size
self.arrange(RIGHT, buff=digit_buff, aligned_edge=DOWN)
# Handle alignment of special characters
for i, c in enumerate(self.num_string):
if c == "" and len(self.num_string) > i + 1:
self[i].align_to(self[i + 1], UP)
self[i].shift(self[i + 1].get_height() * DOWN / 2)
elif c == ",":
self[i].shift(self[i].get_height() * DOWN / 2)
if self.unit and self.unit.startswith("^"):
self.unit_sign.align_to(self, UP)
self[-1].align_to(self, UP)
if self.include_background_rectangle:
self.add_background_rectangle()
@ -111,12 +124,8 @@ class DecimalNumber(VMobject):
num_string = num_string.replace("-", "")
return num_string
def char_to_mob(self, char: str) -> Tex | Text:
if char not in self.char_to_mob_map:
self.char_to_mob_map[char] = Text(char, **self.text_config)
result = self.char_to_mob_map[char].copy()
result.scale(self.get_font_size() / result.font_size)
return result
def char_to_mob(self, char: str) -> Text:
return char_to_cahced_mob(char, **self.text_config)
def init_uniforms(self) -> None:
super().init_uniforms()
@ -171,7 +180,8 @@ class DecimalNumber(VMobject):
self.set_submobjects_from_number(number)
self.move_to(move_to_point, self.edge_to_fix)
self.set_style(**style)
self.fix_in_frame(self._is_fixed_in_frame)
for submob in self.get_family():
submob.uniforms.update(self.uniforms)
return self
def _handle_scale_side_effects(self, scale_factor: float) -> Self:

View file

@ -30,6 +30,8 @@ class SurroundingRectangle(Rectangle):
super().__init__(color=color, **kwargs)
self.buff = buff
self.surround(mobject)
if mobject.is_fixed_in_frame():
self.fix_in_frame()
def surround(self, mobject, buff=None) -> Self:
self.mobject = mobject
@ -120,13 +122,9 @@ class Underline(Line):
stretch_factor=1.2,
**kwargs
):
super().__init__(
LEFT, RIGHT,
stroke_color=stroke_color,
stroke_width=stroke_width,
**kwargs
)
self.insert_n_curves(30)
super().__init__(LEFT, RIGHT, **kwargs)
if not isinstance(stroke_width, (float, int)):
self.insert_n_curves(len(stroke_width) - 2)
self.set_stroke(stroke_color, stroke_width)
self.set_width(mobject.get_width() * stretch_factor)
self.next_to(mobject, DOWN, buff=buff)

View file

@ -91,8 +91,9 @@ class Brace(Tex):
return text_mob
def get_tex(self, *tex: str, **kwargs) -> Tex:
tex_mob = Tex(*tex)
self.put_at_tip(tex_mob, **kwargs)
buff = kwargs.pop("buff", SMALL_BUFF)
tex_mob = Tex(*tex, **kwargs)
self.put_at_tip(tex_mob, buff=buff)
return tex_mob
def get_tip(self) -> np.ndarray:

View file

@ -2,6 +2,7 @@ from __future__ import annotations
import numpy as np
import itertools as it
import random
from manimlib.animation.composition import AnimationGroup
from manimlib.animation.rotation import Rotating
@ -24,6 +25,7 @@ from manimlib.constants import LEFT
from manimlib.constants import LEFT
from manimlib.constants import MED_LARGE_BUFF
from manimlib.constants import MED_SMALL_BUFF
from manimlib.constants import LARGE_BUFF
from manimlib.constants import ORIGIN
from manimlib.constants import OUT
from manimlib.constants import PI
@ -41,6 +43,7 @@ from manimlib.constants import WHITE
from manimlib.constants import YELLOW
from manimlib.constants import TAU
from manimlib.mobject.boolean_ops import Difference
from manimlib.mobject.boolean_ops import Union
from manimlib.mobject.geometry import Arc
from manimlib.mobject.geometry import Circle
from manimlib.mobject.geometry import Dot
@ -51,6 +54,7 @@ from manimlib.mobject.geometry import Square
from manimlib.mobject.geometry import AnnularSector
from manimlib.mobject.mobject import Mobject
from manimlib.mobject.numbers import Integer
from manimlib.mobject.shape_matchers import SurroundingRectangle
from manimlib.mobject.svg.svg_mobject import SVGMobject
from manimlib.mobject.svg.tex_mobject import Tex
from manimlib.mobject.svg.tex_mobject import TexText
@ -59,9 +63,13 @@ from manimlib.mobject.three_dimensions import Prismify
from manimlib.mobject.three_dimensions import VCube
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.mobject.svg.text_mobject import Text
from manimlib.utils.bezier import interpolate
from manimlib.utils.iterables import adjacent_pairs
from manimlib.utils.rate_functions import linear
from manimlib.utils.space_ops import angle_of_vector
from manimlib.utils.space_ops import compass_directions
from manimlib.utils.space_ops import get_norm
from manimlib.utils.space_ops import midpoint
from manimlib.utils.space_ops import rotate_vector
@ -344,66 +352,76 @@ class ClockPassesTime(AnimationGroup):
)
class Bubble(SVGMobject):
class Bubble(VGroup):
file_name: str = "Bubbles_speech.svg"
bubble_center_adjustment_factor = 0.125
def __init__(
self,
content: str | VMobject | None = None,
buff: float = 1.0,
filler_shape: Tuple[float, float] = (3.0, 2.0),
pin_point: Vect3 | None = None,
direction: Vect3 = LEFT,
center_point: Vect3 = ORIGIN,
content_scale_factor: float = 0.7,
height: float = 4.0,
width: float = 8.0,
max_height: float | None = None,
max_width: float | None = None,
bubble_center_adjustment_factor: float = 0.125,
add_content: bool = True,
fill_color: ManimColor = BLACK,
fill_opacity: float = 0.8,
stroke_color: ManimColor = WHITE,
stroke_width: float = 3.0,
**kwargs
):
self.direction = LEFT # Possibly updated below by self.flip()
self.bubble_center_adjustment_factor = bubble_center_adjustment_factor
self.content_scale_factor = content_scale_factor
super().__init__(**kwargs)
self.direction = direction
super().__init__(
fill_color=fill_color,
fill_opacity=fill_opacity,
stroke_color=stroke_color,
stroke_width=stroke_width,
**kwargs
)
if content is None:
content = Rectangle(*filler_shape)
content.set_fill(opacity=0)
content.set_stroke(width=0)
elif isinstance(content, str):
content = Text(content)
self.content = content
self.center()
self.set_height(height, stretch=True)
self.set_width(width, stretch=True)
if max_height:
self.set_max_height(max_height)
if max_width:
self.set_max_width(max_width)
self.body = self.get_body(content, direction, buff)
self.body.set_fill(fill_color, fill_opacity)
self.body.set_stroke(stroke_color, stroke_width)
self.add(self.body)
if add_content:
self.add(self.content)
if pin_point is not None:
self.pin_to(pin_point)
def get_body(self, content: VMobject, direction: Vect3, buff: float) -> VMobject:
body = SVGMobject(self.file_name)
if direction[0] > 0:
self.flip()
self.content = VMobject()
body.flip()
# Resize
width = content.get_width()
height = content.get_height()
target_width = width + min(buff, height)
target_height = 1.35 * (height + buff) # Magic number?
body.set_shape(target_width, target_height)
body.move_to(content)
body.shift(self.bubble_center_adjustment_factor * body.get_height() * DOWN)
return body
def get_tip(self):
# TODO, find a better way
return self.get_corner(DOWN + self.direction) - 0.6 * self.direction
return self.get_corner(DOWN + self.direction)
def get_bubble_center(self):
factor = self.bubble_center_adjustment_factor
return self.get_center() + factor * self.get_height() * UP
def move_tip_to(self, point):
mover = VGroup(self)
if self.content is not None:
mover.add(self.content)
mover.shift(point - self.get_tip())
self.shift(point - self.get_tip())
return self
def flip(self, axis=UP):
super().flip(axis=axis)
def flip(self, axis=UP, only_body=True, **kwargs):
super().flip(axis=axis, **kwargs)
if only_body:
# Flip in place, don't use kwargs
self.content.flip(axis=axis)
if abs(axis[1]) > 0:
self.direction = -np.array(self.direction)
return self
@ -418,9 +436,9 @@ class Bubble(SVGMobject):
self.move_tip_to(mob_center + vector_from_center)
return self
def position_mobject_inside(self, mobject):
mobject.set_max_width(self.content_scale_factor * self.get_width())
mobject.set_max_height(self.content_scale_factor * self.get_height() / 1.5)
def position_mobject_inside(self, mobject, buff=MED_LARGE_BUFF):
mobject.set_max_width(self.body.get_width() - 2 * buff)
mobject.set_max_height(self.body.get_height() / 1.5 - 2 * buff)
mobject.shift(self.get_bubble_center() - mobject.get_center())
return mobject
@ -429,26 +447,110 @@ class Bubble(SVGMobject):
self.content = mobject
return self.content
def write(self, *text):
self.add_content(TexText(*text))
def write(self, text):
self.add_content(Text(text))
return self
def resize_to_content(self, buff=0.75):
width = self.content.get_width()
height = self.content.get_height()
target_width = width + min(buff, height)
target_height = 1.35 * (self.content.get_height() + buff)
tip_point = self.get_tip()
self.stretch_to_fit_width(target_width, about_point=tip_point)
self.stretch_to_fit_height(target_height, about_point=tip_point)
self.position_mobject_inside(self.content)
def resize_to_content(self, buff=1.0): # TODO
self.body.match_points(self.get_body(
self.content, self.direction, buff
))
def clear(self):
self.add_content(VMobject())
self.remove(self.content)
return self
class SpeechBubble(Bubble):
def __init__(
self,
content: str | VMobject | None = None,
buff: float = MED_SMALL_BUFF,
filler_shape: Tuple[float, float] = (2.0, 1.0),
stem_height_to_bubble_height: float = 0.5,
stem_top_x_props: Tuple[float, float] = (0.2, 0.3),
**kwargs
):
self.stem_height_to_bubble_height = stem_height_to_bubble_height
self.stem_top_x_props = stem_top_x_props
super().__init__(content, buff, filler_shape, **kwargs)
def get_body(self, content: VMobject, direction: Vect3, buff: float) -> VMobject:
rect = SurroundingRectangle(content, buff=buff)
rect.round_corners()
lp = rect.get_corner(DL)
rp = rect.get_corner(DR)
stem_height = self.stem_height_to_bubble_height * rect.get_height()
low_prop, high_prop = self.stem_top_x_props
triangle = Polygon(
interpolate(lp, rp, low_prop),
interpolate(lp, rp, high_prop),
lp + stem_height * DOWN,
)
result = Union(rect, triangle)
result.insert_n_curves(20)
if direction[0] > 0:
result.flip()
return result
class ThoughtBubble(Bubble):
def __init__(
self,
content: str | VMobject | None = None,
buff: float = SMALL_BUFF,
filler_shape: Tuple[float, float] = (2.0, 1.0),
bulge_radius: float = 0.35,
bulge_overlap: float = 0.25,
noise_factor: float = 0.1,
circle_radii: list[float] = [0.1, 0.15, 0.2],
**kwargs
):
self.bulge_radius = bulge_radius
self.bulge_overlap = bulge_overlap
self.noise_factor = noise_factor
self.circle_radii = circle_radii
super().__init__(content, buff, filler_shape, **kwargs)
def get_body(self, content: VMobject, direction: Vect3, buff: float) -> VMobject:
rect = SurroundingRectangle(content, buff)
perimeter = rect.get_arc_length()
radius = self.bulge_radius
step = (1 - self.bulge_overlap) * (2 * radius)
nf = self.noise_factor
corners = [rect.get_corner(v) for v in [DL, UL, UR, DR]]
points = []
for c1, c2 in adjacent_pairs(corners):
n_alphas = int(get_norm(c1 - c2) / step) + 1
for alpha in np.linspace(0, 1, n_alphas):
points.append(interpolate(
c1, c2, alpha + nf * (step / n_alphas) * (random.random() - 0.5)
))
cloud = Union(rect, *(
# Add bulges
Circle(radius=radius * (1 + nf * random.random())).move_to(point)
for point in points
))
cloud.set_stroke(WHITE, 2)
circles = VGroup(Circle(radius=radius) for radius in self.circle_radii)
circ_buff = 0.25 * self.circle_radii[0]
circles.arrange(UR, buff=circ_buff)
circles[1].shift(circ_buff * DR)
circles.next_to(cloud, DOWN, 4 * circ_buff, aligned_edge=LEFT)
circles.set_stroke(WHITE, 2)
result = VGroup(*circles, cloud)
if direction[0] > 0:
result.flip()
return result
class OldSpeechBubble(Bubble):
file_name: str = "Bubbles_speech.svg"
@ -456,17 +558,16 @@ class DoubleSpeechBubble(Bubble):
file_name: str = "Bubbles_double_speech.svg"
class ThoughtBubble(Bubble):
class OldThoughtBubble(Bubble):
file_name: str = "Bubbles_thought.svg"
def __init__(self, **kwargs):
Bubble.__init__(self, **kwargs)
self.submobjects.sort(
key=lambda m: m.get_bottom()[1]
)
def get_body(self, content: VMobject, direction: Vect3, buff: float) -> VMobject:
body = super().get_body(content, direction, buff)
body.sort(lambda p: p[1])
return body
def make_green_screen(self):
self.submobjects[-1].set_fill(GREEN_SCREEN, opacity=1)
self.body[-1].set_fill(GREEN_SCREEN, opacity=1)
return self

View file

@ -231,7 +231,7 @@ class Tex(StringMobject):
))
return re.findall(pattern, self.string)
def make_number_changable(
def make_number_changeable(
self,
value: float | int | str,
index: int = 0,
@ -241,7 +241,7 @@ class Tex(StringMobject):
substr = str(value)
parts = self.select_parts(substr)
if len(parts) == 0:
log.warning(f"{value} not found in Tex.make_number_changable call")
log.warning(f"{value} not found in Tex.make_number_changeable call")
return VMobject()
if index > len(parts) - 1:
log.warning(f"Requested {index}th occurance of {value}, but only {len(parts)} exist")

View file

@ -71,5 +71,5 @@ class ImageMobject(Mobject):
rgb = self.image.getpixel((
int((pw - 1) * x_alpha),
int((ph - 1) * y_alpha),
))
))[:3]
return np.array(rgb) / 255

View file

@ -133,20 +133,28 @@ class Surface(Mobject):
if len(indices) == 0:
return np.zeros((3, 0))
left = indices - 1
right = indices + 1
up = indices - nv
down = indices + nv
# For each point, find two adjacent points at indices
# step1 and step2, such that crossing points[step1] - points
# with points[step1] - points gives a normal vector
step1 = indices + 1
step2 = indices + nu
left[0] = indices[0]
right[-1] = indices[-1]
up[:nv] = indices[:nv]
down[-nv:] = indices[-nv:]
# Right edge
step1[nu - 1::nu] = indices[nu - 1::nu] + nu
step2[nu - 1::nu] = indices[nu - 1::nu] - 1
# Bottom edge
step1[-nu:] = indices[-nu:] - nu
step2[-nu:] = indices[-nu:] + 1
# Lower right point
step1[-1] = indices[-1] - 1
step2[-1] = indices[-1] - nu
points = self.get_points()
crosses = cross(
points[right] - points[left],
points[up] - points[down],
points[step1] - points,
points[step2] - points,
)
self.data["normal"] = normalize_along_axis(crosses, 1)
return self.data["normal"]

View file

@ -13,6 +13,7 @@ from manimlib.constants import JOINT_TYPE_MAP
from manimlib.constants import ORIGIN, OUT
from manimlib.constants import TAU
from manimlib.mobject.mobject import Mobject
from manimlib.mobject.mobject import Group
from manimlib.mobject.mobject import Point
from manimlib.utils.bezier import bezier
from manimlib.utils.bezier import get_quadratic_approximation_of_cubic
@ -40,7 +41,6 @@ from manimlib.utils.space_ops import get_norm
from manimlib.utils.space_ops import get_unit_normal
from manimlib.utils.space_ops import line_intersects_path
from manimlib.utils.space_ops import midpoint
from manimlib.utils.space_ops import normalize_along_axis
from manimlib.utils.space_ops import rotation_between_vectors
from manimlib.utils.space_ops import poly_line_length
from manimlib.utils.space_ops import z_to_vector
@ -48,15 +48,18 @@ from manimlib.shader_wrapper import ShaderWrapper
from manimlib.shader_wrapper import FillShaderWrapper
from typing import TYPE_CHECKING
from typing import Generic, TypeVar, Iterable
SubVmobjectType = TypeVar('SubVmobjectType', bound='VMobject')
if TYPE_CHECKING:
from typing import Callable, Iterable, Tuple, Any
from typing import Callable, Tuple, Any
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, Vect4Array, Self
from moderngl.context import Context
DEFAULT_STROKE_COLOR = GREY_A
DEFAULT_FILL_COLOR = GREY_C
class VMobject(Mobject):
fill_shader_folder: str = "quadratic_bezier_fill"
stroke_shader_folder: str = "quadratic_bezier_stroke"
@ -97,7 +100,7 @@ class VMobject(Mobject):
flat_stroke: bool = True,
use_simple_quadratic_approx: bool = False,
# Measured in pixel widths
anti_alias_width: float = 1.0,
anti_alias_width: float = 1.5,
fill_border_width: float = 0.5,
use_winding_fill: bool = True,
**kwargs
@ -187,9 +190,10 @@ class VMobject(Mobject):
recurse: bool = True
) -> Self:
self.set_rgba_array_by_color(color, opacity, 'fill_rgba', recurse)
if border_width is not None:
for mob in self.get_family(recurse):
mob.data["fill_border_width"] = border_width
if border_width is None:
border_width = 0 if self.get_fill_opacity() < 1 else 0.5
for mob in self.get_family(recurse):
mob.data["fill_border_width"] = border_width
self.note_changed_fill()
return self
@ -1415,17 +1419,23 @@ class VMobject(Mobject):
return [sw for sw in shader_wrappers if len(sw.vert_data) > 0]
class VGroup(VMobject):
def __init__(self, *vmobjects: VMobject, **kwargs):
class VGroup(Group, VMobject, Generic[SubVmobjectType]):
def __init__(self, *vmobjects: SubVmobjectType | Iterable[SubVmobjectType], **kwargs):
super().__init__(**kwargs)
self.add(*vmobjects)
if vmobjects:
self.uniforms.update(vmobjects[0].uniforms)
if any(isinstance(vmob, Mobject) and not isinstance(vmob, VMobject) for vmob in vmobjects):
raise Exception("Only VMobjects can be passed into VGroup")
self._ingest_args(*vmobjects)
if self.submobjects:
self.uniforms.update(self.submobjects[0].uniforms)
def __add__(self, other: VMobject) -> Self:
assert(isinstance(other, VMobject))
assert isinstance(other, VMobject)
return self.add(other)
# This is just here to make linters happy with references to things like VGroup(...)[0]
def __getitem__(self, index) -> SubVmobjectType:
return super().__getitem__(index)
class VectorizedPoint(Point, VMobject):
def __init__(

View file

@ -48,6 +48,7 @@ RESIZE_KEY = 't'
COLOR_KEY = 'c'
INFORMATION_KEY = 'i'
CURSOR_KEY = 'k'
COPY_FRAME_POSITION_KEY = 'p'
# Note, a lot of the functionality here is still buggy and very much a work in progress.
@ -504,8 +505,8 @@ class InteractiveScene(Scene):
self.toggle_selection_mode()
elif char == "s" and modifiers == COMMAND_MODIFIER:
self.save_selection_to_file()
elif char == PAN_3D_KEY and modifiers == COMMAND_MODIFIER:
self.copy_frame_anim_call()
elif char == "d" and modifiers == SHIFT_MODIFIER:
self.copy_frame_positioning()
elif symbol in ARROW_SYMBOLS:
self.nudge_selection(
vect=[LEFT, UP, RIGHT, DOWN][ARROW_SYMBOLS.index(symbol)],
@ -615,16 +616,18 @@ class InteractiveScene(Scene):
self.clear_selection()
# Copying code to recreate state
def copy_frame_anim_call(self):
def copy_frame_positioning(self):
frame = self.frame
center = frame.get_center()
height = frame.get_height()
angles = frame.get_euler_angles()
call = f"self.frame.animate.reorient"
call += str(tuple((angles / DEGREES).astype(int)))
call = f"reorient("
theta, phi, gamma = (angles / DEGREES).astype(int)
call += f"{theta}, {phi}, {gamma}"
if any(center != 0):
call += f".move_to({list(np.round(center, 2))})"
call += f", {tuple(np.round(center, 2))}"
if height != FRAME_HEIGHT:
call += ".set_height({:.2f})".format(height)
call += ", {:.2f}".format(height)
call += ")"
pyperclip.copy(call)

View file

@ -7,6 +7,7 @@ import platform
import pyperclip
import random
import time
import re
from functools import wraps
from IPython.terminal import pt_inputhooks
@ -44,9 +45,11 @@ from manimlib.utils.iterables import batch_by_property
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Callable, Iterable
from typing import Callable, Iterable, TypeVar
from manimlib.typing import Vect3
T = TypeVar('T')
from PIL.Image import Image
from manimlib.animation.animation import Animation
@ -210,7 +213,8 @@ class Scene(object):
show_animation_progress: bool = False,
) -> None:
if not self.preview:
return # Embed is only relevant with a preview
# Embed is only relevant with a preview
return
self.stop_skipping()
self.update_frame()
self.save_state()
@ -236,6 +240,8 @@ class Scene(object):
i2g=self.i2g,
i2m=self.i2m,
checkpoint_paste=self.checkpoint_paste,
touch=lambda: shell.enable_gui("manim"),
notouch=lambda: shell.enable_gui(None),
)
# Enables gui interactions during the embed
@ -257,20 +263,19 @@ class Scene(object):
# namespace, since this is just a shell session anyway.
shell.events.register(
"pre_run_cell",
lambda: shell.user_global_ns.update(shell.user_ns)
lambda *args, **kwargs: shell.user_global_ns.update(shell.user_ns)
)
# Operation to run after each ipython command
def post_cell_func():
def post_cell_func(*args, **kwargs):
if not self.is_window_closing():
self.update_frame(dt=0, ignore_skipping=True)
self.save_state()
shell.events.register("post_run_cell", post_cell_func)
# Flash border, and potentially play sound, on exceptions
def custom_exc(shell, etype, evalue, tb, tb_offset=None):
# still show the error don't just swallow it
# Show the error don't just swallow it
shell.showtraceback((etype, evalue, tb), tb_offset=tb_offset)
if self.embed_error_sound:
os.system("printf '\a'")
@ -342,17 +347,9 @@ class Scene(object):
mobject.update(dt)
def should_update_mobjects(self) -> bool:
return self.always_update_mobjects or any([
len(mob.get_family_updaters()) > 0
for mob in self.mobjects
])
def has_time_based_updaters(self) -> bool:
return any([
sm.has_time_based_updater()
for mob in self.mobjects()
for sm in mob.get_family()
])
return self.always_update_mobjects or any(
mob.has_updaters() for mob in self.mobjects
)
# Related to time
@ -399,7 +396,8 @@ class Scene(object):
for batch, key in batches
]
def affects_mobject_list(func: Callable):
@staticmethod
def affects_mobject_list(func: Callable[..., T]) -> Callable[..., T]:
@wraps(func)
def wrapper(self, *args, **kwargs):
func(self, *args, **kwargs)
@ -774,13 +772,31 @@ class Scene(object):
)
pasted = pyperclip.paste()
line0 = pasted.lstrip().split("\n")[0]
if line0.startswith("#"):
if line0 not in self.checkpoint_states:
self.checkpoint(line0)
else:
self.revert_to_checkpoint(line0)
lines = pasted.split("\n")
# Commented lines trigger saved checkpoints
if lines[0].lstrip().startswith("#"):
if lines[0] not in self.checkpoint_states:
self.checkpoint(lines[0])
else:
self.revert_to_checkpoint(lines[0])
# Copied methods of a scene are handled specially
# A bit hacky, yes, but convenient
method_pattern = r"^def\s+([a-zA-Z_]\w*)\s*\(self.*\):"
method_names = re.findall(method_pattern ,lines[0].strip())
if method_names:
method_name = method_names[0]
indent = " " * lines[0].index(lines[0].strip())
pasted = "\n".join([
# Remove self from function signature
re.sub(r"self(,\s*)?", "", lines[0]),
*lines[1:],
# Attach to scene via self.func_name = func_name
f"{indent}self.{method_name} = {method_name}"
])
# Keep track of skipping and progress bar status
prev_skipping = self.skip_animations
self.skip_animations = skip
@ -836,6 +852,13 @@ class Scene(object):
return self.window and (self.window.is_closing or self.quit_interaction)
# Event handling
def set_floor_plane(self, plane: str = "xy"):
if plane == "xy":
self.frame.set_euler_axes("zxz")
elif plane == "xz":
self.frame.set_euler_axes("zxy")
else:
raise Exception("Only `xz` and `xy` are valid floor planes")
def on_mouse_motion(
self,
@ -1023,7 +1046,7 @@ class ThreeDScene(Scene):
default_frame_orientation = (-30, 70)
always_depth_test = True
def add(self, *mobjects, set_depth_test: bool = True):
def add(self, *mobjects: Mobject, set_depth_test: bool = True):
for mob in mobjects:
if set_depth_test and not mob.is_fixed_in_frame() and self.always_depth_test:
mob.apply_depth_test()

View file

@ -9,9 +9,8 @@ const float Y_SCALE = 2.0 / DEFAULT_FRAME_HEIGHT;
void emit_gl_Position(vec3 point){
vec4 result = vec4(point, 1.0);
if(!bool(is_fixed_in_frame)){
result = view * result;
}
// This allow for smooth transitions between objects fixed and unfixed from frame
result = mix(view * result, result, is_fixed_in_frame);
// Essentially a projection matrix
result.x *= X_SCALE;
result.y *= Y_SCALE;

View file

@ -1,23 +0,0 @@
uniform float is_fixed_in_frame;
uniform mat4 view;
uniform float focal_distance;
const float DEFAULT_FRAME_HEIGHT = 8.0;
const float ASPECT_RATIO = 16.0 / 9.0;
const float X_SCALE = 2.0 / DEFAULT_FRAME_HEIGHT / ASPECT_RATIO;
const float Y_SCALE = 2.0 / DEFAULT_FRAME_HEIGHT;
void emit_gl_Position(vec3 point){
vec4 result = vec4(point, 1.0);
if(!bool(is_fixed_in_frame)){
result = view * result;
}
// Essentially a projection matrix
result.x *= X_SCALE;
result.y *= Y_SCALE;
result.z /= focal_distance;
result.w = 1.0 - result.z;
// Flip and scale to prevent premature clipping
result.z *= -0.1;
gl_Position = result;
}

View file

@ -1,66 +1,20 @@
#version 330
in vec2 uv_coords;
in float uv_stroke_width;
in float uv_anti_alias_width;
// Value between -1 and 1
in float scaled_signed_dist_to_curve;
in float scaled_anti_alias_width;
in vec4 color;
in float is_linear;
out vec4 frag_color;
const float QUICK_DIST_WIDTH = 0.2;
float dist_to_curve(){
// In the linear case, the curve will have
// been set to equal the x axis
if(bool(is_linear)) return abs(uv_coords.y);
// Otherwise, find the distance from uv_coords to the curve y = x^2
float x0 = uv_coords.x;
float y0 = uv_coords.y;
// This is a quick approximation for computing
// the distance to the curve.
// Evaluate F(x, y) = y - x^2
// divide by its gradient's magnitude
float Fxy = y0 - x0 * x0;
float approx_dist = abs(Fxy) * inversesqrt(1.0 + 4 * x0 * x0);
if(approx_dist < QUICK_DIST_WIDTH) return approx_dist;
// Otherwise, solve for the minimal distance.
// The distance squared between (x0, y0) and a point (x, x^2) looks like
//
// (x0 - x)^2 + (y0 - x^2)^2 = x^4 + (1 - 2y0)x^2 - 2x0 * x + (x0^2 + y0^2)
//
// Setting the derivative equal to zero (and rescaling) looks like
//
// x^3 + (0.5 - y0) * x - 0.5 * x0 = 0
//
// Adapted from https://www.shadertoy.com/view/ws3GD7
x0 = abs(x0);
float p = (0.5 - y0) / 3.0; // p / 3 in usual Cardano's formula notation
float q = 0.25 * x0; // -q / 2 in usual Cardano's formula notation
float disc = q*q + p*p*p;
float r = sqrt(abs(disc));
float x = (disc > 0.0) ?
// 1 root
pow(q + r, 1.0 / 3.0) + pow(abs(q - r), 1.0 / 3.0) * sign(-p) :
// 3 roots
2.0 * cos(atan(r, q) / 3.0) * sqrt(-p);
return length(vec2(x0 - x, y0 - x * x));
}
void main() {
if (uv_stroke_width == 0) discard;
if(scaled_anti_alias_width < 0) discard;
frag_color = color;
// sdf for the region around the curve we wish to color.
float signed_dist = dist_to_curve() - 0.5 * uv_stroke_width;
frag_color.a *= smoothstep(0.5, -0.5, signed_dist / uv_anti_alias_width);
float signed_dist_to_region = abs(scaled_signed_dist_to_curve) - 1.0;
frag_color.a *= smoothstep(
0, -scaled_anti_alias_width,
signed_dist_to_region
);
}

View file

@ -1,12 +1,13 @@
#version 330
layout (triangles) in;
layout (triangle_strip, max_vertices = 6) out;
layout (triangle_strip, max_vertices = 64) out; // Related to MAX_STEPS below
uniform float anti_alias_width;
uniform float flat_stroke;
uniform float pixel_size;
uniform float joint_type;
uniform float frame_scale;
in vec3 verts[3];
@ -15,12 +16,8 @@ in float v_stroke_width[3];
in vec4 v_color[3];
out vec4 color;
out float uv_stroke_width;
out float uv_anti_alias_width;
out float is_linear;
out vec2 uv_coords;
out float scaled_anti_alias_width;
out float scaled_signed_dist_to_curve;
// Codes for joint types
const int NO_JOINT = 0;
@ -32,11 +29,13 @@ const int MITER_JOINT = 3;
// two vectors is larger than this, we
// consider them aligned
const float COS_THRESHOLD = 0.99;
// Used to determine how many lines to break the curve into
const float POLYLINE_FACTOR = 30;
const int MAX_STEPS = 32;
vec3 unit_normal = vec3(0.0, 0.0, 1.0);
#INSERT emit_gl_Position.glsl
#INSERT get_xyz_to_uv.glsl
#INSERT finalize_color.glsl
@ -54,6 +53,30 @@ vec4 normalized_joint_product(vec4 joint_product){
}
vec3 point_on_curve(float t){
return verts[0] + 2 * (verts[1] - verts[0]) * t + (verts[0] - 2 * verts[1] + verts[2]) * t * t;
}
vec3 tangent_on_curve(float t){
return 2 * (verts[1] - verts[0]) + 2 * (verts[0] - 2 * verts[1] + verts[2]) * t;
}
void compute_subdivision(out int n_steps, out float subdivision[MAX_STEPS]){
// Crude estimate for the number of polyline segments to use, based
// on the area spanned by the control points
float area = 0.5 * length(v_joint_product[1].xzy);
int count = 2 + int(round(POLYLINE_FACTOR * sqrt(area) / frame_scale));
n_steps = min(count, MAX_STEPS);
for(int i = 0; i < MAX_STEPS; i++){
if (i >= n_steps) break;
subdivision[i] = float(i) / (n_steps - 1);
}
}
void create_joint(
vec4 joint_product,
vec3 unit_tan,
@ -83,73 +106,58 @@ void create_joint(
changing_c1 = static_c1 + shift * unit_tan;
}
vec3 get_perp(int index, vec4 joint_product, vec3 point, vec3 tangent, float aaw){
vec3 left_step(vec3 point, vec3 tangent, vec4 joint_product){
/*
Perpendicular vectors to the left of the curve
*/
float buff = 0.5 * v_stroke_width[index] + aaw;
// Add correction for sharp angles to prevent weird bevel effects
if(joint_product.w < -0.75) buff *= 4 * (joint_product.w + 1.0);
float mult = 1.0;
if(joint_product.w < -0.75) mult *= 4 * (joint_product.w + 1.0);
vec3 normal = get_joint_unit_normal(joint_product);
// Set global unit normal
unit_normal = normal;
// Choose the "outward" normal direction
if(normal.z < 0) normal *= -1;
if(bool(flat_stroke)){
return buff * normalize(cross(normal, tangent));
return mult * normalize(cross(normal, tangent));
}else{
return buff * normalize(cross(camera_position - point, tangent));
return mult * normalize(cross(camera_position - point, tangent));
}
}
// This function is responsible for finding the corners of
// a bounding region around the bezier curve, which can be
// emitted as a triangle fan, with vertices vaguely close
// to control points so that the passage of vert data to
// frag shaders is most natural.
void get_corners(
// Control points for a bezier curve
vec3 p0,
vec3 p1,
vec3 p2,
// Unit tangent vectors at p0 and p2
vec3 v01,
vec3 v12,
// Anti-alias width
float aaw,
out vec3 corners[6]
void emit_point_with_width(
vec3 point,
vec3 tangent,
vec4 joint_product,
float width,
vec4 joint_color
){
bool linear = bool(is_linear);
vec4 jp0 = normalized_joint_product(v_joint_product[0]);
vec4 jp2 = normalized_joint_product(v_joint_product[2]);
vec3 p0_perp = get_perp(0, jp0, p0, v01, aaw);
vec3 p2_perp = get_perp(2, jp2, p2, v12, aaw);
vec3 p1_perp = 0.5 * (p0_perp + p2_perp);
if(linear){
p1_perp *= (0.5 * v_stroke_width[1] + aaw) / length(p1_perp);
}
vec3 unit_tan = normalize(tangent);
vec4 normed_join_product = normalized_joint_product(joint_product);
vec3 perp = 0.5 * width * left_step(point, unit_tan, normed_join_product);
// The order of corners should be for a triangle_strip.
vec3 c0 = p0 + p0_perp;
vec3 c1 = p0 - p0_perp;
vec3 c2 = p1 + p1_perp;
vec3 c3 = p1 - p1_perp;
vec3 c4 = p2 + p2_perp;
vec3 c5 = p2 - p2_perp;
// Move the inner middle control point to make
// room for the curve
// float orientation = dot(unit_normal, v_joint_product[1].xyz);
float orientation = v_joint_product[1].z;
if(!linear && orientation >= 0.0) c2 = 0.5 * (c0 + c4);
else if(!linear && orientation < 0.0) c3 = 0.5 * (c1 + c5);
vec3 corners[2] = vec3[2](point + perp, point - perp);
create_joint(
normed_join_product, unit_tan, length(perp),
corners[0], corners[0],
corners[1], corners[1]
);
// Account for previous and next control points
if(bool(flat_stroke)){
create_joint(jp0, v01, length(p0_perp), c1, c1, c0, c0);
create_joint(jp2, -v12, length(p2_perp), c5, c5, c4, c4);
}
color = finalize_color(joint_color, point, unit_normal);
if (width == 0) scaled_anti_alias_width = -1.0; // Signal to discard in frag
else scaled_anti_alias_width = 2.0 * anti_alias_width * pixel_size / width;
corners = vec3[6](c0, c1, c2, c3, c4, c5);
// Emit two corners
// The frag shader will receive a value from -1 to 1,
// reflecting where in the stroke that point is
scaled_signed_dist_to_curve = -1.0;
emit_gl_Position(corners[0]);
EmitVertex();
scaled_signed_dist_to_curve = +1.0;
emit_gl_Position(corners[1]);
EmitVertex();
}
void main() {
@ -157,53 +165,40 @@ void main() {
// the first anchor is set equal to that anchor
if (verts[0] == verts[1]) return;
vec3 p0 = verts[0];
vec3 p1 = verts[1];
vec3 p2 = verts[2];
vec3 v01 = normalize(p1 - p0);
vec3 v12 = normalize(p2 - p1);
vec4 jp1 = normalized_joint_product(v_joint_product[1]);
is_linear = float(jp1.w > COS_THRESHOLD);
// We want to change the coordinates to a space where the curve
// coincides with y = x^2, between some values x0 and x2. Or, in
// the case of a linear curve just put it on the x-axis
mat4 xyz_to_uv;
float uv_scale_factor;
if(!bool(is_linear)){
bool too_steep;
xyz_to_uv = get_xyz_to_uv(p0, p1, p2, 2.0, too_steep);
is_linear = float(too_steep);
uv_scale_factor = length(xyz_to_uv[0].xyz);
// Compute subdivision
int n_steps;
float subdivision[MAX_STEPS];
compute_subdivision(n_steps, subdivision);
vec3 points[MAX_STEPS];
for (int i = 0; i < MAX_STEPS; i++){
if (i >= n_steps) break;
points[i] = point_on_curve(subdivision[i]);
}
float scaled_aaw = anti_alias_width * pixel_size;
vec3 corners[6];
get_corners(p0, p1, p2, v01, v12, scaled_aaw, corners);
// Compute joint products
vec4 joint_products[MAX_STEPS];
joint_products[0] = v_joint_product[0];
joint_products[0].xyz *= -1;
joint_products[n_steps - 1] = v_joint_product[2];
for (int i = 1; i < MAX_STEPS; i++){
if (i >= n_steps - 1) break;
vec3 v1 = points[i] - points[i - 1];
vec3 v2 = points[i + 1] - points[i];
joint_products[i].xyz = cross(v1, v2);
joint_products[i].w = dot(v1, v2);
}
// Emit each corner
float max_sw = max(v_stroke_width[0], v_stroke_width[2]);
for(int i = 0; i < 6; i++){
float stroke_width = v_stroke_width[i / 2];
if(bool(is_linear)){
float sign = vec2(-1, 1)[i % 2];
// In this case, we only really care about
// the v coordinate
uv_coords = vec2(0, sign * (0.5 * stroke_width + scaled_aaw));
uv_anti_alias_width = scaled_aaw;
uv_stroke_width = stroke_width;
}else{
uv_coords = (xyz_to_uv * vec4(corners[i], 1.0)).xy;
uv_stroke_width = uv_scale_factor * stroke_width;
uv_anti_alias_width = uv_scale_factor * scaled_aaw;
}
color = finalize_color(v_color[i / 2], corners[i], unit_normal);
emit_gl_Position(corners[i]);
EmitVertex();
// Emit vertex pairs aroudn subdivided points
for (int i = 0; i < MAX_STEPS; i++){
if (i >= n_steps) break;
float t = subdivision[i];
emit_point_with_width(
points[i],
tangent_on_curve(t),
joint_products[i],
mix(v_stroke_width[0], v_stroke_width[2], t),
mix(v_color[0], v_color[2], t)
);
}
EndPrimitive();
}

View file

@ -20,8 +20,7 @@ const float STROKE_WIDTH_CONVERSION = 0.01;
void main(){
verts = point;
v_stroke_width = STROKE_WIDTH_CONVERSION * stroke_width;
v_stroke_width *= mix(frame_scale, 1, is_fixed_in_frame);
v_stroke_width = STROKE_WIDTH_CONVERSION * stroke_width * mix(frame_scale, 1, is_fixed_in_frame);
v_joint_product = joint_product;
v_color = stroke_rgba;
}

View file

@ -4,6 +4,7 @@ from colour import Color
from colour import hex2rgb
from colour import rgb2hex
import numpy as np
import random
from manimlib.constants import COLORMAP_3B1B
from manimlib.constants import WHITE
@ -102,6 +103,16 @@ def interpolate_color(
return rgb_to_color(rgb)
def interpolate_color_by_hsl(
color1: ManimColor,
color2: ManimColor,
alpha: float
) -> Color:
hsl1 = np.array(Color(color1).get_hsl())
hsl2 = np.array(Color(color2).get_hsl())
return Color(hsl=interpolate(hsl1, hsl2, alpha))
def average_color(*colors: ManimColor) -> Color:
rgbs = np.array(list(map(color_to_rgb, colors)))
return rgb_to_color(np.sqrt((rgbs**2).mean(0)))
@ -111,9 +122,16 @@ def random_color() -> Color:
return Color(rgb=tuple(np.random.random(3)))
def random_bright_color() -> Color:
color = random_color()
return average_color(color, Color(WHITE))
def random_bright_color(
hue_range: tuple[float, float] = (0.0, 1.0),
saturation_range: tuple[float, float] = (0.5, 0.8),
luminance_range: tuple[float, float] = (0.5, 1.0),
) -> Color:
return Color(hsl=(
interpolate(*hue_range, random.random()),
interpolate(*saturation_range, random.random()),
interpolate(*luminance_range, random.random()),
))
def get_colormap_list(

View file

@ -3,6 +3,7 @@ from __future__ import annotations
from colour import Color
import numpy as np
import random
from typing import TYPE_CHECKING
@ -83,6 +84,12 @@ def listify(obj: object) -> list:
return [obj]
def shuffled(iterable: Iterable) -> list:
as_list = list(iterable)
random.shuffle(as_list)
return as_list
def resize_array(nparray: np.ndarray, length: int) -> np.ndarray:
if len(nparray) == length:
return nparray

View file

@ -9,7 +9,7 @@ import numpy as np
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Callable, TypeVar
from typing import Callable, TypeVar, Iterable
from manimlib.typing import FloatArray
Scalable = TypeVar("Scalable", float, FloatArray)
@ -30,11 +30,11 @@ def gen_choose(n: int, r: int) -> int:
def get_num_args(function: Callable) -> int:
return len(get_parameters(function))
return function.__code__.co_argcount
def get_parameters(function: Callable) -> list:
return list(inspect.signature(function).parameters.keys())
def get_parameters(function: Callable) -> Iterable[str]:
return inspect.signature(function).parameters.keys()
# Just to have a less heavyweight name for this extremely common operation
#