diff --git a/include/linux/net.h b/include/linux/net.h index 91216b16feb78e0a19b2129aa5c390fd0cd7eaff..2a0391eea05c3612b2ef560ed4852811fd16dfff 100644 --- a/include/linux/net.h +++ b/include/linux/net.h @@ -222,6 +222,7 @@ enum { int sock_wake_async(struct socket_wq *sk_wq, int how, int band); int sock_register(const struct net_proto_family *fam); void sock_unregister(int family); +bool sock_is_registered(int family); int __sock_create(struct net *net, int family, int type, int proto, struct socket **res, int kern); int sock_create(int family, int type, int proto, struct socket **res); diff --git a/include/net/sock.h b/include/net/sock.h index 169c92afcafa3d548f8238e91606b87c187559f4..ae23f3b389cac6bcf7f270a4cd2ae3498b2d5ad5 100644 --- a/include/net/sock.h +++ b/include/net/sock.h @@ -1137,6 +1137,7 @@ struct proto { int proto_register(struct proto *prot, int alloc_slab); void proto_unregister(struct proto *prot); +int sock_load_diag_module(int family, int protocol); #ifdef SOCK_REFCNT_DEBUG static inline void sk_refcnt_debug_inc(struct sock *sk) diff --git a/net/core/sock.c b/net/core/sock.c index c501499a04fe973e80e18655b306d762d348ff44..85b0b64e7f9dd565b5e85579ce43b78ddb65dc86 100644 --- a/net/core/sock.c +++ b/net/core/sock.c @@ -3261,6 +3261,27 @@ void proto_unregister(struct proto *prot) } EXPORT_SYMBOL(proto_unregister); +int sock_load_diag_module(int family, int protocol) +{ + if (!protocol) { + if (!sock_is_registered(family)) + return -ENOENT; + + return request_module("net-pf-%d-proto-%d-type-%d", PF_NETLINK, + NETLINK_SOCK_DIAG, family); + } + +#ifdef CONFIG_INET + if (family == AF_INET && + !rcu_access_pointer(inet_protos[protocol])) + return -ENOENT; +#endif + + return request_module("net-pf-%d-proto-%d-type-%d-%d", PF_NETLINK, + NETLINK_SOCK_DIAG, family, protocol); +} +EXPORT_SYMBOL(sock_load_diag_module); + #ifdef CONFIG_PROC_FS static void *proto_seq_start(struct seq_file *seq, loff_t *pos) __acquires(proto_list_mutex) diff --git a/net/core/sock_diag.c b/net/core/sock_diag.c index 146b50e30659daca3bdc4413d800890ef482f521..c37b5be7c5e4f0b4b91267b34c5ba867e90cbc69 100644 --- a/net/core/sock_diag.c +++ b/net/core/sock_diag.c @@ -220,8 +220,7 @@ static int __sock_diag_cmd(struct sk_buff *skb, struct nlmsghdr *nlh) return -EINVAL; if (sock_diag_handlers[req->sdiag_family] == NULL) - request_module("net-pf-%d-proto-%d-type-%d", PF_NETLINK, - NETLINK_SOCK_DIAG, req->sdiag_family); + sock_load_diag_module(req->sdiag_family, 0); mutex_lock(&sock_diag_table_mutex); hndl = sock_diag_handlers[req->sdiag_family]; @@ -247,8 +246,7 @@ static int sock_diag_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh, case TCPDIAG_GETSOCK: case DCCPDIAG_GETSOCK: if (inet_rcv_compat == NULL) - request_module("net-pf-%d-proto-%d-type-%d", PF_NETLINK, - NETLINK_SOCK_DIAG, AF_INET); + sock_load_diag_module(AF_INET, 0); mutex_lock(&sock_diag_table_mutex); if (inet_rcv_compat != NULL) @@ -281,14 +279,12 @@ static int sock_diag_bind(struct net *net, int group) case SKNLGRP_INET_TCP_DESTROY: case SKNLGRP_INET_UDP_DESTROY: if (!sock_diag_handlers[AF_INET]) - request_module("net-pf-%d-proto-%d-type-%d", PF_NETLINK, - NETLINK_SOCK_DIAG, AF_INET); + sock_load_diag_module(AF_INET, 0); break; case SKNLGRP_INET6_TCP_DESTROY: case SKNLGRP_INET6_UDP_DESTROY: if (!sock_diag_handlers[AF_INET6]) - request_module("net-pf-%d-proto-%d-type-%d", PF_NETLINK, - NETLINK_SOCK_DIAG, AF_INET6); + sock_load_diag_module(AF_INET6, 0); break; } return 0; diff --git a/net/ipv4/inet_diag.c b/net/ipv4/inet_diag.c index a383f299ce246bd84b28bad1909bae600ad699c0..4e5bc4b2f14e6786ceb7d63e5902f8fc17819dfa 100644 --- a/net/ipv4/inet_diag.c +++ b/net/ipv4/inet_diag.c @@ -53,8 +53,7 @@ static DEFINE_MUTEX(inet_diag_table_mutex); static const struct inet_diag_handler *inet_diag_lock_handler(int proto) { if (!inet_diag_table[proto]) - request_module("net-pf-%d-proto-%d-type-%d-%d", PF_NETLINK, - NETLINK_SOCK_DIAG, AF_INET, proto); + sock_load_diag_module(AF_INET, proto); mutex_lock(&inet_diag_table_mutex); if (!inet_diag_table[proto]) diff --git a/net/socket.c b/net/socket.c index a93c99b518cab69b67d53a71298d15ce17c4d06b..08847c3b8c39f708a0d11fb35184852b22a40835 100644 --- a/net/socket.c +++ b/net/socket.c @@ -2587,6 +2587,11 @@ void sock_unregister(int family) } EXPORT_SYMBOL(sock_unregister); +bool sock_is_registered(int family) +{ + return family < NPROTO && rcu_access_pointer(net_families[family]); +} + static int __init sock_init(void) { int err;