summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/14435.bugfix1
-rw-r--r--poetry.lock2
-rw-r--r--pyproject.toml3
-rw-r--r--synapse/util/caches/stream_change_cache.py142
-rw-r--r--tests/util/test_stream_change_cache.py38
5 files changed, 133 insertions, 53 deletions
diff --git a/changelog.d/14435.bugfix b/changelog.d/14435.bugfix
new file mode 100644
index 0000000000..149ee99dd7
--- /dev/null
+++ b/changelog.d/14435.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug where a device list update might not be sent to clients in certain circumstances.
diff --git a/poetry.lock b/poetry.lock
index 8c63134578..90b363a548 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1639,7 +1639,7 @@ url-preview = ["lxml"]
 [metadata]
 lock-version = "1.1"
 python-versions = "^3.7.1"
-content-hash = "27811bd21d56ceeb0f68ded5a00375efcd1a004928f0736f5b02927ce8594cb0"
+content-hash = "8c44ceeb9df5c3ab43040400e0a6b895de49417e61293a1ba027640b34f03263"
 
 [metadata.files]
 attrs = [
diff --git a/pyproject.toml b/pyproject.toml
index af5ce2aa03..1368e4e688 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -141,7 +141,8 @@ pyasn1 = ">=0.1.9"
 pyasn1-modules = ">=0.0.7"
 bcrypt = ">=3.1.7"
 Pillow = ">=5.4.0"
-sortedcontainers = ">=1.4.4"
+# We use SortedDict.peekitem(), which was added in sortedcontainers 1.5.2.
+sortedcontainers = ">=1.5.2"
 pymacaroons = ">=0.13.0"
 msgpack = ">=0.5.2"
 phonenumbers = ">=8.2.0"
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 666f4b6895..042de8d7c8 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -27,13 +27,17 @@ EntityType = str
 
 
 class StreamChangeCache:
-    """Keeps track of the stream positions of the latest change in a set of entities.
+    """
+    Keeps track of the stream positions of the latest change in a set of entities.
+
+    The entity will is typically a room ID or user ID, but can be any string.
 
-    Typically the entity will be a room or user id.
+    Can be queried for whether a specific entity has changed after a stream position
+    or for a list of changed entities after a stream position. See the individual
+    methods for more information.
 
-    Given a list of entities and a stream position, it will give a subset of
-    entities that may have changed since that position. If position key is too
-    old then the cache will simply return all given entities.
+    Only tracks to a maximum cache size, any position earlier than the earliest
+    known stream position must be treated as unknown.
     """
 
     def __init__(
@@ -45,16 +49,20 @@ class StreamChangeCache:
     ) -> None:
         self._original_max_size: int = max_size
         self._max_size = math.floor(max_size)
-        self._entity_to_key: Dict[EntityType, int] = {}
 
-        # map from stream id to the a set of entities which changed at that stream id.
+        # map from stream id to the set of entities which changed at that stream id.
         self._cache: SortedDict[int, Set[EntityType]] = SortedDict()
+        # map from entity to the stream ID of the latest change for that entity.
+        #
+        # Must be kept in sync with _cache.
+        self._entity_to_key: Dict[EntityType, int] = {}
 
         # the earliest stream_pos for which we can reliably answer
         # get_all_entities_changed. In other words, one less than the earliest
         # stream_pos for which we know _cache is valid.
         #
         self._earliest_known_stream_pos = current_stream_pos
+
         self.name = name
         self.metrics = caches.register_cache(
             "cache", self.name, self._cache, resize_callback=self.set_cache_factor
@@ -82,22 +90,46 @@ class StreamChangeCache:
         return False
 
     def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool:
-        """Returns True if the entity may have been updated since stream_pos"""
+        """
+        Returns True if the entity may have been updated after stream_pos.
+
+        Args:
+            entity: The entity to check for changes.
+            stream_pos: The stream position to check for changes after.
+
+        Return:
+            True if the entity may have been updated, this happens if:
+                * The given stream position is at or earlier than the earliest
+                  known stream position.
+                * The given stream position is earlier than the latest change for
+                  the entity.
+
+            False otherwise:
+                * The entity is unknown.
+                * The given stream position is at or later than the latest change
+                  for the entity.
+        """
         assert isinstance(stream_pos, int)
 
-        if stream_pos < self._earliest_known_stream_pos:
+        # _cache is not valid at or before the earliest known stream position, so
+        # return that the entity has changed.
+        if stream_pos <= self._earliest_known_stream_pos:
             self.metrics.inc_misses()
             return True
 
+        # If the entity is unknown, it hasn't changed.
         latest_entity_change_pos = self._entity_to_key.get(entity, None)
         if latest_entity_change_pos is None:
             self.metrics.inc_hits()
             return False
 
+        # This is a known entity, return true if the stream position is earlier
+        # than the last change.
         if stream_pos < latest_entity_change_pos:
             self.metrics.inc_misses()
             return True
 
+        # Otherwise, the stream position is after the latest change: return false.
         self.metrics.inc_hits()
         return False
 
@@ -105,15 +137,27 @@ class StreamChangeCache:
         self, entities: Collection[EntityType], stream_pos: int
     ) -> Union[Set[EntityType], FrozenSet[EntityType]]:
         """
-        Returns subset of entities that have had new things since the given
-        position.  Entities unknown to the cache will be returned.  If the
-        position is too old it will just return the given list.
+        Returns the subset of the given entities that have had changes after the given position.
+
+        Entities unknown to the cache will be returned.
+
+        If the position is too old it will just return the given list.
+
+        Args:
+            entities: Entities to check for changes.
+            stream_pos: The stream position to check for changes after.
+
+        Return:
+            A subset of entities which have changed after the given stream position.
+
+            This will be all entities if the given stream position is at or earlier
+            than the earliest known stream position.
         """
         changed_entities = self.get_all_entities_changed(stream_pos)
         if changed_entities is not None:
             # We now do an intersection, trying to do so in the most efficient
             # way possible (some of these sets are *large*). First check in the
-            # given iterable is already set that we can reuse, otherwise we
+            # given iterable is already a set that we can reuse, otherwise we
             # create a set of the *smallest* of the two iterables and call
             # `intersection(..)` on it (this can be twice as fast as the reverse).
             if isinstance(entities, (set, frozenset)):
@@ -130,29 +174,57 @@ class StreamChangeCache:
         return result
 
     def has_any_entity_changed(self, stream_pos: int) -> bool:
-        """Returns if any entity has changed"""
-        assert type(stream_pos) is int
+        """
+        Returns true if any entity has changed after the given stream position.
+
+        Args:
+            stream_pos: The stream position to check for changes after.
+
+        Return:
+            True if any entity has changed after the given stream position or
+            if the given stream position is at or earlier than the earliest
+            known stream position.
+
+            False otherwise.
+        """
+        assert isinstance(stream_pos, int)
 
         if not self._cache:
             # If the cache is empty, nothing can have changed.
             return False
 
-        if stream_pos >= self._earliest_known_stream_pos:
-            self.metrics.inc_hits()
-            return self._cache.bisect_right(stream_pos) < len(self._cache)
-        else:
+        # _cache is not valid at or before the earliest known stream position, so
+        # return that an entity has changed.
+        if stream_pos <= self._earliest_known_stream_pos:
             self.metrics.inc_misses()
             return True
 
+        self.metrics.inc_hits()
+        return stream_pos < self._cache.peekitem()[0]
+
     def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType]]:
-        """Returns all entities that have had new things since the given
-        position. If the position is too old it will return None.
+        """
+        Returns all entities that have had changes after the given position.
+
+        If the stream change cache does not go far enough back, i.e. the position
+        is too old, it will return None.
 
         Returns the entities in the order that they were changed.
+
+        Args:
+            stream_pos: The stream position to check for changes after.
+
+        Return:
+            Entities which have changed after the given stream position.
+
+            None if the given stream position is at or earlier than the earliest
+            known stream position.
         """
-        assert type(stream_pos) is int
+        assert isinstance(stream_pos, int)
 
-        if stream_pos < self._earliest_known_stream_pos:
+        # _cache is not valid at or before the earliest known stream position, so
+        # return None to mark that it is unknown if an entity has changed.
+        if stream_pos <= self._earliest_known_stream_pos:
             return None
 
         changed_entities: List[EntityType] = []
@@ -162,11 +234,17 @@ class StreamChangeCache:
         return changed_entities
 
     def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None:
-        """Informs the cache that the entity has been changed at the given
-        position.
         """
-        assert type(stream_pos) is int
+        Informs the cache that the entity has been changed at the given position.
+
+        Args:
+            entity: The entity to mark as changed.
+            stream_pos: The stream position to update the entity to.
+        """
+        assert isinstance(stream_pos, int)
 
+        # For a change before _cache is valid (e.g. at or before the earliest known
+        # stream position) there's nothing to do.
         if stream_pos <= self._earliest_known_stream_pos:
             return
 
@@ -189,6 +267,11 @@ class StreamChangeCache:
         self._evict()
 
     def _evict(self) -> None:
+        """
+        Ensure the cache has not exceeded the maximum size.
+
+        Evicts entries until it is at the maximum size.
+        """
         # if the cache is too big, remove entries
         while len(self._cache) > self._max_size:
             k, r = self._cache.popitem(0)
@@ -199,5 +282,12 @@ class StreamChangeCache:
     def get_max_pos_of_last_change(self, entity: EntityType) -> int:
         """Returns an upper bound of the stream id of the last change to an
         entity.
+
+        Args:
+            entity: The entity to check.
+
+        Return:
+            The stream position of the latest change for the given entity or
+            the earliest known stream position if the entitiy is unknown.
         """
         return self._entity_to_key.get(entity, self._earliest_known_stream_pos)
diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py
index 1b0fa52ad1..a29cc872f9 100644
--- a/tests/util/test_stream_change_cache.py
+++ b/tests/util/test_stream_change_cache.py
@@ -51,6 +51,8 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
         # return True, whether it's a known entity or not.
         self.assertTrue(cache.has_entity_changed("user@foo.com", 0))
         self.assertTrue(cache.has_entity_changed("not@here.website", 0))
+        self.assertTrue(cache.has_entity_changed("user@foo.com", 3))
+        self.assertTrue(cache.has_entity_changed("not@here.website", 3))
 
     def test_entity_has_changed_pops_off_start(self) -> None:
         """
@@ -65,15 +67,14 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
 
         # The cache is at the max size, 2
         self.assertEqual(len(cache._cache), 2)
+        # The cache's earliest known position is 2.
+        self.assertEqual(cache._earliest_known_stream_pos, 2)
 
         # The oldest item has been popped off
         self.assertTrue("user@foo.com" not in cache._entity_to_key)
 
-        self.assertEqual(
-            cache.get_all_entities_changed(2),
-            ["bar@baz.net", "user@elsewhere.org"],
-        )
-        self.assertIsNone(cache.get_all_entities_changed(1))
+        self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"])
+        self.assertIsNone(cache.get_all_entities_changed(2))
 
         # If we update an existing entity, it keeps the two existing entities
         cache.entity_has_changed("bar@baz.net", 5)
@@ -81,10 +82,10 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
             {"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key)
         )
         self.assertEqual(
-            cache.get_all_entities_changed(2),
+            cache.get_all_entities_changed(3),
             ["user@elsewhere.org", "bar@baz.net"],
         )
-        self.assertIsNone(cache.get_all_entities_changed(1))
+        self.assertIsNone(cache.get_all_entities_changed(2))
 
     def test_get_all_entities_changed(self) -> None:
         """
@@ -99,28 +100,15 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
         cache.entity_has_changed("anotheruser@foo.com", 3)
         cache.entity_has_changed("user@elsewhere.org", 4)
 
-        r = cache.get_all_entities_changed(1)
+        r = cache.get_all_entities_changed(2)
 
-        # either of these are valid
-        ok1 = [
-            "user@foo.com",
-            "bar@baz.net",
-            "anotheruser@foo.com",
-            "user@elsewhere.org",
-        ]
-        ok2 = [
-            "user@foo.com",
-            "anotheruser@foo.com",
-            "bar@baz.net",
-            "user@elsewhere.org",
-        ]
+        # Results are ordered so either of these are valid.
+        ok1 = ["bar@baz.net", "anotheruser@foo.com", "user@elsewhere.org"]
+        ok2 = ["anotheruser@foo.com", "bar@baz.net", "user@elsewhere.org"]
         self.assertTrue(r == ok1 or r == ok2)
 
-        r = cache.get_all_entities_changed(2)
-        self.assertTrue(r == ok1[1:] or r == ok2[1:])
-
         self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"])
-        self.assertEqual(cache.get_all_entities_changed(0), None)
+        self.assertEqual(cache.get_all_entities_changed(1), None)
 
         # ... later, things gest more updates
         cache.entity_has_changed("user@foo.com", 5)