diff --git a/manimlib/mobject/coordinate_systems.py b/manimlib/mobject/coordinate_systems.py index 3ebc4377..e97d0ec4 100644 --- a/manimlib/mobject/coordinate_systems.py +++ b/manimlib/mobject/coordinate_systems.py @@ -4,8 +4,9 @@ from abc import ABC, abstractmethod import numbers import numpy as np +import itertools as it -from manimlib.constants import BLACK, BLUE, BLUE_D, GREEN, GREY_A, WHITE, RED +from manimlib.constants import BLACK, BLUE, BLUE_D, BLUE_E, GREEN, GREY_A, WHITE, RED from manimlib.constants import DEGREES, PI from manimlib.constants import DL, UL, DOWN, DR, LEFT, ORIGIN, OUT, RIGHT, UP from manimlib.constants import FRAME_HEIGHT, FRAME_WIDTH @@ -18,14 +19,16 @@ from manimlib.mobject.geometry import Line from manimlib.mobject.geometry import Rectangle from manimlib.mobject.number_line import NumberLine from manimlib.mobject.svg.tex_mobject import Tex -from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.dot_cloud import DotCloud +from manimlib.mobject.types.surface import ParametricSurface +from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.utils.config_ops import digest_config from manimlib.utils.config_ops import merge_dicts_recursively from manimlib.utils.simple_functions import binary_search from manimlib.utils.space_ops import angle_of_vector from manimlib.utils.space_ops import get_norm from manimlib.utils.space_ops import rotate_vector +from manimlib.utils.space_ops import normalize from typing import TYPE_CHECKING @@ -234,16 +237,27 @@ class CoordinateSystem(ABC): """ return self.input_to_graph_point(x, graph) - def bind_graph_to_func(self, graph, func, jagged=False): + def bind_graph_to_func(self, graph, func, jagged=False, get_discontinuities=None): """ Use for graphing functions which might change over time, or change with conditions """ - graph.x_values = [self.x_axis.p2n(p) for p in graph.get_points()] - graph.add_updater(lambda g: g.set_points([self.c2p(x, func(x)) for x in g.x_values])) - if jagged: - graph.add_updater(lambda g: g.make_jagged()) - else: + x_values = [self.x_axis.p2n(p) for p in graph.get_points()] + + def get_x_values(): + if get_discontinuities: + ds = get_discontinuities() + ep = 1e-6 + added_xs = it.chain(*((d - ep, d + ep) for d in ds)) + return sorted([*x_values, *added_xs])[:len(x_values)] + else: + return x_values + + graph.add_updater(lambda g: g.set_points_as_corners([ + self.c2p(x, func(x)) + for x in get_x_values() + ])) + if not jagged: graph.add_updater(lambda g: g.make_approximately_smooth()) return graph