[Openvpn-devel,RFC,ovpn,net-next,5/5] ovpn: implement IGMPv3/MLDv2 SSM snooping

Message ID 20260514095210.288979-5-marco@mandelbit.com
State New
Headers show
Series [Openvpn-devel,RFC,ovpn,net-next,1/5] ovpn: add multicast/braodcast packet transmission support | expand

Commit Message

Marco Baffo May 14, 2026, 9:52 a.m. UTC
Extend the multicast subscription table to track per-source state,
enabling SSM (S,G) forwarding instead of group-only ASM forwarding.

Data structures:
- Add ovpn_mcast_source to track individual source addresses
- Extend ovpn_mcast_sub with filter_mode (INCLUDE/EXCLUDE) and a
  source list

Subscription API:
- Add ovpn_mcast_sub_update() to create or update a subscription
  with a full source list and filter mode
- ovpn_mcast_join() becomes a thin wrapper around sub_update()
  (EXCLUDE mode with empty source list = ASM join)
- Add ovpn_mcast_srcs_update() for incremental source merging:
  ALLOW_NEW adds sources in INCLUDE mode and removes them in EXCLUDE
  mode; BLOCK_OLD does the opposite
- Empty INCLUDE subscriptions are automatically deleted when
  BLOCK_OLD removes the last source

RX path (snooping):
- IGMPv3 and MLDv2 parsers now extract source lists from reports
  and pass them to sub_update()
- All record types are handled: MODE_IS_*, CHANGE_TO_*,
  ALLOW_NEW_SOURCES, BLOCK_OLD_SOURCES

TX path (forwarding):
- Add ovpn_mcast_src_allowed() to evaluate a source against a peer's
  filter mode and source list
- ovpn_peer_list_get_by_mcast_group() now takes a source address
  and only returns peers whose subscription allows the source
- ASM backward compatibility preserved: EXCLUDE with empty source
  list allows all sources

Signed-off-by: Marco Baffo <marco@mandelbit.com>
---
 drivers/net/ovpn/mcast.c | 334 +++++++++++++++++++++++++++++++++------
 drivers/net/ovpn/mcast.h |  14 +-
 drivers/net/ovpn/peer.c  |   7 +-
 3 files changed, 305 insertions(+), 50 deletions(-)

Patch

diff --git a/drivers/net/ovpn/mcast.c b/drivers/net/ovpn/mcast.c
index 1e436a6721bb..74b791ad7489 100644
--- a/drivers/net/ovpn/mcast.c
+++ b/drivers/net/ovpn/mcast.c
@@ -17,9 +17,16 @@  struct ovpn_mcast_group {
 	struct list_head subs;
 };
 
+struct ovpn_mcast_source {
+	struct list_head list;
+	struct in6_addr addr;
+};
+
 struct ovpn_mcast_sub {
 	struct list_head list;
 	struct ovpn_peer *peer;
+	enum ovpn_mcast_filter_mode filter_mode;
+	struct list_head sources;
 };
 
 static inline u32 ovpn_mcast_hash(const struct in6_addr *group_addr)
@@ -47,10 +54,21 @@  ovpn_mcast_group_find(const struct ovpn_priv *ovpn, const struct in6_addr *group
 	return NULL;
 }
 
+static void ovpn_mcast_srcs_del_all(struct list_head *srcs)
+{
+	struct ovpn_mcast_source *src, *next;
+
+	list_for_each_entry_safe(src, next, srcs, list) {
+		list_del(&src->list);
+		kfree(src);
+	}
+}
+
 static struct ovpn_peer *ovpn_mcast_sub_del(struct ovpn_mcast_sub *sub)
 {
 	struct ovpn_peer *peer = sub->peer;
 
+	ovpn_mcast_srcs_del_all(&sub->sources);
 	list_del(&sub->list);
 	kfree(sub);
 	return peer;
@@ -85,20 +103,138 @@  void ovpn_mcast_cleanup(struct ovpn_priv *ovpn)
 	}
 }
 
