[Openvpn-devel] Fix StatusChangeCallback so it works without a LogCallback

Message ID 20230709231929.195048-1-jeremyfleischman@gmail.com
State New
Delegated to: David Sommerseth
Headers show
Series [Openvpn-devel] Fix StatusChangeCallback so it works without a LogCallback | expand

Commit Message

Jeremy Fleischman July 9, 2023, 11:19 p.m. UTC
`StatusChangeCallback` requires that LogForward be enabled, which
previously only happened in `LogCallback`. Now both of them do it, which
requires a bit of bookkeeping. This fixes
https://github.com/OpenVPN/openvpn3-linux/issues/197

Notes:

1. I'm not sure if this would place nicely with a multithreaded program
   or not. Hopefully that's not something we care to support?
2. To keep the bookkeeping accurate, I opted to explictly remove any
   preexisting callbacks before registering a new one. This means that
   when you're clobbering an existing LogCallback (for example), you'll
   actually end disabling LogForward right before you re-enable it. I
   don't think this is a big deal, but just wanted to call it out.

Signed-off-by: Jeremy Fleischman <jeremyfleischman@gmail.com>
---
 src/python/openvpn3/SessionManager.py | 63 ++++++++++++++++++++++-----
 1 file changed, 51 insertions(+), 12 deletions(-)

Comments

David Sommerseth Sept. 3, 2023, 10:56 p.m. UTC | #1
Hi Jeremy,

First of all; sorry it's taken so long time to get back to you.  GH 
issue #171 has unfortunately taken most of my time, so this patch went 
on the side burner.

I've looked at your patch, and I wonder if it can be done a bit simpler. 
  I'm open to hear your views; I might have overlooked some details.


On 10/07/2023 01:19, Jeremy Fleischman wrote:
[...snip...]
> @@ -285,22 +286,24 @@ def GetFormattedStatistics(self, prefix='Connection statistics:\n', format_str='
>       #
>       def LogCallback(self, cbfnc):
>           if cbfnc is not None:
> +            # Remove the existing callback if there is one.
> +            if self.__log_callback is not None:
> +                self.LogCallback(None)
> +
>               self.__log_callback = cbfnc
>               self.__dbuscon.add_signal_receiver(cbfnc,
>                                                  signal_name='Log',
>                                                  dbus_interface='net.openvpn.v3.backends',
>                                                  bus_name='net.openvpn.v3.log',
>                                                  path=self.__session_path)
> -            self.__session_intf.LogForward(True)
> +            self.__add_LogForward_receiver()


I'm wondering if this could be made cleared.  The recursion is kinda 
clever, but feels like it hides the purpose.

Wouldn't this code below work just as well (consider this more a concept 
code than proper Python code)?  Here there is no recursion and no 
reference counting which could go wrong.

------------------------------------------------------------------------
class Session(object):
      [...]
      self.__log_forward_enabled = False

      def LogCallback(self, cbfnc):	
	if cbfnc is not None and self.__log_callback is None:
	    # Log Callback function is being enabled; not
             # set before
             self.__log_callback = cbfnc
	    self.__dbuscon.add_signal_receiver(cbfnc, ...)
	    self.__set_log_forward(True)

	elif cbfnc is None and self.__log_callback is not None:
	    # Log Callback function is being disabled; can only
             # happen because it was set
             self.__dbuscon.remove_signal_receiver(self.__log_callback,
                                                   ...)
             self.__log_callback = None
	    try:
		self.__set_log_forward(False)
	    except dbus.exception.DBusException:
		pass

	elif (cbfnc is not None and self.__log_callback is not None):
             # In this case, the program must first disable the
             # current LogCallback() before setting a new one.
	    raise RuntimeErrpr('LogCallback() is already enabled')

	# No need to complain if unsetting an unset log callback
         # function; this will not have any behavioral impact at all.




      def __set_log_forward(self, enable):
          # This method can only be called *after* callback
          # function has been set in this object

          if not self.__log_forward_enabled and enable:
              # If log forwarding is disabled it can be enabled
              self.__log_forward_enabled = True
              self.__session_intf.LogForward(True)

	 elif self.__log_forward_enabled and not enable:
              # Log forwarding can only be disabled if
              # both Log and StatusChange callbacks are unset
	     if (self.__log_callback is None)
                 and (self.__status_callback is None):
		 self.__log_forward_enabled = False
                  self.__session_intf.LogForward(False)

