From e9f99f1cff16e7bae59575b5067158eb917956f1 Mon Sep 17 00:00:00 2001 From: Ben Hambrecht Date: Thu, 12 Apr 2018 17:21:46 +0200 Subject: [PATCH] expanded (and fixed) Histogram class --- active_projects/eop/histograms.py | 379 +++++++++++++++++------------- 1 file changed, 210 insertions(+), 169 deletions(-) diff --git a/active_projects/eop/histograms.py b/active_projects/eop/histograms.py index 2de66ea4..2f5a847e 100644 --- a/active_projects/eop/histograms.py +++ b/active_projects/eop/histograms.py @@ -2,190 +2,231 @@ from big_ol_pile_of_manim_imports import * from random import * def text_range(start,stop,step): # a range as a list of strings - numbers = np.arange(start,stop,step) - labels = [] - for x in numbers: - labels.append(str(x)) - return labels + numbers = np.arange(start,stop,step) + labels = [] + for x in numbers: + labels.append(str(x)) + return labels class Histogram(VMobject): - CONFIG = { - "start_color" : RED, - "end_color" : BLUE, - "x_scale" : 1.0, - "y_scale" : 1.0, - } + CONFIG = { + "start_color" : RED, + "end_color" : BLUE, + "x_scale" : 1.0, + "y_scale" : 1.0, + "x_labels" : "auto", + "x_min" : 0 + } - def __init__(self, x_values, y_values, **kwargs): + def __init__(self, x_values, y_values, mode = "widths", **kwargs): + # mode = "widths" : x_values means the widths of the bars + # mode = "posts" : x_values means the delimiters btw the bars - digest_config(self, kwargs) + digest_config(self, kwargs) - # preliminaries - self.x_values = x_values - self.y_values = y_values + if mode == "widths" and len(x_values) != len(y_values): + raise Exception("Array lengths do not match up!") + elif mode == "posts" and len(x_values) != len(y_values) + 1: + raise Exception("Array lengths do not match up!") - self.x_steps = x_values[1:] - x_values[:-1] - self.x_min = x_values[0] - self.x_steps[0] * 0.5 - self.x_posts = (x_values[1:] + x_values[:-1]) * 0.5 - self.x_max = x_values[-1] + self.x_steps[-1] * 0.5 - self.x_posts = np.insert(self.x_posts,0,self.x_min) - self.x_posts = np.append(self.x_posts,self.x_max) + # preliminaries + self.y_values = np.array(y_values) - self.x_widths = self.x_posts[1:] - self.x_posts[:-1] + if mode == "widths": + self.widths = x_values + self.posts = np.cumsum(self.widths) + self.posts = np.insert(self.posts, 0, 0) + self.posts += self.x_min + self.x_max = self.posts[-1] + elif mode == "posts": + self.posts = x_values + self.widths = x_values[1:] - x_values[:-1] + self.x_min = self.posts[0] + self.x_max = self.posts[-1] + else: + raise Exception("Invalid mode or no mode specified!") - self.x_values_scaled = self.x_scale * x_values - self.x_steps_scaled = self.x_scale * self.x_steps - self.x_posts_scaled = self.x_scale * self.x_posts - self.x_min_scaled = self.x_scale * self.x_min - self.x_max_scaled = self.x_scale * self.x_max - self.x_widths_scaled = self.x_scale * self.x_widths + self.x_mids = 0.5 * (self.posts[:-1] + self.posts[1:]) - self.y_values_scaled = self.y_scale * self.y_values + self.widths_scaled = self.x_scale * self.widths + self.posts_scaled = self.x_scale * self.posts + self.x_min_scaled = self.x_scale * self.x_min + self.x_max_scaled = self.x_scale * self.x_max - VMobject.__init__(self, **kwargs) - digest_config(self, kwargs) - + self.y_values_scaled = self.y_scale * self.y_values - def generate_points(self): + VMobject.__init__(self, **kwargs) + digest_config(self, kwargs) + - previous_bar = ORIGIN - self.bars = [] - outline_points = [] - self.x_labels = text_range(self.x_values[0], self.x_max, self.x_steps[0]) + def generate_points(self): - for (i,x) in enumerate(self.x_values): + def empty_string_array(n): + arr = [] + for i in range(n): + arr.append("") + return arr - bar = Rectangle( - width = self.x_widths_scaled[i], - height = self.y_values_scaled[i], - ) - t = float(x - self.x_values[0])/(self.x_values[-1] - self.x_values[0]) - bar_color = interpolate_color( - self.start_color, - self.end_color, - t - ) - bar.set_fill(color = bar_color, opacity = 1) - bar.set_stroke(width = 0) - bar.next_to(previous_bar,RIGHT,buff = 0, aligned_edge = DOWN) - - self.add(bar) - self.bars.append(bar) + def num_arr_to_string_arr(arr): # converts number array to string array + ret_arr = [] + for x in arr: + ret_arr.append(str(x)) + return ret_arr - label = TextMobject(self.x_labels[i]) - label.next_to(bar,DOWN) - self.add(label) + previous_bar = ORIGIN + self.bars = [] + outline_points = [] + if self.x_labels == "widths": + self.x_labels = num_arr_to_string_arr(self.widths) + elif self.x_labels == "mids": + print self.x_mids + self.x_labels = num_arr_to_string_arr(self.x_mids) + elif self.x_labels == "none": + self.x_labels = empty_string_array(len(self.widths)) - if i == 0: - # start with the lower left - outline_points.append(bar.get_anchors()[-2]) + print self.x_labels - # upper two points of each bar - outline_points.append(bar.get_anchors()[0]) - outline_points.append(bar.get_anchors()[1]) + for (i,x) in enumerate(self.x_mids): - previous_bar = bar + bar = Rectangle( + width = self.widths_scaled[i], + height = self.y_values_scaled[i], + ) + t = float(x - self.x_min)/(self.x_max - self.x_min) + bar_color = interpolate_color( + self.start_color, + self.end_color, + t + ) + bar.set_fill(color = bar_color, opacity = 1) + bar.set_stroke(width = 0) + bar.next_to(previous_bar,RIGHT,buff = 0, aligned_edge = DOWN) + + self.add(bar) + self.bars.append(bar) - # close the outline - # lower right - outline_points.append(bar.get_anchors()[2]) - # lower left - outline_points.append(outline_points[0]) + label = TextMobject(self.x_labels[i]) + label.next_to(bar,DOWN) + self.add(label) + + if i == 0: + # start with the lower left + outline_points.append(bar.get_anchors()[-2]) + + # upper two points of each bar + outline_points.append(bar.get_anchors()[0]) + outline_points.append(bar.get_anchors()[1]) + + previous_bar = bar + # close the outline + # lower right + outline_points.append(bar.get_anchors()[2]) + # lower left + outline_points.append(outline_points[0]) + + self.outline = Polygon(*outline_points) + self.outline.set_stroke(color = WHITE) + self.add(self.outline) + + def get_lower_left_point(self): + return self.bars[0].get_anchors()[-2] + + + +class BuildUpHistogram(Animation): + + def __init__(self, hist, **kwargs): + self.histogram = hist - self.outline = Polygon(*outline_points) - self.outline.set_stroke(color = WHITE) - self.add(self.outline) - def get_lower_left_point(self): - return self.bars[0].get_anchors()[-2] class FlashThroughHistogram(Animation): - CONFIG = { - "cell_color" : WHITE, - "cell_opacity" : 0.8, - "hist_opacity" : 0.2 - } + CONFIG = { + "cell_color" : WHITE, + "cell_opacity" : 0.8, + "hist_opacity" : 0.2 + } - def __init__(self, mobject, direction = "horizontal", mode = "random", **kwargs): + def __init__(self, mobject, direction = "horizontal", mode = "random", **kwargs): - digest_config(self, kwargs) + digest_config(self, kwargs) - self.cell_height = mobject.y_scale - self.prototype_cell = Rectangle( - width = 1, - height = self.cell_height, - fill_color = self.cell_color, - fill_opacity = self.cell_opacity, - stroke_width = 0, - ) + self.cell_height = mobject.y_scale + self.prototype_cell = Rectangle( + width = 1, + height = self.cell_height, + fill_color = self.cell_color, + fill_opacity = self.cell_opacity, + stroke_width = 0, + ) - x_values = mobject.x_values - y_values = mobject.y_values + x_values = mobject.x_values + y_values = mobject.y_values - self.mode = mode - self.direction = direction + self.mode = mode + self.direction = direction - self.generate_cell_indices(x_values,y_values) - Animation.__init__(self,mobject,**kwargs) + self.generate_cell_indices(x_values,y_values) + Animation.__init__(self,mobject,**kwargs) - def generate_cell_indices(self,x_values,y_values): + def generate_cell_indices(self,x_values,y_values): - self.cell_indices = [] - for (i,x) in enumerate(x_values): + self.cell_indices = [] + for (i,x) in enumerate(x_values): - nb_cells = y_values[i] - for j in range(nb_cells): - self.cell_indices.append((i, j)) + nb_cells = y_values[i] + for j in range(nb_cells): + self.cell_indices.append((i, j)) - self.reordered_cell_indices = self.cell_indices - if self.mode == "random": - shuffle(self.reordered_cell_indices) + self.reordered_cell_indices = self.cell_indices + if self.mode == "random": + shuffle(self.reordered_cell_indices) - def cell_for_index(self,i,j): + def cell_for_index(self,i,j): - if self.direction == "vertical": - width = self.mobject.x_scale - height = self.mobject.y_scale - x = (i + 0.5) * self.mobject.x_scale - y = (j + 0.5) * self.mobject.y_scale - center = self.mobject.get_lower_left_point() + x * RIGHT + y * UP - - elif self.direction == "horizontal": - width = self.mobject.x_scale / self.mobject.y_values[i] - height = self.mobject.y_scale * self.mobject.y_values[i] - x = i * self.mobject.x_scale + (j + 0.5) * width - y = height / 2 - center = self.mobject.get_lower_left_point() + x * RIGHT + y * UP + if self.direction == "vertical": + width = self.mobject.x_scale + height = self.mobject.y_scale + x = (i + 0.5) * self.mobject.x_scale + y = (j + 0.5) * self.mobject.y_scale + center = self.mobject.get_lower_left_point() + x * RIGHT + y * UP + + elif self.direction == "horizontal": + width = self.mobject.x_scale / self.mobject.y_values[i] + height = self.mobject.y_scale * self.mobject.y_values[i] + x = i * self.mobject.x_scale + (j + 0.5) * width + y = height / 2 + center = self.mobject.get_lower_left_point() + x * RIGHT + y * UP - cell = Rectangle(width = width, height = height) - cell.move_to(center) - return cell + cell = Rectangle(width = width, height = height) + cell.move_to(center) + return cell - def update_mobject(self,t): + def update_mobject(self,t): - if t == 0: - self.mobject.add(self.prototype_cell) + if t == 0: + self.mobject.add(self.prototype_cell) - flash_nb = int(t * (len(self.cell_indices))) - 1 - (i,j) = self.reordered_cell_indices[flash_nb] - cell = self.cell_for_index(i,j) - self.prototype_cell.width = cell.get_width() - self.prototype_cell.height = cell.get_height() - self.prototype_cell.generate_points() - self.prototype_cell.move_to(cell.get_center()) + flash_nb = int(t * (len(self.cell_indices))) - 1 + (i,j) = self.reordered_cell_indices[flash_nb] + cell = self.cell_for_index(i,j) + self.prototype_cell.width = cell.get_width() + self.prototype_cell.height = cell.get_height() + self.prototype_cell.generate_points() + self.prototype_cell.move_to(cell.get_center()) - #if t == 1: - # self.mobject.remove(self.prototype_cell) + #if t == 1: + # self.mobject.remove(self.prototype_cell) @@ -205,47 +246,47 @@ class FlashThroughHistogram(Animation): class SampleScene(Scene): - def construct(self): + def construct(self): - x_values = np.array([1,2,3,4,5]) - y_values = np.array([4,3,5,2,3]) + x_values = np.array([1,2,3,4,5]) + y_values = np.array([4,3,5,2,3]) - hist1 = Histogram( - x_values = x_values, - y_values = y_values, - x_scale = 0.5, - y_scale = 0.5, - ).shift(1*DOWN) - self.add(hist1) - self.wait() + hist1 = Histogram( + x_values = x_values, + y_values = y_values, + x_scale = 0.5, + y_scale = 0.5, + ).shift(1*DOWN) + self.add(hist1) + self.wait() - y_values2 = np.array([3,8,7,15,5]) + y_values2 = np.array([3,8,7,15,5]) - hist2 = Histogram( - x_values = x_values, - y_values = y_values2, - x_scale = 0.5, - y_scale = 0.5, - x_labels = text_range(1,6,1), - ) + hist2 = Histogram( + x_values = x_values, + y_values = y_values2, + x_scale = 0.5, + y_scale = 0.5, + x_labels = text_range(1,6,1), + ) - v1 = hist1.get_lower_left_point() - v2 = hist2.get_lower_left_point() - hist2.shift(v1 - v2) - - # self.play( - # ReplacementTransform(hist1,hist2) - # ) + v1 = hist1.get_lower_left_point() + v2 = hist2.get_lower_left_point() + hist2.shift(v1 - v2) + + # self.play( + # ReplacementTransform(hist1,hist2) + # ) - self.play( - FlashThroughHistogram( - hist1, - direction = "horizontal", - mode = "linear", - run_time = 10, - rate_func = None, - ) - ) + self.play( + FlashThroughHistogram( + hist1, + direction = "horizontal", + mode = "linear", + run_time = 10, + rate_func = None, + ) + )