From a7d7ed079386143b2092eee9d9a396f2f990785f Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Thu, 15 Dec 2022 20:09:52 -0800 Subject: [PATCH] Kill CONFIG, and slightly refactor, matrix.py --- manimlib/mobject/matrix.py | 126 ++++++++++++++++++++----------------- 1 file changed, 70 insertions(+), 56 deletions(-) diff --git a/manimlib/mobject/matrix.py b/manimlib/mobject/matrix.py index fae35a26..e5ba8407 100644 --- a/manimlib/mobject/matrix.py +++ b/manimlib/mobject/matrix.py @@ -18,9 +18,10 @@ from manimlib.mobject.types.vectorized_mobject import VMobject from typing import TYPE_CHECKING if TYPE_CHECKING: + from typing import Sequence import numpy.typing as npt from manimlib.mobject.mobject import Mobject - from manimlib.constants import ManimColor + from manimlib.constants import ManimColor, np_vector VECTOR_LABEL_SCALE_FACTOR = 0.8 @@ -31,13 +32,13 @@ def matrix_to_tex_string(matrix: npt.ArrayLike) -> str: if matrix.ndim == 1: matrix = matrix.reshape((matrix.size, 1)) n_rows, n_cols = matrix.shape - prefix = "\\left[ \\begin{array}{%s}" % ("c" * n_cols) - suffix = "\\end{array} \\right]" + prefix = R"\left[ \begin{array}{%s}" % ("c" * n_cols) + suffix = R"\end{array} \right]" rows = [ " & ".join(row) for row in matrix ] - return prefix + " \\\\ ".join(rows) + suffix + return prefix + R" \\ ".join(rows) + suffix def matrix_to_mobject(matrix: npt.ArrayLike) -> Tex: @@ -71,73 +72,83 @@ def vector_coordinate_label( class Matrix(VMobject): - CONFIG = { - "v_buff": 0.8, - "h_buff": 1.3, - "bracket_h_buff": 0.2, - "bracket_v_buff": 0.25, - "add_background_rectangles_to_entries": False, - "include_background_rectangle": False, - "element_to_mobject": Tex, - "element_to_mobject_config": {}, - "element_alignment_corner": DOWN, - } - - def __init__(self, matrix: npt.ArrayLike, **kwargs): + def __init__( + self, + matrix: Sequence[Sequence[str | float | VMobject]], + v_buff: float = 0.8, + h_buff: float = 1.3, + bracket_h_buff: float = 0.2, + bracket_v_buff: float = 0.25, + add_background_rectangles_to_entries: bool = False, + include_background_rectangle: bool = False, + element_config: dict = dict(), + element_alignment_corner: np_vector = DOWN, + **kwargs + ): """ Matrix can either include numbers, tex_strings, or mobjects """ - VMobject.__init__(self, **kwargs) - mob_matrix = self.mob_matrix = self.matrix_to_mob_matrix(matrix) - self.organize_mob_matrix(mob_matrix) + super().__init__(**kwargs) + + mob_matrix = self.matrix_to_mob_matrix(matrix, **element_config) + self.mob_matrix = mob_matrix + + self.organize_mob_matrix(mob_matrix, v_buff, h_buff, element_alignment_corner) self.elements = VGroup(*it.chain(*mob_matrix)) self.add(self.elements) - self.add_brackets() + self.add_brackets(bracket_v_buff, bracket_h_buff) self.center() - if self.add_background_rectangles_to_entries: + if add_background_rectangles_to_entries: for mob in self.elements: mob.add_background_rectangle() - if self.include_background_rectangle: + if include_background_rectangle: self.add_background_rectangle() - def matrix_to_mob_matrix(self, matrix: npt.ArrayLike) -> list[list[Mobject]]: + + def element_to_mobject(self, element: str | float, **config) -> Tex: + return Tex(str(element), **config) + + def matrix_to_mob_matrix(self, matrix: npt.ArrayLike, **config) -> list[list[VMobject]]: return [ [ - self.element_to_mobject(item, **self.element_to_mobject_config) + self.element_to_mobject(item, **config) for item in row ] for row in matrix ] - def organize_mob_matrix(self, matrix: npt.ArrayLike): + def organize_mob_matrix( + self, + matrix: list[list[Mobject]], + v_buff: float, + h_buff: float, + aligned_corner: np_vector, + ): for i, row in enumerate(matrix): for j, elem in enumerate(row): mob = matrix[i][j] mob.move_to( - i * self.v_buff * DOWN + j * self.h_buff * RIGHT, - self.element_alignment_corner + i * v_buff * DOWN + j * h_buff * RIGHT, + aligned_corner ) return self - def add_brackets(self): + def add_brackets(self, v_buff: float, h_buff: float): height = len(self.mob_matrix) - bracket_pair = Tex("".join([ - "\\left[", - "\\begin{array}{c}", - *height * ["\\quad \\\\"], - "\\end{array}", - "\\right]", - ]))[0] - bracket_pair.set_height( - self.get_height() + 1 * self.bracket_v_buff - ) - l_bracket = bracket_pair[:len(bracket_pair) // 2] - r_bracket = bracket_pair[len(bracket_pair) // 2:] - l_bracket.next_to(self, LEFT, self.bracket_h_buff) - r_bracket.next_to(self, RIGHT, self.bracket_h_buff) - self.add(l_bracket, r_bracket) - self.brackets = VGroup(l_bracket, r_bracket) + brackets = Tex("".join(( + R"\left[\begin{array}{c}", + *height * [R"\quad \\"], + R"\end{array}\right]", + )))[0] + brackets.set_height(self.get_height() + v_buff) + l_bracket = brackets[:len(brackets) // 2] + r_bracket = brackets[len(brackets) // 2:] + l_bracket.next_to(self, LEFT, h_buff) + r_bracket.next_to(self, RIGHT, h_buff) + brackets.set_submobjects([l_bracket, r_bracket]) + self.brackets = brackets + self.add(*brackets) return self def get_columns(self) -> VGroup: @@ -174,23 +185,26 @@ class Matrix(VMobject): class DecimalMatrix(Matrix): - CONFIG = { - "element_to_mobject": DecimalNumber, - "element_to_mobject_config": {"num_decimal_places": 1} - } + def element_to_mobject(self, element: float, num_decimal_places: int = 1, **config) -> DecimalNumber: + return DecimalNumber(element, num_decimal_places=num_decimal_places, **config) class IntegerMatrix(Matrix): - CONFIG = { - "element_to_mobject": Integer, - "element_alignment_corner": UP, - } + def __init__( + self, + matrix: npt.ArrayLike, + element_alignment_corner: np_vector = UP, + **kwargs + ): + super().__init__(matrix, element_alignment_corner=element_alignment_corner, **kwargs) + + def element_to_mobject(self, element: int, **config) -> Integer: + return Integer(element, **config) class MobjectMatrix(Matrix): - CONFIG = { - "element_to_mobject": lambda m: m, - } + def element_to_mobject(self, element: VMobject, **config) -> VMobject: + return element def get_det_text(