summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/17130.misc1
-rw-r--r--synapse/util/caches/stream_change_cache.py20
-rw-r--r--tests/util/test_stream_change_cache.py17
3 files changed, 34 insertions, 4 deletions
diff --git a/changelog.d/17130.misc b/changelog.d/17130.misc
new file mode 100644
index 0000000000..ac20c90bde
--- /dev/null
+++ b/changelog.d/17130.misc
@@ -0,0 +1 @@
+Add optimisation to `StreamChangeCache.get_entities_changed(..)`.
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 2079ca789c..91c335f85b 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -165,7 +165,7 @@ class StreamChangeCache:
         return False
 
     def get_entities_changed(
-        self, entities: Collection[EntityType], stream_pos: int
+        self, entities: Collection[EntityType], stream_pos: int, _perf_factor: int = 1
     ) -> Union[Set[EntityType], FrozenSet[EntityType]]:
         """
         Returns the subset of the given entities that have had changes after the given position.
@@ -177,6 +177,8 @@ class StreamChangeCache:
         Args:
             entities: Entities to check for changes.
             stream_pos: The stream position to check for changes after.
+            _perf_factor: Used by unit tests to choose when to use each
+                optimisation.
 
         Return:
             A subset of entities which have changed after the given stream position.
@@ -184,6 +186,22 @@ class StreamChangeCache:
             This will be all entities if the given stream position is at or earlier
             than the earliest known stream position.
         """
+        if not self._cache or stream_pos <= self._earliest_known_stream_pos:
+            self.metrics.inc_misses()
+            return set(entities)
+
+        # If there have been tonnes of changes compared with the number of
+        # entities, it is faster to check each entities stream ordering
+        # one-by-one.
+        max_stream_pos, _ = self._cache.peekitem()
+        if max_stream_pos - stream_pos > _perf_factor * len(entities):
+            self.metrics.inc_hits()
+            return {
+                entity
+                for entity in entities
+                if self._entity_to_key.get(entity, -1) > stream_pos
+            }
+
         cache_result = self.get_all_entities_changed(stream_pos)
         if cache_result.hit:
             # We now do an intersection, trying to do so in the most efficient
diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py
index 3df053493b..5d38718a50 100644
--- a/tests/util/test_stream_change_cache.py
+++ b/tests/util/test_stream_change_cache.py
@@ -1,3 +1,5 @@
+from parameterized import parameterized
+
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 from tests import unittest
@@ -161,7 +163,8 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
         self.assertFalse(cache.has_any_entity_changed(2))
         self.assertFalse(cache.has_any_entity_changed(3))
 
-    def test_get_entities_changed(self) -> None:
+    @parameterized.expand([(0,), (1000000000,)])
+    def test_get_entities_changed(self, perf_factor: int) -> None:
         """
         StreamChangeCache.get_entities_changed will return the entities in the
         given list that have changed since the provided stream ID.  If the
@@ -178,7 +181,9 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
         # get the ones after that point.
         self.assertEqual(
             cache.get_entities_changed(
-                ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], stream_pos=2
+                ["user@foo.com", "bar@baz.net", "user@elsewhere.org"],
+                stream_pos=2,
+                _perf_factor=perf_factor,
             ),
             {"bar@baz.net", "user@elsewhere.org"},
         )
@@ -195,6 +200,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
                     "not@here.website",
                 ],
                 stream_pos=2,
+                _perf_factor=perf_factor,
             ),
             {"bar@baz.net", "user@elsewhere.org"},
         )
@@ -210,6 +216,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
                     "not@here.website",
                 ],
                 stream_pos=0,
+                _perf_factor=perf_factor,
             ),
             {"user@foo.com", "bar@baz.net", "user@elsewhere.org", "not@here.website"},
         )
@@ -217,7 +224,11 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
         # Query a subset of the entries mid-way through the stream. We should
         # only get back the subset.
         self.assertEqual(
-            cache.get_entities_changed(["bar@baz.net"], stream_pos=2),
+            cache.get_entities_changed(
+                ["bar@baz.net"],
+                stream_pos=2,
+                _perf_factor=perf_factor,
+            ),
             {"bar@baz.net"},
         )