3b1b-manim/manimlib/mobject/matrix.py
2024-03-25 19:10:42 -03:00

295 lines
9.2 KiB
Python

from __future__ import annotations
import itertools as it
import numpy as np
from manimlib.constants import DOWN, LEFT, RIGHT, ORIGIN
from manimlib.constants import DEGREES
from manimlib.mobject.numbers import DecimalNumber
from manimlib.mobject.svg.tex_mobject import Tex
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Sequence, Union, Tuple, Optional
from manimlib.typing import ManimColor, Vect3, VectNArray, Self
StringMatrixType = Union[Sequence[Sequence[str]], np.ndarray[int, np.dtype[np.str_]]]
FloatMatrixType = Union[Sequence[Sequence[float]], VectNArray]
VMobjectMatrixType = Sequence[Sequence[VMobject]]
GenericMatrixType = Union[FloatMatrixType, StringMatrixType, VMobjectMatrixType]
class Matrix(VMobject):
def __init__(
self,
matrix: GenericMatrixType,
v_buff: float = 0.5,
h_buff: float = 0.5,
bracket_h_buff: float = 0.2,
bracket_v_buff: float = 0.25,
height: float | None = None,
element_config: dict = dict(),
element_alignment_corner: Vect3 = DOWN,
ellipses_row: Optional[int] = None,
ellipses_col: Optional[int] = None,
):
"""
Matrix can either include numbers, tex_strings,
or mobjects
"""
super().__init__()
self.mob_matrix = self.create_mobject_matrix(
matrix, v_buff, h_buff, element_alignment_corner,
**element_config
)
# Create helpful groups for the elements
n_cols = len(self.mob_matrix[0])
self.elements = [elem for row in self.mob_matrix for elem in row]
self.columns = VGroup(*(
VGroup(*(row[i] for row in self.mob_matrix))
for i in range(n_cols)
))
self.rows = VGroup(*(VGroup(*row) for row in self.mob_matrix))
if height is not None:
self.rows.set_height(height - 2 * bracket_v_buff)
self.brackets = self.create_brackets(self.rows, bracket_v_buff, bracket_h_buff)
self.ellipses = []
# Add elements and brackets
self.add(*self.elements)
self.add(*self.brackets)
self.center()
# Potentially add ellipses
self.swap_entries_for_ellipses(
ellipses_row,
ellipses_col,
)
def copy(self, deep: bool = False):
result = super().copy(deep)
self_family = self.get_family()
copy_family = result.get_family()
for attr in ["elements", "ellipses"]:
setattr(result, attr, [
copy_family[self_family.index(mob)]
for mob in getattr(self, attr)
])
return result
def create_mobject_matrix(
self,
matrix: GenericMatrixType,
v_buff: float,
h_buff: float,
aligned_corner: Vect3,
**element_config
) -> VMobjectMatrixType:
"""
Creates and organizes the matrix of mobjects
"""
mob_matrix = [
[
self.element_to_mobject(element, **element_config)
for element in row
]
for row in matrix
]
max_width = max(elem.get_width() for row in mob_matrix for elem in row)
max_height = max(elem.get_height() for row in mob_matrix for elem in row)
x_step = (max_width + h_buff) * RIGHT
y_step = (max_height + v_buff) * DOWN
for i, row in enumerate(mob_matrix):
for j, elem in enumerate(row):
elem.move_to(i * y_step + j * x_step, aligned_corner)
return mob_matrix
def element_to_mobject(self, element, **config) -> VMobject:
if isinstance(element, VMobject):
return element
elif isinstance(element, float | complex):
return DecimalNumber(element, **config)
else:
return Tex(str(element), **config)
def create_brackets(self, rows, v_buff: float, h_buff: float) -> VGroup:
brackets = Tex("".join((
R"\left[\begin{array}{c}",
*len(rows) * [R"\quad \\"],
R"\end{array}\right]",
)))
brackets.set_height(rows.get_height() + v_buff)
l_bracket = brackets[:len(brackets) // 2]
r_bracket = brackets[len(brackets) // 2:]
l_bracket.next_to(rows, LEFT, h_buff)
r_bracket.next_to(rows, RIGHT, h_buff)
return VGroup(l_bracket, r_bracket)
def get_column(self, index: int):
if not 0 <= index < len(self.columns):
raise IndexError(f"Index {index} out of bound for matrix with {len(self.columns)} columns")
return self.columns[index]
def get_row(self, index: int):
if not 0 <= index < len(self.rows):
raise IndexError(f"Index {index} out of bound for matrix with {len(self.rows)} rows")
return self.rows[index]
def get_columns(self) -> VGroup:
return self.columns
def get_rows(self) -> VGroup:
return self.rows
def set_column_colors(self, *colors: ManimColor) -> Self:
columns = self.get_columns()
for color, column in zip(colors, columns):
column.set_color(color)
return self
def add_background_to_entries(self) -> Self:
for mob in self.get_entries():
mob.add_background_rectangle()
return self
def swap_entry_for_dots(self, entry, dots):
dots.move_to(entry)
entry.become(dots)
if entry in self.elements:
self.elements.remove(entry)
if entry not in self.ellipses:
self.ellipses.append(entry)
def swap_entries_for_ellipses(
self,
row_index: Optional[int] = None,
col_index: Optional[int] = None,
height_ratio: float = 0.65,
width_ratio: float = 0.4
):
rows = self.get_rows()
cols = self.get_columns()
avg_row_height = rows.get_height() / len(rows)
vdots_height = height_ratio * avg_row_height
avg_col_width = cols.get_width() / len(cols)
hdots_width = width_ratio * avg_col_width
use_vdots = row_index is not None and -len(rows) <= row_index < len(rows)
use_hdots = col_index is not None and -len(cols) <= col_index < len(cols)
if use_vdots:
for column in cols:
# Add vdots
dots = Tex(R"\vdots")
dots.set_height(vdots_height)
self.swap_entry_for_dots(column[row_index], dots)
if use_hdots:
for row in rows:
# Add hdots
dots = Tex(R"\hdots")
dots.set_width(hdots_width)
self.swap_entry_for_dots(row[col_index], dots)
if use_vdots and use_hdots:
rows[row_index][col_index].rotate(-45 * DEGREES)
return self
def get_mob_matrix(self) -> VMobjectMatrixType:
return self.mob_matrix
def get_entries(self) -> VGroup:
return VGroup(*self.elements)
def get_brackets(self) -> VGroup:
return VGroup(*self.brackets)
def get_ellipses(self) -> VGroup:
return VGroup(*self.ellipses)
class DecimalMatrix(Matrix):
def __init__(
self,
matrix: FloatMatrixType,
num_decimal_places: int = 2,
decimal_config: dict = dict(),
**config
):
self.float_matrix = matrix
super().__init__(
matrix,
element_config=dict(
num_decimal_places=num_decimal_places,
**decimal_config
),
**config
)
def element_to_mobject(self, element, **decimal_config) -> DecimalNumber:
return DecimalNumber(element, **decimal_config)
class IntegerMatrix(DecimalMatrix):
def __init__(
self,
matrix: FloatMatrixType,
num_decimal_places: int = 0,
decimal_config: dict = dict(),
**config
):
super().__init__(matrix, num_decimal_places, decimal_config, **config)
class TexMatrix(Matrix):
def __init__(
self,
matrix: StringMatrixType,
tex_config: dict = dict(),
**config,
):
super().__init__(
matrix,
element_config=tex_config,
**config
)
class MobjectMatrix(Matrix):
def __init__(
self,
group: VGroup,
n_rows: int | None = None,
n_cols: int | None = None,
height: float = 4.0,
element_alignment_corner=ORIGIN,
**config,
):
# Have fallback defaults of n_rows and n_cols
n_mobs = len(group)
if n_rows is None:
n_rows = int(np.sqrt(n_mobs)) if n_cols is None else n_mobs // n_cols
if n_cols is None:
n_cols = n_mobs // n_rows
if len(group) < n_rows * n_cols:
raise Exception("Input to MobjectMatrix must have at least n_rows * n_cols entries")
mob_matrix = [
[group[n * n_cols + k] for k in range(n_cols)]
for n in range(n_rows)
]
config.update(
height=height,
element_alignment_corner=element_alignment_corner,
)
super().__init__(mob_matrix, **config)
def element_to_mobject(self, element: VMobject, **config) -> VMobject:
return element