summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/13863.bugfix1
-rw-r--r--synapse/storage/databases/main/events_worker.py40
-rw-r--r--synapse/util/caches/descriptors.py6
-rw-r--r--tests/storage/databases/main/test_events_worker.py152
-rw-r--r--tests/util/caches/test_descriptors.py33
5 files changed, 165 insertions, 67 deletions
diff --git a/changelog.d/13863.bugfix b/changelog.d/13863.bugfix
new file mode 100644
index 0000000000..74264a4fab
--- /dev/null
+++ b/changelog.d/13863.bugfix
@@ -0,0 +1 @@
+Fix `have_seen_event` cache not being invalidated after we persist an event which causes inefficiency effects like extra `/state` federation calls.
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 52914febf9..7cdc9fe98f 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1474,32 +1474,38 @@ class EventsWorkerStore(SQLBaseStore):
         # the batches as big as possible.
 
         results: Set[str] = set()
-        for chunk in batch_iter(event_ids, 500):
-            r = await self._have_seen_events_dict(
-                [(room_id, event_id) for event_id in chunk]
+        for event_ids_chunk in batch_iter(event_ids, 500):
+            events_seen_dict = await self._have_seen_events_dict(
+                room_id, event_ids_chunk
+            )
+            results.update(
+                eid for (eid, have_event) in events_seen_dict.items() if have_event
             )
-            results.update(eid for ((_rid, eid), have_event) in r.items() if have_event)
 
         return results
 
-    @cachedList(cached_method_name="have_seen_event", list_name="keys")
+    @cachedList(cached_method_name="have_seen_event", list_name="event_ids")
     async def _have_seen_events_dict(
-        self, keys: Collection[Tuple[str, str]]
-    ) -> Dict[Tuple[str, str], bool]:
+        self,
+        room_id: str,
+        event_ids: Collection[str],
+    ) -> Dict[str, bool]:
         """Helper for have_seen_events
 
         Returns:
-             a dict {(room_id, event_id)-> bool}
+             a dict {event_id -> bool}
         """
         # if the event cache contains the event, obviously we've seen it.
 
         cache_results = {
-            (rid, eid)
-            for (rid, eid) in keys
-            if await self._get_event_cache.contains((eid,))
+            event_id
+            for event_id in event_ids
+            if await self._get_event_cache.contains((event_id,))
         }
         results = dict.fromkeys(cache_results, True)
-        remaining = [k for k in keys if k not in cache_results]
+        remaining = [
+            event_id for event_id in event_ids if event_id not in cache_results
+        ]
         if not remaining:
             return results
 
@@ -1511,23 +1517,21 @@ class EventsWorkerStore(SQLBaseStore):
 
             sql = "SELECT event_id FROM events AS e WHERE "
             clause, args = make_in_list_sql_clause(
-                txn.database_engine, "e.event_id", [eid for (_rid, eid) in remaining]
+                txn.database_engine, "e.event_id", remaining
             )
             txn.execute(sql + clause, args)
             found_events = {eid for eid, in txn}
 
             # ... and then we can update the results for each key
-            results.update(
-                {(rid, eid): (eid in found_events) for (rid, eid) in remaining}
-            )
+            results.update({eid: (eid in found_events) for eid in remaining})
 
         await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn)
         return results
 
     @cached(max_entries=100000, tree=True)
     async def have_seen_event(self, room_id: str, event_id: str) -> bool:
-        res = await self._have_seen_events_dict(((room_id, event_id),))
-        return res[(room_id, event_id)]
+        res = await self._have_seen_events_dict(room_id, [event_id])
+        return res[event_id]
 
     def _get_current_state_event_counts_txn(
         self, txn: LoggingTransaction, room_id: str
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 3909f1caea..0391966462 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -431,6 +431,12 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
         cache: DeferredCache[CacheKey, Any] = cached_method.cache
         num_args = cached_method.num_args
 
+        if num_args != self.num_args:
+            raise Exception(
+                "Number of args (%s) does not match underlying cache_method_name=%s (%s)."
+                % (self.num_args, self.cached_method_name, num_args)
+            )
+
         @functools.wraps(self.orig)
         def wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Dict]":
             # If we're passed a cache_context then we'll want to call its
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 67401272ac..32a798d74b 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -35,66 +35,45 @@ from synapse.util import Clock
 from synapse.util.async_helpers import yieldable_gather_results
 
 from tests import unittest
