Merge pull request #1938 from 3b1b/video-work

Small refactors and bug fixes
This commit is contained in:
Grant Sanderson 2022-12-19 17:08:08 -08:00 committed by GitHub
commit 78fd6d3f35
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 58 additions and 47 deletions

View file

@ -405,6 +405,7 @@ class AnnularSector(VMobject):
fill_color=fill_color, fill_color=fill_color,
fill_opacity=fill_opacity, fill_opacity=fill_opacity,
stroke_width=stroke_width, stroke_width=stroke_width,
**kwargs,
) )
# Initialize points # Initialize points
@ -417,11 +418,10 @@ class AnnularSector(VMobject):
) )
for radius in (inner_radius, outer_radius) for radius in (inner_radius, outer_radius)
] ]
outer_arc.reverse_points() self.append_points(inner_arc.get_points()[::-1]) # Reverse
self.append_points(inner_arc.get_points())
self.add_line_to(outer_arc.get_points()[0]) self.add_line_to(outer_arc.get_points()[0])
self.append_points(outer_arc.get_points()) self.append_points(outer_arc.get_points())
self.add_line_to(inner_arc.get_points()[0]) self.add_line_to(inner_arc.get_points()[-1])
class Sector(AnnularSector): class Sector(AnnularSector):
@ -454,6 +454,7 @@ class Annulus(VMobject):
fill_color=fill_color, fill_color=fill_color,
fill_opacity=fill_opacity, fill_opacity=fill_opacity,
stroke_width=stroke_width, stroke_width=stroke_width,
**kwargs,
) )
self.radius = outer_radius self.radius = outer_radius

View file

@ -52,6 +52,7 @@ class Brace(SingleStringTex):
self.shift(left - self.get_corner(UL) + buff * DOWN) self.shift(left - self.get_corner(UL) + buff * DOWN)
for mob in mobject, self: for mob in mobject, self:
mob.rotate(angle, about_point=ORIGIN) mob.rotate(angle, about_point=ORIGIN)
self.refresh_unit_normal()
def set_initial_width(self, width: float): def set_initial_width(self, width: float):
width_diff = width - self.get_width() width_diff = width - self.get_width()

View file

@ -3,7 +3,6 @@ from __future__ import annotations
import re import re
from manimlib.mobject.svg.string_mobject import StringMobject from manimlib.mobject.svg.string_mobject import StringMobject
from manimlib.utils.tex_file_writing import display_during_execution
from manimlib.utils.tex_file_writing import tex_content_to_svg_file from manimlib.utils.tex_file_writing import tex_content_to_svg_file
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@ -79,10 +78,9 @@ class MTex(StringMobject):
) )
def get_file_path_by_content(self, content: str) -> str: def get_file_path_by_content(self, content: str) -> str:
with display_during_execution(f"Writing \"{self.tex_string}\""): file_path = tex_content_to_svg_file(
file_path = tex_content_to_svg_file( content, self.template, self.additional_preamble, self.tex_string
content, self.template, self.additional_preamble )
)
return file_path return file_path
# Parsing # Parsing

View file

