summary refs log tree commit diff
path: root/synapse/storage/databases/main/events.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/events.py')
-rw-r--r--synapse/storage/databases/main/events.py180
1 files changed, 98 insertions, 82 deletions
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index dd255aefb9..1ae1ebe108 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -39,7 +39,6 @@ from synapse.api.room_versions import RoomVersions
 from synapse.crypto.event_signing import compute_event_reference_hash
 from synapse.events import EventBase  # noqa: F401
 from synapse.events.snapshot import EventContext  # noqa: F401
-from synapse.logging.utils import log_function
 from synapse.storage._base import db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
@@ -69,7 +68,7 @@ event_counter = Counter(
 )
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
 class DeltaState:
     """Deltas to use to update the `current_state_events` table.
 
@@ -80,9 +79,9 @@ class DeltaState:
             should e.g. be removed from `current_state_events` table.
     """
 
-    to_delete = attr.ib(type=List[Tuple[str, str]])
-    to_insert = attr.ib(type=StateMap[str])
-    no_longer_in_room = attr.ib(type=bool, default=False)
+    to_delete: List[Tuple[str, str]]
+    to_insert: StateMap[str]
+    no_longer_in_room: bool = False
 
 
 class PersistEventsStore:
@@ -328,7 +327,6 @@ class PersistEventsStore:
 
         return existing_prevs
 
