chore: add type hints to manimlib.camera

This commit is contained in:
TonyCrane 2022-02-13 19:32:53 +08:00
parent 992e61ddf2
commit 1064e2bb30
No known key found for this signature in database
GPG key ID: 2313A5058A9C637C

View file

@ -1,15 +1,18 @@
import moderngl
import math
from colour import Color
import OpenGL.GL as gl
from __future__ import annotations
from PIL import Image
import numpy as np
import math
import itertools as it
import moderngl
import numpy as np
from PIL import Image
import OpenGL.GL as gl
from colour import Color
from manimlib.constants import *
from manimlib.mobject.mobject import Mobject
from manimlib.mobject.mobject import Point
from manimlib.shader_wrapper import ShaderWrapper
from manimlib.utils.config_ops import digest_config
from manimlib.utils.simple_functions import fdiv
from manimlib.utils.simple_functions import clip
@ -29,12 +32,12 @@ class CameraFrame(Mobject):
"focal_distance": 2,
}
def init_data(self):
def init_data(self) -> None:
super().init_data()
self.data["euler_angles"] = np.array(self.euler_angles, dtype=float)
self.refresh_rotation_matrix()
def init_points(self):
def init_points(self) -> None:
self.set_points([ORIGIN, LEFT, RIGHT, DOWN, UP])
self.set_width(self.frame_shape[0], stretch=True)
self.set_height(self.frame_shape[1], stretch=True)
@ -47,13 +50,13 @@ class CameraFrame(Mobject):
self.set_euler_angles(0, 0, 0)
return self
def get_euler_angles(self):
def get_euler_angles(self) -> np.ndarray:
return self.data["euler_angles"]
def get_inverse_camera_rotation_matrix(self):
def get_inverse_camera_rotation_matrix(self) -> list[list[float]]:
return self.inverse_camera_rotation_matrix
def refresh_rotation_matrix(self):
def refresh_rotation_matrix(self) -> None:
# Rotate based on camera orientation
theta, phi, gamma = self.get_euler_angles()
quat = quaternion_mult(
@ -63,7 +66,7 @@ class CameraFrame(Mobject):
)
self.inverse_camera_rotation_matrix = rotation_matrix_transpose_from_quaternion(quat)
def rotate(self, angle, axis=OUT, **kwargs):
def rotate(self, angle: float, axis: np.ndarray = OUT, **kwargs):
curr_rot_T = self.get_inverse_camera_rotation_matrix()
added_rot_T = rotation_matrix_transpose(angle, axis)
new_rot_T = np.dot(curr_rot_T, added_rot_T)
@ -78,7 +81,13 @@ class CameraFrame(Mobject):
self.set_euler_angles(theta, phi, gamma)
return self
def set_euler_angles(self, theta=None, phi=None, gamma=None, units=RADIANS):
def set_euler_angles(
self,
theta: float | None = None,
phi: float | None = None,
gamma: float | None = None,
units: float = RADIANS
):
if theta is not None:
self.data["euler_angles"][0] = theta * units
if phi is not None:
@ -88,7 +97,12 @@ class CameraFrame(Mobject):
self.refresh_rotation_matrix()
return self
def reorient(self, theta_degrees=None, phi_degrees=None, gamma_degrees=None):
def reorient(
self,
theta_degrees: float | None = None,
phi_degrees: float | None = None,
gamma_degrees: float | None = None,
):
"""
Shortcut for set_euler_angles, defaulting to taking
in angles in degrees
@ -96,60 +110,60 @@ class CameraFrame(Mobject):
self.set_euler_angles(theta_degrees, phi_degrees, gamma_degrees, units=DEGREES)
return self
def set_theta(self, theta):
def set_theta(self, theta: float):
return self.set_euler_angles(theta=theta)
def set_phi(self, phi):
def set_phi(self, phi: float):
return self.set_euler_angles(phi=phi)
def set_gamma(self, gamma):
def set_gamma(self, gamma: float):
return self.set_euler_angles(gamma=gamma)
def increment_theta(self, dtheta):
def increment_theta(self, dtheta: float):
self.data["euler_angles"][0] += dtheta
self.refresh_rotation_matrix()
return self
def increment_phi(self, dphi):
def increment_phi(self, dphi: float):
phi = self.data["euler_angles"][1]
new_phi = clip(phi + dphi, 0, PI)
self.data["euler_angles"][1] = new_phi
self.refresh_rotation_matrix()
return self
def increment_gamma(self, dgamma):
def increment_gamma(self, dgamma: float):
self.data["euler_angles"][2] += dgamma
self.refresh_rotation_matrix()
return self
def get_theta(self):
def get_theta(self) -> float:
return self.data["euler_angles"][0]
def get_phi(self):
def get_phi(self) -> float:
return self.data["euler_angles"][1]
def get_gamma(self):
def get_gamma(self) -> float:
return self.data["euler_angles"][2]
def get_shape(self):
def get_shape(self) -> tuple[float, float]:
return (self.get_width(), self.get_height())
def get_center(self):
def get_center(self) -> np.ndarray:
# Assumes first point is at the center
return self.get_points()[0]
def get_width(self):
def get_width(self) -> float:
points = self.get_points()
return points[2, 0] - points[1, 0]
def get_height(self):
def get_height(self) -> float:
points = self.get_points()
return points[4, 1] - points[3, 1]
def get_focal_distance(self):
def get_focal_distance(self) -> float:
return self.focal_distance * self.get_height()
def get_implied_camera_location(self):
def get_implied_camera_location(self) -> tuple[float, float, float]:
theta, phi, gamma = self.get_euler_angles()
dist = self.get_focal_distance()
x, y, z = self.get_center()
@ -190,10 +204,10 @@ class Camera(object):
"samples": 0,
}
def __init__(self, ctx=None, **kwargs):
def __init__(self, ctx: moderngl.Context | None = None, **kwargs):
digest_config(self, kwargs, locals())
self.rgb_max_val = np.iinfo(self.pixel_array_dtype).max
self.background_rgba = [
self.rgb_max_val: float = np.iinfo(self.pixel_array_dtype).max
self.background_rgba: list[float] = [
*Color(self.background_color).get_rgb(),
self.background_opacity
]
@ -205,10 +219,10 @@ class Camera(object):
self.refresh_perspective_uniforms()
self.static_mobject_to_render_group_list = {}
def init_frame(self):
def init_frame(self) -> None:
self.frame = CameraFrame(**self.frame_config)
def init_context(self, ctx=None):
def init_context(self, ctx: moderngl.Context | None = None) -> None:
if ctx is None:
ctx = moderngl.create_standalone_context()
fbo = self.get_fbo(ctx, 0)
@ -223,7 +237,7 @@ class Camera(object):
fbo_msaa.use()
self.fbo_msaa = fbo_msaa
def set_ctx_blending(self, enable=True):
def set_ctx_blending(self, enable: bool = True) -> None:
if enable:
self.ctx.enable(moderngl.BLEND)
else:
@ -233,17 +247,21 @@ class Camera(object):
# moderngl.ONE, moderngl.ONE
)
def set_ctx_depth_test(self, enable=True):
def set_ctx_depth_test(self, enable: bool = True) -> None:
if enable:
self.ctx.enable(moderngl.DEPTH_TEST)
else:
self.ctx.disable(moderngl.DEPTH_TEST)
def init_light_source(self):
def init_light_source(self) -> None:
self.light_source = Point(self.light_source_position)
# Methods associated with the frame buffer
def get_fbo(self, ctx, samples=0):
def get_fbo(
self,
ctx: moderngl.Context,
samples: int = 0
) -> moderngl.Framebuffer:
pw = self.pixel_width
ph = self.pixel_height
return ctx.framebuffer(
@ -258,16 +276,16 @@ class Camera(object):
)
)
def clear(self):
def clear(self) -> None:
self.fbo.clear(*self.background_rgba)
self.fbo_msaa.clear(*self.background_rgba)
def reset_pixel_shape(self, new_width, new_height):
def reset_pixel_shape(self, new_width: int, new_height: int) -> None:
self.pixel_width = new_width
self.pixel_height = new_height
self.refresh_perspective_uniforms()
def get_raw_fbo_data(self, dtype='f1'):
def get_raw_fbo_data(self, dtype: str = 'f1') -> bytes:
# Copy blocks from the fbo_msaa to the drawn fbo using Blit
pw, ph = (self.pixel_width, self.pixel_height)
gl.glBindFramebuffer(gl.GL_READ_FRAMEBUFFER, self.fbo_msaa.glo)
@ -279,7 +297,7 @@ class Camera(object):
dtype=dtype,
)
def get_image(self, pixel_array=None):
def get_image(self) -> Image:
return Image.frombytes(
'RGBA',
self.get_pixel_shape(),
@ -287,7 +305,7 @@ class Camera(object):
'raw', 'RGBA', 0, -1
)
def get_pixel_array(self):
def get_pixel_array(self) -> np.ndarray:
raw = self.get_raw_fbo_data(dtype='f4')
flat_arr = np.frombuffer(raw, dtype='f4')
arr = flat_arr.reshape([*self.fbo.size, self.n_channels])
@ -295,7 +313,7 @@ class Camera(object):
return (self.rgb_max_val * arr).astype(self.pixel_array_dtype)
# Needed?
def get_texture(self):
def get_texture(self) -> moderngl.Texture:
texture = self.ctx.texture(
size=self.fbo.size,
components=4,
@ -305,32 +323,32 @@ class Camera(object):
return texture
# Getting camera attributes
def get_pixel_shape(self):
def get_pixel_shape(self) -> tuple[int, int]:
return self.fbo.viewport[2:4]
# return (self.pixel_width, self.pixel_height)
def get_pixel_width(self):
def get_pixel_width(self) -> int:
return self.get_pixel_shape()[0]
def get_pixel_height(self):
def get_pixel_height(self) -> int:
return self.get_pixel_shape()[1]
def get_frame_height(self):
def get_frame_height(self) -> float:
return self.frame.get_height()
def get_frame_width(self):
def get_frame_width(self) -> float:
return self.frame.get_width()
def get_frame_shape(self):
def get_frame_shape(self) -> tuple[float, float]:
return (self.get_frame_width(), self.get_frame_height())
def get_frame_center(self):
def get_frame_center(self) -> np.ndarray:
return self.frame.get_center()
def get_location(self):
def get_location(self) -> tuple[float, float, float]:
return self.frame.get_implied_camera_location()
def resize_frame_shape(self, fixed_dimension=0):
def resize_frame_shape(self, fixed_dimension: bool = False) -> None:
"""
Changes frame_shape to match the aspect ratio
of the pixels, where fixed_dimension determines
@ -342,7 +360,7 @@ class Camera(object):
frame_height = self.get_frame_height()
frame_width = self.get_frame_width()
aspect_ratio = fdiv(pixel_width, pixel_height)
if fixed_dimension == 0:
if not fixed_dimension:
frame_height = frame_width / aspect_ratio
else:
frame_width = aspect_ratio * frame_height
@ -350,13 +368,13 @@ class Camera(object):
self.frame.set_width(frame_width)
# Rendering
def capture(self, *mobjects, **kwargs):
def capture(self, *mobjects: Mobject, **kwargs) -> None:
self.refresh_perspective_uniforms()
for mobject in mobjects:
for render_group in self.get_render_group_list(mobject):
self.render(render_group)
def render(self, render_group):
def render(self, render_group: dict[str]) -> None:
shader_wrapper = render_group["shader_wrapper"]
shader_program = render_group["prog"]
self.set_shader_uniforms(shader_program, shader_wrapper)
@ -365,13 +383,17 @@ class Camera(object):
if render_group["single_use"]:
self.release_render_group(render_group)
def get_render_group_list(self, mobject):
def get_render_group_list(self, mobject: Mobject) -> list[dict[str]] | map[dict[str]]:
try:
return self.static_mobject_to_render_group_list[id(mobject)]
except KeyError:
return map(self.get_render_group, mobject.get_shader_wrapper_list())
def get_render_group(self, shader_wrapper, single_use=True):
def get_render_group(
self,
shader_wrapper: ShaderWrapper,
single_use: bool = True
) -> dict[str]:
# Data buffers
vbo = self.ctx.buffer(shader_wrapper.vert_data.tobytes())
if shader_wrapper.vert_indices is None:
@ -399,12 +421,12 @@ class Camera(object):
"single_use": single_use,
}
def release_render_group(self, render_group):
def release_render_group(self, render_group: dict[str]) -> None:
for key in ["vbo", "ibo", "vao"]:
if render_group[key] is not None:
render_group[key].release()
def set_mobjects_as_static(self, *mobjects):
def set_mobjects_as_static(self, *mobjects: Mobject) -> None:
# Creates buffer and array objects holding each mobjects shader data
for mob in mobjects:
self.static_mobject_to_render_group_list[id(mob)] = [
@ -412,18 +434,23 @@ class Camera(object):
for sw in mob.get_shader_wrapper_list()
]
def release_static_mobjects(self):
def release_static_mobjects(self) -> None:
for rg_list in self.static_mobject_to_render_group_list.values():
for render_group in rg_list:
self.release_render_group(render_group)
self.static_mobject_to_render_group_list = {}
# Shaders
def init_shaders(self):
def init_shaders(self) -> None:
# Initialize with the null id going to None
self.id_to_shader_program = {"": None}
self.id_to_shader_program: dict[
int | str, tuple[moderngl.Program, str] | None
] = {"": None}
def get_shader_program(self, shader_wrapper):
def get_shader_program(
self,
shader_wrapper: ShaderWrapper
) -> tuple[moderngl.Program, str]:
sid = shader_wrapper.get_program_id()
if sid not in self.id_to_shader_program:
# Create shader program for the first time, then cache
@ -433,7 +460,11 @@ class Camera(object):
self.id_to_shader_program[sid] = (program, vert_format)
return self.id_to_shader_program[sid]
def set_shader_uniforms(self, shader, shader_wrapper):
def set_shader_uniforms(
self,
shader: moderngl.Program,
shader_wrapper: ShaderWrapper
) -> None:
for name, path in shader_wrapper.texture_paths.items():
tid = self.get_texture_id(path)
shader[name].value = tid
@ -445,7 +476,7 @@ class Camera(object):
except KeyError:
pass
def refresh_perspective_uniforms(self):
def refresh_perspective_uniforms(self) -> None:
frame = self.frame
pw, ph = self.get_pixel_shape()
fw, fh = frame.get_shape()
@ -470,11 +501,13 @@ class Camera(object):
"focal_distance": frame.get_focal_distance(),
}
def init_textures(self):
self.n_textures = 0
self.path_to_texture = {}
def init_textures(self) -> None:
self.n_textures: int = 0
self.path_to_texture: dict[
str, tuple[int, moderngl.Texture]
] = {}
def get_texture_id(self, path):
def get_texture_id(self, path: str) -> int:
if path not in self.path_to_texture:
if self.n_textures == 15: # I have no clue why this is needed
self.n_textures += 1
@ -490,7 +523,7 @@ class Camera(object):
self.path_to_texture[path] = (tid, texture)
return self.path_to_texture[path][0]
def release_texture(self, path):
def release_texture(self, path: str):
tid_and_texture = self.path_to_texture.pop(path, None)
if tid_and_texture:
tid_and_texture[1].release()