summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/events.py542
-rw-r--r--synapse/storage/roommember.py173
2 files changed, 335 insertions, 380 deletions
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 428300ea0a..b1d5f469c8 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -30,7 +30,6 @@ from twisted.internet import defer
 import synapse.metrics
 from synapse.api.constants import EventTypes
 from synapse.api.errors import SynapseError
-# these are only included to make the type annotations work
 from synapse.events import EventBase  # noqa: F401
 from synapse.events.snapshot import EventContext  # noqa: F401
 from synapse.metrics.background_process_metrics import run_as_background_process
@@ -51,8 +50,11 @@ from synapse.util.metrics import Measure
 logger = logging.getLogger(__name__)
 
 persist_event_counter = Counter("synapse_storage_events_persisted_events", "")
-event_counter = Counter("synapse_storage_events_persisted_events_sep", "",
-                        ["type", "origin_type", "origin_entity"])
+event_counter = Counter(
+    "synapse_storage_events_persisted_events_sep",
+    "",
+    ["type", "origin_type", "origin_entity"],
+)
 
 # The number of times we are recalculating the current state
 state_delta_counter = Counter("synapse_storage_events_state_delta", "")
@@ -60,13 +62,15 @@ state_delta_counter = Counter("synapse_storage_events_state_delta", "")
 # The number of times we are recalculating state when there is only a
 # single forward extremity
 state_delta_single_event_counter = Counter(
-    "synapse_storage_events_state_delta_single_event", "")
+    "synapse_storage_events_state_delta_single_event", ""
+)
 
 # The number of times we are reculating state when we could have resonably
 # calculated the delta when we calculated the state for an event we were
 # persisting.
 state_delta_reuse_delta_counter = Counter(
-    "synapse_storage_events_state_delta_reuse_delta", "")
+    "synapse_storage_events_state_delta_reuse_delta", ""
+)
 
 
 def encode_json(json_object):
@@ -84,9 +88,9 @@ class _EventPeristenceQueue(object):
     concurrent transaction per room.
     """
 
-    _EventPersistQueueItem = namedtuple("_EventPersistQueueItem", (
-        "events_and_contexts", "backfilled", "deferred",
-    ))
+    _EventPersistQueueItem = namedtuple(
+        "_EventPersistQueueItem", ("events_and_contexts", "backfilled", "deferred")
+    )
 
     def __init__(self):
         self._event_persist_queues = {}
@@ -119,11 +123,13 @@ class _EventPeristenceQueue(object):
 
         deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
 
-        queue.append(self._EventPersistQueueItem(
-            events_and_contexts=events_and_contexts,
-            backfilled=backfilled,
-            deferred=deferred,
-        ))
+        queue.append(
+            self._EventPersistQueueItem(
+                events_and_contexts=events_and_contexts,
+                backfilled=backfilled,
+                deferred=deferred,
+            )
+        )
 
         return deferred.observe()
 
@@ -191,6 +197,7 @@ def _retry_on_integrity_error(func):
     Args:
         func: function that returns a Deferred and accepts a `delete_existing` arg
     """
+
     @wraps(func)
     @defer.inlineCallbacks
     def f(self, *args, **kwargs):
@@ -206,8 +213,12 @@ def _retry_on_integrity_error(func):
 
 # inherits from EventFederationStore so that we can call _update_backward_extremities
 # and _handle_mult_prev_events (though arguably those could both be moved in here)
-class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore,
-                  BackgroundUpdateStore):
+class EventsStore(
+    StateGroupWorkerStore,
+    EventFederationStore,
+    EventsWorkerStore,
+    BackgroundUpdateStore,
+):
     EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
     EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
 
@@ -265,8 +276,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         deferreds = []
         for room_id, evs_ctxs in iteritems(partitioned):
             d = self._event_persist_queue.add_to_queue(
-                room_id, evs_ctxs,
-                backfilled=backfilled,
+                room_id, evs_ctxs, backfilled=backfilled
             )
             deferreds.append(d)
 
@@ -296,8 +306,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             and the stream ordering of the latest persisted event
         """
         deferred = self._event_persist_queue.add_to_queue(
-            event.room_id, [(event, context)],
-            backfilled=backfilled,
+            event.room_id, [(event, context)], backfilled=backfilled
         )
 
         self._maybe_start_persisting(event.room_id)
@@ -312,16 +321,16 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         def persisting_queue(item):
             with Measure(self._clock, "persist_events"):
                 yield self._persist_events(
-                    item.events_and_contexts,
-                    backfilled=item.backfilled,
+                    item.events_and_contexts, backfilled=item.backfilled
                 )
 
         self._event_persist_queue.handle_queue(room_id, persisting_queue)
 
     @_retry_on_integrity_error
     @defer.inlineCallbacks
-    def _persist_events(self, events_and_contexts, backfilled=False,
-                        delete_existing=False):
+    def _persist_events(
+        self, events_and_contexts, backfilled=False, delete_existing=False
+    ):
         """Persist events to db
 
         Args:
@@ -345,13 +354,11 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             )
 
         with stream_ordering_manager as stream_orderings:
-            for (event, context), stream, in zip(
-                events_and_contexts, stream_orderings
-            ):
+            for (event, context), stream in zip(events_and_contexts, stream_orderings):
                 event.internal_metadata.stream_ordering = stream
 
             chunks = [
-                events_and_contexts[x:x + 100]
+                events_and_contexts[x : x + 100]
                 for x in range(0, len(events_and_contexts), 100)
             ]
 
@@ -445,12 +452,9 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
                                         state_delta_reuse_delta_counter.inc()
                                         break
 
