summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/4942.bugfix1
-rw-r--r--changelog.d/4947.feature1
-rw-r--r--changelog.d/4959.misc1
-rw-r--r--changelog.d/4965.misc1
-rw-r--r--synapse/federation/send_queue.py73
-rw-r--r--synapse/federation/sender/__init__.py19
-rw-r--r--synapse/handlers/presence.py176
-rw-r--r--synapse/handlers/register.py17
-rw-r--r--synapse/module_api/__init__.py9
-rw-r--r--synapse/rest/client/v1/admin.py2
-rw-r--r--synapse/storage/events.py542
-rw-r--r--synapse/storage/roommember.py173
-rw-r--r--tests/handlers/test_presence.py176
13 files changed, 772 insertions, 419 deletions
diff --git a/changelog.d/4942.bugfix b/changelog.d/4942.bugfix
new file mode 100644
index 0000000000..590d80d58f
--- /dev/null
+++ b/changelog.d/4942.bugfix
@@ -0,0 +1 @@
+Fix bug where presence updates were sent to all servers in a room when a new server joined, rather than to just the new server.
diff --git a/changelog.d/4947.feature b/changelog.d/4947.feature
new file mode 100644
index 0000000000..b9d27b90f1
--- /dev/null
+++ b/changelog.d/4947.feature
@@ -0,0 +1 @@
+Add ability for password provider modules to bind email addresses to users upon registration.
\ No newline at end of file
diff --git a/changelog.d/4959.misc b/changelog.d/4959.misc
new file mode 100644
index 0000000000..dd4275501f
--- /dev/null
+++ b/changelog.d/4959.misc
@@ -0,0 +1 @@
+Run `black` to clean up formatting on `synapse/storage/roommember.py` and `synapse/storage/events.py`.
\ No newline at end of file
diff --git a/changelog.d/4965.misc b/changelog.d/4965.misc
new file mode 100644
index 0000000000..284c58b75e
--- /dev/null
+++ b/changelog.d/4965.misc
@@ -0,0 +1 @@
+Remove log line for password via the admin API.
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 04d04a4457..0240b339b0 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -55,7 +55,12 @@ class FederationRemoteSendQueue(object):
         self.is_mine_id = hs.is_mine_id
 
         self.presence_map = {}  # Pending presence map user_id -> UserPresenceState
-        self.presence_changed = SortedDict()  # Stream position -> user_id
+        self.presence_changed = SortedDict()  # Stream position -> list[user_id]
+
+        # Stores the destinations we need to explicitly send presence to about a
+        # given user.
+        # Stream position -> (user_id, destinations)
+        self.presence_destinations = SortedDict()
 
         self.keyed_edu = {}  # (destination, key) -> EDU
         self.keyed_edu_changed = SortedDict()  # stream position -> (destination, key)
@@ -77,7 +82,7 @@ class FederationRemoteSendQueue(object):
 
         for queue_name in [
             "presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
-            "edus", "device_messages", "pos_time",
+            "edus", "device_messages", "pos_time", "presence_destinations",
         ]:
             register(queue_name, getattr(self, queue_name))
 
@@ -121,6 +126,15 @@ class FederationRemoteSendQueue(object):
                 for user_id in uids
             )
 
+            keys = self.presence_destinations.keys()
+            i = self.presence_destinations.bisect_left(position_to_delete)
+            for key in keys[:i]:
+                del self.presence_destinations[key]
+
+            user_ids.update(
+                user_id for user_id, _ in self.presence_destinations.values()
+            )
+
             to_del = [
                 user_id for user_id in self.presence_map if user_id not in user_ids
             ]
@@ -209,6 +223,20 @@ class FederationRemoteSendQueue(object):
 
         self.notifier.on_new_replication_data()
 
+    def send_presence_to_destinations(self, states, destinations):
+        """As per FederationSender
+
+        Args:
+            states (list[UserPresenceState])
+            destinations (list[str])
+        """
+        for state in states:
+            pos = self._next_pos()
+            self.presence_map.update({state.user_id: state for state in states})
+            self.presence_destinations[pos] = (state.user_id, destinations)
+
+        self.notifier.on_new_replication_data()
+
     def send_device_messages(self, destination):
         """As per FederationSender"""
         pos = self._next_pos()
@@ -261,6 +289,16 @@ class FederationRemoteSendQueue(object):
                 state=self.presence_map[user_id],
             )))
 
+        # Fetch presence to send to destinations
+        i = self.presence_destinations.bisect_right(from_token)
+        j = self.presence_destinations.bisect_right(to_token) + 1
+
+        for pos, (user_id, dests) in self.presence_destinations.items()[i:j]:
+            rows.append((pos, PresenceDestinationsRow(
+                state=self.presence_map[user_id],
+                destinations=list(dests),
+            )))
+
         # Fetch changes keyed edus
         i = self.keyed_edu_changed.bisect_right(from_token)
         j = self.keyed_edu_changed.bisect_right(to_token) + 1
@@ -357,6 +395,29 @@ class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", (
         buff.presence.append(self.state)
 
 
+class PresenceDestinationsRow(BaseFederationRow, namedtuple("PresenceDestinationsRow", (
+    "state",  # UserPresenceState
+    "destinations",  # list[str]
+))):
+    TypeId = "pd"
+
+    @staticmethod
+    def from_data(data):
+        return PresenceDestinationsRow(
+            state=UserPresenceState.from_dict(data["state"]),
+            destinations=data["dests"],
+        )
+
+    def to_data(self):
+        return {
+            "state": self.state.as_dict(),
+            "dests": self.destinations,
+        }
+
+    def add_to_buffer(self, buff):
+        buff.presence_destinations.append((self.state, self.destinations))
+
+
 class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", (
     "key",  # tuple(str) - the edu key passed to send_edu
     "edu",  # Edu
