@@ -534,6 +534,12 @@ int ovpn_nl_peer_set_doit(struct sk_buff *skb, struct genl_info *info)
*/
if (ret > 0)
ovpn_peer_hash_vpn_ip(peer);
+ /* if the remote endpoint was updated, the by_transp_addr hash bucket
+ * also needs to be refreshed, otherwise incoming packets from the new
+ * remote address would fail the lockless lookup
+ */
+ if (attrs[OVPN_A_PEER_REMOTE_IPV4] || attrs[OVPN_A_PEER_REMOTE_IPV6])
+ ovpn_peer_hash_transp_addr(peer);
spin_unlock_bh(&ovpn->lock);
ovpn_peer_put(peer);
@@ -188,6 +188,9 @@ int ovpn_peer_reset_sockaddr(struct ovpn_peer *peer,
&(*__tbl1)[ovpn_get_hash_slot(*__tbl1, _key, _key_len)];\
})
+static void __ovpn_peer_hash_transp_addr(struct ovpn_peer *peer,
+ const struct ovpn_bind *bind);
+
/**
* ovpn_peer_endpoints_update - update remote or local endpoint for peer
* @peer: peer to update the remote endpoint for
@@ -195,7 +198,6 @@ int ovpn_peer_reset_sockaddr(struct ovpn_peer *peer,
*/
void ovpn_peer_endpoints_update(struct ovpn_peer *peer, struct sk_buff *skb)
{
- struct hlist_nulls_head *nhead;
struct sockaddr_storage ss;
struct sockaddr_in6 *sa6;
bool reset_cache = false;
@@ -294,49 +296,17 @@ void ovpn_peer_endpoints_update(struct ovpn_peer *peer, struct sk_buff *skb)
ovpn_nl_peer_float_notify(peer, &ss);
/* rehashing is required only in MP mode as P2P has one peer
- * only and thus there is no hashtable
+ * only and thus there is no hashtable.
+ *
+ * This function may be invoked concurrently, so re-read peer->bind
+ * under the proper locks and rehash against its current value.
*/
if (peer->ovpn->mode == OVPN_MODE_MP) {
spin_lock_bh(&peer->ovpn->lock);
spin_lock_bh(&peer->lock);
bind = rcu_dereference_protected(peer->bind,
lockdep_is_held(&peer->lock));
- if (unlikely(!bind)) {
- spin_unlock_bh(&peer->lock);
- spin_unlock_bh(&peer->ovpn->lock);
- return;
- }
-
- /* peer may have been concurrently removed between the caller's
- * initial lookup and our acquisition of ovpn->lock; skip the
- * rehash so we don't re-insert a removed peer
- */
- if (unlikely(hlist_unhashed(&peer->hash_entry_id))) {
- spin_unlock_bh(&peer->lock);
- spin_unlock_bh(&peer->ovpn->lock);
- return;
- }
-
- /* This function may be invoked concurrently, therefore another
- * float may have happened in parallel: perform rehashing
- * using the peer->bind->remote directly as key
- */
-
- switch (bind->remote.in4.sin_family) {
- case AF_INET:
- salen = sizeof(*sa);
- break;
- case AF_INET6:
- salen = sizeof(*sa6);
- break;
- }
-
- /* remove old hashing */
- hlist_nulls_del_init_rcu(&peer->hash_entry_transp_addr);
- /* re-add with new transport address */
- nhead = ovpn_get_hash_head(peer->ovpn->peers->by_transp_addr,
- &bind->remote, salen);
- hlist_nulls_add_head_rcu(&peer->hash_entry_transp_addr, nhead);
+ __ovpn_peer_hash_transp_addr(peer, bind);
spin_unlock_bh(&peer->lock);
spin_unlock_bh(&peer->ovpn->lock);
}
@@ -905,6 +875,66 @@ bool ovpn_peer_check_by_src(struct ovpn_priv *ovpn, struct sk_buff *skb,
return match;
}
+/* Move @peer to the by_transp_addr bucket matching its current bind.
+ *
+ * Caller must hold both peer->ovpn->lock and peer->lock, and must have
+ * already dereferenced a valid (non-NULL) peer->bind, passed in as @bind.
+ */
+static void __ovpn_peer_hash_transp_addr(struct ovpn_peer *peer,
+ const struct ovpn_bind *bind)
+{
+ struct hlist_nulls_head *nhead;
+ size_t salen;
+
+ lockdep_assert_held(&peer->ovpn->lock);
+ lockdep_assert_held(&peer->lock);
+
+ if (WARN_ON_ONCE(!bind))
+ return;
+
+ /* peer may have been concurrently removed between the caller's
+ * initial lookup and our acquisition of ovpn->lock; skip the
+ * rehash so we don't re-insert a removed peer
+ */
+ if (unlikely(hlist_unhashed(&peer->hash_entry_id)))
+ return;
+
+ switch (bind->remote.in4.sin_family) {
+ case AF_INET:
+ salen = sizeof(struct sockaddr_in);
+ break;
+ case AF_INET6:
+ salen = sizeof(struct sockaddr_in6);
+ break;
+ default:
+ return;
+ }
+
+ /* remove old hashing (no-op if entry is not currently linked) */
+ hlist_nulls_del_init_rcu(&peer->hash_entry_transp_addr);
+ /* re-add with current transport address */
+ nhead = ovpn_get_hash_head(peer->ovpn->peers->by_transp_addr,
+ &bind->remote, salen);
+ hlist_nulls_add_head_rcu(&peer->hash_entry_transp_addr, nhead);
+}
+
+void ovpn_peer_hash_transp_addr(struct ovpn_peer *peer)
+{
+ struct ovpn_bind *bind;
+
+ lockdep_assert_held(&peer->ovpn->lock);
+
+ /* rehashing makes sense only in multipeer mode */
+ if (peer->ovpn->mode != OVPN_MODE_MP)
+ return;
+
+ spin_lock_bh(&peer->lock);
+ bind = rcu_dereference_protected(peer->bind,
+ lockdep_is_held(&peer->lock));
+ __ovpn_peer_hash_transp_addr(peer, bind);
+ spin_unlock_bh(&peer->lock);
+}
+
void ovpn_peer_hash_vpn_ip(struct ovpn_peer *peer)
{
struct hlist_nulls_head *nhead;
@@ -150,6 +150,7 @@ struct ovpn_peer *ovpn_peer_get_by_id(struct ovpn_priv *ovpn, u32 peer_id);
struct ovpn_peer *ovpn_peer_get_by_dst(struct ovpn_priv *ovpn,
struct sk_buff *skb);
void ovpn_peer_hash_vpn_ip(struct ovpn_peer *peer);
+void ovpn_peer_hash_transp_addr(struct ovpn_peer *peer);
bool ovpn_peer_check_by_src(struct ovpn_priv *ovpn, struct sk_buff *skb,
struct ovpn_peer *peer);