-                            logger.info(
-                                "Calculating state delta for room %s", room_id,
-                            )
+                            logger.info("Calculating state delta for room %s", room_id)
                             with Measure(
-                                self._clock,
-                                "persist_events.get_new_state_after_events",
+                                self._clock, "persist_events.get_new_state_after_events"
                             ):
                                 res = yield self._get_new_state_after_events(
                                     room_id,
@@ -470,11 +474,10 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
                                 state_delta_for_room[room_id] = ([], delta_ids)
                             elif current_state is not None:
                                 with Measure(
-                                    self._clock,
-                                    "persist_events.calculate_state_delta",
+                                    self._clock, "persist_events.calculate_state_delta"
                                 ):
                                     delta = yield self._calculate_state_delta(
-                                        room_id, current_state,
+                                        room_id, current_state
                                     )
                                 state_delta_for_room[room_id] = delta
 
@@ -498,7 +501,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
                     # backfilled events have negative stream orderings, so we don't
                     # want to set the event_persisted_position to that.
                     synapse.metrics.event_persisted_position.set(
-                        chunk[-1][0].internal_metadata.stream_ordering,
+                        chunk[-1][0].internal_metadata.stream_ordering
                     )
 
                 for event, context in chunk:
@@ -515,9 +518,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
                     event_counter.labels(event.type, origin_type, origin_entity).inc()
 
                 for room_id, new_state in iteritems(current_state_for_room):
-                    self.get_current_state_ids.prefill(
-                        (room_id, ), new_state
-                    )
+                    self.get_current_state_ids.prefill((room_id,), new_state)
 
                 for room_id, latest_event_ids in iteritems(new_forward_extremeties):
                     self.get_latest_event_ids_in_room.prefill(
@@ -535,8 +536,10 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         # we're only interested in new events which aren't outliers and which aren't
         # being rejected.
         new_events = [
-            event for event, ctx in event_contexts
-            if not event.internal_metadata.is_outlier() and not ctx.rejected
+            event
+            for event, ctx in event_contexts
+            if not event.internal_metadata.is_outlier()
+            and not ctx.rejected
             and not event.internal_metadata.is_soft_failed()
         ]
 
@@ -544,15 +547,11 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         result = set(latest_event_ids)
 
         # add all the new events to the list
-        result.update(
-            event.event_id for event in new_events
-        )
+        result.update(event.event_id for event in new_events)
 
         # Now remove all events which are prev_events of any of the new events
         result.difference_update(
-            e_id
-            for event in new_events
-            for e_id in event.prev_event_ids()
+            e_id for event in new_events for e_id in event.prev_event_ids()
         )
 
         # Finally, remove any events which are prev_events of any existing events.
@@ -592,17 +591,14 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             results.extend(r[0] for r in txn)
 
         for chunk in batch_iter(event_ids, 100):
-            yield self.runInteraction(
-                "_get_events_which_are_prevs",
-                _get_events,
-                chunk,
-            )
+            yield self.runInteraction("_get_events_which_are_prevs", _get_events, chunk)
 
         defer.returnValue(results)
 
     @defer.inlineCallbacks
-    def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids,
-                                    new_latest_event_ids):
+    def _get_new_state_after_events(
+        self, room_id, events_context, old_latest_event_ids, new_latest_event_ids
+    ):
         """Calculate the current state dict after adding some new events to
         a room
 
@@ -642,7 +638,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
                 if not ev.internal_metadata.is_outlier():
                     raise Exception(
                         "Context for new event %s has no state "
-                        "group" % (ev.event_id, ),
+                        "group" % (ev.event_id,)
                     )
                 continue
 
@@ -682,9 +678,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
 
         if missing_event_ids:
             # Now pull out the state groups for any missing events from DB
-            event_to_groups = yield self._get_state_group_for_events(
-                missing_event_ids,
-            )
+            event_to_groups = yield self._get_state_group_for_events(missing_event_ids)
             event_id_to_state_group.update(event_to_groups)
 
         # State groups of old_latest_event_ids
@@ -710,9 +704,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             new_state_group = next(iter(new_state_groups))
             old_state_group = next(iter(old_state_groups))
 
-            delta_ids = state_group_deltas.get(
-                (old_state_group, new_state_group,), None
-            )
+            delta_ids = state_group_deltas.get((old_state_group, new_state_group), None)
             if delta_ids is not None:
                 # We have a delta from the existing to new current state,
                 # so lets just return that. If we happen to already have
@@ -735,9 +727,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
 
         # Ok, we need to defer to the state handler to resolve our state sets.
 
-        state_groups = {
-            sg: state_groups_map[sg] for sg in new_state_groups
-        }
+        state_groups = {sg: state_groups_map[sg] for sg in new_state_groups}
 
         events_map = {ev.event_id: ev for ev, _ in events_context}
 
@@ -755,8 +745,11 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
 
         logger.debug("calling resolve_state_groups from preserve_events")
         res = yield self._state_resolution_handler.resolve_state_groups(
-            room_id, room_version, state_groups, events_map,
-            state_res_store=StateResolutionStore(self)
+            room_id,
+            room_version,
+            state_groups,
+            events_map,
+            state_res_store=StateResolutionStore(self),
         )
 
         defer.returnValue((res.state, None))
@@ -774,22 +767,26 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         """
         existing_state = yield self.get_current_state_ids(room_id)
 
-        to_delete = [
-            key for key in existing_state
-            if key not in current_state
-        ]
+        to_delete = [key for key in existing_state if key not in current_state]
 
         to_insert = {
-            key: ev_id for key, ev_id in iteritems(current_state)
+            key: ev_id
+            for key, ev_id in iteritems(current_state)
             if ev_id != existing_state.get(key)
         }
 
         defer.returnValue((to_delete, to_insert))
 
     @log_function
-    def _persist_events_txn(self, txn, events_and_contexts, backfilled,
-                            delete_existing=False, state_delta_for_room={},
-                            new_forward_extremeties={}):
+    def _persist_events_txn(
+        self,
+        txn,
+        events_and_contexts,
+        backfilled,
+        delete_existing=False,
+        state_delta_for_room={},
+        new_forward_extremeties={},
+    ):
         """Insert some number of room events into the necessary database tables.
 
         Rejected events are only inserted into the events table, the events_json table,
@@ -828,20 +825,17 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
 
         # Ensure that we don't have the same event twice.
         events_and_contexts = self._filter_events_and_contexts_for_duplicates(
-            events_and_contexts,
+            events_and_contexts
         )
 
         self._update_room_depths_txn(
-            txn,
-            events_and_contexts=events_and_contexts,
-            backfilled=backfilled,
+            txn, events_and_contexts=events_and_contexts, backfilled=backfilled
         )
 
         # _update_outliers_txn filters out any events which have already been
         # persisted, and returns the filtered list.
         events_and_contexts = self._update_outliers_txn(
-            txn,
-            events_and_contexts=events_and_contexts,
+            txn, events_and_contexts=events_and_contexts
         )
 
         # From this point onwards the events are only events that we haven't
@@ -852,15 +846,9 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             # for these events so we can reinsert them.
             # This gets around any problems with some tables already having
             # entries.
-            self._delete_existing_rows_txn(
-                txn,
-                events_and_contexts=events_and_contexts,
-            )
+            self._delete_existing_rows_txn(txn, events_and_contexts=events_and_contexts)
 
-        self._store_event_txn(
-            txn,
-            events_and_contexts=events_and_contexts,
-        )
+        self._store_event_txn(txn, events_and_contexts=events_and_contexts)
 
         # Insert into event_to_state_groups.
         self._store_event_state_mappings_txn(txn, events_and_contexts)
@@ -889,8 +877,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         # _store_rejected_events_txn filters out any events which were
         # rejected, and returns the filtered list.
         events_and_contexts = self._store_rejected_events_txn(
-            txn,
-            events_and_contexts=events_and_contexts,
+            txn, events_and_contexts=events_and_contexts
         )
 
         # From this point onwards the events are only ones that weren't
@@ -920,22 +907,40 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
                     WHERE room_id = ? AND type = ? AND state_key = ?
                 )
             """
-            txn.executemany(sql, (
+            txn.executemany(
+                sql,
                 (
-                    max_stream_order, room_id, etype, state_key, None,
-                    room_id, etype, state_key,
-                )
-                for etype, state_key in to_delete
-                # We sanity check that we're deleting rather than updating
-                if (etype, state_key) not in to_insert
-            ))
-            txn.executemany(sql, (
+                    (
+                        max_stream_order,
+                        room_id,
+                        etype,
+                        state_key,
+                        None,
+                        room_id,
+                        etype,
+                        state_key,
+                    )
+                    for etype, state_key in to_delete
+                    # We sanity check that we're deleting rather than updating
+                    if (etype, state_key) not in to_insert
+                ),
+            )
+            txn.executemany(
+                sql,
                 (
-                    max_stream_order, room_id, etype, state_key, ev_id,
-                    room_id, etype, state_key,
-                )
-                for (etype, state_key), ev_id in iteritems(to_insert)
-            ))
+                    (
+                        max_stream_order,
+                        room_id,
+                        etype,
+                        state_key,
+                        ev_id,
+                        room_id,
+                        etype,
+                        state_key,
+                    )
+                    for (etype, state_key), ev_id in iteritems(to_insert)
+                ),
+            )
 
             # Now we actually update the current_state_events table
 
