diff options
Diffstat (limited to 'synapse/storage/events_worker.py')
-rw-r--r-- | synapse/storage/events_worker.py | 141 |
1 files changed, 68 insertions, 73 deletions
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 1716be529a..663991a9b6 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -21,8 +21,9 @@ from canonicaljson import json from twisted.internet import defer -from synapse.api.constants import EventFormatVersions, EventTypes +from synapse.api.constants import EventTypes from synapse.api.errors import NotFoundError +from synapse.api.room_versions import EventFormatVersions from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401 # these are only included to make the type annotations work from synapse.events.snapshot import EventContext # noqa: F401 @@ -70,17 +71,21 @@ class EventsWorkerStore(SQLBaseStore): """ return self._simple_select_one_onecol( table="events", - keyvalues={ - "event_id": event_id, - }, + keyvalues={"event_id": event_id}, retcol="received_ts", desc="get_received_ts", ) @defer.inlineCallbacks - def get_event(self, event_id, check_redacted=True, - get_prev_content=False, allow_rejected=False, - allow_none=False, check_room_id=None): + def get_event( + self, + event_id, + check_redacted=True, + get_prev_content=False, + allow_rejected=False, + allow_none=False, + check_room_id=None, + ): """Get an event from the database by event_id. Args: @@ -117,8 +122,13 @@ class EventsWorkerStore(SQLBaseStore): defer.returnValue(event) @defer.inlineCallbacks - def get_events(self, event_ids, check_redacted=True, - get_prev_content=False, allow_rejected=False): + def get_events( + self, + event_ids, + check_redacted=True, + get_prev_content=False, + allow_rejected=False, + ): """Get events from the database Args: @@ -142,8 +152,13 @@ class EventsWorkerStore(SQLBaseStore): defer.returnValue({e.event_id: e for e in events}) @defer.inlineCallbacks - def _get_events(self, event_ids, check_redacted=True, - get_prev_content=False, allow_rejected=False): + def _get_events( + self, + event_ids, + check_redacted=True, + get_prev_content=False, + allow_rejected=False, + ): if not event_ids: defer.returnValue([]) @@ -151,8 +166,7 @@ class EventsWorkerStore(SQLBaseStore): event_ids = set(event_ids) event_entry_map = self._get_events_from_cache( - event_ids, - allow_rejected=allow_rejected, + event_ids, allow_rejected=allow_rejected ) missing_events_ids = [e for e in event_ids if e not in event_entry_map] @@ -168,8 +182,7 @@ class EventsWorkerStore(SQLBaseStore): # # _enqueue_events is a bit of a rubbish name but naming is hard. missing_events = yield self._enqueue_events( - missing_events_ids, - allow_rejected=allow_rejected, + missing_events_ids, allow_rejected=allow_rejected ) event_entry_map.update(missing_events) @@ -213,7 +226,10 @@ class EventsWorkerStore(SQLBaseStore): ) expected_domain = get_domain_from_id(entry.event.sender) - if orig_sender and get_domain_from_id(orig_sender) == expected_domain: + if ( + orig_sender + and get_domain_from_id(orig_sender) == expected_domain + ): # This redaction event is allowed. Mark as not needing a # recheck. entry.event.internal_metadata.recheck_redaction = False @@ -266,8 +282,7 @@ class EventsWorkerStore(SQLBaseStore): for event_id in events: ret = self._get_event_cache.get( - (event_id,), None, - update_metrics=update_metrics, + (event_id,), None, update_metrics=update_metrics ) if not ret: continue @@ -317,19 +332,13 @@ class EventsWorkerStore(SQLBaseStore): with Measure(self._clock, "_fetch_event_list"): try: event_id_lists = list(zip(*event_list))[0] - event_ids = [ - item for sublist in event_id_lists for item in sublist - ] + event_ids = [item for sublist in event_id_lists for item in sublist] rows = self._new_transaction( - conn, "do_fetch", [], [], - self._fetch_event_rows, event_ids, + conn, "do_fetch", [], [], self._fetch_event_rows, event_ids ) - row_dict = { - r["event_id"]: r - for r in rows - } + row_dict = {r["event_id"]: r for r in rows} # We only want to resolve deferreds from the main thread def fire(lst, res): @@ -337,13 +346,10 @@ class EventsWorkerStore(SQLBaseStore): if not d.called: try: with PreserveLoggingContext(): - d.callback([ - res[i] - for i in ids - if i in res - ]) + d.callback([res[i] for i in ids if i in res]) except Exception: logger.exception("Failed to callback") + with PreserveLoggingContext(): self.hs.get_reactor().callFromThread(fire, event_list, row_dict) except Exception as e: @@ -370,9 +376,7 @@ class EventsWorkerStore(SQLBaseStore): events_d = defer.Deferred() with self._event_fetch_lock: - self._event_fetch_list.append( - (events, events_d) - ) + self._event_fetch_list.append((events, events_d)) self._event_fetch_lock.notify() @@ -384,9 +388,7 @@ class EventsWorkerStore(SQLBaseStore): if should_start: run_as_background_process( - "fetch_events", - self.runWithConnection, - self._do_fetch, + "fetch_events", self.runWithConnection, self._do_fetch ) logger.debug("Loading %d events", len(events)) @@ -397,29 +399,30 @@ class EventsWorkerStore(SQLBaseStore): if not allow_rejected: rows[:] = [r for r in rows if not r["rejects"]] - res = yield make_deferred_yieldable(defer.gatherResults( - [ - run_in_background( - self._get_event_from_row, - row["internal_metadata"], row["json"], row["redacts"], - rejected_reason=row["rejects"], - format_version=row["format_version"], - ) - for row in rows - ], - consumeErrors=True - )) + res = yield make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background( + self._get_event_from_row, + row["internal_metadata"], + row["json"], + row["redacts"], + rejected_reason=row["rejects"], + format_version=row["format_version"], + ) + for row in rows + ], + consumeErrors=True, + ) + ) - defer.returnValue({ - e.event.event_id: e - for e in res if e - }) + defer.returnValue({e.event.event_id: e for e in res if e}) def _fetch_event_rows(self, txn, events): rows = [] N = 200 for i in range(1 + len(events) // N): - evs = events[i * N:(i + 1) * N] + evs = events[i * N : (i + 1) * N] if not evs: break @@ -443,8 +446,9 @@ class EventsWorkerStore(SQLBaseStore): return rows @defer.inlineCallbacks - def _get_event_from_row(self, internal_metadata, js, redacted, - format_version, rejected_reason=None): + def _get_event_from_row( + self, internal_metadata, js, redacted, format_version, rejected_reason=None + ): with Measure(self._clock, "_get_event_from_row"): d = json.loads(js) internal_metadata = json.loads(internal_metadata) @@ -483,9 +487,7 @@ class EventsWorkerStore(SQLBaseStore): # Get the redaction event. because = yield self.get_event( - redaction_id, - check_redacted=False, - allow_none=True, + redaction_id, check_redacted=False, allow_none=True ) if because: @@ -507,8 +509,7 @@ class EventsWorkerStore(SQLBaseStore): redacted_event = None cache_entry = _EventCacheEntry( - event=original_ev, - redacted_event=redacted_event, + event=original_ev, redacted_event=redacted_event ) self._get_event_cache.prefill((original_ev.event_id,), cache_entry) @@ -544,23 +545,17 @@ class EventsWorkerStore(SQLBaseStore): results = set() def have_seen_events_txn(txn, chunk): - sql = ( - "SELECT event_id FROM events as e WHERE e.event_id IN (%s)" - % (",".join("?" * len(chunk)), ) + sql = "SELECT event_id FROM events as e WHERE e.event_id IN (%s)" % ( + ",".join("?" * len(chunk)), ) txn.execute(sql, chunk) - for (event_id, ) in txn: + for (event_id,) in txn: results.add(event_id) # break the input up into chunks of 100 input_iterator = iter(event_ids) - for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), - []): - yield self.runInteraction( - "have_seen_events", - have_seen_events_txn, - chunk, - ) + for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): + yield self.runInteraction("have_seen_events", have_seen_events_txn, chunk) defer.returnValue(results) def get_seen_events_with_rejections(self, event_ids): |