mirror of
https://github.com/3b1b/manim.git
synced 2025-11-15 05:17:47 +00:00
chore: add type hints to manimlib.shader_wrapper
This commit is contained in:
parent
db71ed1ae9
commit
91ffdeb2d4
1 changed files with 30 additions and 25 deletions
|
|
@ -1,8 +1,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
import copy
|
||||
from typing import Iterable
|
||||
|
||||
import moderngl
|
||||
import numpy as np
|
||||
import copy
|
||||
|
||||
from manimlib.utils.directories import get_shader_dir
|
||||
from manimlib.utils.file_ops import find_file
|
||||
|
|
@ -15,15 +19,16 @@ from manimlib.utils.file_ops import find_file
|
|||
|
||||
|
||||
class ShaderWrapper(object):
|
||||
def __init__(self,
|
||||
vert_data=None,
|
||||
vert_indices=None,
|
||||
shader_folder=None,
|
||||
uniforms=None, # A dictionary mapping names of uniform variables
|
||||
texture_paths=None, # A dictionary mapping names to filepaths for textures.
|
||||
depth_test=False,
|
||||
render_primitive=moderngl.TRIANGLE_STRIP,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
vert_data: np.ndarray | None = None,
|
||||
vert_indices: np.ndarray | None = None,
|
||||
shader_folder: str | None = None,
|
||||
uniforms: dict[str, float] | None = None, # A dictionary mapping names of uniform variables
|
||||
texture_paths: dict[str, str] | None = None, # A dictionary mapping names to filepaths for textures.
|
||||
depth_test: bool = False,
|
||||
render_primitive: int = moderngl.TRIANGLE_STRIP,
|
||||
):
|
||||
self.vert_data = vert_data
|
||||
self.vert_indices = vert_indices
|
||||
self.vert_attributes = vert_data.dtype.names
|
||||
|
|
@ -46,20 +51,20 @@ class ShaderWrapper(object):
|
|||
result.texture_paths = dict(self.texture_paths)
|
||||
return result
|
||||
|
||||
def is_valid(self):
|
||||
def is_valid(self) -> bool:
|
||||
return all([
|
||||
self.vert_data is not None,
|
||||
self.program_code["vertex_shader"] is not None,
|
||||
self.program_code["fragment_shader"] is not None,
|
||||
])
|
||||
|
||||
def get_id(self):
|
||||
def get_id(self) -> str:
|
||||
return self.id
|
||||
|
||||
def get_program_id(self):
|
||||
def get_program_id(self) -> int:
|
||||
return self.program_id
|
||||
|
||||
def create_id(self):
|
||||
def create_id(self) -> str:
|
||||
# A unique id for a shader
|
||||
return "|".join(map(str, [
|
||||
self.program_id,
|
||||
|
|
@ -69,32 +74,32 @@ class ShaderWrapper(object):
|
|||
self.render_primitive,
|
||||
]))
|
||||
|
||||
def refresh_id(self):
|
||||
def refresh_id(self) -> None:
|
||||
self.program_id = self.create_program_id()
|
||||
self.id = self.create_id()
|
||||
|
||||
def create_program_id(self):
|
||||
def create_program_id(self) -> int:
|
||||
return hash("".join((
|
||||
self.program_code[f"{name}_shader"] or ""
|
||||
for name in ("vertex", "geometry", "fragment")
|
||||
)))
|
||||
|
||||
def init_program_code(self):
|
||||
def get_code(name):
|
||||
def init_program_code(self) -> None:
|
||||
def get_code(name: str) -> str | None:
|
||||
return get_shader_code_from_file(
|
||||
os.path.join(self.shader_folder, f"{name}.glsl")
|
||||
)
|
||||
|
||||
self.program_code = {
|
||||
self.program_code: dict[str, str | None] = {
|
||||
"vertex_shader": get_code("vert"),
|
||||
"geometry_shader": get_code("geom"),
|
||||
"fragment_shader": get_code("frag"),
|
||||
}
|
||||
|
||||
def get_program_code(self):
|
||||
def get_program_code(self) -> dict[str, str | None]:
|
||||
return self.program_code
|
||||
|
||||
def replace_code(self, old, new):
|
||||
def replace_code(self, old: str, new: str) -> None:
|
||||
code_map = self.program_code
|
||||
for (name, code) in code_map.items():
|
||||
if code_map[name] is None:
|
||||
|
|
@ -102,7 +107,7 @@ class ShaderWrapper(object):
|
|||
code_map[name] = re.sub(old, new, code_map[name])
|
||||
self.refresh_id()
|
||||
|
||||
def combine_with(self, *shader_wrappers):
|
||||
def combine_with(self, *shader_wrappers: ShaderWrapper):
|
||||
# Assume they are of the same type
|
||||
if len(shader_wrappers) == 0:
|
||||
return
|
||||
|
|
@ -122,10 +127,10 @@ class ShaderWrapper(object):
|
|||
|
||||
|
||||
# For caching
|
||||
filename_to_code_map = {}
|
||||
filename_to_code_map: dict[str, str] = {}
|
||||
|
||||
|
||||
def get_shader_code_from_file(filename):
|
||||
def get_shader_code_from_file(filename: str) -> str | None:
|
||||
if not filename:
|
||||
return None
|
||||
if filename in filename_to_code_map:
|
||||
|
|
@ -157,7 +162,7 @@ def get_shader_code_from_file(filename):
|
|||
return result
|
||||
|
||||
|
||||
def get_colormap_code(rgb_list):
|
||||
def get_colormap_code(rgb_list: Iterable[float]) -> str:
|
||||
data = ",".join(
|
||||
"vec3({}, {}, {})".format(*rgb)
|
||||
for rgb in rgb_list
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue