Kill CONFIG, and slightly refactor, matrix.py

This commit is contained in:
Grant Sanderson 2022-12-15 20:09:52 -08:00
parent 9d65ef3cae
commit a7d7ed0793

View file

@ -18,9 +18,10 @@ from manimlib.mobject.types.vectorized_mobject import VMobject
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Sequence
import numpy.typing as npt
from manimlib.mobject.mobject import Mobject
from manimlib.constants import ManimColor
from manimlib.constants import ManimColor, np_vector
VECTOR_LABEL_SCALE_FACTOR = 0.8
@ -31,13 +32,13 @@ def matrix_to_tex_string(matrix: npt.ArrayLike) -> str:
if matrix.ndim == 1:
matrix = matrix.reshape((matrix.size, 1))
n_rows, n_cols = matrix.shape
prefix = "\\left[ \\begin{array}{%s}" % ("c" * n_cols)
suffix = "\\end{array} \\right]"
prefix = R"\left[ \begin{array}{%s}" % ("c" * n_cols)
suffix = R"\end{array} \right]"
rows = [
" & ".join(row)
for row in matrix
]
return prefix + " \\\\ ".join(rows) + suffix
return prefix + R" \\ ".join(rows) + suffix
def matrix_to_mobject(matrix: npt.ArrayLike) -> Tex:
@ -71,73 +72,83 @@ def vector_coordinate_label(
class Matrix(VMobject):
CONFIG = {
"v_buff": 0.8,
"h_buff": 1.3,
"bracket_h_buff": 0.2,
"bracket_v_buff": 0.25,
"add_background_rectangles_to_entries": False,
"include_background_rectangle": False,
"element_to_mobject": Tex,
"element_to_mobject_config": {},
"element_alignment_corner": DOWN,
}
def __init__(self, matrix: npt.ArrayLike, **kwargs):
def __init__(
self,
matrix: Sequence[Sequence[str | float | VMobject]],
v_buff: float = 0.8,
h_buff: float = 1.3,
bracket_h_buff: float = 0.2,
bracket_v_buff: float = 0.25,
add_background_rectangles_to_entries: bool = False,
include_background_rectangle: bool = False,
element_config: dict = dict(),
element_alignment_corner: np_vector = DOWN,
**kwargs
):
"""
Matrix can either include numbers, tex_strings,
or mobjects
"""
VMobject.__init__(self, **kwargs)
mob_matrix = self.mob_matrix = self.matrix_to_mob_matrix(matrix)
self.organize_mob_matrix(mob_matrix)
super().__init__(**kwargs)
mob_matrix = self.matrix_to_mob_matrix(matrix, **element_config)
self.mob_matrix = mob_matrix
self.organize_mob_matrix(mob_matrix, v_buff, h_buff, element_alignment_corner)
self.elements = VGroup(*it.chain(*mob_matrix))
self.add(self.elements)
self.add_brackets()
self.add_brackets(bracket_v_buff, bracket_h_buff)
self.center()
if self.add_background_rectangles_to_entries:
if add_background_rectangles_to_entries:
for mob in self.elements:
mob.add_background_rectangle()
if self.include_background_rectangle:
if include_background_rectangle:
self.add_background_rectangle()
def matrix_to_mob_matrix(self, matrix: npt.ArrayLike) -> list[list[Mobject]]:
def element_to_mobject(self, element: str | float, **config) -> Tex:
return Tex(str(element), **config)
def matrix_to_mob_matrix(self, matrix: npt.ArrayLike, **config) -> list[list[VMobject]]:
return [
[
self.element_to_mobject(item, **self.element_to_mobject_config)
self.element_to_mobject(item, **config)
for item in row
]
for row in matrix
]
def organize_mob_matrix(self, matrix: npt.ArrayLike):
def organize_mob_matrix(
self,
matrix: list[list[Mobject]],
v_buff: float,
h_buff: float,
aligned_corner: np_vector,
):
for i, row in enumerate(matrix):
for j, elem in enumerate(row):
mob = matrix[i][j]
mob.move_to(
i * self.v_buff * DOWN + j * self.h_buff * RIGHT,
self.element_alignment_corner
i * v_buff * DOWN + j * h_buff * RIGHT,
aligned_corner
)
return self
def add_brackets(self):
def add_brackets(self, v_buff: float, h_buff: float):
height = len(self.mob_matrix)
bracket_pair = Tex("".join([
"\\left[",
"\\begin{array}{c}",
*height * ["\\quad \\\\"],
"\\end{array}",
"\\right]",
]))[0]
bracket_pair.set_height(
self.get_height() + 1 * self.bracket_v_buff
)
l_bracket = bracket_pair[:len(bracket_pair) // 2]
r_bracket = bracket_pair[len(bracket_pair) // 2:]
l_bracket.next_to(self, LEFT, self.bracket_h_buff)
r_bracket.next_to(self, RIGHT, self.bracket_h_buff)
self.add(l_bracket, r_bracket)
self.brackets = VGroup(l_bracket, r_bracket)
brackets = Tex("".join((
R"\left[\begin{array}{c}",
*height * [R"\quad \\"],
R"\end{array}\right]",
)))[0]
brackets.set_height(self.get_height() + v_buff)
l_bracket = brackets[:len(brackets) // 2]
r_bracket = brackets[len(brackets) // 2:]
l_bracket.next_to(self, LEFT, h_buff)
r_bracket.next_to(self, RIGHT, h_buff)
brackets.set_submobjects([l_bracket, r_bracket])
self.brackets = brackets
self.add(*brackets)
return self
def get_columns(self) -> VGroup:
@ -174,23 +185,26 @@ class Matrix(VMobject):
class DecimalMatrix(Matrix):
CONFIG = {
"element_to_mobject": DecimalNumber,
"element_to_mobject_config": {"num_decimal_places": 1}
}
def element_to_mobject(self, element: float, num_decimal_places: int = 1, **config) -> DecimalNumber:
return DecimalNumber(element, num_decimal_places=num_decimal_places, **config)
class IntegerMatrix(Matrix):
CONFIG = {
"element_to_mobject": Integer,
"element_alignment_corner": UP,
}
def __init__(
self,
matrix: npt.ArrayLike,
element_alignment_corner: np_vector = UP,
**kwargs
):
super().__init__(matrix, element_alignment_corner=element_alignment_corner, **kwargs)
def element_to_mobject(self, element: int, **config) -> Integer:
return Integer(element, **config)
class MobjectMatrix(Matrix):
CONFIG = {
"element_to_mobject": lambda m: m,
}
def element_to_mobject(self, element: VMobject, **config) -> VMobject:
return element
def get_det_text(