@@ -428,6 +489,7 @@ TypeToRow = {
     Row.TypeId: Row
     for Row in (
         PresenceRow,
+        PresenceDestinationsRow,
         KeyedEduRow,
         EduRow,
         DeviceRow,
@@ -437,6 +499,7 @@ TypeToRow = {
 
 ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", (
     "presence",  # list(UserPresenceState)
+    "presence_destinations",  # list of tuples of UserPresenceState and destinations
     "keyed_edus",  # dict of destination -> { key -> Edu }
     "edus",  # dict of destination -> [Edu]
     "device_destinations",  # set of destinations
@@ -458,6 +521,7 @@ def process_rows_for_federation(transaction_queue, rows):
 
     buff = ParsedFederationStreamData(
         presence=[],
+        presence_destinations=[],
         keyed_edus={},
         edus={},
         device_destinations=set(),
@@ -476,6 +540,11 @@ def process_rows_for_federation(transaction_queue, rows):
     if buff.presence:
         transaction_queue.send_presence(buff.presence)
 
+    for state, destinations in buff.presence_destinations:
+        transaction_queue.send_presence_to_destinations(
+            states=[state], destinations=destinations,
+        )
+
     for destination, edu_map in iteritems(buff.keyed_edus):
         for key, edu in edu_map.items():
             transaction_queue.send_edu(edu, key)
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 1dc041752b..4f0f939102 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -371,7 +371,7 @@ class FederationSender(object):
             return
 
         # First we queue up the new presence by user ID, so multiple presence
-        # updates in quick successtion are correctly handled
+        # updates in quick succession are correctly handled.
         # We only want to send presence for our own users, so lets always just
         # filter here just in case.
         self.pending_presence.update({
@@ -402,6 +402,23 @@ class FederationSender(object):
         finally:
             self._processing_pending_presence = False
 
+    def send_presence_to_destinations(self, states, destinations):
+        """Send the given presence states to the given destinations.
+
+        Args:
+            states (list[UserPresenceState])
+            destinations (list[str])
+        """
+
+        if not states or not self.hs.config.use_presence:
+            # No-op if presence is disabled.
+            return
+
+        for destination in destinations:
+            if destination == self.server_name:
+                continue
+            self._get_per_destination_queue(destination).send_presence(states)
+
     @measure_func("txnqueue._process_presence")
     @defer.inlineCallbacks
     def _process_presence_inner(self, states):
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 37e87fc054..e85c49742d 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -31,9 +31,11 @@ from prometheus_client import Counter
 
 from twisted.internet import defer
 
-from synapse.api.constants import PresenceState
+import synapse.metrics
+from synapse.api.constants import EventTypes, Membership, PresenceState
 from synapse.api.errors import SynapseError
 from synapse.metrics import LaterGauge
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.presence import UserPresenceState
 from synapse.types import UserID, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
@@ -98,6 +100,7 @@ class PresenceHandler(object):
         self.hs = hs
         self.is_mine = hs.is_mine
         self.is_mine_id = hs.is_mine_id
+        self.server_name = hs.hostname
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
         self.wheel_timer = WheelTimer()
@@ -132,9 +135,6 @@ class PresenceHandler(object):
             )
         )
 
-        distributor = hs.get_distributor()
-        distributor.observe("user_joined_room", self.user_joined_room)
-
         active_presence = self.store.take_presence_startup_info()
 
         # A dictionary of the current state of users. This is prefilled with
@@ -220,6 +220,15 @@ class PresenceHandler(object):
         LaterGauge("synapse_handlers_presence_wheel_timer_size", "", [],
                    lambda: len(self.wheel_timer))
 
+        # Used to handle sending of presence to newly joined users/servers
+        if hs.config.use_presence:
+            self.notifier.add_replication_callback(self.notify_new_event)
+
+        # Presence is best effort and quickly heals itself, so lets just always
+        # stream from the current state when we restart.
+        self._event_pos = self.store.get_current_events_token()
+        self._event_processing = False
+
     @defer.inlineCallbacks
     def _on_shutdown(self):
         """Gets called when shutting down. This lets us persist any updates that
@@ -751,31 +760,6 @@ class PresenceHandler(object):
         yield self._update_states([prev_state.copy_and_replace(**new_fields)])
 
     @defer.inlineCallbacks
-    def user_joined_room(self, user, room_id):
-        """Called (via the distributor) when a user joins a room. This funciton
-        sends presence updates to servers, either:
-            1. the joining user is a local user and we send their presence to
-               all servers in the room.
-            2. the joining user is a remote user and so we send presence for all
-               local users in the room.
-        """
-        # We only need to send presence to servers that don't have it yet. We
-        # don't need to send to local clients here, as that is done as part
-        # of the event stream/sync.
-        # TODO: Only send to servers not already in the room.
-        if self.is_mine(user):
-            state = yield self.current_state_for_user(user.to_string())
-
-            self._push_to_remotes([state])
-        else:
-            user_ids = yield self.store.get_users_in_room(room_id)
-            user_ids = list(filter(self.is_mine_id, user_ids))
-
-            states = yield self.current_state_for_users(user_ids)
-
-            self._push_to_remotes(list(states.values()))
-
-    @defer.inlineCallbacks
     def get_presence_list(self, observer_user, accepted=None):
         """Returns the presence for all users in their presence list.
         """
@@ -945,6 +929,140 @@ class PresenceHandler(object):
         rows = yield self.store.get_all_presence_updates(last_id, current_id)
         defer.returnValue(rows)
 
+    def notify_new_event(self):
+        """Called when new events have happened. Handles users and servers
+        joining rooms and require being sent presence.
+        """
+
+        if self._event_processing:
+            return
+
+        @defer.inlineCallbacks
+        def _process_presence():
+            assert not self._event_processing
+
+            self._event_processing = True
+            try:
+                yield self._unsafe_process()
+            finally:
+                self._event_processing = False
+
+        run_as_background_process("presence.notify_new_event", _process_presence)
+
+    @defer.inlineCallbacks
+    def _unsafe_process(self):
+        # Loop round handling deltas until we're up to date
+        while True:
+            with Measure(self.clock, "presence_delta"):
+                deltas = yield self.store.get_current_state_deltas(self._event_pos)
+                if not deltas:
+                    return
+
+                yield self._handle_state_delta(deltas)
+
+                self._event_pos = deltas[-1]["stream_id"]
+
+                # Expose current event processing position to prometheus
+                synapse.metrics.event_processing_positions.labels("presence").set(
+                    self._event_pos
+                )
+
+    @defer.inlineCallbacks
+    def _handle_state_delta(self, deltas):
+        """Process current state deltas to find new joins that need to be
+        handled.
+        """
+        for delta in deltas:
+            typ = delta["type"]
+            state_key = delta["state_key"]
+            room_id = delta["room_id"]
+            event_id = delta["event_id"]
+            prev_event_id = delta["prev_event_id"]
+
+            logger.debug("Handling: %r %r, %s", typ, state_key, event_id)
+
+            if typ != EventTypes.Member:
+                continue
+
+            event = yield self.store.get_event(event_id)
+            if event.content.get("membership") != Membership.JOIN:
+                # We only care about joins
+                continue
+
+            if prev_event_id:
+                prev_event = yield self.store.get_event(prev_event_id)
+                if prev_event.content.get("membership") == Membership.JOIN:
+                    # Ignore changes to join events.
+                    continue
+
+            yield self._on_user_joined_room(room_id, state_key)
+
+    @defer.inlineCallbacks
+    def _on_user_joined_room(self, room_id, user_id):
+        """Called when we detect a user joining the room via the current state
+        delta stream.
+
+        Args:
+            room_id (str)
+            user_id (str)
+
+        Returns:
+            Deferred
+        """
+
+        if self.is_mine_id(user_id):
+            # If this is a local user then we need to send their presence
+            # out to hosts in the room (who don't already have it)
+
+            # TODO: We should be able to filter the hosts down to those that
+            # haven't previously seen the user
+
+            state = yield self.current_state_for_user(user_id)
+            hosts = yield self.state.get_current_hosts_in_room(room_id)
+
+            # Filter out ourselves.
+            hosts = set(host for host in hosts if host != self.server_name)
+
+            self.federation.send_presence_to_destinations(
+                states=[state],
+                destinations=hosts,
+            )
+        else:
+            # A remote user has joined the room, so we need to:
+            #   1. Check if this is a new server in the room
+            #   2. If so send any presence they don't already have for
+            #      local users in the room.
+
+            # TODO: We should be able to filter the users down to those that
+            # the server hasn't previously seen
+
+            # TODO: Check that this is actually a new server joining the
+            # room.
+
+            user_ids = yield self.state.get_current_user_in_room(room_id)
+            user_ids = list(filter(self.is_mine_id, user_ids))
+
+            states = yield self.current_state_for_users(user_ids)
+
+            # Filter out old presence, i.e. offline presence states where
+            # the user hasn't been active for a week. We can change this
+            # depending on what we want the UX to be, but at the least we
+            # should filter out offline presence where the state is just the
+            # default state.
+            now = self.clock.time_msec()
+            states = [
+                state for state in states.values()
+                if state.state != PresenceState.OFFLINE
+                or now - state.last_active_ts < 7 * 24 * 60 * 60 * 1000
+                or state.status_msg is not None
+            ]
+
+            if states:
+                self.federation.send_presence_to_destinations(
+                    states=states,
+                    destinations=[get_domain_from_id(user_id)],
+                )
+
 
 def should_notify(old_state, new_state):
     """Decides if a presence state change should be sent to interested parties.
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 58940e0320..a51d11a257 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -153,6 +153,7 @@ class RegistrationHandler(BaseHandler):
         user_type=None,
         default_display_name=None,
         address=None,
+        bind_emails=[],
     ):
         """Registers a new client on the server.
 
@@ -172,6 +173,7 @@ class RegistrationHandler(BaseHandler):
             default_display_name (unicode|None): if set, the new user's displayname
               will be set to this. Defaults to 'localpart'.
             address (str|None): the IP address used to perform the registration.
+            bind_emails (List[str]): list of emails to bind to this account.
         Returns:
             A tuple of (user_id, access_token).
         Raises:
@@ -261,6 +263,21 @@ class RegistrationHandler(BaseHandler):
         if not self.hs.config.user_consent_at_registration:
             yield self._auto_join_rooms(user_id)
 
+        # Bind any specified emails to this account
+        current_time = self.hs.get_clock().time_msec()
+        for email in bind_emails:
+            # generate threepid dict
+            threepid_dict = {
+                "medium": "email",
+                "address": email,
+                "validated_at": current_time,
+            }
+
+            # Bind email to new account
+            yield self._register_email_threepid(
+                user_id, threepid_dict, None, False,
+            )
+
         defer.returnValue((user_id, token))
 
     @defer.inlineCallbacks
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 235ce8334e..b3abd1b3c6 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -74,14 +74,14 @@ class ModuleApi(object):
         return self._auth_handler.check_user_exists(user_id)
 
     @defer.inlineCallbacks
-    def register(self, localpart, displayname=None):
+    def register(self, localpart, displayname=None, emails=[]):
         """Registers a new user with given localpart and optional
-           displayname.
+           displayname, emails.
 
         Args:
             localpart (str): The localpart of the new user.
-            displayname (str|None): The displayname of the new user. If None,
-                the user's displayname will default to `localpart`.
+            displayname (str|None): The displayname of the new user.
+            emails (List[str]): Emails to bind to the new user.
 
         Returns:
             Deferred: a 2-tuple of (user_id, access_token)
@@ -90,6 +90,7 @@ class ModuleApi(object):
         reg = self.hs.get_registration_handler()
         user_id, access_token = yield reg.register(
             localpart=localpart, default_display_name=displayname,
+            bind_emails=emails,
         )
 
         defer.returnValue((user_id, access_token))
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index e788769639..1a26f5a1a6 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -647,8 +647,6 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
         assert_params_in_dict(params, ["new_password"])
         new_password = params['new_password']
 
-        logger.info("new_password: %r", new_password)
-
         yield self._set_password_handler.set_password(
             target_user_id, new_password, requester
         )
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 0b0a4dcdd3..d0668e39c4 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -30,7 +30,6 @@ from twisted.internet import defer
 import synapse.metrics
 from synapse.api.constants import EventTypes
 from synapse.api.errors import SynapseError
-# these are only included to make the type annotations work
 from synapse.events import EventBase  # noqa: F401
 from synapse.events.snapshot import EventContext  # noqa: F401
 from synapse.metrics.background_process_metrics import run_as_background_process
@@ -51,8 +50,11 @@ from synapse.util.metrics import Measure
 logger = logging.getLogger(__name__)
 
 persist_event_counter = Counter("synapse_storage_events_persisted_events", "")
-event_counter = Counter("synapse_storage_events_persisted_events_sep", "",
-                        ["type", "origin_type", "origin_entity"])
+event_counter = Counter(
+    "synapse_storage_events_persisted_events_sep",
+    "",
+    ["type", "origin_type", "origin_entity"],
+)
 
 # The number of times we are recalculating the current state
 state_delta_counter = Counter("synapse_storage_events_state_delta", "")
@@ -60,13 +62,15 @@ state_delta_counter = Counter("synapse_storage_events_state_delta", "")
 # The number of times we are recalculating state when there is only a
 # single forward extremity
 state_delta_single_event_counter = Counter(
-    "synapse_storage_events_state_delta_single_event", "")
+    "synapse_storage_events_state_delta_single_event", ""
+)
 
 # The number of times we are reculating state when we could have resonably
 # calculated the delta when we calculated the state for an event we were
 # persisting.
 state_delta_reuse_delta_counter = Counter(
-    "synapse_storage_events_state_delta_reuse_delta", "")
+    "synapse_storage_events_state_delta_reuse_delta", ""
+)
 
 
 def encode_json(json_object):
@@ -84,9 +88,9 @@ class _EventPeristenceQueue(object):
     concurrent transaction per room.
     """
 
-    _EventPersistQueueItem = namedtuple("_EventPersistQueueItem", (
-        "events_and_contexts", "backfilled", "deferred",
-    ))
+    _EventPersistQueueItem = namedtuple(
+        "_EventPersistQueueItem", ("events_and_contexts", "backfilled", "deferred")
+    )
 
     def __init__(self):
         self._event_persist_queues = {}
@@ -119,11 +123,13 @@ class _EventPeristenceQueue(object):
 
         deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
 
-        queue.append(self._EventPersistQueueItem(
-            events_and_contexts=events_and_contexts,
-            backfilled=backfilled,
-            deferred=deferred,
-        ))
+        queue.append(
+            self._EventPersistQueueItem(
+                events_and_contexts=events_and_contexts,
+                backfilled=backfilled,
+                deferred=deferred,
+            )
+        )
 
         return deferred.observe()
 
@@ -191,6 +197,7 @@ def _retry_on_integrity_error(func):
     Args:
         func: function that returns a Deferred and accepts a `delete_existing` arg
     """
+
     @wraps(func)
     @defer.inlineCallbacks
     def f(self, *args, **kwargs):
@@ -206,8 +213,12 @@ def _retry_on_integrity_error(func):
 
 # inherits from EventFederationStore so that we can call _update_backward_extremities
 # and _handle_mult_prev_events (though arguably those could both be moved in here)
-class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore,
-                  BackgroundUpdateStore):
+class EventsStore(
+    StateGroupWorkerStore,
+    EventFederationStore,
+    EventsWorkerStore,
+    BackgroundUpdateStore,
+):
     EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
     EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
 
