summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-10-19 10:29:03 -0400
committerGitHub <noreply@github.com>2021-10-19 14:29:03 +0000
commit0dd0c40329cf620590b781b13d5b79332581fea7 (patch)
treebed05a4f4ade07d0519d735cdfad9c570814d90a
parentFix instances of [example]{.title-ref} in the upgrade notes (#11118) (diff)
downloadsynapse-0dd0c40329cf620590b781b13d5b79332581fea7.tar.xz
Add missing type hints to event fetching. (#11121)
Updates the event rows returned from the database to be
attrs classes instead of dictionaries.
-rw-r--r--changelog.d/11121.misc1
-rw-r--r--synapse/storage/databases/main/events_worker.py142
2 files changed, 82 insertions, 61 deletions
diff --git a/changelog.d/11121.misc b/changelog.d/11121.misc
new file mode 100644
index 0000000000..916beeaacb
--- /dev/null
+++ b/changelog.d/11121.misc
@@ -0,0 +1 @@
+Add type hints for event fetching.
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 4a1a2f4a6a..ae37901be9 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -55,8 +55,9 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import BackfillStream
 from synapse.replication.tcp.streams.events import EventsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.engines import PostgresEngine
+from synapse.storage.types import Connection
 from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
 from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import JsonDict, get_domain_from_id
@@ -86,6 +87,47 @@ class _EventCacheEntry:
     redacted_event: Optional[EventBase]
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _EventRow:
+    """
+    An event, as pulled from the database.
+
+    Properties:
+        event_id: The event ID of the event.
+
+        stream_ordering: stream ordering for this event
+
+        json: json-encoded event structure
+
+        internal_metadata: json-encoded internal metadata dict
+
+        format_version: The format of the event. Hopefully one of EventFormatVersions.
+            'None' means the event predates EventFormatVersions (so the event is format V1).
+
+        room_version_id: The version of the room which contains the event. Hopefully
+            one of RoomVersions.
+
+           Due to historical reasons, there may be a few events in the database which
+           do not have an associated room; in this case None will be returned here.
+
+        rejected_reason: if the event was rejected, the reason why.
+
+        redactions: a list of event-ids which (claim to) redact this event.
+
+        outlier: True if this event is an outlier.
+    """
+
+    event_id: str
+    stream_ordering: int
+    json: str
+    internal_metadata: str
+    format_version: Optional[int]
+    room_version_id: Optional[int]
+    rejected_reason: Optional[str]
+    redactions: List[str]
+    outlier: bool
+
+
 class EventRedactBehaviour(Names):
     """
     What to do when retrieving a redacted event from the database.
@@ -686,7 +728,7 @@ class EventsWorkerStore(SQLBaseStore):
             for e in state_to_include.values()
         ]
 
-    def _do_fetch(self, conn):
+    def _do_fetch(self, conn: Connection) -> None:
         """Takes a database connection and waits for requests for events from
         the _event_fetch_list queue.
         """
@@ -713,13 +755,15 @@ class EventsWorkerStore(SQLBaseStore):
 
             self._fetch_event_list(conn, event_list)
 
-    def _fetch_event_list(self, conn, event_list):
+    def _fetch_event_list(
+        self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
+    ) -> None:
         """Handle a load of requests from the _event_fetch_list queue
 
         Args:
-            conn (twisted.enterprise.adbapi.Connection): database connection
+            conn: database connection
 
-            event_list (list[Tuple[list[str], Deferred]]):
+            event_list:
                 The fetch requests. Each entry consists of a list of event
                 ids to be fetched, and a deferred to be completed once the
                 events have been fetched.
@@ -788,7 +832,7 @@ class EventsWorkerStore(SQLBaseStore):
                 row = row_map.get(event_id)
                 fetched_events[event_id] = row
                 if row:
-                    redaction_ids.update(row["redactions"])
+                    redaction_ids.update(row.redactions)
 
             events_to_fetch = redaction_ids.difference(fetched_events.keys())
             if events_to_fetch:
@@ -799,32 +843,32 @@ class EventsWorkerStore(SQLBaseStore):
         for event_id, row in fetched_events.items():
             if not row:
                 continue
-            assert row["event_id"] == event_id
+            assert row.event_id == event_id
 
-            rejected_reason = row["rejected_reason"]
+            rejected_reason = row.rejected_reason
 
             # If the event or metadata cannot be parsed, log the error and act
             # as if the event is unknown.
             try:
-                d = db_to_json(row["json"])
+                d = db_to_json(row.json)
             except ValueError:
                 logger.error("Unable to parse json from event: %s", event_id)
                 continue
             try:
-                internal_metadata = db_to_json(row["internal_metadata"])
+                internal_metadata = db_to_json(row.internal_metadata)
             except ValueError:
                 logger.error(
                     "Unable to parse internal_metadata from event: %s", event_id
                 )
                 continue
 
-            format_version = row["format_version"]
+            format_version = row.format_version
             if format_version is None:
                 # This means that we stored the event before we had the concept
                 # of a event format version, so it must be a V1 event.
                 format_version = EventFormatVersions.V1
 
-            room_version_id = row["room_version_id"]
+            room_version_id = row.room_version_id
 
             if not room_version_id:
                 # this should only happen for out-of-band membership events which
@@ -889,8 +933,8 @@ class EventsWorkerStore(SQLBaseStore):
                 internal_metadata_dict=internal_metadata,
                 rejected_reason=rejected_reason,
             )
-            original_ev.internal_metadata.stream_ordering = row["stream_ordering"]
-            original_ev.internal_metadata.outlier = row["outlier"]
+            original_ev.internal_metadata.stream_ordering = row.stream_ordering
+            original_ev.internal_metadata.outlier = row.outlier
 
             event_map[event_id] = original_ev
 
@@ -898,7 +942,7 @@ class EventsWorkerStore(SQLBaseStore):
         # the cache entries.
         result_map = {}
         for event_id, original_ev in event_map.items():
-            redactions = fetched_events[event_id]["redactions"]
+            redactions = fetched_events[event_id].redactions
             redacted_event = self._maybe_redact_event_row(
                 original_ev, redactions, event_map
             )
@@ -912,17 +956,17 @@ class EventsWorkerStore(SQLBaseStore):
 
         return result_map
 
-    async def _enqueue_events(self, events):
+    async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]:
         """Fetches events from the database using the _event_fetch_list. This
         allows batch and bulk fetching of events - it allows us to fetch events
         without having to create a new transaction for each request for events.
 
         Args:
-            events (Iterable[str]): events to be fetched.
+            events: events to be fetched.
 
         Returns:
-            Dict[str, Dict]: map from event id to row data from the database.
-                May contain events that weren't requested.
+            A map from event id to row data from the database. May contain events
+            that weren't requested.
         """
 
         events_d = defer.Deferred()
