summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/controllers/persist_events.py34
-rw-r--r--synapse/storage/databases/main/cache.py2
-rw-r--r--tests/storage/databases/main/test_events_worker.py120
3 files changed, 106 insertions, 50 deletions
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index bde7a6648a..de4c8c15f9 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -43,7 +43,7 @@ from prometheus_client import Counter, Histogram
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, Membership
-from synapse.events import EventBase
+from synapse.events import EventBase, relation_from_event
 from synapse.events.snapshot import EventContext
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
 from synapse.logging.tracing import (
@@ -435,6 +435,22 @@ class EventsPersistenceStorageController:
             else:
                 events.append(event)
 
+                # We expect events to be persisted by this point and this makes
+                # mypy happy about `stream_ordering` not being optional below
+                assert event.internal_metadata.stream_ordering
+                # Invalidate related caches after we persist a new event
+                relation = relation_from_event(event)
+                self.main_store._invalidate_caches_for_event(
+                    stream_ordering=event.internal_metadata.stream_ordering,
+                    event_id=event.event_id,
+                    room_id=event.room_id,
+                    etype=event.type,
+                    state_key=event.state_key if hasattr(event, "state_key") else None,
+                    redacts=event.redacts,
+                    relates_to=relation.parent_id if relation else None,
+                    backfilled=backfilled,
+                )
+
         return (
             events,
             self.main_store.get_room_max_token(),
@@ -467,6 +483,22 @@ class EventsPersistenceStorageController:
         replaced_event = replaced_events.get(event.event_id)
         if replaced_event:
             event = await self.main_store.get_event(replaced_event)
+        else:
+            # We expect events to be persisted by this point and this makes
+            # mypy happy about `stream_ordering` not being optional below
+            assert event.internal_metadata.stream_ordering
+            # Invalidate related caches after we persist a new event
+            relation = relation_from_event(event)
+            self.main_store._invalidate_caches_for_event(
+                stream_ordering=event.internal_metadata.stream_ordering,
+                event_id=event.event_id,
+                room_id=event.room_id,
+                etype=event.type,
+                state_key=event.state_key if hasattr(event, "state_key") else None,
+                redacts=event.redacts,
+                relates_to=relation.parent_id if relation else None,
+                backfilled=backfilled,
+            )
 
         event_stream_id = event.internal_metadata.stream_ordering
         # stream ordering should have been assigned by now
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 12e9a42382..99b0b5d09f 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -223,7 +223,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         # process triggering the invalidation is responsible for clearing any external
         # cached objects.
         self._invalidate_local_get_event_cache(event_id)
-        self.have_seen_event.invalidate((room_id, event_id))
+        self.have_seen_event.invalidate(((room_id, event_id),))
 
         self.get_latest_event_ids_in_room.invalidate((room_id,))
 
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 67401272ac..158ad1f439 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -35,66 +35,45 @@ from synapse.util import Clock
 from synapse.util.async_helpers import yieldable_gather_results
 
 from tests import unittest
+from tests.test_utils.event_injection import create_event, inject_event
 
 
 class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
     def prepare(self, reactor, clock, hs):
+        self.hs = hs
         self.store: EventsWorkerStore = hs.get_datastores().main
 
-        # insert some test data
-        for rid in ("room1", "room2"):
-            self.get_success(
-                self.store.db_pool.simple_insert(
-                    "rooms",
-                    {"room_id": rid, "room_version": 4},
-                )
-            )
+        self.user = self.register_user("user", "pass")
+        self.token = self.login(self.user, "pass")
+        self.room_id = self.helper.create_room_as(self.user, tok=self.token)
 
         self.event_ids: List[str] = []
-        for idx, rid in enumerate(
-            (
-                "room1",
-                "room1",
-                "room1",
-                "room2",
-            )
-        ):
-            event_json = {"type": f"test {idx}", "room_id": rid}
-            event = make_event_from_dict(event_json, room_version=RoomVersions.V4)
-            event_id = event.event_id
-
-            self.get_success(
-                self.store.db_pool.simple_insert(
-                    "events",
-                    {
-                        "event_id": event_id,
-                        "room_id": rid,
-                        "topological_ordering": idx,
-                        "stream_ordering": idx,
-                        "type": event.type,
-                        "processed": True,
-                        "outlier": False,
-                    },
-                )
-            )
-            self.get_success(
-                self.store.db_pool.simple_insert(
-                    "event_json",
-                    {
-                        "event_id": event_id,
-                        "room_id": rid,
-                        "json": json.dumps(event_json),
-                        "internal_metadata": "{}",
-                        "format_version": 3,
-                    },
+        for i in range(3):
+            event = self.get_success(
+                inject_event(
+                    hs,
+                    room_version=RoomVersions.V7.identifier,
+                    room_id=self.room_id,
+                    sender=self.user,
+                    type="test_event_type",
+                    content={"body": f"foobarbaz{i}"},
                 )
             )
-            self.event_ids.append(event_id)
+
+            self.event_ids.append(event.event_id)
 
     def test_simple(self):
         with LoggingContext(name="test") as ctx:
             res = self.get_success(
-                self.store.have_seen_events("room1", [self.event_ids[0], "event19"])
+                self.store.have_seen_events(
+                    self.room_id, [self.event_ids[0], "eventdoesnotexist"]
+                )
             )
             self.assertEqual(res, {self.event_ids[0]})
 
@@ -104,7 +83,9 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
         # a second lookup of the same events should cause no queries
         with LoggingContext(name="test") as ctx:
             res = self.get_success(
-                self.store.have_seen_events("room1", [self.event_ids[0], "event19"])
+                self.store.have_seen_events(
+                    self.room_id, [self.event_ids[0], "eventdoesnotexist"]
+                )
             )
             self.assertEqual(res, {self.event_ids[0]})
             self.assertEqual(ctx.get_resource_usage().db_txn_count, 0)
@@ -116,11 +97,54 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
         # looking it up should now cause no db hits
         with LoggingContext(name="test") as ctx:
             res = self.get_success(
-                self.store.have_seen_events("room1", [self.event_ids[0]])
+                self.store.have_seen_events(self.room_id, [self.event_ids[0]])
             )
             self.assertEqual(res, {self.event_ids[0]})
             self.assertEqual(ctx.get_resource_usage().db_txn_count, 0)
 
+    def test_persisting_event_invalidates_cache(self):
+        event, event_context = self.get_success(
+            create_event(
+                self.hs,
+                room_id=self.room_id,
+                sender=self.user,
+                type="test_event_type",
+                content={"body": "garply"},
+            )
+        )
+
+        with LoggingContext(name="test") as ctx:
+            # First, check `have_seen_event` for an event we have not seen yet
+            # to prime the cache with a `false` value.
+            res = self.get_success(
+                self.store.have_seen_events(event.room_id, [event.event_id])
+            )
+            self.assertEqual(res, set())
+
+            # That should result in a single db query to lookup
+            self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
+
+        # Persist the event which should invalidate or prefill the
+        # `have_seen_event` cache so we don't return stale values.
+        persistence = self.hs.get_storage_controllers().persistence
+        self.get_success(
+            persistence.persist_event(
+                event,
+                event_context,
+            )
+        )
+
+        with LoggingContext(name="test") as ctx:
+            # Check `have_seen_event` again and we should see the updated fact
+            # that we have now seen the event after persisting it.
+            res = self.get_success(
+                self.store.have_seen_events(event.room_id, [event.event_id])
+            )
+            self.assertEqual(res, {event.event_id})
+
+            # That should result in a single db query to lookup
+            self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
+
 
 class EventCacheTestCase(unittest.HomeserverTestCase):
     """Test that the various layers of event cache works."""