@@ -265,8 +276,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         deferreds = []
         for room_id, evs_ctxs in iteritems(partitioned):
             d = self._event_persist_queue.add_to_queue(
-                room_id, evs_ctxs,
-                backfilled=backfilled,
+                room_id, evs_ctxs, backfilled=backfilled
             )
             deferreds.append(d)
 
@@ -296,8 +306,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             and the stream ordering of the latest persisted event
         """
         deferred = self._event_persist_queue.add_to_queue(
-            event.room_id, [(event, context)],
-            backfilled=backfilled,
+            event.room_id, [(event, context)], backfilled=backfilled
         )
 
         self._maybe_start_persisting(event.room_id)
@@ -312,16 +321,16 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         def persisting_queue(item):
             with Measure(self._clock, "persist_events"):
                 yield self._persist_events(
-                    item.events_and_contexts,
-                    backfilled=item.backfilled,
+                    item.events_and_contexts, backfilled=item.backfilled
                 )
 
         self._event_persist_queue.handle_queue(room_id, persisting_queue)
 
     @_retry_on_integrity_error
     @defer.inlineCallbacks
-    def _persist_events(self, events_and_contexts, backfilled=False,
-                        delete_existing=False):
+    def _persist_events(
+        self, events_and_contexts, backfilled=False, delete_existing=False
+    ):
         """Persist events to db
 
         Args:
@@ -345,13 +354,11 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             )
 
         with stream_ordering_manager as stream_orderings:
-            for (event, context), stream, in zip(
-                events_and_contexts, stream_orderings
-            ):
+            for (event, context), stream in zip(events_and_contexts, stream_orderings):
                 event.internal_metadata.stream_ordering = stream
 
             chunks = [
-                events_and_contexts[x:x + 100]
+                events_and_contexts[x : x + 100]
                 for x in range(0, len(events_and_contexts), 100)
             ]
 
@@ -445,12 +452,9 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
                                         state_delta_reuse_delta_counter.inc()
                                         break
 
