Refactor command handling in svg_mobject.py

This commit is contained in:
Michael W 2021-10-24 22:30:18 +08:00 committed by GitHub
parent c82f60e29e
commit b6f9da87d0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -345,10 +345,7 @@ class VMobjectFromSVGPathstring(VMobject):
self.triangulation = np.load(tris_filepath) self.triangulation = np.load(tris_filepath)
self.needs_new_triangulation = False self.needs_new_triangulation = False
else: else:
self.relative_point = np.array(ORIGIN) self.handle_commands()
for command, coord_string in self.get_commands_and_coord_strings():
new_points = self.string_to_points(command, coord_string)
self.handle_command(command, new_points)
if self.should_subdivide_sharp_curves: if self.should_subdivide_sharp_curves:
# For a healthy triangulation later # For a healthy triangulation later
self.subdivide_sharp_curves() self.subdivide_sharp_curves()
@ -370,64 +367,40 @@ class VMobjectFromSVGPathstring(VMobject):
re.split(pattern, self.path_string)[1:] re.split(pattern, self.path_string)[1:]
) )
def handle_command(self, command, new_points): def handle_commands(self):
relative_point = ORIGIN
for command, coord_string in self.get_commands_and_coord_strings():
func, number_types_str = self.command_to_function(command)
upper_command = command.upper()
if upper_command == "Z":
func() # `close_path` takes no arguments
continue
number_types = np.array(list(number_types_str))
n_numbers = len(number_types_str)
number_groups = np.array(string_to_numbers(coord_string)).reshape((-1, n_numbers))
for numbers in number_groups:
if command.islower(): if command.islower():
# Treat it as a relative command # Treat it as a relative command
if command == "a": numbers[number_types == "x"] += relative_point[0]
# Only the last `self.dim` columns refer to points numbers[number_types == "y"] += relative_point[1]
new_points[:, -self.dim:] += self.relative_point
else:
new_points += self.relative_point
func, n_points = self.command_to_function(command) if upper_command == "A":
command_points = new_points[:n_points] args = [*numbers[:5], np.array([*numbers[5:7], 0.0])]
if command.upper() == "A": elif upper_command == "H":
func(*command_points[0][:-self.dim], np.array(command_points[0][-self.dim:])) args = [np.array([numbers[0], relative_point[1], 0.0])]
elif upper_command == "V":
args = [np.array([relative_point[0], numbers[0], 0.0])]
else: else:
func(*command_points) args = list(np.hstack((
leftover_points = new_points[n_points:] numbers.reshape((-1, 2)), np.zeros((n_numbers // 2, 1))
)))
func(*args)
relative_point = self.get_last_point()
# Recursively handle the rest of the points
if len(leftover_points) > 0:
if command.upper() == "M":
# Treat following points as relative line coordinates
command = "l"
if command.islower():
if command == "a":
leftover_points[:, -self.dim:] -= self.relative_point
else:
leftover_points -= self.relative_point
self.relative_point = self.get_last_point()
self.handle_command(command, leftover_points)
else:
# Command is over, reset for future relative commands
self.relative_point = self.get_last_point()
def string_to_points(self, command, coord_string):
numbers = string_to_numbers(coord_string)
if command.upper() == "A":
# Only the last `self.dim` columns refer to points
# Each "point" returned here has a size of `(5 + self.dim)`
params = np.array(numbers).reshape((-1, 7))
result = np.zeros((params.shape[0], 5 + self.dim))
result[:, :7] = params
return result
if command.upper() in ["H", "V"]:
i = {"H": 0, "V": 1}[command.upper()]
xy = np.zeros((len(numbers), 2))
xy[:, i] = numbers
if command.isupper():
xy[:, 1 - i] = self.relative_point[1 - i]
else:
xy = np.array(numbers).reshape((-1, 2))
result = np.zeros((xy.shape[0], self.dim))
result[:, :2] = xy
return result
def add_elliptical_arc_to(self, rx, ry, x_axis_rotation, large_arc_flag, sweep_flag, point): def add_elliptical_arc_to(self, rx, ry, x_axis_rotation, large_arc_flag, sweep_flag, point):
"""
In fact, this method only suits 2d VMobjects.
"""
def close_to_zero(a, threshold=1e-5): def close_to_zero(a, threshold=1e-5):
return abs(a) < threshold return abs(a) < threshold
@ -536,19 +509,19 @@ class VMobjectFromSVGPathstring(VMobject):
def get_command_to_function_map(self): def get_command_to_function_map(self):
""" """
Associates svg command to VMobject function, and Associates svg command to VMobject function, and
the number of arguments it takes in the types of arguments it takes in
""" """
return { return {
"M": (self.start_new_path, 1), "M": (self.start_new_path, "xy"),
"L": (self.add_line_to, 1), "L": (self.add_line_to, "xy"),
"H": (self.add_line_to, 1), "H": (self.add_line_to, "x"),
"V": (self.add_line_to, 1), "V": (self.add_line_to, "y"),
"C": (self.add_cubic_bezier_curve_to, 3), "C": (self.add_cubic_bezier_curve_to, "xyxyxy"),
"S": (self.add_smooth_cubic_curve_to, 2), "S": (self.add_smooth_cubic_curve_to, "xyxy"),
"Q": (self.add_quadratic_bezier_curve_to, 2), "Q": (self.add_quadratic_bezier_curve_to, "xyxy"),
"T": (self.add_smooth_curve_to, 1), "T": (self.add_smooth_curve_to, "xy"),
"A": (self.add_elliptical_arc_to, 1), "A": (self.add_elliptical_arc_to, "-----xy"),
"Z": (self.close_path, 0), "Z": (self.close_path, ""),
} }
def get_original_path_string(self): def get_original_path_string(self):