@@ -964,7 +969,8 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
 
             txn.call_after(
                 self._curr_state_delta_stream_cache.entity_has_changed,
-                room_id, max_stream_order,
+                room_id,
+                max_stream_order,
             )
 
             # Invalidate the various caches
@@ -982,26 +988,20 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
 
             self._invalidate_state_caches_and_stream(txn, room_id, members_changed)
 
-    def _update_forward_extremities_txn(self, txn, new_forward_extremities,
-                                        max_stream_order):
+    def _update_forward_extremities_txn(
+        self, txn, new_forward_extremities, max_stream_order
+    ):
         for room_id, new_extrem in iteritems(new_forward_extremities):
             self._simple_delete_txn(
-                txn,
-                table="event_forward_extremities",
-                keyvalues={"room_id": room_id},
-            )
-            txn.call_after(
-                self.get_latest_event_ids_in_room.invalidate, (room_id,)
+                txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
             )
+            txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
 
         self._simple_insert_many_txn(
             txn,
             table="event_forward_extremities",
             values=[
-                {
-                    "event_id": ev_id,
-                    "room_id": room_id,
-                }
+                {"event_id": ev_id, "room_id": room_id}
                 for room_id, new_extrem in iteritems(new_forward_extremities)
                 for ev_id in new_extrem
             ],
@@ -1021,7 +1021,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
                 }
                 for room_id, new_extrem in iteritems(new_forward_extremities)
                 for event_id in new_extrem
-            ]
+            ],
         )
 
     @classmethod
@@ -1065,7 +1065,8 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             if not backfilled:
                 txn.call_after(
                     self._events_stream_cache.entity_has_changed,
-                    event.room_id, event.internal_metadata.stream_ordering,
+                    event.room_id,
+                    event.internal_metadata.stream_ordering,
                 )
 
             if not event.internal_metadata.is_outlier() and not context.rejected:
