From b817e6f15f2e9c301cbe701de5f32a64a604ab64 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Sat, 17 Dec 2022 17:29:49 -0800 Subject: [PATCH] Tweak type hints in color.py --- manimlib/utils/color.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/manimlib/utils/color.py b/manimlib/utils/color.py index ec864086..e0078bfd 100644 --- a/manimlib/utils/color.py +++ b/manimlib/utils/color.py @@ -13,11 +13,11 @@ from manimlib.utils.iterables import resize_with_interpolation from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Iterable - from manimlib.typing import ManimColor + from typing import Iterable, Sequence + from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array -def color_to_rgb(color: ManimColor) -> np.ndarray: +def color_to_rgb(color: ManimColor) -> Vect3: if isinstance(color, str): return hex_to_rgb(color) elif isinstance(color, Color): @@ -26,26 +26,26 @@ def color_to_rgb(color: ManimColor) -> np.ndarray: raise Exception("Invalid color type") -def color_to_rgba(color: ManimColor, alpha: float = 1.0) -> np.ndarray: +def color_to_rgba(color: ManimColor, alpha: float = 1.0) -> Vect4: return np.array([*color_to_rgb(color), alpha]) -def rgb_to_color(rgb: Iterable[float]) -> Color: +def rgb_to_color(rgb: Vect3 | Sequence[float]) -> Color: try: return Color(rgb=tuple(rgb)) except ValueError: return Color(WHITE) -def rgba_to_color(rgba: Iterable[float]) -> Color: - return rgb_to_color(tuple(rgba)[:3]) +def rgba_to_color(rgba: Vect4) -> Color: + return rgb_to_color(rgba[:3]) -def rgb_to_hex(rgb: Iterable[float]) -> str: +def rgb_to_hex(rgb: Vect3 | Sequence[float]) -> str: return rgb2hex(rgb, force_long=True).upper() -def hex_to_rgb(hex_code: str) -> np.ndarray: +def hex_to_rgb(hex_code: str) -> Vect3: return np.array(hex2rgb(hex_code)) @@ -53,13 +53,13 @@ def invert_color(color: ManimColor) -> Color: return rgb_to_color(1.0 - color_to_rgb(color)) -def color_to_int_rgb(color: ManimColor) -> np.ndarray: +def color_to_int_rgb(color: ManimColor) -> np.ndarray[int, np.dtype[np.uint8]]: return (255 * color_to_rgb(color)).astype('uint8') -def color_to_int_rgba(color: ManimColor, opacity: float = 1.0) -> np.ndarray: +def color_to_int_rgba(color: ManimColor, opacity: float = 1.0) -> np.ndarray[int, np.dtype[np.uint8]]: alpha = int(255 * opacity) - return np.array([*color_to_int_rgb(color), alpha]) + return np.array([*color_to_int_rgb(color), alpha], dtype=np.uint8) def color_gradient( @@ -107,7 +107,7 @@ def random_bright_color() -> Color: def get_colormap_list( map_name: str = "viridis", n_colors: int = 9 -) -> np.ndarray: +) -> Vect3Array: """ Options for map_name: 3b1b_colormap @@ -123,7 +123,7 @@ def get_colormap_list( from matplotlib.cm import get_cmap if map_name == "3b1b_colormap": - rgbs = [color_to_rgb(color) for color in COLORMAP_3B1B] + rgbs = np.array([color_to_rgb(color) for color in COLORMAP_3B1B]) else: rgbs = get_cmap(map_name).colors # Make more general? return resize_with_interpolation(np.array(rgbs), n_colors)