Refactor Matrix, DecimalMatrix, MobjectMatrix, etc.

This commit is contained in:
Grant Sanderson 2024-01-18 11:12:42 -06:00
parent 41ece958fd
commit 855ef9be8d

View file

@ -4,85 +4,36 @@ import itertools as it
import numpy as np
from manimlib.constants import DEFAULT_MOBJECT_TO_MOBJECT_BUFFER
from manimlib.constants import DOWN, LEFT, RIGHT, UP
from manimlib.constants import WHITE
from manimlib.constants import DOWN, LEFT, RIGHT, ORIGIN
from manimlib.mobject.numbers import DecimalNumber
from manimlib.mobject.numbers import Integer
from manimlib.mobject.shape_matchers import BackgroundRectangle
from manimlib.mobject.svg.tex_mobject import Tex
from manimlib.mobject.svg.tex_mobject import TexText
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Sequence
import numpy.typing as npt
from manimlib.mobject.mobject import Mobject
from manimlib.typing import ManimColor, Vect3, Self
from typing import Sequence, Union, Tuple
from manimlib.typing import ManimColor, Vect3, VectNArray, Self
StringMatrixType = Union[Sequence[Sequence[str]], np.ndarray[int, np.dtype[np.str_]]]
FloatMatrixType = Union[Sequence[Sequence[float]], VectNArray]
VMobjectMatrixType = Sequence[Sequence[VMobject]]
GenericMatrixType = Union[FloatMatrixType, StringMatrixType, VMobjectMatrixType]
VECTOR_LABEL_SCALE_FACTOR = 0.8
def matrix_to_tex_string(matrix: npt.ArrayLike) -> str:
matrix = np.array(matrix).astype("str")
if matrix.ndim == 1:
matrix = matrix.reshape((matrix.size, 1))
n_rows, n_cols = matrix.shape
prefix = R"\left[ \begin{array}{%s}" % ("c" * n_cols)
suffix = R"\end{array} \right]"
rows = [
" & ".join(row)
for row in matrix
]
return prefix + R" \\ ".join(rows) + suffix
def matrix_to_mobject(matrix: npt.ArrayLike) -> Tex:
return Tex(matrix_to_tex_string(matrix))
def vector_coordinate_label(
vector_mob: VMobject,
integer_labels: bool = True,
n_dim: int = 2,
color: ManimColor = WHITE
) -> Matrix:
vect = np.array(vector_mob.get_end())
if integer_labels:
vect = np.round(vect).astype(int)
vect = vect[:n_dim]
vect = vect.reshape((n_dim, 1))
label = Matrix(vect, add_background_rectangles_to_entries=True)
label.scale(VECTOR_LABEL_SCALE_FACTOR)
shift_dir = np.array(vector_mob.get_end())
if shift_dir[0] >= 0: # Pointing right
shift_dir -= label.get_left() + DEFAULT_MOBJECT_TO_MOBJECT_BUFFER * LEFT
else: # Pointing left
shift_dir -= label.get_right() + DEFAULT_MOBJECT_TO_MOBJECT_BUFFER * RIGHT
label.shift(shift_dir)
label.set_color(color)
label.rect = BackgroundRectangle(label)
label.add_to_back(label.rect)
return label
class Matrix(VMobject):
def __init__(
self,
matrix: Sequence[Sequence[str | float | VMobject]],
v_buff: float = 0.8,
h_buff: float = 1.0,
matrix: GenericMatrixType,
v_buff: float = 0.5,
h_buff: float = 0.5,
bracket_h_buff: float = 0.2,
bracket_v_buff: float = 0.25,
add_background_rectangles_to_entries: bool = False,
include_background_rectangle: bool = False,
height: float | None = None,
element_config: dict = dict(),
element_alignment_corner: Vect3 = DOWN,
**kwargs
):
"""
Matrix can either include numbers, tex_strings,
@ -90,54 +41,60 @@ class Matrix(VMobject):
"""
super().__init__()
mob_matrix = self.matrix_to_mob_matrix(matrix, **kwargs)
self.mob_matrix = mob_matrix
self.mob_matrix = self.create_mobject_matrix(
matrix, v_buff, h_buff, element_alignment_corner,
**element_config
)
# Create helpful groups for the elements
n_cols = len(self.mob_matrix[0])
self.elements = VGroup(*it.chain(*self.mob_matrix))
self.columns = VGroup(*(
VGroup(*(row[i] for row in self.mob_matrix))
for i in range(n_cols)
))
self.rows = VGroup(*(VGroup(*row) for row in self.mob_matrix))
self.organize_mob_matrix(mob_matrix, v_buff, h_buff, element_alignment_corner)
self.elements = VGroup(*it.chain(*mob_matrix))
# Add elements and brackets
self.add(self.elements)
if height is not None:
self.set_height(height - 2 * bracket_v_buff)
self.add_brackets(bracket_v_buff, bracket_h_buff)
self.center()
if add_background_rectangles_to_entries:
for mob in self.elements:
mob.add_background_rectangle()
if include_background_rectangle:
self.add_background_rectangle()
def element_to_mobject(self, element: str | float | VMobject, **config) -> VMobject:
if isinstance(element, VMobject):
return element
return Tex(str(element), **config)
def matrix_to_mob_matrix(
def create_mobject_matrix(
self,
matrix: Sequence[Sequence[str | float | VMobject]],
**config
) -> list[list[VMobject]]:
return [
[
self.element_to_mobject(item, **config)
for item in row
]
for row in matrix
]
def organize_mob_matrix(
self,
matrix: list[list[VMobject]],
matrix: GenericMatrixType,
v_buff: float,
h_buff: float,
aligned_corner: Vect3,
) -> Self:
for i, row in enumerate(matrix):
**element_config
) -> VMobjectMatrixType:
"""
Creates and organizes the matrix of mobjects
"""
mob_matrix = [
[
self.element_to_mobject(element, **element_config)
for element in row
]
for row in matrix
]
max_width = max(elem.get_width() for row in mob_matrix for elem in row)
max_height = max(elem.get_height() for row in mob_matrix for elem in row)
x_step = (max_width + h_buff) * RIGHT
y_step = (max_height + v_buff) * DOWN
for i, row in enumerate(mob_matrix):
for j, elem in enumerate(row):
mob = matrix[i][j]
mob.move_to(
i * v_buff * DOWN + j * h_buff * RIGHT,
aligned_corner
)
return self
elem.move_to(i * y_step + j * x_step, aligned_corner)
return mob_matrix
def element_to_mobject(self, element, **config) -> VMobject:
if isinstance(element, VMobject):
return element
elif isinstance(element, float | complex):
return DecimalNumber(element, **config)
else:
return Tex(str(element), **config)
def add_brackets(self, v_buff: float, h_buff: float) -> Self:
height = len(self.mob_matrix)
@ -152,21 +109,25 @@ class Matrix(VMobject):
l_bracket.next_to(self, LEFT, h_buff)
r_bracket.next_to(self, RIGHT, h_buff)
brackets.set_submobjects([l_bracket, r_bracket])
self.brackets = brackets
self.brackets = VGroup(l_bracket, r_bracket)
self.add(*brackets)
return self
def get_column(self, index: int):
if not 0 <= index < len(self.columns):
raise IndexError(f"Index {index} out of bound for matrix with {len(self.columns)} columns")
return self.columns[index]
def get_row(self, index: int):
if not 0 <= index < len(self.rows):
raise IndexError(f"Index {index} out of bound for matrix with {len(self.rows)} rows")
return self.rows[index]
def get_columns(self) -> VGroup:
return VGroup(*[
VGroup(*[row[i] for row in self.mob_matrix])
for i in range(len(self.mob_matrix[0]))
])
return self.columns
def get_rows(self) -> VGroup:
return VGroup(*[
VGroup(*row)
for row in self.mob_matrix
])
return self.rows
def set_column_colors(self, *colors: ManimColor) -> Self:
columns = self.get_columns()
@ -179,7 +140,7 @@ class Matrix(VMobject):
mob.add_background_rectangle()
return self
def get_mob_matrix(self) -> list[list[Mobject]]:
def get_mob_matrix(self) -> VMobjectMatrixType:
return self.mob_matrix
def get_entries(self) -> VGroup:
@ -190,50 +151,82 @@ class Matrix(VMobject):
class DecimalMatrix(Matrix):
def element_to_mobject(self, element: float, num_decimal_places: int = 1, **config) -> DecimalNumber:
return DecimalNumber(element, num_decimal_places=num_decimal_places, **config)
class IntegerMatrix(Matrix):
def __init__(
self,
matrix: npt.ArrayLike,
element_alignment_corner: Vect3 = UP,
**kwargs
matrix: FloatMatrixType,
num_decimal_places: int = 2,
decimal_config: dict = dict(),
**config
):
super().__init__(matrix, element_alignment_corner=element_alignment_corner, **kwargs)
super().__init__(
matrix,
element_config=dict(
num_decimal_places=num_decimal_places,
**decimal_config
),
**config
)
def element_to_mobject(self, element, **decimal_config) -> VMobject:
return DecimalNumber(element, **decimal_config)
class IntegerMatrix(DecimalMatrix):
def __init__(
self,
matrix: FloatMatrixType,
num_decimal_places: int = 0,
decimal_config: dict = dict(),
**config
):
super().__init__(matrix, num_decimal_places, decimal_config, **config)
class TexMatrix(Matrix):
def __init__(
self,
matrix: StringMatrixType,
tex_config: dict = dict(),
**config,
):
super().__init__(
matrix,
element_config=tex_config,
**config
)
def element_to_mobject(self, element: int, **config) -> Integer:
return Integer(element, **config)
class MobjectMatrix(Matrix):
def __init__(
self,
group: VGroup,
n_rows: int | None = None,
n_cols: int | None = None,
height: float = 4.0,
element_alignment_corner=ORIGIN,
**config,
):
# Have fallback defaults of n_rows and n_cols
n_mobs = len(group)
if n_rows is None:
n_rows = int(np.sqrt(n_mobs)) if n_cols is None else n_mobs // n_cols
if n_cols is None:
n_cols = n_mobs // n_rows
if len(group) < n_rows * n_cols:
raise Exception("Input to MobjectMatrix must have at least n_rows * n_cols entries")
mob_matrix = [
[group[n * n_cols + k] for k in range(n_cols)]
for n in range(n_rows)
]
config.update(
height=height,
element_alignment_corner=element_alignment_corner,
)
super().__init__(mob_matrix, **config)
def element_to_mobject(self, element: VMobject, **config) -> VMobject:
return element
def get_det_text(
matrix: Matrix,
determinant: int | str | None = None,
background_rect: bool = False,
initial_scale_factor: int = 2
) -> VGroup:
parens = Tex("()")
parens.scale(initial_scale_factor)
parens.stretch_to_fit_height(matrix.get_height())
l_paren, r_paren = parens.split()
l_paren.next_to(matrix, LEFT, buff=0.1)
r_paren.next_to(matrix, RIGHT, buff=0.1)
det = TexText("det")
det.scale(initial_scale_factor)
det.next_to(l_paren, LEFT, buff=0.1)
if background_rect:
det.add_background_rectangle()
det_text = VGroup(det, l_paren, r_paren)
if determinant is not None:
eq = Tex("=")
eq.next_to(r_paren, RIGHT, buff=0.1)
result = Tex(str(determinant))
result.next_to(eq, RIGHT, buff=0.2)
det_text.add(eq, result)
return det_text