Move texture id tracking to ShaderWrapper

Rather than having a globally unique id for each texture, dynamically allocate new texure ids within each ShaderWrapper, so that there is no upper bound on how many textures can be used.
This commit is contained in:
Grant Sanderson 2024-09-06 11:07:38 -05:00
parent 76fdd02db0
commit e7c540f415
2 changed files with 28 additions and 31 deletions

View file

@ -15,8 +15,6 @@ from manimlib.utils.iterables import resize_array
from manimlib.utils.shaders import get_shader_code_from_file from manimlib.utils.shaders import get_shader_code_from_file
from manimlib.utils.shaders import get_shader_program from manimlib.utils.shaders import get_shader_program
from manimlib.utils.shaders import image_path_to_texture from manimlib.utils.shaders import image_path_to_texture
from manimlib.utils.shaders import get_texture_id
from manimlib.utils.shaders import release_texture
from manimlib.utils.shaders import set_program_uniform from manimlib.utils.shaders import set_program_uniform
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@ -50,7 +48,7 @@ class ShaderWrapper(object):
self.shader_folder = shader_folder self.shader_folder = shader_folder
self.depth_test = depth_test self.depth_test = depth_test
self.render_primitive = render_primitive self.render_primitive = render_primitive
self.texture_names_to_ids = dict() self.texture_paths = texture_paths or dict()
self.program_uniform_mirror: UniformDict = dict() self.program_uniform_mirror: UniformDict = dict()
self.bind_to_mobject_uniforms(mobject_uniforms or dict()) self.bind_to_mobject_uniforms(mobject_uniforms or dict())
@ -59,8 +57,7 @@ class ShaderWrapper(object):
for old, new in code_replacements.items(): for old, new in code_replacements.items():
self.replace_code(old, new) self.replace_code(old, new)
self.init_program() self.init_program()
if texture_paths is not None: self.init_textures()
self.init_textures(texture_paths)
self.init_vertex_objects() self.init_vertex_objects()
self.refresh_id() self.refresh_id()
@ -92,16 +89,24 @@ class ShaderWrapper(object):
self.vert_format = moderngl.detect_format(self.program, self.vert_attributes) self.vert_format = moderngl.detect_format(self.program, self.vert_attributes)
self.programs = [self.program] self.programs = [self.program]
def init_textures(self, texture_paths: dict[str, str]): def init_textures(self):
self.texture_names_to_ids = { self.texture_names_to_ids = dict()
name: get_texture_id(image_path_to_texture(path, self.ctx)) self.textures = []
for name, path in texture_paths.items() for name, path in self.texture_paths.items():
} self.add_texture(name, image_path_to_texture(path, self.ctx))
def init_vertex_objects(self): def init_vertex_objects(self):
self.vbo = None self.vbo = None
self.vaos = [] self.vaos = []
def add_texture(self, name: str, texture: moderngl.Texture):
max_units = self.ctx.info['GL_MAX_TEXTURE_IMAGE_UNITS']
if len(self.textures) >= max_units:
raise ValueError(f"Unable to use more than {max_units} textures for a program")
# The position in the list determines its id
self.texture_names_to_ids[name] = len(self.textures)
self.textures.append(texture)
def bind_to_mobject_uniforms(self, mobject_uniforms: UniformDict): def bind_to_mobject_uniforms(self, mobject_uniforms: UniformDict):
self.mobject_uniforms = mobject_uniforms self.mobject_uniforms = mobject_uniforms
@ -114,7 +119,7 @@ class ShaderWrapper(object):
self.mobject_uniforms, self.mobject_uniforms,
self.depth_test, self.depth_test,
self.render_primitive, self.render_primitive,
self.texture_names_to_ids, self.texture_paths,
]))) ])))
def replace_code(self, old: str, new: str) -> None: def replace_code(self, old: str, new: str) -> None:
@ -182,6 +187,8 @@ class ShaderWrapper(object):
def pre_render(self): def pre_render(self):
self.set_ctx_depth_test(self.depth_test) self.set_ctx_depth_test(self.depth_test)
self.set_ctx_clip_plane(self.use_clip_plane()) self.set_ctx_clip_plane(self.use_clip_plane())
for tid, texture in enumerate(self.textures):
texture.use(tid)
def render(self): def render(self):
for vao in self.vaos: for vao in self.vaos:
@ -201,6 +208,13 @@ class ShaderWrapper(object):
obj.release() obj.release()
self.init_vertex_objects() self.init_vertex_objects()
def release_textures(self):
for texture in self.textures:
texture.release()
del texture
self.textures = []
self.texture_names_to_ids = dict()
class VShaderWrapper(ShaderWrapper): class VShaderWrapper(ShaderWrapper):
def __init__( def __init__(
@ -227,6 +241,8 @@ class VShaderWrapper(ShaderWrapper):
code_replacements=code_replacements, code_replacements=code_replacements,
) )
self.fill_canvas = VShaderWrapper.get_fill_canvas(self.ctx) self.fill_canvas = VShaderWrapper.get_fill_canvas(self.ctx)
self.add_texture('Texture', self.fill_canvas[0].color_attachments[0])
self.add_texture('DepthTexture', self.fill_canvas[2].color_attachments[0])
def init_program_code(self) -> None: def init_program_code(self) -> None:
self.program_code = { self.program_code = {
@ -441,9 +457,6 @@ class VShaderWrapper(ShaderWrapper):
fragment_shader=alpha_adjust_frag, fragment_shader=alpha_adjust_frag,
) )
fill_program['Texture'].value = get_texture_id(fill_texture)
fill_program['DepthTexture'].value = get_texture_id(depth_texture)
verts = np.array([[0, 0], [0, 1], [1, 0], [1, 1]]) verts = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
simple_vbo = ctx.buffer(verts.astype('f4').tobytes()) simple_vbo = ctx.buffer(verts.astype('f4').tobytes())
fill_texture_vao = ctx.simple_vertex_array( fill_texture_vao = ctx.simple_vertex_array(

View file

@ -19,8 +19,7 @@ if TYPE_CHECKING:
from moderngl.framebuffer import Framebuffer from moderngl.framebuffer import Framebuffer
# Global maps updated as textures are allocated # Global maps to reflect uniform status
ID_TO_TEXTURE: dict[int, moderngl.Texture] = dict()
PROGRAM_UNIFORM_MIRRORS: dict[int, dict[str, float | tuple]] = dict() PROGRAM_UNIFORM_MIRRORS: dict[int, dict[str, float | tuple]] = dict()
@ -34,21 +33,6 @@ def image_path_to_texture(path: str, ctx: moderngl.Context) -> moderngl.Texture:
) )
def get_texture_id(texture: moderngl.Texture) -> int:
tid = 0
while tid in ID_TO_TEXTURE:
tid += 1
ID_TO_TEXTURE[tid] = texture
texture.use(location=tid)
return tid
def release_texture(texture_id: int):
texture = ID_TO_TEXTURE.pop(texture_id, None)
if texture is not None:
texture.release()
@lru_cache() @lru_cache()
def get_shader_program( def get_shader_program(
ctx: moderngl.context.Context, ctx: moderngl.context.Context,