diff --git a/manimlib/mobject/probability.py b/manimlib/mobject/probability.py index 69a71069..fb7a76b7 100644 --- a/manimlib/mobject/probability.py +++ b/manimlib/mobject/probability.py @@ -8,6 +8,7 @@ from manimlib.mobject.svg.tex_mobject import TexText from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.utils.color import color_gradient from manimlib.utils.iterables import listify +from numpy import ndarray EPSILON = 0.0001 @@ -149,7 +150,9 @@ class BarChart(VGroup): "height": 4, "width": 6, "n_ticks": 4, + "include_x_ticks": False, "tick_width": 0.2, + "tick_width_x": 0.15, "label_y_axis": True, "y_axis_label_height": 0.25, "max_value": 1, @@ -165,6 +168,7 @@ class BarChart(VGroup): if self.max_value is None: self.max_value = max(values) + self.n_ticks_x = len(values) self.add_axes() self.add_bars(values) self.center() @@ -182,6 +186,17 @@ class BarChart(VGroup): ticks.add(tick) y_axis.add(ticks) + if self.include_x_ticks == True: + ticks1 = VGroup() + widths = np.linspace(0, self.width, self.n_ticks_x + 1) + label_values = np.linspace(0, len(self.bar_names), self.n_ticks_x + 1) + for x, value in zip(widths, label_values): + tick1 = Line(UP*0.05, DOWN*0.05) + tick1.set_width(self.tick_width_x) + tick1.move_to(x * RIGHT) + ticks.add(tick1) + x_axis.add(ticks1) + self.add(x_axis, y_axis) self.x_axis, self.y_axis = x_axis, y_axis @@ -196,7 +211,7 @@ class BarChart(VGroup): self.add(labels) def add_bars(self, values): - buff = float(self.width) / (2 * len(values) + 1) + buff = float(self.width) / (2 * len(values)) bars = VGroup() for i, value in enumerate(values): bar = Rectangle( @@ -205,7 +220,7 @@ class BarChart(VGroup): stroke_width=self.bar_stroke_width, fill_opacity=self.bar_fill_opacity, ) - bar.move_to((2 * i + 1) * buff * RIGHT, DOWN + LEFT) + bar.move_to((2 * i + 0.5) * buff * RIGHT, DOWN + LEFT * 5) bars.add(bar) bars.set_color_by_gradient(*self.bar_colors)