summary refs log tree commit diff
path: root/synapse/storage/events.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/events.py')
-rw-r--r--synapse/storage/events.py39
1 files changed, 26 insertions, 13 deletions
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index a3c260ddc4..b2ab4b02f3 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -42,7 +42,7 @@ class EventsStore(SQLBaseStore):
             stream_ordering = self.min_token
 
         try:
-            yield self.runInteraction(
+            invalidates = yield self.runInteraction(
                 "persist_event",
                 self._persist_event_txn,
                 event=event,
@@ -52,6 +52,11 @@ class EventsStore(SQLBaseStore):
                 is_new_state=is_new_state,
                 current_state=current_state,
             )
+            for invalidated in invalidates:
+                invalidated_callback = invalidated[0]
+                invalidated_args = invalidated[1:]
+                invalidated_callback(*invalidated_args)
+
         except _RollbackButIsFineException:
             pass
 
@@ -91,9 +96,10 @@ class EventsStore(SQLBaseStore):
     def _persist_event_txn(self, txn, event, context, backfilled,
                            stream_ordering=None, is_new_state=True,
                            current_state=None):
+        invalidates = []
 
         # Remove the any existing cache entries for the event_id
-        self._invalidate_get_event_cache(event.event_id)
+        invalidates.append((self._invalidate_get_event_cache, event.event_id))
 
         if stream_ordering is None:
             with self._stream_id_gen.get_next_txn(txn) as stream_ordering:
@@ -150,10 +156,11 @@ class EventsStore(SQLBaseStore):
         outlier = event.internal_metadata.is_outlier()
 
         if not outlier:
-            self._store_state_groups_txn(txn, event, context)
+            self._store_state_groups_txn(txn, invalidates, event, context)
 
             self._update_min_depth_for_room_txn(
                 txn,
+                invalidates,
                 event.room_id,
                 event.depth
             )
@@ -199,6 +206,7 @@ class EventsStore(SQLBaseStore):
 
         self._handle_prev_events(
             txn,
+            invalidates,
             outlier=outlier,
             event_id=event.event_id,
             prev_events=event.prev_events,
@@ -206,13 +214,13 @@ class EventsStore(SQLBaseStore):
         )
 
         if event.type == EventTypes.Member:
-            self._store_room_member_txn(txn, event)
+            self._store_room_member_txn(txn, invalidates, event)
         elif event.type == EventTypes.Name:
-            self._store_room_name_txn(txn, event)
+            self._store_room_name_txn(txn, invalidates, event)
         elif event.type == EventTypes.Topic:
-            self._store_room_topic_txn(txn, event)
+            self._store_room_topic_txn(txn, invalidates, event)
         elif event.type == EventTypes.Redaction:
-            self._store_redaction(txn, event)
+            self._store_redaction(txn, invalidates, event)
 
         event_dict = {
             k: v
@@ -281,19 +289,22 @@ class EventsStore(SQLBaseStore):
         )
 
         if context.rejected:
-            self._store_rejections_txn(txn, event.event_id, context.rejected)
+            self._store_rejections_txn(
+                txn, invalidates, event.event_id, context.rejected
+            )
 
         for hash_alg, hash_base64 in event.hashes.items():
             hash_bytes = decode_base64(hash_base64)
             self._store_event_content_hash_txn(
-                txn, event.event_id, hash_alg, hash_bytes,
+                txn, invalidates, event.event_id, hash_alg, hash_bytes,
             )
 
         for prev_event_id, prev_hashes in event.prev_events:
             for alg, hash_base64 in prev_hashes.items():
                 hash_bytes = decode_base64(hash_base64)
                 self._store_prev_event_hash_txn(
-                    txn, event.event_id, prev_event_id, alg, hash_bytes
+                    txn, invalidates, event.event_id, prev_event_id, alg,
+                    hash_bytes
                 )
 
         for auth_id, _ in event.auth_events:
@@ -309,7 +320,7 @@ class EventsStore(SQLBaseStore):
 
         (ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
         self._store_event_reference_hash_txn(
-            txn, event.event_id, ref_alg, ref_hash_bytes
+            txn, invalidates, event.event_id, ref_alg, ref_hash_bytes
         )
 
         if event.is_state():
@@ -356,9 +367,11 @@ class EventsStore(SQLBaseStore):
                     }
                 )
 
-    def _store_redaction(self, txn, event):
+        return invalidates
+
+    def _store_redaction(self, txn, invalidates, event):
         # invalidate the cache for the redacted event
-        self._invalidate_get_event_cache(event.redacts)
+        invalidates.append((self._invalidate_get_event_cache, event.redacts))
         txn.execute(
             "INSERT INTO redactions (event_id, redacts) VALUES (?,?)",
             (event.event_id, event.redacts)