Add type annotations for color.py

This commit is contained in:
YishiMichael 2022-04-12 11:13:05 +08:00
parent 859680d5ab
commit f307c2a298
No known key found for this signature in database
GPG key ID: EC615C0C5A86BC80

View file

@ -1,6 +1,10 @@
import random from __future__ import annotations
from typing import Iterable, Union
from colour import Color from colour import Color
from colour import hex2rgb
from colour import rgb2hex
import numpy as np import numpy as np
from manimlib.constants import WHITE from manimlib.constants import WHITE
@ -8,8 +12,10 @@ from manimlib.constants import COLORMAP_3B1B
from manimlib.utils.bezier import interpolate from manimlib.utils.bezier import interpolate
from manimlib.utils.iterables import resize_with_interpolation from manimlib.utils.iterables import resize_with_interpolation
ManimColor = Union[str, Color]
def color_to_rgb(color):
def color_to_rgb(color: ManimColor) -> np.ndarray:
if isinstance(color, str): if isinstance(color, str):
return hex_to_rgb(color) return hex_to_rgb(color)
elif isinstance(color, Color): elif isinstance(color, Color):
@ -18,55 +24,48 @@ def color_to_rgb(color):
raise Exception("Invalid color type") raise Exception("Invalid color type")
def color_to_rgba(color, alpha=1): def color_to_rgba(color: ManimColor, alpha: float = 1.0) -> np.ndarray:
return np.array([*color_to_rgb(color), alpha]) return np.array([*color_to_rgb(color), alpha])
def rgb_to_color(rgb): def rgb_to_color(rgb: Iterable[float]) -> Color:
try: try:
return Color(rgb=rgb) return Color(rgb=tuple(rgb))
except ValueError: except ValueError:
return Color(WHITE) return Color(WHITE)
def rgba_to_color(rgba): def rgba_to_color(rgba: Iterable[float]) -> Color:
return rgb_to_color(rgba[:3]) return rgb_to_color(tuple(rgba)[:3])
def rgb_to_hex(rgb): def rgb_to_hex(rgb: Iterable[float]) -> str:
return "#" + "".join( return rgb2hex(rgb, force_long=True).upper()
hex(int_x // 16)[2] + hex(int_x % 16)[2]
for x in rgb
for int_x in [int(255 * x)]
)
def hex_to_rgb(hex_code): def hex_to_rgb(hex_code: str) -> tuple[float]:
hex_part = hex_code[1:] return hex2rgb(hex_code)
if len(hex_part) == 3:
hex_part = "".join([2 * c for c in hex_part])
return np.array([
int(hex_part[i:i + 2], 16) / 255
for i in range(0, 6, 2)
])
def invert_color(color): def invert_color(color: ManimColor) -> Color:
return rgb_to_color(1.0 - color_to_rgb(color)) return rgb_to_color(1.0 - color_to_rgb(color))
def color_to_int_rgb(color): def color_to_int_rgb(color: ManimColor) -> np.ndarray:
return (255 * color_to_rgb(color)).astype('uint8') return (255 * color_to_rgb(color)).astype('uint8')
def color_to_int_rgba(color, opacity=1.0): def color_to_int_rgba(color: ManimColor, opacity: float = 1.0) -> np.ndarray:
alpha = int(255 * opacity) alpha = int(255 * opacity)
return np.array([*color_to_int_rgb(color), alpha]) return np.array([*color_to_int_rgb(color), alpha])
def color_gradient(reference_colors, length_of_output): def color_gradient(
reference_colors: Iterable[ManimColor],
length_of_output: int
) -> list[Color]:
if length_of_output == 0: if length_of_output == 0:
return reference_colors[0] return []
rgbs = list(map(color_to_rgb, reference_colors)) rgbs = list(map(color_to_rgb, reference_colors))
alphas = np.linspace(0, (len(rgbs) - 1), length_of_output) alphas = np.linspace(0, (len(rgbs) - 1), length_of_output)
floors = alphas.astype('int') floors = alphas.astype('int')
@ -80,30 +79,33 @@ def color_gradient(reference_colors, length_of_output):
] ]
def interpolate_color(color1, color2, alpha): def interpolate_color(
color1: ManimColor,
color2: ManimColor,
alpha: float
) -> Color:
rgb = interpolate(color_to_rgb(color1), color_to_rgb(color2), alpha) rgb = interpolate(color_to_rgb(color1), color_to_rgb(color2), alpha)
return rgb_to_color(rgb) return rgb_to_color(rgb)
def average_color(*colors): def average_color(*colors: ManimColor) -> Color:
rgbs = np.array(list(map(color_to_rgb, colors))) rgbs = np.array(list(map(color_to_rgb, colors)))
return rgb_to_color(rgbs.mean(0)) return rgb_to_color(rgbs.mean(0))
def random_bright_color(): def random_color() -> Color:
return Color(rgb=tuple(np.random.random(3)))
def random_bright_color() -> Color:
color = random_color() color = random_color()
curr_rgb = color_to_rgb(color) return average_color(color, Color(WHITE))
new_rgb = interpolate(
curr_rgb, np.ones(len(curr_rgb)), 0.5
)
return Color(rgb=new_rgb)
def random_color(): def get_colormap_list(
return Color(rgb=(random.random() for i in range(3))) map_name: str = "viridis",
n_colors: int = 9
) -> np.ndarray:
def get_colormap_list(map_name="viridis", n_colors=9):
""" """
Options for map_name: Options for map_name:
3b1b_colormap 3b1b_colormap