mirror of
https://github.com/3b1b/manim.git
synced 2025-04-13 09:47:07 +00:00
Refactor Matrix, DecimalMatrix, MobjectMatrix, etc.
This commit is contained in:
parent
41ece958fd
commit
855ef9be8d
1 changed files with 140 additions and 147 deletions
|
@ -4,85 +4,36 @@ import itertools as it
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from manimlib.constants import DEFAULT_MOBJECT_TO_MOBJECT_BUFFER
|
from manimlib.constants import DOWN, LEFT, RIGHT, ORIGIN
|
||||||
from manimlib.constants import DOWN, LEFT, RIGHT, UP
|
|
||||||
from manimlib.constants import WHITE
|
|
||||||
from manimlib.mobject.numbers import DecimalNumber
|
from manimlib.mobject.numbers import DecimalNumber
|
||||||
from manimlib.mobject.numbers import Integer
|
|
||||||
from manimlib.mobject.shape_matchers import BackgroundRectangle
|
|
||||||
from manimlib.mobject.svg.tex_mobject import Tex
|
from manimlib.mobject.svg.tex_mobject import Tex
|
||||||
from manimlib.mobject.svg.tex_mobject import TexText
|
|
||||||
from manimlib.mobject.types.vectorized_mobject import VGroup
|
from manimlib.mobject.types.vectorized_mobject import VGroup
|
||||||
from manimlib.mobject.types.vectorized_mobject import VMobject
|
from manimlib.mobject.types.vectorized_mobject import VMobject
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import Sequence
|
from typing import Sequence, Union, Tuple
|
||||||
import numpy.typing as npt
|
from manimlib.typing import ManimColor, Vect3, VectNArray, Self
|
||||||
from manimlib.mobject.mobject import Mobject
|
|
||||||
from manimlib.typing import ManimColor, Vect3, Self
|
|
||||||
|
|
||||||
|
StringMatrixType = Union[Sequence[Sequence[str]], np.ndarray[int, np.dtype[np.str_]]]
|
||||||
|
FloatMatrixType = Union[Sequence[Sequence[float]], VectNArray]
|
||||||
|
VMobjectMatrixType = Sequence[Sequence[VMobject]]
|
||||||
|
GenericMatrixType = Union[FloatMatrixType, StringMatrixType, VMobjectMatrixType]
|
||||||
|
|
||||||
VECTOR_LABEL_SCALE_FACTOR = 0.8
|
|
||||||
|
|
||||||
|
|
||||||
def matrix_to_tex_string(matrix: npt.ArrayLike) -> str:
|
|
||||||
matrix = np.array(matrix).astype("str")
|
|
||||||
if matrix.ndim == 1:
|
|
||||||
matrix = matrix.reshape((matrix.size, 1))
|
|
||||||
n_rows, n_cols = matrix.shape
|
|
||||||
prefix = R"\left[ \begin{array}{%s}" % ("c" * n_cols)
|
|
||||||
suffix = R"\end{array} \right]"
|
|
||||||
rows = [
|
|
||||||
" & ".join(row)
|
|
||||||
for row in matrix
|
|
||||||
]
|
|
||||||
return prefix + R" \\ ".join(rows) + suffix
|
|
||||||
|
|
||||||
|
|
||||||
def matrix_to_mobject(matrix: npt.ArrayLike) -> Tex:
|
|
||||||
return Tex(matrix_to_tex_string(matrix))
|
|
||||||
|
|
||||||
|
|
||||||
def vector_coordinate_label(
|
|
||||||
vector_mob: VMobject,
|
|
||||||
integer_labels: bool = True,
|
|
||||||
n_dim: int = 2,
|
|
||||||
color: ManimColor = WHITE
|
|
||||||
) -> Matrix:
|
|
||||||
vect = np.array(vector_mob.get_end())
|
|
||||||
if integer_labels:
|
|
||||||
vect = np.round(vect).astype(int)
|
|
||||||
vect = vect[:n_dim]
|
|
||||||
vect = vect.reshape((n_dim, 1))
|
|
||||||
label = Matrix(vect, add_background_rectangles_to_entries=True)
|
|
||||||
label.scale(VECTOR_LABEL_SCALE_FACTOR)
|
|
||||||
|
|
||||||
shift_dir = np.array(vector_mob.get_end())
|
|
||||||
if shift_dir[0] >= 0: # Pointing right
|
|
||||||
shift_dir -= label.get_left() + DEFAULT_MOBJECT_TO_MOBJECT_BUFFER * LEFT
|
|
||||||
else: # Pointing left
|
|
||||||
shift_dir -= label.get_right() + DEFAULT_MOBJECT_TO_MOBJECT_BUFFER * RIGHT
|
|
||||||
label.shift(shift_dir)
|
|
||||||
label.set_color(color)
|
|
||||||
label.rect = BackgroundRectangle(label)
|
|
||||||
label.add_to_back(label.rect)
|
|
||||||
return label
|
|
||||||
|
|
||||||
|
|
||||||
class Matrix(VMobject):
|
class Matrix(VMobject):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
matrix: Sequence[Sequence[str | float | VMobject]],
|
matrix: GenericMatrixType,
|
||||||
v_buff: float = 0.8,
|
v_buff: float = 0.5,
|
||||||
h_buff: float = 1.0,
|
h_buff: float = 0.5,
|
||||||
bracket_h_buff: float = 0.2,
|
bracket_h_buff: float = 0.2,
|
||||||
bracket_v_buff: float = 0.25,
|
bracket_v_buff: float = 0.25,
|
||||||
add_background_rectangles_to_entries: bool = False,
|
height: float | None = None,
|
||||||
include_background_rectangle: bool = False,
|
element_config: dict = dict(),
|
||||||
element_alignment_corner: Vect3 = DOWN,
|
element_alignment_corner: Vect3 = DOWN,
|
||||||
**kwargs
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Matrix can either include numbers, tex_strings,
|
Matrix can either include numbers, tex_strings,
|
||||||
|
@ -90,54 +41,60 @@ class Matrix(VMobject):
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
mob_matrix = self.matrix_to_mob_matrix(matrix, **kwargs)
|
self.mob_matrix = self.create_mobject_matrix(
|
||||||
self.mob_matrix = mob_matrix
|
matrix, v_buff, h_buff, element_alignment_corner,
|
||||||
|
**element_config
|
||||||
|
)
|
||||||
|
# Create helpful groups for the elements
|
||||||
|
n_cols = len(self.mob_matrix[0])
|
||||||
|
self.elements = VGroup(*it.chain(*self.mob_matrix))
|
||||||
|
self.columns = VGroup(*(
|
||||||
|
VGroup(*(row[i] for row in self.mob_matrix))
|
||||||
|
for i in range(n_cols)
|
||||||
|
))
|
||||||
|
self.rows = VGroup(*(VGroup(*row) for row in self.mob_matrix))
|
||||||
|
|
||||||
self.organize_mob_matrix(mob_matrix, v_buff, h_buff, element_alignment_corner)
|
# Add elements and brackets
|
||||||
self.elements = VGroup(*it.chain(*mob_matrix))
|
|
||||||
self.add(self.elements)
|
self.add(self.elements)
|
||||||
|
if height is not None:
|
||||||
|
self.set_height(height - 2 * bracket_v_buff)
|
||||||
self.add_brackets(bracket_v_buff, bracket_h_buff)
|
self.add_brackets(bracket_v_buff, bracket_h_buff)
|
||||||
self.center()
|
self.center()
|
||||||
if add_background_rectangles_to_entries:
|
|
||||||
for mob in self.elements:
|
|
||||||
mob.add_background_rectangle()
|
|
||||||
if include_background_rectangle:
|
|
||||||
self.add_background_rectangle()
|
|
||||||
|
|
||||||
|
def create_mobject_matrix(
|
||||||
def element_to_mobject(self, element: str | float | VMobject, **config) -> VMobject:
|
|
||||||
if isinstance(element, VMobject):
|
|
||||||
return element
|
|
||||||
return Tex(str(element), **config)
|
|
||||||
|
|
||||||
def matrix_to_mob_matrix(
|
|
||||||
self,
|
self,
|
||||||
matrix: Sequence[Sequence[str | float | VMobject]],
|
matrix: GenericMatrixType,
|
||||||
**config
|
|
||||||
) -> list[list[VMobject]]:
|
|
||||||
return [
|
|
||||||
[
|
|
||||||
self.element_to_mobject(item, **config)
|
|
||||||
for item in row
|
|
||||||
]
|
|
||||||
for row in matrix
|
|
||||||
]
|
|
||||||
|
|
||||||
def organize_mob_matrix(
|
|
||||||
self,
|
|
||||||
matrix: list[list[VMobject]],
|
|
||||||
v_buff: float,
|
v_buff: float,
|
||||||
h_buff: float,
|
h_buff: float,
|
||||||
aligned_corner: Vect3,
|
aligned_corner: Vect3,
|
||||||
) -> Self:
|
**element_config
|
||||||
for i, row in enumerate(matrix):
|
) -> VMobjectMatrixType:
|
||||||
|
"""
|
||||||
|
Creates and organizes the matrix of mobjects
|
||||||
|
"""
|
||||||
|
mob_matrix = [
|
||||||
|
[
|
||||||
|
self.element_to_mobject(element, **element_config)
|
||||||
|
for element in row
|
||||||
|
]
|
||||||
|
for row in matrix
|
||||||
|
]
|
||||||
|
max_width = max(elem.get_width() for row in mob_matrix for elem in row)
|
||||||
|
max_height = max(elem.get_height() for row in mob_matrix for elem in row)
|
||||||
|
x_step = (max_width + h_buff) * RIGHT
|
||||||
|
y_step = (max_height + v_buff) * DOWN
|
||||||
|
for i, row in enumerate(mob_matrix):
|
||||||
for j, elem in enumerate(row):
|
for j, elem in enumerate(row):
|
||||||
mob = matrix[i][j]
|
elem.move_to(i * y_step + j * x_step, aligned_corner)
|
||||||
mob.move_to(
|
return mob_matrix
|
||||||
i * v_buff * DOWN + j * h_buff * RIGHT,
|
|
||||||
aligned_corner
|
def element_to_mobject(self, element, **config) -> VMobject:
|
||||||
)
|
if isinstance(element, VMobject):
|
||||||
return self
|
return element
|
||||||
|
elif isinstance(element, float | complex):
|
||||||
|
return DecimalNumber(element, **config)
|
||||||
|
else:
|
||||||
|
return Tex(str(element), **config)
|
||||||
|
|
||||||
def add_brackets(self, v_buff: float, h_buff: float) -> Self:
|
def add_brackets(self, v_buff: float, h_buff: float) -> Self:
|
||||||
height = len(self.mob_matrix)
|
height = len(self.mob_matrix)
|
||||||
|
@ -152,21 +109,25 @@ class Matrix(VMobject):
|
||||||
l_bracket.next_to(self, LEFT, h_buff)
|
l_bracket.next_to(self, LEFT, h_buff)
|
||||||
r_bracket.next_to(self, RIGHT, h_buff)
|
r_bracket.next_to(self, RIGHT, h_buff)
|
||||||
brackets.set_submobjects([l_bracket, r_bracket])
|
brackets.set_submobjects([l_bracket, r_bracket])
|
||||||
self.brackets = brackets
|
self.brackets = VGroup(l_bracket, r_bracket)
|
||||||
self.add(*brackets)
|
self.add(*brackets)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def get_column(self, index: int):
|
||||||
|
if not 0 <= index < len(self.columns):
|
||||||
|
raise IndexError(f"Index {index} out of bound for matrix with {len(self.columns)} columns")
|
||||||
|
return self.columns[index]
|
||||||
|
|
||||||
|
def get_row(self, index: int):
|
||||||
|
if not 0 <= index < len(self.rows):
|
||||||
|
raise IndexError(f"Index {index} out of bound for matrix with {len(self.rows)} rows")
|
||||||
|
return self.rows[index]
|
||||||
|
|
||||||
def get_columns(self) -> VGroup:
|
def get_columns(self) -> VGroup:
|
||||||
return VGroup(*[
|
return self.columns
|
||||||
VGroup(*[row[i] for row in self.mob_matrix])
|
|
||||||
for i in range(len(self.mob_matrix[0]))
|
|
||||||
])
|
|
||||||
|
|
||||||
def get_rows(self) -> VGroup:
|
def get_rows(self) -> VGroup:
|
||||||
return VGroup(*[
|
return self.rows
|
||||||
VGroup(*row)
|
|
||||||
for row in self.mob_matrix
|
|
||||||
])
|
|
||||||
|
|
||||||
def set_column_colors(self, *colors: ManimColor) -> Self:
|
def set_column_colors(self, *colors: ManimColor) -> Self:
|
||||||
columns = self.get_columns()
|
columns = self.get_columns()
|
||||||
|
@ -179,7 +140,7 @@ class Matrix(VMobject):
|
||||||
mob.add_background_rectangle()
|
mob.add_background_rectangle()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def get_mob_matrix(self) -> list[list[Mobject]]:
|
def get_mob_matrix(self) -> VMobjectMatrixType:
|
||||||
return self.mob_matrix
|
return self.mob_matrix
|
||||||
|
|
||||||
def get_entries(self) -> VGroup:
|
def get_entries(self) -> VGroup:
|
||||||
|
@ -190,50 +151,82 @@ class Matrix(VMobject):
|
||||||
|
|
||||||
|
|
||||||
class DecimalMatrix(Matrix):
|
class DecimalMatrix(Matrix):
|
||||||
def element_to_mobject(self, element: float, num_decimal_places: int = 1, **config) -> DecimalNumber:
|
|
||||||
return DecimalNumber(element, num_decimal_places=num_decimal_places, **config)
|
|
||||||
|
|
||||||
|
|
||||||
class IntegerMatrix(Matrix):
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
matrix: npt.ArrayLike,
|
matrix: FloatMatrixType,
|
||||||
element_alignment_corner: Vect3 = UP,
|
num_decimal_places: int = 2,
|
||||||
**kwargs
|
decimal_config: dict = dict(),
|
||||||
|
**config
|
||||||
):
|
):
|
||||||
super().__init__(matrix, element_alignment_corner=element_alignment_corner, **kwargs)
|
super().__init__(
|
||||||
|
matrix,
|
||||||
|
element_config=dict(
|
||||||
|
num_decimal_places=num_decimal_places,
|
||||||
|
**decimal_config
|
||||||
|
),
|
||||||
|
**config
|
||||||
|
)
|
||||||
|
|
||||||
|
def element_to_mobject(self, element, **decimal_config) -> VMobject:
|
||||||
|
return DecimalNumber(element, **decimal_config)
|
||||||
|
|
||||||
|
|
||||||
|
class IntegerMatrix(DecimalMatrix):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
matrix: FloatMatrixType,
|
||||||
|
num_decimal_places: int = 0,
|
||||||
|
decimal_config: dict = dict(),
|
||||||
|
**config
|
||||||
|
):
|
||||||
|
super().__init__(matrix, num_decimal_places, decimal_config, **config)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TexMatrix(Matrix):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
matrix: StringMatrixType,
|
||||||
|
tex_config: dict = dict(),
|
||||||
|
**config,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
matrix,
|
||||||
|
element_config=tex_config,
|
||||||
|
**config
|
||||||
|
)
|
||||||
|
|
||||||
def element_to_mobject(self, element: int, **config) -> Integer:
|
|
||||||
return Integer(element, **config)
|
|
||||||
|
|
||||||
|
|
||||||
class MobjectMatrix(Matrix):
|
class MobjectMatrix(Matrix):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
group: VGroup,
|
||||||
|
n_rows: int | None = None,
|
||||||
|
n_cols: int | None = None,
|
||||||
|
height: float = 4.0,
|
||||||
|
element_alignment_corner=ORIGIN,
|
||||||
|
**config,
|
||||||
|
):
|
||||||
|
# Have fallback defaults of n_rows and n_cols
|
||||||
|
n_mobs = len(group)
|
||||||
|
if n_rows is None:
|
||||||
|
n_rows = int(np.sqrt(n_mobs)) if n_cols is None else n_mobs // n_cols
|
||||||
|
if n_cols is None:
|
||||||
|
n_cols = n_mobs // n_rows
|
||||||
|
|
||||||
|
if len(group) < n_rows * n_cols:
|
||||||
|
raise Exception("Input to MobjectMatrix must have at least n_rows * n_cols entries")
|
||||||
|
|
||||||
|
mob_matrix = [
|
||||||
|
[group[n * n_cols + k] for k in range(n_cols)]
|
||||||
|
for n in range(n_rows)
|
||||||
|
]
|
||||||
|
config.update(
|
||||||
|
height=height,
|
||||||
|
element_alignment_corner=element_alignment_corner,
|
||||||
|
)
|
||||||
|
super().__init__(mob_matrix, **config)
|
||||||
|
|
||||||
def element_to_mobject(self, element: VMobject, **config) -> VMobject:
|
def element_to_mobject(self, element: VMobject, **config) -> VMobject:
|
||||||
return element
|
return element
|
||||||
|
|
||||||
|
|
||||||
def get_det_text(
|
|
||||||
matrix: Matrix,
|
|
||||||
determinant: int | str | None = None,
|
|
||||||
background_rect: bool = False,
|
|
||||||
initial_scale_factor: int = 2
|
|
||||||
) -> VGroup:
|
|
||||||
parens = Tex("()")
|
|
||||||
parens.scale(initial_scale_factor)
|
|
||||||
parens.stretch_to_fit_height(matrix.get_height())
|
|
||||||
l_paren, r_paren = parens.split()
|
|
||||||
l_paren.next_to(matrix, LEFT, buff=0.1)
|
|
||||||
r_paren.next_to(matrix, RIGHT, buff=0.1)
|
|
||||||
det = TexText("det")
|
|
||||||
det.scale(initial_scale_factor)
|
|
||||||
det.next_to(l_paren, LEFT, buff=0.1)
|
|
||||||
if background_rect:
|
|
||||||
det.add_background_rectangle()
|
|
||||||
det_text = VGroup(det, l_paren, r_paren)
|
|
||||||
if determinant is not None:
|
|
||||||
eq = Tex("=")
|
|
||||||
eq.next_to(r_paren, RIGHT, buff=0.1)
|
|
||||||
result = Tex(str(determinant))
|
|
||||||
result.next_to(eq, RIGHT, buff=0.2)
|
|
||||||
det_text.add(eq, result)
|
|
||||||
return det_text
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue