// SPDX-License-Identifier: GPL-2.0-only #include #include #include #include #include #include #include "psp.h" struct psp_dev *psp_dev_get_for_sock(struct sock *sk) { struct psp_dev *psd = NULL; struct dst_entry *dst; rcu_read_lock(); dst = __sk_dst_get(sk); if (dst) { psd = rcu_dereference(dst_dev_rcu(dst)->psp_dev); if (psd && !psp_dev_tryget(psd)) psd = NULL; } rcu_read_unlock(); return psd; } static struct sk_buff * psp_validate_xmit(struct sock *sk, struct net_device *dev, struct sk_buff *skb) { struct psp_assoc *pas; bool good; rcu_read_lock(); pas = psp_skb_get_assoc_rcu(skb); good = !pas || rcu_access_pointer(dev->psp_dev) == pas->psd; rcu_read_unlock(); if (!good) { sk_skb_reason_drop(sk, skb, SKB_DROP_REASON_PSP_OUTPUT); return NULL; } return skb; } struct psp_assoc *psp_assoc_create(struct psp_dev *psd) { struct psp_assoc *pas; lockdep_assert_held(&psd->lock); pas = kzalloc(struct_size(pas, drv_data, psd->caps->assoc_drv_spc), GFP_KERNEL_ACCOUNT); if (!pas) return NULL; pas->psd = psd; pas->dev_id = psd->id; pas->generation = psd->generation; psp_dev_get(psd); refcount_set(&pas->refcnt, 1); list_add_tail(&pas->assocs_list, &psd->active_assocs); return pas; } static struct psp_assoc *psp_assoc_dummy(struct psp_assoc *pas) { struct psp_dev *psd = pas->psd; size_t sz; lockdep_assert_held(&psd->lock); sz = struct_size(pas, drv_data, psd->caps->assoc_drv_spc); return kmemdup(pas, sz, GFP_KERNEL); } static int psp_dev_tx_key_add(struct psp_dev *psd, struct psp_assoc *pas, struct netlink_ext_ack *extack) { return psd->ops->tx_key_add(psd, pas, extack); } void psp_dev_tx_key_del(struct psp_dev *psd, struct psp_assoc *pas) { if (pas->tx.spi) psd->ops->tx_key_del(psd, pas); list_del(&pas->assocs_list); } static void psp_assoc_free(struct work_struct *work) { struct psp_assoc *pas = container_of(work, struct psp_assoc, work); struct psp_dev *psd = pas->psd; mutex_lock(&psd->lock); if (psd->ops) psp_dev_tx_key_del(psd, pas); mutex_unlock(&psd->lock); psp_dev_put(psd); kfree(pas); } static void psp_assoc_free_queue(struct rcu_head *head) { struct psp_assoc *pas = container_of(head, struct psp_assoc, rcu); INIT_WORK(&pas->work, psp_assoc_free); schedule_work(&pas->work); } /** * psp_assoc_put() - release a reference on a PSP association * @pas: association to release */ void psp_assoc_put(struct psp_assoc *pas) { if (pas && refcount_dec_and_test(&pas->refcnt)) call_rcu(&pas->rcu, psp_assoc_free_queue); } void psp_sk_assoc_free(struct sock *sk) { struct psp_assoc *pas = rcu_dereference_protected(sk->psp_assoc, 1); rcu_assign_pointer(sk->psp_assoc, NULL); psp_assoc_put(pas); } int psp_sock_assoc_set_rx(struct sock *sk, struct psp_assoc *pas, struct psp_key_parsed *key, struct netlink_ext_ack *extack) { int err; memcpy(&pas->rx, key, sizeof(*key)); lock_sock(sk); if (psp_sk_assoc(sk)) { NL_SET_ERR_MSG(extack, "Socket already has PSP state"); err = -EBUSY; goto exit_unlock; } refcount_inc(&pas->refcnt); rcu_assign_pointer(sk->psp_assoc, pas); err = 0; exit_unlock: release_sock(sk); return err; } static int psp_sock_recv_queue_check(struct sock *sk, struct psp_assoc *pas) { struct psp_skb_ext *pse; struct sk_buff *skb; skb_rbtree_walk(skb, &tcp_sk(sk)->out_of_order_queue) { pse = skb_ext_find(skb, SKB_EXT_PSP); if (!psp_pse_matches_pas(pse, pas)) return -EBUSY; } skb_queue_walk(&sk->sk_receive_queue, skb) { pse = skb_ext_find(skb, SKB_EXT_PSP); if (!psp_pse_matches_pas(pse, pas)) return -EBUSY; } return 0; } int psp_sock_assoc_set_tx(struct sock *sk, struct psp_dev *psd, u32 version, struct psp_key_parsed *key, struct netlink_ext_ack *extack) { struct inet_connection_sock *icsk; struct psp_assoc *pas, *dummy; int err; lock_sock(sk); pas = psp_sk_assoc(sk); if (!pas) { NL_SET_ERR_MSG(extack, "Socket has no Rx key"); err = -EINVAL; goto exit_unlock; } if (pas->psd != psd) { NL_SET_ERR_MSG(extack, "Rx key from different device"); err = -EINVAL; goto exit_unlock; } if (pas->version != version) { NL_SET_ERR_MSG(extack, "PSP version mismatch with existing state"); err = -EINVAL; goto exit_unlock; } if (pas->tx.spi) { NL_SET_ERR_MSG(extack, "Tx key already set"); err = -EBUSY; goto exit_unlock; } err = psp_sock_recv_queue_check(sk, pas); if (err) { NL_SET_ERR_MSG(extack, "Socket has incompatible segments already in the recv queue"); goto exit_unlock; } /* Pass a fake association to drivers to make sure they don't * try to store pointers to it. For re-keying we'll need to * re-allocate the assoc structures. */ dummy = psp_assoc_dummy(pas); if (!dummy) { err = -ENOMEM; goto exit_unlock; } memcpy(&dummy->tx, key, sizeof(*key)); err = psp_dev_tx_key_add(psd, dummy, extack); if (err) goto exit_free_dummy; memcpy(pas->drv_data, dummy->drv_data, psd->caps->assoc_drv_spc); memcpy(&pas->tx, key, sizeof(*key)); WRITE_ONCE(sk->sk_validate_xmit_skb, psp_validate_xmit); tcp_write_collapse_fence(sk); pas->upgrade_seq = tcp_sk(sk)->rcv_nxt; icsk = inet_csk(sk); icsk->icsk_ext_hdr_len += psp_sk_overhead(sk); icsk->icsk_sync_mss(sk, icsk->icsk_pmtu_cookie); exit_free_dummy: kfree(dummy); exit_unlock: release_sock(sk); return err; } void psp_assocs_key_rotated(struct psp_dev *psd) { struct psp_assoc *pas, *next; /* Mark the stale associations as invalid, they will no longer * be able to Rx any traffic. */ list_for_each_entry_safe(pas, next, &psd->prev_assocs, assocs_list) pas->generation |= ~PSP_GEN_VALID_MASK; list_splice_init(&psd->prev_assocs, &psd->stale_assocs); list_splice_init(&psd->active_assocs, &psd->prev_assocs); /* TODO: we should inform the sockets that got shut down */ } void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk) { struct psp_assoc *pas = psp_sk_assoc(sk); if (pas) refcount_inc(&pas->refcnt); rcu_assign_pointer(tw->psp_assoc, pas); tw->tw_validate_xmit_skb = psp_validate_xmit; } void psp_twsk_assoc_free(struct inet_timewait_sock *tw) { struct psp_assoc *pas = rcu_dereference_protected(tw->psp_assoc, 1); rcu_assign_pointer(tw->psp_assoc, NULL); psp_assoc_put(pas); } void psp_reply_set_decrypted(struct sk_buff *skb) { struct psp_assoc *pas; rcu_read_lock(); pas = psp_sk_get_assoc_rcu(skb->sk); if (pas && pas->tx.spi) skb->decrypted = 1; rcu_read_unlock(); } EXPORT_IPV6_MOD_GPL(psp_reply_set_decrypted);