-                            logger.info(
-                                "Calculating state delta for room %s", room_id,
-                            )
+                            logger.info("Calculating state delta for room %s", room_id)
                             with Measure(
-                                self._clock,
-                                "persist_events.get_new_state_after_events",
+                                self._clock, "persist_events.get_new_state_after_events"
                             ):
                                 res = yield self._get_new_state_after_events(
                                     room_id,
@@ -470,11 +474,10 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
                                 state_delta_for_room[room_id] = ([], delta_ids)
                             elif current_state is not None:
                                 with Measure(
-                                    self._clock,
-                                    "persist_events.calculate_state_delta",
+                                    self._clock, "persist_events.calculate_state_delta"
                                 ):
                                     delta = yield self._calculate_state_delta(
-                                        room_id, current_state,
+                                        room_id, current_state
                                     )
                                 state_delta_for_room[room_id] = delta
 
@@ -498,7 +501,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
                     # backfilled events have negative stream orderings, so we don't
                     # want to set the event_persisted_position to that.
                     synapse.metrics.event_persisted_position.set(
-                        chunk[-1][0].internal_metadata.stream_ordering,
+                        chunk[-1][0].internal_metadata.stream_ordering
                     )
 
                 for event, context in chunk:
@@ -515,9 +518,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
                     event_counter.labels(event.type, origin_type, origin_entity).inc()
 
                 for room_id, new_state in iteritems(current_state_for_room):
-                    self.get_current_state_ids.prefill(
-                        (room_id, ), new_state
-                    )
+                    self.get_current_state_ids.prefill((room_id,), new_state)
 
                 for room_id, latest_event_ids in iteritems(new_forward_extremeties):
                     self.get_latest_event_ids_in_room.prefill(
@@ -535,8 +536,10 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         # we're only interested in new events which aren't outliers and which aren't
         # being rejected.
         new_events = [
-            event for event, ctx in event_contexts
-            if not event.internal_metadata.is_outlier() and not ctx.rejected
+            event
+            for event, ctx in event_contexts
+            if not event.internal_metadata.is_outlier()
+            and not ctx.rejected
             and not event.internal_metadata.is_soft_failed()
         ]
 
@@ -544,15 +547,11 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         result = set(latest_event_ids)
 
         # add all the new events to the list
-        result.update(
-            event.event_id for event in new_events
-        )
+        result.update(event.event_id for event in new_events)
 
         # Now remove all events which are prev_events of any of the new events
         result.difference_update(
-            e_id
-            for event in new_events
-            for e_id in event.prev_event_ids()
+            e_id for event in new_events for e_id in event.prev_event_ids()
         )
 
         # Finally, remove any events which are prev_events of any existing events.
@@ -592,17 +591,14 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             results.extend(r[0] for r in txn)
 
         for chunk in batch_iter(event_ids, 100):
-            yield self.runInteraction(
-                "_get_events_which_are_prevs",
-                _get_events,
-                chunk,
-            )
+            yield self.runInteraction("_get_events_which_are_prevs", _get_events, chunk)
 
         defer.returnValue(results)
 
     @defer.inlineCallbacks
-    def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids,
-                                    new_latest_event_ids):
+    def _get_new_state_after_events(
+        self, room_id, events_context, old_latest_event_ids, new_latest_event_ids
+    ):
         """Calculate the current state dict after adding some new events to
         a room
 
@@ -642,7 +638,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
                 if not ev.internal_metadata.is_outlier():
                     raise Exception(
                         "Context for new event %s has no state "
-                        "group" % (ev.event_id, ),
+                        "group" % (ev.event_id,)
                     )
                 continue
 
@@ -682,9 +678,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
 
         if missing_event_ids:
             # Now pull out the state groups for any missing events from DB
-            event_to_groups = yield self._get_state_group_for_events(
-                missing_event_ids,
-            )
+            event_to_groups = yield self._get_state_group_for_events(missing_event_ids)
             event_id_to_state_group.update(event_to_groups)
 
         # State groups of old_latest_event_ids
@@ -710,9 +704,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             new_state_group = next(iter(new_state_groups))
             old_state_group = next(iter(old_state_groups))
 
-            delta_ids = state_group_deltas.get(
-                (old_state_group, new_state_group,), None
-            )
+            delta_ids = state_group_deltas.get((old_state_group, new_state_group), None)
             if delta_ids is not None:
                 # We have a delta from the existing to new current state,
                 # so lets just return that. If we happen to already have
@@ -735,9 +727,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
 
         # Ok, we need to defer to the state handler to resolve our state sets.
 
-        state_groups = {
-            sg: state_groups_map[sg] for sg in new_state_groups
-        }
+        state_groups = {sg: state_groups_map[sg] for sg in new_state_groups}
 
         events_map = {ev.event_id: ev for ev, _ in events_context}
 
@@ -755,8 +745,11 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
 
         logger.debug("calling resolve_state_groups from preserve_events")
         res = yield self._state_resolution_handler.resolve_state_groups(
-            room_id, room_version, state_groups, events_map,
-            state_res_store=StateResolutionStore(self)
+            room_id,
+            room_version,
+            state_groups,
+            events_map,
+            state_res_store=StateResolutionStore(self),
         )
 
         defer.returnValue((res.state, None))
@@ -774,22 +767,26 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         """
         existing_state = yield self.get_current_state_ids(room_id)
 
-        to_delete = [
-            key for key in existing_state
-            if key not in current_state
-        ]
+        to_delete = [key for key in existing_state if key not in current_state]
 
         to_insert = {
-            key: ev_id for key, ev_id in iteritems(current_state)
+            key: ev_id
+            for key, ev_id in iteritems(current_state)
             if ev_id != existing_state.get(key)
         }
 
         defer.returnValue((to_delete, to_insert))
 
     @log_function
-    def _persist_events_txn(self, txn, events_and_contexts, backfilled,
-                            delete_existing=False, state_delta_for_room={},
-                            new_forward_extremeties={}):
+    def _persist_events_txn(
+        self,
+        txn,
+        events_and_contexts,
+        backfilled,
+        delete_existing=False,
+        state_delta_for_room={},
+        new_forward_extremeties={},
+    ):
         """Insert some number of room events into the necessary database tables.
 
         Rejected events are only inserted into the events table, the events_json table,
@@ -828,20 +825,17 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
 
         # Ensure that we don't have the same event twice.
         events_and_contexts = self._filter_events_and_contexts_for_duplicates(
-            events_and_contexts,
+            events_and_contexts
         )
 
         self._update_room_depths_txn(
-            txn,
-            events_and_contexts=events_and_contexts,
-            backfilled=backfilled,
+            txn, events_and_contexts=events_and_contexts, backfilled=backfilled
         )
 
         # _update_outliers_txn filters out any events which have already been
         # persisted, and returns the filtered list.
         events_and_contexts = self._update_outliers_txn(
-            txn,
-            events_and_contexts=events_and_contexts,
+            txn, events_and_contexts=events_and_contexts
         )
 
         # From this point onwards the events are only events that we haven't
@@ -852,15 +846,9 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             # for these events so we can reinsert them.
             # This gets around any problems with some tables already having
             # entries.
-            self._delete_existing_rows_txn(
-                txn,
-                events_and_contexts=events_and_contexts,
-            )
+            self._delete_existing_rows_txn(txn, events_and_contexts=events_and_contexts)
 
-        self._store_event_txn(
-            txn,
-            events_and_contexts=events_and_contexts,
-        )
+        self._store_event_txn(txn, events_and_contexts=events_and_contexts)
 
         # Insert into event_to_state_groups.
         self._store_event_state_mappings_txn(txn, events_and_contexts)
@@ -889,8 +877,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         # _store_rejected_events_txn filters out any events which were
         # rejected, and returns the filtered list.
         events_and_contexts = self._store_rejected_events_txn(
-            txn,
-            events_and_contexts=events_and_contexts,
+            txn, events_and_contexts=events_and_contexts
         )
 
         # From this point onwards the events are only ones that weren't
@@ -920,22 +907,40 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
                     WHERE room_id = ? AND type = ? AND state_key = ?
                 )
             """
-            txn.executemany(sql, (
+            txn.executemany(
+                sql,
                 (
-                    max_stream_order, room_id, etype, state_key, None,
-                    room_id, etype, state_key,
-                )
-                for etype, state_key in to_delete
-                # We sanity check that we're deleting rather than updating
-                if (etype, state_key) not in to_insert
-            ))
-            txn.executemany(sql, (
+                    (
+                        max_stream_order,
+                        room_id,
+                        etype,
+                        state_key,
+                        None,
+                        room_id,
+                        etype,
+                        state_key,
+                    )
+                    for etype, state_key in to_delete
+                    # We sanity check that we're deleting rather than updating
+                    if (etype, state_key) not in to_insert
+                ),
+            )
+            txn.executemany(
+                sql,
                 (
-                    max_stream_order, room_id, etype, state_key, ev_id,
-                    room_id, etype, state_key,
-                )
-                for (etype, state_key), ev_id in iteritems(to_insert)
-            ))
+                    (
+                        max_stream_order,
+                        room_id,
+                        etype,
+                        state_key,
+                        ev_id,
+                        room_id,
+                        etype,
+                        state_key,
+                    )
+                    for (etype, state_key), ev_id in iteritems(to_insert)
+                ),
+            )
 
             # Now we actually update the current_state_events table
 
@@ -964,7 +969,8 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
 
             txn.call_after(
                 self._curr_state_delta_stream_cache.entity_has_changed,
-                room_id, max_stream_order,
+                room_id,
+                max_stream_order,
             )
 
             # Invalidate the various caches
@@ -982,26 +988,20 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
 
             self._invalidate_state_caches_and_stream(txn, room_id, members_changed)
 
-    def _update_forward_extremities_txn(self, txn, new_forward_extremities,
-                                        max_stream_order):
+    def _update_forward_extremities_txn(
+        self, txn, new_forward_extremities, max_stream_order
+    ):
         for room_id, new_extrem in iteritems(new_forward_extremities):
             self._simple_delete_txn(
-                txn,
-                table="event_forward_extremities",
-                keyvalues={"room_id": room_id},
-            )
-            txn.call_after(
-                self.get_latest_event_ids_in_room.invalidate, (room_id,)
+                txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
             )
+            txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
 
         self._simple_insert_many_txn(
             txn,
             table="event_forward_extremities",
             values=[
-                {
-                    "event_id": ev_id,
-                    "room_id": room_id,
-                }
+                {"event_id": ev_id, "room_id": room_id}
                 for room_id, new_extrem in iteritems(new_forward_extremities)
                 for ev_id in new_extrem
             ],
@@ -1021,7 +1021,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
                 }
                 for room_id, new_extrem in iteritems(new_forward_extremities)
                 for event_id in new_extrem
-            ]
+            ],
         )
 
     @classmethod
@@ -1065,7 +1065,8 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             if not backfilled:
                 txn.call_after(
                     self._events_stream_cache.entity_has_changed,
-                    event.room_id, event.internal_metadata.stream_ordering,
+                    event.room_id,
+                    event.internal_metadata.stream_ordering,
                 )
 
             if not event.internal_metadata.is_outlier() and not context.rejected:
@@ -1092,16 +1093,12 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             are already in the events table.
         """
         txn.execute(
-            "SELECT event_id, outlier FROM events WHERE event_id in (%s)" % (
-                ",".join(["?"] * len(events_and_contexts)),
-            ),
-            [event.event_id for event, _ in events_and_contexts]
+            "SELECT event_id, outlier FROM events WHERE event_id in (%s)"
+            % (",".join(["?"] * len(events_and_contexts)),),
+            [event.event_id for event, _ in events_and_contexts],
         )
 
-        have_persisted = {
-            event_id: outlier
-            for event_id, outlier in txn
-        }
+        have_persisted = {event_id: outlier for event_id, outlier in txn}
 
         to_remove = set()
         for event, context in events_and_contexts:
@@ -1128,18 +1125,12 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
                     logger.exception("")
                     raise
 
-                metadata_json = encode_json(
-                    event.internal_metadata.get_dict()
-                )
+                metadata_json = encode_json(event.internal_metadata.get_dict())
 
                 sql = (
-                    "UPDATE event_json SET internal_metadata = ?"
-                    " WHERE event_id = ?"
-                )
-                txn.execute(
-                    sql,
-                    (metadata_json, event.event_id,)
+                    "UPDATE event_json SET internal_metadata = ?" " WHERE event_id = ?"
                 )
+                txn.execute(sql, (metadata_json, event.event_id))
 
                 # Add an entry to the ex_outlier_stream table to replicate the
                 # change in outlier status to our workers.
@@ -1152,25 +1143,17 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
                         "event_stream_ordering": stream_order,
                         "event_id": event.event_id,
                         "state_group": state_group_id,
-                    }
+                    },
                 )
 
-                sql = (
-                    "UPDATE events SET outlier = ?"
-                    " WHERE event_id = ?"
-                )
-                txn.execute(
-                    sql,
-                    (False, event.event_id,)
-                )
+                sql = "UPDATE events SET outlier = ?" " WHERE event_id = ?"
+                txn.execute(sql, (False, event.event_id))
 
                 # Update the event_backward_extremities table now that this
                 # event isn't an outlier any more.
                 self._update_backward_extremeties(txn, [event])
 
-        return [
-            ec for ec in events_and_contexts if ec[0] not in to_remove
-        ]
+        return [ec for ec in events_and_contexts if ec[0] not in to_remove]
 
     @classmethod
     def _delete_existing_rows_txn(cls, txn, events_and_contexts):
@@ -1181,39 +1164,37 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         logger.info("Deleting existing")
 
         for table in (
-                "events",
-                "event_auth",
-                "event_json",
-                "event_content_hashes",
-                "event_destinations",
-                "event_edge_hashes",
-                "event_edges",
-                "event_forward_extremities",
-                "event_reference_hashes",
-                "event_search",
-                "event_signatures",
-                "event_to_state_groups",
-                "guest_access",
-                "history_visibility",
-                "local_invites",
-                "room_names",
-                "state_events",
-                "rejections",
-                "redactions",
-                "room_memberships",
-                "topics"
+            "events",
+            "event_auth",
+            "event_json",
+            "event_content_hashes",
+            "event_destinations",
+            "event_edge_hashes",
+            "event_edges",
+            "event_forward_extremities",
+            "event_reference_hashes",
+            "event_search",
+            "event_signatures",
+            "event_to_state_groups",
+            "guest_access",
+            "history_visibility",
+            "local_invites",
+            "room_names",
+            "state_events",
+            "rejections",
+            "redactions",
+            "room_memberships",
+            "topics",
         ):
             txn.executemany(
                 "DELETE FROM %s WHERE event_id = ?" % (table,),
-                [(ev.event_id,) for ev, _ in events_and_contexts]
+                [(ev.event_id,) for ev, _ in events_and_contexts],
             )
 
-        for table in (
-            "event_push_actions",
-        ):
+        for table in ("event_push_actions",):
             txn.executemany(
                 "DELETE FROM %s WHERE room_id = ? AND event_id = ?" % (table,),
-                [(ev.room_id, ev.event_id) for ev, _ in events_and_contexts]
+                [(ev.room_id, ev.event_id) for ev, _ in events_and_contexts],
             )
 
     def _store_event_txn(self, txn, events_and_contexts):
@@ -1296,17 +1277,14 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         for event, context in events_and_contexts:
             if context.rejected:
                 # Insert the event_id into the rejections table
-                self._store_rejections_txn(
-                    txn, event.event_id, context.rejected
-                )
+                self._store_rejections_txn(txn, event.event_id, context.rejected)
                 to_remove.add(event)
 
-        return [
-            ec for ec in events_and_contexts if ec[0] not in to_remove
-        ]
+        return [ec for ec in events_and_contexts if ec[0] not in to_remove]
 
-    def _update_metadata_tables_txn(self, txn, events_and_contexts,
-                                    all_events_and_contexts, backfilled):
+    def _update_metadata_tables_txn(
+        self, txn, events_and_contexts, all_events_and_contexts, backfilled
+    ):
         """Update all the miscellaneous tables for new events
 
         Args:
