3b1b-manim/manimlib/mobject/matrix.py

223 lines
6.7 KiB
Python
Raw Normal View History

from __future__ import annotations
import itertools as it
from typing import Union, Sequence
import numpy as np
import numpy.typing as npt
from manimlib.constants import *
from manimlib.mobject.numbers import DecimalNumber
from manimlib.mobject.numbers import Integer
from manimlib.mobject.shape_matchers import BackgroundRectangle
from manimlib.mobject.svg.tex_mobject import Tex
from manimlib.mobject.svg.tex_mobject import TexText
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import colour
from manimlib.mobject.mobject import Mobject
ManimColor = Union[str, colour.Color, Sequence[float]]
VECTOR_LABEL_SCALE_FACTOR = 0.8
def matrix_to_tex_string(matrix: npt.ArrayLike) -> str:
matrix = np.array(matrix).astype("str")
if matrix.ndim == 1:
matrix = matrix.reshape((matrix.size, 1))
n_rows, n_cols = matrix.shape
prefix = "\\left[ \\begin{array}{%s}" % ("c" * n_cols)
suffix = "\\end{array} \\right]"
rows = [
" & ".join(row)
for row in matrix
]
return prefix + " \\\\ ".join(rows) + suffix
def matrix_to_mobject(matrix: npt.ArrayLike) -> Tex:
return Tex(matrix_to_tex_string(matrix))
def vector_coordinate_label(
vector_mob: VMobject,
integer_labels: bool = True,
n_dim: int = 2,
color: ManimColor = WHITE
) -> Matrix:
vect = np.array(vector_mob.get_end())
if integer_labels:
vect = np.round(vect).astype(int)
vect = vect[:n_dim]
vect = vect.reshape((n_dim, 1))
label = Matrix(vect, add_background_rectangles_to_entries=True)
label.scale(VECTOR_LABEL_SCALE_FACTOR)
shift_dir = np.array(vector_mob.get_end())
if shift_dir[0] >= 0: # Pointing right
shift_dir -= label.get_left() + DEFAULT_MOBJECT_TO_MOBJECT_BUFFER * LEFT
else: # Pointing left
shift_dir -= label.get_right() + DEFAULT_MOBJECT_TO_MOBJECT_BUFFER * RIGHT
label.shift(shift_dir)
label.set_color(color)
label.rect = BackgroundRectangle(label)
label.add_to_back(label.rect)
return label
class Matrix(VMobject):
CONFIG = {
2018-04-25 17:27:43 -07:00
"v_buff": 0.8,
"h_buff": 1.3,
"bracket_h_buff": 0.2,
"bracket_v_buff": 0.25,
"add_background_rectangles_to_entries": False,
2018-04-26 21:26:22 -07:00
"include_background_rectangle": False,
"element_to_mobject": Tex,
2018-04-24 16:06:14 -07:00
"element_to_mobject_config": {},
2021-03-23 08:46:13 -07:00
"element_alignment_corner": DOWN,
}
def __init__(self, matrix: npt.ArrayLike, **kwargs):
"""
2021-10-12 09:04:30 +08:00
Matrix can either include numbers, tex_strings,
or mobjects
"""
VMobject.__init__(self, **kwargs)
2020-08-30 16:00:59 -07:00
matrix = self.matrix = np.array(matrix, ndmin=2)
2018-04-24 16:06:14 -07:00
mob_matrix = self.matrix_to_mob_matrix(matrix)
self.organize_mob_matrix(mob_matrix)
# self.elements = VGroup(*mob_matrix.flatten())
self.elements = VGroup(*it.chain(*mob_matrix))
2018-04-24 16:06:14 -07:00
self.add(self.elements)
self.add_brackets()
self.center()
2018-04-24 16:06:14 -07:00
self.mob_matrix = mob_matrix
if self.add_background_rectangles_to_entries:
2018-04-24 16:06:14 -07:00
for mob in self.elements:
mob.add_background_rectangle()
2018-04-26 21:26:22 -07:00
if self.include_background_rectangle:
self.add_background_rectangle()
def matrix_to_mob_matrix(self, matrix: npt.ArrayLike) -> list[list[Mobject]]:
return [
[
self.element_to_mobject(item, **self.element_to_mobject_config)
for item in row
]
for row in matrix
]
def organize_mob_matrix(self, matrix: npt.ArrayLike):
for i, row in enumerate(matrix):
for j, elem in enumerate(row):
mob = matrix[i][j]
2018-04-24 16:06:14 -07:00
mob.move_to(
i * self.v_buff * DOWN + j * self.h_buff * RIGHT,
self.element_alignment_corner
)
return self
def add_brackets(self):
2020-08-30 16:00:59 -07:00
height = self.matrix.shape[0]
bracket_pair = Tex("".join([
2020-08-30 16:00:59 -07:00
"\\left[",
"\\begin{array}{c}",
*height * ["\\quad \\\\"],
2022-02-13 05:12:41 +02:00
"\\end{array}",
2020-08-30 16:00:59 -07:00
"\\right]",
]))[0]
bracket_pair.set_height(
self.get_height() + 1 * self.bracket_v_buff
)
2020-08-30 16:00:59 -07:00
l_bracket = bracket_pair[:len(bracket_pair) // 2]
r_bracket = bracket_pair[len(bracket_pair) // 2:]
l_bracket.next_to(self, LEFT, self.bracket_h_buff)
r_bracket.next_to(self, RIGHT, self.bracket_h_buff)
self.add(l_bracket, r_bracket)
self.brackets = VGroup(l_bracket, r_bracket)
return self
def get_columns(self) -> VGroup:
2018-09-27 17:37:25 -07:00
return VGroup(*[
VGroup(*[row[i] for row in self.mob_matrix])
for i in range(len(self.mob_matrix[0]))
2018-09-27 17:37:25 -07:00
])
def get_rows(self) -> VGroup:
return VGroup(*[
VGroup(*row)
for row in self.mob_matrix
])
def set_column_colors(self, *colors: ManimColor):
2018-09-27 17:37:25 -07:00
columns = self.get_columns()
for color, column in zip(colors, columns):
column.set_color(color)
return self
def add_background_to_entries(self):
for mob in self.get_entries():
mob.add_background_rectangle()
return self
def get_mob_matrix(self) -> list[list[Mobject]]:
return self.mob_matrix
def get_entries(self) -> VGroup:
return self.elements
def get_brackets(self) -> VGroup:
return self.brackets
2018-04-24 16:06:14 -07:00
class DecimalMatrix(Matrix):
CONFIG = {
"element_to_mobject": DecimalNumber,
"element_to_mobject_config": {"num_decimal_places": 1}
2018-04-24 16:06:14 -07:00
}
class IntegerMatrix(Matrix):
CONFIG = {
"element_to_mobject": Integer,
"element_alignment_corner": UP,
2018-04-24 16:06:14 -07:00
}
class MobjectMatrix(Matrix):
CONFIG = {
"element_to_mobject": lambda m: m,
}
def get_det_text(
matrix: Matrix,
determinant: int | str | None = None,
background_rect: bool = False,
initial_scale_factor: int = 2
) -> VGroup:
parens = Tex("(", ")")
2018-04-25 17:27:43 -07:00
parens.scale(initial_scale_factor)
2018-04-24 16:06:14 -07:00
parens.stretch_to_fit_height(matrix.get_height())
l_paren, r_paren = parens.split()
l_paren.next_to(matrix, LEFT, buff=0.1)
r_paren.next_to(matrix, RIGHT, buff=0.1)
det = TexText("det")
2018-04-29 18:03:09 -07:00
det.scale(initial_scale_factor)
det.next_to(l_paren, LEFT, buff=0.1)
2018-04-24 16:06:14 -07:00
if background_rect:
det.add_background_rectangle()
2019-03-16 22:12:31 -07:00
det_text = VGroup(det, l_paren, r_paren)
2018-04-24 16:06:14 -07:00
if determinant is not None:
eq = Tex("=")
2018-04-24 16:06:14 -07:00
eq.next_to(r_paren, RIGHT, buff=0.1)
result = Tex(str(determinant))
2018-04-24 16:06:14 -07:00
result.next_to(eq, RIGHT, buff=0.2)
det_text.add(eq, result)
return det_text