@@ -1092,16 +1093,12 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             are already in the events table.
         """
         txn.execute(
-            "SELECT event_id, outlier FROM events WHERE event_id in (%s)" % (
-                ",".join(["?"] * len(events_and_contexts)),
-            ),
-            [event.event_id for event, _ in events_and_contexts]
+            "SELECT event_id, outlier FROM events WHERE event_id in (%s)"
+            % (",".join(["?"] * len(events_and_contexts)),),
+            [event.event_id for event, _ in events_and_contexts],
         )
 
-        have_persisted = {
-            event_id: outlier
-            for event_id, outlier in txn
-        }
+        have_persisted = {event_id: outlier for event_id, outlier in txn}
 
         to_remove = set()
         for event, context in events_and_contexts:
@@ -1128,18 +1125,12 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
                     logger.exception("")
                     raise
 
-                metadata_json = encode_json(
-                    event.internal_metadata.get_dict()
-                )
+                metadata_json = encode_json(event.internal_metadata.get_dict())
 
                 sql = (
-                    "UPDATE event_json SET internal_metadata = ?"
-                    " WHERE event_id = ?"
-                )
-                txn.execute(
-                    sql,
-                    (metadata_json, event.event_id,)
+                    "UPDATE event_json SET internal_metadata = ?" " WHERE event_id = ?"
                 )
+                txn.execute(sql, (metadata_json, event.event_id))
 
                 # Add an entry to the ex_outlier_stream table to replicate the
                 # change in outlier status to our workers.
@@ -1152,25 +1143,17 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
                         "event_stream_ordering": stream_order,
                         "event_id": event.event_id,
                         "state_group": state_group_id,
-                    }
+                    },
                 )
 
-                sql = (
-                    "UPDATE events SET outlier = ?"
-                    " WHERE event_id = ?"
-                )
-                txn.execute(
-                    sql,
-                    (False, event.event_id,)
-                )
+                sql = "UPDATE events SET outlier = ?" " WHERE event_id = ?"
+                txn.execute(sql, (False, event.event_id))
 
                 # Update the event_backward_extremities table now that this
                 # event isn't an outlier any more.
                 self._update_backward_extremeties(txn, [event])
 
-        return [
-            ec for ec in events_and_contexts if ec[0] not in to_remove
-        ]
+        return [ec for ec in events_and_contexts if ec[0] not in to_remove]
 
     @classmethod
     def _delete_existing_rows_txn(cls, txn, events_and_contexts):
@@ -1181,39 +1164,37 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         logger.info("Deleting existing")
 
         for table in (
-                "events",
-                "event_auth",
-                "event_json",
-                "event_content_hashes",
-                "event_destinations",
-                "event_edge_hashes",
-                "event_edges",
-                "event_forward_extremities",
-                "event_reference_hashes",
-                "event_search",
-                "event_signatures",
-                "event_to_state_groups",
-                "guest_access",
-                "history_visibility",
-                "local_invites",
-                "room_names",
-                "state_events",
-                "rejections",
-                "redactions",
-                "room_memberships",
-                "topics"
+            "events",
+            "event_auth",
+            "event_json",
+            "event_content_hashes",
+            "event_destinations",
+            "event_edge_hashes",
+            "event_edges",
+            "event_forward_extremities",
+            "event_reference_hashes",
+            "event_search",
+            "event_signatures",
+            "event_to_state_groups",
+            "guest_access",
+            "history_visibility",
+            "local_invites",
+            "room_names",
+            "state_events",
+            "rejections",
+            "redactions",
+            "room_memberships",
+            "topics",
         ):
             txn.executemany(
                 "DELETE FROM %s WHERE event_id = ?" % (table,),
-                [(ev.event_id,) for ev, _ in events_and_contexts]
+                [(ev.event_id,) for ev, _ in events_and_contexts],
             )
 
-        for table in (
-            "event_push_actions",
-        ):
+        for table in ("event_push_actions",):
             txn.executemany(
                 "DELETE FROM %s WHERE room_id = ? AND event_id = ?" % (table,),
-                [(ev.room_id, ev.event_id) for ev, _ in events_and_contexts]
+                [(ev.room_id, ev.event_id) for ev, _ in events_and_contexts],
             )
 
     def _store_event_txn(self, txn, events_and_contexts):
@@ -1296,17 +1277,14 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         for event, context in events_and_contexts:
             if context.rejected:
                 # Insert the event_id into the rejections table
-                self._store_rejections_txn(
-                    txn, event.event_id, context.rejected
-                )
+                self._store_rejections_txn(txn, event.event_id, context.rejected)
                 to_remove.add(event)
 
-        return [
-            ec for ec in events_and_contexts if ec[0] not in to_remove
-        ]
+        return [ec for ec in events_and_contexts if ec[0] not in to_remove]
 
-    def _update_metadata_tables_txn(self, txn, events_and_contexts,
-                                    all_events_and_contexts, backfilled):
+    def _update_metadata_tables_txn(
+        self, txn, events_and_contexts, all_events_and_contexts, backfilled
+    ):
         """Update all the miscellaneous tables for new events
 
         Args:
@@ -1342,8 +1320,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         # Update the event_forward_extremities, event_backward_extremities and
         # event_edges tables.
         self._handle_mult_prev_events(
-            txn,
-            events=[event for event, _ in events_and_contexts],
+            txn, events=[event for event, _ in events_and_contexts]
         )
 
         for event, _ in events_and_contexts:
@@ -1401,11 +1378,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
 
             state_values.append(vals)
 
-        self._simple_insert_many_txn(
-            txn,
-            table="state_events",
-            values=state_values,
-        )
+        self._simple_insert_many_txn(txn, table="state_events", values=state_values)
 
         # Prefill the event cache
         self._add_to_cache(txn, events_and_contexts)
@@ -1416,10 +1389,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         rows = []
         N = 200
         for i in range(0, len(events_and_contexts), N):
-            ev_map = {
-                e[0].event_id: e[0]
-                for e in events_and_contexts[i:i + N]
-            }
+            ev_map = {e[0].event_id: e[0] for e in events_and_contexts[i : i + N]}
             if not ev_map:
                 break
 
@@ -1439,14 +1409,14 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             for row in rows:
                 event = ev_map[row["event_id"]]
                 if not row["rejects"] and not row["redacts"]:
-                    to_prefill.append(_EventCacheEntry(
-                        event=event,
-                        redacted_event=None,
-                    ))
+                    to_prefill.append(
+                        _EventCacheEntry(event=event, redacted_event=None)
+                    )
 
         def prefill():
             for cache_entry in to_prefill:
                 self._get_event_cache.prefill((cache_entry[0].event_id,), cache_entry)
+
         txn.call_after(prefill)
 
     def _store_redaction(self, txn, event):
@@ -1454,7 +1424,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         txn.call_after(self._invalidate_get_event_cache, event.redacts)
         txn.execute(
             "INSERT INTO redactions (event_id, redacts) VALUES (?,?)",
-            (event.event_id, event.redacts)
+            (event.event_id, event.redacts),
         )
 
     @defer.inlineCallbacks
@@ -1465,6 +1435,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         If it has been significantly less or more than one day since the last
         call to this function, it will return None.
         """
+
         def _count_messages(txn):
             sql = """
                 SELECT COALESCE(COUNT(*), 0) FROM events