@@ -1342,8 +1320,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         # Update the event_forward_extremities, event_backward_extremities and
         # event_edges tables.
         self._handle_mult_prev_events(
-            txn,
-            events=[event for event, _ in events_and_contexts],
+            txn, events=[event for event, _ in events_and_contexts]
         )
 
         for event, _ in events_and_contexts:
@@ -1401,11 +1378,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
 
             state_values.append(vals)
 
-        self._simple_insert_many_txn(
-            txn,
-            table="state_events",
-            values=state_values,
-        )
+        self._simple_insert_many_txn(txn, table="state_events", values=state_values)
 
         # Prefill the event cache
         self._add_to_cache(txn, events_and_contexts)
@@ -1416,10 +1389,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         rows = []
         N = 200
         for i in range(0, len(events_and_contexts), N):
-            ev_map = {
-                e[0].event_id: e[0]
-                for e in events_and_contexts[i:i + N]
-            }
+            ev_map = {e[0].event_id: e[0] for e in events_and_contexts[i : i + N]}
             if not ev_map:
                 break
 
@@ -1439,14 +1409,14 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             for row in rows:
                 event = ev_map[row["event_id"]]
                 if not row["rejects"] and not row["redacts"]:
-                    to_prefill.append(_EventCacheEntry(
-                        event=event,
-                        redacted_event=None,
-                    ))
+                    to_prefill.append(
+                        _EventCacheEntry(event=event, redacted_event=None)
+                    )
 
         def prefill():
             for cache_entry in to_prefill:
                 self._get_event_cache.prefill((cache_entry[0].event_id,), cache_entry)
+
         txn.call_after(prefill)
 
     def _store_redaction(self, txn, event):
@@ -1454,7 +1424,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         txn.call_after(self._invalidate_get_event_cache, event.redacts)
         txn.execute(
             "INSERT INTO redactions (event_id, redacts) VALUES (?,?)",
-            (event.event_id, event.redacts)
+            (event.event_id, event.redacts),
         )
 
     @defer.inlineCallbacks
@@ -1465,6 +1435,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         If it has been significantly less or more than one day since the last
         call to this function, it will return None.
         """
+
         def _count_messages(txn):
             sql = """
                 SELECT COALESCE(COUNT(*), 0) FROM events
@@ -1492,7 +1463,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
                 AND stream_ordering > ?
             """
 
-            txn.execute(sql, (like_clause, self.stream_ordering_day_ago,))
+            txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
             count, = txn.fetchone()
             return count
 
@@ -1557,18 +1528,16 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
 
                 update_rows.append((sender, contains_url, event_id))
 
-            sql = (
-                "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
-            )
+            sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
 
             for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
-                clump = update_rows[index:index + INSERT_CLUMP_SIZE]
+                clump = update_rows[index : index + INSERT_CLUMP_SIZE]
                 txn.executemany(sql, clump)
 
             progress = {
                 "target_min_stream_id_inclusive": target_min_stream_id,
                 "max_stream_id_exclusive": min_stream_id,
-                "rows_inserted": rows_inserted + len(rows)
+                "rows_inserted": rows_inserted + len(rows),
             }
 
             self._background_update_progress_txn(
@@ -1613,10 +1582,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
 
             rows_to_update = []
 
-            chunks = [
-                event_ids[i:i + 100]
-                for i in range(0, len(event_ids), 100)
-            ]
+            chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)]
             for chunk in chunks:
                 ev_rows = self._simple_select_many_txn(
                     txn,
@@ -1639,18 +1605,16 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
 
                     rows_to_update.append((origin_server_ts, event_id))
 
-            sql = (
-                "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
-            )
+            sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
 
             for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE):
-                clump = rows_to_update[index:index + INSERT_CLUMP_SIZE]
+                clump = rows_to_update[index : index + INSERT_CLUMP_SIZE]
                 txn.executemany(sql, clump)
 
             progress = {
                 "target_min_stream_id_inclusive": target_min_stream_id,
                 "max_stream_id_exclusive": min_stream_id,
-                "rows_inserted": rows_inserted + len(rows_to_update)
+                "rows_inserted": rows_inserted + len(rows_to_update),
             }
 
             self._background_update_progress_txn(
@@ -1714,6 +1678,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             new_event_updates.extend(txn)
 
             return new_event_updates
+
         return self.runInteraction(
             "get_all_new_forward_event_rows", get_all_new_forward_event_rows
         )
@@ -1756,13 +1721,20 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             new_event_updates.extend(txn.fetchall())
 
             return new_event_updates
+
         return self.runInteraction(
             "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
         )
 
     @cached(num_args=5, max_entries=10)
-    def get_all_new_events(self, last_backfill_id, last_forward_id,
-                           current_backfill_id, current_forward_id, limit):
+    def get_all_new_events(
+        self,
+        last_backfill_id,
+        last_forward_id,
+        current_backfill_id,
+        current_forward_id,
+        limit,
+    ):
         """Get all the new events that have arrived at the server either as
         new events or as backfilled events"""
         have_backfill_events = last_backfill_id != current_backfill_id
@@ -1837,14 +1809,15 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
                 backward_ex_outliers = []
 
             return AllNewEventsResult(
-                new_forward_events, new_backfill_events,
-                forward_ex_outliers, backward_ex_outliers,
+                new_forward_events,
+                new_backfill_events,
+                forward_ex_outliers,
+                backward_ex_outliers,
             )
+
         return self.runInteraction("get_all_new_events", get_all_new_events_txn)
 
-    def purge_history(
-        self, room_id, token, delete_local_events,
-    ):
+    def purge_history(self, room_id, token, delete_local_events):
         """Deletes room history before a certain point
 
         Args:
@@ -1860,13 +1833,13 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
 
         return self.runInteraction(
             "purge_history",
-            self._purge_history_txn, room_id, token,
+            self._purge_history_txn,
+            room_id,
+            token,
             delete_local_events,
         )
 
-    def _purge_history_txn(
-        self, txn, room_id, token_str, delete_local_events,
-    ):
+    def _purge_history_txn(self, txn, room_id, token_str, delete_local_events):
         token = RoomStreamToken.parse(token_str)
 
         # Tables that should be pruned:
@@ -1913,7 +1886,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             "ON e.event_id = f.event_id "
             "AND e.room_id = f.room_id "
             "WHERE f.room_id = ?",
-            (room_id,)
+            (room_id,),
         )
         rows = txn.fetchall()
         max_depth = max(row[1] for row in rows)
@@ -1934,10 +1907,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             should_delete_expr += " AND event_id NOT LIKE ?"
 
             # We include the parameter twice since we use the expression twice
-            should_delete_params += (
-                "%:" + self.hs.hostname,
-                "%:" + self.hs.hostname,
-            )
+            should_delete_params += ("%:" + self.hs.hostname, "%:" + self.hs.hostname)
 
         should_delete_params += (room_id, token.topological)
 
