2022-12-21 12:47:18 -08:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2022-12-20 22:35:41 -08:00
|
|
|
import re
|
2024-12-12 10:39:54 -06:00
|
|
|
from functools import lru_cache
|
2022-12-21 12:47:18 -08:00
|
|
|
|
2022-12-29 10:37:46 -08:00
|
|
|
from manimlib.utils.tex_to_symbol_count import TEX_TO_SYMBOL_COUNT
|
2022-12-21 12:47:18 -08:00
|
|
|
|
2022-12-20 22:35:41 -08:00
|
|
|
|
2024-12-12 10:39:54 -06:00
|
|
|
@lru_cache
|
2022-12-20 22:35:41 -08:00
|
|
|
def num_tex_symbols(tex: str) -> int:
|
2024-12-12 10:39:54 -06:00
|
|
|
tex = remove_tex_environments(tex)
|
|
|
|
commands_pattern = r"""
|
|
|
|
(?P<sqrt>\\sqrt\[[0-9]+\])| # Special sqrt with number
|
|
|
|
(?P<cmd>\\[a-zA-Z!,-/:;<>]+) # Regular commands
|
2022-12-20 22:35:41 -08:00
|
|
|
"""
|
2022-12-21 12:47:18 -08:00
|
|
|
total = 0
|
2024-12-12 10:39:54 -06:00
|
|
|
pos = 0
|
|
|
|
for match in re.finditer(commands_pattern, tex, re.VERBOSE):
|
|
|
|
# Count normal characters up to this command
|
|
|
|
total += sum(1 for c in tex[pos:match.start()] if c not in "^{} \n\t_$\\&")
|
2022-12-29 10:37:46 -08:00
|
|
|
|
2024-12-12 10:39:54 -06:00
|
|
|
if match.group("sqrt"):
|
|
|
|
total += len(match.group()) - 5
|
|
|
|
else:
|
|
|
|
total += TEX_TO_SYMBOL_COUNT.get(match.group(), 1)
|
|
|
|
pos = match.end()
|
2022-12-21 12:47:18 -08:00
|
|
|
|
|
|
|
# Count remaining characters
|
2024-12-12 10:39:54 -06:00
|
|
|
total += sum(1 for c in tex[pos:] if c not in "^{} \n\t_$\\&")
|
2022-12-21 12:47:18 -08:00
|
|
|
return total
|
2024-12-12 10:39:54 -06:00
|
|
|
|
|
|
|
|
|
|
|
def remove_tex_environments(tex: str) -> str:
|
|
|
|
# Handle \phantom{...} with any content
|
|
|
|
tex = re.sub(r"\\phantom\{[^}]*\}", "", tex)
|
|
|
|
# Handle other environment commands
|
|
|
|
tex = re.sub(r"\\(begin|end)(\{\w+\})?(\{\w+\})?(\[\w+\])?", "", tex)
|
|
|
|
return tex
|