Merge branch 'do-not-leave-dangling-sk-pointers-in-pf-create-functions'

Ignat Korchagin says:

====================
do not leave dangling sk pointers in pf->create functions

Some protocol family create() implementations have an error path after
allocating the sk object and calling sock_init_data(). sock_init_data()
attaches the allocated sk object to the sock object, provided by the
caller.

If the create() implementation errors out after calling sock_init_data(),
it releases the allocated sk object, but the caller ends up having a
dangling sk pointer in its sock object on return. Subsequent manipulations
on this sock object may try to access the sk pointer, because it is not
NULL thus creating a use-after-free scenario.

We have implemented a stable hotfix in commit 6310831433
("net: explicitly clear the sk pointer, when pf->create fails"), but this
series aims to fix it properly by going through each of the pf->create()
implementations and making sure they all don't return a sock object with
a dangling pointer on error.
====================

Link: https://patch.msgid.link/20241014153808.51894-1-ignat@cloudflare.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
This commit is contained in:
Jakub Kicinski 2024-10-15 18:43:11 -07:00
commit 2d859aff77
9 changed files with 42 additions and 45 deletions

View file

@ -1886,6 +1886,7 @@ static struct sock *l2cap_sock_alloc(struct net *net, struct socket *sock,
chan = l2cap_chan_create(); chan = l2cap_chan_create();
if (!chan) { if (!chan) {
sk_free(sk); sk_free(sk);
sock->sk = NULL;
return NULL; return NULL;
} }

View file

@ -274,13 +274,13 @@ static struct sock *rfcomm_sock_alloc(struct net *net, struct socket *sock,
struct rfcomm_dlc *d; struct rfcomm_dlc *d;
struct sock *sk; struct sock *sk;
sk = bt_sock_alloc(net, sock, &rfcomm_proto, proto, prio, kern); d = rfcomm_dlc_alloc(prio);
if (!sk) if (!d)
return NULL; return NULL;
d = rfcomm_dlc_alloc(prio); sk = bt_sock_alloc(net, sock, &rfcomm_proto, proto, prio, kern);
if (!d) { if (!sk) {
sk_free(sk); rfcomm_dlc_free(d);
return NULL; return NULL;
} }

View file

@ -171,6 +171,7 @@ static int can_create(struct net *net, struct socket *sock, int protocol,
/* release sk on errors */ /* release sk on errors */
sock_orphan(sk); sock_orphan(sk);
sock_put(sk); sock_put(sk);
sock->sk = NULL;
} }
errout: errout:

View file

@ -3827,9 +3827,6 @@ void sk_common_release(struct sock *sk)
sk->sk_prot->unhash(sk); sk->sk_prot->unhash(sk);
if (sk->sk_socket)
sk->sk_socket->sk = NULL;
/* /*
* In this point socket cannot receive new packets, but it is possible * In this point socket cannot receive new packets, but it is possible
* that some packets are in flight because some CPU runs receiver and * that some packets are in flight because some CPU runs receiver and

View file

@ -1043,19 +1043,21 @@ static int ieee802154_create(struct net *net, struct socket *sock,
if (sk->sk_prot->hash) { if (sk->sk_prot->hash) {
rc = sk->sk_prot->hash(sk); rc = sk->sk_prot->hash(sk);
if (rc) { if (rc)
sk_common_release(sk); goto out_sk_release;
goto out;
}
} }
if (sk->sk_prot->init) { if (sk->sk_prot->init) {
rc = sk->sk_prot->init(sk); rc = sk->sk_prot->init(sk);
if (rc) if (rc)
sk_common_release(sk); goto out_sk_release;
} }
out: out:
return rc; return rc;
out_sk_release:
sk_common_release(sk);
sock->sk = NULL;
goto out;
} }
static const struct net_proto_family ieee802154_family_ops = { static const struct net_proto_family ieee802154_family_ops = {

View file

@ -376,32 +376,30 @@ lookup_protocol:
inet->inet_sport = htons(inet->inet_num); inet->inet_sport = htons(inet->inet_num);
/* Add to protocol hash chains. */ /* Add to protocol hash chains. */
err = sk->sk_prot->hash(sk); err = sk->sk_prot->hash(sk);
if (err) { if (err)
sk_common_release(sk); goto out_sk_release;
goto out;
}
} }
if (sk->sk_prot->init) { if (sk->sk_prot->init) {
err = sk->sk_prot->init(sk); err = sk->sk_prot->init(sk);
if (err) { if (err)
sk_common_release(sk); goto out_sk_release;
goto out;
}
} }
if (!kern) { if (!kern) {
err = BPF_CGROUP_RUN_PROG_INET_SOCK(sk); err = BPF_CGROUP_RUN_PROG_INET_SOCK(sk);
if (err) { if (err)
sk_common_release(sk); goto out_sk_release;
goto out;
}
} }
out: out:
return err; return err;
out_rcu_unlock: out_rcu_unlock:
rcu_read_unlock(); rcu_read_unlock();
goto out; goto out;
out_sk_release:
sk_common_release(sk);
sock->sk = NULL;
goto out;
} }