@@ -1492,7 +1463,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
                 AND stream_ordering > ?
             """
 
-            txn.execute(sql, (like_clause, self.stream_ordering_day_ago,))
+            txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
             count, = txn.fetchone()
             return count
 
@@ -1557,18 +1528,16 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
 
                 update_rows.append((sender, contains_url, event_id))
 
-            sql = (
-                "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
-            )
+            sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
 
             for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
-                clump = update_rows[index:index + INSERT_CLUMP_SIZE]
+                clump = update_rows[index : index + INSERT_CLUMP_SIZE]
                 txn.executemany(sql, clump)
 
             progress = {
                 "target_min_stream_id_inclusive": target_min_stream_id,
                 "max_stream_id_exclusive": min_stream_id,
-                "rows_inserted": rows_inserted + len(rows)
+                "rows_inserted": rows_inserted + len(rows),
             }
 
             self._background_update_progress_txn(
@@ -1613,10 +1582,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
 
             rows_to_update = []
 
-            chunks = [
-                event_ids[i:i + 100]
-                for i in range(0, len(event_ids), 100)
-            ]
+            chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)]
             for chunk in chunks:
                 ev_rows = self._simple_select_many_txn(
                     txn,
@@ -1639,18 +1605,16 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
 
                     rows_to_update.append((origin_server_ts, event_id))
 
-            sql = (
-                "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
-            )
+            sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
 
             for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE):
-                clump = rows_to_update[index:index + INSERT_CLUMP_SIZE]
+                clump = rows_to_update[index : index + INSERT_CLUMP_SIZE]
                 txn.executemany(sql, clump)
 
             progress = {
                 "target_min_stream_id_inclusive": target_min_stream_id,
                 "max_stream_id_exclusive": min_stream_id,
-                "rows_inserted": rows_inserted + len(rows_to_update)
+                "rows_inserted": rows_inserted + len(rows_to_update),
             }
 
             self._background_update_progress_txn(
@@ -1714,6 +1678,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             new_event_updates.extend(txn)
 
             return new_event_updates
+
         return self.runInteraction(
             "get_all_new_forward_event_rows", get_all_new_forward_event_rows
         )
@@ -1756,13 +1721,20 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             new_event_updates.extend(txn.fetchall())
 
             return new_event_updates
+
         return self.runInteraction(
             "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
         )
 
     @cached(num_args=5, max_entries=10)
-    def get_all_new_events(self, last_backfill_id, last_forward_id,
-                           current_backfill_id, current_forward_id, limit):
+    def get_all_new_events(
+        self,
+        last_backfill_id,
+        last_forward_id,
+        current_backfill_id,
+        current_forward_id,
+        limit,
+    ):
         """Get all the new events that have arrived at the server either as
         new events or as backfilled events"""
         have_backfill_events = last_backfill_id != current_backfill_id
@@ -1837,14 +1809,15 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
                 backward_ex_outliers = []
 
             return AllNewEventsResult(
-                new_forward_events, new_backfill_events,
-                forward_ex_outliers, backward_ex_outliers,
+                new_forward_events,
+                new_backfill_events,
+                forward_ex_outliers,
+                backward_ex_outliers,
             )
+
         return self.runInteraction("get_all_new_events", get_all_new_events_txn)
 
-    def purge_history(
-        self, room_id, token, delete_local_events,
-    ):
+    def purge_history(self, room_id, token, delete_local_events):
         """Deletes room history before a certain point
 
         Args:
@@ -1860,13 +1833,13 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
 
         return self.runInteraction(
             "purge_history",
-            self._purge_history_txn, room_id, token,
+            self._purge_history_txn,
+            room_id,
+            token,
             delete_local_events,
         )
 
-    def _purge_history_txn(
-        self, txn, room_id, token_str, delete_local_events,
-    ):
+    def _purge_history_txn(self, txn, room_id, token_str, delete_local_events):
         token = RoomStreamToken.parse(token_str)
 
         # Tables that should be pruned:
@@ -1913,7 +1886,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             "ON e.event_id = f.event_id "
             "AND e.room_id = f.room_id "
             "WHERE f.room_id = ?",
-            (room_id,)
+            (room_id,),
         )
         rows = txn.fetchall()
         max_depth = max(row[1] for row in rows)
@@ -1934,10 +1907,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             should_delete_expr += " AND event_id NOT LIKE ?"
 
             # We include the parameter twice since we use the expression twice
-            should_delete_params += (
-                "%:" + self.hs.hostname,
-                "%:" + self.hs.hostname,
-            )
+            should_delete_params += ("%:" + self.hs.hostname, "%:" + self.hs.hostname)
 
         should_delete_params += (room_id, token.topological)
 
@@ -1948,10 +1918,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             " SELECT event_id, %s"
             " FROM events AS e LEFT JOIN state_events USING (event_id)"
             " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?"
-            % (
-                should_delete_expr,
-                should_delete_expr,
-            ),
+            % (should_delete_expr, should_delete_expr),
             should_delete_params,
         )
 
@@ -1961,23 +1928,19 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         # the should_delete / shouldn't_delete subsets
         txn.execute(
             "CREATE INDEX events_to_purge_should_delete"
-            " ON events_to_purge(should_delete)",
+            " ON events_to_purge(should_delete)"
         )
 
         # We do joins against events_to_purge for e.g. calculating state
         # groups to purge, etc., so lets make an index.
-        txn.execute(
-            "CREATE INDEX events_to_purge_id"
-            " ON events_to_purge(event_id)",
-        )
+        txn.execute("CREATE INDEX events_to_purge_id" " ON events_to_purge(event_id)")
 
-        txn.execute(
-            "SELECT event_id, should_delete FROM events_to_purge"
-        )
+        txn.execute("SELECT event_id, should_delete FROM events_to_purge")
         event_rows = txn.fetchall()
         logger.info(
             "[purge] found %i events before cutoff, of which %i can be deleted",
-            len(event_rows), sum(1 for e in event_rows if e[1]),
+            len(event_rows),
+            sum(1 for e in event_rows if e[1]),
         )
 
         logger.info("[purge] Finding new backward extremities")
@@ -1989,24 +1952,21 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             "SELECT DISTINCT e.event_id FROM events_to_purge AS e"
             " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id"
             " LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id"
-            " WHERE ep2.event_id IS NULL",
+            " WHERE ep2.event_id IS NULL"
         )
         new_backwards_extrems = txn.fetchall()
 
         logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems)
 
         txn.execute(
-            "DELETE FROM event_backward_extremities WHERE room_id = ?",
-            (room_id,)
+            "DELETE FROM event_backward_extremities WHERE room_id = ?", (room_id,)
         )
 
         # Update backward extremeties
         txn.executemany(
             "INSERT INTO event_backward_extremities (room_id, event_id)"
             " VALUES (?, ?)",
-            [
-                (room_id, event_id) for event_id, in new_backwards_extrems
-            ]
+            [(room_id, event_id) for event_id, in new_backwards_extrems],
         )
 
         logger.info("[purge] finding redundant state groups")
@@ -2014,28 +1974,25 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         # Get all state groups that are referenced by events that are to be
         # deleted. We then go and check if they are referenced by other events
         # or state groups, and if not we delete them.
