mirror of
https://github.com/3b1b/manim.git
synced 2025-08-05 16:49:03 +00:00
expanded functionality of histograms, esp. for updating
This commit is contained in:
parent
fd7dc42d16
commit
fe5438a841
1 changed files with 52 additions and 21 deletions
|
@ -16,7 +16,8 @@ class Histogram(VMobject):
|
||||||
"end_color" : BLUE,
|
"end_color" : BLUE,
|
||||||
"x_scale" : 1.0,
|
"x_scale" : 1.0,
|
||||||
"y_scale" : 1.0,
|
"y_scale" : 1.0,
|
||||||
"x_labels" : "auto",
|
"x_labels" : "auto", # widths, mids, auto, none, [...]
|
||||||
|
"y_labels" : "auto", # auto, none, [...]
|
||||||
"x_min" : 0
|
"x_min" : 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -31,18 +32,29 @@ class Histogram(VMobject):
|
||||||
elif mode == "posts" and len(x_values) != len(y_values) + 1:
|
elif mode == "posts" and len(x_values) != len(y_values) + 1:
|
||||||
raise Exception("Array lengths do not match up!")
|
raise Exception("Array lengths do not match up!")
|
||||||
|
|
||||||
# preliminaries
|
|
||||||
self.y_values = np.array(y_values)
|
|
||||||
|
|
||||||
if mode == "widths":
|
self.y_values = y_values
|
||||||
self.widths = x_values
|
self.x_values = x_values
|
||||||
|
self.mode = mode
|
||||||
|
self.process_values()
|
||||||
|
|
||||||
|
VMobject.__init__(self, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def process_values(self):
|
||||||
|
|
||||||
|
# preliminaries
|
||||||
|
self.y_values = np.array(self.y_values)
|
||||||
|
|
||||||
|
if self.mode == "widths":
|
||||||
|
self.widths = self.x_values
|
||||||
self.posts = np.cumsum(self.widths)
|
self.posts = np.cumsum(self.widths)
|
||||||
self.posts = np.insert(self.posts, 0, 0)
|
self.posts = np.insert(self.posts, 0, 0)
|
||||||
self.posts += self.x_min
|
self.posts += self.x_min
|
||||||
self.x_max = self.posts[-1]
|
self.x_max = self.posts[-1]
|
||||||
elif mode == "posts":
|
elif self.mode == "posts":
|
||||||
self.posts = x_values
|
self.posts = self.x_values
|
||||||
self.widths = x_values[1:] - x_values[:-1]
|
self.widths = self.x_values[1:] - self.x_values[:-1]
|
||||||
self.x_min = self.posts[0]
|
self.x_min = self.posts[0]
|
||||||
self.x_max = self.posts[-1]
|
self.x_max = self.posts[-1]
|
||||||
else:
|
else:
|
||||||
|
@ -57,12 +69,14 @@ class Histogram(VMobject):
|
||||||
|
|
||||||
self.y_values_scaled = self.y_scale * self.y_values
|
self.y_values_scaled = self.y_scale * self.y_values
|
||||||
|
|
||||||
VMobject.__init__(self, **kwargs)
|
|
||||||
digest_config(self, kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_points(self):
|
def generate_points(self):
|
||||||
|
|
||||||
|
self.process_values()
|
||||||
|
for submob in self.submobjects:
|
||||||
|
self.remove(submob)
|
||||||
|
|
||||||
def empty_string_array(n):
|
def empty_string_array(n):
|
||||||
arr = []
|
arr = []
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
|
@ -72,21 +86,33 @@ class Histogram(VMobject):
|
||||||
def num_arr_to_string_arr(arr): # converts number array to string array
|
def num_arr_to_string_arr(arr): # converts number array to string array
|
||||||
ret_arr = []
|
ret_arr = []
|
||||||
for x in arr:
|
for x in arr:
|
||||||
ret_arr.append(str(x))
|
if x == np.floor(x):
|
||||||
|
new_x = int(np.floor(x))
|
||||||
|
else:
|
||||||
|
new_x = x
|
||||||
|
ret_arr.append(str(new_x))
|
||||||
return ret_arr
|
return ret_arr
|
||||||
|
|
||||||
previous_bar = ORIGIN
|
previous_bar = ORIGIN
|
||||||
self.bars = []
|
self.bars = VGroup()
|
||||||
|
self.x_labels_group = VGroup()
|
||||||
|
self.y_labels_group = VGroup()
|
||||||
outline_points = []
|
outline_points = []
|
||||||
|
|
||||||
if self.x_labels == "widths":
|
if self.x_labels == "widths":
|
||||||
self.x_labels = num_arr_to_string_arr(self.widths)
|
self.x_labels = num_arr_to_string_arr(self.widths)
|
||||||
elif self.x_labels == "mids":
|
elif self.x_labels == "mids":
|
||||||
print self.x_mids
|
|
||||||
self.x_labels = num_arr_to_string_arr(self.x_mids)
|
self.x_labels = num_arr_to_string_arr(self.x_mids)
|
||||||
elif self.x_labels == "none":
|
elif self.x_labels == "none":
|
||||||
self.x_labels = empty_string_array(len(self.widths))
|
self.x_labels = empty_string_array(len(self.widths))
|
||||||
|
|
||||||
print self.x_labels
|
if self.y_labels == "auto":
|
||||||
|
self.y_labels = num_arr_to_string_arr(self.y_values)
|
||||||
|
elif self.y_labels == "none":
|
||||||
|
self.y_labels = empty_string_array(len(self.y_values))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
for (i,x) in enumerate(self.x_mids):
|
for (i,x) in enumerate(self.x_mids):
|
||||||
|
|
||||||
|
@ -104,12 +130,15 @@ class Histogram(VMobject):
|
||||||
bar.set_stroke(width = 0)
|
bar.set_stroke(width = 0)
|
||||||
bar.next_to(previous_bar,RIGHT,buff = 0, aligned_edge = DOWN)
|
bar.next_to(previous_bar,RIGHT,buff = 0, aligned_edge = DOWN)
|
||||||
|
|
||||||
self.add(bar)
|
self.bars.add(bar)
|
||||||
self.bars.append(bar)
|
|
||||||
|
|
||||||
label = TextMobject(self.x_labels[i])
|
x_label = TextMobject(self.x_labels[i])
|
||||||
label.next_to(bar,DOWN)
|
x_label.next_to(bar,DOWN)
|
||||||
self.add(label)
|
self.x_labels_group.add(x_label)
|
||||||
|
|
||||||
|
y_label = TextMobject(self.y_labels[i])
|
||||||
|
y_label.next_to(bar, UP)
|
||||||
|
self.y_labels_group.add(y_label)
|
||||||
|
|
||||||
if i == 0:
|
if i == 0:
|
||||||
# start with the lower left
|
# start with the lower left
|
||||||
|
@ -128,7 +157,9 @@ class Histogram(VMobject):
|
||||||
|
|
||||||
self.outline = Polygon(*outline_points)
|
self.outline = Polygon(*outline_points)
|
||||||
self.outline.set_stroke(color = WHITE)
|
self.outline.set_stroke(color = WHITE)
|
||||||
self.add(self.outline)
|
self.add(self.bars, self.x_labels_group, self.y_labels_group, self.outline)
|
||||||
|
|
||||||
|
print self.submobjects
|
||||||
|
|
||||||
def get_lower_left_point(self):
|
def get_lower_left_point(self):
|
||||||
return self.bars[0].get_anchors()[-2]
|
return self.bars[0].get_anchors()[-2]
|
||||||
|
|
Loading…
Add table
Reference in a new issue