2018-12-24 12:37:51 -08:00
|
|
|
import inspect
|
2018-03-30 18:19:23 -07:00
|
|
|
import numpy as np
|
2022-02-13 15:16:16 -08:00
|
|
|
import math
|
2022-01-26 13:03:14 +08:00
|
|
|
from functools import lru_cache
|
2018-03-30 18:19:23 -07:00
|
|
|
|
2018-04-06 13:58:59 -07:00
|
|
|
|
2018-03-30 18:19:23 -07:00
|
|
|
def sigmoid(x):
|
2018-04-06 13:58:59 -07:00
|
|
|
return 1.0 / (1 + np.exp(-x))
|
|
|
|
|
2018-03-30 18:19:23 -07:00
|
|
|
|
2022-01-26 13:03:14 +08:00
|
|
|
@lru_cache(maxsize=10)
|
|
|
|
def choose(n, k):
|
2022-02-13 15:16:16 -08:00
|
|
|
return math.comb(n, k)
|
|
|
|
|
|
|
|
|
|
|
|
def gen_choose(n, r):
|
|
|
|
return np.prod(np.arange(n, n - r, -1)) / math.factorial(r)
|
2018-03-30 18:19:23 -07:00
|
|
|
|
2018-08-12 12:17:32 -07:00
|
|
|
|
|
|
|
def get_num_args(function):
|
2019-01-29 14:22:46 -08:00
|
|
|
return len(get_parameters(function))
|
|
|
|
|
|
|
|
|
|
|
|
def get_parameters(function):
|
|
|
|
return inspect.signature(function).parameters
|
2018-08-12 12:17:32 -07:00
|
|
|
|
2018-03-30 18:19:23 -07:00
|
|
|
# Just to have a less heavyweight name for this extremely common operation
|
|
|
|
#
|
|
|
|
# We may wish to have more fine-grained control over division by zero behavior
|
|
|
|
# in the future (separate specifiable values for 0/0 and x/0 with x != 0),
|
|
|
|
# but for now, we just allow the option to handle indeterminate 0/0.
|
2018-04-06 13:58:59 -07:00
|
|
|
|
|
|
|
|
2020-02-18 22:27:13 -08:00
|
|
|
def clip(a, min_a, max_a):
|
|
|
|
if a < min_a:
|
|
|
|
return min_a
|
|
|
|
elif a > max_a:
|
|
|
|
return max_a
|
|
|
|
return a
|
|
|
|
|
|
|
|
|
2018-04-06 13:58:59 -07:00
|
|
|
def fdiv(a, b, zero_over_zero_value=None):
|
|
|
|
if zero_over_zero_value is not None:
|
2018-03-30 18:19:23 -07:00
|
|
|
out = np.full_like(a, zero_over_zero_value)
|
2018-04-06 13:58:59 -07:00
|
|
|
where = np.logical_or(a != 0, b != 0)
|
2018-03-30 18:19:23 -07:00
|
|
|
else:
|
|
|
|
out = None
|
|
|
|
where = True
|
|
|
|
|
2018-04-06 13:58:59 -07:00
|
|
|
return np.true_divide(a, b, out=out, where=where)
|
2019-02-06 21:16:26 -08:00
|
|
|
|
|
|
|
|
|
|
|
def binary_search(function,
|
|
|
|
target,
|
|
|
|
lower_bound,
|
|
|
|
upper_bound,
|
|
|
|
tolerance=1e-4):
|
|
|
|
lh = lower_bound
|
|
|
|
rh = upper_bound
|
|
|
|
while abs(rh - lh) > tolerance:
|
|
|
|
mh = np.mean([lh, rh])
|
|
|
|
lx, mx, rx = [function(h) for h in (lh, mh, rh)]
|
|
|
|
if lx == target:
|
|
|
|
return lx
|
|
|
|
if rx == target:
|
|
|
|
return rx
|
|
|
|
|
|
|
|
if lx <= target and rx >= target:
|
|
|
|
if mx > target:
|
|
|
|
rh = mh
|
|
|
|
else:
|
|
|
|
lh = mh
|
|
|
|
elif lx > target and rx < target:
|
|
|
|
lh, rh = rh, lh
|
|
|
|
else:
|
|
|
|
return None
|
|
|
|
return mh
|