@@ -1948,10 +1918,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             " SELECT event_id, %s"
             " FROM events AS e LEFT JOIN state_events USING (event_id)"
             " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?"
-            % (
-                should_delete_expr,
-                should_delete_expr,
-            ),
+            % (should_delete_expr, should_delete_expr),
             should_delete_params,
         )
 
@@ -1961,23 +1928,19 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         # the should_delete / shouldn't_delete subsets
         txn.execute(
             "CREATE INDEX events_to_purge_should_delete"
-            " ON events_to_purge(should_delete)",
+            " ON events_to_purge(should_delete)"
         )
 
         # We do joins against events_to_purge for e.g. calculating state
         # groups to purge, etc., so lets make an index.
-        txn.execute(
-            "CREATE INDEX events_to_purge_id"
-            " ON events_to_purge(event_id)",
-        )
+        txn.execute("CREATE INDEX events_to_purge_id" " ON events_to_purge(event_id)")
 
-        txn.execute(
-            "SELECT event_id, should_delete FROM events_to_purge"
-        )
+        txn.execute("SELECT event_id, should_delete FROM events_to_purge")
         event_rows = txn.fetchall()
         logger.info(
             "[purge] found %i events before cutoff, of which %i can be deleted",
-            len(event_rows), sum(1 for e in event_rows if e[1]),
+            len(event_rows),
+            sum(1 for e in event_rows if e[1]),
         )
 
         logger.info("[purge] Finding new backward extremities")
@@ -1989,24 +1952,21 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             "SELECT DISTINCT e.event_id FROM events_to_purge AS e"
             " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id"
             " LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id"
-            " WHERE ep2.event_id IS NULL",
+            " WHERE ep2.event_id IS NULL"
         )
         new_backwards_extrems = txn.fetchall()
 
         logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems)
 
         txn.execute(
-            "DELETE FROM event_backward_extremities WHERE room_id = ?",
-            (room_id,)
+            "DELETE FROM event_backward_extremities WHERE room_id = ?", (room_id,)
         )
 
         # Update backward extremeties
         txn.executemany(
             "INSERT INTO event_backward_extremities (room_id, event_id)"
             " VALUES (?, ?)",
-            [
-                (room_id, event_id) for event_id, in new_backwards_extrems
-            ]
+            [(room_id, event_id) for event_id, in new_backwards_extrems],
         )
 
         logger.info("[purge] finding redundant state groups")
@@ -2014,28 +1974,25 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         # Get all state groups that are referenced by events that are to be
         # deleted. We then go and check if they are referenced by other events
         # or state groups, and if not we delete them.
-        txn.execute("""
+        txn.execute(
+            """
             SELECT DISTINCT state_group FROM events_to_purge
             INNER JOIN event_to_state_groups USING (event_id)
-        """)
+        """
+        )
 
         referenced_state_groups = set(sg for sg, in txn)
         logger.info(
-            "[purge] found %i referenced state groups",
-            len(referenced_state_groups),
+            "[purge] found %i referenced state groups", len(referenced_state_groups)
         )
 
         logger.info("[purge] finding state groups that can be deleted")
 
-        state_groups_to_delete, remaining_state_groups = (
-            self._find_unreferenced_groups_during_purge(
-                txn, referenced_state_groups,
-            )
-        )
+        _ = self._find_unreferenced_groups_during_purge(txn, referenced_state_groups)
+        state_groups_to_delete, remaining_state_groups = _
 
         logger.info(
-            "[purge] found %i state groups to delete",
-            len(state_groups_to_delete),
+            "[purge] found %i state groups to delete", len(state_groups_to_delete)
         )
 
         logger.info(
@@ -2047,25 +2004,15 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         # groups to non delta versions.
         for sg in remaining_state_groups:
             logger.info("[purge] de-delta-ing remaining state group %s", sg)
-            curr_state = self._get_state_groups_from_groups_txn(
-                txn, [sg],
-            )
+            curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
             curr_state = curr_state[sg]
 
             self._simple_delete_txn(
-                txn,
-                table="state_groups_state",
-                keyvalues={
-                    "state_group": sg,
-                }
+                txn, table="state_groups_state", keyvalues={"state_group": sg}
             )
 
             self._simple_delete_txn(
-                txn,
-                table="state_group_edges",
-                keyvalues={
-                    "state_group": sg,
-                }
+                txn, table="state_group_edges", keyvalues={"state_group": sg}
             )
 
             self._simple_insert_many_txn(
@@ -2099,9 +2046,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             "WHERE event_id IN (SELECT event_id from events_to_purge)"
         )
         for event_id, _ in event_rows:
-            txn.call_after(self._get_state_group_for_event.invalidate, (
-                event_id,
-            ))
+            txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
 
         # Delete all remote non-state events
         for table in (
@@ -2123,21 +2068,19 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             txn.execute(
                 "DELETE FROM %s WHERE event_id IN ("
                 "    SELECT event_id FROM events_to_purge WHERE should_delete"
-                ")" % (table,),
+                ")" % (table,)
             )
 
         # event_push_actions lacks an index on event_id, and has one on
         # (room_id, event_id) instead.
-        for table in (
-            "event_push_actions",
-        ):
+        for table in ("event_push_actions",):
             logger.info("[purge] removing events from %s", table)
 
             txn.execute(
                 "DELETE FROM %s WHERE room_id = ? AND event_id IN ("
                 "    SELECT event_id FROM events_to_purge WHERE should_delete"
                 ")" % (table,),
-                (room_id, )
+                (room_id,),
             )
 
         # Mark all state and own events as outliers
@@ -2162,27 +2105,28 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
         # extremities. However, the events in event_backward_extremities
         # are ones we don't have yet so we need to look at the events that
         # point to it via event_edges table.
-        txn.execute("""
+        txn.execute(
+            """
             SELECT COALESCE(MIN(depth), 0)
             FROM event_backward_extremities AS eb
             INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id
             INNER JOIN events AS e ON e.event_id = eg.event_id
             WHERE eb.room_id = ?
-        """, (room_id,))
+        """,
+            (room_id,),
+        )
         min_depth, = txn.fetchone()
 
         logger.info("[purge] updating room_depth to %d", min_depth)
 
         txn.execute(
             "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
-            (min_depth, room_id,)
+            (min_depth, room_id),
         )
 
         # finally, drop the temp table. this will commit the txn in sqlite,
         # so make sure to keep this actually last.
-        txn.execute(
-            "DROP TABLE events_to_purge"
-        )
+        txn.execute("DROP TABLE events_to_purge")
 
         logger.info("[purge] done")
 
@@ -2226,7 +2170,9 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
                 SELECT DISTINCT state_group FROM event_to_state_groups
                 LEFT JOIN events_to_purge AS ep USING (event_id)
                 WHERE state_group IN (%s) AND ep.event_id IS NULL
-            """ % (",".join("?" for _ in current_search),)
+            """ % (
+                ",".join("?" for _ in current_search),
+            )
             txn.execute(sql, list(current_search))
 
             referenced = set(sg for sg, in txn)
@@ -2242,7 +2188,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
                 column="prev_state_group",
                 iterable=current_search,
                 keyvalues={},
-                retcols=("prev_state_group", "state_group",),
+                retcols=("prev_state_group", "state_group"),
             )
 
             prevs = set(row["state_group"] for row in rows)
@@ -2279,13 +2225,15 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             table="events",
             retcols=["topological_ordering", "stream_ordering"],
             keyvalues={"event_id": event_id},
-            allow_none=True
+            allow_none=True,
         )
 
         if not res:
             raise SynapseError(404, "Could not find event %s" % (event_id,))
 
-        defer.returnValue((int(res["topological_ordering"]), int(res["stream_ordering"])))
+        defer.returnValue(
+            (int(res["topological_ordering"]), int(res["stream_ordering"]))
+        )
 
     def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
         def get_all_updated_current_state_deltas_txn(txn):