@ -51,10 +51,11 @@ class StringMobject(SVGMobject, ABC):
so that each submobject of the original `SVGMobject` will be labelled so that each submobject of the original `SVGMobject` will be labelled
by the color of its paired submobject from the additional `SVGMobject`. by the color of its paired submobject from the additional `SVGMobject`.
""" """
height = None
def __init__( def __init__(
self, self,
string: str, string: str,
height: float | None = None,
fill_color: ManimColor = WHITE, fill_color: ManimColor = WHITE,
stroke_color: ManimColor = WHITE, stroke_color: ManimColor = WHITE,
stroke_width: float = 0, stroke_width: float = 0,
@ -68,17 +69,16 @@ class StringMobject(SVGMobject, ABC):
**kwargs **kwargs
): ):
self.string = string self.string = string
self.path_string_config = dict(path_string_config)
self.base_color = base_color or WHITE self.base_color = base_color or WHITE
self.isolate = isolate self.isolate = isolate
self.protect = protect self.protect = protect
self.parse() self.parse()
super().__init__( super().__init__(
height=height,
stroke_color=stroke_color, stroke_color=stroke_color,
fill_color=fill_color, fill_color=fill_color,
stroke_width=stroke_width, stroke_width=stroke_width,
path_string_config=path_string_config,
**kwargs **kwargs
) )
self.labels = [submob.label for submob in self.submobjects] self.labels = [submob.label for submob in self.submobjects]
@ -105,7 +105,7 @@ class StringMobject(SVGMobject, ABC):
labelled_svg = SVGMobject(file_path) labelled_svg = SVGMobject(file_path)
if len(self.submobjects) != len(labelled_svg.submobjects): if len(self.submobjects) != len(labelled_svg.submobjects):
log.warning( log.warning(
"Cannot align submobjects of the labelled svg " "Cannot align submobjects of the labelled svg " + \
"to the original svg. Skip the labelling process." "to the original svg. Skip the labelling process."
) )
for submob in self.submobjects: for submob in self.submobjects:
@ -126,7 +126,7 @@ class StringMobject(SVGMobject, ABC):
submob.label = label submob.label = label
if unrecognizable_colors: if unrecognizable_colors:
log.warning( log.warning(
"Unrecognizable color labels detected (%s). " "Unrecognizable color labels detected (%s). " + \
"The result could be unexpected.", "The result could be unexpected.",
", ".join( ", ".join(
self.int_to_hex(color) self.int_to_hex(color)
@ -144,11 +144,7 @@ class StringMobject(SVGMobject, ABC):
if not labelled_svg.submobjects: if not labelled_svg.submobjects:
return return
bb_0 = self.get_bounding_box() labelled_svg.replace(self)
bb_1 = labelled_svg.get_bounding_box()
scale_factor = abs((bb_0[2] - bb_0[0]) / (bb_1[2] - bb_1[0]))
labelled_svg.move_to(self).scale(scale_factor)
distance_matrix = cdist( distance_matrix = cdist(
[submob.get_center() for submob in self.submobjects], [submob.get_center() for submob in self.submobjects],
[submob.get_center() for submob in labelled_svg.submobjects] [submob.get_center() for submob in labelled_svg.submobjects]

View file

@ -5,6 +5,7 @@ from xml.etree import ElementTree as ET
import numpy as np import numpy as np
import svgelements as se import svgelements as se
import io
from manimlib.constants import RIGHT from manimlib.constants import RIGHT
from manimlib.logger import log from manimlib.logger import log
@ -35,12 +36,14 @@ def _convert_point_to_3d(x: float, y: float) -> np.ndarray:
class SVGMobject(VMobject): class SVGMobject(VMobject):
file_name: str = "" file_name: str = ""
height: float | None = 2.0
width: float | None = None
def __init__( def __init__(
self, self,
file_name: str = "", file_name: str = "",
should_center: bool = True, should_center: bool = True,
height: float | None = 2.0, height: float | None = None,
width: float | None = None, width: float | None = None,
# Style that overrides the original svg # Style that overrides the original svg
color: ManimColor = None, color: ManimColor = None,
@ -66,7 +69,6 @@ class SVGMobject(VMobject):
self.file_name = file_name or self.file_name self.file_name = file_name or self.file_name
self.svg_default = dict(svg_default) self.svg_default = dict(svg_default)
self.path_string_config = dict(path_string_config) self.path_string_config = dict(path_string_config)
self.height = height
super().__init__(**kwargs ) super().__init__(**kwargs )
self.init_svg_mobject() self.init_svg_mobject()
@ -82,6 +84,9 @@ class SVGMobject(VMobject):
) )
# Initialize position # Initialize position
height = height or self.height
width = width or self.width
if should_center: if should_center:
self.center() self.center()
if height is not None: if height is not None:
@ -114,13 +119,13 @@ class SVGMobject(VMobject):
file_path = self.get_file_path() file_path = self.get_file_path()
element_tree = ET.parse(file_path) element_tree = ET.parse(file_path)
new_tree = self.modify_xml_tree(element_tree) new_tree = self.modify_xml_tree(element_tree)
# Create a temporary svg file to dump modified svg to be parsed
root, ext = os.path.splitext(file_path)
modified_file_path = root + "_" + ext
new_tree.write(modified_file_path)
svg = se.SVG.parse(modified_file_path) # New svg based on tree contents
os.remove(modified_file_path) data_stream = io.BytesIO()
new_tree.write(data_stream)
data_stream.seek(0)
svg = se.SVG.parse(data_stream)
data_stream.close()
mobjects = self.get_mobjects_from(svg) mobjects = self.get_mobjects_from(svg)
self.add(*mobjects) self.add(*mobjects)

View file

@ -11,7 +11,6 @@ from manimlib.constants import MED_LARGE_BUFF, SMALL_BUFF
from manimlib.mobject.geometry import Line from manimlib.mobject.geometry import Line
from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.svg.svg_mobject import SVGMobject
from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.tex_file_writing import display_during_execution
from manimlib.utils.tex_file_writing import tex_content_to_svg_file from manimlib.utils.tex_file_writing import tex_content_to_svg_file
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@ -25,6 +24,8 @@ SCALE_FACTOR_PER_FONT_POINT = 0.001
class SingleStringTex(SVGMobject): class SingleStringTex(SVGMobject):
height: float | None = None
def __init__( def __init__(
self, self,
tex_string: str, tex_string: str,
@ -60,6 +61,7 @@ class SingleStringTex(SVGMobject):
fill_color=fill_color, fill_color=fill_color,
fill_opacity=fill_opacity, fill_opacity=fill_opacity,
stroke_width=stroke_width, stroke_width=stroke_width,
path_string_config=path_string_config,
**kwargs **kwargs
) )
@ -83,10 +85,9 @@ class SingleStringTex(SVGMobject):
def get_file_path(self) -> str: def get_file_path(self) -> str:
content = self.get_tex_file_body(self.tex_string) content = self.get_tex_file_body(self.tex_string)
with display_during_execution(f"Writing \"{self.tex_string}\""): file_path = tex_content_to_svg_file(
file_path = tex_content_to_svg_file( content, self.template, self.additional_preamble, self.tex_string
content, self.template, self.additional_preamble )
)
return file_path return file_path
def get_tex_file_body(self, tex_string: str) -> str: def get_tex_file_body(self, tex_string: str) -> str:
@ -246,7 +247,7 @@ class Tex(SingleStringTex):
tex_string = tex_string.strip() tex_string = tex_string.strip()
if len(tex_string) == 0: if len(tex_string) == 0:
continue continue
sub_tex_mob = SingleStringTex(tex_string) sub_tex_mob = SingleStringTex(tex_string, math_mode=self.math_mode)
num_submobs = len(sub_tex_mob) num_submobs = len(sub_tex_mob)
if num_submobs == 0: if num_submobs == 0:
continue continue

View file

@ -100,6 +100,7 @@ class VMobject(Mobject):
self.flat_stroke = flat_stroke self.flat_stroke = flat_stroke
self.needs_new_triangulation = True self.needs_new_triangulation = True
self.needs_new_unit_normal = True
self.triangulation = np.zeros(0, dtype='i4') self.triangulation = np.zeros(0, dtype='i4')
super().__init__(**kwargs) super().__init__(**kwargs)
@ -114,7 +115,7 @@ class VMobject(Mobject):
"fill_rgba": np.zeros((1, 4)), "fill_rgba": np.zeros((1, 4)),
"stroke_rgba": np.zeros((1, 4)), "stroke_rgba": np.zeros((1, 4)),
"stroke_width": np.zeros((1, 1)), "stroke_width": np.zeros((1, 1)),
"unit_normal": np.zeros((1, 3)) "unit_normal": np.array(OUT, ndmin=2),
}) })
# These are here just to make type checkers happy # These are here just to make type checkers happy
@ -771,7 +772,7 @@ class VMobject(Mobject):
]) ])
def get_unit_normal(self, recompute: bool = False) -> Vect3: def get_unit_normal(self, recompute: bool = False) -> Vect3:
if not recompute: if not self.needs_new_unit_normal and not recompute:
return self.data["unit_normal"][0] return self.data["unit_normal"][0]
if self.get_num_points() < 3: if self.get_num_points() < 3:
@ -788,17 +789,12 @@ class VMobject(Mobject):
points[2] - points[1], points[2] - points[1],
) )
self.data["unit_normal"][:] = normal self.data["unit_normal"][:] = normal
self.needs_new_unit_normal = False
return normal return normal
def refresh_unit_normal(self): def refresh_unit_normal(self):
for mob in self.get_family(): for mob in self.get_family():
mob.get_unit_normal(recompute=True) mob.needs_new_unit_normal = True
return self
def reverse_points(self):
super().reverse_points()
self.refresh_unit_normal()
self.refresh_triangulation()
return self return self
# Alignment # Alignment
@ -1030,6 +1026,16 @@ class VMobject(Mobject):
super().set_points(points) super().set_points(points)
return self return self
@triggers_refreshed_triangulation
def append_points(self, points: Vect3Array):
super().append_points(points)
return self
@triggers_refreshed_triangulation
def reverse_points(self):
super().reverse_points()
return self
@triggers_refreshed_triangulation @triggers_refreshed_triangulation
def set_data(self, data: dict): def set_data(self, data: dict):
super().set_data(data) super().set_data(data)

View file

@ -52,7 +52,8 @@ def get_tex_config() -> dict[str, str]:
def tex_content_to_svg_file( def tex_content_to_svg_file(
content: str, template: str, additional_preamble: str content: str, template: str, additional_preamble: str,
short_tex: str
) -> str: ) -> str:
tex_config = get_tex_config() tex_config = get_tex_config()
if not template or template == tex_config["template"]: if not template or template == tex_config["template"]:
@ -78,7 +79,8 @@ def tex_content_to_svg_file(
) )
if not os.path.exists(svg_file): if not os.path.exists(svg_file):
# If svg doesn't exist, create it # If svg doesn't exist, create it
create_tex_svg(full_tex, svg_file, compiler) with display_during_execution("Writing " + short_tex):
create_tex_svg(full_tex, svg_file, compiler)
return svg_file return svg_file
@ -112,14 +114,15 @@ def create_tex_svg(full_tex: str, svg_file: str, compiler: str) -> None:
log.error( log.error(
"LaTeX Error! Not a worry, it happens to the best of us." "LaTeX Error! Not a worry, it happens to the best of us."
) )
error_str = ""
with open(root + ".log", "r", encoding="utf-8") as log_file: with open(root + ".log", "r", encoding="utf-8") as log_file:
error_match_obj = re.search(r"(?<=\n! ).*", log_file.read()) error_match_obj = re.search(r"(?<=\n! ).*\n.*\n", log_file.read())
if error_match_obj: if error_match_obj:
error_str = error_match_obj.group()
log.debug( log.debug(
"The error could be: `%s`", f"The error could be:\n`{error_str}`",
error_match_obj.group()
) )
raise LatexError() raise LatexError(error_str)
# dvi to svg # dvi to svg
os.system(" ".join(( os.system(" ".join((