mirror of
https://github.com/3b1b/manim.git
synced 2025-08-05 16:49:03 +00:00
Preview back propogation scene tweaked to my liking
This commit is contained in:
parent
5048fe80a5
commit
151f733f09
1 changed files with 61 additions and 19 deletions
80
nn/part2.py
80
nn/part2.py
|
@ -103,6 +103,8 @@ class PreviewLearning(NetworkScene):
|
|||
"max_stroke_width" : 3,
|
||||
"stroke_width_exp" : 3,
|
||||
"eta" : 5.0,
|
||||
"positive_change_color" : GREEN_B,
|
||||
"negative_change_color" : RED_B,
|
||||
}
|
||||
def construct(self):
|
||||
self.initialize_network()
|
||||
|
@ -115,7 +117,7 @@ class PreviewLearning(NetworkScene):
|
|||
self.color_network_edges()
|
||||
|
||||
def add_training_words(self):
|
||||
words = TextMobject("Training in \\\\ progress $\\dots$")
|
||||
words = TextMobject("Training in \\\\ progress$\\dots$")
|
||||
words.scale(1.5)
|
||||
words.to_corner(UP+LEFT)
|
||||
|
||||
|
@ -127,12 +129,14 @@ class PreviewLearning(NetworkScene):
|
|||
image = get_training_image_group(train_in, train_out)
|
||||
self.activate_network(train_in, FadeIn(image))
|
||||
self.backprop_one_example(
|
||||
train_in, train_out, FadeOut(image)
|
||||
train_in, train_out,
|
||||
FadeOut(image), self.network_mob.layers.restore
|
||||
)
|
||||
|
||||
def activate_network(self, train_in, *added_anims):
|
||||
network_mob = self.network_mob
|
||||
layers = network_mob.layers
|
||||
layers.save_state()
|
||||
activations = self.network.get_activation_of_all_layers(train_in)
|
||||
active_layers = [
|
||||
self.network_mob.get_active_layer(i, vect)
|
||||
|
@ -173,6 +177,8 @@ class PreviewLearning(NetworkScene):
|
|||
delta_neuron_groups, neuron_groups,
|
||||
delta_edge_groups, edge_groups
|
||||
)
|
||||
pc_color = self.positive_change_color
|
||||
nc_color = self.negative_change_color
|
||||
for i, nb, nw, delta_neurons, neurons, delta_edges, edges in reversed(tups):
|
||||
shown_nw = self.get_adjusted_first_matrix(nw)
|
||||
if np.max(shown_nw) == 0:
|
||||
|
@ -180,42 +186,45 @@ class PreviewLearning(NetworkScene):
|
|||
max_b = np.max(np.abs(nb))
|
||||
max_w = np.max(np.abs(shown_nw))
|
||||
for neuron, b in zip(delta_neurons, nb):
|
||||
color = RED_E if b > 0 else GREEN_E
|
||||
color = nc_color if b > 0 else pc_color
|
||||
# neuron.set_fill(color, abs(b)/max_b)
|
||||
neuron.set_stroke(color, 3)
|
||||
for edge, w in zip(delta_edges.split(), shown_nw.T.flatten()):
|
||||
edge.set_stroke(
|
||||
RED_E if w > 0 else GREEN_E,
|
||||
nc_color if w > 0 else pc_color,
|
||||
3*abs(w)/max_w
|
||||
)
|
||||
edge.rotate_in_place(np.pi)
|
||||
if i == 0:
|
||||
if i == 2:
|
||||
delta_edges.submobjects = [
|
||||
delta_edges[j]
|
||||
delta_edges[-(j+1)]
|
||||
for j in np.argsort(shown_nw.T.flatten())
|
||||
]
|
||||
network = self.network
|
||||
network.weights[i] -= self.eta*nw
|
||||
network.biases[i] -= self.eta*nb
|
||||
|
||||
reversed_delta_edges = VGroup(*reversed(delta_edge_groups))
|
||||
reversed_delta_edges = VGroup(*it.chain(*reversed(delta_edge_groups)))
|
||||
reversed_delta_neurons = VGroup(*reversed(delta_neuron_groups))
|
||||
edge_groups.save_state()
|
||||
|
||||
self.play(
|
||||
ShowCreation(
|
||||
LaggedStart(
|
||||
ShowCreation,
|
||||
reversed_delta_edges,
|
||||
run_time = 2,
|
||||
submobject_mode = "lagged_start",
|
||||
lag_factor = 6,
|
||||
run_time = 1.5,
|
||||
lag_ratio = 0.15,
|
||||
),
|
||||
FadeIn(
|
||||
reversed_delta_neurons,
|
||||
run_time = 2,
|
||||
submobject_mode = "lagged_start",
|
||||
lag_factor = 4,
|
||||
rate_func = None,
|
||||
)
|
||||
)
|
||||
self.color_network_edges()
|
||||
self.remove(edge_groups)
|
||||
self.play(*it.chain(
|
||||
[ReplacementTransform(
|
||||
edge_groups.saved_state, edge_groups,
|
||||
|
@ -281,6 +290,8 @@ class TrainingVsTestData(Scene):
|
|||
self.get_examples() for x in range(2)
|
||||
]
|
||||
|
||||
training_examples.next_to(ORIGIN, LEFT)
|
||||
test_examples.next_to(ORIGIN, RIGHT)
|
||||
self.play(
|
||||
LaggedStart(FadeIn, training_examples),
|
||||
LaggedStart(FadeIn, test_examples),
|
||||
|
@ -339,7 +350,7 @@ class TrainingVsTestData(Scene):
|
|||
self.remove(train_ex)
|
||||
self.add(new_ex)
|
||||
new_ex[0][0].highlight(color)
|
||||
self.dither(1./10)
|
||||
self.dither(1./30)
|
||||
training_examples = new_examples
|
||||
|
||||
class NotSciFi(TeacherStudentsScene):
|
||||
|
@ -348,20 +359,51 @@ class NotSciFi(TeacherStudentsScene):
|
|||
self.student_says(
|
||||
"Machines learning?!?",
|
||||
student_index = 0,
|
||||
target_mode = "confused",
|
||||
target_mode = "pleading",
|
||||
run_time = 1,
|
||||
)
|
||||
bubble = students[0].bubble
|
||||
students[0].bubble = None
|
||||
self.student_says(
|
||||
"Run!", student_index = 2,
|
||||
target_mode = "pleading",
|
||||
bubble_kwargs = {"direction" : LEFT}
|
||||
"Should we \\\\ be worried?", student_index = 2,
|
||||
target_mode = "confused",
|
||||
bubble_kwargs = {"direction" : LEFT},
|
||||
run_time = 1,
|
||||
)
|
||||
self.dither()
|
||||
students[0].bubble = bubble
|
||||
|
||||
|
||||
|
||||
self.teacher_says(
|
||||
"It's actually \\\\ just calculus.",
|
||||
run_time = 1
|
||||
)
|
||||
self.teacher.bubble = None
|
||||
self.dither()
|
||||
self.student_says(
|
||||
"Even worse!",
|
||||
target_mode = "horrified",
|
||||
bubble_kwargs = {
|
||||
"direction" : LEFT,
|
||||
"width" : 3,
|
||||
"height" : 2,
|
||||
},
|
||||
)
|
||||
self.dither(2)
|
||||
|
||||
class FunctionMinmization(GraphScene):
|
||||
CONFIG = {
|
||||
"x_labeled_nums" : range(-1, 10),
|
||||
}
|
||||
def construct(self):
|
||||
self.setup_axes()
|
||||
def func(x):
|
||||
x -= 5
|
||||
return 0.1*(x**3 - 9*x) + 4
|
||||
graph = self.get_graph(func)
|
||||
graph_label = self.get_graph_label(graph, "C(x)")
|
||||
self.add(graph, graph_label)
|
||||
|
||||
dot = Dot(color = YELLOW)
|
||||
x =
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue