mirror of
https://github.com/3b1b/manim.git
synced 2025-08-05 16:49:03 +00:00
Preview back propogation scene tweaked to my liking
This commit is contained in:
parent
5048fe80a5
commit
151f733f09
1 changed files with 61 additions and 19 deletions
80
nn/part2.py
80
nn/part2.py
|
@ -103,6 +103,8 @@ class PreviewLearning(NetworkScene):
|
||||||
"max_stroke_width" : 3,
|
"max_stroke_width" : 3,
|
||||||
"stroke_width_exp" : 3,
|
"stroke_width_exp" : 3,
|
||||||
"eta" : 5.0,
|
"eta" : 5.0,
|
||||||
|
"positive_change_color" : GREEN_B,
|
||||||
|
"negative_change_color" : RED_B,
|
||||||
}
|
}
|
||||||
def construct(self):
|
def construct(self):
|
||||||
self.initialize_network()
|
self.initialize_network()
|
||||||
|
@ -115,7 +117,7 @@ class PreviewLearning(NetworkScene):
|
||||||
self.color_network_edges()
|
self.color_network_edges()
|
||||||
|
|
||||||
def add_training_words(self):
|
def add_training_words(self):
|
||||||
words = TextMobject("Training in \\\\ progress $\\dots$")
|
words = TextMobject("Training in \\\\ progress$\\dots$")
|
||||||
words.scale(1.5)
|
words.scale(1.5)
|
||||||
words.to_corner(UP+LEFT)
|
words.to_corner(UP+LEFT)
|
||||||
|
|
||||||
|
@ -127,12 +129,14 @@ class PreviewLearning(NetworkScene):
|
||||||
image = get_training_image_group(train_in, train_out)
|
image = get_training_image_group(train_in, train_out)
|
||||||
self.activate_network(train_in, FadeIn(image))
|
self.activate_network(train_in, FadeIn(image))
|
||||||
self.backprop_one_example(
|
self.backprop_one_example(
|
||||||
train_in, train_out, FadeOut(image)
|
train_in, train_out,
|
||||||
|
FadeOut(image), self.network_mob.layers.restore
|
||||||
)
|
)
|
||||||
|
|
||||||
def activate_network(self, train_in, *added_anims):
|
def activate_network(self, train_in, *added_anims):
|
||||||
network_mob = self.network_mob
|
network_mob = self.network_mob
|
||||||
layers = network_mob.layers
|
layers = network_mob.layers
|
||||||
|
layers.save_state()
|
||||||
activations = self.network.get_activation_of_all_layers(train_in)
|
activations = self.network.get_activation_of_all_layers(train_in)
|
||||||
active_layers = [
|
active_layers = [
|
||||||
self.network_mob.get_active_layer(i, vect)
|
self.network_mob.get_active_layer(i, vect)
|
||||||
|
@ -173,6 +177,8 @@ class PreviewLearning(NetworkScene):
|
||||||
delta_neuron_groups, neuron_groups,
|
delta_neuron_groups, neuron_groups,
|
||||||
delta_edge_groups, edge_groups
|
delta_edge_groups, edge_groups
|
||||||
)
|
)
|
||||||
|
pc_color = self.positive_change_color
|
||||||
|
nc_color = self.negative_change_color
|
||||||
for i, nb, nw, delta_neurons, neurons, delta_edges, edges in reversed(tups):
|
for i, nb, nw, delta_neurons, neurons, delta_edges, edges in reversed(tups):
|
||||||
shown_nw = self.get_adjusted_first_matrix(nw)
|
shown_nw = self.get_adjusted_first_matrix(nw)
|
||||||
if np.max(shown_nw) == 0:
|
if np.max(shown_nw) == 0:
|
||||||
|
@ -180,42 +186,45 @@ class PreviewLearning(NetworkScene):
|
||||||
max_b = np.max(np.abs(nb))
|
max_b = np.max(np.abs(nb))
|
||||||
max_w = np.max(np.abs(shown_nw))
|
max_w = np.max(np.abs(shown_nw))
|
||||||
for neuron, b in zip(delta_neurons, nb):
|
for neuron, b in zip(delta_neurons, nb):
|
||||||
color = RED_E if b > 0 else GREEN_E
|
color = nc_color if b > 0 else pc_color
|
||||||
# neuron.set_fill(color, abs(b)/max_b)
|
# neuron.set_fill(color, abs(b)/max_b)
|
||||||
neuron.set_stroke(color, 3)
|
neuron.set_stroke(color, 3)
|
||||||
for edge, w in zip(delta_edges.split(), shown_nw.T.flatten()):
|
for edge, w in zip(delta_edges.split(), shown_nw.T.flatten()):
|
||||||
edge.set_stroke(
|
edge.set_stroke(
|
||||||
RED_E if w > 0 else GREEN_E,
|
nc_color if w > 0 else pc_color,
|
||||||
3*abs(w)/max_w
|
3*abs(w)/max_w
|
||||||
)
|
)
|
||||||
edge.rotate_in_place(np.pi)
|
edge.rotate_in_place(np.pi)
|
||||||
if i == 0:
|
if i == 2:
|
||||||
delta_edges.submobjects = [
|
delta_edges.submobjects = [
|
||||||
delta_edges[j]
|
delta_edges[-(j+1)]
|
||||||
for j in np.argsort(shown_nw.T.flatten())
|
for j in np.argsort(shown_nw.T.flatten())
|
||||||
]
|
]
|
||||||
network = self.network
|
network = self.network
|
||||||
network.weights[i] -= self.eta*nw
|
network.weights[i] -= self.eta*nw
|
||||||
network.biases[i] -= self.eta*nb
|
network.biases[i] -= self.eta*nb
|
||||||
|
|
||||||
reversed_delta_edges = VGroup(*reversed(delta_edge_groups))
|
reversed_delta_edges = VGroup(*it.chain(*reversed(delta_edge_groups)))
|
||||||
reversed_delta_neurons = VGroup(*reversed(delta_neuron_groups))
|
reversed_delta_neurons = VGroup(*reversed(delta_neuron_groups))
|
||||||
edge_groups.save_state()
|
edge_groups.save_state()
|
||||||
|
|
||||||
self.play(
|
self.play(
|
||||||
ShowCreation(
|
LaggedStart(
|
||||||
|
ShowCreation,
|
||||||
reversed_delta_edges,
|
reversed_delta_edges,
|
||||||
run_time = 2,
|
run_time = 1.5,
|
||||||
submobject_mode = "lagged_start",
|
lag_ratio = 0.15,
|
||||||
lag_factor = 6,
|
|
||||||
),
|
),
|
||||||
FadeIn(
|
FadeIn(
|
||||||
reversed_delta_neurons,
|
reversed_delta_neurons,
|
||||||
run_time = 2,
|
run_time = 2,
|
||||||
submobject_mode = "lagged_start",
|
submobject_mode = "lagged_start",
|
||||||
lag_factor = 4,
|
lag_factor = 4,
|
||||||
|
rate_func = None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.color_network_edges()
|
self.color_network_edges()
|
||||||
|
self.remove(edge_groups)
|
||||||
self.play(*it.chain(
|
self.play(*it.chain(
|
||||||
[ReplacementTransform(
|
[ReplacementTransform(
|
||||||
edge_groups.saved_state, edge_groups,
|
edge_groups.saved_state, edge_groups,
|
||||||
|
@ -281,6 +290,8 @@ class TrainingVsTestData(Scene):
|
||||||
self.get_examples() for x in range(2)
|
self.get_examples() for x in range(2)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
training_examples.next_to(ORIGIN, LEFT)
|
||||||
|
test_examples.next_to(ORIGIN, RIGHT)
|
||||||
self.play(
|
self.play(
|
||||||
LaggedStart(FadeIn, training_examples),
|
LaggedStart(FadeIn, training_examples),
|
||||||
LaggedStart(FadeIn, test_examples),
|
LaggedStart(FadeIn, test_examples),
|
||||||
|
@ -339,7 +350,7 @@ class TrainingVsTestData(Scene):
|
||||||
self.remove(train_ex)
|
self.remove(train_ex)
|
||||||
self.add(new_ex)
|
self.add(new_ex)
|
||||||
new_ex[0][0].highlight(color)
|
new_ex[0][0].highlight(color)
|
||||||
self.dither(1./10)
|
self.dither(1./30)
|
||||||
training_examples = new_examples
|
training_examples = new_examples
|
||||||
|
|
||||||
class NotSciFi(TeacherStudentsScene):
|
class NotSciFi(TeacherStudentsScene):
|
||||||
|
@ -348,20 +359,51 @@ class NotSciFi(TeacherStudentsScene):
|
||||||
self.student_says(
|
self.student_says(
|
||||||
"Machines learning?!?",
|
"Machines learning?!?",
|
||||||
student_index = 0,
|
student_index = 0,
|
||||||
target_mode = "confused",
|
target_mode = "pleading",
|
||||||
|
run_time = 1,
|
||||||
)
|
)
|
||||||
bubble = students[0].bubble
|
bubble = students[0].bubble
|
||||||
students[0].bubble = None
|
students[0].bubble = None
|
||||||
self.student_says(
|
self.student_says(
|
||||||
"Run!", student_index = 2,
|
"Should we \\\\ be worried?", student_index = 2,
|
||||||
target_mode = "pleading",
|
target_mode = "confused",
|
||||||
bubble_kwargs = {"direction" : LEFT}
|
bubble_kwargs = {"direction" : LEFT},
|
||||||
|
run_time = 1,
|
||||||
)
|
)
|
||||||
self.dither()
|
self.dither()
|
||||||
students[0].bubble = bubble
|
students[0].bubble = bubble
|
||||||
|
self.teacher_says(
|
||||||
|
"It's actually \\\\ just calculus.",
|
||||||
|
run_time = 1
|
||||||
|
)
|
||||||
|
self.teacher.bubble = None
|
||||||
|
self.dither()
|
||||||
|
self.student_says(
|
||||||
|
"Even worse!",
|
||||||
|
target_mode = "horrified",
|
||||||
|
bubble_kwargs = {
|
||||||
|
"direction" : LEFT,
|
||||||
|
"width" : 3,
|
||||||
|
"height" : 2,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.dither(2)
|
||||||
|
|
||||||
|
class FunctionMinmization(GraphScene):
|
||||||
|
CONFIG = {
|
||||||
|
"x_labeled_nums" : range(-1, 10),
|
||||||
|
}
|
||||||
|
def construct(self):
|
||||||
|
self.setup_axes()
|
||||||
|
def func(x):
|
||||||
|
x -= 5
|
||||||
|
return 0.1*(x**3 - 9*x) + 4
|
||||||
|
graph = self.get_graph(func)
|
||||||
|
graph_label = self.get_graph_label(graph, "C(x)")
|
||||||
|
self.add(graph, graph_label)
|
||||||
|
|
||||||
|
dot = Dot(color = YELLOW)
|
||||||
|
x =
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue