summary refs log tree commit diff
diff options
context:
space:
mode:
authorNick Mills-Barrett <nick@beeper.com>2022-09-26 16:26:35 +0100
committerGitHub <noreply@github.com>2022-09-26 16:26:35 +0100
commit6b4593a80fa2fd9ec8e1ec82fad74f3b7fbb9ba3 (patch)
tree39c8f39a45b9bf4e36575fbb2c2c0b2878255808
parentUpdate NixOS module URL (#13818) (diff)
downloadsynapse-6b4593a80fa2fd9ec8e1ec82fad74f3b7fbb9ba3.tar.xz
Simplify cache invalidation after event persist txn (#13796)
This moves all the invalidations into a single place and de-duplicates
the code involved in invalidating caches for a given event by using
the base class method.
-rw-r--r--changelog.d/13796.misc1
-rw-r--r--synapse/storage/_base.py3
-rw-r--r--synapse/storage/databases/main/cache.py34
-rw-r--r--synapse/storage/databases/main/events.py133
4 files changed, 52 insertions, 119 deletions
diff --git a/changelog.d/13796.misc b/changelog.d/13796.misc
new file mode 100644
index 0000000000..9ed1662394
--- /dev/null
+++ b/changelog.d/13796.misc
@@ -0,0 +1 @@
+Use shared methods for cache invalidation when persisting events, remove duplicate codepaths. Contributed by Nick @ Beeper (@fizzadar).
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 303a5d5298..313e8aca7d 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -91,6 +91,9 @@ class SQLBaseStore(metaclass=ABCMeta):
             self._attempt_to_invalidate_cache(
                 "get_user_in_room_with_profile", (room_id, user_id)
             )
+            self._attempt_to_invalidate_cache(
+                "get_rooms_for_user_with_stream_ordering", (user_id,)
+            )
 
         # Purge other caches based on room state.
         self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 2c421151c1..db6ce83a2b 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -223,15 +223,16 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         # process triggering the invalidation is responsible for clearing any external
         # cached objects.
         self._invalidate_local_get_event_cache(event_id)
-        self.have_seen_event.invalidate((room_id, event_id))
 
-        self.get_latest_event_ids_in_room.invalidate((room_id,))
-
-        self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,))
+        self._attempt_to_invalidate_cache("have_seen_event", (room_id, event_id))
+        self._attempt_to_invalidate_cache("get_latest_event_ids_in_room", (room_id,))
+        self._attempt_to_invalidate_cache(
+            "get_unread_event_push_actions_by_room_for_user", (room_id,)
+        )
 
         # The `_get_membership_from_event_id` is immutable, except for the
         # case where we look up an event *before* persisting it.
-        self._get_membership_from_event_id.invalidate((event_id,))
+        self._attempt_to_invalidate_cache("_get_membership_from_event_id", (event_id,))
 
         if not backfilled:
             self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
@@ -240,19 +241,26 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             self._invalidate_local_get_event_cache(redacts)
             # Caches which might leak edits must be invalidated for the event being
             # redacted.
-            self.get_relations_for_event.invalidate((redacts,))
-            self.get_applicable_edit.invalidate((redacts,))
+            self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,))
+            self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,))
 
         if etype == EventTypes.Member:
             self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
-            self.get_invited_rooms_for_local_user.invalidate((state_key,))
+            self._attempt_to_invalidate_cache(
+                "get_invited_rooms_for_local_user", (state_key,)
+            )
 
         if relates_to:
-            self.get_relations_for_event.invalidate((relates_to,))
-            self.get_aggregation_groups_for_event.invalidate((relates_to,))
-            self.get_applicable_edit.invalidate((relates_to,))
-            self.get_thread_summary.invalidate((relates_to,))
-            self.get_thread_participated.invalidate((relates_to,))
+            self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,))
+            self._attempt_to_invalidate_cache(
+                "get_aggregation_groups_for_event", (relates_to,)
+            )
+            self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,))
+            self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,))
+            self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,))
+            self._attempt_to_invalidate_cache(
+                "get_mutual_event_relations_for_rel_type", (relates_to,)
+            )
 
     async def invalidate_cache_and_stream(
         self, cache_name: str, keys: Tuple[Any, ...]
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1b54a2eb57..2e156a4a11 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -35,7 +35,7 @@ import attr
 from prometheus_client import Counter
 
 import synapse.metrics
-from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
+from synapse.api.constants import EventContentFields, EventTypes
 from synapse.api.errors import Codes, SynapseError
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase, relation_from_event
@@ -410,6 +410,31 @@ class PersistEventsStore:
         assert min_stream_order
         assert max_stream_order
 
+        # Once the txn completes, invalidate all of the relevant caches. Note that we do this
+        # up here because it captures all the events_and_contexts before any are removed.
+        for event, _ in events_and_contexts:
+            self.store.invalidate_get_event_cache_after_txn(txn, event.event_id)
+            if event.redacts:
+                self.store.invalidate_get_event_cache_after_txn(txn, event.redacts)
+
+            relates_to = None
+            relation = relation_from_event(event)
+            if relation:
+                relates_to = relation.parent_id
+
+            assert event.internal_metadata.stream_ordering is not None
+            txn.call_after(
+                self.store._invalidate_caches_for_event,
+                event.internal_metadata.stream_ordering,
+                event.event_id,
+                event.room_id,
+                event.type,
+                getattr(event, "state_key", None),
+                event.redacts,
+                relates_to,
+                backfilled=False,
+            )
+
         self._update_forward_extremities_txn(
             txn,
             new_forward_extremities=new_forward_extremities,
@@ -459,6 +484,7 @@ class PersistEventsStore:
 
         # We call this last as it assumes we've inserted the events into
         # room_memberships, where applicable.
+        # NB: This function invalidates all state related caches
         self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
 
     def _persist_event_auth_chain_txn(
@@ -1172,13 +1198,6 @@ class PersistEventsStore:
             )
 
             # Invalidate the various caches
-
-            for member in members_changed:
-                txn.call_after(
-                    self.store.get_rooms_for_user_with_stream_ordering.invalidate,
-                    (member,),
-                )
-
             self.store._invalidate_state_caches_and_stream(
                 txn, room_id, members_changed
             )
@@ -1222,9 +1241,6 @@ class PersistEventsStore:
             self.db_pool.simple_delete_txn(
                 txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
             )
-            txn.call_after(
-                self.store.get_latest_event_ids_in_room.invalidate, (room_id,)
-            )
 
         self.db_pool.simple_insert_many_txn(
             txn,
@@ -1294,8 +1310,6 @@ class PersistEventsStore:
         """
         depth_updates: Dict[str, int] = {}
         for event, context in events_and_contexts:
-            # Remove the any existing cache entries for the event_ids
-            self.store.invalidate_get_event_cache_after_txn(txn, event.event_id)
             # Then update the `stream_ordering` position to mark the latest
             # event as the front of the room. This should not be done for
             # backfilled events because backfilled events have negative
@@ -1697,16 +1711,7 @@ class PersistEventsStore:
         txn.async_call_after(prefill)
 
     def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None:
-        """Invalidate the caches for the redacted event.
-
-        Note that these caches are also cleared as part of event replication in
-        _invalidate_caches_for_event.
-        """
         assert event.redacts is not None
-        self.store.invalidate_get_event_cache_after_txn(txn, event.redacts)
-        txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,))
-        txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,))
-
         self.db_pool.simple_upsert_txn(
             txn,
             table="redactions",
@@ -1807,34 +1812,6 @@ class PersistEventsStore:
 
         for event in events:
             assert event.internal_metadata.stream_ordering is not None
-            txn.call_after(
-                self.store._membership_stream_cache.entity_has_changed,
-                event.state_key,
-                event.internal_metadata.stream_ordering,
-            )
-            txn.call_after(
-                self.store.get_invited_rooms_for_local_user.invalidate,
-                (event.state_key,),
-            )
-            txn.call_after(
-                self.store.get_local_users_in_room.invalidate,
-                (event.room_id,),
-            )
-            txn.call_after(
-                self.store.get_number_joined_users_in_room.invalidate,
-                (event.room_id,),
-            )
-            txn.call_after(
-                self.store.get_user_in_room_with_profile.invalidate,
-                (event.room_id, event.state_key),
-            )
-
-            # The `_get_membership_from_event_id` is immutable, except for the
-            # case where we look up an event *before* persisting it.
-            txn.call_after(
-                self.store._get_membership_from_event_id.invalidate,
-                (event.event_id,),
-            )
 
             # We update the local_current_membership table only if the event is
             # "current", i.e., its something that has just happened.
@@ -1883,35 +1860,6 @@ class PersistEventsStore:
             },
         )
 
-        txn.call_after(
-            self.store.get_relations_for_event.invalidate, (relation.parent_id,)
-        )
-        txn.call_after(
-            self.store.get_aggregation_groups_for_event.invalidate,
-            (relation.parent_id,),
-        )
-        txn.call_after(
-            self.store.get_mutual_event_relations_for_rel_type.invalidate,
-            (relation.parent_id,),
-        )
-
-        if relation.rel_type == RelationTypes.REPLACE:
-            txn.call_after(
-                self.store.get_applicable_edit.invalidate, (relation.parent_id,)
-            )
-
-        if relation.rel_type == RelationTypes.THREAD:
-            txn.call_after(
-                self.store.get_thread_summary.invalidate, (relation.parent_id,)
-            )
-            # It should be safe to only invalidate the cache if the user has not
-            # previously participated in the thread, but that's difficult (and
-            # potentially error-prone) so it is always invalidated.
-            txn.call_after(
-                self.store.get_thread_participated.invalidate,
-                (relation.parent_id, event.sender),
-            )
-
     def _handle_insertion_event(
         self, txn: LoggingTransaction, event: EventBase
     ) -> None:
@@ -2213,28 +2161,6 @@ class PersistEventsStore:
                 ),
             )
 
-            room_to_event_ids: Dict[str, List[str]] = {}
-            for e in non_outlier_events:
-                room_to_event_ids.setdefault(e.room_id, []).append(e.event_id)
-
-            for room_id, event_ids in room_to_event_ids.items():
-                rows = self.db_pool.simple_select_many_txn(
-                    txn,
-                    table="event_push_actions_staging",
-                    column="event_id",
-                    iterable=event_ids,
-                    keyvalues={},
-                    retcols=("user_id",),
-                )
-
-                user_ids = {row["user_id"] for row in rows}
-
-                for user_id in user_ids:
-                    txn.call_after(
-                        self.store.get_unread_event_push_actions_by_room_for_user.invalidate,
-                        (room_id, user_id),
-                    )
-
         # Now we delete the staging area for *all* events that were being
         # persisted.
         txn.execute_batch(
@@ -2249,11 +2175,6 @@ class PersistEventsStore:
     def _remove_push_actions_for_event_id_txn(
         self, txn: LoggingTransaction, room_id: str, event_id: str
     ) -> None:
-        # Sad that we have to blow away the cache for the whole room here
-        txn.call_after(
-            self.store.get_unread_event_push_actions_by_room_for_user.invalidate,
-            (room_id,),
-        )
         txn.execute(
             "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?",
             (room_id, event_id),