diff --git a/drivers/net/ovpn/socket.c b/drivers/net/ovpn/socket.c
index a83cbab72591..66a2ecbc483b 100644
--- a/drivers/net/ovpn/socket.c
+++ b/drivers/net/ovpn/socket.c
@@ -66,6 +66,7 @@ static bool ovpn_socket_put(struct ovpn_peer *peer, struct ovpn_socket *sock)
 void ovpn_socket_release(struct ovpn_peer *peer)
 {
 	struct ovpn_socket *sock;
+	struct sock *sk;
 	bool released;
 
 	might_sleep();
@@ -75,13 +76,14 @@ void ovpn_socket_release(struct ovpn_peer *peer)
 	if (!sock)
 		return;
 
-	/* sanity check: we should not end up here if the socket
-	 * was already closed
+	/* sock->sk may be released concurrently, therefore we
+	 * first attempt grabbing a reference.
+	 * if sock->sk is NULL it means it is already being
+	 * destroyed and we don't need any further cleanup
 	 */
-	if (!sock->sock->sk) {
-		DEBUG_NET_WARN_ON_ONCE(1);
+	sk = sock->sock->sk;
+	if (!sk || !refcount_inc_not_zero(&sk->sk_refcnt))
 		return;
-	}
 
 	/* Drop the reference while holding the sock lock to avoid
 	 * concurrent ovpn_socket_new call to mess up with a partially
@@ -90,18 +92,18 @@ void ovpn_socket_release(struct ovpn_peer *peer)
 	 * Holding the lock ensures that a socket with refcnt 0 is fully
 	 * detached before it can be picked by a concurrent reader.
 	 */
-	lock_sock(sock->sock->sk);
+	lock_sock(sk);
 	released = ovpn_socket_put(peer, sock);
-	release_sock(sock->sock->sk);
+	release_sock(sk);
 
 	/* align all readers with sk_user_data being NULL */
 	synchronize_rcu();
 
 	/* following cleanup should happen with lock released */
 	if (released) {
-		if (sock->sock->sk->sk_protocol == IPPROTO_UDP) {
+		if (sk->sk_protocol == IPPROTO_UDP) {
 			netdev_put(sock->ovpn->dev, &sock->dev_tracker);
-		} else if (sock->sock->sk->sk_protocol == IPPROTO_TCP) {
+		} else if (sk->sk_protocol == IPPROTO_TCP) {
 			/* wait for TCP jobs to terminate */
 			ovpn_tcp_socket_wait_finish(sock);
 			ovpn_peer_put(sock->peer);
@@ -111,6 +113,7 @@ void ovpn_socket_release(struct ovpn_peer *peer)
 		 */
 		kfree(sock);
 	}
+	sock_put(sk);
 }
 
 static bool ovpn_socket_hold(struct ovpn_socket *sock)
