summary refs log tree commit diff
path: root/synapse/storage/events_worker.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/events_worker.py')
-rw-r--r--synapse/storage/events_worker.py141
1 files changed, 68 insertions, 73 deletions
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index 1716be529a..663991a9b6 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -21,8 +21,9 @@ from canonicaljson import json
 
 from twisted.internet import defer
 
-from synapse.api.constants import EventFormatVersions, EventTypes
+from synapse.api.constants import EventTypes
 from synapse.api.errors import NotFoundError
+from synapse.api.room_versions import EventFormatVersions
 from synapse.events import FrozenEvent, event_type_from_format_version  # noqa: F401
 # these are only included to make the type annotations work
 from synapse.events.snapshot import EventContext  # noqa: F401
@@ -70,17 +71,21 @@ class EventsWorkerStore(SQLBaseStore):
         """
         return self._simple_select_one_onecol(
             table="events",
-            keyvalues={
-                "event_id": event_id,
-            },
+            keyvalues={"event_id": event_id},
             retcol="received_ts",
             desc="get_received_ts",
         )
 
     @defer.inlineCallbacks
-    def get_event(self, event_id, check_redacted=True,
-                  get_prev_content=False, allow_rejected=False,
-                  allow_none=False, check_room_id=None):
+    def get_event(
+        self,
+        event_id,
+        check_redacted=True,
+        get_prev_content=False,
+        allow_rejected=False,
+        allow_none=False,
+        check_room_id=None,
+    ):
         """Get an event from the database by event_id.
 
         Args:
@@ -117,8 +122,13 @@ class EventsWorkerStore(SQLBaseStore):
         defer.returnValue(event)
 
     @defer.inlineCallbacks
