Faster VMobject.get_arc_length

This commit is contained in:
Grant Sanderson 2023-02-16 15:02:30 -08:00
parent 3a05352f73
commit c372ef4aaa

View file

@ -17,6 +17,7 @@ from manimlib.utils.bezier import bezier
from manimlib.utils.bezier import get_quadratic_approximation_of_cubic
from manimlib.utils.bezier import approx_smooth_quadratic_bezier_handles
from manimlib.utils.bezier import smooth_quadratic_path
from manimlib.utils.bezier import interpolate
from manimlib.utils.bezier import integer_interpolate
from manimlib.utils.bezier import inverse_interpolate
from manimlib.utils.bezier import find_intersection
@ -38,6 +39,7 @@ from manimlib.utils.space_ops import get_unit_normal
from manimlib.utils.space_ops import line_intersects_path
from manimlib.utils.space_ops import midpoint
from manimlib.utils.space_ops import normalize_along_axis
from manimlib.utils.space_ops import poly_line_length
from manimlib.utils.space_ops import z_to_vector
from manimlib.shader_wrapper import ShaderWrapper
from manimlib.shader_wrapper import FillShaderWrapper
@ -814,14 +816,16 @@ class VMobject(Mobject):
return np.vstack(new_points)
def get_arc_length(self, n_sample_points: int | None = None) -> float:
if n_sample_points is None:
n_sample_points = 4 * self.get_num_curves() + 1
points = np.array([
self.point_from_proportion(a)
for a in np.linspace(0, 1, n_sample_points)
])
diffs = points[1:] - points[:-1]
return sum(map(get_norm, diffs))
if n_sample_points is not None:
points = np.array([
self.quick_point_from_proportion(a)
for a in np.linspace(0, 1, n_sample_points)
])
return poly_line_length(points)
points = self.get_points()
inner_len = poly_line_length(points[::2])
outer_len = poly_line_length(points)
return interpolate(inner_len, outer_len, 1 / 3)
def get_area_vector(self) -> Vect3:
# Returns a vector whose length is the area bound by