summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/wheel_timer.py39
1 files changed, 26 insertions, 13 deletions
diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py
index e108adc460..177e198e7e 100644
--- a/synapse/util/wheel_timer.py
+++ b/synapse/util/wheel_timer.py
@@ -11,17 +11,20 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Generic, List, TypeVar
+import logging
+from typing import Generic, Hashable, List, Set, TypeVar
 
-T = TypeVar("T")
+import attr
 
+logger = logging.getLogger(__name__)
+
+T = TypeVar("T", bound=Hashable)
 
-class _Entry(Generic[T]):
-    __slots__ = ["end_key", "queue"]
 
-    def __init__(self, end_key: int) -> None:
-        self.end_key: int = end_key
-        self.queue: List[T] = []
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _Entry(Generic[T]):
+    end_key: int
+    elements: Set[T] = attr.Factory(set)
 
 
 class WheelTimer(Generic[T]):
@@ -48,17 +51,27 @@ class WheelTimer(Generic[T]):
             then: When to return the object strictly after.
         """
         then_key = int(then / self.bucket_size) + 1
+        now_key = int(now / self.bucket_size)
 
         if self.entries:
             min_key = self.entries[0].end_key
             max_key = self.entries[-1].end_key
 
+            if min_key < now_key - 10:
+                # If we have ten buckets that are due and still nothing has
+                # called `fetch()` then we likely have a bug that is causing a
+                # memory leak.
+                logger.warning(
+                    "Inserting into a wheel timer that hasn't been read from recently. Item: %s",
+                    obj,
+                )
+
             if then_key <= max_key:
                 # The max here is to protect against inserts for times in the past
-                self.entries[max(min_key, then_key) - min_key].queue.append(obj)
+                self.entries[max(min_key, then_key) - min_key].elements.add(obj)
                 return
 
-        next_key = int(now / self.bucket_size) + 1
+        next_key = now_key + 1
         if self.entries:
             last_key = self.entries[-1].end_key
         else:
@@ -71,7 +84,7 @@ class WheelTimer(Generic[T]):
         # to insert. This ensures there are no gaps.
         self.entries.extend(_Entry(key) for key in range(last_key, then_key + 1))
 
-        self.entries[-1].queue.append(obj)
+        self.entries[-1].elements.add(obj)
 
     def fetch(self, now: int) -> List[T]:
         """Fetch any objects that have timed out
@@ -84,11 +97,11 @@ class WheelTimer(Generic[T]):
         """
         now_key = int(now / self.bucket_size)
 
-        ret = []
+        ret: List[T] = []
         while self.entries and self.entries[0].end_key <= now_key:
-            ret.extend(self.entries.pop(0).queue)
+            ret.extend(self.entries.pop(0).elements)
 
         return ret
 
     def __len__(self) -> int:
-        return sum(len(entry.queue) for entry in self.entries)
+        return sum(len(entry.elements) for entry in self.entries)