+from tests.test_utils.event_injection import create_event, inject_event
 
 
 class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
     def prepare(self, reactor, clock, hs):
+        self.hs = hs
         self.store: EventsWorkerStore = hs.get_datastores().main
 
-        # insert some test data
-        for rid in ("room1", "room2"):
-            self.get_success(
-                self.store.db_pool.simple_insert(
-                    "rooms",
-                    {"room_id": rid, "room_version": 4},
-                )
-            )
+        self.user = self.register_user("user", "pass")
+        self.token = self.login(self.user, "pass")
+        self.room_id = self.helper.create_room_as(self.user, tok=self.token)
 
         self.event_ids: List[str] = []
-        for idx, rid in enumerate(
-            (
-                "room1",
-                "room1",
-                "room1",
-                "room2",
-            )
-        ):
-            event_json = {"type": f"test {idx}", "room_id": rid}
-            event = make_event_from_dict(event_json, room_version=RoomVersions.V4)
-            event_id = event.event_id
-
-            self.get_success(
-                self.store.db_pool.simple_insert(
-                    "events",
-                    {
-                        "event_id": event_id,
-                        "room_id": rid,
-                        "topological_ordering": idx,
-                        "stream_ordering": idx,
-                        "type": event.type,
-                        "processed": True,
-                        "outlier": False,
-                    },
+        for i in range(3):
+            event = self.get_success(
+                inject_event(
+                    hs,
+                    room_version=RoomVersions.V7.identifier,
+                    room_id=self.room_id,
+                    sender=self.user,
+                    type="test_event_type",
+                    content={"body": f"foobarbaz{i}"},
                 )
             )
-            self.get_success(
-                self.store.db_pool.simple_insert(
-                    "event_json",
-                    {
-                        "event_id": event_id,
-                        "room_id": rid,
-                        "json": json.dumps(event_json),
-                        "internal_metadata": "{}",
-                        "format_version": 3,
-                    },
-                )
-            )
-            self.event_ids.append(event_id)
+
+            self.event_ids.append(event.event_id)
 
     def test_simple(self):
         with LoggingContext(name="test") as ctx:
             res = self.get_success(
-                self.store.have_seen_events("room1", [self.event_ids[0], "event19"])
+                self.store.have_seen_events(
+                    self.room_id, [self.event_ids[0], "eventdoesnotexist"]
+                )
             )
             self.assertEqual(res, {self.event_ids[0]})
 
@@ -104,7 +83,9 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
         # a second lookup of the same events should cause no queries
         with LoggingContext(name="test") as ctx:
             res = self.get_success(
-                self.store.have_seen_events("room1", [self.event_ids[0], "event19"])
+                self.store.have_seen_events(
+                    self.room_id, [self.event_ids[0], "eventdoesnotexist"]
+                )
             )
             self.assertEqual(res, {self.event_ids[0]})
             self.assertEqual(ctx.get_resource_usage().db_txn_count, 0)
@@ -116,11 +97,86 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
         # looking it up should now cause no db hits
         with LoggingContext(name="test") as ctx:
             res = self.get_success(
-                self.store.have_seen_events("room1", [self.event_ids[0]])
+                self.store.have_seen_events(self.room_id, [self.event_ids[0]])
             )
             self.assertEqual(res, {self.event_ids[0]})
             self.assertEqual(ctx.get_resource_usage().db_txn_count, 0)
 
+    def test_persisting_event_invalidates_cache(self):
+        """
+        Test to make sure that the `have_seen_event` cache
+        is invalidated after we persist an event and returns
+        the updated value.
+        """
+        event, event_context = self.get_success(
+            create_event(
+                self.hs,
+                room_id=self.room_id,
+                sender=self.user,
+                type="test_event_type",
+                content={"body": "garply"},
+            )
+        )
+
+        with LoggingContext(name="test") as ctx:
+            # First, check `have_seen_event` for an event we have not seen yet
+            # to prime the cache with a `false` value.
+            res = self.get_success(
+                self.store.have_seen_events(event.room_id, [event.event_id])
+            )
+            self.assertEqual(res, set())
+
+            # That should result in a single db query to lookup
+            self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
+
+        # Persist the event which should invalidate or prefill the
+        # `have_seen_event` cache so we don't return stale values.
+        persistence = self.hs.get_storage_controllers().persistence
+        self.get_success(
+            persistence.persist_event(
+                event,
+                event_context,
+            )
+        )
+
+        with LoggingContext(name="test") as ctx:
+            # Check `have_seen_event` again and we should see the updated fact
+            # that we have now seen the event after persisting it.
+            res = self.get_success(
+                self.store.have_seen_events(event.room_id, [event.event_id])
+            )
+            self.assertEqual(res, {event.event_id})
+
+            # That should result in a single db query to lookup
+            self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
+
+    def test_invalidate_cache_by_room_id(self):
+        """
+        Test to make sure that all events associated with the given `(room_id,)`
+        are invalidated in the `have_seen_event` cache.
+        """
+        with LoggingContext(name="test") as ctx:
+            # Prime the cache with some values
+            res = self.get_success(
+                self.store.have_seen_events(self.room_id, self.event_ids)
+            )
+            self.assertEqual(res, set(self.event_ids))
+
+            # That should result in a single db query to lookup
+            self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
+
+        # Clear the cache with any events associated with the `room_id`
+        self.store.have_seen_event.invalidate((self.room_id,))
+
+        with LoggingContext(name="test") as ctx:
+            res = self.get_success(
+                self.store.have_seen_events(self.room_id, self.event_ids)
+            )
+            self.assertEqual(res, set(self.event_ids))
+
+            # Since we cleared the cache, it should result in another db query to lookup
+            self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
+
 
 class EventCacheTestCase(unittest.HomeserverTestCase):
     """Test that the various layers of event cache works."""
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 48e616ac74..90861fe522 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Set
+from typing import Iterable, Set, Tuple
 from unittest import mock
 
 from twisted.internet import defer, reactor
@@ -1008,3 +1008,34 @@ class CachedListDescriptorTestCase(unittest.TestCase):
             obj.inner_context_was_finished, "Tried to restart a finished logcontext"
         )
         self.assertEqual(current_context(), SENTINEL_CONTEXT)
+
+    def test_num_args_mismatch(self):
+        """
+        Make sure someone does not accidentally use @cachedList on a method with
+        a mismatch in the number args to the underlying single cache method.
+        """
+
+        class Cls:
+            @descriptors.cached(tree=True)
+            def fn(self, room_id, event_id):
+                pass
+
+            # This is wrong ❌. `@cachedList` expects to be given the same number
+            # of arguments as the underlying cached function, just with one of
+            # the arguments being an iterable
+            @descriptors.cachedList(cached_method_name="fn", list_name="keys")
+            def list_fn(self, keys: Iterable[Tuple[str, str]]):
+                pass
+
+            # Corrected syntax ✅
+            #
+            # @cachedList(cached_method_name="fn", list_name="event_ids")
+            # async def list_fn(
+            #     self, room_id: str, event_ids: Collection[str],
+            # )
+
+        obj = Cls()
+
+        # Make sure this raises an error about the arg mismatch
+        with self.assertRaises(Exception):
+            obj.list_fn([("foo", "bar")])