Refactor Matrix, DecimalMatrix, MobjectMatrix, etc.

This commit is contained in:
Grant Sanderson 2024-01-18 11:12:42 -06:00
parent 41ece958fd
commit 855ef9be8d

View file

@ -4,85 +4,36 @@ import itertools as it
import numpy as np import numpy as np
from manimlib.constants import DEFAULT_MOBJECT_TO_MOBJECT_BUFFER from manimlib.constants import DOWN, LEFT, RIGHT, ORIGIN
from manimlib.constants import DOWN, LEFT, RIGHT, UP
from manimlib.constants import WHITE
from manimlib.mobject.numbers import DecimalNumber from manimlib.mobject.numbers import DecimalNumber
from manimlib.mobject.numbers import Integer
from manimlib.mobject.shape_matchers import BackgroundRectangle
from manimlib.mobject.svg.tex_mobject import Tex from manimlib.mobject.svg.tex_mobject import Tex
from manimlib.mobject.svg.tex_mobject import TexText
from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.mobject.types.vectorized_mobject import VMobject
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Sequence from typing import Sequence, Union, Tuple
import numpy.typing as npt from manimlib.typing import ManimColor, Vect3, VectNArray, Self
from manimlib.mobject.mobject import Mobject
from manimlib.typing import ManimColor, Vect3, Self
StringMatrixType = Union[Sequence[Sequence[str]], np.ndarray[int, np.dtype[np.str_]]]
FloatMatrixType = Union[Sequence[Sequence[float]], VectNArray]
VMobjectMatrixType = Sequence[Sequence[VMobject]]
GenericMatrixType = Union[FloatMatrixType, StringMatrixType, VMobjectMatrixType]
VECTOR_LABEL_SCALE_FACTOR = 0.8
def matrix_to_tex_string(matrix: npt.ArrayLike) -> str:
matrix = np.array(matrix).astype("str")
if matrix.ndim == 1:
matrix = matrix.reshape((matrix.size, 1))
n_rows, n_cols = matrix.shape
prefix = R"\left[ \begin{array}{%s}" % ("c" * n_cols)
suffix = R"\end{array} \right]"
rows = [
" & ".join(row)
for row in matrix
]
return prefix + R" \\ ".join(rows) + suffix
def matrix_to_mobject(matrix: npt.ArrayLike) -> Tex:
return Tex(matrix_to_tex_string(matrix))
def vector_coordinate_label(
vector_mob: VMobject,
integer_labels: bool = True,
n_dim: int = 2,
color: ManimColor = WHITE
) -> Matrix:
vect = np.array(vector_mob.get_end())
if integer_labels:
vect = np.round(vect).astype(int)
vect = vect[:n_dim]
vect = vect.reshape((n_dim, 1))
label = Matrix(vect, add_background_rectangles_to_entries=True)
label.scale(VECTOR_LABEL_SCALE_FACTOR)
shift_dir = np.array(vector_mob.get_end())
if shift_dir[0] >= 0: # Pointing right
shift_dir -= label.get_left() + DEFAULT_MOBJECT_TO_MOBJECT_BUFFER * LEFT
else: # Pointing left
shift_dir -= label.get_right() + DEFAULT_MOBJECT_TO_MOBJECT_BUFFER * RIGHT
label.shift(shift_dir)
label.set_color(color)
label.rect = BackgroundRectangle(label)
label.add_to_back(label.rect)
return label
class Matrix(VMobject): class Matrix(VMobject):
def __init__( def __init__(
self, self,
matrix: Sequence[Sequence[str | float | VMobject]], matrix: GenericMatrixType,
v_buff: float = 0.8, v_buff: float = 0.5,
h_buff: float = 1.0, h_buff: float = 0.5,
bracket_h_buff: float = 0.2, bracket_h_buff: float = 0.2,
bracket_v_buff: float = 0.25, bracket_v_buff: float = 0.25,
add_background_rectangles_to_entries: bool = False, height: float | None = None,
include_background_rectangle: bool = False, element_config: dict = dict(),
element_alignment_corner: Vect3 = DOWN, element_alignment_corner: Vect3 = DOWN,
**kwargs
): ):
""" """
Matrix can either include numbers, tex_strings, Matrix can either include numbers, tex_strings,
@ -90,54 +41,60 @@ class Matrix(VMobject):
""" """
super().__init__() super().__init__()
mob_matrix = self.matrix_to_mob_matrix(matrix, **kwargs) self.mob_matrix = self.create_mobject_matrix(
self.mob_matrix = mob_matrix matrix, v_buff, h_buff, element_alignment_corner,
**element_config
)
# Create helpful groups for the elements
n_cols = len(self.mob_matrix[0])
self.elements = VGroup(*it.chain(*self.mob_matrix))
self.columns = VGroup(*(
VGroup(*(row[i] for row in self.mob_matrix))
for i in range(n_cols)
))
self.rows = VGroup(*(VGroup(*row) for row in self.mob_matrix))
self.organize_mob_matrix(mob_matrix, v_buff, h_buff, element_alignment_corner) # Add elements and brackets
self.elements = VGroup(*it.chain(*mob_matrix))
self.add(self.elements) self.add(self.elements)
if height is not None:
self.set_height(height - 2 * bracket_v_buff)
self.add_brackets(bracket_v_buff, bracket_h_buff) self.add_brackets(bracket_v_buff, bracket_h_buff)
self.center() self.center()
if add_background_rectangles_to_entries:
for mob in self.elements:
mob.add_background_rectangle()
if include_background_rectangle:
self.add_background_rectangle()
def create_mobject_matrix(
def element_to_mobject(self, element: str | float | VMobject, **config) -> VMobject:
if isinstance(element, VMobject):
return element
return Tex(str(element), **config)
def matrix_to_mob_matrix(
self, self,
matrix: Sequence[Sequence[str | float | VMobject]], matrix: GenericMatrixType,
**config
) -> list[list[VMobject]]:
return [
[
self.element_to_mobject(item, **config)
for item in row
]
for row in matrix
]
def organize_mob_matrix(
self,
matrix: list[list[VMobject]],
v_buff: float, v_buff: float,
h_buff: float, h_buff: float,
aligned_corner: Vect3, aligned_corner: Vect3,
) -> Self: **element_config
for i, row in enumerate(matrix): ) -> VMobjectMatrixType:
"""
Creates and organizes the matrix of mobjects
"""
mob_matrix = [
[
self.element_to_mobject(element, **element_config)
for element in row
]
for row in matrix
]
max_width = max(elem.get_width() for row in mob_matrix for elem in row)
max_height = max(elem.get_height() for row in mob_matrix for elem in row)
x_step = (max_width + h_buff) * RIGHT
y_step = (max_height + v_buff) * DOWN
for i, row in enumerate(mob_matrix):
for j, elem in enumerate(row): for j, elem in enumerate(row):
mob = matrix[i][j] elem.move_to(i * y_step + j * x_step, aligned_corner)
mob.move_to( return mob_matrix
i * v_buff * DOWN + j * h_buff * RIGHT,
aligned_corner def element_to_mobject(self, element, **config) -> VMobject:
) if isinstance(element, VMobject):
return self return element
elif isinstance(element, float | complex):
return DecimalNumber(element, **config)
else:
return Tex(str(element), **config)
def add_brackets(self, v_buff: float, h_buff: float) -> Self: def add_brackets(self, v_buff: float, h_buff: float) -> Self:
height = len(self.mob_matrix) height = len(self.mob_matrix)
@ -152,21 +109,25 @@ class Matrix(VMobject):
l_bracket.next_to(self, LEFT, h_buff) l_bracket.next_to(self, LEFT, h_buff)
r_bracket.next_to(self, RIGHT, h_buff) r_bracket.next_to(self, RIGHT, h_buff)
brackets.set_submobjects([l_bracket, r_bracket]) brackets.set_submobjects([l_bracket, r_bracket])
self.brackets = brackets self.brackets = VGroup(l_bracket, r_bracket)
self.add(*brackets) self.add(*brackets)
return self return self
def get_column(self, index: int):
if not 0 <= index < len(self.columns):
raise IndexError(f"Index {index} out of bound for matrix with {len(self.columns)} columns")
return self.columns[index]
def get_row(self, index: int):
if not 0 <= index < len(self.rows):
raise IndexError(f"Index {index} out of bound for matrix with {len(self.rows)} rows")
return self.rows[index]
def get_columns(self) -> VGroup: def get_columns(self) -> VGroup:
return VGroup(*[ return self.columns
VGroup(*[row[i] for row in self.mob_matrix])
for i in range(len(self.mob_matrix[0]))
])
def get_rows(self) -> VGroup: def get_rows(self) -> VGroup:
return VGroup(*[ return self.rows
VGroup(*row)
for row in self.mob_matrix
])
def set_column_colors(self, *colors: ManimColor) -> Self: def set_column_colors(self, *colors: ManimColor) -> Self:
columns = self.get_columns() columns = self.get_columns()
@ -179,7 +140,7 @@ class Matrix(VMobject):
mob.add_background_rectangle() mob.add_background_rectangle()
return self return self
def get_mob_matrix(self) -> list[list[Mobject]]: def get_mob_matrix(self) -> VMobjectMatrixType:
return self.mob_matrix return self.mob_matrix
def get_entries(self) -> VGroup: def get_entries(self) -> VGroup:
@ -190,50 +151,82 @@ class Matrix(VMobject):
class DecimalMatrix(Matrix): class DecimalMatrix(Matrix):
def element_to_mobject(self, element: float, num_decimal_places: int = 1, **config) -> DecimalNumber:
return DecimalNumber(element, num_decimal_places=num_decimal_places, **config)
class IntegerMatrix(Matrix):
def __init__( def __init__(
self, self,
matrix: npt.ArrayLike, matrix: FloatMatrixType,
element_alignment_corner: Vect3 = UP, num_decimal_places: int = 2,
**kwargs decimal_config: dict = dict(),
**config
): ):
super().__init__(matrix, element_alignment_corner=element_alignment_corner, **kwargs) super().__init__(
matrix,
element_config=dict(
num_decimal_places=num_decimal_places,
**decimal_config
),
**config
)
def element_to_mobject(self, element, **decimal_config) -> VMobject:
return DecimalNumber(element, **decimal_config)
class IntegerMatrix(DecimalMatrix):
def __init__(
self,
matrix: FloatMatrixType,
num_decimal_places: int = 0,
decimal_config: dict = dict(),
**config
):
super().__init__(matrix, num_decimal_places, decimal_config, **config)
class TexMatrix(Matrix):
def __init__(
self,
matrix: StringMatrixType,
tex_config: dict = dict(),
**config,
):
super().__init__(
matrix,
element_config=tex_config,
**config
)
def element_to_mobject(self, element: int, **config) -> Integer:
return Integer(element, **config)
class MobjectMatrix(Matrix): class MobjectMatrix(Matrix):
def __init__(
self,
group: VGroup,
n_rows: int | None = None,
n_cols: int | None = None,
height: float = 4.0,
element_alignment_corner=ORIGIN,
**config,
):
# Have fallback defaults of n_rows and n_cols
n_mobs = len(group)
if n_rows is None:
n_rows = int(np.sqrt(n_mobs)) if n_cols is None else n_mobs // n_cols
if n_cols is None:
n_cols = n_mobs // n_rows
if len(group) < n_rows * n_cols:
raise Exception("Input to MobjectMatrix must have at least n_rows * n_cols entries")
mob_matrix = [
[group[n * n_cols + k] for k in range(n_cols)]
for n in range(n_rows)
]
config.update(
height=height,
element_alignment_corner=element_alignment_corner,
)
super().__init__(mob_matrix, **config)
def element_to_mobject(self, element: VMobject, **config) -> VMobject: def element_to_mobject(self, element: VMobject, **config) -> VMobject:
return element return element
def get_det_text(
matrix: Matrix,
determinant: int | str | None = None,
background_rect: bool = False,
initial_scale_factor: int = 2
) -> VGroup:
parens = Tex("()")
parens.scale(initial_scale_factor)
parens.stretch_to_fit_height(matrix.get_height())
l_paren, r_paren = parens.split()
l_paren.next_to(matrix, LEFT, buff=0.1)
r_paren.next_to(matrix, RIGHT, buff=0.1)
det = TexText("det")
det.scale(initial_scale_factor)
det.next_to(l_paren, LEFT, buff=0.1)
if background_rect:
det.add_background_rectangle()
det_text = VGroup(det, l_paren, r_paren)
if determinant is not None:
eq = Tex("=")
eq.next_to(r_paren, RIGHT, buff=0.1)
result = Tex(str(determinant))
result.next_to(eq, RIGHT, buff=0.2)
det_text.add(eq, result)
return det_text