diff --git a/net/core/sock_diag.c b/net/core/sock_diag.c index 73b2e36032b3..e6ea6764d10a 100644 --- a/net/core/sock_diag.c +++ b/net/core/sock_diag.c @@ -16,7 +16,7 @@ #include #include -static const struct sock_diag_handler __rcu *sock_diag_handlers[AF_MAX]; +static const struct sock_diag_handler *sock_diag_handlers[AF_MAX]; static int (*inet_rcv_compat)(struct sk_buff *skb, struct nlmsghdr *nlh); static DEFINE_MUTEX(sock_diag_table_mutex); static struct workqueue_struct *broadcast_wq; @@ -119,24 +119,6 @@ static size_t sock_diag_nlmsg_size(void) + nla_total_size_64bit(sizeof(struct tcp_info))); /* INET_DIAG_INFO */ } -static const struct sock_diag_handler *sock_diag_lock_handler(int family) -{ - const struct sock_diag_handler *handler; - - rcu_read_lock(); - handler = rcu_dereference(sock_diag_handlers[family]); - if (handler && !try_module_get(handler->owner)) - handler = NULL; - rcu_read_unlock(); - - return handler; -} - -static void sock_diag_unlock_handler(const struct sock_diag_handler *handler) -{ - module_put(handler->owner); -} - static void sock_diag_broadcast_destroy_work(struct work_struct *work) { struct broadcast_sk *bsk = @@ -153,12 +135,12 @@ static void sock_diag_broadcast_destroy_work(struct work_struct *work) if (!skb) goto out; - hndl = sock_diag_lock_handler(sk->sk_family); - if (hndl) { - if (hndl->get_info) - err = hndl->get_info(skb, sk); - sock_diag_unlock_handler(hndl); - } + mutex_lock(&sock_diag_table_mutex); + hndl = sock_diag_handlers[sk->sk_family]; + if (hndl && hndl->get_info) + err = hndl->get_info(skb, sk); + mutex_unlock(&sock_diag_table_mutex); + if (!err) nlmsg_multicast(sock_net(sk)->diag_nlsk, skb, 0, group, GFP_KERNEL); @@ -199,26 +181,33 @@ EXPORT_SYMBOL_GPL(sock_diag_unregister_inet_compat); int sock_diag_register(const struct sock_diag_handler *hndl) { - int family = hndl->family; + int err = 0; - if (family >= AF_MAX) + if (hndl->family >= AF_MAX) return -EINVAL; - return !cmpxchg((const struct sock_diag_handler **) - &sock_diag_handlers[family], - NULL, hndl) ? 0 : -EBUSY; + mutex_lock(&sock_diag_table_mutex); + if (sock_diag_handlers[hndl->family]) + err = -EBUSY; + else + WRITE_ONCE(sock_diag_handlers[hndl->family], hndl); + mutex_unlock(&sock_diag_table_mutex); + + return err; } EXPORT_SYMBOL_GPL(sock_diag_register); -void sock_diag_unregister(const struct sock_diag_handler *hndl) +void sock_diag_unregister(const struct sock_diag_handler *hnld) { - int family = hndl->family; + int family = hnld->family; if (family >= AF_MAX) return; - xchg((const struct sock_diag_handler **)&sock_diag_handlers[family], - NULL); + mutex_lock(&sock_diag_table_mutex); + BUG_ON(sock_diag_handlers[family] != hnld); + WRITE_ONCE(sock_diag_handlers[family], NULL); + mutex_unlock(&sock_diag_table_mutex); } EXPORT_SYMBOL_GPL(sock_diag_unregister); @@ -235,20 +224,20 @@ static int __sock_diag_cmd(struct sk_buff *skb, struct nlmsghdr *nlh) return -EINVAL; req->sdiag_family = array_index_nospec(req->sdiag_family, AF_MAX); - if (!rcu_access_pointer(sock_diag_handlers[req->sdiag_family])) + if (READ_ONCE(sock_diag_handlers[req->sdiag_family]) == NULL) sock_load_diag_module(req->sdiag_family, 0); - hndl = sock_diag_lock_handler(req->sdiag_family); + mutex_lock(&sock_diag_table_mutex); + hndl = sock_diag_handlers[req->sdiag_family]; if (hndl == NULL) - return -ENOENT; - - if (nlh->nlmsg_type == SOCK_DIAG_BY_FAMILY) + err = -ENOENT; + else if (nlh->nlmsg_type == SOCK_DIAG_BY_FAMILY) err = hndl->dump(skb, nlh); else if (nlh->nlmsg_type == SOCK_DESTROY && hndl->destroy) err = hndl->destroy(skb, nlh); else err = -EOPNOTSUPP; - sock_diag_unlock_handler(hndl); + mutex_unlock(&sock_diag_table_mutex); return err; } @@ -294,12 +283,12 @@ static int sock_diag_bind(struct net *net, int group) switch (group) { case SKNLGRP_INET_TCP_DESTROY: case SKNLGRP_INET_UDP_DESTROY: - if (!rcu_access_pointer(sock_diag_handlers[AF_INET])) + if (!READ_ONCE(sock_diag_handlers[AF_INET])) sock_load_diag_module(AF_INET, 0); break; case SKNLGRP_INET6_TCP_DESTROY: case SKNLGRP_INET6_UDP_DESTROY: - if (!rcu_access_pointer(sock_diag_handlers[AF_INET6])) + if (!READ_ONCE(sock_diag_handlers[AF_INET6])) sock_load_diag_module(AF_INET6, 0); break; }