diff --git a/changelog.d/11121.misc b/changelog.d/11121.misc
new file mode 100644
index 0000000000..916beeaacb
--- /dev/null
+++ b/changelog.d/11121.misc
@@ -0,0 +1 @@
+Add type hints for event fetching.
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 4a1a2f4a6a..ae37901be9 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -55,8 +55,9 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.engines import PostgresEngine
+from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
@@ -86,6 +87,47 @@ class _EventCacheEntry:
redacted_event: Optional[EventBase]
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _EventRow:
+ """
+ An event, as pulled from the database.
+
+ Properties:
+ event_id: The event ID of the event.
+
+ stream_ordering: stream ordering for this event
+
+ json: json-encoded event structure
+
+ internal_metadata: json-encoded internal metadata dict
+
+ format_version: The format of the event. Hopefully one of EventFormatVersions.
+ 'None' means the event predates EventFormatVersions (so the event is format V1).
+
+ room_version_id: The version of the room which contains the event. Hopefully
+ one of RoomVersions.
+
+ Due to historical reasons, there may be a few events in the database which
+ do not have an associated room; in this case None will be returned here.
+
+ rejected_reason: if the event was rejected, the reason why.
+
+ redactions: a list of event-ids which (claim to) redact this event.
+
+ outlier: True if this event is an outlier.
+ """
+
+ event_id: str
+ stream_ordering: int
+ json: str
+ internal_metadata: str
+ format_version: Optional[int]
+ room_version_id: Optional[int]
+ rejected_reason: Optional[str]
+ redactions: List[str]
+ outlier: bool
+
+
class EventRedactBehaviour(Names):
"""
What to do when retrieving a redacted event from the database.
@@ -686,7 +728,7 @@ class EventsWorkerStore(SQLBaseStore):
for e in state_to_include.values()
]
- def _do_fetch(self, conn):
+ def _do_fetch(self, conn: Connection) -> None:
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
"""
@@ -713,13 +755,15 @@ class EventsWorkerStore(SQLBaseStore):
self._fetch_event_list(conn, event_list)
- def _fetch_event_list(self, conn, event_list):
+ def _fetch_event_list(
+ self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
+ ) -> None:
"""Handle a load of requests from the _event_fetch_list queue
Args:
- conn (twisted.enterprise.adbapi.Connection): database connection
+ conn: database connection
- event_list (list[Tuple[list[str], Deferred]]):
+ event_list:
The fetch requests. Each entry consists of a list of event
ids to be fetched, and a deferred to be completed once the
events have been fetched.
@@ -788,7 +832,7 @@ class EventsWorkerStore(SQLBaseStore):
row = row_map.get(event_id)
fetched_events[event_id] = row
if row:
- redaction_ids.update(row["redactions"])
+ redaction_ids.update(row.redactions)
events_to_fetch = redaction_ids.difference(fetched_events.keys())
if events_to_fetch:
@@ -799,32 +843,32 @@ class EventsWorkerStore(SQLBaseStore):
for event_id, row in fetched_events.items():
if not row:
continue
- assert row["event_id"] == event_id
+ assert row.event_id == event_id
- rejected_reason = row["rejected_reason"]
+ rejected_reason = row.rejected_reason
# If the event or metadata cannot be parsed, log the error and act
# as if the event is unknown.
try:
- d = db_to_json(row["json"])
+ d = db_to_json(row.json)
except ValueError:
logger.error("Unable to parse json from event: %s", event_id)
continue
try:
- internal_metadata = db_to_json(row["internal_metadata"])
+ internal_metadata = db_to_json(row.internal_metadata)
except ValueError:
logger.error(
"Unable to parse internal_metadata from event: %s", event_id
)
continue
- format_version = row["format_version"]
+ format_version = row.format_version
if format_version is None:
# This means that we stored the event before we had the concept
# of a event format version, so it must be a V1 event.
format_version = EventFormatVersions.V1
- room_version_id = row["room_version_id"]
+ room_version_id = row.room_version_id
if not room_version_id:
# this should only happen for out-of-band membership events which
@@ -889,8 +933,8 @@ class EventsWorkerStore(SQLBaseStore):
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
- original_ev.internal_metadata.stream_ordering = row["stream_ordering"]
- original_ev.internal_metadata.outlier = row["outlier"]
+ original_ev.internal_metadata.stream_ordering = row.stream_ordering
+ original_ev.internal_metadata.outlier = row.outlier
event_map[event_id] = original_ev
@@ -898,7 +942,7 @@ class EventsWorkerStore(SQLBaseStore):
# the cache entries.
result_map = {}
for event_id, original_ev in event_map.items():
- redactions = fetched_events[event_id]["redactions"]
+ redactions = fetched_events[event_id].redactions
redacted_event = self._maybe_redact_event_row(
original_ev, redactions, event_map
)
@@ -912,17 +956,17 @@ class EventsWorkerStore(SQLBaseStore):
return result_map
- async def _enqueue_events(self, events):
+ async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]:
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
Args:
- events (Iterable[str]): events to be fetched.
+ events: events to be fetched.
Returns:
- Dict[str, Dict]: map from event id to row data from the database.
- May contain events that weren't requested.
+ A map from event id to row data from the database. May contain events
+ that weren't requested.
"""
events_d = defer.Deferred()
@@ -949,43 +993,19 @@ class EventsWorkerStore(SQLBaseStore):
return row_map
- def _fetch_event_rows(self, txn, event_ids):
+ def _fetch_event_rows(
+ self, txn: LoggingTransaction, event_ids: Iterable[str]
+ ) -> Dict[str, _EventRow]:
"""Fetch event rows from the database
Events which are not found are omitted from the result.
- The returned per-event dicts contain the following keys:
-
- * event_id (str)
-
- * stream_ordering (int): stream ordering for this event
-
- * json (str): json-encoded event structure
-
- * internal_metadata (str): json-encoded internal metadata dict
-
- * format_version (int|None): The format of the event. Hopefully one
- of EventFormatVersions. 'None' means the event predates
- EventFormatVersions (so the event is format V1).
-
- * room_version_id (str|None): The version of the room which contains the event.
- Hopefully one of RoomVersions.
-
- Due to historical reasons, there may be a few events in the database which
- do not have an associated room; in this case None will be returned here.
-
- * rejected_reason (str|None): if the event was rejected, the reason
- why.
-
- * redactions (List[str]): a list of event-ids which (claim to) redact
- this event.
-
Args:
- txn (twisted.enterprise.adbapi.Connection):
- event_ids (Iterable[str]): event IDs to fetch
+ txn: The database transaction.
+ event_ids: event IDs to fetch
Returns:
- Dict[str, Dict]: a map from event id to event info.
+ A map from event id to event info.
"""
event_dict = {}
for evs in batch_iter(event_ids, 200):
@@ -1013,17 +1033,17 @@ class EventsWorkerStore(SQLBaseStore):
for row in txn:
event_id = row[0]
- event_dict[event_id] = {
- "event_id": event_id,
- "stream_ordering": row[1],
- "internal_metadata": row[2],
- "json": row[3],
- "format_version": row[4],
- "room_version_id": row[5],
- "rejected_reason": row[6],
- "redactions": [],
- "outlier": row[7],
- }
+ event_dict[event_id] = _EventRow(
+ event_id=event_id,
+ stream_ordering=row[1],
+ internal_metadata=row[2],
+ json=row[3],
+ format_version=row[4],
+ room_version_id=row[5],
+ rejected_reason=row[6],
+ redactions=[],
+ outlier=row[7],
+ )
# check for redactions
redactions_sql = "SELECT event_id, redacts FROM redactions WHERE "
@@ -1035,7 +1055,7 @@ class EventsWorkerStore(SQLBaseStore):
for (redacter, redacted) in txn:
d = event_dict.get(redacted)
if d:
- d["redactions"].append(redacter)
+ d.redactions.append(redacter)
return event_dict
|