Clean up num_tex_symbols

This commit is contained in:
Grant Sanderson 2024-12-11 17:47:26 -06:00
parent dbcec1fcea
commit 4dfd4a8736

View file

@ -1,40 +1,38 @@
from __future__ import annotations
import re
from functools import lru_cache
from manimlib.utils.tex_to_symbol_count import TEX_TO_SYMBOL_COUNT
@lru_cache
def num_tex_symbols(tex: str) -> int:
tex = remove_tex_environments(tex)
commands_pattern = r"""
(?P<sqrt>\\sqrt\[[0-9]+\])| # Special sqrt with number
(?P<cmd>\\[a-zA-Z!,-/:;<>]+) # Regular commands
"""
This function attempts to estimate the number of symbols that
a given string of tex would produce.
Warning, it may not behave perfectly
"""
# First, remove patterns like \begin{align}, \phantom{thing},
# \begin{array}{cc}, etc.
pattern = "|".join(
rf"(\\{s})" + r"(\{\w+\})?(\{\w+\})?(\[\w+\])?"
for s in ["begin", "end", "phantom"]
)
tex = re.sub(pattern, "", tex)
# Progressively count the symbols associated with certain tex commands,
# and remove those commands from the string, adding the number of symbols
# that command creates
total = 0
pos = 0
for match in re.finditer(commands_pattern, tex, re.VERBOSE):
# Count normal characters up to this command
total += sum(1 for c in tex[pos:match.start()] if c not in "^{} \n\t_$\\&")
# Start with the special case \sqrt[number]
for substr in re.findall(r"\\sqrt\[[0-9]+\]", tex):
total += len(substr) - 5 # e.g. \sqrt[3] is 3 symbols
tex = tex.replace(substr, " ")
general_command = r"\\[a-zA-Z!,-/:;<>]+"
for substr in re.findall(general_command, tex):
total += TEX_TO_SYMBOL_COUNT.get(substr, 1)
tex = tex.replace(substr, " ")
if match.group("sqrt"):
total += len(match.group()) - 5
else:
total += TEX_TO_SYMBOL_COUNT.get(match.group(), 1)
pos = match.end()
# Count remaining characters
total += sum(map(lambda c: c not in "^{} \n\t_$\\&", tex))
total += sum(1 for c in tex[pos:] if c not in "^{} \n\t_$\\&")
return total
def remove_tex_environments(tex: str) -> str:
# Handle \phantom{...} with any content
tex = re.sub(r"\\phantom\{[^}]*\}", "", tex)
# Handle other environment commands
tex = re.sub(r"\\(begin|end)(\{\w+\})?(\{\w+\})?(\[\w+\])?", "", tex)
return tex