summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12656.misc1
-rw-r--r--synapse/handlers/presence.py42
-rw-r--r--synapse/util/wheel_timer.py39
3 files changed, 54 insertions, 28 deletions
diff --git a/changelog.d/12656.misc b/changelog.d/12656.misc
new file mode 100644
index 0000000000..8a8743e614
--- /dev/null
+++ b/changelog.d/12656.misc
@@ -0,0 +1 @@
+Prevent memory leak from reoccurring when presence is disabled.
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index d078162c29..268481ec19 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -659,27 +659,28 @@ class PresenceHandler(BasePresenceHandler):
         )
 
         now = self.clock.time_msec()
-        for state in self.user_to_current_state.values():
-            self.wheel_timer.insert(
-                now=now, obj=state.user_id, then=state.last_active_ts + IDLE_TIMER
-            )
-            self.wheel_timer.insert(
-                now=now,
-                obj=state.user_id,
-                then=state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
-            )
-            if self.is_mine_id(state.user_id):
+        if self._presence_enabled:
+            for state in self.user_to_current_state.values():
                 self.wheel_timer.insert(
-                    now=now,
-                    obj=state.user_id,
-                    then=state.last_federation_update_ts + FEDERATION_PING_INTERVAL,
+                    now=now, obj=state.user_id, then=state.last_active_ts + IDLE_TIMER
                 )
-            else:
                 self.wheel_timer.insert(
                     now=now,
                     obj=state.user_id,
-                    then=state.last_federation_update_ts + FEDERATION_TIMEOUT,
+                    then=state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
                 )
+                if self.is_mine_id(state.user_id):
+                    self.wheel_timer.insert(
+                        now=now,
+                        obj=state.user_id,
+                        then=state.last_federation_update_ts + FEDERATION_PING_INTERVAL,
+                    )
+                else:
+                    self.wheel_timer.insert(
+                        now=now,
+                        obj=state.user_id,
+                        then=state.last_federation_update_ts + FEDERATION_TIMEOUT,
+                    )
 
         # Set of users who have presence in the `user_to_current_state` that
         # have not yet been persisted
@@ -804,6 +805,13 @@ class PresenceHandler(BasePresenceHandler):
                 This is currently used to bump the max presence stream ID without changing any
                 user's presence (see PresenceHandler.add_users_to_send_full_presence_to).
         """
+        if not self._presence_enabled:
+            # We shouldn't get here if presence is disabled, but we check anyway
+            # to ensure that we don't a) send out presence federation and b)
+            # don't add things to the wheel timer that will never be handled.
+            logger.warning("Tried to update presence states when presence is disabled")
+            return
+
         now = self.clock.time_msec()
 
         with Measure(self.clock, "presence_update_states"):
@@ -1229,6 +1237,10 @@ class PresenceHandler(BasePresenceHandler):
         ):
             raise SynapseError(400, "Invalid presence state")
 
+        # If presence is disabled, no-op
+        if not self.hs.config.server.use_presence:
+            return
+
         user_id = target_user.to_string()
 
         prev_state = await self.current_state_for_user(user_id)
diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py
index e108adc460..177e198e7e 100644
--- a/synapse/util/wheel_timer.py
+++ b/synapse/util/wheel_timer.py
@@ -11,17 +11,20 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Generic, List, TypeVar
+import logging
+from typing import Generic, Hashable, List, Set, TypeVar
 
-T = TypeVar("T")
+import attr
 
+logger = logging.getLogger(__name__)
+
+T = TypeVar("T", bound=Hashable)
 
-class _Entry(Generic[T]):
-    __slots__ = ["end_key", "queue"]
 
-    def __init__(self, end_key: int) -> None:
-        self.end_key: int = end_key
-        self.queue: List[T] = []
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _Entry(Generic[T]):
+    end_key: int
+    elements: Set[T] = attr.Factory(set)
 
 
 class WheelTimer(Generic[T]):
@@ -48,17 +51,27 @@ class WheelTimer(Generic[T]):
             then: When to return the object strictly after.
         """
         then_key = int(then / self.bucket_size) + 1
+        now_key = int(now / self.bucket_size)
 
         if self.entries:
             min_key = self.entries[0].end_key
             max_key = self.entries[-1].end_key
 
+            if min_key < now_key - 10:
+                # If we have ten buckets that are due and still nothing has
+                # called `fetch()` then we likely have a bug that is causing a
+                # memory leak.
+                logger.warning(
+                    "Inserting into a wheel timer that hasn't been read from recently. Item: %s",
+                    obj,
+                )
+
             if then_key <= max_key:
                 # The max here is to protect against inserts for times in the past
-                self.entries[max(min_key, then_key) - min_key].queue.append(obj)
+                self.entries[max(min_key, then_key) - min_key].elements.add(obj)
                 return
 
-        next_key = int(now / self.bucket_size) + 1
+        next_key = now_key + 1
         if self.entries:
             last_key = self.entries[-1].end_key
         else:
@@ -71,7 +84,7 @@ class WheelTimer(Generic[T]):
         # to insert. This ensures there are no gaps.
         self.entries.extend(_Entry(key) for key in range(last_key, then_key + 1))
 
-        self.entries[-1].queue.append(obj)
+        self.entries[-1].elements.add(obj)
 
     def fetch(self, now: int) -> List[T]:
         """Fetch any objects that have timed out
@@ -84,11 +97,11 @@ class WheelTimer(Generic[T]):
         """
         now_key = int(now / self.bucket_size)
 
-        ret = []
+        ret: List[T] = []
         while self.entries and self.entries[0].end_key <= now_key:
-            ret.extend(self.entries.pop(0).queue)
+            ret.extend(self.entries.pop(0).elements)
 
         return ret
 
     def __len__(self) -> int:
-        return sum(len(entry.queue) for entry in self.entries)
+        return sum(len(entry.elements) for entry in self.entries)