From f4eb2724c5f7621f85935929fd6135d429539a24 Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Tue, 25 Jan 2022 14:04:35 +0800 Subject: [PATCH] refactor SVGMobject.handle_transforms --- manimlib/mobject/svg/svg_mobject.py | 107 +++++++++++++++------------- 1 file changed, 58 insertions(+), 49 deletions(-) diff --git a/manimlib/mobject/svg/svg_mobject.py b/manimlib/mobject/svg/svg_mobject.py index 650252ac..dd94a634 100644 --- a/manimlib/mobject/svg/svg_mobject.py +++ b/manimlib/mobject/svg/svg_mobject.py @@ -236,7 +236,6 @@ class SVGMobject(VMobject): return mob def handle_transforms(self, element, mobject): - # TODO, this could use some cleaning... x, y = 0, 0 try: x = self.attribute_to_float(element.getAttribute('x')) @@ -246,56 +245,66 @@ class SVGMobject(VMobject): except Exception: pass - transform = element.getAttribute('transform') + transform_names = [ + "matrix", + "translate", "translateX", "translateY", + "scale", "scaleX", "scaleY", + "rotate", + "skew", "skewX", "skewY" + ] + transform_pattern = re.compile("|".join([x + r"[^)]*\)" for x in transform_names])) + number_pattern = re.compile(r"[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?") + transforms = transform_pattern.findall(element.getAttribute('transform'))[::-1] - try: # transform matrix - prefix = "matrix(" - suffix = ")" - if not transform.startswith(prefix) or not transform.endswith(suffix): - raise Exception() - transform = transform[len(prefix):-len(suffix)] - transform = string_to_numbers(transform) - transform = np.array(transform).reshape([3, 2]) - x = transform[2][0] - y = -transform[2][1] - matrix = np.identity(self.dim) - matrix[:2, :2] = transform[:2, :] - matrix[1] *= -1 - matrix[:, 1] *= -1 + for transform in transforms: + op_name, op_args = transform.split("(") + op_name = op_name.strip() + op_args = [float(x) for x in number_pattern.findall(op_args)] + + if op_name == "matrix": + self._handle_matrix_transform(mobject, op_name, op_args) + elif op_name.startswith("translate"): + self._handle_translate_transform(mobject, op_name, op_args) + elif op_name.startswith("scale"): + self._handle_scale_transform(mobject, op_name, op_args) + + def _handle_matrix_transform(self, mobject, op_name, op_args): + transform = np.array(op_args).reshape([3, 2]) + x = transform[2][0] + y = -transform[2][1] + matrix = np.identity(self.dim) + matrix[:2, :2] = transform[:2, :] + matrix[1] *= -1 + matrix[:, 1] *= -1 + for mob in mobject.family_members_with_points(): + mob.apply_matrix(matrix.T) + mobject.shift(x * RIGHT + y * UP) - for mob in mobject.family_members_with_points(): - mob.apply_matrix(matrix.T) - mobject.shift(x * RIGHT + y * UP) - except: - pass - - try: # transform scale - prefix = "scale(" - suffix = ")" - if not transform.startswith(prefix) or not transform.endswith(suffix): - raise Exception() - transform = transform[len(prefix):-len(suffix)] - scale_values = string_to_numbers(transform) - if len(scale_values) == 2: - scale_x, scale_y = scale_values - mobject.scale(np.array([scale_x, scale_y, 1]), about_point=ORIGIN) - elif len(scale_values) == 1: - scale = scale_values[0] - mobject.scale(np.array([scale, scale, 1]), about_point=ORIGIN) - except: - pass - - try: # transform translate - prefix = "translate(" - suffix = ")" - if not transform.startswith(prefix) or not transform.endswith(suffix): - raise Exception() - transform = transform[len(prefix):-len(suffix)] - x, y = string_to_numbers(transform) - mobject.shift(x * RIGHT + y * DOWN) - except: - pass - # TODO, ... + def _handle_translate_transform(self, mobject, op_name, op_args): + if op_name.endswith("X"): + x, y = op_args[0], 0 + elif op_name.endswith("Y"): + x, y = 0, op_args[0] + else: + x, y = op_args + mobject.shift(x * RIGHT + y * DOWN) + + def _handle_scale_transform(self, mobject, op_name, op_args): + if op_name.endswith("X"): + sx, sy = op_args[0], 1 + elif op_name.endswith("Y"): + sx, sy = 1, op_args[0] + elif len(op_args) == 2: + sx, sy = op_args + else: + sx = sy = op_args[0] + if sx < 0: + mobject.flip(UP) + sx = -sx + if sy < 0: + mobject.flip(RIGHT) + sy = -sy + mobject.scale(np.array([sx, sy, 1]), about_point=ORIGIN) def flatten(self, input_list): output_list = []