summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/notifier.py50
-rw-r--r--tests/rest/client/v1/test_presence.py1
-rw-r--r--tests/utils.py3
3 files changed, 38 insertions, 16 deletions
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 6fcb7767a0..344dd03172 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -71,14 +71,17 @@ class _NotifierUserStream(object):
     so that it can remove itself from the indexes in the Notifier class.
     """
 
-    def __init__(self, user, rooms, current_token, appservice=None):
+    def __init__(self, user, rooms, current_token, time_now_ms,
+                 appservice=None):
         self.user = str(user)
         self.appservice = appservice
         self.listeners = set()
         self.rooms = set(rooms)
         self.current_token = current_token
+        self.last_notified_ms = time_now_ms
 
-    def notify(self, stream_key, stream_id):
+    def notify(self, stream_key, stream_id, time_now_ms):
+        self.last_notified_ms = time_now_ms
         self.current_token = self.current_token.copy_and_replace(
             stream_key, stream_id
         )
@@ -96,7 +99,7 @@ class _NotifierUserStream(object):
             lst = notifier.room_to_user_streams.get(room, set())
             lst.discard(self)
 
-        notifier.user_to_user_streams.get(self.user, set()).discard(self)
+        notifier.user_to_user_stream.pop(self.user)
 
         if self.appservice:
             notifier.appservice_to_user_streams.get(
@@ -111,6 +114,8 @@ class Notifier(object):
     Primarily used from the /events stream.
     """
 
+    UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000
+
     def __init__(self, hs):
         self.hs = hs
 
@@ -128,6 +133,10 @@ class Notifier(object):
             "user_joined_room", self._user_joined_room
         )
 
+        self.clock.looping_call(
+            self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS
+        )
+
         # This is not a very cheap test to perform, but it's only executed
         # when rendering the metrics page, which is likely once per minute at
         # most when scraping it.
@@ -221,9 +230,12 @@ class Notifier(object):
 
         logger.debug("on_new_room_event listeners %s", user_streams)
 
+        time_now_ms = self.clock.time_msec()
         for user_stream in user_streams:
             try:
-                user_stream.notify("room_key", "s%d" % (room_stream_id,))
+                user_stream.notify(
+                    "room_key", "s%d" % (room_stream_id,), time_now_ms
+                )
             except:
                 logger.exception("Failed to notify listener")
 
@@ -246,9 +258,10 @@ class Notifier(object):
         for room in rooms:
             user_streams |= self.room_to_user_streams.get(room, set())
 
+        time_now_ms = self.clock.time_msec()
         for user_stream in user_streams:
             try:
-                user_stream.notify(stream_key, new_token)
+                user_stream.notify(stream_key, new_token, time_now_ms)
             except:
                 logger.exception("Failed to notify listener")
 
@@ -260,6 +273,7 @@ class Notifier(object):
         """
 
         deferred = defer.Deferred()
+        time_now_ms = self.clock.time_msec()
 
         user = str(user)
         user_stream = self.user_to_user_stream.get(user)
@@ -272,6 +286,7 @@ class Notifier(object):
                 rooms=rooms,
                 appservice=appservice,
                 current_token=current_token,
+                time_now_ms=time_now_ms,
             )
             self._register_with_keys(user_stream)
         else:
@@ -366,6 +381,20 @@ class Notifier(object):
         defer.returnValue(result)
 
     @log_function
+    def remove_expired_streams(self):
+        time_now_ms = self.clock.time_msec()
+        expired_streams = []
+        expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS
+        for stream in self.user_to_user_stream.values():
+            if stream.listeners:
+                continue
+            if stream.last_notified_ms < expire_before_ts:
+                expired_streams.append(stream)
+
+        for expired_stream in expired_streams:
+            expired_stream.remove(self)
+
+    @log_function
     def _register_with_keys(self, user_stream):
         self.user_to_user_stream[user_stream.user] = user_stream
 
@@ -385,14 +414,3 @@ class Notifier(object):
             room_streams = self.room_to_user_streams.setdefault(room_id, set())
             room_streams.add(new_user_stream)
             new_user_stream.rooms.add(room_id)
-
-
-def _discard_if_notified(listener_set):
-    """Remove any 'stale' listeners from the given set.
-    """
-    to_discard = set()
-    for l in listener_set:
-        if l.notified():
-            to_discard.add(l)
-
-    listener_set -= to_discard
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index c0c52796ad..29c0038f06 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -271,6 +271,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
                 "call_later",
                 "cancel_call_later",
                 "time_msec",
+                "looping_call",
             ]),
         )
 
diff --git a/tests/utils.py b/tests/utils.py
index a67530bd63..3b5c335911 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -197,6 +197,9 @@ class MockClock(object):
 
         return t
 
+    def looping_call(self, function, interval):
+        pass
+
     def cancel_call_later(self, timer):
         if timer[2]:
             raise Exception("Cannot cancel an expired timer")