-        txn.execute("""
+        txn.execute(
+            """
             SELECT DISTINCT state_group FROM events_to_purge
             INNER JOIN event_to_state_groups USING (event_id)
-        """)
+        """
+        )
 
         referenced_state_groups = set(sg for sg, in txn)
         logger.info(
-            "[purge] found %i referenced state groups",
-            len(referenced_state_groups),
+            "[purge] found %i referenced state groups", len(referenced_state_groups)
         )
 
         logger.info("[purge] finding state groups that can be deleted")
 
-        state_groups_to_delete, remaining_state_groups = (
-            self._find_unreferenced_groups_during_purge(
-                txn, referenced_state_groups,
-            )
-        )
+        _ = self._find_unreferenced_groups_during_purge(txn, referenced_state_groups)
+        state_groups_to_delete, remaining_state_groups = _
 
         logger.info(
-            "[purge] found %i state groups to delete",
-            len(state_groups_to_delete),
+            "[purge] found %i state groups to delete", len(state_groups_to_delete)
         )
 
         logger.info(
@@ -2047,25 +2004,15 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         # groups to non delta versions.
         for sg in remaining_state_groups:
             logger.info("[purge] de-delta-ing remaining state group %s", sg)
-            curr_state = self._get_state_groups_from_groups_txn(
-                txn, [sg],
-            )
+            curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
             curr_state = curr_state[sg]
 
             self._simple_delete_txn(
-                txn,
-                table="state_groups_state",
-                keyvalues={
-                    "state_group": sg,
-                }
+                txn, table="state_groups_state", keyvalues={"state_group": sg}
             )
 
             self._simple_delete_txn(
-                txn,
-                table="state_group_edges",
-                keyvalues={
-                    "state_group": sg,
-                }
+                txn, table="state_group_edges", keyvalues={"state_group": sg}
             )
 
             self._simple_insert_many_txn(
@@ -2099,9 +2046,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             "WHERE event_id IN (SELECT event_id from events_to_purge)"
         )
         for event_id, _ in event_rows:
-            txn.call_after(self._get_state_group_for_event.invalidate, (
-                event_id,
-            ))
+            txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
 
         # Delete all remote non-state events
         for table in (
@@ -2123,21 +2068,19 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             txn.execute(
                 "DELETE FROM %s WHERE event_id IN ("
                 "    SELECT event_id FROM events_to_purge WHERE should_delete"
-                ")" % (table,),
+                ")" % (table,)
             )
 
         # event_push_actions lacks an index on event_id, and has one on
         # (room_id, event_id) instead.
-        for table in (
-            "event_push_actions",
-        ):
+        for table in ("event_push_actions",):
             logger.info("[purge] removing events from %s", table)
 
             txn.execute(
                 "DELETE FROM %s WHERE room_id = ? AND event_id IN ("
                 "    SELECT event_id FROM events_to_purge WHERE should_delete"
                 ")" % (table,),
-                (room_id, )
+                (room_id,),
             )
 
         # Mark all state and own events as outliers
@@ -2162,27 +2105,28 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         # extremities. However, the events in event_backward_extremities
         # are ones we don't have yet so we need to look at the events that
         # point to it via event_edges table.
-        txn.execute("""
+        txn.execute(
+            """
             SELECT COALESCE(MIN(depth), 0)
             FROM event_backward_extremities AS eb
             INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id
             INNER JOIN events AS e ON e.event_id = eg.event_id
             WHERE eb.room_id = ?
-        """, (room_id,))
+        """,
+            (room_id,),
+        )
         min_depth, = txn.fetchone()
 
         logger.info("[purge] updating room_depth to %d", min_depth)
 
         txn.execute(
             "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
-            (min_depth, room_id,)
+            (min_depth, room_id),
         )
 
         # finally, drop the temp table. this will commit the txn in sqlite,
         # so make sure to keep this actually last.
-        txn.execute(
-            "DROP TABLE events_to_purge"
-        )
+        txn.execute("DROP TABLE events_to_purge")
 
         logger.info("[purge] done")
 
@@ -2226,7 +2170,9 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
                 SELECT DISTINCT state_group FROM event_to_state_groups
                 LEFT JOIN events_to_purge AS ep USING (event_id)
                 WHERE state_group IN (%s) AND ep.event_id IS NULL
-            """ % (",".join("?" for _ in current_search),)
+            """ % (
+                ",".join("?" for _ in current_search),
+            )
             txn.execute(sql, list(current_search))
 
             referenced = set(sg for sg, in txn)
@@ -2242,7 +2188,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
                 column="prev_state_group",
                 iterable=current_search,
                 keyvalues={},
-                retcols=("prev_state_group", "state_group",),
+                retcols=("prev_state_group", "state_group"),
             )
 
             prevs = set(row["state_group"] for row in rows)
@@ -2279,13 +2225,15 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             table="events",
             retcols=["topological_ordering", "stream_ordering"],
             keyvalues={"event_id": event_id},
-            allow_none=True
+            allow_none=True,
         )
 
         if not res:
             raise SynapseError(404, "Could not find event %s" % (event_id,))
 
-        defer.returnValue((int(res["topological_ordering"]), int(res["stream_ordering"])))
+        defer.returnValue(
+            (int(res["topological_ordering"]), int(res["stream_ordering"]))
+        )
 
     def get_max_current_state_delta_stream_id(self):
         return self._stream_id_gen.get_current_token()
