Merge pull request #1938 from 3b1b/video-work

Small refactors and bug fixes
This commit is contained in:
Grant Sanderson 2022-12-19 17:08:08 -08:00 committed by GitHub
commit 78fd6d3f35
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 58 additions and 47 deletions

View file

@ -405,6 +405,7 @@ class AnnularSector(VMobject):
fill_color=fill_color,
fill_opacity=fill_opacity,
stroke_width=stroke_width,
**kwargs,
)
# Initialize points
@ -417,11 +418,10 @@ class AnnularSector(VMobject):
)
for radius in (inner_radius, outer_radius)
]
outer_arc.reverse_points()
self.append_points(inner_arc.get_points())
self.append_points(inner_arc.get_points()[::-1]) # Reverse
self.add_line_to(outer_arc.get_points()[0])
self.append_points(outer_arc.get_points())
self.add_line_to(inner_arc.get_points()[0])
self.add_line_to(inner_arc.get_points()[-1])
class Sector(AnnularSector):
@ -454,6 +454,7 @@ class Annulus(VMobject):
fill_color=fill_color,
fill_opacity=fill_opacity,
stroke_width=stroke_width,
**kwargs,
)
self.radius = outer_radius

View file

@ -52,6 +52,7 @@ class Brace(SingleStringTex):
self.shift(left - self.get_corner(UL) + buff * DOWN)
for mob in mobject, self:
mob.rotate(angle, about_point=ORIGIN)
self.refresh_unit_normal()
def set_initial_width(self, width: float):
width_diff = width - self.get_width()

View file

@ -3,7 +3,6 @@ from __future__ import annotations
import re
from manimlib.mobject.svg.string_mobject import StringMobject
from manimlib.utils.tex_file_writing import display_during_execution
from manimlib.utils.tex_file_writing import tex_content_to_svg_file
from typing import TYPE_CHECKING
@ -79,10 +78,9 @@ class MTex(StringMobject):
)
def get_file_path_by_content(self, content: str) -> str:
with display_during_execution(f"Writing \"{self.tex_string}\""):
file_path = tex_content_to_svg_file(
content, self.template, self.additional_preamble
)
file_path = tex_content_to_svg_file(
content, self.template, self.additional_preamble, self.tex_string
)
return file_path
# Parsing

View file

