Pull out helper functions from shader_wrapper.py

This commit is contained in:
Grant Sanderson 2023-01-16 14:18:35 -08:00
parent 7050c7e7b0
commit f15ac81131
2 changed files with 53 additions and 43 deletions

View file

@ -7,16 +7,13 @@ import re
import moderngl
import numpy as np
from functools import lru_cache
from manimlib.utils.directories import get_shader_dir
from manimlib.utils.file_ops import find_file
from manimlib.utils.iterables import resize_array
from manimlib.utils.shaders import get_shader_code_from_file
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Iterable, List
from typing import List
# Mobjects that should be rendered with
@ -174,41 +171,3 @@ class ShaderWrapper(object):
n_verts = new_n_verts
n_points += len(data)
return self
@lru_cache(maxsize=12)
def get_shader_code_from_file(filename: str) -> str | None:
if not filename:
return None
try:
filepath = find_file(
filename,
directories=[get_shader_dir(), "/"],
extensions=[],
)
except IOError:
return None
with open(filepath, "r") as f:
result = f.read()
# To share functionality between shaders, some functions are read in
# from other files an inserted into the relevant strings before
# passing to ctx.program for compiling
# Replace "#INSERT " lines with relevant code
insertions = re.findall(r"^#INSERT .*\.glsl$", result, flags=re.MULTILINE)
for line in insertions:
inserted_code = get_shader_code_from_file(
os.path.join("inserts", line.replace("#INSERT ", ""))
)
result = result.replace(line, inserted_code)
return result
def get_colormap_code(rgb_list: Iterable[float]) -> str:
data = ",".join(
"vec3({}, {}, {})".format(*rgb)
for rgb in rgb_list
)
return f"vec3[{len(rgb_list)}]({data})"

51
manimlib/utils/shaders.py Normal file
View file

@ -0,0 +1,51 @@
from __future__ import annotations
import os
import re
from functools import lru_cache
from manimlib.utils.directories import get_shader_dir
from manimlib.utils.file_ops import find_file
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Sequence
@lru_cache(maxsize=12)
def get_shader_code_from_file(filename: str) -> str | None:
if not filename:
return None
try:
filepath = find_file(
filename,
directories=[get_shader_dir(), "/"],
extensions=[],
)
except IOError:
return None
with open(filepath, "r") as f:
result = f.read()
# To share functionality between shaders, some functions are read in
# from other files an inserted into the relevant strings before
# passing to ctx.program for compiling
# Replace "#INSERT " lines with relevant code
insertions = re.findall(r"^#INSERT .*\.glsl$", result, flags=re.MULTILINE)
for line in insertions:
inserted_code = get_shader_code_from_file(
os.path.join("inserts", line.replace("#INSERT ", ""))
)
result = result.replace(line, inserted_code)
return result
def get_colormap_code(rgb_list: Sequence[float]) -> str:
data = ",".join(
"vec3({}, {}, {})".format(*rgb)
for rgb in rgb_list
)
return f"vec3[{len(rgb_list)}]({data})"