@@ -2300,13 +2248,19 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             """
             txn.execute(sql, (from_token, to_token, limit))
             return txn.fetchall()
+
         return self.runInteraction(
             "get_all_updated_current_state_deltas",
             get_all_updated_current_state_deltas_txn,
         )
 
 
-AllNewEventsResult = namedtuple("AllNewEventsResult", [
-    "new_forward_events", "new_backfill_events",
-    "forward_ex_outliers", "backward_ex_outliers",
-])
+AllNewEventsResult = namedtuple(
+    "AllNewEventsResult",
+    [
+        "new_forward_events",
+        "new_backfill_events",
+        "forward_ex_outliers",
+        "backward_ex_outliers",
+    ],
+)
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 592c1bcd33..57df17bcc2 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -35,28 +35,22 @@ logger = logging.getLogger(__name__)
 
 
 RoomsForUser = namedtuple(
-    "RoomsForUser",
-    ("room_id", "sender", "membership", "event_id", "stream_ordering")
+    "RoomsForUser", ("room_id", "sender", "membership", "event_id", "stream_ordering")
 )
 
 GetRoomsForUserWithStreamOrdering = namedtuple(
-    "_GetRoomsForUserWithStreamOrdering",
-    ("room_id", "stream_ordering",)
+    "_GetRoomsForUserWithStreamOrdering", ("room_id", "stream_ordering")
 )
 
 
 # We store this using a namedtuple so that we save about 3x space over using a
 # dict.
-ProfileInfo = namedtuple(
-    "ProfileInfo", ("avatar_url", "display_name")
-)
+ProfileInfo = namedtuple("ProfileInfo", ("avatar_url", "display_name"))
 
 # "members" points to a truncated list of (user_id, event_id) tuples for users of
 # a given membership type, suitable for use in calculating heroes for a room.
 # "count" points to the total numberr of users of a given membership type.
-MemberSummary = namedtuple(
-    "MemberSummary", ("members", "count")
-)
+MemberSummary = namedtuple("MemberSummary", ("members", "count"))
 
 _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
 
@@ -67,7 +61,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         """Returns the set of all hosts currently in the room
         """
         user_ids = yield self.get_users_in_room(
-            room_id, on_invalidate=cache_context.invalidate,
+            room_id, on_invalidate=cache_context.invalidate
         )
         hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
         defer.returnValue(hosts)
@@ -84,8 +78,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 " WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?"
             )
 
-            txn.execute(sql, (room_id, Membership.JOIN,))
+            txn.execute(sql, (room_id, Membership.JOIN))
             return [to_ascii(r[0]) for r in txn]
+
         return self.runInteraction("get_users_in_room", f)
 
     @cached(max_entries=100000)
@@ -156,9 +151,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             A deferred list of RoomsForUser.
         """
 
-        return self.get_rooms_for_user_where_membership_is(
-            user_id, [Membership.INVITE]
-        )
+        return self.get_rooms_for_user_where_membership_is(user_id, [Membership.INVITE])
 
     @defer.inlineCallbacks
     def get_invite_for_user_in_room(self, user_id, room_id):
@@ -196,11 +189,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return self.runInteraction(
             "get_rooms_for_user_where_membership_is",
             self._get_rooms_for_user_where_membership_is_txn,
-            user_id, membership_list
+            user_id,
+            membership_list,
         )
 
-    def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id,
-                                                    membership_list):
+    def _get_rooms_for_user_where_membership_is_txn(
+        self, txn, user_id, membership_list
+    ):
 
         do_invite = Membership.INVITE in membership_list
         membership_list = [m for m in membership_list if m != Membership.INVITE]
@@ -227,9 +222,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             ) % (where_clause,)
 
             txn.execute(sql, args)
-            results = [
-                RoomsForUser(**r) for r in self.cursor_to_dict(txn)
-            ]
+            results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)]
 
         if do_invite:
             sql = (
@@ -241,13 +234,16 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             )
 
             txn.execute(sql, (user_id,))
-            results.extend(RoomsForUser(
-                room_id=r["room_id"],
-                sender=r["inviter"],
-                event_id=r["event_id"],
-                stream_ordering=r["stream_ordering"],
-                membership=Membership.INVITE,
-            ) for r in self.cursor_to_dict(txn))
+            results.extend(
+                RoomsForUser(
+                    room_id=r["room_id"],
+                    sender=r["inviter"],
+                    event_id=r["event_id"],
+                    stream_ordering=r["stream_ordering"],
+                    membership=Membership.INVITE,
+                )
+                for r in self.cursor_to_dict(txn)
+            )
 
         return results
 
