Simulations for an SIR epidemic model

This commit is contained in:
Grant Sanderson 2020-03-19 10:47:05 -07:00
parent f7ad9e71e0
commit 6fee2f5a0d

770
from_3b1b/active/sir.py Normal file
View file

@ -0,0 +1,770 @@
from manimlib.imports import *
# import scipy.stats
SICKLY_GREEN = "#9BBD37"
COLOR_MAP = {
"S": BLUE,
"I": RED,
"R": GREY_D,
}
def update_time(mob, dt):
mob.time += dt
class Person(VGroup):
CONFIG = {
"status": "S", # S, I or R
"height": 0.2,
"color_map": COLOR_MAP,
"infection_ring_style": {
"stroke_color": RED,
"stroke_opacity": 0.8,
"stroke_width": 0,
},
"infection_radius": 0.5,
"infection_animation_period": 2,
"max_speed": 1,
"dl_bound": [-FRAME_WIDTH / 2, -FRAME_HEIGHT / 2],
"ur_bound": [FRAME_WIDTH / 2, FRAME_HEIGHT / 2],
"gravity_well": None,
"gravity_strength": 1,
"wander_step_size": 1,
"wander_step_duration": 1,
"social_distance_factor": 0,
"n_repulsion_points": 10,
"social_distance_color": YELLOW,
"max_social_distance_stroke_width": 5,
}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.time = 0
self.last_step_change = -1
self.change_anims = []
self.velocity = np.zeros(3)
self.infection_start_time = np.inf
self.infection_end_time = np.inf
self.repulsion_points = []
self.center_point = VectorizedPoint()
self.add(self.center_point)
self.add_body()
self.add_infection_ring()
self.set_status(self.status, run_time=0)
# Updaters
self.add_updater(update_time)
self.add_updater(lambda m, dt: m.update_position(dt))
self.add_updater(lambda m, dt: m.update_infection_ring(dt))
self.add_updater(lambda m: m.progress_through_change_anims())
def add_body(self):
body = self.get_body()
body.set_height(self.height)
body.move_to(self.get_center())
self.add(body)
self.body = body
def get_body(self, status):
person = SVGMobject(file_name="person")
person.set_stroke(width=0)
return person
def set_status(self, status, run_time=1):
start_color = self.color_map[self.status]
end_color = self.color_map[status]
anims = [
UpdateFromAlphaFunc(
self.body,
lambda m, a: m.set_color(interpolate_color(
start_color, end_color, a
)),
run_time=run_time,
)
]
for anim in anims:
self.push_anim(anim)
if status == "I":
self.infection_start_time = self.time
self.infection_ring.set_stroke(width=0, opacity=0)
if self.status == "I":
self.infection_end_time = self.time
self.status = status
def push_anim(self, anim):
anim.begin()
anim.start_time = self.time
self.change_anims.append(anim)
return self
def pop_anim(self, anim):
anim.update(1)
anim.finish()
self.change_anims.remove(anim)
def add_infection_ring(self):
self.infection_ring = Circle(
radius=self.height / 2,
)
self.infection_ring.set_style(**self.infection_ring_style)
self.add(self.infection_ring)
self.infection_ring.time = 0
return self
def update_position(self, dt):
center = self.get_center()
total_force = np.zeros(3)
# Gravity
if self.wander_step_size != 0:
if (self.time - self.last_step_change) > self.wander_step_duration:
vect = rotate_vector(RIGHT, TAU * random.random())
self.gravity_well = center + self.wander_step_size * vect
self.last_step_change = self.time
if self.gravity_well is not None:
to_well = (self.gravity_well - center)
dist = get_norm(to_well)
if dist != 0:
total_force += self.gravity_strength * to_well / (dist**3)
# Potentially avoid neighbors
if self.social_distance_factor > 0:
repulsion_force = np.zeros(3)
min_dist = np.inf
for point in self.repulsion_points:
to_point = point - center
dist = get_norm(to_point)
if dist < min_dist:
min_dist = dist
if dist > 0:
repulsion_force -= self.social_distance_factor * to_point / (dist**3)
self.body.set_stroke(
self.social_distance_color,
width=clip(4 * min_dist - 1, 0, self.max_social_distance_stroke_width),
background=True,
)
total_force += repulsion_force
# Avoid walls
wall_force = np.zeros(3)
for i in range(2):
to_lower = center[i] - self.dl_bound[i]
to_upper = self.ur_bound[i] - center[i]
# Bounce
if to_lower < 0:
self.velocity[i] = abs(self.velocity[i])
if to_upper < 0:
self.velocity[i] = -abs(self.velocity[i])
# Repelling force
wall_force[i] += max((-1 + 1 / to_lower), 0)
wall_force[i] -= max((-1 + 1 / to_upper), 0)
total_force += wall_force
# Apply force
self.velocity += total_force * dt
# Limit speed
speed = get_norm(self.velocity)
if speed > self.max_speed:
self.velocity *= self.max_speed / speed
# Update velocity
self.shift(self.velocity * dt)
def update_infection_ring(self, dt):
ring = self.infection_ring
if not (self.infection_start_time <= self.time <= self.infection_end_time + 1):
return self
ring_time = self.time - self.infection_start_time
period = self.infection_animation_period
alpha = (ring_time % period) / period
ring.set_height(interpolate(
self.height,
self.infection_radius,
smooth(alpha),
))
ring.set_stroke(
width=interpolate(
0, 5,
there_and_back(alpha),
),
opacity=min([
min([ring_time, 1]),
min([self.infection_end_time + 1 - self.time, 1]),
]),
)
return self
def progress_through_change_anims(self):
for anim in self.change_anims:
if anim.run_time == 0:
alpha = 1
else:
alpha = (self.time - anim.start_time) / anim.run_time
anim.update(alpha)
if alpha >= 1:
self.pop_anim(anim)
def get_center(self):
return self.center_point.points[0]
class DotPerson(Person):
def get_body(self):
return Dot()
class PiPerson(Person):
CONFIG = {
"mode_map": {
"S": "guilty",
"I": "sick",
"R": "tease",
}
}
def get_body(self):
return Randolph()
def set_status(self, status, run_time=1):
super().set_status(status)
target = self.body.copy()
target.change(self.mode_map[status])
target.set_color(self.color_map[status])
transform = Transform(self.body, target)
transform.begin()
def update(body, alpha):
transform.update(alpha)
body.move_to(self.center_point)
anims = [
UpdateFromAlphaFunc(self.body, update, run_time=run_time),
]
for anim in anims:
self.push_anim(anim)
return self
class SIRSimulation(VGroup):
CONFIG = {
"n_cities": 1,
"city_population": 100,
"box_size": 7,
"person_type": PiPerson,
"person_config": {
"height": 0.2,
"infection_radius": 0.6,
"gravity_strength": 1,
"wander_step_size": 1,
},
"p_infection_per_day": 0.2,
"infection_time": 5,
"travel_rate": 0,
}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.time = 0
self.add_updater(update_time)
self.add_boxes()
self.add_people()
self.add_updater(lambda m, dt: m.update_statusses(dt))
def add_boxes(self):
boxes = VGroup()
for x in range(self.n_cities):
box = Square()
box.set_height(self.box_size)
box.set_stroke(WHITE, 3)
boxes.add(box)
boxes.arrange_in_grid(buff=LARGE_BUFF)
self.add(boxes)
self.boxes = boxes
def add_people(self):
people = VGroup()
for box in self.boxes:
dl_bound = box.get_corner(DL)
ur_bound = box.get_corner(UR)
for x in range(self.city_population):
person = self.person_type(
dl_bound=dl_bound,
ur_bound=ur_bound,
**self.person_config
)
person.move_to([
interpolate(lower, upper, random.random())
for lower, upper in zip(dl_bound, ur_bound)
])
person.box = box
people.add(person)
# Choose a patient zero
random.choice(people).set_status("I")
self.add(people)
self.people = people
def update_statusses(self, dt):
s_group, i_group = [
list(filter(
lambda m: m.status == status,
self.people
))
for status in ["S", "I"]
]
for s_person in s_group:
for i_person in i_group:
dist = get_norm(i_person.get_center() - s_person.get_center())
if dist < s_person.infection_radius and random.random() < self.p_infection_per_day * dt:
s_person.set_status("I")
for i_person in i_group:
if (i_person.time - i_person.infection_start_time) > self.infection_time:
i_person.set_status("R")
# Travel
if self.travel_rate > 0:
path_func = path_along_arc(45 * DEGREES)
for person in self.people:
if random.random() < self.travel_rate * dt:
new_box = random.choice(self.boxes)
person.box = new_box
person.dl_bound = new_box.get_corner(DL)
person.ur_bound = new_box.get_corner(UR)
person.old_center = person.get_center()
person.new_center = new_box.get_center()
anim = UpdateFromAlphaFunc(
person,
lambda m, a: m.move_to(path_func(
m.old_center, m.new_center, a,
)),
run_time=3,
)
person.push_anim(anim)
# Social distancing
centers = np.array([person.get_center() for person in self.people])
for center, person in zip(centers, self.people):
if person.social_distance_factor > 0:
diffs = np.linalg.norm(centers - center, axis=1)
person.repulsion_points = centers[np.argsort(diffs)[1:person.n_repulsion_points + 1]]
def get_status_counts(self):
return np.array([
len(list(filter(
lambda m: m.status == status,
self.people
)))
for status in "SIR"
])
def get_status_proportions(self):
counts = self.get_status_counts()
return counts / sum(counts)
class SIRGraph(VGroup):
CONFIG = {
"color_map": COLOR_MAP,
"height": 7,
"width": 5,
"update_frequency": 0.5,
"include_braces": False,
}
def __init__(self, simulation, **kwargs):
super().__init__(**kwargs)
self.simulation = simulation
self.data = [simulation.get_status_proportions()] * 2
self.add_axes()
self.add_graph()
self.add_x_labels()
self.time = 0
self.last_update_time = 0
self.add_updater(update_time)
self.add_updater(lambda m: m.update_graph())
self.add_updater(lambda m: m.update_x_labels())
def add_axes(self):
axes = Axes(
y_min=0,
y_max=1,
y_axis_config={
"tick_frequency": 0.1,
},
x_min=0,
x_max=1,
axis_config={
"include_tip": False,
},
)
origin = axes.c2p(0, 0)
axes.x_axis.set_width(self.width, about_point=origin, stretch=True)
axes.y_axis.set_height(self.height, about_point=origin, stretch=True)
self.add(axes)
self.axes = axes
def add_graph(self):
self.graph = self.get_graph(self.data)
self.add(self.graph)
def add_x_labels(self):
self.x_labels = VGroup()
self.x_ticks = VGroup()
self.add(self.x_ticks, self.x_labels)
def get_graph(self, data):
axes = self.axes
i_points = []
s_points = []
for x, props in zip(np.linspace(0, 1, len(data)), data):
i_point = axes.c2p(x, props[1])
s_point = axes.c2p(x, sum(props[:2]))
i_points.append(i_point)
s_points.append(s_point)
r_points = [
axes.c2p(0, 1),
axes.c2p(1, 1),
*s_points[::-1],
axes.c2p(0, 1),
]
s_points.extend([
*i_points[::-1],
s_points[0],
])
i_points.extend([
axes.c2p(1, 0),
axes.c2p(0, 0),
i_points[0],
])
points_lists = [s_points, i_points, r_points]
regions = VGroup(VMobject(), VMobject(), VMobject())
for region, status, points in zip(regions, "SIR", points_lists):
region.set_points_as_corners(points)
region.set_stroke(width=0)
region.set_fill(self.color_map[status], 1)
regions[0].set_fill(opacity=0.5)
return regions
def update_graph(self):
if (self.time - self.last_update_time) > self.update_frequency:
self.data.append(self.simulation.get_status_proportions())
self.graph.become(self.get_graph(self.data))
self.last_update_time = self.time
def update_x_labels(self):
tick_height = 0.03 * self.graph.get_height()
tick_template = Line(DOWN, UP)
tick_template.set_height(tick_height)
def get_tick(x):
tick = tick_template.copy()
tick.move_to(self.axes.c2p(x / self.time, 0))
return tick
def get_label(x, tick):
label = Integer(x)
label.set_height(tick_height)
label.next_to(tick, DOWN, buff=0.5 * tick_height)
return label
self.x_labels.set_submobjects([])
self.x_ticks.set_submobjects([])
if self.time < 15:
tick_range = range(1, int(self.time) + 1)
elif self.time < 50:
tick_range = range(5, int(self.time) + 1, 5)
elif self.time < 100:
tick_range = range(10, int(self.time) + 1, 10)
for x in tick_range:
tick = get_tick(x)
label = get_label(x, tick)
self.x_ticks.add(tick)
self.x_labels.add(label)
if 10 < self.time < 15:
alpha = (self.time - 10) / 5
for tick, label in zip(self.x_ticks, self.x_labels):
if label.get_value() % 5 != 0:
label.set_opacity(1 - alpha)
tick.set_opacity(1 - alpha)
class GraphBraces(VGroup):
CONFIG = {
"update_frequency": 0.5,
}
def __init__(self, graph, simulation, **kwargs):
super().__init__(**kwargs)
axes = self.axes = graph.axes
self.simulation = simulation
ys = np.linspace(0, 1, 4)
self.lines = VGroup(*[
Line(axes.c2p(1, y1), axes.c2p(1, y2))
for y1, y2 in zip(ys, ys[1:])
])
self.braces = VGroup(*[Brace(line, RIGHT) for line in self.lines])
self.labels = VGroup(
TextMobject("Susceptible", color=COLOR_MAP["S"]),
TextMobject("Infectious", color=COLOR_MAP["I"]),
TextMobject("Recovered", color=COLOR_MAP["R"]),
)
self.max_label_height = graph.get_height() * 0.05
self.add(self.braces, self.labels)
self.time = 0
self.last_update_time = 0
self.add_updater(update_time)
self.add_updater(lambda m: m.update_braces())
def update_braces(self):
if (self.time - self.last_update_time) <= self.update_frequency:
return
self.last_update_time = self.time
lines = self.lines
braces = self.braces
labels = self.labels
axes = self.axes
props = self.simulation.get_status_proportions()
ys = np.cumsum([0, props[1], props[0], props[2]])
epsilon = 1e-6
for i, y1, y2 in zip([1, 0, 2], ys, ys[1:]):
lines[i].set_points_as_corners([
axes.c2p(1, y1),
axes.c2p(1, y2),
])
height = lines[i].get_height()
braces[i].set_height(
max(height, epsilon),
stretch=True
)
braces[i].next_to(lines[i], RIGHT)
label_height = clip(height, epsilon, self.max_label_height)
labels[i].scale(label_height / labels[i][0][0].get_height())
labels[i].next_to(braces[i], RIGHT)
return self
class ValueSlider(NumberLine):
CONFIG = {
"x_min": 0,
"x_max": 1,
"tick_frequency": 0.1,
"numbers_with_elongated_ticks": [],
"numbers_to_show": np.linspace(0, 1, 6),
"decimal_number_config": {
"num_decimal_places": 1,
},
"stroke_width": 5,
"width": 8,
"marker_color": BLUE,
}
def __init__(self, name, value, **kwargs):
super().__init__(**kwargs)
self.set_width(self.width, stretch=True)
self.add_numbers()
self.marker = ArrowTip(start_angle=-90 * DEGREES)
self.marker.move_to(self.n2p(value), DOWN)
self.marker.set_color(self.marker_color)
self.add(self.marker)
self.label = DecimalNumber(value)
self.label.next_to(self.marker, UP)
self.add(self.label)
self.name = TextMobject(name)
self.name.scale(1.25)
self.name.next_to(self, DOWN)
self.name.match_color(self.marker)
self.add(self.name)
def get_change_anim(self, new_value, **kwargs):
start_value = self.label.get_value()
m2l = self.label.get_center() - self.marker.get_center()
def update(mob, alpha):
interim_value = interpolate(start_value, new_value, alpha)
mob.marker.move_to(mob.n2p(interim_value), DOWN)
mob.label.move_to(mob.marker.get_center() + m2l)
mob.label.set_value(interim_value)
return UpdateFromAlphaFunc(self, update, **kwargs)
# Scenes
class RunSimpleSimulation(Scene):
CONFIG = {
"simulation_config": {
"person_type": PiPerson,
"n_cities": 1,
"city_population": 100,
"person_config": {
"infection_radius": 0.75,
"social_distance_factor": 0,
"gravity_strength": 0.2,
"max_speed": 0.5,
},
"travel_rate": 0,
"infection_time": 5,
},
"graph_config": {
"update_frequency": 0.25,
},
"graph_height_to_frame_height": 0.5,
"graph_width_to_frame_height": 0.75,
"include_graph_braces": True,
}
def setup(self):
self.add_simulation()
self.position_camera()
self.add_graph()
self.add_sliders()
def construct(self):
for x in range(5):
self.wait(5)
def add_simulation(self):
self.simulation = SIRSimulation(**self.simulation_config)
self.add(self.simulation)
def position_camera(self):
frame = self.camera.frame
boxes = self.simulation.boxes
min_height = boxes.get_height() + 1
min_width = 3 * boxes.get_width()
if frame.get_height() < min_height:
frame.set_height(min_height)
if frame.get_width() < min_width:
frame.set_width(min_width)
frame.next_to(boxes.get_right(), LEFT, buff=-0.1 * boxes.get_width())
def add_graph(self):
frame = self.camera.frame
frame_height = frame.get_height()
graph = SIRGraph(
self.simulation,
height=self.graph_height_to_frame_height * frame_height,
width=self.graph_width_to_frame_height * frame_height,
**self.graph_config,
)
graph.move_to(frame, UL)
graph.shift(0.05 * DR * frame_height)
self.add(graph)
self.graph = graph
if self.include_graph_braces:
self.graph_braces = GraphBraces(
graph,
self.simulation,
update_frequency=graph.update_frequency
)
self.add(self.graph_braces)
def add_sliders(self):
pass
class RunSimpleSimulationWithDots(RunSimpleSimulation):
CONFIG = {
"person_type": DotPerson,
}
class SimpleSocialDistancing(RunSimpleSimulation):
CONFIG = {
"simulation_config": {
"person_type": PiPerson,
"n_cities": 1,
"city_population": 100,
"person_config": {
"infection_radius": 0.75,
"social_distance_factor": 2,
"gravity_strength": 0.1,
},
"travel_rate": 0,
"infection_time": 5,
},
}
def construct(self):
for x in range(5):
self.wait(5)
class SimpleTravel(RunSimpleSimulation):
CONFIG = {
"simulation_config": {
"person_type": DotPerson,
"n_cities": 12,
"city_population": 100,
"person_config": {
"infection_radius": 0.75,
"social_distance_factor": 0,
"gravity_strength": 0.5,
},
"travel_rate": 0.02,
"infection_time": 5,
},
}
def construct(self):
for x in range(10):
self.wait(5)
def add_sliders(self):
slider = ValueSlider(
"Travel rate",
self.simulation.travel_rate,
x_min=0,
x_max=0.1,
tick_frequency=0.01,
numbers_with_elongated_ticks=[],
numbers_to_show=np.linspace(0, 0.1, 6),
decimal_number_config={
"num_decimal_places": 2,
}
)
slider.match_width(self.graph)
slider.next_to(self.graph, DOWN, buff=5)
self.add(slider)
self.slider = slider