@ -51,10 +51,11 @@ class StringMobject(SVGMobject, ABC):
so that each submobject of the original `SVGMobject` will be labelled
by the color of its paired submobject from the additional `SVGMobject`.
"""
height = None
def __init__(
self,
string: str,
height: float | None = None,
fill_color: ManimColor = WHITE,
stroke_color: ManimColor = WHITE,
stroke_width: float = 0,
@ -68,17 +69,16 @@ class StringMobject(SVGMobject, ABC):
**kwargs
):
self.string = string
self.path_string_config = dict(path_string_config)
self.base_color = base_color or WHITE
self.isolate = isolate
self.protect = protect
self.parse()
super().__init__(
height=height,
stroke_color=stroke_color,
fill_color=fill_color,
stroke_width=stroke_width,
path_string_config=path_string_config,
**kwargs
)
self.labels = [submob.label for submob in self.submobjects]
@ -105,7 +105,7 @@ class StringMobject(SVGMobject, ABC):
labelled_svg = SVGMobject(file_path)
if len(self.submobjects) != len(labelled_svg.submobjects):
log.warning(
"Cannot align submobjects of the labelled svg "
"Cannot align submobjects of the labelled svg " + \
"to the original svg. Skip the labelling process."
)
for submob in self.submobjects:
@ -126,7 +126,7 @@ class StringMobject(SVGMobject, ABC):
submob.label = label
if unrecognizable_colors:
log.warning(
"Unrecognizable color labels detected (%s). "
"Unrecognizable color labels detected (%s). " + \
"The result could be unexpected.",
", ".join(
self.int_to_hex(color)
@ -144,11 +144,7 @@ class StringMobject(SVGMobject, ABC):
if not labelled_svg.submobjects:
return
bb_0 = self.get_bounding_box()
bb_1 = labelled_svg.get_bounding_box()
scale_factor = abs((bb_0[2] - bb_0[0]) / (bb_1[2] - bb_1[0]))
labelled_svg.move_to(self).scale(scale_factor)
labelled_svg.replace(self)
distance_matrix = cdist(
[submob.get_center() for submob in self.submobjects],
[submob.get_center() for submob in labelled_svg.submobjects]

View file

@ -5,6 +5,7 @@ from xml.etree import ElementTree as ET
import numpy as np
import svgelements as se
import io
from manimlib.constants import RIGHT
from manimlib.logger import log
@ -35,12 +36,14 @@ def _convert_point_to_3d(x: float, y: float) -> np.ndarray:
class SVGMobject(VMobject):
file_name: str = ""
height: float | None = 2.0
width: float | None = None
def __init__(
self,
file_name: str = "",
should_center: bool = True,
height: float | None = 2.0,
height: float | None = None,
width: float | None = None,
# Style that overrides the original svg
color: ManimColor = None,
@ -66,7 +69,6 @@ class SVGMobject(VMobject):
self.file_name = file_name or self.file_name
self.svg_default = dict(svg_default)
self.path_string_config = dict(path_string_config)
self.height = height
super().__init__(**kwargs )
self.init_svg_mobject()
@ -82,6 +84,9 @@ class SVGMobject(VMobject):
)
# Initialize position
height = height or self.height
width = width or self.width
if should_center:
self.center()
if height is not None:
@ -114,13 +119,13 @@ class SVGMobject(VMobject):
file_path = self.get_file_path()
element_tree = ET.parse(file_path)
new_tree = self.modify_xml_tree(element_tree)
# Create a temporary svg file to dump modified svg to be parsed
root, ext = os.path.splitext(file_path)
modified_file_path = root + "_" + ext
new_tree.write(modified_file_path)
svg = se.SVG.parse(modified_file_path)
os.remove(modified_file_path)
# New svg based on tree contents
data_stream = io.BytesIO()
new_tree.write(data_stream)
data_stream.seek(0)
svg = se.SVG.parse(data_stream)
data_stream.close()
mobjects = self.get_mobjects_from(svg)
self.add(*mobjects)

View file

@ -11,7 +11,6 @@ from manimlib.constants import MED_LARGE_BUFF, SMALL_BUFF
from manimlib.mobject.geometry import Line
from manimlib.mobject.svg.svg_mobject import SVGMobject
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.tex_file_writing import display_during_execution
from manimlib.utils.tex_file_writing import tex_content_to_svg_file
from typing import TYPE_CHECKING
@ -25,6 +24,8 @@ SCALE_FACTOR_PER_FONT_POINT = 0.001
class SingleStringTex(SVGMobject):
height: float | None = None
def __init__(
self,
tex_string: str,
@ -60,6 +61,7 @@ class SingleStringTex(SVGMobject):
fill_color=fill_color,
fill_opacity=fill_opacity,
stroke_width=stroke_width,
path_string_config=path_string_config,
**kwargs
)
@ -83,10 +85,9 @@ class SingleStringTex(SVGMobject):
def get_file_path(self) -> str:
content = self.get_tex_file_body(self.tex_string)
with display_during_execution(f"Writing \"{self.tex_string}\""):
file_path = tex_content_to_svg_file(
content, self.template, self.additional_preamble
)
file_path = tex_content_to_svg_file(
content, self.template, self.additional_preamble, self.tex_string
)
return file_path
def get_tex_file_body(self, tex_string: str) -> str:
@ -246,7 +247,7 @@ class Tex(SingleStringTex):
tex_string = tex_string.strip()
if len(tex_string) == 0:
continue
sub_tex_mob = SingleStringTex(tex_string)
sub_tex_mob = SingleStringTex(tex_string, math_mode=self.math_mode)
num_submobs = len(sub_tex_mob)
if num_submobs == 0:
continue

View file

@ -100,6 +100,7 @@ class VMobject(Mobject):
self.flat_stroke = flat_stroke
self.needs_new_triangulation = True
self.needs_new_unit_normal = True
self.triangulation = np.zeros(0, dtype='i4')
super().__init__(**kwargs)
@ -114,7 +115,7 @@ class VMobject(Mobject):
"fill_rgba": np.zeros((1, 4)),
"stroke_rgba": np.zeros((1, 4)),
"stroke_width": np.zeros((1, 1)),
"unit_normal": np.zeros((1, 3))
"unit_normal": np.array(OUT, ndmin=2),
})
# These are here just to make type checkers happy
@ -771,7 +772,7 @@ class VMobject(Mobject):
])
def get_unit_normal(self, recompute: bool = False) -> Vect3:
if not recompute:
if not self.needs_new_unit_normal and not recompute:
return self.data["unit_normal"][0]
if self.get_num_points() < 3:
@ -788,17 +789,12 @@ class VMobject(Mobject):
points[2] - points[1],
)
self.data["unit_normal"][:] = normal
self.needs_new_unit_normal = False
return normal
def refresh_unit_normal(self):
for mob in self.get_family():
mob.get_unit_normal(recompute=True)
return self
def reverse_points(self):
super().reverse_points()
self.refresh_unit_normal()
self.refresh_triangulation()
mob.needs_new_unit_normal = True
return self
# Alignment
@ -1030,6 +1026,16 @@ class VMobject(Mobject):
super().set_points(points)
return self
@triggers_refreshed_triangulation
def append_points(self, points: Vect3Array):
super().append_points(points)
return self
@triggers_refreshed_triangulation
def reverse_points(self):
super().reverse_points()
return self
@triggers_refreshed_triangulation
def set_data(self, data: dict):
super().set_data(data)

View file

@ -52,7 +52,8 @@ def get_tex_config() -> dict[str, str]:
def tex_content_to_svg_file(
content: str, template: str, additional_preamble: str
content: str, template: str, additional_preamble: str,
short_tex: str
) -> str:
tex_config = get_tex_config()
if not template or template == tex_config["template"]:
@ -78,7 +79,8 @@ def tex_content_to_svg_file(
)
if not os.path.exists(svg_file):
# If svg doesn't exist, create it
create_tex_svg(full_tex, svg_file, compiler)
with display_during_execution("Writing " + short_tex):
create_tex_svg(full_tex, svg_file, compiler)
return svg_file
@ -112,14 +114,15 @@ def create_tex_svg(full_tex: str, svg_file: str, compiler: str) -> None:
log.error(
"LaTeX Error! Not a worry, it happens to the best of us."
)
error_str = ""
with open(root + ".log", "r", encoding="utf-8") as log_file:
error_match_obj = re.search(r"(?<=\n! ).*", log_file.read())
error_match_obj = re.search(r"(?<=\n! ).*\n.*\n", log_file.read())
if error_match_obj:
error_str = error_match_obj.group()
log.debug(
"The error could be: `%s`",
error_match_obj.group()
f"The error could be:\n`{error_str}`",
)
raise LatexError()
raise LatexError(error_str)
# dvi to svg
os.system(" ".join((