mirror of
https://github.com/3b1b/manim.git
synced 2025-04-13 09:47:07 +00:00
Refactor Matrix, DecimalMatrix, MobjectMatrix, etc.
This commit is contained in:
parent
41ece958fd
commit
855ef9be8d
1 changed files with 140 additions and 147 deletions
|
@ -4,85 +4,36 @@ import itertools as it
|
|||
|
||||
import numpy as np
|
||||
|
||||
from manimlib.constants import DEFAULT_MOBJECT_TO_MOBJECT_BUFFER
|
||||
from manimlib.constants import DOWN, LEFT, RIGHT, UP
|
||||
from manimlib.constants import WHITE
|
||||
from manimlib.constants import DOWN, LEFT, RIGHT, ORIGIN
|
||||
from manimlib.mobject.numbers import DecimalNumber
|
||||
from manimlib.mobject.numbers import Integer
|
||||
from manimlib.mobject.shape_matchers import BackgroundRectangle
|
||||
from manimlib.mobject.svg.tex_mobject import Tex
|
||||
from manimlib.mobject.svg.tex_mobject import TexText
|
||||
from manimlib.mobject.types.vectorized_mobject import VGroup
|
||||
from manimlib.mobject.types.vectorized_mobject import VMobject
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Sequence
|
||||
import numpy.typing as npt
|
||||
from manimlib.mobject.mobject import Mobject
|
||||
from manimlib.typing import ManimColor, Vect3, Self
|
||||
from typing import Sequence, Union, Tuple
|
||||
from manimlib.typing import ManimColor, Vect3, VectNArray, Self
|
||||
|
||||
StringMatrixType = Union[Sequence[Sequence[str]], np.ndarray[int, np.dtype[np.str_]]]
|
||||
FloatMatrixType = Union[Sequence[Sequence[float]], VectNArray]
|
||||
VMobjectMatrixType = Sequence[Sequence[VMobject]]
|
||||
GenericMatrixType = Union[FloatMatrixType, StringMatrixType, VMobjectMatrixType]
|
||||
|
||||
VECTOR_LABEL_SCALE_FACTOR = 0.8
|
||||
|
||||
|
||||
def matrix_to_tex_string(matrix: npt.ArrayLike) -> str:
|
||||
matrix = np.array(matrix).astype("str")
|
||||
if matrix.ndim == 1:
|
||||
matrix = matrix.reshape((matrix.size, 1))
|
||||
n_rows, n_cols = matrix.shape
|
||||
prefix = R"\left[ \begin{array}{%s}" % ("c" * n_cols)
|
||||
suffix = R"\end{array} \right]"
|
||||
rows = [
|
||||
" & ".join(row)
|
||||
for row in matrix
|
||||
]
|
||||
return prefix + R" \\ ".join(rows) + suffix
|
||||
|
||||
|
||||
def matrix_to_mobject(matrix: npt.ArrayLike) -> Tex:
|
||||
return Tex(matrix_to_tex_string(matrix))
|
||||
|
||||
|
||||
def vector_coordinate_label(
|
||||
vector_mob: VMobject,
|
||||
integer_labels: bool = True,
|
||||
n_dim: int = 2,
|
||||
color: ManimColor = WHITE
|
||||
) -> Matrix:
|
||||
vect = np.array(vector_mob.get_end())
|
||||
if integer_labels:
|
||||
vect = np.round(vect).astype(int)
|
||||
vect = vect[:n_dim]
|
||||
vect = vect.reshape((n_dim, 1))
|
||||
label = Matrix(vect, add_background_rectangles_to_entries=True)
|
||||
label.scale(VECTOR_LABEL_SCALE_FACTOR)
|
||||
|
||||
shift_dir = np.array(vector_mob.get_end())
|
||||
if shift_dir[0] >= 0: # Pointing right
|
||||
shift_dir -= label.get_left() + DEFAULT_MOBJECT_TO_MOBJECT_BUFFER * LEFT
|
||||
else: # Pointing left
|
||||
shift_dir -= label.get_right() + DEFAULT_MOBJECT_TO_MOBJECT_BUFFER * RIGHT
|
||||
label.shift(shift_dir)
|
||||
label.set_color(color)
|
||||
label.rect = BackgroundRectangle(label)
|
||||
label.add_to_back(label.rect)
|
||||
return label
|
||||
|
||||
|
||||
class Matrix(VMobject):
|
||||
def __init__(
|
||||
self,
|
||||
matrix: Sequence[Sequence[str | float | VMobject]],
|
||||
v_buff: float = 0.8,
|
||||
h_buff: float = 1.0,
|
||||
matrix: GenericMatrixType,
|
||||
v_buff: float = 0.5,
|
||||
h_buff: float = 0.5,
|
||||
bracket_h_buff: float = 0.2,
|
||||
bracket_v_buff: float = 0.25,
|
||||
add_background_rectangles_to_entries: bool = False,
|
||||
include_background_rectangle: bool = False,
|
||||
height: float | None = None,
|
||||
element_config: dict = dict(),
|
||||
element_alignment_corner: Vect3 = DOWN,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Matrix can either include numbers, tex_strings,
|
||||
|
@ -90,54 +41,60 @@ class Matrix(VMobject):
|
|||
"""
|
||||
super().__init__()
|
||||
|
||||
mob_matrix = self.matrix_to_mob_matrix(matrix, **kwargs)
|
||||
self.mob_matrix = mob_matrix
|
||||
self.mob_matrix = self.create_mobject_matrix(
|
||||
matrix, v_buff, h_buff, element_alignment_corner,
|
||||
**element_config
|
||||
)
|
||||
# Create helpful groups for the elements
|
||||
n_cols = len(self.mob_matrix[0])
|
||||
self.elements = VGroup(*it.chain(*self.mob_matrix))
|
||||
self.columns = VGroup(*(
|
||||
VGroup(*(row[i] for row in self.mob_matrix))
|
||||
for i in range(n_cols)
|
||||
))
|
||||
self.rows = VGroup(*(VGroup(*row) for row in self.mob_matrix))
|
||||
|
||||
self.organize_mob_matrix(mob_matrix, v_buff, h_buff, element_alignment_corner)
|
||||
self.elements = VGroup(*it.chain(*mob_matrix))
|
||||
# Add elements and brackets
|
||||
self.add(self.elements)
|
||||
if height is not None:
|
||||
self.set_height(height - 2 * bracket_v_buff)
|
||||
self.add_brackets(bracket_v_buff, bracket_h_buff)
|
||||
self.center()
|
||||
if add_background_rectangles_to_entries:
|
||||
for mob in self.elements:
|
||||
mob.add_background_rectangle()
|
||||
if include_background_rectangle:
|
||||
self.add_background_rectangle()
|
||||
|
||||
|
||||
def element_to_mobject(self, element: str | float | VMobject, **config) -> VMobject:
|
||||
if isinstance(element, VMobject):
|
||||
return element
|
||||
return Tex(str(element), **config)
|
||||
|
||||
def matrix_to_mob_matrix(
|
||||
def create_mobject_matrix(
|
||||
self,
|
||||
matrix: Sequence[Sequence[str | float | VMobject]],
|
||||
**config
|
||||
) -> list[list[VMobject]]:
|
||||
return [
|
||||
[
|
||||
self.element_to_mobject(item, **config)
|
||||
for item in row
|
||||
]
|
||||
for row in matrix
|
||||
]
|
||||
|
||||
def organize_mob_matrix(
|
||||
self,
|
||||
matrix: list[list[VMobject]],
|
||||
matrix: GenericMatrixType,
|
||||
v_buff: float,
|
||||
h_buff: float,
|
||||
aligned_corner: Vect3,
|
||||
) -> Self:
|
||||
for i, row in enumerate(matrix):
|
||||
**element_config
|
||||
) -> VMobjectMatrixType:
|
||||
"""
|
||||
Creates and organizes the matrix of mobjects
|
||||
"""
|
||||
mob_matrix = [
|
||||
[
|
||||
self.element_to_mobject(element, **element_config)
|
||||
for element in row
|
||||
]
|
||||
for row in matrix
|
||||
]
|
||||
max_width = max(elem.get_width() for row in mob_matrix for elem in row)
|
||||
max_height = max(elem.get_height() for row in mob_matrix for elem in row)
|
||||
x_step = (max_width + h_buff) * RIGHT
|
||||
y_step = (max_height + v_buff) * DOWN
|
||||
for i, row in enumerate(mob_matrix):
|
||||
for j, elem in enumerate(row):
|
||||
mob = matrix[i][j]
|
||||
mob.move_to(
|
||||
i * v_buff * DOWN + j * h_buff * RIGHT,
|
||||
aligned_corner
|
||||
)
|
||||
return self
|
||||
elem.move_to(i * y_step + j * x_step, aligned_corner)
|
||||
return mob_matrix
|
||||
|
||||
def element_to_mobject(self, element, **config) -> VMobject:
|
||||
if isinstance(element, VMobject):
|
||||
return element
|
||||
elif isinstance(element, float | complex):
|
||||
return DecimalNumber(element, **config)
|
||||
else:
|
||||
return Tex(str(element), **config)
|
||||
|
||||
def add_brackets(self, v_buff: float, h_buff: float) -> Self:
|
||||
height = len(self.mob_matrix)
|
||||
|
@ -152,21 +109,25 @@ class Matrix(VMobject):
|
|||
l_bracket.next_to(self, LEFT, h_buff)
|
||||
r_bracket.next_to(self, RIGHT, h_buff)
|
||||
brackets.set_submobjects([l_bracket, r_bracket])
|
||||
self.brackets = brackets
|
||||
self.brackets = VGroup(l_bracket, r_bracket)
|
||||
self.add(*brackets)
|
||||
return self
|
||||
|
||||
def get_column(self, index: int):
|
||||
if not 0 <= index < len(self.columns):
|
||||
raise IndexError(f"Index {index} out of bound for matrix with {len(self.columns)} columns")
|
||||
return self.columns[index]
|
||||
|
||||
def get_row(self, index: int):
|
||||
if not 0 <= index < len(self.rows):
|
||||
raise IndexError(f"Index {index} out of bound for matrix with {len(self.rows)} rows")
|
||||
return self.rows[index]
|
||||
|
||||
def get_columns(self) -> VGroup:
|
||||
return VGroup(*[
|
||||
VGroup(*[row[i] for row in self.mob_matrix])
|
||||
for i in range(len(self.mob_matrix[0]))
|
||||
])
|
||||
return self.columns
|
||||
|
||||
def get_rows(self) -> VGroup:
|
||||
return VGroup(*[
|
||||
VGroup(*row)
|
||||
for row in self.mob_matrix
|
||||
])
|
||||
return self.rows
|
||||
|
||||
def set_column_colors(self, *colors: ManimColor) -> Self:
|
||||
columns = self.get_columns()
|
||||
|
@ -179,7 +140,7 @@ class Matrix(VMobject):
|
|||
mob.add_background_rectangle()
|
||||
return self
|
||||
|
||||
def get_mob_matrix(self) -> list[list[Mobject]]:
|
||||
def get_mob_matrix(self) -> VMobjectMatrixType:
|
||||
return self.mob_matrix
|
||||
|
||||
def get_entries(self) -> VGroup:
|
||||
|
@ -190,50 +151,82 @@ class Matrix(VMobject):
|
|||
|
||||
|
||||
class DecimalMatrix(Matrix):
|
||||
def element_to_mobject(self, element: float, num_decimal_places: int = 1, **config) -> DecimalNumber:
|
||||
return DecimalNumber(element, num_decimal_places=num_decimal_places, **config)
|
||||
|
||||
|
||||
class IntegerMatrix(Matrix):
|
||||
def __init__(
|
||||
self,
|
||||
matrix: npt.ArrayLike,
|
||||
element_alignment_corner: Vect3 = UP,
|
||||
**kwargs
|
||||
matrix: FloatMatrixType,
|
||||
num_decimal_places: int = 2,
|
||||
decimal_config: dict = dict(),
|
||||
**config
|
||||
):
|
||||
super().__init__(matrix, element_alignment_corner=element_alignment_corner, **kwargs)
|
||||
super().__init__(
|
||||
matrix,
|
||||
element_config=dict(
|
||||
num_decimal_places=num_decimal_places,
|
||||
**decimal_config
|
||||
),
|
||||
**config
|
||||
)
|
||||
|
||||
def element_to_mobject(self, element, **decimal_config) -> VMobject:
|
||||
return DecimalNumber(element, **decimal_config)
|
||||
|
||||
|
||||
class IntegerMatrix(DecimalMatrix):
|
||||
def __init__(
|
||||
self,
|
||||
matrix: FloatMatrixType,
|
||||
num_decimal_places: int = 0,
|
||||
decimal_config: dict = dict(),
|
||||
**config
|
||||
):
|
||||
super().__init__(matrix, num_decimal_places, decimal_config, **config)
|
||||
|
||||
|
||||
|
||||
class TexMatrix(Matrix):
|
||||
def __init__(
|
||||
self,
|
||||
matrix: StringMatrixType,
|
||||
tex_config: dict = dict(),
|
||||
**config,
|
||||
):
|
||||
super().__init__(
|
||||
matrix,
|
||||
element_config=tex_config,
|
||||
**config
|
||||
)
|
||||
|
||||
def element_to_mobject(self, element: int, **config) -> Integer:
|
||||
return Integer(element, **config)
|
||||
|
||||
|
||||
class MobjectMatrix(Matrix):
|
||||
def __init__(
|
||||
self,
|
||||
group: VGroup,
|
||||
n_rows: int | None = None,
|
||||
n_cols: int | None = None,
|
||||
height: float = 4.0,
|
||||
element_alignment_corner=ORIGIN,
|
||||
**config,
|
||||
):
|
||||
# Have fallback defaults of n_rows and n_cols
|
||||
n_mobs = len(group)
|
||||
if n_rows is None:
|
||||
n_rows = int(np.sqrt(n_mobs)) if n_cols is None else n_mobs // n_cols
|
||||
if n_cols is None:
|
||||
n_cols = n_mobs // n_rows
|
||||
|
||||
if len(group) < n_rows * n_cols:
|
||||
raise Exception("Input to MobjectMatrix must have at least n_rows * n_cols entries")
|
||||
|
||||
mob_matrix = [
|
||||
[group[n * n_cols + k] for k in range(n_cols)]
|
||||
for n in range(n_rows)
|
||||
]
|
||||
config.update(
|
||||
height=height,
|
||||
element_alignment_corner=element_alignment_corner,
|
||||
)
|
||||
super().__init__(mob_matrix, **config)
|
||||
|
||||
def element_to_mobject(self, element: VMobject, **config) -> VMobject:
|
||||
return element
|
||||
|
||||
|
||||
def get_det_text(
|
||||
matrix: Matrix,
|
||||
determinant: int | str | None = None,
|
||||
background_rect: bool = False,
|
||||
initial_scale_factor: int = 2
|
||||
) -> VGroup:
|
||||
parens = Tex("()")
|
||||
parens.scale(initial_scale_factor)
|
||||
parens.stretch_to_fit_height(matrix.get_height())
|
||||
l_paren, r_paren = parens.split()
|
||||
l_paren.next_to(matrix, LEFT, buff=0.1)
|
||||
r_paren.next_to(matrix, RIGHT, buff=0.1)
|
||||
det = TexText("det")
|
||||
det.scale(initial_scale_factor)
|
||||
det.next_to(l_paren, LEFT, buff=0.1)
|
||||
if background_rect:
|
||||
det.add_background_rectangle()
|
||||
det_text = VGroup(det, l_paren, r_paren)
|
||||
if determinant is not None:
|
||||
eq = Tex("=")
|
||||
eq.next_to(r_paren, RIGHT, buff=0.1)
|
||||
result = Tex(str(determinant))
|
||||
result.next_to(eq, RIGHT, buff=0.2)
|
||||
det_text.add(eq, result)
|
||||
return det_text
|
||||
|
|
Loading…
Add table
Reference in a new issue