summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2015-06-18 15:49:05 +0100
committerErik Johnston <erik@matrix.org>2015-06-18 15:49:24 +0100
commit22049ea700173017cf2f8e88fb8848e06b82f9b3 (patch)
tree5a6375689c4009777690da704bc2baec86e5a014 /synapse
parentFix notifier leak (diff)
downloadsynapse-22049ea700173017cf2f8e88fb8848e06b82f9b3.tar.xz
Refactor the notifier.wait_for_events code to be clearer. Add _NotifierUserStream.new_listener that accpets a token to avoid races.
Diffstat (limited to 'synapse')
-rw-r--r--synapse/notifier.py122
-rw-r--r--synapse/util/__init__.py8
-rw-r--r--synapse/util/async.py13
3 files changed, 73 insertions, 70 deletions
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 27c034ed51..e441561029 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -16,7 +16,7 @@
 from twisted.internet import defer
 
 from synapse.util.logutils import log_function
-from synapse.util.async import run_on_reactor
+from synapse.util.async import run_on_reactor, ObservableDeferred
 from synapse.types import StreamToken
 import synapse.metrics
 
@@ -45,20 +45,16 @@ class _NotificationListener(object):
     The events stream handler will have yielded to the deferred, so to
     notify the handler it is sufficient to resolve the deferred.
     """
+    __slots__ = ["deferred"]
 
     def __init__(self, deferred):
-        self.deferred = deferred
+        object.__setattr__(self, "deferred", deferred)
 
-    def notified(self):
-        return self.deferred.called
+    def __getattr__(self, name):
+        return getattr(self.deferred, name)
 
-    def notify(self, token):
-        """ Inform whoever is listening about the new events.
-        """
-        try:
-            self.deferred.callback(token)
-        except defer.AlreadyCalledError:
-            pass
+    def __setattr__(self, name, value):
+        setattr(self.deferred, name, value)
 
 
 class _NotifierUserStream(object):
@@ -75,11 +71,12 @@ class _NotifierUserStream(object):
                  appservice=None):
         self.user = str(user)
         self.appservice = appservice
-        self.listeners = set()
         self.rooms = set(rooms)
         self.current_token = current_token
         self.last_notified_ms = time_now_ms
 
+        self.notify_deferred = ObservableDeferred(defer.Deferred())
+
     def notify(self, stream_key, stream_id, time_now_ms):
         """Notify any listeners for this user of a new event from an
         event source.
@@ -91,12 +88,10 @@ class _NotifierUserStream(object):
         self.current_token = self.current_token.copy_and_advance(
             stream_key, stream_id
         )
-        if self.listeners:
-            self.last_notified_ms = time_now_ms
-            listeners = self.listeners
-            self.listeners = set()
-            for listener in listeners:
-                listener.notify(self.current_token)
+        self.last_notified_ms = time_now_ms
+        noify_deferred = self.notify_deferred
+        self.notify_deferred = ObservableDeferred(defer.Deferred())
+        noify_deferred.callback(self.current_token)
 
     def remove(self, notifier):
         """ Remove this listener from all the indexes in the Notifier
@@ -114,6 +109,18 @@ class _NotifierUserStream(object):
                 self.appservice, set()
             ).discard(self)
 
+    def count_listeners(self):
+        return len(self.noify_deferred.observers())
+
+    def new_listener(self, token):
+        """Returns a deferred that is resolved when there is a new token
+        greater than the given token.
+        """
+        if self.current_token.is_after(token):
+            return _NotificationListener(defer.succeed(self.current_token))
+        else:
+            return _NotificationListener(self.notify_deferred.observe())
+
 
 class Notifier(object):
     """ This class is responsible for notifying any listeners when there are
@@ -158,7 +165,7 @@ class Notifier(object):
             for x in self.appservice_to_user_streams.values():
                 all_user_streams |= x
 
-            return sum(len(stream.listeners) for stream in all_user_streams)
+            return sum(stream.count_listeners() for stream in all_user_streams)
         metrics.register_callback("listeners", count_listeners)
 
         metrics.register_callback(
@@ -286,10 +293,6 @@ class Notifier(object):
         """Wait until the callback returns a non empty response or the
         timeout fires.
         """
-
-        deferred = defer.Deferred()
-        time_now_ms = self.clock.time_msec()
-
         user = str(user)
         user_stream = self.user_to_user_stream.get(user)
         if user_stream is None:
@@ -302,54 +305,38 @@ class Notifier(object):
                 rooms=rooms,
                 appservice=appservice,
                 current_token=current_token,
-                time_now_ms=time_now_ms,
+                time_now_ms=self.clock.time_msec(),
             )
             self._register_with_keys(user_stream)
