Have matrix keep track of elements and ellipses as lists instead of VGroups

This commit is contained in:
Grant Sanderson 2024-02-13 14:52:16 -06:00
parent ed3ac74d67
commit 7b577e9fc1

View file

@ -47,23 +47,24 @@ class Matrix(VMobject):
matrix, v_buff, h_buff, element_alignment_corner,
**element_config
)
# Create helpful groups for the elements
n_cols = len(self.mob_matrix[0])
self.elements = VGroup(*it.chain(*self.mob_matrix))
self.elements = [elem for row in self.mob_matrix for elem in row]
self.columns = VGroup(*(
VGroup(*(row[i] for row in self.mob_matrix))
for i in range(n_cols)
))
self.rows = VGroup(*(VGroup(*row) for row in self.mob_matrix))
self.ellipses = VGroup()
if height is not None:
self.rows.set_height(height - 2 * bracket_v_buff)
self.brackets = self.create_brackets(self.rows, bracket_v_buff, bracket_h_buff)
self.ellipses = []
# Add elements and brackets
self.elements.center()
self.add(self.elements)
if height is not None:
self.set_height(height - 2 * bracket_v_buff)
self.add_brackets(bracket_v_buff, bracket_h_buff)
self.add(self.ellipses)
self.add(*self.elements)
self.add(*self.brackets)
self.center()
# Potentially add ellipses
self.swap_entries_for_ellipses(
@ -71,6 +72,17 @@ class Matrix(VMobject):
ellipses_col,
)
def copy(self, deep: bool = False):
result = super().copy(deep)
self_family = self.get_family()
copy_family = result.get_family()
for attr in ["elements", "ellipses"]:
setattr(result, attr, [
copy_family[self_family.index(mob)]
for mob in getattr(self, attr)
])
return result
def create_mobject_matrix(
self,
matrix: GenericMatrixType,
@ -106,22 +118,18 @@ class Matrix(VMobject):
else:
return Tex(str(element), **config)
def add_brackets(self, v_buff: float, h_buff: float) -> Self:
height = len(self.mob_matrix)
def create_brackets(self, rows, v_buff: float, h_buff: float) -> VGroup:
brackets = Tex("".join((
R"\left[\begin{array}{c}",
*height * [R"\quad \\"],
*len(rows) * [R"\quad \\"],
R"\end{array}\right]",
)))
brackets.set_height(self.get_height() + v_buff)
brackets.set_height(rows.get_height() + v_buff)
l_bracket = brackets[:len(brackets) // 2]
r_bracket = brackets[len(brackets) // 2:]
l_bracket.next_to(self, LEFT, h_buff)
r_bracket.next_to(self, RIGHT, h_buff)
brackets.set_submobjects([l_bracket, r_bracket])
self.brackets = VGroup(l_bracket, r_bracket)
self.add(*brackets)
return self
l_bracket.next_to(rows, LEFT, h_buff)
r_bracket.next_to(rows, RIGHT, h_buff)
return VGroup(l_bracket, r_bracket)
def get_column(self, index: int):
if not 0 <= index < len(self.columns):
@ -150,6 +158,14 @@ class Matrix(VMobject):
mob.add_background_rectangle()
return self
def swap_entry_for_dots(self, entry, dots):
dots.move_to(entry)
entry.become(dots)
if entry in self.elements:
self.elements.remove(entry)
if entry not in self.ellipses:
self.ellipses.append(entry)
def swap_entries_for_ellipses(
self,
row_index: Optional[int] = None,
@ -169,24 +185,18 @@ class Matrix(VMobject):
use_vdots = row_index is not None and -len(rows) <= row_index < len(rows)
use_hdots = col_index is not None and -len(cols) <= col_index < len(cols)
def swap_entry_for_dots(entry, dots):
dots.move_to(entry)
entry.become(dots)
self.elements.remove(entry)
self.ellipses.add(entry)
if use_vdots:
for column in cols:
# Add vdots
dots = Tex(R"\vdots")
dots.set_height(vdots_height)
swap_entry_for_dots(column[row_index], dots)
self.swap_entry_for_dots(column[row_index], dots)
if use_hdots:
for row in rows:
# Add hdots
dots = Tex(R"\hdots")
dots.set_width(hdots_width)
swap_entry_for_dots(row[col_index], dots)
self.swap_entry_for_dots(row[col_index], dots)
if use_vdots and use_hdots:
rows[row_index][col_index].rotate(-45 * DEGREES)
return self
@ -195,10 +205,13 @@ class Matrix(VMobject):
return self.mob_matrix
def get_entries(self) -> VGroup:
return self.elements
return VGroup(*self.elements)
def get_brackets(self) -> VGroup:
return self.brackets
return VGroup(*self.brackets)
def get_ellipses(self) -> VGroup:
return VGroup(*self.ellipses)
class DecimalMatrix(Matrix):