@@ -2297,13 +2245,19 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
             """
             txn.execute(sql, (from_token, to_token, limit))
             return txn.fetchall()
+
         return self.runInteraction(
             "get_all_updated_current_state_deltas",
             get_all_updated_current_state_deltas_txn,
         )
 
 
-AllNewEventsResult = namedtuple("AllNewEventsResult", [
-    "new_forward_events", "new_backfill_events",
-    "forward_ex_outliers", "backward_ex_outliers",
-])
+AllNewEventsResult = namedtuple(
+    "AllNewEventsResult",
+    [
+        "new_forward_events",
+        "new_backfill_events",
+        "forward_ex_outliers",
+        "backward_ex_outliers",
+    ],
+)
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 592c1bcd33..57df17bcc2 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -35,28 +35,22 @@ logger = logging.getLogger(__name__)
 
 
 RoomsForUser = namedtuple(
-    "RoomsForUser",
-    ("room_id", "sender", "membership", "event_id", "stream_ordering")
+    "RoomsForUser", ("room_id", "sender", "membership", "event_id", "stream_ordering")
 )
 
 GetRoomsForUserWithStreamOrdering = namedtuple(
-    "_GetRoomsForUserWithStreamOrdering",
-    ("room_id", "stream_ordering",)
+    "_GetRoomsForUserWithStreamOrdering", ("room_id", "stream_ordering")
 )
 
 
 # We store this using a namedtuple so that we save about 3x space over using a
 # dict.
-ProfileInfo = namedtuple(
-    "ProfileInfo", ("avatar_url", "display_name")
-)
+ProfileInfo = namedtuple("ProfileInfo", ("avatar_url", "display_name"))
 
 # "members" points to a truncated list of (user_id, event_id) tuples for users of
 # a given membership type, suitable for use in calculating heroes for a room.
 # "count" points to the total numberr of users of a given membership type.
-MemberSummary = namedtuple(
-    "MemberSummary", ("members", "count")
-)
+MemberSummary = namedtuple("MemberSummary", ("members", "count"))
 
 _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
 
@@ -67,7 +61,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         """Returns the set of all hosts currently in the room
         """
         user_ids = yield self.get_users_in_room(
-            room_id, on_invalidate=cache_context.invalidate,
+            room_id, on_invalidate=cache_context.invalidate
         )
         hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
         defer.returnValue(hosts)
@@ -84,8 +78,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 " WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?"
             )
 
-            txn.execute(sql, (room_id, Membership.JOIN,))
+            txn.execute(sql, (room_id, Membership.JOIN))
             return [to_ascii(r[0]) for r in txn]
+
         return self.runInteraction("get_users_in_room", f)
 
     @cached(max_entries=100000)
@@ -156,9 +151,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             A deferred list of RoomsForUser.
         """
 
-        return self.get_rooms_for_user_where_membership_is(
-            user_id, [Membership.INVITE]
-        )
+        return self.get_rooms_for_user_where_membership_is(user_id, [Membership.INVITE])
 
     @defer.inlineCallbacks
     def get_invite_for_user_in_room(self, user_id, room_id):
@@ -196,11 +189,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return self.runInteraction(
             "get_rooms_for_user_where_membership_is",
             self._get_rooms_for_user_where_membership_is_txn,
-            user_id, membership_list
+            user_id,
+            membership_list,
         )
 
-    def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id,
-                                                    membership_list):
+    def _get_rooms_for_user_where_membership_is_txn(
+        self, txn, user_id, membership_list
+    ):
 
         do_invite = Membership.INVITE in membership_list
         membership_list = [m for m in membership_list if m != Membership.INVITE]
@@ -227,9 +222,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             ) % (where_clause,)
 
             txn.execute(sql, args)
-            results = [
-                RoomsForUser(**r) for r in self.cursor_to_dict(txn)
-            ]
+            results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)]
 
         if do_invite:
             sql = (
@@ -241,13 +234,16 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             )
 
             txn.execute(sql, (user_id,))
-            results.extend(RoomsForUser(
-                room_id=r["room_id"],
-                sender=r["inviter"],
-                event_id=r["event_id"],
-                stream_ordering=r["stream_ordering"],
-                membership=Membership.INVITE,
-            ) for r in self.cursor_to_dict(txn))
+            results.extend(
+                RoomsForUser(
+                    room_id=r["room_id"],
+                    sender=r["inviter"],
+                    event_id=r["event_id"],
+                    stream_ordering=r["stream_ordering"],
+                    membership=Membership.INVITE,
+                )
+                for r in self.cursor_to_dict(txn)
+            )
 
         return results
 