-    @log_function
     def _persist_events_txn(
         self,
         txn: LoggingTransaction,
@@ -442,12 +440,9 @@ class PersistEventsStore:
         self.db_pool.simple_insert_many_txn(
             txn,
             table="event_auth",
+            keys=("event_id", "room_id", "auth_id"),
             values=[
-                {
-                    "event_id": event.event_id,
-                    "room_id": event.room_id,
-                    "auth_id": auth_id,
-                }
+                (event.event_id, event.room_id, auth_id)
                 for event in events
                 for auth_id in event.auth_event_ids()
                 if event.is_state()
@@ -675,8 +670,9 @@ class PersistEventsStore:
         db_pool.simple_insert_many_txn(
             txn,
             table="event_auth_chains",
+            keys=("event_id", "chain_id", "sequence_number"),
             values=[
-                {"event_id": event_id, "chain_id": c_id, "sequence_number": seq}
+                (event_id, c_id, seq)
                 for event_id, (c_id, seq) in new_chain_tuples.items()
             ],
         )
@@ -782,13 +778,14 @@ class PersistEventsStore:
         db_pool.simple_insert_many_txn(
             txn,
             table="event_auth_chain_links",
+            keys=(
+                "origin_chain_id",
+                "origin_sequence_number",
+                "target_chain_id",
+                "target_sequence_number",
+            ),
             values=[
-                {
-                    "origin_chain_id": source_id,
-                    "origin_sequence_number": source_seq,
-                    "target_chain_id": target_id,
-                    "target_sequence_number": target_seq,
-                }
+                (source_id, source_seq, target_id, target_seq)
                 for (
                     source_id,
                     source_seq,
@@ -943,20 +940,28 @@ class PersistEventsStore:
             txn_id = getattr(event.internal_metadata, "txn_id", None)
             if token_id and txn_id:
                 to_insert.append(
-                    {
-                        "event_id": event.event_id,
-                        "room_id": event.room_id,
-                        "user_id": event.sender,
-                        "token_id": token_id,
-                        "txn_id": txn_id,
-                        "inserted_ts": self._clock.time_msec(),
-                    }
+                    (
+                        event.event_id,
+                        event.room_id,
+                        event.sender,
+                        token_id,
+                        txn_id,
+                        self._clock.time_msec(),
+                    )
                 )
 
         if to_insert:
             self.db_pool.simple_insert_many_txn(
                 txn,
                 table="event_txn_id",
+                keys=(
+                    "event_id",
+                    "room_id",
+                    "user_id",
+                    "token_id",
+                    "txn_id",
+                    "inserted_ts",
+                ),
                 values=to_insert,
             )
 
@@ -1161,8 +1166,9 @@ class PersistEventsStore:
         self.db_pool.simple_insert_many_txn(
             txn,
             table="event_forward_extremities",
+            keys=("event_id", "room_id"),
             values=[
-                {"event_id": ev_id, "room_id": room_id}
+                (ev_id, room_id)
                 for room_id, new_extrem in new_forward_extremities.items()
                 for ev_id in new_extrem
             ],
@@ -1174,12 +1180,9 @@ class PersistEventsStore:
         self.db_pool.simple_insert_many_txn(
             txn,
             table="stream_ordering_to_exterm",
+            keys=("room_id", "event_id", "stream_ordering"),
             values=[
-                {
-                    "room_id": room_id,
-                    "event_id": event_id,
-                    "stream_ordering": max_stream_order,
-                }
+                (room_id, event_id, max_stream_order)
                 for room_id, new_extrem in new_forward_extremities.items()
                 for event_id in new_extrem
             ],
@@ -1251,20 +1254,22 @@ class PersistEventsStore:
         for room_id, depth in depth_updates.items():
             self._update_min_depth_for_room_txn(txn, room_id, depth)
 
-    def _update_outliers_txn(self, txn, events_and_contexts):
+    def _update_outliers_txn(
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+    ) -> List[Tuple[EventBase, EventContext]]:
         """Update any outliers with new event info.
 
-        This turns outliers into ex-outliers (unless the new event was
-        rejected).
+        This turns outliers into ex-outliers (unless the new event was rejected), and
+        also removes any other events we have already seen from the list.
 
         Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            events_and_contexts (list[(EventBase, EventContext)]): events
-                we are persisting
+            txn: db connection
+            events_and_contexts: events we are persisting
 
         Returns:
-            list[(EventBase, EventContext)] new list, without events which
-            are already in the events table.
+            new list, without events which are already in the events table.
         """
         txn.execute(
             "SELECT event_id, outlier FROM events WHERE event_id in (%s)"
@@ -1272,7 +1277,9 @@ class PersistEventsStore:
             [event.event_id for event, _ in events_and_contexts],
         )
 
-        have_persisted = {event_id: outlier for event_id, outlier in txn}
+        have_persisted: Dict[str, bool] = {
+            event_id: outlier for event_id, outlier in txn
+        }
 
         to_remove = set()
         for event, context in events_and_contexts:
@@ -1282,15 +1289,22 @@ class PersistEventsStore:
             to_remove.add(event)
 
             if context.rejected:
-                # If the event is rejected then we don't care if the event
-                # was an outlier or not.
+                # If the incoming event is rejected then we don't care if the event
+                # was an outlier or not - what we have is at least as good.
                 continue
 
             outlier_persisted = have_persisted[event.event_id]
             if not event.internal_metadata.is_outlier() and outlier_persisted:
                 # We received a copy of an event that we had already stored as
-                # an outlier in the database. We now have some state at that
+                # an outlier in the database. We now have some state at that event
                 # so we need to update the state_groups table with that state.
+                #
+                # Note that we do not update the stream_ordering of the event in this
+                # scenario. XXX: does this cause bugs? It will mean we won't send such
+                # events down /sync. In general they will be historical events, so that
+                # doesn't matter too much, but that is not always the case.
+
+                logger.info("Updating state for ex-outlier event %s", event.event_id)
 
                 # insert into event_to_state_groups.
                 try:
@@ -1342,7 +1356,7 @@ class PersistEventsStore:
             d.pop("redacted_because", None)
             return d
 
-        self.db_pool.simple_insert_many_values_txn(
+        self.db_pool.simple_insert_many_txn(
             txn,
             table="event_json",
             keys=("event_id", "room_id", "internal_metadata", "json", "format_version"),
@@ -1358,7 +1372,7 @@ class PersistEventsStore:
             ),
         )
 
-        self.db_pool.simple_insert_many_values_txn(
+        self.db_pool.simple_insert_many_txn(
             txn,
             table="events",
             keys=(
@@ -1412,7 +1426,7 @@ class PersistEventsStore:
         )
         txn.execute(sql + clause, [False] + args)
 
-        self.db_pool.simple_insert_many_values_txn(
+        self.db_pool.simple_insert_many_txn(
             txn,
             table="state_events",
             keys=("event_id", "room_id", "type", "state_key"),
@@ -1622,14 +1636,9 @@ class PersistEventsStore:
         return self.db_pool.simple_insert_many_txn(
             txn=txn,
             table="event_labels",
+            keys=("event_id", "label", "room_id", "topological_ordering"),
             values=[
-                {
-                    "event_id": event_id,
-                    "label": label,
-                    "room_id": room_id,
-                    "topological_ordering": topological_ordering,
-                }
-                for label in labels
+                (event_id, label, room_id, topological_ordering) for label in labels
             ],
         )
 
@@ -1657,16 +1666,13 @@ class PersistEventsStore:
         vals = []
         for event in events:
             ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
-            vals.append(
-                {
-                    "event_id": event.event_id,
-                    "algorithm": ref_alg,
-                    "hash": memoryview(ref_hash_bytes),
-                }
-            )
+            vals.append((event.event_id, ref_alg, memoryview(ref_hash_bytes)))
 
         self.db_pool.simple_insert_many_txn(
-            txn, table="event_reference_hashes", values=vals
+            txn,
+            table="event_reference_hashes",
+            keys=("event_id", "algorithm", "hash"),
+            values=vals,
         )
 
     def _store_room_members_txn(
@@ -1689,18 +1695,25 @@ class PersistEventsStore:
         self.db_pool.simple_insert_many_txn(
             txn,
             table="room_memberships",
+            keys=(
+                "event_id",
+                "user_id",
+                "sender",
+                "room_id",
+                "membership",
+                "display_name",
+                "avatar_url",
+            ),
             values=[
-                {
-                    "event_id": event.event_id,
-                    "user_id": event.state_key,
-                    "sender": event.user_id,
-                    "room_id": event.room_id,
-                    "membership": event.membership,
-                    "display_name": non_null_str_or_none(
-                        event.content.get("displayname")
-                    ),
-                    "avatar_url": non_null_str_or_none(event.content.get("avatar_url")),
-                }
+                (
+                    event.event_id,
+                    event.state_key,
+                    event.user_id,
+                    event.room_id,
+                    event.membership,
+                    non_null_str_or_none(event.content.get("displayname")),
+                    non_null_str_or_none(event.content.get("avatar_url")),
+                )
                 for event in events
             ],
         )
@@ -1791,6 +1804,13 @@ class PersistEventsStore:
             txn.call_after(
                 self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
             )
+            # It should be safe to only invalidate the cache if the user has not
+            # previously participated in the thread, but that's difficult (and
+            # potentially error-prone) so it is always invalidated.
+            txn.call_after(
+                self.store.get_thread_participated.invalidate,
+                (parent_id, event.room_id, event.sender),
+            )
 
     def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
         """Handles keeping track of insertion events and edges/connections.
@@ -2163,13 +2183,9 @@ class PersistEventsStore:
         self.db_pool.simple_insert_many_txn(
             txn,
             table="event_edges",
+            keys=("event_id", "prev_event_id", "room_id", "is_state"),
             values=[
-                {
-                    "event_id": ev.event_id,
-                    "prev_event_id": e_id,
-                    "room_id": ev.room_id,
-                    "is_state": False,
-                }
+                (ev.event_id, e_id, ev.room_id, False)
                 for ev in events
                 for e_id in ev.prev_event_ids()
             ],
@@ -2226,17 +2242,17 @@ class PersistEventsStore:
         )
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
 class _LinkMap:
     """A helper type for tracking links between chains."""
 
     # Stores the set of links as nested maps: source chain ID -> target chain ID
     # -> source sequence number -> target sequence number.
-    maps = attr.ib(type=Dict[int, Dict[int, Dict[int, int]]], factory=dict)
+    maps: Dict[int, Dict[int, Dict[int, int]]] = attr.Factory(dict)
 
     # Stores the links that have been added (with new set to true), as tuples of
     # `(source chain ID, source sequence no, target chain ID, target sequence no.)`
-    additions = attr.ib(type=Set[Tuple[int, int, int, int]], factory=set)
+    additions: Set[Tuple[int, int, int, int]] = attr.Factory(set)
 
     def add_link(
         self,