View file

@ -252,31 +252,29 @@ lookup_protocol:
*/ */
inet->inet_sport = htons(inet->inet_num); inet->inet_sport = htons(inet->inet_num);
err = sk->sk_prot->hash(sk); err = sk->sk_prot->hash(sk);
if (err) { if (err)
sk_common_release(sk); goto out_sk_release;
goto out;
}
} }
if (sk->sk_prot->init) { if (sk->sk_prot->init) {
err = sk->sk_prot->init(sk); err = sk->sk_prot->init(sk);
if (err) { if (err)
sk_common_release(sk); goto out_sk_release;
goto out;
}
} }
if (!kern) { if (!kern) {
err = BPF_CGROUP_RUN_PROG_INET_SOCK(sk); err = BPF_CGROUP_RUN_PROG_INET_SOCK(sk);
if (err) { if (err)
sk_common_release(sk); goto out_sk_release;
goto out;
}
} }
out: out:
return err; return err;
out_rcu_unlock: out_rcu_unlock:
rcu_read_unlock(); rcu_read_unlock();
goto out; goto out;
out_sk_release:
sk_common_release(sk);
sock->sk = NULL;
goto out;
} }
static int __inet6_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len, static int __inet6_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len,

View file

@ -3422,17 +3422,17 @@ static int packet_create(struct net *net, struct socket *sock, int protocol,
if (sock->type == SOCK_PACKET) if (sock->type == SOCK_PACKET)
sock->ops = &packet_ops_spkt; sock->ops = &packet_ops_spkt;
po = pkt_sk(sk);
err = packet_alloc_pending(po);
if (err)
goto out_sk_free;
sock_init_data(sock, sk); sock_init_data(sock, sk);
po = pkt_sk(sk);
init_completion(&po->skb_completion); init_completion(&po->skb_completion);
sk->sk_family = PF_PACKET; sk->sk_family = PF_PACKET;
po->num = proto; po->num = proto;
err = packet_alloc_pending(po);
if (err)
goto out2;
packet_cached_dev_reset(po); packet_cached_dev_reset(po);
sk->sk_destruct = packet_sock_destruct; sk->sk_destruct = packet_sock_destruct;
@ -3464,7 +3464,7 @@ static int packet_create(struct net *net, struct socket *sock, int protocol,
sock_prot_inuse_add(net, &packet_proto, 1); sock_prot_inuse_add(net, &packet_proto, 1);
return 0; return 0;
out2: out_sk_free:
sk_free(sk); sk_free(sk);
out: out:
return err; return err;

View file

@ -1576,9 +1576,9 @@ int __sock_create(struct net *net, int family, int type, int protocol,
err = pf->create(net, sock, protocol, kern); err = pf->create(net, sock, protocol, kern);
if (err < 0) { if (err < 0) {
/* ->create should release the allocated sock->sk object on error /* ->create should release the allocated sock->sk object on error
* but it may leave the dangling pointer * and make sure sock->sk is set to NULL to avoid use-after-free
*/ */
sock->sk = NULL; DEBUG_NET_WARN_ON_ONCE(sock->sk);
goto out_module_put; goto out_module_put;
} }