------------------------------------------------------------------------

The StatusChangeCallback() would need a similar implementation as 
LogCallback() too.

In regards to multi-threading; I would not expect this code to be used 
in a multi-threaded setup where different callbacks would be 
enabled/disabled on the fly in parallel.  But it might be good to add 
this as a comment in general, that these methods are not considered 
thread-safe.  However, since the code snippets above does not use 
reference counting, it should be a bit more robust as it bases the 
decision on the value of the callback function pointers.


Thoughts?
Jeremy Fleischman Sept. 4, 2023, 5:46 p.m. UTC | #2
Hey David,

On Sun, Sep 3, 2023 at 3:56 PM David Sommerseth
<dazo+openvpn@eurephia.org> wrote:
>
> First of all; sorry it's taken so long time to get back to you.  GH
> issue #171 has unfortunately taken most of my time, so this patch went
> on the side burner.

No worries, thanks for getting around to this!

> I've looked at your patch, and I wonder if it can be done a bit simpler.
>   I'm open to hear your views; I might have overlooked some details.

I like this! At a high level, your proposal consolidates the logic
about "who needs LogForward enabled" into this new __set_log_forward
method. It maybe makes it a little more complicated to add more
callbacks in the future that require LogForward, but that's an
incredibly minor point, and I don't think it's worth the complication
of the reference counting implementation I originally proposed.

I took your idea, and rolled with it a bit. I think we can simplify
things a bit further if we let __set_log_forward be completely
responsible for determining if LogForward should be enabled (it knows
if it should be enabled because it can check if there are any
callbacks registered that require LogForward). How does this updated
diff look?

diff --git a/src/python/openvpn3/SessionManager.py
b/src/python/openvpn3/SessionManager.py
index 3632790..1a567be 100644
--- a/src/python/openvpn3/SessionManager.py
+++ b/src/python/openvpn3/SessionManager.py
@@ -114,6 +114,7 @@ def __init__(self, dbuscon, objpath):
         self.__log_callback = None
         self.__status_callback = None
         self.__deleted = False
+        self.__log_forward_enabled = False


     def __del__(self):
@@ -291,16 +292,11 @@ def LogCallback(self, cbfnc):

dbus_interface='net.openvpn.v3.backends',
                                                bus_name='net.openvpn.v3.log',
                                                path=self.__session_path)
-            self.__session_intf.LogForward(True)
         else:
-            try:
-                self.__session_intf.LogForward(False)
-            except dbus.exceptions.DBusException:
-                # If this fails, the session is typically already removed
-                pass
             self.__dbuscon.remove_signal_receiver(self.__log_callback, 'Log')
             self.__log_callback = None

+        self.__set_log_forward()

     ##
     #  Subscribes to the StatusChange signals for this session and register
@@ -318,10 +314,14 @@ def StatusChangeCallback(self, cbfnc):
                                                bus_name='net.openvpn.v3.log',
                                                path=self.__session_path)
         else:
-            self.__dbuscon.remove_signal_receiver(self.__status_callback,
-                                                  'StatusChange')
-            self.__status_callback = None
+            # Only remove the callback if there actually *is* a callback
+            # currently.
+            if self.__status_callback is not None:
+                self.__dbuscon.remove_signal_receiver(self.__status_callback,
+                                                      'StatusChange')
+                self.__status_callback = None

+        self.__set_log_forward()


     ##
@@ -417,6 +417,30 @@ def GetDCO(self):
     def SetDCO(self, dco):
         self.__prop_intf.Set('net.openvpn.v3.sessions', 'dco', dco)