-    def get_events(self, event_ids, check_redacted=True,
-                   get_prev_content=False, allow_rejected=False):
+    def get_events(
+        self,
+        event_ids,
+        check_redacted=True,
+        get_prev_content=False,
+        allow_rejected=False,
+    ):
         """Get events from the database
 
         Args:
@@ -142,8 +152,13 @@ class EventsWorkerStore(SQLBaseStore):
         defer.returnValue({e.event_id: e for e in events})
 
     @defer.inlineCallbacks
-    def _get_events(self, event_ids, check_redacted=True,
-                    get_prev_content=False, allow_rejected=False):
+    def _get_events(
+        self,
+        event_ids,
+        check_redacted=True,
+        get_prev_content=False,
+        allow_rejected=False,
+    ):
         if not event_ids:
             defer.returnValue([])
 
@@ -151,8 +166,7 @@ class EventsWorkerStore(SQLBaseStore):
         event_ids = set(event_ids)
 
         event_entry_map = self._get_events_from_cache(
-            event_ids,
-            allow_rejected=allow_rejected,
+            event_ids, allow_rejected=allow_rejected
         )
 
         missing_events_ids = [e for e in event_ids if e not in event_entry_map]
@@ -168,8 +182,7 @@ class EventsWorkerStore(SQLBaseStore):
             #
             # _enqueue_events is a bit of a rubbish name but naming is hard.
             missing_events = yield self._enqueue_events(
-                missing_events_ids,
-                allow_rejected=allow_rejected,
+                missing_events_ids, allow_rejected=allow_rejected
             )
 
             event_entry_map.update(missing_events)
@@ -213,7 +226,10 @@ class EventsWorkerStore(SQLBaseStore):
                     )
 
                     expected_domain = get_domain_from_id(entry.event.sender)
-                    if orig_sender and get_domain_from_id(orig_sender) == expected_domain:
+                    if (
+                        orig_sender
+                        and get_domain_from_id(orig_sender) == expected_domain
+                    ):
                         # This redaction event is allowed. Mark as not needing a
                         # recheck.
                         entry.event.internal_metadata.recheck_redaction = False
@@ -266,8 +282,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         for event_id in events:
             ret = self._get_event_cache.get(
-                (event_id,), None,
-                update_metrics=update_metrics,
+                (event_id,), None, update_metrics=update_metrics
             )
             if not ret:
                 continue
@@ -317,19 +332,13 @@ class EventsWorkerStore(SQLBaseStore):
         with Measure(self._clock, "_fetch_event_list"):
             try:
                 event_id_lists = list(zip(*event_list))[0]
-                event_ids = [
-                    item for sublist in event_id_lists for item in sublist
-                ]
+                event_ids = [item for sublist in event_id_lists for item in sublist]
 
                 rows = self._new_transaction(
-                    conn, "do_fetch", [], [],
-                    self._fetch_event_rows, event_ids,
+                    conn, "do_fetch", [], [], self._fetch_event_rows, event_ids
                 )
 
-                row_dict = {
-                    r["event_id"]: r
-                    for r in rows
-                }
+                row_dict = {r["event_id"]: r for r in rows}
 
                 # We only want to resolve deferreds from the main thread
                 def fire(lst, res):
@@ -337,13 +346,10 @@ class EventsWorkerStore(SQLBaseStore):
                         if not d.called:
                             try:
                                 with PreserveLoggingContext():
-                                    d.callback([
-                                        res[i]
-                                        for i in ids
-                                        if i in res
-                                    ])
+                                    d.callback([res[i] for i in ids if i in res])
                             except Exception:
                                 logger.exception("Failed to callback")
+
                 with PreserveLoggingContext():
                     self.hs.get_reactor().callFromThread(fire, event_list, row_dict)
             except Exception as e:
@@ -370,9 +376,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         events_d = defer.Deferred()
         with self._event_fetch_lock:
-            self._event_fetch_list.append(
-                (events, events_d)
-            )
+            self._event_fetch_list.append((events, events_d))
 
             self._event_fetch_lock.notify()
 
@@ -384,9 +388,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         if should_start:
             run_as_background_process(
-                "fetch_events",
-                self.runWithConnection,
-                self._do_fetch,
+                "fetch_events", self.runWithConnection, self._do_fetch
             )
 
         logger.debug("Loading %d events", len(events))
@@ -397,29 +399,30 @@ class EventsWorkerStore(SQLBaseStore):
         if not allow_rejected:
             rows[:] = [r for r in rows if not r["rejects"]]
 
-        res = yield make_deferred_yieldable(defer.gatherResults(
-            [
-                run_in_background(
-                    self._get_event_from_row,
-                    row["internal_metadata"], row["json"], row["redacts"],
-                    rejected_reason=row["rejects"],
-                    format_version=row["format_version"],
-                )
-                for row in rows
-            ],
-            consumeErrors=True
-        ))
+        res = yield make_deferred_yieldable(
+            defer.gatherResults(
+                [
+                    run_in_background(
+                        self._get_event_from_row,
+                        row["internal_metadata"],
+                        row["json"],
+                        row["redacts"],
+                        rejected_reason=row["rejects"],
+                        format_version=row["format_version"],
+                    )
+                    for row in rows
+                ],
+                consumeErrors=True,
+            )
+        )
 
-        defer.returnValue({
-            e.event.event_id: e
-            for e in res if e
-        })
+        defer.returnValue({e.event.event_id: e for e in res if e})
 
     def _fetch_event_rows(self, txn, events):
         rows = []
         N = 200
         for i in range(1 + len(events) // N):
-            evs = events[i * N:(i + 1) * N]
+            evs = events[i * N : (i + 1) * N]
             if not evs:
                 break
 
@@ -443,8 +446,9 @@ class EventsWorkerStore(SQLBaseStore):
         return rows
 
     @defer.inlineCallbacks
-    def _get_event_from_row(self, internal_metadata, js, redacted,
-                            format_version, rejected_reason=None):
+    def _get_event_from_row(
+        self, internal_metadata, js, redacted, format_version, rejected_reason=None
+    ):
         with Measure(self._clock, "_get_event_from_row"):
             d = json.loads(js)
             internal_metadata = json.loads(internal_metadata)
@@ -483,9 +487,7 @@ class EventsWorkerStore(SQLBaseStore):
                 # Get the redaction event.
 
                 because = yield self.get_event(
-                    redaction_id,
-                    check_redacted=False,
-                    allow_none=True,
+                    redaction_id, check_redacted=False, allow_none=True
                 )
 
                 if because:
@@ -507,8 +509,7 @@ class EventsWorkerStore(SQLBaseStore):
                             redacted_event = None
 
             cache_entry = _EventCacheEntry(
-                event=original_ev,
-                redacted_event=redacted_event,
+                event=original_ev, redacted_event=redacted_event
             )
 
             self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
@@ -544,23 +545,17 @@ class EventsWorkerStore(SQLBaseStore):
         results = set()
 
         def have_seen_events_txn(txn, chunk):
-            sql = (
-                "SELECT event_id FROM events as e WHERE e.event_id IN (%s)"
-                % (",".join("?" * len(chunk)), )
+            sql = "SELECT event_id FROM events as e WHERE e.event_id IN (%s)" % (
+                ",".join("?" * len(chunk)),
             )
             txn.execute(sql, chunk)
-            for (event_id, ) in txn:
+            for (event_id,) in txn:
                 results.add(event_id)
 
         # break the input up into chunks of 100
         input_iterator = iter(event_ids)
-        for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)),
-                          []):
-            yield self.runInteraction(
-                "have_seen_events",
-                have_seen_events_txn,
-                chunk,
-            )
+        for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
+            yield self.runInteraction("have_seen_events", have_seen_events_txn, chunk)
         defer.returnValue(results)
 
     def get_seen_events_with_rejections(self, event_ids):