diff --git a/changelog.d/17850.bugfix b/changelog.d/17850.bugfix
new file mode 100644
index 0000000000..8ea99c4ef9
--- /dev/null
+++ b/changelog.d/17850.bugfix
@@ -0,0 +1 @@
+Fix bug when some presence and typing timeouts can expire early.
\ No newline at end of file
diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py
index 44b109bdfd..95eb1d7185 100644
--- a/synapse/util/wheel_timer.py
+++ b/synapse/util/wheel_timer.py
@@ -47,7 +47,6 @@ class WheelTimer(Generic[T]):
"""
self.bucket_size: int = bucket_size
self.entries: List[_Entry[T]] = []
- self.current_tick: int = 0
def insert(self, now: int, obj: T, then: int) -> None:
"""Inserts object into timer.
@@ -78,11 +77,10 @@ class WheelTimer(Generic[T]):
self.entries[max(min_key, then_key) - min_key].elements.add(obj)
return
- next_key = now_key + 1
if self.entries:
- last_key = self.entries[-1].end_key
+ last_key = self.entries[-1].end_key + 1
else:
- last_key = next_key
+ last_key = now_key + 1
# Handle the case when `then` is in the past and `entries` is empty.
then_key = max(last_key, then_key)
diff --git a/tests/util/test_wheel_timer.py b/tests/util/test_wheel_timer.py
index 173a7cfaec..6fa575a18e 100644
--- a/tests/util/test_wheel_timer.py
+++ b/tests/util/test_wheel_timer.py
@@ -28,53 +28,55 @@ class WheelTimerTestCase(unittest.TestCase):
def test_single_insert_fetch(self) -> None:
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
- obj = object()
- wheel.insert(100, obj, 150)
+ wheel.insert(100, "1", 150)
self.assertListEqual(wheel.fetch(101), [])
self.assertListEqual(wheel.fetch(110), [])
self.assertListEqual(wheel.fetch(120), [])
self.assertListEqual(wheel.fetch(130), [])
self.assertListEqual(wheel.fetch(149), [])
- self.assertListEqual(wheel.fetch(156), [obj])
+ self.assertListEqual(wheel.fetch(156), ["1"])
self.assertListEqual(wheel.fetch(170), [])
def test_multi_insert(self) -> None:
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
- obj1 = object()
- obj2 = object()
- obj3 = object()
- wheel.insert(100, obj1, 150)
- wheel.insert(105, obj2, 130)
- wheel.insert(106, obj3, 160)
+ wheel.insert(100, "1", 150)
+ wheel.insert(105, "2", 130)
+ wheel.insert(106, "3", 160)
self.assertListEqual(wheel.fetch(110), [])
- self.assertListEqual(wheel.fetch(135), [obj2])
+ self.assertListEqual(wheel.fetch(135), ["2"])
self.assertListEqual(wheel.fetch(149), [])
- self.assertListEqual(wheel.fetch(158), [obj1])
+ self.assertListEqual(wheel.fetch(158), ["1"])
self.assertListEqual(wheel.fetch(160), [])
- self.assertListEqual(wheel.fetch(200), [obj3])
+ self.assertListEqual(wheel.fetch(200), ["3"])
self.assertListEqual(wheel.fetch(210), [])
def test_insert_past(self) -> None:
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
- obj = object()
- wheel.insert(100, obj, 50)
- self.assertListEqual(wheel.fetch(120), [obj])
+ wheel.insert(100, "1", 50)
+ self.assertListEqual(wheel.fetch(120), ["1"])
def test_insert_past_multi(self) -> None:
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
- obj1 = object()
- obj2 = object()
- obj3 = object()
- wheel.insert(100, obj1, 150)
- wheel.insert(100, obj2, 140)
- wheel.insert(100, obj3, 50)
- self.assertListEqual(wheel.fetch(110), [obj3])
+ wheel.insert(100, "1", 150)
+ wheel.insert(100, "2", 140)
+ wheel.insert(100, "3", 50)
+ self.assertListEqual(wheel.fetch(110), ["3"])
self.assertListEqual(wheel.fetch(120), [])
- self.assertListEqual(wheel.fetch(147), [obj2])
- self.assertListEqual(wheel.fetch(200), [obj1])
+ self.assertListEqual(wheel.fetch(147), ["2"])
+ self.assertListEqual(wheel.fetch(200), ["1"])
self.assertListEqual(wheel.fetch(240), [])
+
+ def test_multi_insert_then_past(self) -> None:
+ wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
+
+ wheel.insert(100, "1", 150)
+ wheel.insert(100, "2", 160)
+ wheel.insert(100, "3", 155)
+
+ self.assertListEqual(wheel.fetch(110), [])
+ self.assertListEqual(wheel.fetch(158), ["1"])
|