@@ -264,19 +260,21 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             of the most recent join for that user and room.
         """
         rooms = yield self.get_rooms_for_user_where_membership_is(
-            user_id, membership_list=[Membership.JOIN],
+            user_id, membership_list=[Membership.JOIN]
+        )
+        defer.returnValue(
+            frozenset(
+                GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
+                for r in rooms
+            )
         )
-        defer.returnValue(frozenset(
-            GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
-            for r in rooms
-        ))
 
     @defer.inlineCallbacks
     def get_rooms_for_user(self, user_id, on_invalidate=None):
         """Returns a set of room_ids the user is currently joined to
         """
         rooms = yield self.get_rooms_for_user_with_stream_ordering(
-            user_id, on_invalidate=on_invalidate,
+            user_id, on_invalidate=on_invalidate
         )
         defer.returnValue(frozenset(r.room_id for r in rooms))
 
@@ -285,13 +283,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         """Returns the set of users who share a room with `user_id`
         """
         room_ids = yield self.get_rooms_for_user(
-            user_id, on_invalidate=cache_context.invalidate,
+            user_id, on_invalidate=cache_context.invalidate
         )
 
         user_who_share_room = set()
         for room_id in room_ids:
             user_ids = yield self.get_users_in_room(
-                room_id, on_invalidate=cache_context.invalidate,
+                room_id, on_invalidate=cache_context.invalidate
             )
             user_who_share_room.update(user_ids)
 
@@ -309,9 +307,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         current_state_ids = yield context.get_current_state_ids(self)
         result = yield self._get_joined_users_from_context(
-            event.room_id, state_group, current_state_ids,
-            event=event,
-            context=context,
+            event.room_id, state_group, current_state_ids, event=event, context=context
         )
         defer.returnValue(result)
 
@@ -325,13 +321,21 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             state_group = object()
 
         return self._get_joined_users_from_context(
-            room_id, state_group, state_entry.state, context=state_entry,
+            room_id, state_group, state_entry.state, context=state_entry
         )
 
-    @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True,
-                           max_entries=100000)
-    def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
-                                       cache_context, event=None, context=None):
+    @cachedInlineCallbacks(
+        num_args=2, cache_context=True, iterable=True, max_entries=100000
+    )
+    def _get_joined_users_from_context(
+        self,
+        room_id,
+        state_group,
+        current_state_ids,
+        cache_context,
+        event=None,
+        context=None,
+    ):
         # We don't use `state_group`, it's there so that we can cache based
         # on it. However, it's important that it's never None, since two current_states
         # with a state_group of None are likely to be different.
@@ -371,9 +375,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         # the hit ratio counts. After all, we don't populate the cache if we
         # miss it here
         event_map = self._get_events_from_cache(
-            member_event_ids,
-            allow_rejected=False,
-            update_metrics=False,
+            member_event_ids, allow_rejected=False, update_metrics=False
         )
 
         missing_member_event_ids = []
@@ -397,21 +399,21 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 table="room_memberships",
                 column="event_id",
                 iterable=missing_member_event_ids,
-                retcols=('user_id', 'display_name', 'avatar_url',),
-                keyvalues={
-                    "membership": Membership.JOIN,
-                },
+                retcols=('user_id', 'display_name', 'avatar_url'),
+                keyvalues={"membership": Membership.JOIN},
                 batch_size=500,
                 desc="_get_joined_users_from_context",
             )
 
-            users_in_room.update({
-                to_ascii(row["user_id"]): ProfileInfo(
-                    avatar_url=to_ascii(row["avatar_url"]),
-                    display_name=to_ascii(row["display_name"]),
-                )
-                for row in rows
-            })
+            users_in_room.update(
+                {
+                    to_ascii(row["user_id"]): ProfileInfo(
+                        avatar_url=to_ascii(row["avatar_url"]),
+                        display_name=to_ascii(row["display_name"]),
+                    )
+                    for row in rows
+                }
+            )
 
         if event is not None and event.type == EventTypes.Member:
             if event.membership == Membership.JOIN:
@@ -505,7 +507,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             state_group = object()
 
         return self._get_joined_hosts(
-            room_id, state_group, state_entry.state, state_entry=state_entry,
+            room_id, state_group, state_entry.state, state_entry=state_entry
         )
 
     @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
@@ -531,6 +533,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         """Returns whether user_id has elected to discard history for room_id.
 
         Returns False if they have since re-joined."""
+
         def f(txn):
             sql = (
                 "SELECT"
@@ -547,6 +550,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             txn.execute(sql, (user_id, room_id))
             rows = txn.fetchall()
             return rows[0][0]
+
         count = yield self.runInteraction("did_forget_membership", f)
         defer.returnValue(count == 0)
 
@@ -575,13 +579,14 @@ class RoomMemberStore(RoomMemberWorkerStore):
                     "avatar_url": event.content.get("avatar_url", None),
                 }
                 for event in events
-            ]
+            ],
         )
 
         for event in events:
             txn.call_after(
                 self._membership_stream_cache.entity_has_changed,
-                event.state_key, event.internal_metadata.stream_ordering
+                event.state_key,
+                event.internal_metadata.stream_ordering,
             )
             txn.call_after(
                 self.get_invited_rooms_for_user.invalidate, (event.state_key,)
@@ -607,7 +612,7 @@ class RoomMemberStore(RoomMemberWorkerStore):
                             "inviter": event.sender,
                             "room_id": event.room_id,
                             "stream_id": event.internal_metadata.stream_ordering,
-                        }
+                        },
                     )
                 else:
                     sql = (
@@ -616,12 +621,15 @@ class RoomMemberStore(RoomMemberWorkerStore):
                         " AND replaced_by is NULL"
                     )
 
-                    txn.execute(sql, (
-                        event.internal_metadata.stream_ordering,
-                        event.event_id,
-                        event.room_id,
-                        event.state_key,
-                    ))
+                    txn.execute(
+                        sql,
+                        (
+                            event.internal_metadata.stream_ordering,
+                            event.event_id,
+                            event.room_id,
+                            event.state_key,
+                        ),
+                    )
 
     @defer.inlineCallbacks
     def locally_reject_invite(self, user_id, room_id):
@@ -632,18 +640,14 @@ class RoomMemberStore(RoomMemberWorkerStore):
         )
 
         def f(txn, stream_ordering):
-            txn.execute(sql, (
-                stream_ordering,
-                True,
-                room_id,
-                user_id,
-            ))
+            txn.execute(sql, (stream_ordering, True, room_id, user_id))
 
         with self._stream_id_gen.get_next() as stream_ordering:
             yield self.runInteraction("locally_reject_invite", f, stream_ordering)
 
     def forget(self, user_id, room_id):
         """Indicate that user_id wishes to discard history for room_id."""
+
         def f(txn):
             sql = (
                 "UPDATE"
@@ -657,9 +661,8 @@ class RoomMemberStore(RoomMemberWorkerStore):
             )
             txn.execute(sql, (user_id, room_id))
 
-            self._invalidate_cache_and_stream(
-                txn, self.did_forget, (user_id, room_id,),
-            )
+            self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id))
+
         return self.runInteraction("forget_membership", f)
 
     @defer.inlineCallbacks
@@ -674,7 +677,7 @@ class RoomMemberStore(RoomMemberWorkerStore):
         INSERT_CLUMP_SIZE = 1000
 
         def add_membership_profile_txn(txn):
-            sql = ("""
+            sql = """
                 SELECT stream_ordering, event_id, events.room_id, event_json.json
                 FROM events
                 INNER JOIN event_json USING (event_id)
@@ -683,7 +686,7 @@ class RoomMemberStore(RoomMemberWorkerStore):
                 AND type = 'm.room.member'
                 ORDER BY stream_ordering DESC
                 LIMIT ?
-            """)
+            """
 
             txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
 
@@ -707,16 +710,14 @@ class RoomMemberStore(RoomMemberWorkerStore):
                 avatar_url = content.get("avatar_url", None)
 
                 if display_name or avatar_url:
-                    to_update.append((
-                        display_name, avatar_url, event_id, room_id
-                    ))
+                    to_update.append((display_name, avatar_url, event_id, room_id))
 
-            to_update_sql = ("""
+            to_update_sql = """
                 UPDATE room_memberships SET display_name = ?, avatar_url = ?
                 WHERE event_id = ? AND room_id = ?
-            """)
+            """
             for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
-                clump = to_update[index:index + INSERT_CLUMP_SIZE]
+                clump = to_update[index : index + INSERT_CLUMP_SIZE]
                 txn.executemany(to_update_sql, clump)
 
             progress = {
@@ -789,7 +790,7 @@ class _JoinedHostsCache(object):
                             self.hosts_to_joined_users.pop(host, None)
             else:
                 joined_users = yield self.store.get_joined_users_from_state(
-                    self.room_id, state_entry,
+                    self.room_id, state_entry
                 )
 
                 self.hosts_to_joined_users = {}