+    ##
+    #  Internal method to enable/disable LogForward as needed.
+    #  Must be called whenever a callback that needs LogForward enabled is
+    #  added or removed.
+    #
+    def __set_log_forward(self):
+        # The LogCallback and the StatusChangeCallback both need LogForward
+        # enabled. In other words, LogForward should be enabled iff one or both
+        # of those callbacks are registered.
+        should_log_forward_be_enabled = (
+            self.__log_callback is not None or self.__status_callback
is not None
+        )
+
+        if should_log_forward_be_enabled and not self.__log_forward_enabled:
+            self.__session_intf.LogForward(True)
+            self.__log_forward_enabled = True
+        elif not should_log_forward_be_enabled and self.__log_forward_enabled:
+            try:
+                self.__session_intf.LogForward(False)
+            except dbus.exceptions.DBusException:
+                # If this fails, the session is typically already removed
+                pass
+
+            self.__log_forward_enabled = False


 ##
Jeremy Fleischman Sept. 4, 2023, 5:51 p.m. UTC | #3
> diff --git a/src/python/openvpn3/SessionManager.py
> b/src/python/openvpn3/SessionManager.py
> index 3632790..1a567be 100644
> --- a/src/python/openvpn3/SessionManager.py
> +++ b/src/python/openvpn3/SessionManager.py
> @@ -114,6 +114,7 @@ def __init__(self, dbuscon, objpath):
>          self.__log_callback = None
>          self.__status_callback = None
>          self.__deleted = False
> +        self.__log_forward_enabled = False
>
>
>      def __del__(self):
> @@ -291,16 +292,11 @@ def LogCallback(self, cbfnc):
>
> dbus_interface='net.openvpn.v3.backends',
>                                                 bus_name='net.openvpn.v3.log',
>                                                 path=self.__session_path)
> -            self.__session_intf.LogForward(True)
>          else:
> -            try:
> -                self.__session_intf.LogForward(False)
> -            except dbus.exceptions.DBusException:
> -                # If this fails, the session is typically already removed
> -                pass
>              self.__dbuscon.remove_signal_receiver(self.__log_callback, 'Log')
>              self.__log_callback = None

Oops, this code unconditionally removes the callback, even if there
isn't currently a callback. The code below in StatusChangeCallback
first checks if there is currently a callback registered before removing
it. If `self.__dbuscon.remove_signal_receiver` is resilient to getting
passed None values for callbacks, then I suppose would could skip the
check and just unconditionally remove the callback.
Let me know what's best.

>
> +        self.__set_log_forward()
>
>      ##
>      #  Subscribes to the StatusChange signals for this session and register
> @@ -318,10 +314,14 @@ def StatusChangeCallback(self, cbfnc):
>                                                 bus_name='net.openvpn.v3.log',
>                                                 path=self.__session_path)
>          else:
> -            self.__dbuscon.remove_signal_receiver(self.__status_callback,
> -                                                  'StatusChange')
> -            self.__status_callback = None
> +            # Only remove the callback if there actually *is* a callback
> +            # currently.
> +            if self.__status_callback is not None:
> +                self.__dbuscon.remove_signal_receiver(self.__status_callback,
> +                                                      'StatusChange')
> +                self.__status_callback = None
>
> +        self.__set_log_forward()
>
>
>      ##
> @@ -417,6 +417,30 @@ def GetDCO(self):
>      def SetDCO(self, dco):
>          self.__prop_intf.Set('net.openvpn.v3.sessions', 'dco', dco)
>
> +    ##
> +    #  Internal method to enable/disable LogForward as needed.
> +    #  Must be called whenever a callback that needs LogForward enabled is
> +    #  added or removed.
> +    #
> +    def __set_log_forward(self):
> +        # The LogCallback and the StatusChangeCallback both need LogForward
> +        # enabled. In other words, LogForward should be enabled iff one or both
> +        # of those callbacks are registered.
> +        should_log_forward_be_enabled = (
> +            self.__log_callback is not None or self.__status_callback
> is not None
> +        )
> +
> +        if should_log_forward_be_enabled and not self.__log_forward_enabled:
> +            self.__session_intf.LogForward(True)
> +            self.__log_forward_enabled = True
> +        elif not should_log_forward_be_enabled and self.__log_forward_enabled:
> +            try:
> +                self.__session_intf.LogForward(False)
> +            except dbus.exceptions.DBusException:
> +                # If this fails, the session is typically already removed
> +                pass
> +
> +            self.__log_forward_enabled = False
>
>
>  ##