@@ -949,43 +993,19 @@ class EventsWorkerStore(SQLBaseStore):
 
         return row_map
 
-    def _fetch_event_rows(self, txn, event_ids):
+    def _fetch_event_rows(
+        self, txn: LoggingTransaction, event_ids: Iterable[str]
+    ) -> Dict[str, _EventRow]:
         """Fetch event rows from the database
 
         Events which are not found are omitted from the result.
 
-        The returned per-event dicts contain the following keys:
-
-         * event_id (str)
-
-         * stream_ordering (int): stream ordering for this event
-
-         * json (str): json-encoded event structure
-
-         * internal_metadata (str): json-encoded internal metadata dict
-
-         * format_version (int|None): The format of the event. Hopefully one
-           of EventFormatVersions. 'None' means the event predates
-           EventFormatVersions (so the event is format V1).
-
-         * room_version_id (str|None): The version of the room which contains the event.
-           Hopefully one of RoomVersions.
-
-           Due to historical reasons, there may be a few events in the database which
-           do not have an associated room; in this case None will be returned here.
-
-         * rejected_reason (str|None): if the event was rejected, the reason
-           why.
-
-         * redactions (List[str]): a list of event-ids which (claim to) redact
-           this event.
-
         Args:
-            txn (twisted.enterprise.adbapi.Connection):
-            event_ids (Iterable[str]): event IDs to fetch
+            txn: The database transaction.
+            event_ids: event IDs to fetch
 
         Returns:
-            Dict[str, Dict]: a map from event id to event info.
+            A map from event id to event info.
         """
         event_dict = {}
         for evs in batch_iter(event_ids, 200):
@@ -1013,17 +1033,17 @@ class EventsWorkerStore(SQLBaseStore):
 
             for row in txn:
                 event_id = row[0]
-                event_dict[event_id] = {
-                    "event_id": event_id,
-                    "stream_ordering": row[1],
-                    "internal_metadata": row[2],
-                    "json": row[3],
-                    "format_version": row[4],
-                    "room_version_id": row[5],
-                    "rejected_reason": row[6],
-                    "redactions": [],
-                    "outlier": row[7],
-                }
+                event_dict[event_id] = _EventRow(
+                    event_id=event_id,
+                    stream_ordering=row[1],
+                    internal_metadata=row[2],
+                    json=row[3],
+                    format_version=row[4],
+                    room_version_id=row[5],
+                    rejected_reason=row[6],
+                    redactions=[],
+                    outlier=row[7],
+                )
 
             # check for redactions
             redactions_sql = "SELECT event_id, redacts FROM redactions WHERE "
@@ -1035,7 +1055,7 @@ class EventsWorkerStore(SQLBaseStore):
             for (redacter, redacted) in txn:
                 d = event_dict.get(redacted)
                 if d:
-                    d["redactions"].append(redacter)
+                    d.redactions.append(redacter)
 
         return event_dict