diff --git a/include/linux/net.h b/include/linux/net.h
index 91216b16feb78e0a19b2129aa5c390fd0cd7eaff..2a0391eea05c3612b2ef560ed4852811fd16dfff 100644
--- a/include/linux/net.h
+++ b/include/linux/net.h
@@ -222,6 +222,7 @@ enum {
 int sock_wake_async(struct socket_wq *sk_wq, int how, int band);
 int sock_register(const struct net_proto_family *fam);
 void sock_unregister(int family);
+bool sock_is_registered(int family);
 int __sock_create(struct net *net, int family, int type, int proto,
 		  struct socket **res, int kern);
 int sock_create(int family, int type, int proto, struct socket **res);
diff --git a/include/net/sock.h b/include/net/sock.h
index 169c92afcafa3d548f8238e91606b87c187559f4..ae23f3b389cac6bcf7f270a4cd2ae3498b2d5ad5 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -1137,6 +1137,7 @@ struct proto {
 
 int proto_register(struct proto *prot, int alloc_slab);
 void proto_unregister(struct proto *prot);
+int sock_load_diag_module(int family, int protocol);
 
 #ifdef SOCK_REFCNT_DEBUG
 static inline void sk_refcnt_debug_inc(struct sock *sk)
diff --git a/net/core/sock.c b/net/core/sock.c
index c501499a04fe973e80e18655b306d762d348ff44..85b0b64e7f9dd565b5e85579ce43b78ddb65dc86 100644
--- a/net/core/sock.c
+++ b/net/core/sock.c
@@ -3261,6 +3261,27 @@ void proto_unregister(struct proto *prot)
 }
 EXPORT_SYMBOL(proto_unregister);
 
+int sock_load_diag_module(int family, int protocol)
+{
+	if (!protocol) {
+		if (!sock_is_registered(family))
+			return -ENOENT;
+
+		return request_module("net-pf-%d-proto-%d-type-%d", PF_NETLINK,
+				      NETLINK_SOCK_DIAG, family);
+	}
+
+#ifdef CONFIG_INET
+	if (family == AF_INET &&
+	    !rcu_access_pointer(inet_protos[protocol]))
+		return -ENOENT;
+#endif
+
+	return request_module("net-pf-%d-proto-%d-type-%d-%d", PF_NETLINK,
+			      NETLINK_SOCK_DIAG, family, protocol);
+}
+EXPORT_SYMBOL(sock_load_diag_module);
+
 #ifdef CONFIG_PROC_FS
 static void *proto_seq_start(struct seq_file *seq, loff_t *pos)
 	__acquires(proto_list_mutex)
diff --git a/net/core/sock_diag.c b/net/core/sock_diag.c
index 146b50e30659daca3bdc4413d800890ef482f521..c37b5be7c5e4f0b4b91267b34c5ba867e90cbc69 100644
--- a/net/core/sock_diag.c
+++ b/net/core/sock_diag.c
@@ -220,8 +220,7 @@ static int __sock_diag_cmd(struct sk_buff *skb, struct nlmsghdr *nlh)
 		return -EINVAL;
 
 	if (sock_diag_handlers[req->sdiag_family] == NULL)
-		request_module("net-pf-%d-proto-%d-type-%d", PF_NETLINK,
-				NETLINK_SOCK_DIAG, req->sdiag_family);
+		sock_load_diag_module(req->sdiag_family, 0);
 
 	mutex_lock(&sock_diag_table_mutex);
 	hndl = sock_diag_handlers[req->sdiag_family];
@@ -247,8 +246,7 @@ static int sock_diag_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh,
 	case TCPDIAG_GETSOCK:
 	case DCCPDIAG_GETSOCK:
 		if (inet_rcv_compat == NULL)
-			request_module("net-pf-%d-proto-%d-type-%d", PF_NETLINK,
-					NETLINK_SOCK_DIAG, AF_INET);
+			sock_load_diag_module(AF_INET, 0);
 
 		mutex_lock(&sock_diag_table_mutex);
 		if (inet_rcv_compat != NULL)
@@ -281,14 +279,12 @@ static int sock_diag_bind(struct net *net, int group)
 	case SKNLGRP_INET_TCP_DESTROY:
 	case SKNLGRP_INET_UDP_DESTROY:
 		if (!sock_diag_handlers[AF_INET])
-			request_module("net-pf-%d-proto-%d-type-%d", PF_NETLINK,
-				       NETLINK_SOCK_DIAG, AF_INET);
+			sock_load_diag_module(AF_INET, 0);
 		break;
 	case SKNLGRP_INET6_TCP_DESTROY:
 	case SKNLGRP_INET6_UDP_DESTROY:
 		if (!sock_diag_handlers[AF_INET6])
-			request_module("net-pf-%d-proto-%d-type-%d", PF_NETLINK,
-				       NETLINK_SOCK_DIAG, AF_INET6);
+			sock_load_diag_module(AF_INET6, 0);
 		break;
 	}
 	return 0;
diff --git a/net/ipv4/inet_diag.c b/net/ipv4/inet_diag.c
index a383f299ce246bd84b28bad1909bae600ad699c0..4e5bc4b2f14e6786ceb7d63e5902f8fc17819dfa 100644
--- a/net/ipv4/inet_diag.c
+++ b/net/ipv4/inet_diag.c
@@ -53,8 +53,7 @@ static DEFINE_MUTEX(inet_diag_table_mutex);
 static const struct inet_diag_handler *inet_diag_lock_handler(int proto)
 {
 	if (!inet_diag_table[proto])
-		request_module("net-pf-%d-proto-%d-type-%d-%d", PF_NETLINK,
-			       NETLINK_SOCK_DIAG, AF_INET, proto);
+		sock_load_diag_module(AF_INET, proto);
 
 	mutex_lock(&inet_diag_table_mutex);
 	if (!inet_diag_table[proto])
diff --git a/net/socket.c b/net/socket.c
index a93c99b518cab69b67d53a71298d15ce17c4d06b..08847c3b8c39f708a0d11fb35184852b22a40835 100644
--- a/net/socket.c
+++ b/net/socket.c
@@ -2587,6 +2587,11 @@ void sock_unregister(int family)
 }
 EXPORT_SYMBOL(sock_unregister);
 
+bool sock_is_registered(int family)
+{
+	return family < NPROTO && rcu_access_pointer(net_families[family]);
+}
+
 static int __init sock_init(void)
 {
 	int err;