+static void ovpn_mcast_srcs_del(struct ovpn_mcast_sub *sub,
+				const struct in6_addr *sources,
+				const unsigned int nsrcs)
+{
+	struct ovpn_mcast_source *src, *next;
+	unsigned int i;
+
+	for (i = 0; i < nsrcs; i++) {
+		list_for_each_entry_safe(src, next, &sub->sources, list) {
+			if (ipv6_addr_equal(&src->addr, &sources[i])) {
+				list_del(&src->list);
+				kfree(src);
+				break;
+			}
+		}
+	}
+}
+
+static bool ovpn_mcast_source_exists(const struct ovpn_mcast_sub *sub,
+				     const struct in6_addr *addr)
+{
+	struct ovpn_mcast_source *src;
+
+	list_for_each_entry(src, &sub->sources, list) {
+		if (ipv6_addr_equal(&src->addr, addr))
+			return true;
+	}
+	return false;
+}
+
+static void ovpn_mcast_srcs_add(struct ovpn_mcast_sub *sub,
+				const struct in6_addr *sources,
+				const unsigned int nsrcs)
+{
+	struct ovpn_mcast_source *src;
+	unsigned int i;
+
+	for (i = 0; i < nsrcs; i++) {
+		if (ovpn_mcast_source_exists(sub, &sources[i]))
+			continue;
+
+		src = kzalloc_obj(*src, GFP_ATOMIC);
+		if (!src)
+			break;
+		src->addr = sources[i];
+		list_add_tail(&src->list, &sub->sources);
+	}
+}
+
+static struct ovpn_peer *ovpn_mcast_srcs_update(struct ovpn_mcast_sub *sub,
+						const enum ovpn_mcast_filter_mode msg_mode,
+						const struct in6_addr *sources,
+						const unsigned int nsrcs)
+{
+	if (!sources || !nsrcs)
+		return NULL;
+
+	/* ALLOW_NEW: add in INCLUDE, del in EXCLUDE.
+	 * BLOCK_OLD: del in INCLUDE, add in EXCLUDE.
+	 */
+	if (sub->filter_mode == msg_mode) {
+		ovpn_mcast_srcs_add(sub, sources, nsrcs);
+	} else {
+		ovpn_mcast_srcs_del(sub, sources, nsrcs);
+		if (sub->filter_mode == OVPN_MCAST_INCLUDE &&
+		    list_empty(&sub->sources))
+			return ovpn_mcast_sub_del(sub);
+	}
+	return NULL;
+}
+
+static bool ovpn_mcast_sub_init(struct ovpn_mcast_sub **subp,
+				struct ovpn_peer *peer,
+				const enum ovpn_mcast_filter_mode mode,
+				struct ovpn_mcast_group *group)
+{
+	struct ovpn_mcast_sub *sub;
+
+	sub = kzalloc_obj(*sub, GFP_ATOMIC);
+	if (unlikely(!sub))
+		return false;
+
+	if (!ovpn_peer_hold(peer)) {
+		kfree(sub);
+		return false;
+	}
+
+	sub->peer = peer;
+	sub->filter_mode = mode;
+	INIT_LIST_HEAD(&sub->sources);
+	list_add_tail(&sub->list, &group->subs);
+	*subp = sub;
+	return true;
+}
+
 /**
- * ovpn_mcast_join - add a peer to a multicast group
+ * ovpn_mcast_sub_update - create, replace, or incrementally update a multicast subscription
  * @ovpn: the ovpn instance
- * @peer: the peer joining the group
- * @group_addr: the multicast group address (IPv4-mapped IPv6 for IPv4 groups)
+ * @peer: the peer whose subscription is being updated
+ * @group_addr: the multicast group address
+ * @mode: the filter mode (INCLUDE or EXCLUDE)
+ * @sources: array of source addresses to add or remove
+ * @nsrcs: number of sources in @sources
+ * @incremental_update: if true, merge sources into existing state;
+ *			if false, replace state entirely
  *
- * Creates the group if it does not exist and adds a subscription for @peer.
- * If the peer is already subscribed, returns success without doing anything.
+ * When @incremental_update is false the subscription is fully replaced with
+ * the given @mode and @sources. An empty source list with INCLUDE mode is
+ * equivalent to leaving the group; with EXCLUDE mode it is an ASM join
+ * (receive all sources).
+ *
+ * When @incremental_update is true the sources are merged: they are added
+ * to the list when @mode matches the current filter mode, or removed when
+ * it differs. ALLOW_NEW maps to INCLUDE; BLOCK_OLD maps to EXCLUDE. If a
+ * BLOCK_OLD operation removes the last source from an INCLUDE subscription,
+ * the subscription is destroyed.
+ *
+ * If no subscription exists for @peer on @group_addr one is created. If the
+ * group does not exist it is created.
+ *
+ * All updates are atomic under @ovpn->lock.
  */
