mirror of
https://github.com/3b1b/manim.git
synced 2025-04-13 09:47:07 +00:00
295 lines
9.2 KiB
Python
295 lines
9.2 KiB
Python
from __future__ import annotations
|
|
|
|
import itertools as it
|
|
|
|
import numpy as np
|
|
|
|
from manimlib.constants import DOWN, LEFT, RIGHT, ORIGIN
|
|
from manimlib.constants import DEGREES
|
|
from manimlib.mobject.numbers import DecimalNumber
|
|
from manimlib.mobject.svg.tex_mobject import Tex
|
|
from manimlib.mobject.types.vectorized_mobject import VGroup
|
|
from manimlib.mobject.types.vectorized_mobject import VMobject
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
if TYPE_CHECKING:
|
|
from typing import Sequence, Union, Tuple, Optional
|
|
from manimlib.typing import ManimColor, Vect3, VectNArray, Self
|
|
|
|
StringMatrixType = Union[Sequence[Sequence[str]], np.ndarray[int, np.dtype[np.str_]]]
|
|
FloatMatrixType = Union[Sequence[Sequence[float]], VectNArray]
|
|
VMobjectMatrixType = Sequence[Sequence[VMobject]]
|
|
GenericMatrixType = Union[FloatMatrixType, StringMatrixType, VMobjectMatrixType]
|
|
|
|
|
|
class Matrix(VMobject):
|
|
def __init__(
|
|
self,
|
|
matrix: GenericMatrixType,
|
|
v_buff: float = 0.5,
|
|
h_buff: float = 0.5,
|
|
bracket_h_buff: float = 0.2,
|
|
bracket_v_buff: float = 0.25,
|
|
height: float | None = None,
|
|
element_config: dict = dict(),
|
|
element_alignment_corner: Vect3 = DOWN,
|
|
ellipses_row: Optional[int] = None,
|
|
ellipses_col: Optional[int] = None,
|
|
):
|
|
"""
|
|
Matrix can either include numbers, tex_strings,
|
|
or mobjects
|
|
"""
|
|
super().__init__()
|
|
|
|
self.mob_matrix = self.create_mobject_matrix(
|
|
matrix, v_buff, h_buff, element_alignment_corner,
|
|
**element_config
|
|
)
|
|
|
|
# Create helpful groups for the elements
|
|
n_cols = len(self.mob_matrix[0])
|
|
self.elements = [elem for row in self.mob_matrix for elem in row]
|
|
self.columns = VGroup(*(
|
|
VGroup(*(row[i] for row in self.mob_matrix))
|
|
for i in range(n_cols)
|
|
))
|
|
self.rows = VGroup(*(VGroup(*row) for row in self.mob_matrix))
|
|
if height is not None:
|
|
self.rows.set_height(height - 2 * bracket_v_buff)
|
|
self.brackets = self.create_brackets(self.rows, bracket_v_buff, bracket_h_buff)
|
|
self.ellipses = []
|
|
|
|
# Add elements and brackets
|
|
self.add(*self.elements)
|
|
self.add(*self.brackets)
|
|
self.center()
|
|
|
|
# Potentially add ellipses
|
|
self.swap_entries_for_ellipses(
|
|
ellipses_row,
|
|
ellipses_col,
|
|
)
|
|
|
|
def copy(self, deep: bool = False):
|
|
result = super().copy(deep)
|
|
self_family = self.get_family()
|
|
copy_family = result.get_family()
|
|
for attr in ["elements", "ellipses"]:
|
|
setattr(result, attr, [
|
|
copy_family[self_family.index(mob)]
|
|
for mob in getattr(self, attr)
|
|
])
|
|
return result
|
|
|
|
def create_mobject_matrix(
|
|
self,
|
|
matrix: GenericMatrixType,
|
|
v_buff: float,
|
|
h_buff: float,
|
|
aligned_corner: Vect3,
|
|
**element_config
|
|
) -> VMobjectMatrixType:
|
|
"""
|
|
Creates and organizes the matrix of mobjects
|
|
"""
|
|
mob_matrix = [
|
|
[
|
|
self.element_to_mobject(element, **element_config)
|
|
for element in row
|
|
]
|
|
for row in matrix
|
|
]
|
|
max_width = max(elem.get_width() for row in mob_matrix for elem in row)
|
|
max_height = max(elem.get_height() for row in mob_matrix for elem in row)
|
|
x_step = (max_width + h_buff) * RIGHT
|
|
y_step = (max_height + v_buff) * DOWN
|
|
for i, row in enumerate(mob_matrix):
|
|
for j, elem in enumerate(row):
|
|
elem.move_to(i * y_step + j * x_step, aligned_corner)
|
|
return mob_matrix
|
|
|
|
def element_to_mobject(self, element, **config) -> VMobject:
|
|
if isinstance(element, VMobject):
|
|
return element
|
|
elif isinstance(element, float | complex):
|
|
return DecimalNumber(element, **config)
|
|
else:
|
|
return Tex(str(element), **config)
|
|
|
|
def create_brackets(self, rows, v_buff: float, h_buff: float) -> VGroup:
|
|
brackets = Tex("".join((
|
|
R"\left[\begin{array}{c}",
|
|
*len(rows) * [R"\quad \\"],
|
|
R"\end{array}\right]",
|
|
)))
|
|
brackets.set_height(rows.get_height() + v_buff)
|
|
l_bracket = brackets[:len(brackets) // 2]
|
|
r_bracket = brackets[len(brackets) // 2:]
|
|
l_bracket.next_to(rows, LEFT, h_buff)
|
|
r_bracket.next_to(rows, RIGHT, h_buff)
|
|
return VGroup(l_bracket, r_bracket)
|
|
|
|
def get_column(self, index: int):
|
|
if not 0 <= index < len(self.columns):
|
|
raise IndexError(f"Index {index} out of bound for matrix with {len(self.columns)} columns")
|
|
return self.columns[index]
|
|
|
|
def get_row(self, index: int):
|
|
if not 0 <= index < len(self.rows):
|
|
raise IndexError(f"Index {index} out of bound for matrix with {len(self.rows)} rows")
|
|
return self.rows[index]
|
|
|
|
def get_columns(self) -> VGroup:
|
|
return self.columns
|
|
|
|
def get_rows(self) -> VGroup:
|
|
return self.rows
|
|
|
|
def set_column_colors(self, *colors: ManimColor) -> Self:
|
|
columns = self.get_columns()
|
|
for color, column in zip(colors, columns):
|
|
column.set_color(color)
|
|
return self
|
|
|
|
def add_background_to_entries(self) -> Self:
|
|
for mob in self.get_entries():
|
|
mob.add_background_rectangle()
|
|
return self
|
|
|
|
def swap_entry_for_dots(self, entry, dots):
|
|
dots.move_to(entry)
|
|
entry.become(dots)
|
|
if entry in self.elements:
|
|
self.elements.remove(entry)
|
|
if entry not in self.ellipses:
|
|
self.ellipses.append(entry)
|
|
|
|
def swap_entries_for_ellipses(
|
|
self,
|
|
row_index: Optional[int] = None,
|
|
col_index: Optional[int] = None,
|
|
height_ratio: float = 0.65,
|
|
width_ratio: float = 0.4
|
|
):
|
|
rows = self.get_rows()
|
|
cols = self.get_columns()
|
|
|
|
avg_row_height = rows.get_height() / len(rows)
|
|
vdots_height = height_ratio * avg_row_height
|
|
|
|
avg_col_width = cols.get_width() / len(cols)
|
|
hdots_width = width_ratio * avg_col_width
|
|
|
|
use_vdots = row_index is not None and -len(rows) <= row_index < len(rows)
|
|
use_hdots = col_index is not None and -len(cols) <= col_index < len(cols)
|
|
|
|
if use_vdots:
|
|
for column in cols:
|
|
# Add vdots
|
|
dots = Tex(R"\vdots")
|
|
dots.set_height(vdots_height)
|
|
self.swap_entry_for_dots(column[row_index], dots)
|
|
if use_hdots:
|
|
for row in rows:
|
|
# Add hdots
|
|
dots = Tex(R"\hdots")
|
|
dots.set_width(hdots_width)
|
|
self.swap_entry_for_dots(row[col_index], dots)
|
|
if use_vdots and use_hdots:
|
|
rows[row_index][col_index].rotate(-45 * DEGREES)
|
|
return self
|
|
|
|
def get_mob_matrix(self) -> VMobjectMatrixType:
|
|
return self.mob_matrix
|
|
|
|
def get_entries(self) -> VGroup:
|
|
return VGroup(*self.elements)
|
|
|
|
def get_brackets(self) -> VGroup:
|
|
return VGroup(*self.brackets)
|
|
|
|
def get_ellipses(self) -> VGroup:
|
|
return VGroup(*self.ellipses)
|
|
|
|
|
|
class DecimalMatrix(Matrix):
|
|
def __init__(
|
|
self,
|
|
matrix: FloatMatrixType,
|
|
num_decimal_places: int = 2,
|
|
decimal_config: dict = dict(),
|
|
**config
|
|
):
|
|
self.float_matrix = matrix
|
|
super().__init__(
|
|
matrix,
|
|
element_config=dict(
|
|
num_decimal_places=num_decimal_places,
|
|
**decimal_config
|
|
),
|
|
**config
|
|
)
|
|
|
|
def element_to_mobject(self, element, **decimal_config) -> DecimalNumber:
|
|
return DecimalNumber(element, **decimal_config)
|
|
|
|
|
|
class IntegerMatrix(DecimalMatrix):
|
|
def __init__(
|
|
self,
|
|
matrix: FloatMatrixType,
|
|
num_decimal_places: int = 0,
|
|
decimal_config: dict = dict(),
|
|
**config
|
|
):
|
|
super().__init__(matrix, num_decimal_places, decimal_config, **config)
|
|
|
|
|
|
class TexMatrix(Matrix):
|
|
def __init__(
|
|
self,
|
|
matrix: StringMatrixType,
|
|
tex_config: dict = dict(),
|
|
**config,
|
|
):
|
|
super().__init__(
|
|
matrix,
|
|
element_config=tex_config,
|
|
**config
|
|
)
|
|
|
|
|
|
class MobjectMatrix(Matrix):
|
|
def __init__(
|
|
self,
|
|
group: VGroup,
|
|
n_rows: int | None = None,
|
|
n_cols: int | None = None,
|
|
height: float = 4.0,
|
|
element_alignment_corner=ORIGIN,
|
|
**config,
|
|
):
|
|
# Have fallback defaults of n_rows and n_cols
|
|
n_mobs = len(group)
|
|
if n_rows is None:
|
|
n_rows = int(np.sqrt(n_mobs)) if n_cols is None else n_mobs // n_cols
|
|
if n_cols is None:
|
|
n_cols = n_mobs // n_rows
|
|
|
|
if len(group) < n_rows * n_cols:
|
|
raise Exception("Input to MobjectMatrix must have at least n_rows * n_cols entries")
|
|
|
|
mob_matrix = [
|
|
[group[n * n_cols + k] for k in range(n_cols)]
|
|
for n in range(n_rows)
|
|
]
|
|
config.update(
|
|
height=height,
|
|
element_alignment_corner=element_alignment_corner,
|
|
)
|
|
super().__init__(mob_matrix, **config)
|
|
|
|
def element_to_mobject(self, element: VMobject, **config) -> VMobject:
|
|
return element
|