chore: add type hints to manimlib.shader_wrapper

This commit is contained in:
TonyCrane 2022-02-15 14:49:02 +08:00
parent db71ed1ae9
commit 91ffdeb2d4
No known key found for this signature in database
GPG key ID: 2313A5058A9C637C

View file

@ -1,8 +1,12 @@
from __future__ import annotations
import os
import re
import copy
from typing import Iterable
import moderngl
import numpy as np
import copy
from manimlib.utils.directories import get_shader_dir
from manimlib.utils.file_ops import find_file
@ -15,15 +19,16 @@ from manimlib.utils.file_ops import find_file
class ShaderWrapper(object):
def __init__(self,
vert_data=None,
vert_indices=None,
shader_folder=None,
uniforms=None, # A dictionary mapping names of uniform variables
texture_paths=None, # A dictionary mapping names to filepaths for textures.
depth_test=False,
render_primitive=moderngl.TRIANGLE_STRIP,
):
def __init__(
self,
vert_data: np.ndarray | None = None,
vert_indices: np.ndarray | None = None,
shader_folder: str | None = None,
uniforms: dict[str, float] | None = None, # A dictionary mapping names of uniform variables
texture_paths: dict[str, str] | None = None, # A dictionary mapping names to filepaths for textures.
depth_test: bool = False,
render_primitive: int = moderngl.TRIANGLE_STRIP,
):
self.vert_data = vert_data
self.vert_indices = vert_indices
self.vert_attributes = vert_data.dtype.names
@ -46,20 +51,20 @@ class ShaderWrapper(object):
result.texture_paths = dict(self.texture_paths)
return result
def is_valid(self):
def is_valid(self) -> bool:
return all([
self.vert_data is not None,
self.program_code["vertex_shader"] is not None,
self.program_code["fragment_shader"] is not None,
])
def get_id(self):
def get_id(self) -> str:
return self.id
def get_program_id(self):
def get_program_id(self) -> int:
return self.program_id
def create_id(self):
def create_id(self) -> str:
# A unique id for a shader
return "|".join(map(str, [
self.program_id,
@ -69,32 +74,32 @@ class ShaderWrapper(object):
self.render_primitive,
]))
def refresh_id(self):
def refresh_id(self) -> None:
self.program_id = self.create_program_id()
self.id = self.create_id()
def create_program_id(self):
def create_program_id(self) -> int:
return hash("".join((
self.program_code[f"{name}_shader"] or ""
for name in ("vertex", "geometry", "fragment")
)))
def init_program_code(self):
def get_code(name):
def init_program_code(self) -> None:
def get_code(name: str) -> str | None:
return get_shader_code_from_file(
os.path.join(self.shader_folder, f"{name}.glsl")
)
self.program_code = {
self.program_code: dict[str, str | None] = {
"vertex_shader": get_code("vert"),
"geometry_shader": get_code("geom"),
"fragment_shader": get_code("frag"),
}
def get_program_code(self):
def get_program_code(self) -> dict[str, str | None]:
return self.program_code
def replace_code(self, old, new):
def replace_code(self, old: str, new: str) -> None:
code_map = self.program_code
for (name, code) in code_map.items():
if code_map[name] is None:
@ -102,7 +107,7 @@ class ShaderWrapper(object):
code_map[name] = re.sub(old, new, code_map[name])
self.refresh_id()
def combine_with(self, *shader_wrappers):
def combine_with(self, *shader_wrappers: ShaderWrapper):
# Assume they are of the same type
if len(shader_wrappers) == 0:
return
@ -122,10 +127,10 @@ class ShaderWrapper(object):
# For caching
filename_to_code_map = {}
filename_to_code_map: dict[str, str] = {}
def get_shader_code_from_file(filename):
def get_shader_code_from_file(filename: str) -> str | None:
if not filename:
return None
if filename in filename_to_code_map:
@ -157,7 +162,7 @@ def get_shader_code_from_file(filename):
return result
def get_colormap_code(rgb_list):
def get_colormap_code(rgb_list: Iterable[float]) -> str:
data = ",".join(
"vec3({}, {}, {})".format(*rgb)
for rgb in rgb_list