diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 8c63a0dc4d..e6247d682d 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -19,9 +19,10 @@ import itertools
import logging
import threading
from collections import namedtuple
-from typing import List, Optional, Tuple
+from typing import Dict, Iterable, List, Optional, Tuple, overload
from constantly import NamedConstant, Names
+from typing_extensions import Literal
from twisted.internet import defer
@@ -32,7 +33,7 @@ from synapse.api.room_versions import (
EventFormatVersions,
RoomVersions,
)
-from synapse.events import make_event_from_dict
+from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import prune_event
from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -42,8 +43,8 @@ from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
-from synapse.types import get_domain_from_id
-from synapse.util.caches.descriptors import Cache, cachedInlineCallbacks
+from synapse.types import Collection, get_domain_from_id
+from synapse.util.caches.descriptors import Cache, cached
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@@ -112,33 +113,58 @@ class EventsWorkerStore(SQLBaseStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == EventsStream.NAME:
- self._stream_id_gen.advance(token)
+ self._stream_id_gen.advance(instance_name, token)
elif stream_name == BackfillStream.NAME:
- self._backfill_id_gen.advance(-token)
+ self._backfill_id_gen.advance(instance_name, -token)
super().process_replication_rows(stream_name, instance_name, token, rows)
- def get_received_ts(self, event_id):
+ async def get_received_ts(self, event_id: str) -> Optional[int]:
"""Get received_ts (when it was persisted) for the event.
Raises an exception for unknown events.
Args:
- event_id (str)
+ event_id: The event ID to query.
Returns:
- Deferred[int|None]: Timestamp in milliseconds, or None for events
- that were persisted before received_ts was implemented.
+ Timestamp in milliseconds, or None for events that were persisted
+ before received_ts was implemented.
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="events",
keyvalues={"event_id": event_id},
retcol="received_ts",
desc="get_received_ts",
)
- @defer.inlineCallbacks
- def get_event(
+ # Inform mypy that if allow_none is False (the default) then get_event
+ # always returns an EventBase.
+ @overload
+ async def get_event(
+ self,
+ event_id: str,
+ redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
+ allow_none: Literal[False] = False,
+ check_room_id: Optional[str] = None,
+ ) -> EventBase:
+ ...
+
+ @overload
+ async def get_event(
+ self,
+ event_id: str,
+ redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
+ allow_none: Literal[True] = False,
+ check_room_id: Optional[str] = None,
+ ) -> Optional[EventBase]:
+ ...
+
+ async def get_event(
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
@@ -146,7 +172,7 @@ class EventsWorkerStore(SQLBaseStore):
allow_rejected: bool = False,
allow_none: bool = False,
check_room_id: Optional[str] = None,
- ):
+ ) -> Optional[EventBase]:
"""Get an event from the database by event_id.
Args:
@@ -171,12 +197,12 @@ class EventsWorkerStore(SQLBaseStore):
If there is a mismatch, behave as per allow_none.
Returns:
- Deferred[EventBase|None]
+ The event, or None if the event was not found.
"""
if not isinstance(event_id, str):
raise TypeError("Invalid event event_id %r" % (event_id,))
- events = yield self.get_events_as_list(
+ events = await self.get_events_as_list(
[event_id],
redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
@@ -194,14 +220,13 @@ class EventsWorkerStore(SQLBaseStore):
return event
- @defer.inlineCallbacks
- def get_events(
+ async def get_events(
self,
- event_ids: List[str],
+ event_ids: Iterable[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
- ):
+ ) -> Dict[str, EventBase]:
"""Get events from the database
Args:
@@ -220,9 +245,9 @@ class EventsWorkerStore(SQLBaseStore):
omits rejeted events from the response.
Returns:
- Deferred : Dict from event_id to event.
+ A mapping from event_id to event.
"""
- events = yield self.get_events_as_list(
+ events = await self.get_events_as_list(
event_ids,
redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
@@ -231,14 +256,13 @@ class EventsWorkerStore(SQLBaseStore):
return {e.event_id: e for e in events}
- @defer.inlineCallbacks
- def get_events_as_list(
+ async def get_events_as_list(
self,
- event_ids: List[str],
+ event_ids: Collection[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
- ):
+ ) -> List[EventBase]:
"""Get events from the database and return in a list in the same order
as given by `event_ids` arg.
@@ -259,8 +283,8 @@ class EventsWorkerStore(SQLBaseStore):
omits rejected events from the response.
Returns:
- Deferred[list[EventBase]]: List of events fetched from the database. The
- events are in the same order as `event_ids` arg.
+ List of events fetched from the database. The events are in the same
+ order as `event_ids` arg.
Note that the returned list may be smaller than the list of event
IDs if not all events could be fetched.
@@ -270,7 +294,7 @@ class EventsWorkerStore(SQLBaseStore):
return []
# there may be duplicates so we cast the list to a set
- event_entry_map = yield self._get_events_from_cache_or_db(
+ event_entry_map = await self._get_events_from_cache_or_db(
set(event_ids), allow_rejected=allow_rejected
)
@@ -305,7 +329,7 @@ class EventsWorkerStore(SQLBaseStore):
continue
redacted_event_id = entry.event.redacts
- event_map = yield self._get_events_from_cache_or_db([redacted_event_id])
+ event_map = await self._get_events_from_cache_or_db([redacted_event_id])
original_event_entry = event_map.get(redacted_event_id)
if not original_event_entry:
# we don't have the redacted event (or it was rejected).
@@ -371,7 +395,7 @@ class EventsWorkerStore(SQLBaseStore):
if get_prev_content:
if "replaces_state" in event.unsigned:
- prev = yield self.get_event(
+ prev = await self.get_event(
event.unsigned["replaces_state"],
get_prev_content=False,
allow_none=True,
@@ -383,8 +407,7 @@ class EventsWorkerStore(SQLBaseStore):
return events
- @defer.inlineCallbacks
- def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
+ async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
"""Fetch a bunch of events from the cache or the database.
If events are pulled from the database, they will be cached for future lookups.
@@ -399,7 +422,7 @@ class EventsWorkerStore(SQLBaseStore):
rejected events are omitted from the response.
Returns:
- Deferred[Dict[str, _EventCacheEntry]]:
+ Dict[str, _EventCacheEntry]:
map from event id to result
"""
event_entry_map = self._get_events_from_cache(
@@ -417,7 +440,7 @@ class EventsWorkerStore(SQLBaseStore):
# the events have been redacted, and if so pulling the redaction event out
# of the database to check it.
#
- missing_events = yield self._get_events_from_db(
+ missing_events = await self._get_events_from_db(
missing_events_ids, allow_rejected=allow_rejected
)
@@ -525,8 +548,7 @@ class EventsWorkerStore(SQLBaseStore):
with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list, e)
- @defer.inlineCallbacks
- def _get_events_from_db(self, event_ids, allow_rejected=False):
+ async def _get_events_from_db(self, event_ids, allow_rejected=False):
"""Fetch a bunch of events from the database.
Returned events will be added to the cache for future lookups.
@@ -540,7 +562,7 @@ class EventsWorkerStore(SQLBaseStore):
rejected events are omitted from the response.
Returns:
- Deferred[Dict[str, _EventCacheEntry]]:
+ Dict[str, _EventCacheEntry]:
map from event id to result. May return extra events which
weren't asked for.
"""
@@ -548,7 +570,7 @@ class EventsWorkerStore(SQLBaseStore):
events_to_fetch = event_ids
while events_to_fetch:
- row_map = yield self._enqueue_events(events_to_fetch)
+ row_map = await self._enqueue_events(events_to_fetch)
# we need to recursively fetch any redactions of those events
redaction_ids = set()
@@ -574,8 +596,20 @@ class EventsWorkerStore(SQLBaseStore):
if not allow_rejected and rejected_reason:
continue
- d = db_to_json(row["json"])
- internal_metadata = db_to_json(row["internal_metadata"])
+ # If the event or metadata cannot be parsed, log the error and act
+ # as if the event is unknown.
+ try:
+ d = db_to_json(row["json"])
+ except ValueError:
+ logger.error("Unable to parse json from event: %s", event_id)
+ continue
+ try:
+ internal_metadata = db_to_json(row["internal_metadata"])
+ except ValueError:
+ logger.error(
+ "Unable to parse internal_metadata from event: %s", event_id
+ )
+ continue
format_version = row["format_version"]
if format_version is None:
@@ -586,19 +620,38 @@ class EventsWorkerStore(SQLBaseStore):
room_version_id = row["room_version_id"]
if not room_version_id:
- # this should only happen for out-of-band membership events
- if not internal_metadata.get("out_of_band_membership"):
- logger.warning(
- "Room %s for event %s is unknown", d["room_id"], event_id
+ # this should only happen for out-of-band membership events which
+ # arrived before #6983 landed. For all other events, we should have
+ # an entry in the 'rooms' table.
+ #
+ # However, the 'out_of_band_membership' flag is unreliable for older
+ # invites, so just accept it for all membership events.
+ #
+ if d["type"] != EventTypes.Member:
+ raise Exception(
+ "Room %s for event %s is unknown" % (d["room_id"], event_id)
)
- continue
- # take a wild stab at the room version based on the event format
+ # so, assuming this is an out-of-band-invite that arrived before #6983
+ # landed, we know that the room version must be v5 or earlier (because
+ # v6 hadn't been invented at that point, so invites from such rooms
+ # would have been rejected.)
+ #
+ # The main reason we need to know the room version here (other than
+ # choosing the right python Event class) is in case the event later has
+ # to be redacted - and all the room versions up to v5 used the same
+ # redaction algorithm.
+ #
+ # So, the following approximations should be adequate.
+
if format_version == EventFormatVersions.V1:
+ # if it's event format v1 then it must be room v1 or v2
room_version = RoomVersions.V1
elif format_version == EventFormatVersions.V2:
+ # if it's event format v2 then it must be room v3
room_version = RoomVersions.V3
else:
+ # if it's event format v3 then it must be room v4 or v5
room_version = RoomVersions.V5
else:
room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
@@ -650,8 +703,7 @@ class EventsWorkerStore(SQLBaseStore):
return result_map
- @defer.inlineCallbacks
- def _enqueue_events(self, events):
+ async def _enqueue_events(self, events):
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
@@ -660,7 +712,7 @@ class EventsWorkerStore(SQLBaseStore):
events (Iterable[str]): events to be fetched.
Returns:
- Deferred[Dict[str, Dict]]: map from event id to row data from the database.
+ Dict[str, Dict]: map from event id to row data from the database.
May contain events that weren't requested.
"""
@@ -683,7 +735,7 @@ class EventsWorkerStore(SQLBaseStore):
logger.debug("Loading %d events: %s", len(events), events)
with PreserveLoggingContext():
- row_map = yield events_d
+ row_map = await events_d
logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))
return row_map
@@ -842,33 +894,29 @@ class EventsWorkerStore(SQLBaseStore):
# no valid redaction found for this event
return None
- @defer.inlineCallbacks
- def have_events_in_timeline(self, event_ids):
+ async def have_events_in_timeline(self, event_ids):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
- rows = yield defer.ensureDeferred(
- self.db_pool.simple_select_many_batch(
- table="events",
- retcols=("event_id",),
- column="event_id",
- iterable=list(event_ids),
- keyvalues={"outlier": False},
- desc="have_events_in_timeline",
- )
+ rows = await self.db_pool.simple_select_many_batch(
+ table="events",
+ retcols=("event_id",),
+ column="event_id",
+ iterable=list(event_ids),
+ keyvalues={"outlier": False},
+ desc="have_events_in_timeline",
)
return {r["event_id"] for r in rows}
- @defer.inlineCallbacks
- def have_seen_events(self, event_ids):
+ async def have_seen_events(self, event_ids):
"""Given a list of event ids, check if we have already processed them.
Args:
event_ids (iterable[str]):
Returns:
- Deferred[set[str]]: The events we have already seen.
+ set[str]: The events we have already seen.
"""
results = set()
@@ -884,7 +932,7 @@ class EventsWorkerStore(SQLBaseStore):
# break the input up into chunks of 100
input_iterator = iter(event_ids)
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"have_seen_events", have_seen_events_txn, chunk
)
return results
@@ -914,8 +962,7 @@ class EventsWorkerStore(SQLBaseStore):
room_id,
)
- @defer.inlineCallbacks
- def get_room_complexity(self, room_id):
+ async def get_room_complexity(self, room_id):
"""
Get a rough approximation of the complexity of the room. This is used by
remote servers to decide whether they wish to join the room or not.
@@ -926,9 +973,9 @@ class EventsWorkerStore(SQLBaseStore):
room_id (str)
Returns:
- Deferred[dict[str:int]] of complexity version to complexity.
+ dict[str:int] of complexity version to complexity.
"""
- state_events = yield self.get_current_state_event_counts(room_id)
+ state_events = await self.get_current_state_event_counts(room_id)
# Call this one "v1", so we can introduce new ones as we want to develop
# it.
@@ -1165,9 +1212,9 @@ class EventsWorkerStore(SQLBaseStore):
to_2, so_2 = await self.get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2)
- @cachedInlineCallbacks(max_entries=5000)
- def get_event_ordering(self, event_id):
- res = yield self.db_pool.simple_select_one(
+ @cached(max_entries=5000)
+ async def get_event_ordering(self, event_id):
+ res = await self.db_pool.simple_select_one(
table="events",
retcols=["topological_ordering", "stream_ordering"],
keyvalues={"event_id": event_id},
|