mirror of
https://github.com/3b1b/manim.git
synced 2025-08-21 05:44:04 +00:00
Refactor svg reading
This commit is contained in:
parent
8c07fcca24
commit
ccef2485b2
1 changed files with 72 additions and 109 deletions
|
@ -34,9 +34,9 @@ class SVGMobject(VMobject):
|
||||||
# Must be filled in in a subclass, or when called
|
# Must be filled in in a subclass, or when called
|
||||||
"file_name": None,
|
"file_name": None,
|
||||||
"unpack_groups": True, # if False, creates a hierarchy of VGroups
|
"unpack_groups": True, # if False, creates a hierarchy of VGroups
|
||||||
|
# TODO, style components should be read in, not defaulted
|
||||||
"stroke_width": DEFAULT_STROKE_WIDTH,
|
"stroke_width": DEFAULT_STROKE_WIDTH,
|
||||||
"fill_opacity": 1.0,
|
"fill_opacity": 1.0,
|
||||||
# "fill_color" : LIGHT_GREY,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, file_name=None, **kwargs):
|
def __init__(self, file_name=None, **kwargs):
|
||||||
|
@ -47,24 +47,25 @@ class SVGMobject(VMobject):
|
||||||
self.move_into_position()
|
self.move_into_position()
|
||||||
|
|
||||||
def ensure_valid_file(self):
|
def ensure_valid_file(self):
|
||||||
if self.file_name is None:
|
file_name = self.file_name
|
||||||
|
if file_name is None:
|
||||||
raise Exception("Must specify file for SVGMobject")
|
raise Exception("Must specify file for SVGMobject")
|
||||||
possible_paths = [
|
possible_paths = [
|
||||||
os.path.join(os.path.join("assets", "svg_images"), self.file_name),
|
os.path.join(os.path.join("assets", "svg_images"), file_name),
|
||||||
os.path.join(os.path.join("assets", "svg_images"), self.file_name + ".svg"),
|
os.path.join(os.path.join("assets", "svg_images"), file_name + ".svg"),
|
||||||
os.path.join(os.path.join("assets", "svg_images"), self.file_name + ".xdv"),
|
os.path.join(os.path.join("assets", "svg_images"), file_name + ".xdv"),
|
||||||
self.file_name,
|
file_name,
|
||||||
]
|
]
|
||||||
for path in possible_paths:
|
for path in possible_paths:
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
self.file_path = path
|
self.file_path = path
|
||||||
return
|
return
|
||||||
raise IOError("No file matching %s in image directory" %
|
raise IOError(f"No file matching {file_name} in image directory")
|
||||||
self.file_name)
|
|
||||||
|
|
||||||
def generate_points(self):
|
def generate_points(self):
|
||||||
doc = minidom.parse(self.file_path)
|
doc = minidom.parse(self.file_path)
|
||||||
self.ref_to_element = {}
|
self.ref_to_element = {}
|
||||||
|
|
||||||
for svg in doc.getElementsByTagName("svg"):
|
for svg in doc.getElementsByTagName("svg"):
|
||||||
mobjects = self.get_mobjects_from(svg)
|
mobjects = self.get_mobjects_from(svg)
|
||||||
if self.unpack_groups:
|
if self.unpack_groups:
|
||||||
|
@ -122,7 +123,7 @@ class SVGMobject(VMobject):
|
||||||
# Remove initial "#" character
|
# Remove initial "#" character
|
||||||
ref = use_element.getAttribute("xlink:href")[1:]
|
ref = use_element.getAttribute("xlink:href")[1:]
|
||||||
if ref not in self.ref_to_element:
|
if ref not in self.ref_to_element:
|
||||||
warnings.warn("%s not recognized" % ref)
|
warnings.warn(f"{ref} not recognized")
|
||||||
return VGroup()
|
return VGroup()
|
||||||
return self.get_mobjects_from(
|
return self.get_mobjects_from(
|
||||||
self.ref_to_element[ref]
|
self.ref_to_element[ref]
|
||||||
|
@ -136,15 +137,12 @@ class SVGMobject(VMobject):
|
||||||
return float(stripped_attr)
|
return float(stripped_attr)
|
||||||
|
|
||||||
def polygon_to_mobject(self, polygon_element):
|
def polygon_to_mobject(self, polygon_element):
|
||||||
# TODO, This seems hacky...
|
|
||||||
path_string = polygon_element.getAttribute("points")
|
path_string = polygon_element.getAttribute("points")
|
||||||
for digit in string.digits:
|
for digit in string.digits:
|
||||||
path_string = path_string.replace(" " + digit, " L" + digit)
|
path_string = path_string.replace(f" {digit}", f"{digit} L")
|
||||||
path_string = "M" + path_string
|
path_string = "M" + path_string
|
||||||
return self.path_string_to_mobject(path_string)
|
return self.path_string_to_mobject(path_string)
|
||||||
|
|
||||||
# <circle class="st1" cx="143.8" cy="268" r="22.6"/>
|
|
||||||
|
|
||||||
def circle_to_mobject(self, circle_element):
|
def circle_to_mobject(self, circle_element):
|
||||||
x, y, r = [
|
x, y, r = [
|
||||||
self.attribute_to_float(
|
self.attribute_to_float(
|
||||||
|
@ -321,111 +319,76 @@ class VMobjectFromSVGPathstring(VMobject):
|
||||||
digest_locals(self)
|
digest_locals(self)
|
||||||
VMobject.__init__(self, **kwargs)
|
VMobject.__init__(self, **kwargs)
|
||||||
|
|
||||||
def get_path_commands(self):
|
|
||||||
result = [
|
|
||||||
"M", # moveto
|
|
||||||
"L", # lineto
|
|
||||||
"H", # horizontal lineto
|
|
||||||
"V", # vertical lineto
|
|
||||||
"C", # curveto
|
|
||||||
"S", # smooth curveto
|
|
||||||
"Q", # quadratic Bezier curve
|
|
||||||
"T", # smooth quadratic Bezier curveto
|
|
||||||
"A", # elliptical Arc
|
|
||||||
"Z", # closepath
|
|
||||||
]
|
|
||||||
result += [s.lower() for s in result]
|
|
||||||
return result
|
|
||||||
|
|
||||||
def generate_points(self):
|
def generate_points(self):
|
||||||
pattern = "[%s]" % ("".join(self.get_path_commands()))
|
self.relative_point = ORIGIN
|
||||||
pairs = list(zip(
|
for command, coord_string in self.get_commands_and_coord_strings():
|
||||||
|
new_points = self.string_to_points(command, coord_string)
|
||||||
|
self.handle_command(command, new_points)
|
||||||
|
# SVG treats y-coordinate differently
|
||||||
|
self.stretch(-1, 1, about_point=ORIGIN)
|
||||||
|
|
||||||
|
def get_commands_and_coord_strings(self):
|
||||||
|
all_commands = list(self.get_command_to_function_map().keys())
|
||||||
|
all_commands += [c.lower() for c in all_commands]
|
||||||
|
pattern = "[{}]".format("".join(all_commands))
|
||||||
|
return zip(
|
||||||
re.findall(pattern, self.path_string),
|
re.findall(pattern, self.path_string),
|
||||||
re.split(pattern, self.path_string)[1:]
|
re.split(pattern, self.path_string)[1:]
|
||||||
))
|
)
|
||||||
# Which mobject should new points be added to
|
|
||||||
self = self
|
|
||||||
for command, coord_string in pairs:
|
|
||||||
self.handle_command(command, coord_string)
|
|
||||||
# people treat y-coordinate differently
|
|
||||||
self.rotate(np.pi, RIGHT, about_point=ORIGIN)
|
|
||||||
|
|
||||||
def handle_command(self, command, coord_string):
|
def handle_command(self, command, new_points):
|
||||||
isLower = command.islower()
|
if command.islower(): # Treat it as a relative command
|
||||||
command = command.upper()
|
new_points += self.relative_point
|
||||||
# new_points are the points that will be added to the curr_points
|
|
||||||
# list. This variable may get modified in the conditionals below.
|
|
||||||
points = self.points
|
|
||||||
new_points = self.string_to_points(coord_string)
|
|
||||||
|
|
||||||
if isLower and len(points) > 0:
|
func, n_points = self.command_to_function(command)
|
||||||
new_points += points[-1]
|
func(*new_points[:n_points])
|
||||||
|
leftover_points = new_points[n_points:]
|
||||||
|
|
||||||
if command == "M": # moveto
|
# Recursively handle the rest of the points
|
||||||
self.start_new_path(new_points[0])
|
if len(leftover_points) > 0:
|
||||||
if len(new_points) <= 1:
|
if command.upper() == "M":
|
||||||
return
|
command = "l" # Treat following points as relative line coordinates
|
||||||
|
self.handle_command(command, leftover_points)
|
||||||
|
else:
|
||||||
|
# Command is over, reset for future relative commands
|
||||||
|
self.relative_point = self.points[-1]
|
||||||
|
|
||||||
# Draw relative line-to values.
|
def string_to_points(self, command, coord_string):
|
||||||
points = self.points
|
|
||||||
new_points = new_points[1:]
|
|
||||||
command = "L"
|
|
||||||
|
|
||||||
for p in new_points:
|
|
||||||
if isLower:
|
|
||||||
# Treat everything as relative line-to until empty
|
|
||||||
p[0] += self.points[-1, 0]
|
|
||||||
p[1] += self.points[-1, 1]
|
|
||||||
self.add_line_to(p)
|
|
||||||
return
|
|
||||||
|
|
||||||
elif command in ["L", "H", "V"]: # lineto
|
|
||||||
if command == "H":
|
|
||||||
new_points[0, 1] = points[-1, 1]
|
|
||||||
elif command == "V":
|
|
||||||
if isLower:
|
|
||||||
new_points[0, 0] -= points[-1, 0]
|
|
||||||
new_points[0, 0] += points[-1, 1]
|
|
||||||
new_points[0, 1] = new_points[0, 0]
|
|
||||||
new_points[0, 0] = points[-1, 0]
|
|
||||||
self.add_line_to(new_points[0])
|
|
||||||
return
|
|
||||||
|
|
||||||
if command == "C": # curveto
|
|
||||||
pass # Yay! No action required
|
|
||||||
elif command in ["S", "T"]: # smooth curveto
|
|
||||||
self.add_smooth_curve_to(*new_points)
|
|
||||||
# handle1 = points[-1] + (points[-1] - points[-2])
|
|
||||||
# new_points = np.append([handle1], new_points, axis=0)
|
|
||||||
return
|
|
||||||
elif command == "Q": # quadratic Bezier curve
|
|
||||||
# TODO, this is a suboptimal approximation
|
|
||||||
new_points = np.append([new_points[0]], new_points, axis=0)
|
|
||||||
elif command == "A": # elliptical Arc
|
|
||||||
raise Exception("Not implemented")
|
|
||||||
elif command == "Z": # closepath
|
|
||||||
return
|
|
||||||
|
|
||||||
# Add first three points
|
|
||||||
self.add_cubic_bezier_curve_to(*new_points[0:3])
|
|
||||||
|
|
||||||
# Handle situations where there's multiple relative control points
|
|
||||||
if len(new_points) > 3:
|
|
||||||
# Add subsequent offset points relatively.
|
|
||||||
for i in range(3, len(new_points), 3):
|
|
||||||
if isLower:
|
|
||||||
new_points[i:i + 3] -= points[-1]
|
|
||||||
new_points[i:i + 3] += new_points[i - 1]
|
|
||||||
self.add_cubic_bezier_curve_to(*new_points[i:i+3])
|
|
||||||
|
|
||||||
def string_to_points(self, coord_string):
|
|
||||||
numbers = string_to_numbers(coord_string)
|
numbers = string_to_numbers(coord_string)
|
||||||
if len(numbers) % 2 == 1:
|
if command.upper() in ["H", "V"]:
|
||||||
numbers.append(0)
|
i = {"H": 0, "V": 1}[command.upper()]
|
||||||
num_points = len(numbers) // 2
|
xy = np.zeros((len(numbers), 2))
|
||||||
result = np.zeros((num_points, self.dim))
|
xy[:, i] = numbers
|
||||||
result[:, :2] = np.array(numbers).reshape((num_points, 2))
|
if command.isupper():
|
||||||
|
xy[:, 1 - i] = self.relative_point[1 - i]
|
||||||
|
elif command.upper() == "A":
|
||||||
|
raise Exception("Not implemented")
|
||||||
|
else:
|
||||||
|
xy = np.array(numbers).reshape((len(numbers) // 2, 2))
|
||||||
|
result = np.zeros((xy.shape[0], self.dim))
|
||||||
|
result[:, :2] = xy
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def command_to_function(self, command):
|
||||||
|
return self.get_command_to_function_map()[command.upper()]
|
||||||
|
|
||||||
|
def get_command_to_function_map(self):
|
||||||
|
"""
|
||||||
|
Associates svg command to VMobject function, and
|
||||||
|
the number of arguments it takes in
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"M": (self.start_new_path, 1),
|
||||||
|
"L": (self.add_line_to, 1),
|
||||||
|
"H": (self.add_line_to, 1),
|
||||||
|
"V": (self.add_line_to, 1),
|
||||||
|
"C": (self.add_cubic_bezier_curve_to, 3),
|
||||||
|
"S": (self.add_smooth_cubic_curve_to, 2),
|
||||||
|
"Q": (self.add_quadratic_bezier_curve_to, 2),
|
||||||
|
"T": (self.add_smooth_curve_to, 1),
|
||||||
|
"A": (self.add_quadratic_bezier_curve_to, 2), # TODO
|
||||||
|
"Z": (lambda: self.add_line_to(self.points[0]), 0),
|
||||||
|
}
|
||||||
|
|
||||||
def get_original_path_string(self):
|
def get_original_path_string(self):
|
||||||
return self.path_string
|
return self.path_string
|
||||||
|
|
Loading…
Add table
Reference in a new issue