Clean up tex_file_writing

This commit is contained in:
Grant Sanderson 2024-12-04 20:30:53 -06:00
parent 129e512b0c
commit ac01b144e8
3 changed files with 82 additions and 76 deletions

View file

@ -7,7 +7,7 @@ import re
from manimlib.constants import BLACK, WHITE
from manimlib.mobject.svg.svg_mobject import SVGMobject
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.tex_file_writing import tex_to_svg
from manimlib.utils.tex_file_writing import latex_to_svg
from typing import TYPE_CHECKING
@ -79,7 +79,7 @@ class SingleStringTex(SVGMobject):
def get_svg_string_by_content(self, content: str) -> str:
return get_cached_value(
key=hash_string(str((content, self.template, self.additional_preamble))),
value_func=lambda: tex_to_svg(content, self.template, self.additional_preamble),
value_func=lambda: latex_to_svg(content, self.template, self.additional_preamble),
message=f"Writing {self.tex_string}..."
)

View file

@ -9,7 +9,7 @@ from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.cache import get_cached_value
from manimlib.utils.color import color_to_hex
from manimlib.utils.color import hex_to_int
from manimlib.utils.tex_file_writing import tex_to_svg
from manimlib.utils.tex_file_writing import latex_to_svg
from manimlib.utils.tex import num_tex_symbols
from manimlib.utils.simple_functions import hash_string
from manimlib.logger import log
@ -88,7 +88,7 @@ class Tex(StringMobject):
def get_svg_string_by_content(self, content: str) -> str:
return get_cached_value(
key=hash_string(str((content, self.template, self.additional_preamble))),
value_func=lambda: tex_to_svg(content, self.template, self.additional_preamble),
value_func=lambda: latex_to_svg(content, self.template, self.additional_preamble),
message=f"Writing {self.tex_string}..."
)

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import os
import re
import yaml
import subprocess
from pathlib import Path
import tempfile
@ -19,9 +20,8 @@ SAVED_TEX_CONFIG = {}
def get_tex_template_config(template_name: str) -> dict[str, str]:
name = template_name.replace(" ", "_").lower()
with open(os.path.join(
get_manim_dir(), "manimlib", "tex_templates.yml"
), encoding="utf-8") as tex_templates_file:
template_path = os.path.join(get_manim_dir(), "manimlib", "tex_templates.yml")
with open(template_path, encoding="utf-8") as tex_templates_file:
templates_dict = yaml.safe_load(tex_templates_file)
if name not in templates_dict:
log.warning(
@ -53,23 +53,8 @@ def get_tex_config() -> dict[str, str]:
return SAVED_TEX_CONFIG
def tex_to_svg(
content: str,
template: str,
additional_preamble: str,
) -> str:
tex_config = get_tex_config()
if not template or template == tex_config["template"]:
compiler = tex_config["compiler"]
preamble = tex_config["preamble"]
else:
config = get_tex_template_config(template)
compiler = config["compiler"]
preamble = config["preamble"]
if additional_preamble:
preamble += "\n" + additional_preamble
full_tex = "\n\n".join((
def get_full_tex(content: str, preamble: str = ""):
return "\n\n".join((
"\\documentclass[preview]{standalone}",
preamble,
"\\begin{document}",
@ -77,14 +62,32 @@ def tex_to_svg(
"\\end{document}"
)) + "\n"
with tempfile.NamedTemporaryFile(suffix='.svg', mode='r+') as tmp:
create_tex_svg(full_tex, tmp.name, compiler)
# Read the contents
tmp.seek(0)
return tmp.read()
def latex_to_svg(
latex: str,
template: str = "",
additional_preamble: str = ""
) -> str:
"""Convert LaTeX string to SVG string.
Args:
latex: LaTeX source code
template: Path to a template LaTeX file
additional_preamble: String including any added "\\usepackage{...}" style imports
Returns:
str: SVG source code
Raises:
LatexError: If LaTeX compilation fails
NotImplementedError: If compiler is not supported
"""
tex_config = get_tex_config()
if template and template != tex_config["template"]:
tex_config = get_tex_template_config(template)
compiler = tex_config["compiler"]
def create_tex_svg(full_tex: str, svg_file: str, compiler: str) -> None:
if compiler == "latex":
program = "latex"
dvi_ext = ".dvi"
@ -92,57 +95,60 @@ def create_tex_svg(full_tex: str, svg_file: str, compiler: str) -> None:
program = "xelatex -no-pdf"
dvi_ext = ".xdv"
else:
raise NotImplementedError(
f"Compiler '{compiler}' is not implemented"
raise NotImplementedError(f"Compiler '{compiler}' is not implemented")
preamble = tex_config["preamble"] + "\n" + additional_preamble
full_tex = get_full_tex(latex, preamble)
# Write intermediate files to a temporary directory
with tempfile.TemporaryDirectory() as temp_dir:
base_path = os.path.join(temp_dir, "working")
tex_path = base_path + ".tex"
dvi_path = base_path + dvi_ext
# Write tex file
with open(tex_path, "w", encoding="utf-8") as tex_file:
tex_file.write(full_tex)
# Run latex compiler
process = subprocess.run(
[
program.split()[0], # Split for xelatex case
"-interaction=batchmode",
"-halt-on-error",
"-output-directory=" + temp_dir,
tex_path
] + (["--no-pdf"] if compiler == "xelatex" else []),
capture_output=True,
text=True
)
# Write tex file
root, _ = os.path.splitext(svg_file)
with open(root + ".tex", "w", encoding="utf-8") as tex_file:
tex_file.write(full_tex)
if process.returncode != 0:
# Handle error
error_str = ""
log_path = base_path + ".log"
if os.path.exists(log_path):
with open(log_path, "r", encoding="utf-8") as log_file:
content = log_file.read()
error_match = re.search(r"(?<=\n! ).*\n.*\n", content)
if error_match:
error_str = error_match.group()
raise LatexError(error_str or "LaTeX compilation failed")
# tex to dvi
if os.system(" ".join((
program,
"-interaction=batchmode",
"-halt-on-error",
f"-output-directory=\"{os.path.dirname(svg_file)}\"",
f"\"{root}.tex\"",
">",
os.devnull
))):
log.error(
"LaTeX Error! Not a worry, it happens to the best of us."
# Run dvisvgm and capture output directly
process = subprocess.run(
[
"dvisvgm",
dvi_path,
"-n", # no fonts
"-v", "0", # quiet
"--stdout", # output to stdout instead of file
],
capture_output=True
)
error_str = ""
with open(root + ".log", "r", encoding="utf-8") as log_file:
error_match_obj = re.search(r"(?<=\n! ).*\n.*\n", log_file.read())
if error_match_obj:
error_str = error_match_obj.group()
log.debug(
f"The error could be:\n`{error_str}`",
)
raise LatexError(error_str)
# dvi to svg
os.system(" ".join((
"dvisvgm",
f"\"{root}{dvi_ext}\"",
"-n",
"-v",
"0",
"-o",
f"\"{svg_file}\"",
">",
os.devnull
)))
# Cleanup superfluous documents
for ext in (".tex", dvi_ext, ".log", ".aux"):
try:
os.remove(root + ext)
except FileNotFoundError:
pass
# Return SVG string
return process.stdout.decode('utf-8')
class LatexError(Exception):