@@ -264,19 +260,21 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             of the most recent join for that user and room.
         """
         rooms = yield self.get_rooms_for_user_where_membership_is(
-            user_id, membership_list=[Membership.JOIN],
+            user_id, membership_list=[Membership.JOIN]
+        )
+        defer.returnValue(
+            frozenset(
+                GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
+                for r in rooms
+            )
         )
-        defer.returnValue(frozenset(
-            GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
-            for r in rooms
-        ))
 
     @defer.inlineCallbacks
     def get_rooms_for_user(self, user_id, on_invalidate=None):
         """Returns a set of room_ids the user is currently joined to
         """
         rooms = yield self.get_rooms_for_user_with_stream_ordering(
-            user_id, on_invalidate=on_invalidate,
+            user_id, on_invalidate=on_invalidate
         )
         defer.returnValue(frozenset(r.room_id for r in rooms))
 
@@ -285,13 +283,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         """Returns the set of users who share a room with `user_id`
         """
         room_ids = yield self.get_rooms_for_user(
-            user_id, on_invalidate=cache_context.invalidate,
+            user_id, on_invalidate=cache_context.invalidate
         )
 
         user_who_share_room = set()
         for room_id in room_ids:
             user_ids = yield self.get_users_in_room(
-                room_id, on_invalidate=cache_context.invalidate,
+                room_id, on_invalidate=cache_context.invalidate
             )
             user_who_share_room.update(user_ids)
 
@@ -309,9 +307,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         current_state_ids = yield context.get_current_state_ids(self)
         result = yield self._get_joined_users_from_context(
-            event.room_id, state_group, current_state_ids,
-            event=event,
-            context=context,
+            event.room_id, state_group, current_state_ids, event=event, context=context
         )
         defer.returnValue(result)
 
@@ -325,13 +321,21 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             state_group = object()
 
         return self._get_joined_users_from_context(
-            room_id, state_group, state_entry.state, context=state_entry,
+            room_id, state_group, state_entry.state, context=state_entry
         )
 
-    @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True,
-                           max_entries=100000)
-    def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
-                                       cache_context, event=None, context=None):
+    @cachedInlineCallbacks(
+        num_args=2, cache_context=True, iterable=True, max_entries=100000
+    )
+    def _get_joined_users_from_context(
+        self,
+        room_id,
+        state_group,
+        current_state_ids,
+        cache_context,
+        event=None,
+        context=None,
+    ):
         # We don't use `state_group`, it's there so that we can cache based
         # on it. However, it's important that it's never None, since two current_states
         # with a state_group of None are likely to be different.
@@ -371,9 +375,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         # the hit ratio counts. After all, we don't populate the cache if we
         # miss it here
         event_map = self._get_events_from_cache(
-            member_event_ids,
-            allow_rejected=False,
-            update_metrics=False,
+            member_event_ids, allow_rejected=False, update_metrics=False
         )
 
         missing_member_event_ids = []
@@ -397,21 +399,21 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 table="room_memberships",
                 column="event_id",
                 iterable=missing_member_event_ids,
-                retcols=('user_id', 'display_name', 'avatar_url',),
-                keyvalues={
-                    "membership": Membership.JOIN,
-                },
+                retcols=('user_id', 'display_name', 'avatar_url'),
+                keyvalues={"membership": Membership.JOIN},
                 batch_size=500,
                 desc="_get_joined_users_from_context",
             )
 
-            users_in_room.update({
-                to_ascii(row["user_id"]): ProfileInfo(
-                    avatar_url=to_ascii(row["avatar_url"]),
-                    display_name=to_ascii(row["display_name"]),
-                )
-                for row in rows
-            })
+            users_in_room.update(
+                {
+                    to_ascii(row["user_id"]): ProfileInfo(
+                        avatar_url=to_ascii(row["avatar_url"]),
+                        display_name=to_ascii(row["display_name"]),
+                    )
+                    for row in rows
+                }
+            )
 
         if event is not None and event.type == EventTypes.Member:
             if event.membership == Membership.JOIN:
@@ -505,7 +507,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             state_group = object()
 
         return self._get_joined_hosts(
-            room_id, state_group, state_entry.state, state_entry=state_entry,
+            room_id, state_group, state_entry.state, state_entry=state_entry
         )
 
     @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
@@ -531,6 +533,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         """Returns whether user_id has elected to discard history for room_id.
 
         Returns False if they have since re-joined."""
+
         def f(txn):
             sql = (
                 "SELECT"
@@ -547,6 +550,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             txn.execute(sql, (user_id, room_id))
             rows = txn.fetchall()
             return rows[0][0]
+
         count = yield self.runInteraction("did_forget_membership", f)
         defer.returnValue(count == 0)
 
@@ -575,13 +579,14 @@ class RoomMemberStore(RoomMemberWorkerStore):
                     "avatar_url": event.content.get("avatar_url", None),
                 }
                 for event in events
-            ]
+            ],
         )
 
         for event in events:
             txn.call_after(
                 self._membership_stream_cache.entity_has_changed,
-                event.state_key, event.internal_metadata.stream_ordering
+                event.state_key,
+                event.internal_metadata.stream_ordering,
             )
             txn.call_after(
                 self.get_invited_rooms_for_user.invalidate, (event.state_key,)
@@ -607,7 +612,7 @@ class RoomMemberStore(RoomMemberWorkerStore):
                             "inviter": event.sender,
                             "room_id": event.room_id,
                             "stream_id": event.internal_metadata.stream_ordering,
-                        }
+                        },
                     )
                 else:
                     sql = (
@@ -616,12 +621,15 @@ class RoomMemberStore(RoomMemberWorkerStore):
                         " AND replaced_by is NULL"
                     )
 
-                    txn.execute(sql, (
-                        event.internal_metadata.stream_ordering,
-                        event.event_id,
-                        event.room_id,
-                        event.state_key,
-                    ))
+                    txn.execute(
+                        sql,
+                        (
+                            event.internal_metadata.stream_ordering,
+                            event.event_id,
+                            event.room_id,
+                            event.state_key,
+                        ),
+                    )
 
     @defer.inlineCallbacks
     def locally_reject_invite(self, user_id, room_id):
@@ -632,18 +640,14 @@ class RoomMemberStore(RoomMemberWorkerStore):
         )
 
         def f(txn, stream_ordering):
-            txn.execute(sql, (
-                stream_ordering,
-                True,
-                room_id,
-                user_id,
-            ))
+            txn.execute(sql, (stream_ordering, True, room_id, user_id))
 
         with self._stream_id_gen.get_next() as stream_ordering:
             yield self.runInteraction("locally_reject_invite", f, stream_ordering)
 
     def forget(self, user_id, room_id):
         """Indicate that user_id wishes to discard history for room_id."""
+
         def f(txn):
             sql = (
                 "UPDATE"
@@ -657,9 +661,8 @@ class RoomMemberStore(RoomMemberWorkerStore):
             )
             txn.execute(sql, (user_id, room_id))
 
-            self._invalidate_cache_and_stream(
-                txn, self.did_forget, (user_id, room_id,),
-            )
+            self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id))
+
         return self.runInteraction("forget_membership", f)
 
     @defer.inlineCallbacks
@@ -674,7 +677,7 @@ class RoomMemberStore(RoomMemberWorkerStore):
         INSERT_CLUMP_SIZE = 1000
 
         def add_membership_profile_txn(txn):
-            sql = ("""
+            sql = """
                 SELECT stream_ordering, event_id, events.room_id, event_json.json
                 FROM events
                 INNER JOIN event_json USING (event_id)
@@ -683,7 +686,7 @@ class RoomMemberStore(RoomMemberWorkerStore):
                 AND type = 'm.room.member'
                 ORDER BY stream_ordering DESC
                 LIMIT ?
-            """)
+            """
 
             txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
 
@@ -707,16 +710,14 @@ class RoomMemberStore(RoomMemberWorkerStore):
                 avatar_url = content.get("avatar_url", None)
 
                 if display_name or avatar_url:
-                    to_update.append((
-                        display_name, avatar_url, event_id, room_id
-                    ))
+                    to_update.append((display_name, avatar_url, event_id, room_id))
 
-            to_update_sql = ("""
+            to_update_sql = """
                 UPDATE room_memberships SET display_name = ?, avatar_url = ?
                 WHERE event_id = ? AND room_id = ?
-            """)
+            """
             for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
-                clump = to_update[index:index + INSERT_CLUMP_SIZE]
+                clump = to_update[index : index + INSERT_CLUMP_SIZE]
                 txn.executemany(to_update_sql, clump)
 
             progress = {
@@ -789,7 +790,7 @@ class _JoinedHostsCache(object):
                             self.hosts_to_joined_users.pop(host, None)
             else:
                 joined_users = yield self.store.get_joined_users_from_state(
-                    self.room_id, state_entry,
+                    self.room_id, state_entry
                 )
 
                 self.hosts_to_joined_users = {}
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index fc2b646ba2..94c6080e34 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -16,7 +16,11 @@
 
 from mock import Mock, call
 
-from synapse.api.constants import PresenceState
+from signedjson.key import generate_signing_key
+
+from synapse.api.constants import EventTypes, Membership, PresenceState
+from synapse.events import room_version_to_event_format
+from synapse.events.builder import EventBuilder
 from synapse.handlers.presence import (
     FEDERATION_PING_INTERVAL,
     FEDERATION_TIMEOUT,
@@ -26,7 +30,9 @@ from synapse.handlers.presence import (
     handle_timeout,
     handle_update,
 )
+from synapse.rest.client.v1 import room
 from synapse.storage.presence import UserPresenceState
+from synapse.types import UserID, get_domain_from_id
 
 from tests import unittest
 
@@ -405,3 +411,171 @@ class PresenceTimeoutTestCase(unittest.TestCase):
 
         self.assertIsNotNone(new_state)
         self.assertEquals(state, new_state)
+
+
+class PresenceJoinTestCase(unittest.HomeserverTestCase):
+    """Tests remote servers get told about presence of users in the room when
+    they join and when new local users join.
+    """
+
+    user_id = "@test:server"
+
+    servlets = [room.register_servlets]
+
+    def make_homeserver(self, reactor, clock):
+        hs = self.setup_test_homeserver(
+            "server", http_client=None,
+            federation_sender=Mock(),
+        )
+        return hs
+
+    def prepare(self, reactor, clock, hs):
+        self.federation_sender = hs.get_federation_sender()
+        self.event_builder_factory = hs.get_event_builder_factory()
+        self.federation_handler = hs.get_handlers().federation_handler
+        self.presence_handler = hs.get_presence_handler()
+
+        # self.event_builder_for_2 = EventBuilderFactory(hs)
+        # self.event_builder_for_2.hostname = "test2"
+
+        self.store = hs.get_datastore()
+        self.state = hs.get_state_handler()
+        self.auth = hs.get_auth()
+
+        # We don't actually check signatures in tests, so lets just create a
+        # random key to use.
+        self.random_signing_key = generate_signing_key("ver")
+
+    def test_remote_joins(self):
+        # We advance time to something that isn't 0, as we use 0 as a special
+        # value.
+        self.reactor.advance(1000000000000)
+
+        # Create a room with two local users
+        room_id = self.helper.create_room_as(self.user_id)
+        self.helper.join(room_id, "@test2:server")
+
+        # Mark test2 as online, test will be offline with a last_active of 0
+        self.presence_handler.set_state(
+            UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE},
+        )
+        self.reactor.pump([0])  # Wait for presence updates to be handled
+
+        #
+        # Test that a new server gets told about existing presence
+        #
+
+        self.federation_sender.reset_mock()
+
+        # Add a new remote server to the room
+        self._add_new_user(room_id, "@alice:server2")
+
+        # We shouldn't have sent out any local presence *updates*
+        self.federation_sender.send_presence.assert_not_called()
+
+        # When new server is joined we send it the local users presence states.
+        # We expect to only see user @test2:server, as @test:server is offline
+        # and has a zero last_active_ts
+        expected_state = self.get_success(
+            self.presence_handler.current_state_for_user("@test2:server")
+        )
+        self.assertEqual(expected_state.state, PresenceState.ONLINE)
+        self.federation_sender.send_presence_to_destinations.assert_called_once_with(
+            destinations=["server2"], states=[expected_state]
+        )
+
+        #
+        # Test that only the new server gets sent presence and not existing servers
+        #
+
+        self.federation_sender.reset_mock()
+        self._add_new_user(room_id, "@bob:server3")
+
+        self.federation_sender.send_presence.assert_not_called()
+        self.federation_sender.send_presence_to_destinations.assert_called_once_with(
+            destinations=["server3"], states=[expected_state]
+        )
+
+    def test_remote_gets_presence_when_local_user_joins(self):
+        # We advance time to something that isn't 0, as we use 0 as a special
+        # value.
+        self.reactor.advance(1000000000000)
+
+        # Create a room with one local users
+        room_id = self.helper.create_room_as(self.user_id)
+
+        # Mark test as online
+        self.presence_handler.set_state(
+            UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE},
+        )
+
+        # Mark test2 as online, test will be offline with a last_active of 0.
+        # Note we don't join them to the room yet
+        self.presence_handler.set_state(
+            UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE},
+        )
+
+        # Add servers to the room
+        self._add_new_user(room_id, "@alice:server2")
+        self._add_new_user(room_id, "@bob:server3")
+
+        self.reactor.pump([0])  # Wait for presence updates to be handled
+
+        #
+        # Test that when a local join happens remote servers get told about it
+        #
+
+        self.federation_sender.reset_mock()
+
+        # Join local user to room
+        self.helper.join(room_id, "@test2:server")
+
+        self.reactor.pump([0])  # Wait for presence updates to be handled
+
+        # We shouldn't have sent out any local presence *updates*
+        self.federation_sender.send_presence.assert_not_called()
+
+        # We expect to only send test2 presence to server2 and server3
+        expected_state = self.get_success(
+            self.presence_handler.current_state_for_user("@test2:server")
+        )
+        self.assertEqual(expected_state.state, PresenceState.ONLINE)
+        self.federation_sender.send_presence_to_destinations.assert_called_once_with(
+            destinations=set(("server2", "server3")),
+            states=[expected_state]
+        )
+
+    def _add_new_user(self, room_id, user_id):
+        """Add new user to the room by creating an event and poking the federation API.
+        """
+
+        hostname = get_domain_from_id(user_id)
+
+        room_version = self.get_success(self.store.get_room_version(room_id))
+
+        builder = EventBuilder(
+            state=self.state,
+            auth=self.auth,
+            store=self.store,
+            clock=self.clock,
+            hostname=hostname,
+            signing_key=self.random_signing_key,
+            format_version=room_version_to_event_format(room_version),
+            room_id=room_id,
+            type=EventTypes.Member,
+            sender=user_id,
+            state_key=user_id,
+            content={"membership": Membership.JOIN}
+        )
+
+        prev_event_ids = self.get_success(
+            self.store.get_latest_event_ids_in_room(room_id)
+        )
+
+        event = self.get_success(builder.build(prev_event_ids))
+
+        self.get_success(self.federation_handler.on_receive_pdu(hostname, event))
+
+        # Check that it was successfully persisted.
+        self.get_success(self.store.get_event(event.event_id))
+        self.get_success(self.store.get_event(event.event_id))