-void ovpn_mcast_join(struct ovpn_priv *ovpn, struct ovpn_peer *peer,
-		     const struct in6_addr *group_addr)
+void ovpn_mcast_sub_update(struct ovpn_priv *ovpn, struct ovpn_peer *peer,
+			   const struct in6_addr *group_addr,
+			   const enum ovpn_mcast_filter_mode mode,
+			   const struct in6_addr *sources,
+			   const unsigned int nsrcs,
+			   const bool incremental_update)
 {
 	struct ovpn_mcast_group *group;
 	struct ovpn_mcast_sub *sub;
+	struct ovpn_peer *peer_to_put = NULL;
 
 	if (!ovpn_mcast_addr_valid(group_addr))
 		return;
@@ -117,19 +253,47 @@  void ovpn_mcast_join(struct ovpn_priv *ovpn, struct ovpn_peer *peer,
 	}
 
 	list_for_each_entry(sub, &group->subs, list) {
-		if (sub->peer == peer)
+		if (sub->peer != peer)
+			continue;
+		if (incremental_update) {
+			peer_to_put = ovpn_mcast_srcs_update(sub, mode, sources, nsrcs);
+			ovpn_mcast_group_try_del(group);
 			goto end;
+		} else {
+			sub->filter_mode = mode;
+			ovpn_mcast_srcs_del_all(&sub->sources);
+			goto add_sources;
+		}
 	}
 
-	sub = kzalloc_obj(*sub, GFP_ATOMIC);
-	if (unlikely(!sub))
+	if (!ovpn_mcast_sub_init(&sub, peer, mode, group)) {
+		ovpn_mcast_group_try_del(group);
 		goto end;
-
-	sub->peer = peer;
-	ovpn_peer_hold(peer);
-	list_add_tail(&sub->list, &group->subs);
+	}
+add_sources:
+	if (sources && nsrcs)
+		ovpn_mcast_srcs_add(sub, sources, nsrcs);
 end:
 	spin_unlock_bh(&ovpn->lock);
+
+	if (peer_to_put)
+		ovpn_peer_put(peer_to_put);
+}
+
+/**
+ * ovpn_mcast_join - add a peer to a multicast group
+ * @ovpn: the ovpn instance
+ * @peer: the peer joining the group
+ * @group_addr: the multicast group address (IPv4-mapped IPv6 for IPv4 groups)
+ *
+ * Creates the group if it does not exist and adds a subscription for @peer.
+ * If the peer is already subscribed, returns without doing anything.
+ */
+void ovpn_mcast_join(struct ovpn_priv *ovpn, struct ovpn_peer *peer,
+		     const struct in6_addr *group_addr)
+{
+	ovpn_mcast_sub_update(ovpn, peer, group_addr, OVPN_MCAST_EXCLUDE,
+			      NULL, 0, false);
 }
 
 /**
@@ -202,20 +366,36 @@  void ovpn_mcast_leave_all(struct ovpn_peer *peer)
 		ovpn_peer_put(peer);
 }
 
+static bool ovpn_mcast_src_allowed(const struct ovpn_mcast_sub *sub,
+				   const struct in6_addr *src_addr)
+{
+	struct ovpn_mcast_source *src;
+
+	list_for_each_entry(src, &sub->sources, list) {
+		if (ipv6_addr_equal(&src->addr, src_addr))
+			return sub->filter_mode == OVPN_MCAST_INCLUDE;
+	}
+	return sub->filter_mode == OVPN_MCAST_EXCLUDE;
+}
+
 /**
  * ovpn_peer_list_get_by_mcast_group - retrieve peers subscribed to a multicast group
  * @ovpn: the ovpn instance to search
  * @group_addr: the multicast group address to look up
  * @list: the lockless list to append matching peers to
  *
- * Searches for the multicast group identified by @group_addr and appends all
- * subscribed peers to @list, acquiring a reference on each one.
+ * @src: the source address to match against per-peer source filters
+ *
+ * Searches for the multicast group identified by @group_addr and appends
+ * subscribed peers whose source filter allows @src to @list, acquiring a
+ * reference on each one.
  *
  * Return: false if no peer was found, true otherwise
  */
 bool ovpn_peer_list_get_by_mcast_group(struct ovpn_priv *ovpn,
 				       const struct in6_addr *group_addr,
-				       struct llist_head *list)
+				       struct llist_head *list,
+				       const struct in6_addr *src)
 {
 	struct ovpn_mcast_group *group;
 	struct ovpn_mcast_sub *sub;
@@ -225,7 +405,8 @@  bool ovpn_peer_list_get_by_mcast_group(struct ovpn_priv *ovpn,
 	group = ovpn_mcast_group_find(ovpn, group_addr);
 	if (group) {
 		list_for_each_entry(sub, &group->subs, list) {
-			if (ovpn_peer_hold(sub->peer))
+			if (ovpn_mcast_src_allowed(sub, src) &&
+			    ovpn_peer_hold(sub->peer))
 				llist_add(&sub->peer->mcast_entry, list);
 		}
 	}
@@ -305,18 +486,47 @@  static bool ovpn_mcast_snoop_mldv2(struct ovpn_peer *peer, struct sk_buff *skb,
 		/* recompute grec after potential head reallocation */
 		grec = (struct mld2_grec *)(skb_network_header(skb) + offset - rec_len);
 
-		/* In MLDv2 ASM, EXCLUDE mode with an empty source list means
-		 * "exclude nothing, receive everything" -> JOIN.
-		 * INCLUDE mode with an empty source list means
-		 * "include nothing, receive nothing" -> LEAVE.
-		 * See RFC 3810, section 4.
-		 */
-		if (nsrcs == 0 &&
-		    (grec->grec_type == MLD2_CHANGE_TO_INCLUDE ||
-		     grec->grec_type == MLD2_MODE_IS_INCLUDE)) {
-			ovpn_mcast_leave(peer->ovpn, peer, &grec->grec_mca);
-		} else {
-			ovpn_mcast_join(peer->ovpn, peer, &grec->grec_mca);
+		switch (grec->grec_type) {
+		case MLD2_MODE_IS_INCLUDE:
+		case MLD2_CHANGE_TO_INCLUDE:
+			if (nsrcs == 0)
+				ovpn_mcast_leave(peer->ovpn, peer,
+						 &grec->grec_mca);
+			else
+				ovpn_mcast_sub_update(peer->ovpn, peer,
+						      &grec->grec_mca,
+						      OVPN_MCAST_INCLUDE,
+						      grec->grec_src, nsrcs,
+						      false);
+			break;
+		case MLD2_MODE_IS_EXCLUDE:
+		case MLD2_CHANGE_TO_EXCLUDE:
+			if (nsrcs == 0)
+				ovpn_mcast_join(peer->ovpn, peer,
+						&grec->grec_mca);
+			else
+				ovpn_mcast_sub_update(peer->ovpn, peer,
+						      &grec->grec_mca,
+						      OVPN_MCAST_EXCLUDE,
+						      grec->grec_src, nsrcs,
+						      false);
+			break;
+		case MLD2_ALLOW_NEW_SOURCES:
+			if (nsrcs)
+				ovpn_mcast_sub_update(peer->ovpn, peer,
+						      &grec->grec_mca,
+						      OVPN_MCAST_INCLUDE,
+						      grec->grec_src, nsrcs,
+						      true);
+			break;
+		case MLD2_BLOCK_OLD_SOURCES:
+			if (nsrcs)
+				ovpn_mcast_sub_update(peer->ovpn, peer,
+						      &grec->grec_mca,
+						      OVPN_MCAST_EXCLUDE,
+						      grec->grec_src, nsrcs,
+						      true);
+			break;
 		}
 	}
 
@@ -381,9 +591,9 @@  static bool ovpn_mcast_snoop_igmpv3(struct ovpn_peer *peer, struct sk_buff *skb,
 				    unsigned int offset, const int ngrec)
 {
 	struct igmpv3_grec *grec;
-	struct in6_addr addr6;
+	struct in6_addr addr6, *srcs = NULL;
 	int i;
-	unsigned int rec_len;
+	unsigned int j, rec_len;
 	__u16 nsrcs;
 
 	for (i = 0; i < ngrec; i++) {
@@ -403,21 +613,53 @@  static bool ovpn_mcast_snoop_igmpv3(struct ovpn_peer *peer, struct sk_buff *skb,
 		/* recompute grec after potential head reallocation */
 		grec = (struct igmpv3_grec *)(skb_network_header(skb) + offset - rec_len);
 
-		/* In IGMPv3 ASM, EXCLUDE mode with an empty source list means
-		 * "exclude nothing, receive everything" -> JOIN.
-		 * INCLUDE mode with an empty source list means
-		 * "include nothing, receive nothing" -> LEAVE.
-		 * See RFC 3376, section 3.
-		 */
-		if (nsrcs == 0 &&
-		    (grec->grec_type == IGMPV3_CHANGE_TO_INCLUDE ||
-		     grec->grec_type == IGMPV3_MODE_IS_INCLUDE)) {
-			ipv6_addr_set_v4mapped(grec->grec_mca, &addr6);
-			ovpn_mcast_leave(peer->ovpn, peer, &addr6);
-		} else {
-			ipv6_addr_set_v4mapped(grec->grec_mca, &addr6);
-			ovpn_mcast_join(peer->ovpn, peer, &addr6);
+		ipv6_addr_set_v4mapped(grec->grec_mca, &addr6);
+
+		if (nsrcs > 0) {
+			srcs = kcalloc(nsrcs, sizeof(*srcs), GFP_ATOMIC);
+			if (!srcs)
+				return false;
+
+			for (j = 0; j < nsrcs; j++)
+				ipv6_addr_set_v4mapped(grec->grec_src[j],
+						       &srcs[j]);
 		}
+
+		switch (grec->grec_type) {
+		case IGMPV3_MODE_IS_INCLUDE:
+		case IGMPV3_CHANGE_TO_INCLUDE:
+			if (nsrcs == 0)
+				ovpn_mcast_leave(peer->ovpn, peer, &addr6);
+			else
+				ovpn_mcast_sub_update(peer->ovpn, peer, &addr6,
+						      OVPN_MCAST_INCLUDE, srcs,
+						      nsrcs, false);
+			break;
+		case IGMPV3_MODE_IS_EXCLUDE:
+		case IGMPV3_CHANGE_TO_EXCLUDE:
+			if (nsrcs == 0)
+				ovpn_mcast_join(peer->ovpn, peer, &addr6);
+			else
+				ovpn_mcast_sub_update(peer->ovpn, peer, &addr6,
+						      OVPN_MCAST_EXCLUDE, srcs,
+						      nsrcs, false);
+			break;
+		case IGMPV3_ALLOW_NEW_SOURCES:
+			if (nsrcs)
+				ovpn_mcast_sub_update(peer->ovpn, peer, &addr6,
+						      OVPN_MCAST_INCLUDE, srcs,
+						      nsrcs, true);
+			break;
+		case IGMPV3_BLOCK_OLD_SOURCES:
+			if (nsrcs)
+				ovpn_mcast_sub_update(peer->ovpn, peer, &addr6,
+						      OVPN_MCAST_EXCLUDE, srcs,
+						      nsrcs, true);
+			break;
+		}
+
+		kfree(srcs);
+		srcs = NULL;
 	}
 
 	return true;
diff --git a/drivers/net/ovpn/mcast.h b/drivers/net/ovpn/mcast.h
index 9e06e893a355..b41812534d58 100644
--- a/drivers/net/ovpn/mcast.h
+++ b/drivers/net/ovpn/mcast.h
@@ -13,15 +13,27 @@  struct in6_addr;
 struct llist_head;
 struct sk_buff;
 
+enum ovpn_mcast_filter_mode {
+	OVPN_MCAST_EXCLUDE,
+	OVPN_MCAST_INCLUDE,
+};
+
 void ovpn_mcast_cleanup(struct ovpn_priv *ovpn);
 void ovpn_mcast_join(struct ovpn_priv *ovpn, struct ovpn_peer *peer,
 		     const struct in6_addr *group_addr);
 void ovpn_mcast_leave(struct ovpn_priv *ovpn, struct ovpn_peer *peer,
 		      const struct in6_addr *group_addr);
+void ovpn_mcast_sub_update(struct ovpn_priv *ovpn, struct ovpn_peer *peer,
+			   const struct in6_addr *group_addr,
+			   const enum ovpn_mcast_filter_mode mode,
+			   const struct in6_addr *sources,
+			   const unsigned int nsrcs,
+			   const bool incremental_update);
 void ovpn_mcast_leave_all(struct ovpn_peer *peer);
 bool ovpn_peer_list_get_by_mcast_group(struct ovpn_priv *ovpn,
 				       const struct in6_addr *group_addr,
-				       struct llist_head *list);
+				       struct llist_head *list,
+				       const struct in6_addr *src);
 bool ovpn_mcast_is_control(struct sk_buff *skb);
 bool ovpn_mcast_snoop_skb(struct ovpn_peer *peer, struct sk_buff *skb);
 
diff --git a/drivers/net/ovpn/peer.c b/drivers/net/ovpn/peer.c
index a9728a157210..3fc69c3cecc0 100644
--- a/drivers/net/ovpn/peer.c
+++ b/drivers/net/ovpn/peer.c
@@ -751,7 +751,7 @@  void ovpn_peer_list_get_by_dst(struct ovpn_priv *ovpn, struct sk_buff *skb,
 {
 	struct ovpn_peer *peer = NULL;
 	unsigned int addr_type;
-	struct in6_addr addr6;
+	struct in6_addr addr6, src;
 	__be32 addr4;
 
 	/* in P2P mode, no matter the destination, packets are always sent to
@@ -779,7 +779,8 @@  void ovpn_peer_list_get_by_dst(struct ovpn_priv *ovpn, struct sk_buff *skb,
 		addr_type = inet_dev_addr_type(dev_net(ovpn->dev), ovpn->dev, addr4);
 		if (addr_type == RTN_MULTICAST) {
 			ipv6_addr_set_v4mapped(addr4, &addr6);
-			if (!ovpn_peer_list_get_by_mcast_group(ovpn, &addr6, list) &&
+			ipv6_addr_set_v4mapped(ip_hdr(skb)->saddr, &src);
+			if (!ovpn_peer_list_get_by_mcast_group(ovpn, &addr6, list, &src) &&
 			    ovpn_mcast_is_control(skb)) {
 				ovpn_peer_list_get_all(ovpn, list);
 			}
@@ -797,7 +798,7 @@  void ovpn_peer_list_get_by_dst(struct ovpn_priv *ovpn, struct sk_buff *skb,
 
 		rcu_read_unlock();
 		if (ipv6_addr_is_multicast(&addr6) &&
-		    !ovpn_peer_list_get_by_mcast_group(ovpn, &addr6, list) &&
+		    !ovpn_peer_list_get_by_mcast_group(ovpn, &addr6, list, &ipv6_hdr(skb)->saddr) &&
 		    ovpn_mcast_is_control(skb)) {
 			ovpn_peer_list_get_all(ovpn, list);
 		}