From fe5438a841327ffaa1fbe1e46b3081bcbf997bb3 Mon Sep 17 00:00:00 2001 From: Ben Hambrecht Date: Fri, 13 Apr 2018 12:24:39 +0200 Subject: [PATCH] expanded functionality of histograms, esp. for updating --- active_projects/eop/histograms.py | 73 ++++++++++++++++++++++--------- 1 file changed, 52 insertions(+), 21 deletions(-) diff --git a/active_projects/eop/histograms.py b/active_projects/eop/histograms.py index 2f5a847e..97ed6445 100644 --- a/active_projects/eop/histograms.py +++ b/active_projects/eop/histograms.py @@ -16,7 +16,8 @@ class Histogram(VMobject): "end_color" : BLUE, "x_scale" : 1.0, "y_scale" : 1.0, - "x_labels" : "auto", + "x_labels" : "auto", # widths, mids, auto, none, [...] + "y_labels" : "auto", # auto, none, [...] "x_min" : 0 } @@ -31,18 +32,29 @@ class Histogram(VMobject): elif mode == "posts" and len(x_values) != len(y_values) + 1: raise Exception("Array lengths do not match up!") - # preliminaries - self.y_values = np.array(y_values) - if mode == "widths": - self.widths = x_values + self.y_values = y_values + self.x_values = x_values + self.mode = mode + self.process_values() + + VMobject.__init__(self, **kwargs) + + + def process_values(self): + + # preliminaries + self.y_values = np.array(self.y_values) + + if self.mode == "widths": + self.widths = self.x_values self.posts = np.cumsum(self.widths) self.posts = np.insert(self.posts, 0, 0) self.posts += self.x_min self.x_max = self.posts[-1] - elif mode == "posts": - self.posts = x_values - self.widths = x_values[1:] - x_values[:-1] + elif self.mode == "posts": + self.posts = self.x_values + self.widths = self.x_values[1:] - self.x_values[:-1] self.x_min = self.posts[0] self.x_max = self.posts[-1] else: @@ -57,12 +69,14 @@ class Histogram(VMobject): self.y_values_scaled = self.y_scale * self.y_values - VMobject.__init__(self, **kwargs) - digest_config(self, kwargs) - + def generate_points(self): + self.process_values() + for submob in self.submobjects: + self.remove(submob) + def empty_string_array(n): arr = [] for i in range(n): @@ -72,21 +86,33 @@ class Histogram(VMobject): def num_arr_to_string_arr(arr): # converts number array to string array ret_arr = [] for x in arr: - ret_arr.append(str(x)) + if x == np.floor(x): + new_x = int(np.floor(x)) + else: + new_x = x + ret_arr.append(str(new_x)) return ret_arr previous_bar = ORIGIN - self.bars = [] + self.bars = VGroup() + self.x_labels_group = VGroup() + self.y_labels_group = VGroup() outline_points = [] + if self.x_labels == "widths": self.x_labels = num_arr_to_string_arr(self.widths) elif self.x_labels == "mids": - print self.x_mids self.x_labels = num_arr_to_string_arr(self.x_mids) elif self.x_labels == "none": self.x_labels = empty_string_array(len(self.widths)) - print self.x_labels + if self.y_labels == "auto": + self.y_labels = num_arr_to_string_arr(self.y_values) + elif self.y_labels == "none": + self.y_labels = empty_string_array(len(self.y_values)) + + + for (i,x) in enumerate(self.x_mids): @@ -104,12 +130,15 @@ class Histogram(VMobject): bar.set_stroke(width = 0) bar.next_to(previous_bar,RIGHT,buff = 0, aligned_edge = DOWN) - self.add(bar) - self.bars.append(bar) + self.bars.add(bar) - label = TextMobject(self.x_labels[i]) - label.next_to(bar,DOWN) - self.add(label) + x_label = TextMobject(self.x_labels[i]) + x_label.next_to(bar,DOWN) + self.x_labels_group.add(x_label) + + y_label = TextMobject(self.y_labels[i]) + y_label.next_to(bar, UP) + self.y_labels_group.add(y_label) if i == 0: # start with the lower left @@ -128,7 +157,9 @@ class Histogram(VMobject): self.outline = Polygon(*outline_points) self.outline.set_stroke(color = WHITE) - self.add(self.outline) + self.add(self.bars, self.x_labels_group, self.y_labels_group, self.outline) + + print self.submobjects def get_lower_left_point(self): return self.bars[0].get_anchors()[-2]