wifi: mt76: add a wrapper for wcid access with validation

Several places use rcu_dereference to get a wcid entry without validating
if the index exceeds the array boundary. Fix this by using a helper function,
which handles validation.

Link: https://patch.msgid.link/20250707154702.1726-1-nbd@nbd.name
Signed-off-by: Felix Fietkau <nbd@nbd.name>
This commit is contained in:
Felix Fietkau 2025-07-07 17:47:01 +02:00
parent 7035a08234
commit dc66a129ad
17 changed files with 41 additions and 68 deletions

View file

@ -1224,6 +1224,16 @@ static inline int mt76_wed_dma_setup(struct mt76_dev *dev, struct mt76_queue *q,
#define mt76_dereference(p, dev) \ #define mt76_dereference(p, dev) \
rcu_dereference_protected(p, lockdep_is_held(&(dev)->mutex)) rcu_dereference_protected(p, lockdep_is_held(&(dev)->mutex))
static inline struct mt76_wcid *
__mt76_wcid_ptr(struct mt76_dev *dev, u16 idx)
{
if (idx >= ARRAY_SIZE(dev->wcid))
return NULL;
return rcu_dereference(dev->wcid[idx]);
}
#define mt76_wcid_ptr(dev, idx) __mt76_wcid_ptr(&(dev)->mt76, idx)
struct mt76_dev *mt76_alloc_device(struct device *pdev, unsigned int size, struct mt76_dev *mt76_alloc_device(struct device *pdev, unsigned int size,
const struct ieee80211_ops *ops, const struct ieee80211_ops *ops,
const struct mt76_driver_ops *drv_ops); const struct mt76_driver_ops *drv_ops);

View file

@ -44,7 +44,7 @@ mt7603_rx_loopback_skb(struct mt7603_dev *dev, struct sk_buff *skb)
if (idx >= MT7603_WTBL_STA - 1) if (idx >= MT7603_WTBL_STA - 1)
goto free; goto free;
wcid = rcu_dereference(dev->mt76.wcid[idx]); wcid = mt76_wcid_ptr(dev, idx);
if (!wcid) if (!wcid)
goto free; goto free;

View file

@ -487,10 +487,7 @@ mt7603_rx_get_wcid(struct mt7603_dev *dev, u8 idx, bool unicast)
struct mt7603_sta *sta; struct mt7603_sta *sta;
struct mt76_wcid *wcid; struct mt76_wcid *wcid;
if (idx >= MT7603_WTBL_SIZE) wcid = mt76_wcid_ptr(dev, idx);
return NULL;
wcid = rcu_dereference(dev->mt76.wcid[idx]);
if (unicast || !wcid) if (unicast || !wcid)
return wcid; return wcid;
@ -1266,12 +1263,9 @@ void mt7603_mac_add_txs(struct mt7603_dev *dev, void *data)
if (pid == MT_PACKET_ID_NO_ACK) if (pid == MT_PACKET_ID_NO_ACK)
return; return;
if (wcidx >= MT7603_WTBL_SIZE)
return;
rcu_read_lock(); rcu_read_lock();
wcid = rcu_dereference(dev->mt76.wcid[wcidx]); wcid = mt76_wcid_ptr(dev, wcidx);
if (!wcid) if (!wcid)
goto out; goto out;

View file

@ -90,10 +90,7 @@ static struct mt76_wcid *mt7615_rx_get_wcid(struct mt7615_dev *dev,
struct mt7615_sta *sta; struct mt7615_sta *sta;
struct mt76_wcid *wcid; struct mt76_wcid *wcid;
if (idx >= MT7615_WTBL_SIZE) wcid = mt76_wcid_ptr(dev, idx);
return NULL;
wcid = rcu_dereference(dev->mt76.wcid[idx]);
if (unicast || !wcid) if (unicast || !wcid)
return wcid; return wcid;
@ -1504,7 +1501,7 @@ static void mt7615_mac_add_txs(struct mt7615_dev *dev, void *data)
rcu_read_lock(); rcu_read_lock();
wcid = rcu_dereference(dev->mt76.wcid[wcidx]); wcid = mt76_wcid_ptr(dev, wcidx);
if (!wcid) if (!wcid)
goto out; goto out;

View file

@ -1172,7 +1172,7 @@ void mt76_connac2_txwi_free(struct mt76_dev *dev, struct mt76_txwi_cache *t,
wcid_idx = wcid->idx; wcid_idx = wcid->idx;
} else { } else {
wcid_idx = le32_get_bits(txwi[1], MT_TXD1_WLAN_IDX); wcid_idx = le32_get_bits(txwi[1], MT_TXD1_WLAN_IDX);
wcid = rcu_dereference(dev->wcid[wcid_idx]); wcid = __mt76_wcid_ptr(dev, wcid_idx);
if (wcid && wcid->sta) { if (wcid && wcid->sta) {
sta = container_of((void *)wcid, struct ieee80211_sta, sta = container_of((void *)wcid, struct ieee80211_sta,

View file

@ -262,10 +262,7 @@ mt76x02_rx_get_sta(struct mt76_dev *dev, u8 idx)
{ {
struct mt76_wcid *wcid; struct mt76_wcid *wcid;
if (idx >= MT76x02_N_WCIDS) wcid = __mt76_wcid_ptr(dev, idx);
return NULL;
wcid = rcu_dereference(dev->wcid[idx]);
if (!wcid) if (!wcid)
return NULL; return NULL;

View file

@ -564,9 +564,7 @@ void mt76x02_send_tx_status(struct mt76x02_dev *dev,
rcu_read_lock(); rcu_read_lock();
if (stat->wcid < MT76x02_N_WCIDS) wcid = mt76_wcid_ptr(dev, stat->wcid);
wcid = rcu_dereference(dev->mt76.wcid[stat->wcid]);
if (wcid && wcid->sta) { if (wcid && wcid->sta) {
void *priv; void *priv;

View file

@ -56,10 +56,7 @@ static struct mt76_wcid *mt7915_rx_get_wcid(struct mt7915_dev *dev,
struct mt7915_sta *sta; struct mt7915_sta *sta;
struct mt76_wcid *wcid; struct mt76_wcid *wcid;
if (idx >= ARRAY_SIZE(dev->mt76.wcid)) wcid = mt76_wcid_ptr(dev, idx);
return NULL;
wcid = rcu_dereference(dev->mt76.wcid[idx]);
if (unicast || !wcid) if (unicast || !wcid)
return wcid; return wcid;
@ -917,7 +914,7 @@ mt7915_mac_tx_free(struct mt7915_dev *dev, void *data, int len)
u16 idx; u16 idx;
idx = FIELD_GET(MT_TX_FREE_WLAN_ID, info); idx = FIELD_GET(MT_TX_FREE_WLAN_ID, info);
wcid = rcu_dereference(dev->mt76.wcid[idx]); wcid = mt76_wcid_ptr(dev, idx);
sta = wcid_to_sta(wcid); sta = wcid_to_sta(wcid);
if (!sta) if (!sta)
continue; continue;
@ -1013,12 +1010,9 @@ static void mt7915_mac_add_txs(struct mt7915_dev *dev, void *data)
if (pid < MT_PACKET_ID_WED) if (pid < MT_PACKET_ID_WED)
return; return;
if (wcidx >= mt7915_wtbl_size(dev))
return;
rcu_read_lock(); rcu_read_lock();
wcid = rcu_dereference(dev->mt76.wcid[wcidx]); wcid = mt76_wcid_ptr(dev, wcidx);
if (!wcid) if (!wcid)
goto out; goto out;

View file

@ -3986,7 +3986,7 @@ int mt7915_mcu_wed_wa_tx_stats(struct mt7915_dev *dev, u16 wlan_idx)
rcu_read_lock(); rcu_read_lock();
wcid = rcu_dereference(dev->mt76.wcid[wlan_idx]); wcid = mt76_wcid_ptr(dev, wlan_idx);
if (wcid) if (wcid)
wcid->stats.tx_packets += le32_to_cpu(res->tx_packets); wcid->stats.tx_packets += le32_to_cpu(res->tx_packets);
else else

View file

@ -587,12 +587,9 @@ static void mt7915_mmio_wed_update_rx_stats(struct mtk_wed_device *wed,
dev = container_of(wed, struct mt7915_dev, mt76.mmio.wed); dev = container_of(wed, struct mt7915_dev, mt76.mmio.wed);
if (idx >= mt7915_wtbl_size(dev))
return;
rcu_read_lock(); rcu_read_lock();
wcid = rcu_dereference(dev->mt76.wcid[idx]); wcid = mt76_wcid_ptr(dev, idx);
if (wcid) { if (wcid) {
wcid->stats.rx_bytes += le32_to_cpu(stats->rx_byte_cnt); wcid->stats.rx_bytes += le32_to_cpu(stats->rx_byte_cnt);
wcid->stats.rx_packets += le32_to_cpu(stats->rx_pkt_cnt); wcid->stats.rx_packets += le32_to_cpu(stats->rx_pkt_cnt);

View file

@ -465,7 +465,7 @@ void mt7921_mac_add_txs(struct mt792x_dev *dev, void *data)
rcu_read_lock(); rcu_read_lock();
wcid = rcu_dereference(dev->mt76.wcid[wcidx]); wcid = mt76_wcid_ptr(dev, wcidx);
if (!wcid) if (!wcid)
goto out; goto out;
@ -516,7 +516,7 @@ static void mt7921_mac_tx_free(struct mt792x_dev *dev, void *data, int len)
count++; count++;
idx = FIELD_GET(MT_TX_FREE_WLAN_ID, info); idx = FIELD_GET(MT_TX_FREE_WLAN_ID, info);
wcid = rcu_dereference(dev->mt76.wcid[idx]); wcid = mt76_wcid_ptr(dev, idx);
sta = wcid_to_sta(wcid); sta = wcid_to_sta(wcid);
if (!sta) if (!sta)
continue; continue;
@ -816,7 +816,7 @@ void mt7921_usb_sdio_tx_complete_skb(struct mt76_dev *mdev,
u16 idx; u16 idx;
idx = le32_get_bits(txwi[1], MT_TXD1_WLAN_IDX); idx = le32_get_bits(txwi[1], MT_TXD1_WLAN_IDX);
wcid = rcu_dereference(mdev->wcid[idx]); wcid = __mt76_wcid_ptr(mdev, idx);
sta = wcid_to_sta(wcid); sta = wcid_to_sta(wcid);
if (sta && likely(e->skb->protocol != cpu_to_be16(ETH_P_PAE))) if (sta && likely(e->skb->protocol != cpu_to_be16(ETH_P_PAE)))

View file

@ -1040,7 +1040,7 @@ void mt7925_mac_add_txs(struct mt792x_dev *dev, void *data)
rcu_read_lock(); rcu_read_lock();
wcid = rcu_dereference(dev->mt76.wcid[wcidx]); wcid = mt76_wcid_ptr(dev, wcidx);
if (!wcid) if (!wcid)
goto out; goto out;
@ -1122,7 +1122,7 @@ mt7925_mac_tx_free(struct mt792x_dev *dev, void *data, int len)
u16 idx; u16 idx;
idx = FIELD_GET(MT_TXFREE_INFO_WLAN_ID, info); idx = FIELD_GET(MT_TXFREE_INFO_WLAN_ID, info);
wcid = rcu_dereference(dev->mt76.wcid[idx]); wcid = mt76_wcid_ptr(dev, idx);
sta = wcid_to_sta(wcid); sta = wcid_to_sta(wcid);
if (!sta) if (!sta)
continue; continue;
@ -1445,7 +1445,7 @@ void mt7925_usb_sdio_tx_complete_skb(struct mt76_dev *mdev,
u16 idx; u16 idx;
idx = le32_get_bits(txwi[1], MT_TXD1_WLAN_IDX); idx = le32_get_bits(txwi[1], MT_TXD1_WLAN_IDX);
wcid = rcu_dereference(mdev->wcid[idx]); wcid = __mt76_wcid_ptr(mdev, idx);
sta = wcid_to_sta(wcid); sta = wcid_to_sta(wcid);
if (sta && likely(e->skb->protocol != cpu_to_be16(ETH_P_PAE))) if (sta && likely(e->skb->protocol != cpu_to_be16(ETH_P_PAE)))

View file

@ -142,10 +142,7 @@ struct mt76_wcid *mt792x_rx_get_wcid(struct mt792x_dev *dev, u16 idx,
struct mt792x_sta *sta; struct mt792x_sta *sta;
struct mt76_wcid *wcid; struct mt76_wcid *wcid;
if (idx >= ARRAY_SIZE(dev->mt76.wcid)) wcid = mt76_wcid_ptr(dev, idx);
return NULL;
wcid = rcu_dereference(dev->mt76.wcid[idx]);
if (unicast || !wcid) if (unicast || !wcid)
return wcid; return wcid;

View file

@ -61,10 +61,7 @@ static struct mt76_wcid *mt7996_rx_get_wcid(struct mt7996_dev *dev,
struct mt76_wcid *wcid; struct mt76_wcid *wcid;
int i; int i;
if (idx >= ARRAY_SIZE(dev->mt76.wcid)) wcid = mt76_wcid_ptr(dev, idx);
return NULL;
wcid = rcu_dereference(dev->mt76.wcid[idx]);
if (!wcid) if (!wcid)
return NULL; return NULL;
@ -1249,7 +1246,7 @@ mt7996_mac_tx_free(struct mt7996_dev *dev, void *data, int len)
u16 idx; u16 idx;
idx = FIELD_GET(MT_TXFREE_INFO_WLAN_ID, info); idx = FIELD_GET(MT_TXFREE_INFO_WLAN_ID, info);
wcid = rcu_dereference(dev->mt76.wcid[idx]); wcid = mt76_wcid_ptr(dev, idx);
sta = wcid_to_sta(wcid); sta = wcid_to_sta(wcid);
if (!sta) if (!sta)
goto next; goto next;
@ -1471,12 +1468,9 @@ static void mt7996_mac_add_txs(struct mt7996_dev *dev, void *data)
if (pid < MT_PACKET_ID_NO_SKB) if (pid < MT_PACKET_ID_NO_SKB)
return; return;
if (wcidx >= mt7996_wtbl_size(dev))
return;
rcu_read_lock(); rcu_read_lock();
wcid = rcu_dereference(dev->mt76.wcid[wcidx]); wcid = mt76_wcid_ptr(dev, wcidx);
if (!wcid) if (!wcid)
goto out; goto out;

View file

@ -555,7 +555,7 @@ mt7996_mcu_rx_all_sta_info_event(struct mt7996_dev *dev, struct sk_buff *skb)
switch (le16_to_cpu(res->tag)) { switch (le16_to_cpu(res->tag)) {
case UNI_ALL_STA_TXRX_RATE: case UNI_ALL_STA_TXRX_RATE:
wlan_idx = le16_to_cpu(res->rate[i].wlan_idx); wlan_idx = le16_to_cpu(res->rate[i].wlan_idx);
wcid = rcu_dereference(dev->mt76.wcid[wlan_idx]); wcid = mt76_wcid_ptr(dev, wlan_idx);
if (!wcid) if (!wcid)
break; break;
@ -565,7 +565,7 @@ mt7996_mcu_rx_all_sta_info_event(struct mt7996_dev *dev, struct sk_buff *skb)
break; break;
case UNI_ALL_STA_TXRX_ADM_STAT: case UNI_ALL_STA_TXRX_ADM_STAT:
wlan_idx = le16_to_cpu(res->adm_stat[i].wlan_idx); wlan_idx = le16_to_cpu(res->adm_stat[i].wlan_idx);
wcid = rcu_dereference(dev->mt76.wcid[wlan_idx]); wcid = mt76_wcid_ptr(dev, wlan_idx);
if (!wcid) if (!wcid)
break; break;
@ -579,7 +579,7 @@ mt7996_mcu_rx_all_sta_info_event(struct mt7996_dev *dev, struct sk_buff *skb)
break; break;
case UNI_ALL_STA_TXRX_MSDU_COUNT: case UNI_ALL_STA_TXRX_MSDU_COUNT:
wlan_idx = le16_to_cpu(res->msdu_cnt[i].wlan_idx); wlan_idx = le16_to_cpu(res->msdu_cnt[i].wlan_idx);
wcid = rcu_dereference(dev->mt76.wcid[wlan_idx]); wcid = mt76_wcid_ptr(dev, wlan_idx);
if (!wcid) if (!wcid)
break; break;
@ -676,10 +676,7 @@ mt7996_mcu_wed_rro_event(struct mt7996_dev *dev, struct sk_buff *skb)
e = (void *)skb->data; e = (void *)skb->data;
idx = le16_to_cpu(e->wlan_id); idx = le16_to_cpu(e->wlan_id);
if (idx >= ARRAY_SIZE(dev->mt76.wcid)) wcid = mt76_wcid_ptr(dev, idx);
break;
wcid = rcu_dereference(dev->mt76.wcid[idx]);
if (!wcid || !wcid->sta) if (!wcid || !wcid->sta)
break; break;

View file

@ -64,7 +64,7 @@ mt76_tx_status_unlock(struct mt76_dev *dev, struct sk_buff_head *list)
struct mt76_tx_cb *cb = mt76_tx_skb_cb(skb); struct mt76_tx_cb *cb = mt76_tx_skb_cb(skb);
struct mt76_wcid *wcid; struct mt76_wcid *wcid;
wcid = rcu_dereference(dev->wcid[cb->wcid]); wcid = __mt76_wcid_ptr(dev, cb->wcid);
if (wcid) { if (wcid) {
status.sta = wcid_to_sta(wcid); status.sta = wcid_to_sta(wcid);
if (status.sta && (wcid->rate.flags || wcid->rate.legacy)) { if (status.sta && (wcid->rate.flags || wcid->rate.legacy)) {
@ -251,9 +251,7 @@ void __mt76_tx_complete_skb(struct mt76_dev *dev, u16 wcid_idx, struct sk_buff *
rcu_read_lock(); rcu_read_lock();
if (wcid_idx < ARRAY_SIZE(dev->wcid)) wcid = __mt76_wcid_ptr(dev, wcid_idx);
wcid = rcu_dereference(dev->wcid[wcid_idx]);
mt76_tx_check_non_aql(dev, wcid, skb); mt76_tx_check_non_aql(dev, wcid, skb);
#ifdef CONFIG_NL80211_TESTMODE #ifdef CONFIG_NL80211_TESTMODE
@ -538,7 +536,7 @@ mt76_txq_schedule_list(struct mt76_phy *phy, enum mt76_txq_id qid)
break; break;
mtxq = (struct mt76_txq *)txq->drv_priv; mtxq = (struct mt76_txq *)txq->drv_priv;
wcid = rcu_dereference(dev->wcid[mtxq->wcid]); wcid = __mt76_wcid_ptr(dev, mtxq->wcid);
if (!wcid || test_bit(MT_WCID_FLAG_PS, &wcid->flags)) if (!wcid || test_bit(MT_WCID_FLAG_PS, &wcid->flags))
continue; continue;

View file

@ -83,7 +83,7 @@ int mt76_get_min_avg_rssi(struct mt76_dev *dev, u8 phy_idx)
if (!(mask & 1)) if (!(mask & 1))
continue; continue;
wcid = rcu_dereference(dev->wcid[j]); wcid = __mt76_wcid_ptr(dev, j);
if (!wcid || wcid->phy_idx != phy_idx) if (!wcid || wcid->phy_idx != phy_idx)
continue; continue;