@@ -17,9 +17,16 @@ struct ovpn_mcast_group {
struct list_head subs;
};
+struct ovpn_mcast_source {
+ struct list_head list;
+ struct in6_addr addr;
+};
+
struct ovpn_mcast_sub {
struct list_head list;
struct ovpn_peer *peer;
+ enum ovpn_mcast_filter_mode filter_mode;
+ struct list_head sources;
};
static inline u32 ovpn_mcast_hash(const struct in6_addr *group_addr)
@@ -47,10 +54,21 @@ ovpn_mcast_group_find(const struct ovpn_priv *ovpn, const struct in6_addr *group
return NULL;
}
+static void ovpn_mcast_srcs_del_all(struct list_head *srcs)
+{
+ struct ovpn_mcast_source *src, *next;
+
+ list_for_each_entry_safe(src, next, srcs, list) {
+ list_del(&src->list);
+ kfree(src);
+ }
+}
+
static struct ovpn_peer *ovpn_mcast_sub_del(struct ovpn_mcast_sub *sub)
{
struct ovpn_peer *peer = sub->peer;
+ ovpn_mcast_srcs_del_all(&sub->sources);
list_del(&sub->list);
kfree(sub);
return peer;
@@ -85,20 +103,138 @@ void ovpn_mcast_cleanup(struct ovpn_priv *ovpn)
}
}
+static void ovpn_mcast_srcs_del(struct ovpn_mcast_sub *sub,
+ const struct in6_addr *sources,
+ const unsigned int nsrcs)
+{
+ struct ovpn_mcast_source *src, *next;
+ unsigned int i;
+
+ for (i = 0; i < nsrcs; i++) {
+ list_for_each_entry_safe(src, next, &sub->sources, list) {
+ if (ipv6_addr_equal(&src->addr, &sources[i])) {
+ list_del(&src->list);
+ kfree(src);
+ break;
+ }
+ }
+ }
+}
+
+static bool ovpn_mcast_source_exists(const struct ovpn_mcast_sub *sub,
+ const struct in6_addr *addr)
+{
+ struct ovpn_mcast_source *src;
+
+ list_for_each_entry(src, &sub->sources, list) {
+ if (ipv6_addr_equal(&src->addr, addr))
+ return true;
+ }
+ return false;
+}
+
+static void ovpn_mcast_srcs_add(struct ovpn_mcast_sub *sub,
+ const struct in6_addr *sources,
+ const unsigned int nsrcs)
+{
+ struct ovpn_mcast_source *src;
+ unsigned int i;
+
+ for (i = 0; i < nsrcs; i++) {
+ if (ovpn_mcast_source_exists(sub, &sources[i]))
+ continue;
+
+ src = kzalloc_obj(*src, GFP_ATOMIC);
+ if (!src)
+ break;
+ src->addr = sources[i];
+ list_add_tail(&src->list, &sub->sources);
+ }
+}
+
+static struct ovpn_peer *ovpn_mcast_srcs_update(struct ovpn_mcast_sub *sub,
+ const enum ovpn_mcast_filter_mode msg_mode,
+ const struct in6_addr *sources,
+ const unsigned int nsrcs)
+{
+ if (!sources || !nsrcs)
+ return NULL;
+
+ /* ALLOW_NEW: add in INCLUDE, del in EXCLUDE.
+ * BLOCK_OLD: del in INCLUDE, add in EXCLUDE.
+ */
+ if (sub->filter_mode == msg_mode) {
+ ovpn_mcast_srcs_add(sub, sources, nsrcs);
+ } else {
+ ovpn_mcast_srcs_del(sub, sources, nsrcs);
+ if (sub->filter_mode == OVPN_MCAST_INCLUDE &&
+ list_empty(&sub->sources))
+ return ovpn_mcast_sub_del(sub);
+ }
+ return NULL;
+}
+
+static bool ovpn_mcast_sub_init(struct ovpn_mcast_sub **subp,
+ struct ovpn_peer *peer,
+ const enum ovpn_mcast_filter_mode mode,
+ struct ovpn_mcast_group *group)
+{
+ struct ovpn_mcast_sub *sub;
+
+ sub = kzalloc_obj(*sub, GFP_ATOMIC);
+ if (unlikely(!sub))
+ return false;
+
+ if (!ovpn_peer_hold(peer)) {
+ kfree(sub);
+ return false;
+ }
+
+ sub->peer = peer;
+ sub->filter_mode = mode;
+ INIT_LIST_HEAD(&sub->sources);
+ list_add_tail(&sub->list, &group->subs);
+ *subp = sub;
+ return true;
+}
+
/**
- * ovpn_mcast_join - add a peer to a multicast group
+ * ovpn_mcast_sub_update - create, replace, or incrementally update a multicast subscription
* @ovpn: the ovpn instance
- * @peer: the peer joining the group
- * @group_addr: the multicast group address (IPv4-mapped IPv6 for IPv4 groups)
+ * @peer: the peer whose subscription is being updated
+ * @group_addr: the multicast group address
+ * @mode: the filter mode (INCLUDE or EXCLUDE)
+ * @sources: array of source addresses to add or remove
+ * @nsrcs: number of sources in @sources
+ * @incremental_update: if true, merge sources into existing state;
+ * if false, replace state entirely
*
- * Creates the group if it does not exist and adds a subscription for @peer.
- * If the peer is already subscribed, returns success without doing anything.
+ * When @incremental_update is false the subscription is fully replaced with
+ * the given @mode and @sources. An empty source list with INCLUDE mode is
+ * equivalent to leaving the group; with EXCLUDE mode it is an ASM join
+ * (receive all sources).
+ *
+ * When @incremental_update is true the sources are merged: they are added
+ * to the list when @mode matches the current filter mode, or removed when
+ * it differs. ALLOW_NEW maps to INCLUDE; BLOCK_OLD maps to EXCLUDE. If a
+ * BLOCK_OLD operation removes the last source from an INCLUDE subscription,
+ * the subscription is destroyed.
+ *
+ * If no subscription exists for @peer on @group_addr one is created. If the
+ * group does not exist it is created.
+ *
+ * All updates are atomic under @ovpn->lock.
*/
-void ovpn_mcast_join(struct ovpn_priv *ovpn, struct ovpn_peer *peer,
- const struct in6_addr *group_addr)
+void ovpn_mcast_sub_update(struct ovpn_priv *ovpn, struct ovpn_peer *peer,
+ const struct in6_addr *group_addr,
+ const enum ovpn_mcast_filter_mode mode,
+ const struct in6_addr *sources,
+ const unsigned int nsrcs,
+ const bool incremental_update)
{
struct ovpn_mcast_group *group;
struct ovpn_mcast_sub *sub;
+ struct ovpn_peer *peer_to_put = NULL;
if (!ovpn_mcast_addr_valid(group_addr))
return;
@@ -117,19 +253,47 @@ void ovpn_mcast_join(struct ovpn_priv *ovpn, struct ovpn_peer *peer,
}
list_for_each_entry(sub, &group->subs, list) {
- if (sub->peer == peer)
+ if (sub->peer != peer)
+ continue;
+ if (incremental_update) {
+ peer_to_put = ovpn_mcast_srcs_update(sub, mode, sources, nsrcs);
+ ovpn_mcast_group_try_del(group);
goto end;
+ } else {
+ sub->filter_mode = mode;
+ ovpn_mcast_srcs_del_all(&sub->sources);
+ goto add_sources;
+ }
}
- sub = kzalloc_obj(*sub, GFP_ATOMIC);
- if (unlikely(!sub))
+ if (!ovpn_mcast_sub_init(&sub, peer, mode, group)) {
+ ovpn_mcast_group_try_del(group);
goto end;
-
- sub->peer = peer;
- ovpn_peer_hold(peer);
- list_add_tail(&sub->list, &group->subs);
+ }
+add_sources:
+ if (sources && nsrcs)
+ ovpn_mcast_srcs_add(sub, sources, nsrcs);
end:
spin_unlock_bh(&ovpn->lock);
+
+ if (peer_to_put)
+ ovpn_peer_put(peer_to_put);
+}
+
+/**
+ * ovpn_mcast_join - add a peer to a multicast group
+ * @ovpn: the ovpn instance
+ * @peer: the peer joining the group
+ * @group_addr: the multicast group address (IPv4-mapped IPv6 for IPv4 groups)
+ *
+ * Creates the group if it does not exist and adds a subscription for @peer.
+ * If the peer is already subscribed, returns without doing anything.
+ */
+void ovpn_mcast_join(struct ovpn_priv *ovpn, struct ovpn_peer *peer,
+ const struct in6_addr *group_addr)
+{
+ ovpn_mcast_sub_update(ovpn, peer, group_addr, OVPN_MCAST_EXCLUDE,
+ NULL, 0, false);
}
/**
@@ -202,20 +366,36 @@ void ovpn_mcast_leave_all(struct ovpn_peer *peer)
ovpn_peer_put(peer);
}
+static bool ovpn_mcast_src_allowed(const struct ovpn_mcast_sub *sub,
+ const struct in6_addr *src_addr)
+{
+ struct ovpn_mcast_source *src;
+
+ list_for_each_entry(src, &sub->sources, list) {
+ if (ipv6_addr_equal(&src->addr, src_addr))
+ return sub->filter_mode == OVPN_MCAST_INCLUDE;
+ }
+ return sub->filter_mode == OVPN_MCAST_EXCLUDE;
+}
+
/**
* ovpn_peer_list_get_by_mcast_group - retrieve peers subscribed to a multicast group
* @ovpn: the ovpn instance to search
* @group_addr: the multicast group address to look up
* @list: the lockless list to append matching peers to
*
- * Searches for the multicast group identified by @group_addr and appends all
- * subscribed peers to @list, acquiring a reference on each one.
+ * @src: the source address to match against per-peer source filters
+ *
+ * Searches for the multicast group identified by @group_addr and appends
+ * subscribed peers whose source filter allows @src to @list, acquiring a
+ * reference on each one.
*
* Return: false if no peer was found, true otherwise
*/
bool ovpn_peer_list_get_by_mcast_group(struct ovpn_priv *ovpn,
const struct in6_addr *group_addr,
- struct llist_head *list)
+ struct llist_head *list,
+ const struct in6_addr *src)
{
struct ovpn_mcast_group *group;
struct ovpn_mcast_sub *sub;
@@ -225,7 +405,8 @@ bool ovpn_peer_list_get_by_mcast_group(struct ovpn_priv *ovpn,
group = ovpn_mcast_group_find(ovpn, group_addr);
if (group) {
list_for_each_entry(sub, &group->subs, list) {
- if (ovpn_peer_hold(sub->peer))
+ if (ovpn_mcast_src_allowed(sub, src) &&
+ ovpn_peer_hold(sub->peer))
llist_add(&sub->peer->mcast_entry, list);
}
}
@@ -305,18 +486,47 @@ static bool ovpn_mcast_snoop_mldv2(struct ovpn_peer *peer, struct sk_buff *skb,
/* recompute grec after potential head reallocation */
grec = (struct mld2_grec *)(skb_network_header(skb) + offset - rec_len);
- /* In MLDv2 ASM, EXCLUDE mode with an empty source list means
- * "exclude nothing, receive everything" -> JOIN.
- * INCLUDE mode with an empty source list means
- * "include nothing, receive nothing" -> LEAVE.
- * See RFC 3810, section 4.
- */
- if (nsrcs == 0 &&
- (grec->grec_type == MLD2_CHANGE_TO_INCLUDE ||
- grec->grec_type == MLD2_MODE_IS_INCLUDE)) {
- ovpn_mcast_leave(peer->ovpn, peer, &grec->grec_mca);
- } else {
- ovpn_mcast_join(peer->ovpn, peer, &grec->grec_mca);
+ switch (grec->grec_type) {
+ case MLD2_MODE_IS_INCLUDE:
+ case MLD2_CHANGE_TO_INCLUDE:
+ if (nsrcs == 0)
+ ovpn_mcast_leave(peer->ovpn, peer,
+ &grec->grec_mca);
+ else
+ ovpn_mcast_sub_update(peer->ovpn, peer,
+ &grec->grec_mca,
+ OVPN_MCAST_INCLUDE,
+ grec->grec_src, nsrcs,
+ false);
+ break;
+ case MLD2_MODE_IS_EXCLUDE:
+ case MLD2_CHANGE_TO_EXCLUDE:
+ if (nsrcs == 0)
+ ovpn_mcast_join(peer->ovpn, peer,
+ &grec->grec_mca);
+ else
+ ovpn_mcast_sub_update(peer->ovpn, peer,
+ &grec->grec_mca,
+ OVPN_MCAST_EXCLUDE,
+ grec->grec_src, nsrcs,
+ false);
+ break;
+ case MLD2_ALLOW_NEW_SOURCES:
+ if (nsrcs)
+ ovpn_mcast_sub_update(peer->ovpn, peer,
+ &grec->grec_mca,
+ OVPN_MCAST_INCLUDE,
+ grec->grec_src, nsrcs,
+ true);
+ break;
+ case MLD2_BLOCK_OLD_SOURCES:
+ if (nsrcs)
+ ovpn_mcast_sub_update(peer->ovpn, peer,
+ &grec->grec_mca,
+ OVPN_MCAST_EXCLUDE,
+ grec->grec_src, nsrcs,
+ true);
+ break;
}
}
@@ -381,9 +591,9 @@ static bool ovpn_mcast_snoop_igmpv3(struct ovpn_peer *peer, struct sk_buff *skb,
unsigned int offset, const int ngrec)
{
struct igmpv3_grec *grec;
- struct in6_addr addr6;
+ struct in6_addr addr6, *srcs = NULL;
int i;
- unsigned int rec_len;
+ unsigned int j, rec_len;
__u16 nsrcs;
for (i = 0; i < ngrec; i++) {
@@ -403,21 +613,53 @@ static bool ovpn_mcast_snoop_igmpv3(struct ovpn_peer *peer, struct sk_buff *skb,
/* recompute grec after potential head reallocation */
grec = (struct igmpv3_grec *)(skb_network_header(skb) + offset - rec_len);
- /* In IGMPv3 ASM, EXCLUDE mode with an empty source list means
- * "exclude nothing, receive everything" -> JOIN.
- * INCLUDE mode with an empty source list means
- * "include nothing, receive nothing" -> LEAVE.
- * See RFC 3376, section 3.
- */
- if (nsrcs == 0 &&
- (grec->grec_type == IGMPV3_CHANGE_TO_INCLUDE ||
- grec->grec_type == IGMPV3_MODE_IS_INCLUDE)) {
- ipv6_addr_set_v4mapped(grec->grec_mca, &addr6);
- ovpn_mcast_leave(peer->ovpn, peer, &addr6);
- } else {
- ipv6_addr_set_v4mapped(grec->grec_mca, &addr6);
- ovpn_mcast_join(peer->ovpn, peer, &addr6);
+ ipv6_addr_set_v4mapped(grec->grec_mca, &addr6);
+
+ if (nsrcs > 0) {
+ srcs = kcalloc(nsrcs, sizeof(*srcs), GFP_ATOMIC);
+ if (!srcs)
+ return false;
+
+ for (j = 0; j < nsrcs; j++)
+ ipv6_addr_set_v4mapped(grec->grec_src[j],
+ &srcs[j]);
}
+
+ switch (grec->grec_type) {
+ case IGMPV3_MODE_IS_INCLUDE:
+ case IGMPV3_CHANGE_TO_INCLUDE:
+ if (nsrcs == 0)
+ ovpn_mcast_leave(peer->ovpn, peer, &addr6);
+ else
+ ovpn_mcast_sub_update(peer->ovpn, peer, &addr6,
+ OVPN_MCAST_INCLUDE, srcs,
+ nsrcs, false);
+ break;
+ case IGMPV3_MODE_IS_EXCLUDE:
+ case IGMPV3_CHANGE_TO_EXCLUDE:
+ if (nsrcs == 0)
+ ovpn_mcast_join(peer->ovpn, peer, &addr6);
+ else
+ ovpn_mcast_sub_update(peer->ovpn, peer, &addr6,
+ OVPN_MCAST_EXCLUDE, srcs,
+ nsrcs, false);
+ break;
+ case IGMPV3_ALLOW_NEW_SOURCES:
+ if (nsrcs)
+ ovpn_mcast_sub_update(peer->ovpn, peer, &addr6,
+ OVPN_MCAST_INCLUDE, srcs,
+ nsrcs, true);
+ break;
+ case IGMPV3_BLOCK_OLD_SOURCES:
+ if (nsrcs)
+ ovpn_mcast_sub_update(peer->ovpn, peer, &addr6,
+ OVPN_MCAST_EXCLUDE, srcs,
+ nsrcs, true);
+ break;
+ }
+
+ kfree(srcs);
+ srcs = NULL;
}
return true;
@@ -13,15 +13,27 @@ struct in6_addr;
struct llist_head;
struct sk_buff;
+enum ovpn_mcast_filter_mode {
+ OVPN_MCAST_EXCLUDE,
+ OVPN_MCAST_INCLUDE,
+};
+
void ovpn_mcast_cleanup(struct ovpn_priv *ovpn);
void ovpn_mcast_join(struct ovpn_priv *ovpn, struct ovpn_peer *peer,
const struct in6_addr *group_addr);
void ovpn_mcast_leave(struct ovpn_priv *ovpn, struct ovpn_peer *peer,
const struct in6_addr *group_addr);
+void ovpn_mcast_sub_update(struct ovpn_priv *ovpn, struct ovpn_peer *peer,
+ const struct in6_addr *group_addr,
+ const enum ovpn_mcast_filter_mode mode,
+ const struct in6_addr *sources,
+ const unsigned int nsrcs,
+ const bool incremental_update);
void ovpn_mcast_leave_all(struct ovpn_peer *peer);
bool ovpn_peer_list_get_by_mcast_group(struct ovpn_priv *ovpn,
const struct in6_addr *group_addr,
- struct llist_head *list);
+ struct llist_head *list,
+ const struct in6_addr *src);
bool ovpn_mcast_is_control(struct sk_buff *skb);
bool ovpn_mcast_snoop_skb(struct ovpn_peer *peer, struct sk_buff *skb);
@@ -751,7 +751,7 @@ void ovpn_peer_list_get_by_dst(struct ovpn_priv *ovpn, struct sk_buff *skb,
{
struct ovpn_peer *peer = NULL;
unsigned int addr_type;
- struct in6_addr addr6;
+ struct in6_addr addr6, src;
__be32 addr4;
/* in P2P mode, no matter the destination, packets are always sent to
@@ -779,7 +779,8 @@ void ovpn_peer_list_get_by_dst(struct ovpn_priv *ovpn, struct sk_buff *skb,
addr_type = inet_dev_addr_type(dev_net(ovpn->dev), ovpn->dev, addr4);
if (addr_type == RTN_MULTICAST) {
ipv6_addr_set_v4mapped(addr4, &addr6);
- if (!ovpn_peer_list_get_by_mcast_group(ovpn, &addr6, list) &&
+ ipv6_addr_set_v4mapped(ip_hdr(skb)->saddr, &src);
+ if (!ovpn_peer_list_get_by_mcast_group(ovpn, &addr6, list, &src) &&
ovpn_mcast_is_control(skb)) {
ovpn_peer_list_get_all(ovpn, list);
}
@@ -797,7 +798,7 @@ void ovpn_peer_list_get_by_dst(struct ovpn_priv *ovpn, struct sk_buff *skb,
rcu_read_unlock();
if (ipv6_addr_is_multicast(&addr6) &&
- !ovpn_peer_list_get_by_mcast_group(ovpn, &addr6, list) &&
+ !ovpn_peer_list_get_by_mcast_group(ovpn, &addr6, list, &ipv6_hdr(skb)->saddr) &&
ovpn_mcast_is_control(skb)) {
ovpn_peer_list_get_all(ovpn, list);
}
Extend the multicast subscription table to track per-source state, enabling SSM (S,G) forwarding instead of group-only ASM forwarding. Data structures: - Add ovpn_mcast_source to track individual source addresses - Extend ovpn_mcast_sub with filter_mode (INCLUDE/EXCLUDE) and a source list Subscription API: - Add ovpn_mcast_sub_update() to create or update a subscription with a full source list and filter mode - ovpn_mcast_join() becomes a thin wrapper around sub_update() (EXCLUDE mode with empty source list = ASM join) - Add ovpn_mcast_srcs_update() for incremental source merging: ALLOW_NEW adds sources in INCLUDE mode and removes them in EXCLUDE mode; BLOCK_OLD does the opposite - Empty INCLUDE subscriptions are automatically deleted when BLOCK_OLD removes the last source RX path (snooping): - IGMPv3 and MLDv2 parsers now extract source lists from reports and pass them to sub_update() - All record types are handled: MODE_IS_*, CHANGE_TO_*, ALLOW_NEW_SOURCES, BLOCK_OLD_SOURCES TX path (forwarding): - Add ovpn_mcast_src_allowed() to evaluate a source against a peer's filter mode and source list - ovpn_peer_list_get_by_mcast_group() now takes a source address and only returns peers whose subscription allows the source - ASM backward compatibility preserved: EXCLUDE with empty source list allows all sources Signed-off-by: Marco Baffo <marco@mandelbit.com> --- drivers/net/ovpn/mcast.c | 334 +++++++++++++++++++++++++++++++++------ drivers/net/ovpn/mcast.h | 14 +- drivers/net/ovpn/peer.c | 7 +- 3 files changed, 305 insertions(+), 50 deletions(-)