diff --git a/net/xdp/xdp_umem.c b/net/xdp/xdp_umem.c
index 555427b3e0fe..4d6c6652f5d1 100644
--- a/net/xdp/xdp_umem.c
+++ b/net/xdp/xdp_umem.c
@@ -42,27 +42,47 @@ void xdp_del_sk_umem(struct xdp_umem *umem, struct xdp_sock *xs)
 	}
 }
 
-int xdp_umem_query(struct net_device *dev, u16 queue_id)
+/* The umem is stored both in the _rx struct and the _tx struct as we do
+ * not know if the device has more tx queues than rx, or the opposite.
+ * This might also change during run time.
+ */
+static void xdp_reg_umem_at_qid(struct net_device *dev, struct xdp_umem *umem,
+				u16 queue_id)
 {
-	struct netdev_bpf bpf;
+	if (queue_id < dev->real_num_rx_queues)
+		dev->_rx[queue_id].umem = umem;
+	if (queue_id < dev->real_num_tx_queues)
+		dev->_tx[queue_id].umem = umem;
+}
 
-	ASSERT_RTNL();
+static struct xdp_umem *xdp_get_umem_from_qid(struct net_device *dev,
+					      u16 queue_id)
+{
+	if (queue_id < dev->real_num_rx_queues)
+		return dev->_rx[queue_id].umem;
+	if (queue_id < dev->real_num_tx_queues)
+		return dev->_tx[queue_id].umem;
 
-	memset(&bpf, 0, sizeof(bpf));
-	bpf.command = XDP_QUERY_XSK_UMEM;
-	bpf.xsk.queue_id = queue_id;
+	return NULL;
+}
 
-	if (!dev->netdev_ops->ndo_bpf)
-		return 0;
-	return dev->netdev_ops->ndo_bpf(dev, &bpf) ?: !!bpf.xsk.umem;
+static void xdp_clear_umem_at_qid(struct net_device *dev, u16 queue_id)
+{
+	/* Zero out the entry independent on how many queues are configured
+	 * at this point in time, as it might be used in the future.
+	 */
+	if (queue_id < dev->num_rx_queues)
+		dev->_rx[queue_id].umem = NULL;
+	if (queue_id < dev->num_tx_queues)
+		dev->_tx[queue_id].umem = NULL;
 }
 
 int xdp_umem_assign_dev(struct xdp_umem *umem, struct net_device *dev,
-			u32 queue_id, u16 flags)
+			u16 queue_id, u16 flags)
 {
 	bool force_zc, force_copy;
 	struct netdev_bpf bpf;
-	int err;
+	int err = 0;
 
 	force_zc = flags & XDP_ZEROCOPY;
 	force_copy = flags & XDP_COPY;
@@ -70,17 +90,23 @@ int xdp_umem_assign_dev(struct xdp_umem *umem, struct net_device *dev,
 	if (force_zc && force_copy)
 		return -EINVAL;
 
-	if (force_copy)
-		return 0;
-
-	if (!dev->netdev_ops->ndo_bpf || !dev->netdev_ops->ndo_xsk_async_xmit)
-		return force_zc ? -EOPNOTSUPP : 0; /* fail or fallback */
-
 	rtnl_lock();
-	err = xdp_umem_query(dev, queue_id);
-	if (err) {
-		err = err < 0 ? -EOPNOTSUPP : -EBUSY;
-		goto err_rtnl_unlock;
+	if (xdp_get_umem_from_qid(dev, queue_id)) {
+		err = -EBUSY;
+		goto out_rtnl_unlock;
+	}
+
+	xdp_reg_umem_at_qid(dev, umem, queue_id);
+	umem->dev = dev;
+	umem->queue_id = queue_id;
+	if (force_copy)
+		/* For copy-mode, we are done. */
+		goto out_rtnl_unlock;
+
+	if (!dev->netdev_ops->ndo_bpf ||
+	    !dev->netdev_ops->ndo_xsk_async_xmit) {
+		err = -EOPNOTSUPP;
+		goto err_unreg_umem;
 	}
 
 	bpf.command = XDP_SETUP_XSK_UMEM;
@@ -89,18 +115,20 @@ int xdp_umem_assign_dev(struct xdp_umem *umem, struct net_device *dev,
 
 	err = dev->netdev_ops->ndo_bpf(dev, &bpf);
 	if (err)
-		goto err_rtnl_unlock;
+		goto err_unreg_umem;
 	rtnl_unlock();
 
 	dev_hold(dev);
-	umem->dev = dev;
-	umem->queue_id = queue_id;
 	umem->zc = true;
 	return 0;
 
-err_rtnl_unlock:
+err_unreg_umem:
+	xdp_clear_umem_at_qid(dev, queue_id);
+	if (!force_zc)
+		err = 0; /* fallback to copy mode */
+out_rtnl_unlock:
 	rtnl_unlock();
-	return force_zc ? err : 0; /* fail or fallback */
+	return err;
 }
 
 static void xdp_umem_clear_dev(struct xdp_umem *umem)
@@ -108,7 +136,7 @@ static void xdp_umem_clear_dev(struct xdp_umem *umem)
 	struct netdev_bpf bpf;
 	int err;
 
-	if (umem->dev) {
+	if (umem->zc) {
 		bpf.command = XDP_SETUP_XSK_UMEM;
 		bpf.xsk.umem = NULL;
 		bpf.xsk.queue_id = umem->queue_id;
@@ -119,9 +147,17 @@ static void xdp_umem_clear_dev(struct xdp_umem *umem)
 
 		if (err)
 			WARN(1, "failed to disable umem!\n");
+	}
 
+	if (umem->dev) {
+		rtnl_lock();
+		xdp_clear_umem_at_qid(umem->dev, umem->queue_id);
+		rtnl_unlock();
+	}
+
+	if (umem->zc) {
 		dev_put(umem->dev);
-		umem->dev = NULL;
+		umem->zc = false;
 	}
 }
 
diff --git a/net/xdp/xdp_umem.h b/net/xdp/xdp_umem.h
index c8be1ad3eb88..27603227601b 100644
--- a/net/xdp/xdp_umem.h
+++ b/net/xdp/xdp_umem.h
@@ -9,7 +9,7 @@
 #include <net/xdp_sock.h>
 
 int xdp_umem_assign_dev(struct xdp_umem *umem, struct net_device *dev,
-			u32 queue_id, u16 flags);
+			u16 queue_id, u16 flags);
 bool xdp_umem_validate_queues(struct xdp_umem *umem);
 void xdp_get_umem(struct xdp_umem *umem);
 void xdp_put_umem(struct xdp_umem *umem);
diff --git a/net/xdp/xsk.c b/net/xdp/xsk.c
index 5a432dfee4ee..caeddad15b7c 100644
--- a/net/xdp/xsk.c
+++ b/net/xdp/xsk.c
@@ -419,13 +419,6 @@ static int xsk_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
 	}
 
 	qid = sxdp->sxdp_queue_id;
-
-	if ((xs->rx && qid >= dev->real_num_rx_queues) ||
-	    (xs->tx && qid >= dev->real_num_tx_queues)) {
-		err = -EINVAL;
-		goto out_unlock;
-	}
-
 	flags = sxdp->sxdp_flags;
 
 	if (flags & XDP_SHARED_UMEM) {