-        else:
-            current_token = user_stream.current_token
 
         result = None
-        if current_token.is_after(from_token):
-            result = yield callback(from_token, current_token)
-
-        if result:
-            defer.returnValue(result)
-
         if timeout:
-            timer = [None]
-            listeners = []
-            timed_out = [False]
-
-            def notify_listeners():
-                user_stream.listeners.difference_update(listeners)
-                for listener in listeners:
-                    listener.notify(current_token)
-                del listeners[:]
-
-            def _timeout_listener():
-                timed_out[0] = True
-                timer[0] = None
-                notify_listeners()
-
-            # We create multiple notification listeners so we have to manage
-            # canceling the timeout ourselves.
-            timer[0] = self.clock.call_later(timeout/1000., _timeout_listener)
-
-            while not result and not timed_out[0]:
-                deferred = defer.Deferred()
-                notify_listeners()
-                listeners.append(_NotificationListener(deferred))
-                user_stream.listeners.update(listeners)
-                new_token = yield deferred
-
-                result = yield callback(current_token, new_token)
-                current_token = new_token
-
-            if timer[0] is not None:
+            listener = None
+            timer = self.clock.call_later(
+                timeout/1000., lambda: listener.cancel()
+            )
+
+            prev_token = from_token
+            while not result:
                 try:
-                    self.clock.cancel_call_later(timer[0])
-                except:
-                    logger.exception("Failed to cancel notifer timer")
+                    # We need to start listening to the streams *before* doing
+                    # the callback, as otherwise we may miss something.
+                    current_token = user_stream.current_token
+
+                    result = yield callback(prev_token, current_token)
+                    if result:
+                        break
+
+                    prev_token = current_token
+                    listener = user_stream.new_listener(prev_token)
+                    yield listener.deferred
+                except defer.CancelledError:
+                    break
+
+            self.clock.cancel_call_later(timer, ignore_errs=True)
+        else:
+            current_token = user_stream.current_token
+            result = yield callback(from_token, current_token)
 
         defer.returnValue(result)
 
@@ -367,6 +354,9 @@ class Notifier(object):
 
         @defer.inlineCallbacks
         def check_for_updates(before_token, after_token):
+            if not after_token.is_after(before_token):
+                defer.returnValue(None)
+
             events = []
             end_token = from_token
             for name, source in self.event_sources.sources.items():
@@ -401,7 +391,7 @@ class Notifier(object):
         expired_streams = []
         expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS
         for stream in self.user_to_user_stream.values():
-            if stream.listeners:
+            if stream.count_listeners():
                 continue
             if stream.last_notified_ms < expire_before_ts:
                 expired_streams.append(stream)
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 260714ccc2..07ff25cef3 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -91,8 +91,12 @@ class Clock(object):
         with PreserveLoggingContext():
             return reactor.callLater(delay, wrapped_callback, *args, **kwargs)
 
-    def cancel_call_later(self, timer):
-        timer.cancel()
+    def cancel_call_later(self, timer, ignore_errs=False):
+        try:
+            timer.cancel()
+        except:
+            if not ignore_errs:
+                raise
 
     def time_bound_deferred(self, given_deferred, time_out):
         if given_deferred.called:
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 1c2044e5b4..6f567bcaa6 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -45,7 +45,7 @@ class ObservableDeferred(object):
     def __init__(self, deferred, consumeErrors=False):
         object.__setattr__(self, "_deferred", deferred)
         object.__setattr__(self, "_result", None)
-        object.__setattr__(self, "_observers", [])
+        object.__setattr__(self, "_observers", set())
 
         def callback(r):
             self._result = (True, r)
@@ -74,12 +74,21 @@ class ObservableDeferred(object):
     def observe(self):
         if not self._result:
             d = defer.Deferred()
-            self._observers.append(d)
+
+            def remove(r):
+                self._observers.discard(d)
+                return r
+            d.addBoth(remove)
+
+            self._observers.add(d)
             return d
         else:
             success, res = self._result
             return defer.succeed(res) if success else defer.fail(res)
 
+    def observers(self):
+        return self._observers
+
     def __getattr__(self, name):
         return getattr(self._deferred, name)