// SPDX-License-Identifier: GPL-2.0-only #include #include #include #include #include #include "psp-nl-gen.h" #include "psp.h" /* Netlink helpers */ static struct sk_buff *psp_nl_reply_new(struct genl_info *info) { struct sk_buff *rsp; void *hdr; rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL); if (!rsp) return NULL; hdr = genlmsg_iput(rsp, info); if (!hdr) { nlmsg_free(rsp); return NULL; } return rsp; } static int psp_nl_reply_send(struct sk_buff *rsp, struct genl_info *info) { /* Note that this *only* works with a single message per skb! */ nlmsg_end(rsp, (struct nlmsghdr *)rsp->data); return genlmsg_reply(rsp, info); } /* Device stuff */ static struct psp_dev * psp_device_get_and_lock(struct net *net, struct nlattr *dev_id) { struct psp_dev *psd; int err; mutex_lock(&psp_devs_lock); psd = xa_load(&psp_devs, nla_get_u32(dev_id)); if (!psd) { mutex_unlock(&psp_devs_lock); return ERR_PTR(-ENODEV); } mutex_lock(&psd->lock); mutex_unlock(&psp_devs_lock); err = psp_dev_check_access(psd, net); if (err) { mutex_unlock(&psd->lock); return ERR_PTR(err); } return psd; } int psp_device_get_locked(const struct genl_split_ops *ops, struct sk_buff *skb, struct genl_info *info) { if (GENL_REQ_ATTR_CHECK(info, PSP_A_DEV_ID)) return -EINVAL; info->user_ptr[0] = psp_device_get_and_lock(genl_info_net(info), info->attrs[PSP_A_DEV_ID]); return PTR_ERR_OR_ZERO(info->user_ptr[0]); } void psp_device_unlock(const struct genl_split_ops *ops, struct sk_buff *skb, struct genl_info *info) { struct socket *socket = info->user_ptr[1]; struct psp_dev *psd = info->user_ptr[0]; mutex_unlock(&psd->lock); if (socket) sockfd_put(socket); } static int psp_nl_dev_fill(struct psp_dev *psd, struct sk_buff *rsp, const struct genl_info *info) { void *hdr; hdr = genlmsg_iput(rsp, info); if (!hdr) return -EMSGSIZE; if (nla_put_u32(rsp, PSP_A_DEV_ID, psd->id) || nla_put_u32(rsp, PSP_A_DEV_IFINDEX, psd->main_netdev->ifindex) || nla_put_u32(rsp, PSP_A_DEV_PSP_VERSIONS_CAP, psd->caps->versions) || nla_put_u32(rsp, PSP_A_DEV_PSP_VERSIONS_ENA, psd->config.versions)) goto err_cancel_msg; genlmsg_end(rsp, hdr); return 0; err_cancel_msg: genlmsg_cancel(rsp, hdr); return -EMSGSIZE; } void psp_nl_notify_dev(struct psp_dev *psd, u32 cmd) { struct genl_info info; struct sk_buff *ntf; if (!genl_has_listeners(&psp_nl_family, dev_net(psd->main_netdev), PSP_NLGRP_MGMT)) return; ntf = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL); if (!ntf) return; genl_info_init_ntf(&info, &psp_nl_family, cmd); if (psp_nl_dev_fill(psd, ntf, &info)) { nlmsg_free(ntf); return; } genlmsg_multicast_netns(&psp_nl_family, dev_net(psd->main_netdev), ntf, 0, PSP_NLGRP_MGMT, GFP_KERNEL); } int psp_nl_dev_get_doit(struct sk_buff *req, struct genl_info *info) { struct psp_dev *psd = info->user_ptr[0]; struct sk_buff *rsp; int err; rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL); if (!rsp) return -ENOMEM; err = psp_nl_dev_fill(psd, rsp, info); if (err) goto err_free_msg; return genlmsg_reply(rsp, info); err_free_msg: nlmsg_free(rsp); return err; } static int psp_nl_dev_get_dumpit_one(struct sk_buff *rsp, struct netlink_callback *cb, struct psp_dev *psd) { if (psp_dev_check_access(psd, sock_net(rsp->sk))) return 0; return psp_nl_dev_fill(psd, rsp, genl_info_dump(cb)); } int psp_nl_dev_get_dumpit(struct sk_buff *rsp, struct netlink_callback *cb) { struct psp_dev *psd; int err = 0; mutex_lock(&psp_devs_lock); xa_for_each_start(&psp_devs, cb->args[0], psd, cb->args[0]) { mutex_lock(&psd->lock); err = psp_nl_dev_get_dumpit_one(rsp, cb, psd); mutex_unlock(&psd->lock); if (err) break; } mutex_unlock(&psp_devs_lock); return err; } int psp_nl_dev_set_doit(struct sk_buff *skb, struct genl_info *info) { struct psp_dev *psd = info->user_ptr[0]; struct psp_dev_config new_config; struct sk_buff *rsp; int err; memcpy(&new_config, &psd->config, sizeof(new_config)); if (info->attrs[PSP_A_DEV_PSP_VERSIONS_ENA]) { new_config.versions = nla_get_u32(info->attrs[PSP_A_DEV_PSP_VERSIONS_ENA]); if (new_config.versions & ~psd->caps->versions) { NL_SET_ERR_MSG(info->extack, "Requested PSP versions not supported by the device"); return -EINVAL; } } else { NL_SET_ERR_MSG(info->extack, "No settings present"); return -EINVAL; } rsp = psp_nl_reply_new(info); if (!rsp) return -ENOMEM; if (memcmp(&new_config, &psd->config, sizeof(new_config))) { err = psd->ops->set_config(psd, &new_config, info->extack); if (err) goto err_free_rsp; memcpy(&psd->config, &new_config, sizeof(new_config)); } psp_nl_notify_dev(psd, PSP_CMD_DEV_CHANGE_NTF); return psp_nl_reply_send(rsp, info); err_free_rsp: nlmsg_free(rsp); return err; } int psp_nl_key_rotate_doit(struct sk_buff *skb, struct genl_info *info) { struct psp_dev *psd = info->user_ptr[0]; struct genl_info ntf_info; struct sk_buff *ntf, *rsp; u8 prev_gen; int err; rsp = psp_nl_reply_new(info); if (!rsp) return -ENOMEM; genl_info_init_ntf(&ntf_info, &psp_nl_family, PSP_CMD_KEY_ROTATE_NTF); ntf = psp_nl_reply_new(&ntf_info); if (!ntf) { err = -ENOMEM; goto err_free_rsp; } if (nla_put_u32(rsp, PSP_A_DEV_ID, psd->id) || nla_put_u32(ntf, PSP_A_DEV_ID, psd->id)) { err = -EMSGSIZE; goto err_free_ntf; } /* suggest the next gen number, driver can override */ prev_gen = psd->generation; psd->generation = (prev_gen + 1) & PSP_GEN_VALID_MASK; err = psd->ops->key_rotate(psd, info->extack); if (err) goto err_free_ntf; WARN_ON_ONCE((psd->generation && psd->generation == prev_gen) || psd->generation & ~PSP_GEN_VALID_MASK); psp_assocs_key_rotated(psd); nlmsg_end(ntf, (struct nlmsghdr *)ntf->data); genlmsg_multicast_netns(&psp_nl_family, dev_net(psd->main_netdev), ntf, 0, PSP_NLGRP_USE, GFP_KERNEL); return psp_nl_reply_send(rsp, info); err_free_ntf: nlmsg_free(ntf); err_free_rsp: nlmsg_free(rsp); return err; } /* Key etc. */ int psp_assoc_device_get_locked(const struct genl_split_ops *ops, struct sk_buff *skb, struct genl_info *info) { struct socket *socket; struct psp_dev *psd; struct nlattr *id; int fd, err; if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_SOCK_FD)) return -EINVAL; fd = nla_get_u32(info->attrs[PSP_A_ASSOC_SOCK_FD]); socket = sockfd_lookup(fd, &err); if (!socket) return err; if (!sk_is_tcp(socket->sk)) { NL_SET_ERR_MSG_ATTR(info->extack, info->attrs[PSP_A_ASSOC_SOCK_FD], "Unsupported socket family and type"); err = -EOPNOTSUPP; goto err_sock_put; } psd = psp_dev_get_for_sock(socket->sk); if (psd) { err = psp_dev_check_access(psd, genl_info_net(info)); if (err) { psp_dev_put(psd); psd = NULL; } } if (!psd && GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_DEV_ID)) { err = -EINVAL; goto err_sock_put; } id = info->attrs[PSP_A_ASSOC_DEV_ID]; if (psd) { mutex_lock(&psd->lock); if (id && psd->id != nla_get_u32(id)) { mutex_unlock(&psd->lock); NL_SET_ERR_MSG_ATTR(info->extack, id, "Device id vs socket mismatch"); err = -EINVAL; goto err_psd_put; } psp_dev_put(psd); } else { psd = psp_device_get_and_lock(genl_info_net(info), id); if (IS_ERR(psd)) { err = PTR_ERR(psd); goto err_sock_put; } } info->user_ptr[0] = psd; info->user_ptr[1] = socket; return 0; err_psd_put: psp_dev_put(psd); err_sock_put: sockfd_put(socket); return err; } static int psp_nl_parse_key(struct genl_info *info, u32 attr, struct psp_key_parsed *key, unsigned int key_sz) { struct nlattr *nest = info->attrs[attr]; struct nlattr *tb[PSP_A_KEYS_SPI + 1]; u32 spi; int err; err = nla_parse_nested(tb, ARRAY_SIZE(tb) - 1, nest, psp_keys_nl_policy, info->extack); if (err) return err; if (NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_KEY) || NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_SPI)) return -EINVAL; if (nla_len(tb[PSP_A_KEYS_KEY]) != key_sz) { NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY], "incorrect key length"); return -EINVAL; } spi = nla_get_u32(tb[PSP_A_KEYS_SPI]); if (!(spi & PSP_SPI_KEY_ID)) { NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY], "invalid SPI: lower 31b must be non-zero"); return -EINVAL; } key->spi = cpu_to_be32(spi); memcpy(key->key, nla_data(tb[PSP_A_KEYS_KEY]), key_sz); return 0; } static int psp_nl_put_key(struct sk_buff *skb, u32 attr, u32 version, struct psp_key_parsed *key) { int key_sz = psp_key_size(version); void *nest; nest = nla_nest_start(skb, attr); if (nla_put_u32(skb, PSP_A_KEYS_SPI, be32_to_cpu(key->spi)) || nla_put(skb, PSP_A_KEYS_KEY, key_sz, key->key)) { nla_nest_cancel(skb, nest); return -EMSGSIZE; } nla_nest_end(skb, nest); return 0; } int psp_nl_rx_assoc_doit(struct sk_buff *skb, struct genl_info *info) { struct socket *socket = info->user_ptr[1]; struct psp_dev *psd = info->user_ptr[0]; struct psp_key_parsed key; struct psp_assoc *pas; struct sk_buff *rsp; u32 version; int err; if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION)) return -EINVAL; version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]); if (!(psd->caps->versions & (1 << version))) { NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]); return -EOPNOTSUPP; } rsp = psp_nl_reply_new(info); if (!rsp) return -ENOMEM; pas = psp_assoc_create(psd); if (!pas) { err = -ENOMEM; goto err_free_rsp; } pas->version = version; err = psd->ops->rx_spi_alloc(psd, version, &key, info->extack); if (err) goto err_free_pas; if (nla_put_u32(rsp, PSP_A_ASSOC_DEV_ID, psd->id) || psp_nl_put_key(rsp, PSP_A_ASSOC_RX_KEY, version, &key)) { err = -EMSGSIZE; goto err_free_pas; } err = psp_sock_assoc_set_rx(socket->sk, pas, &key, info->extack); if (err) { NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_SOCK_FD]); goto err_free_pas; } psp_assoc_put(pas); return psp_nl_reply_send(rsp, info); err_free_pas: psp_assoc_put(pas); err_free_rsp: nlmsg_free(rsp); return err; } int psp_nl_tx_assoc_doit(struct sk_buff *skb, struct genl_info *info) { struct socket *socket = info->user_ptr[1]; struct psp_dev *psd = info->user_ptr[0]; struct psp_key_parsed key; struct sk_buff *rsp; unsigned int key_sz; u32 version; int err; if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION) || GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_TX_KEY)) return -EINVAL; version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]); if (!(psd->caps->versions & (1 << version))) { NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]); return -EOPNOTSUPP; } key_sz = psp_key_size(version); if (!key_sz) return -EINVAL; err = psp_nl_parse_key(info, PSP_A_ASSOC_TX_KEY, &key, key_sz); if (err < 0) return err; rsp = psp_nl_reply_new(info); if (!rsp) return -ENOMEM; err = psp_sock_assoc_set_tx(socket->sk, psd, version, &key, info->extack); if (err) goto err_free_msg; return psp_nl_reply_send(rsp, info); err_free_msg: nlmsg_free(rsp); return err; }