mirror of
https://github.com/3b1b/manim.git
synced 2025-04-13 09:47:07 +00:00
Allow Scalable type to be any FloatArray
This commit is contained in:
parent
6f0020950f
commit
f8b39f2ff1
2 changed files with 26 additions and 19 deletions
|
@ -15,9 +15,7 @@ if TYPE_CHECKING:
|
|||
from typing import Callable, Sequence, TypeVar
|
||||
from manimlib.typing import VectN, FloatArray
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
Scalable = TypeVar("Scalable", float, VectN)
|
||||
Scalable = TypeVar("Scalable", float, FloatArray)
|
||||
|
||||
|
||||
CLOSED_THRESHOLD = 0.001
|
||||
|
|
|
@ -5,26 +5,34 @@ import math
|
|||
|
||||
import numpy as np
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from typing import Callable, TypeVar
|
||||
from manimlib.typing import FloatArray
|
||||
|
||||
def sigmoid(x):
|
||||
Scalable = TypeVar("Scalable", float, FloatArray)
|
||||
|
||||
|
||||
|
||||
def sigmoid(x: float | FloatArray):
|
||||
return 1.0 / (1 + np.exp(-x))
|
||||
|
||||
|
||||
@lru_cache(maxsize=10)
|
||||
def choose(n, k):
|
||||
def choose(n: int, k: int) -> int:
|
||||
return math.comb(n, k)
|
||||
|
||||
|
||||
def gen_choose(n, r):
|
||||
return np.prod(np.arange(n, n - r, -1)) / math.factorial(r)
|
||||
def gen_choose(n: int, r: int) -> int:
|
||||
return int(np.prod(range(n, n - r, -1)) / math.factorial(r))
|
||||
|
||||
|
||||
def get_num_args(function):
|
||||
def get_num_args(function: Callable) -> int:
|
||||
return len(get_parameters(function))
|
||||
|
||||
|
||||
def get_parameters(function):
|
||||
return inspect.signature(function).parameters
|
||||
def get_parameters(function: Callable) -> list:
|
||||
return list(inspect.signature(function).parameters.keys())
|
||||
|
||||
# Just to have a less heavyweight name for this extremely common operation
|
||||
#
|
||||
|
@ -33,7 +41,7 @@ def get_parameters(function):
|
|||
# but for now, we just allow the option to handle indeterminate 0/0.
|
||||
|
||||
|
||||
def clip(a, min_a, max_a):
|
||||
def clip(a: float, min_a: float, max_a: float) -> float:
|
||||
if a < min_a:
|
||||
return min_a
|
||||
elif a > max_a:
|
||||
|
@ -41,7 +49,7 @@ def clip(a, min_a, max_a):
|
|||
return a
|
||||
|
||||
|
||||
def fdiv(a, b, zero_over_zero_value=None):
|
||||
def fdiv(a: Scalable, b: Scalable, zero_over_zero_value: Scalable | None = None) -> Scalable:
|
||||
if zero_over_zero_value is not None:
|
||||
out = np.full_like(a, zero_over_zero_value)
|
||||
where = np.logical_or(a != 0, b != 0)
|
||||
|
@ -52,15 +60,15 @@ def fdiv(a, b, zero_over_zero_value=None):
|
|||
return np.true_divide(a, b, out=out, where=where)
|
||||
|
||||
|
||||
def binary_search(function,
|
||||
target,
|
||||
lower_bound,
|
||||
upper_bound,
|
||||
tolerance=1e-4):
|
||||
def binary_search(function: Callable[[float], float],
|
||||
target: float,
|
||||
lower_bound: float,
|
||||
upper_bound: float,
|
||||
tolerance:float = 1e-4) -> float | None:
|
||||
lh = lower_bound
|
||||
rh = upper_bound
|
||||
mh = (lh + rh) / 2
|
||||
while abs(rh - lh) > tolerance:
|
||||
mh = np.mean([lh, rh])
|
||||
lx, mx, rx = [function(h) for h in (lh, mh, rh)]
|
||||
if lx == target:
|
||||
return lx
|
||||
|
@ -76,10 +84,11 @@ def binary_search(function,
|
|||
lh, rh = rh, lh
|
||||
else:
|
||||
return None
|
||||
mh = (lh + rh) / 2
|
||||
return mh
|
||||
|
||||
|
||||
def hash_string(string):
|
||||
def hash_string(string: str) -> str:
|
||||
# Truncating at 16 bytes for cleanliness
|
||||
hasher = hashlib.sha256(string.encode())
|
||||
return hasher.hexdigest()[:16]
|
||||
|
|
Loading…
Add table
Reference in a new issue