diff --git a/utils/space_ops.py b/utils/space_ops.py index 3b65b4cd..0bc6b744 100644 --- a/utils/space_ops.py +++ b/utils/space_ops.py @@ -100,12 +100,15 @@ def project_along_vector(point, vector): return np.dot(point, matrix.T) -def normalize(vect): +def normalize(vect, fall_back=None): norm = get_norm(vect) if norm > 0: return vect / norm else: - return np.zeros(len(vect)) + if fall_back is not None: + return fall_back + else: + return np.zeros(len(vect)) def cross(v1, v2):