Patch

diff --git a/src/python/openvpn3/SessionManager.py b/src/python/openvpn3/SessionManager.py
index 3632790..05126aa 100644
--- a/src/python/openvpn3/SessionManager.py
+++ b/src/python/openvpn3/SessionManager.py
@@ -114,6 +114,7 @@  def __init__(self, dbuscon, objpath):
         self.__log_callback = None
         self.__status_callback = None
         self.__deleted = False
+        self.__LogForward_receiver_count = 0
 
 
     def __del__(self):
@@ -285,22 +286,24 @@  def GetFormattedStatistics(self, prefix='Connection statistics:\n', format_str='
     #
     def LogCallback(self, cbfnc):
         if cbfnc is not None:
+            # Remove the existing callback if there is one.
+            if self.__log_callback is not None:
+                self.LogCallback(None)
+
             self.__log_callback = cbfnc
             self.__dbuscon.add_signal_receiver(cbfnc,
                                                signal_name='Log',
                                                dbus_interface='net.openvpn.v3.backends',
                                                bus_name='net.openvpn.v3.log',
                                                path=self.__session_path)
-            self.__session_intf.LogForward(True)
+            self.__add_LogForward_receiver()
         else:
-            try:
-                self.__session_intf.LogForward(False)
-            except dbus.exceptions.DBusException:
-                # If this fails, the session is typically already removed
-                pass
-            self.__dbuscon.remove_signal_receiver(self.__log_callback, 'Log')
-            self.__log_callback = None
-
+            # Only remove the callback if there actually *is* a callback
+            # currently.
+            if self.__log_callback is not None:
+                self.__remove_LogForward_receiver()
+                self.__dbuscon.remove_signal_receiver(self.__log_callback, 'Log')
+                self.__log_callback = None
 
     ##
     #  Subscribes to the StatusChange signals for this session and register
@@ -311,16 +314,25 @@  def LogCallback(self, cbfnc):
     #
     def StatusChangeCallback(self, cbfnc):
         if cbfnc is not None:
+            # Remove the existing callback if there is one.
+            if self.__status_callback is not None:
+                self.StatusChangeCallback(None)
+
             self.__status_callback = cbfnc
             self.__dbuscon.add_signal_receiver(cbfnc,
                                                signal_name='StatusChange',
                                                dbus_interface='net.openvpn.v3.backends',
                                                bus_name='net.openvpn.v3.log',
                                                path=self.__session_path)
+            self.__add_LogForward_receiver()
         else:
-            self.__dbuscon.remove_signal_receiver(self.__status_callback,
-                                                  'StatusChange')
-            self.__status_callback = None
+            # Only remove the callback if there actually *is* a callback
+            # currently.
+            if self.__status_callback is not None:
+                self.__remove_LogForward_receiver()
+                self.__dbuscon.remove_signal_receiver(self.__status_callback,
+                                                      'StatusChange')
+                self.__status_callback = None
 
 
 
@@ -417,6 +429,33 @@  def GetDCO(self):
     def SetDCO(self, dco):
         self.__prop_intf.Set('net.openvpn.v3.sessions', 'dco', dco)
 
+    ##
+    #  Internal method to increase the count of how many signal receivers need
+    #  LogForward. Turns on LogForward if this is the first receiver.
+    #
+    def __add_LogForward_receiver(self):
+        # This is our first need for LogForward. Turn it on.
+        if self.__LogForward_receiver_count == 0:
+            self.__session_intf.LogForward(True)
+
+        self.__LogForward_receiver_count += 1
+
+    ##
+    #  Internal method track to reduce the count of how many signal receivers
+    #  need LogForward. Turns off LogForward if this was the last receiver.
+    #
+    def __remove_LogForward_receiver(self):
+        assert self.__LogForward_receiver_count > 0
+        self.__LogForward_receiver_count -= 1
+
+        # No receivers are left in need of LogForward. Turn it off.
+        if self.__LogForward_receiver_count == 0:
+            try:
+                self.__session_intf.LogForward(False)
+            except dbus.exceptions.DBusException:
+                # If this fails, the session is typically already removed
+                pass
+
 
 
 ##