summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/ratelimiting.py31
-rw-r--r--synapse/app/client_reader.py11
-rw-r--r--synapse/config/_base.py27
-rw-r--r--synapse/config/database.py3
-rw-r--r--synapse/config/logger.py4
-rw-r--r--synapse/config/ratelimiting.py18
-rw-r--r--synapse/config/registration.py2
-rw-r--r--synapse/config/server.py4
-rw-r--r--synapse/federation/send_queue.py14
-rw-r--r--synapse/federation/transaction_queue.py31
-rw-r--r--synapse/federation/transport/server.py44
-rw-r--r--synapse/handlers/_base.py4
-rw-r--r--synapse/handlers/device.py342
-rw-r--r--synapse/handlers/federation.py46
-rw-r--r--synapse/handlers/presence.py6
-rw-r--r--synapse/handlers/receipts.py103
-rw-r--r--synapse/handlers/register.py39
-rw-r--r--synapse/handlers/sync.py86
-rw-r--r--synapse/handlers/typing.py2
-rw-r--r--synapse/notifier.py65
-rw-r--r--synapse/replication/http/register.py8
-rw-r--r--synapse/replication/slave/storage/client_ips.py2
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py15
-rw-r--r--synapse/replication/slave/storage/devices.py45
-rw-r--r--synapse/replication/slave/storage/push_rule.py2
-rw-r--r--synapse/replication/tcp/protocol.py31
-rw-r--r--synapse/rest/client/v1/admin.py45
-rw-r--r--synapse/rest/client/v2_alpha/register.py33
-rw-r--r--synapse/server.py11
-rw-r--r--synapse/storage/deviceinbox.py324
-rw-r--r--synapse/storage/devices.py675
-rw-r--r--synapse/storage/end_to_end_keys.py88
-rw-r--r--synapse/storage/event_federation.py22
-rw-r--r--synapse/storage/receipts.py26
-rw-r--r--synapse/storage/stream.py19
-rw-r--r--synapse/visibility.py87
36 files changed, 1302 insertions, 1013 deletions
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index 3bb5b3da37..ad68079eeb 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -23,12 +23,13 @@ class Ratelimiter(object):
     def __init__(self):
         self.message_counts = collections.OrderedDict()
 
-    def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count, update=True):
-        """Can the user send a message?
+    def can_do_action(self, key, time_now_s, rate_hz, burst_count, update=True):
+        """Can the entity (e.g. user or IP address) perform the action?
         Args:
-            user_id: The user sending a message.
+            key: The key we should use when rate limiting. Can be a user ID
+                (when sending events), an IP address, etc.
             time_now_s: The time now.
-            msg_rate_hz: The long term number of messages a user can send in a
+            rate_hz: The long term number of messages a user can send in a
                 second.
             burst_count: How many messages the user can send before being
                 limited.
@@ -41,10 +42,10 @@ class Ratelimiter(object):
         """
         self.prune_message_counts(time_now_s)
         message_count, time_start, _ignored = self.message_counts.get(
-            user_id, (0., time_now_s, None),
+            key, (0., time_now_s, None),
         )
         time_delta = time_now_s - time_start
-        sent_count = message_count - time_delta * msg_rate_hz
+        sent_count = message_count - time_delta * rate_hz
         if sent_count < 0:
             allowed = True
             time_start = time_now_s
@@ -56,13 +57,13 @@ class Ratelimiter(object):
             message_count += 1
 
         if update:
-            self.message_counts[user_id] = (
-                message_count, time_start, msg_rate_hz
+            self.message_counts[key] = (
+                message_count, time_start, rate_hz
             )
 
-        if msg_rate_hz > 0:
+        if rate_hz > 0:
             time_allowed = (
-                time_start + (message_count - burst_count + 1) / msg_rate_hz
+                time_start + (message_count - burst_count + 1) / rate_hz
             )
             if time_allowed < time_now_s:
                 time_allowed = time_now_s
@@ -72,12 +73,12 @@ class Ratelimiter(object):
         return allowed, time_allowed
 
     def prune_message_counts(self, time_now_s):
-        for user_id in list(self.message_counts.keys()):
-            message_count, time_start, msg_rate_hz = (
-                self.message_counts[user_id]
+        for key in list(self.message_counts.keys()):
+            message_count, time_start, rate_hz = (
+                self.message_counts[key]
             )
             time_delta = time_now_s - time_start
-            if message_count - time_delta * msg_rate_hz > 0:
+            if message_count - time_delta * rate_hz > 0:
                 break
             else:
-                del self.message_counts[user_id]
+                del self.message_counts[key]
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index 5070094cad..beaea64a61 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -33,9 +33,13 @@ from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
 from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
 from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
+from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
+from synapse.replication.slave.storage.devices import SlavedDeviceStore
 from synapse.replication.slave.storage.directory import DirectoryStore
 from synapse.replication.slave.storage.events import SlavedEventStore
 from synapse.replication.slave.storage.keys import SlavedKeyStore
+from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
+from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
 from synapse.replication.slave.storage.registration import SlavedRegistrationStore
 from synapse.replication.slave.storage.room import RoomStore
 from synapse.replication.slave.storage.transactions import SlavedTransactionStore
@@ -49,6 +53,7 @@ from synapse.rest.client.v1.room import (
     RoomStateRestServlet,
 )
 from synapse.rest.client.v2_alpha.account import ThreepidRestServlet
+from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet
 from synapse.rest.client.v2_alpha.register import RegisterRestServlet
 from synapse.server import HomeServer
 from synapse.storage.engines import create_engine
@@ -61,6 +66,10 @@ logger = logging.getLogger("synapse.app.client_reader")
 
 
 class ClientReaderSlavedStore(
+    SlavedDeviceInboxStore,
+    SlavedDeviceStore,
+    SlavedReceiptsStore,
+    SlavedPushRuleStore,
     SlavedAccountDataStore,
     SlavedEventStore,
     SlavedKeyStore,
@@ -98,6 +107,8 @@ class ClientReaderServer(HomeServer):
                     RegisterRestServlet(self).register(resource)
                     LoginRestServlet(self).register(resource)
                     ThreepidRestServlet(self).register(resource)
+                    KeyQueryServlet(self).register(resource)
+                    KeyChangesServlet(self).register(resource)
 
                     resources.update({
                         "/_matrix/client/r0": resource,
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 5aec43b702..c4d3087fa4 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -180,9 +180,7 @@ class Config(object):
         Returns:
             str: the yaml config file
         """
-        default_config = "# vim:ft=yaml\n"
-
-        default_config += "\n\n".join(
+        default_config = "\n\n".join(
             dedent(conf)
             for conf in self.invoke_all(
                 "default_config",
@@ -297,19 +295,26 @@ class Config(object):
                         "Must specify a server_name to a generate config for."
                         " Pass -H server.name."
                     )
+
+                config_str = obj.generate_config(
+                    config_dir_path=config_dir_path,
+                    data_dir_path=os.getcwd(),
+                    server_name=server_name,
+                    report_stats=(config_args.report_stats == "yes"),
+                    generate_secrets=True,
+                )
+
                 if not cls.path_exists(config_dir_path):
                     os.makedirs(config_dir_path)
                 with open(config_path, "w") as config_file:
-                    config_str = obj.generate_config(
-                        config_dir_path=config_dir_path,
-                        data_dir_path=os.getcwd(),
-                        server_name=server_name,
-                        report_stats=(config_args.report_stats == "yes"),
-                        generate_secrets=True,
+                    config_file.write(
+                        "# vim:ft=yaml\n\n"
                     )
-                    config = yaml.load(config_str)
-                    obj.invoke_all("generate_files", config)
                     config_file.write(config_str)
+
+                config = yaml.load(config_str)
+                obj.invoke_all("generate_files", config)
+
                 print(
                     (
                         "A config file has been generated in %r for server name"
diff --git a/synapse/config/database.py b/synapse/config/database.py
index c8890147a6..63e9cb63f8 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -49,7 +49,8 @@ class DatabaseConfig(Config):
     def default_config(self, data_dir_path, **kwargs):
         database_path = os.path.join(data_dir_path, "homeserver.db")
         return """\
-        # Database configuration
+        ## Database ##
+
         database:
           # The database engine name
           name: "sqlite3"
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index f6940b65fd..464c28c2d9 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -81,7 +81,9 @@ class LoggingConfig(Config):
 
     def default_config(self, config_dir_path, server_name, **kwargs):
         log_config = os.path.join(config_dir_path, server_name + ".log.config")
-        return """
+        return """\
+        ## Logging ##
+
         # A yaml python logging config file
         #
         log_config: "%(log_config)s"
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 54b71e6841..093042fdb9 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -27,6 +27,13 @@ class RatelimitConfig(Config):
         self.federation_rc_reject_limit = config["federation_rc_reject_limit"]
         self.federation_rc_concurrent = config["federation_rc_concurrent"]
 
+        self.rc_registration_requests_per_second = config.get(
+            "rc_registration_requests_per_second", 0.17,
+        )
+        self.rc_registration_request_burst_count = config.get(
+            "rc_registration_request_burst_count", 3,
+        )
+
     def default_config(self, **kwargs):
         return """\
         ## Ratelimiting ##
@@ -62,4 +69,15 @@ class RatelimitConfig(Config):
         # single server
         #
         federation_rc_concurrent: 3
+
+        # Number of registration requests a client can send per second.
+        # Defaults to 1/minute (0.17).
+        #
+        #rc_registration_requests_per_second: 0.17
+
+        # Number of registration requests a client can send before being
+        # throttled.
+        # Defaults to 3.
+        #
+        #rc_registration_request_burst_count: 3.0
         """
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 2881482f96..d34dc9e456 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -64,6 +64,8 @@ class RegistrationConfig(Config):
 
         return """\
         ## Registration ##
+        # Registration can be rate-limited using the parameters in the "Ratelimiting"
+        # section of this file.
 
         # Enable registration for new users.
         enable_registration: False
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 4200f10da3..35a322fee0 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -260,9 +260,11 @@ class ServerConfig(Config):
         # This is used by remote servers to connect to this server,
         # e.g. matrix.org, localhost:8080, etc.
         # This is also the last part of your UserID.
+        #
         server_name: "%(server_name)s"
 
         # When running as a daemon, the file to store the pid in
+        #
         pid_file: %(pid_file)s
 
         # CPU affinity mask. Setting this restricts the CPUs on which the
@@ -304,9 +306,11 @@ class ServerConfig(Config):
         # Set the soft limit on the number of file descriptors synapse can use
         # Zero is used to indicate synapse should set the soft limit to the
         # hard limit.
+        #
         soft_file_limit: 0
 
         # Set to false to disable presence tracking on this homeserver.
+        #
         use_presence: true
 
         # The GC threshold parameters to pass to `gc.set_threshold`, if defined
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 6f5995735a..b7d0b25781 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -159,8 +159,12 @@ class FederationRemoteSendQueue(object):
         # stream.
         pass
 
-    def send_edu(self, destination, edu_type, content, key=None):
+    def build_and_send_edu(self, destination, edu_type, content, key=None):
         """As per TransactionQueue"""
+        if destination == self.server_name:
+            logger.info("Not sending EDU to ourselves")
+            return
+
         pos = self._next_pos()
 
         edu = Edu(
@@ -465,15 +469,11 @@ def process_rows_for_federation(transaction_queue, rows):
 
     for destination, edu_map in iteritems(buff.keyed_edus):
         for key, edu in edu_map.items():
-            transaction_queue.send_edu(
-                edu.destination, edu.edu_type, edu.content, key=key,
-            )
+            transaction_queue.send_edu(edu, key)
 
     for destination, edu_list in iteritems(buff.edus):
         for edu in edu_list:
-            transaction_queue.send_edu(
-                edu.destination, edu.edu_type, edu.content, key=None,
-            )
+            transaction_queue.send_edu(edu, None)
 
     for destination in buff.device_destinations:
         transaction_queue.send_device_messages(destination)
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 30941f5ad6..e5e42c647d 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -361,7 +361,19 @@ class TransactionQueue(object):
 
                 self._attempt_new_transaction(destination)
 
-    def send_edu(self, destination, edu_type, content, key=None):
+    def build_and_send_edu(self, destination, edu_type, content, key=None):
+        """Construct an Edu object, and queue it for sending
+
+        Args:
+            destination (str): name of server to send to
+            edu_type (str): type of EDU to send
+            content (dict): content of EDU
+            key (Any|None): clobbering key for this edu
+        """
+        if destination == self.server_name:
+            logger.info("Not sending EDU to ourselves")
+            return
+
         edu = Edu(
             origin=self.server_name,
             destination=destination,
@@ -369,18 +381,23 @@ class TransactionQueue(object):
             content=content,
         )
 
-        if destination == self.server_name:
-            logger.info("Not sending EDU to ourselves")
-            return
+        self.send_edu(edu, key)
+
+    def send_edu(self, edu, key):
+        """Queue an EDU for sending
 
+        Args:
+            edu (Edu): edu to send
+            key (Any|None): clobbering key for this edu
+        """
         if key:
             self.pending_edus_keyed_by_dest.setdefault(
-                destination, {}
+                edu.destination, {}
             )[(edu.edu_type, key)] = edu
         else:
-            self.pending_edus_by_dest.setdefault(destination, []).append(edu)
+            self.pending_edus_by_dest.setdefault(edu.destination, []).append(edu)
 
-        self._attempt_new_transaction(destination)
+        self._attempt_new_transaction(edu.destination)
 
     def send_device_messages(self, destination):
         if destination == self.server_name:
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index ebb81be377..96d680a5ad 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -759,7 +759,7 @@ class FederationVersionServlet(BaseFederationServlet):
 class FederationGroupsProfileServlet(BaseFederationServlet):
     """Get/set the basic profile of a group on behalf of a user
     """
-    PATH = "/groups/(?P<group_id>[^/]*)/profile$"
+    PATH = "/groups/(?P<group_id>[^/]*)/profile"
 
     @defer.inlineCallbacks
     def on_GET(self, origin, content, query, group_id):
@@ -787,7 +787,7 @@ class FederationGroupsProfileServlet(BaseFederationServlet):
 
 
 class FederationGroupsSummaryServlet(BaseFederationServlet):
-    PATH = "/groups/(?P<group_id>[^/]*)/summary$"
+    PATH = "/groups/(?P<group_id>[^/]*)/summary"
 
     @defer.inlineCallbacks
     def on_GET(self, origin, content, query, group_id):
@@ -805,7 +805,7 @@ class FederationGroupsSummaryServlet(BaseFederationServlet):
 class FederationGroupsRoomsServlet(BaseFederationServlet):
     """Get the rooms in a group on behalf of a user
     """
-    PATH = "/groups/(?P<group_id>[^/]*)/rooms$"
+    PATH = "/groups/(?P<group_id>[^/]*)/rooms"
 
     @defer.inlineCallbacks
     def on_GET(self, origin, content, query, group_id):
@@ -823,7 +823,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet):
 class FederationGroupsAddRoomsServlet(BaseFederationServlet):
     """Add/remove room from group
     """
-    PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)$"
+    PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
 
     @defer.inlineCallbacks
     def on_POST(self, origin, content, query, group_id, room_id):
@@ -855,7 +855,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
     """
     PATH = (
         "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
-        "/config/(?P<config_key>[^/]*)$"
+        "/config/(?P<config_key>[^/]*)"
     )
 
     @defer.inlineCallbacks
@@ -874,7 +874,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
 class FederationGroupsUsersServlet(BaseFederationServlet):
     """Get the users in a group on behalf of a user
     """
-    PATH = "/groups/(?P<group_id>[^/]*)/users$"
+    PATH = "/groups/(?P<group_id>[^/]*)/users"
 
     @defer.inlineCallbacks
     def on_GET(self, origin, content, query, group_id):
@@ -892,7 +892,7 @@ class FederationGroupsUsersServlet(BaseFederationServlet):
 class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
     """Get the users that have been invited to a group
     """
-    PATH = "/groups/(?P<group_id>[^/]*)/invited_users$"
+    PATH = "/groups/(?P<group_id>[^/]*)/invited_users"
 
     @defer.inlineCallbacks
     def on_GET(self, origin, content, query, group_id):
@@ -910,7 +910,7 @@ class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
 class FederationGroupsInviteServlet(BaseFederationServlet):
     """Ask a group server to invite someone to the group
     """
-    PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite$"
+    PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
 
     @defer.inlineCallbacks
     def on_POST(self, origin, content, query, group_id, user_id):
@@ -928,7 +928,7 @@ class FederationGroupsInviteServlet(BaseFederationServlet):
 class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
     """Accept an invitation from the group server
     """
-    PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite$"
+    PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite"
 
     @defer.inlineCallbacks
     def on_POST(self, origin, content, query, group_id, user_id):
@@ -945,7 +945,7 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
 class FederationGroupsJoinServlet(BaseFederationServlet):
     """Attempt to join a group
     """
-    PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join$"
+    PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join"
 
     @defer.inlineCallbacks
     def on_POST(self, origin, content, query, group_id, user_id):
@@ -962,7 +962,7 @@ class FederationGroupsJoinServlet(BaseFederationServlet):
 class FederationGroupsRemoveUserServlet(BaseFederationServlet):
     """Leave or kick a user from the group
     """
-    PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove$"
+    PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
 
     @defer.inlineCallbacks
     def on_POST(self, origin, content, query, group_id, user_id):
@@ -980,7 +980,7 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet):
 class FederationGroupsLocalInviteServlet(BaseFederationServlet):
     """A group server has invited a local user
     """
-    PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite$"
+    PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
 
     @defer.inlineCallbacks
     def on_POST(self, origin, content, query, group_id, user_id):
@@ -997,7 +997,7 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet):
 class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
     """A group server has removed a local user
     """
-    PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove$"
+    PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
 
     @defer.inlineCallbacks
     def on_POST(self, origin, content, query, group_id, user_id):
@@ -1014,7 +1014,7 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
 class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
     """A group or user's server renews their attestation
     """
-    PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)$"
+    PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)"
 
     @defer.inlineCallbacks
     def on_POST(self, origin, content, query, group_id, user_id):
@@ -1037,7 +1037,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
     PATH = (
         "/groups/(?P<group_id>[^/]*)/summary"
         "(/categories/(?P<category_id>[^/]+))?"
-        "/rooms/(?P<room_id>[^/]*)$"
+        "/rooms/(?P<room_id>[^/]*)"
     )
 
     @defer.inlineCallbacks
@@ -1080,7 +1080,7 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet):
     """Get all categories for a group
     """
     PATH = (
-        "/groups/(?P<group_id>[^/]*)/categories/$"
+        "/groups/(?P<group_id>[^/]*)/categories/"
     )
 
     @defer.inlineCallbacks
@@ -1100,7 +1100,7 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
     """Add/remove/get a category in a group
     """
     PATH = (
-        "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
+        "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)"
     )
 
     @defer.inlineCallbacks
@@ -1150,7 +1150,7 @@ class FederationGroupsRolesServlet(BaseFederationServlet):
     """Get roles in a group
     """
     PATH = (
-        "/groups/(?P<group_id>[^/]*)/roles/$"
+        "/groups/(?P<group_id>[^/]*)/roles/"
     )
 
     @defer.inlineCallbacks
@@ -1170,7 +1170,7 @@ class FederationGroupsRoleServlet(BaseFederationServlet):
     """Add/remove/get a role in a group
     """
     PATH = (
-        "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$"
+        "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)"
     )
 
     @defer.inlineCallbacks
@@ -1226,7 +1226,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
     PATH = (
         "/groups/(?P<group_id>[^/]*)/summary"
         "(/roles/(?P<role_id>[^/]+))?"
-        "/users/(?P<user_id>[^/]*)$"
+        "/users/(?P<user_id>[^/]*)"
     )
 
     @defer.inlineCallbacks
@@ -1269,7 +1269,7 @@ class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
     """Get roles in a group
     """
     PATH = (
-        "/get_groups_publicised$"
+        "/get_groups_publicised"
     )
 
     @defer.inlineCallbacks
@@ -1284,7 +1284,7 @@ class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
 class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
     """Sets whether a group is joinable without an invite or knock
     """
-    PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy$"
+    PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy"
 
     @defer.inlineCallbacks
     def on_PUT(self, origin, content, query, group_id):
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 594754cfd8..d8d86d6ff3 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -93,9 +93,9 @@ class BaseHandler(object):
             messages_per_second = self.hs.config.rc_messages_per_second
             burst_count = self.hs.config.rc_message_burst_count
 
-        allowed, time_allowed = self.ratelimiter.send_message(
+        allowed, time_allowed = self.ratelimiter.can_do_action(
             user_id, time_now,
-            msg_rate_hz=messages_per_second,
+            rate_hz=messages_per_second,
             burst_count=burst_count,
             update=update,
         )
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index c708c35d4d..c09a7c6280 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -37,13 +37,185 @@ from ._base import BaseHandler
 logger = logging.getLogger(__name__)
 
 
-class DeviceHandler(BaseHandler):
+class DeviceWorkerHandler(BaseHandler):
     def __init__(self, hs):
-        super(DeviceHandler, self).__init__(hs)
+        super(DeviceWorkerHandler, self).__init__(hs)
 
         self.hs = hs
         self.state = hs.get_state_handler()
         self._auth_handler = hs.get_auth_handler()
+
+    @defer.inlineCallbacks
+    def get_devices_by_user(self, user_id):
+        """
+        Retrieve the given user's devices
+
+        Args:
+            user_id (str):
+        Returns:
+            defer.Deferred: list[dict[str, X]]: info on each device
+        """
+
+        device_map = yield self.store.get_devices_by_user(user_id)
+
+        ips = yield self.store.get_last_client_ip_by_device(
+            user_id, device_id=None
+        )
+
+        devices = list(device_map.values())
+        for device in devices:
+            _update_device_from_client_ips(device, ips)
+
+        defer.returnValue(devices)
+
+    @defer.inlineCallbacks
+    def get_device(self, user_id, device_id):
+        """ Retrieve the given device
+
+        Args:
+            user_id (str):
+            device_id (str):
+
+        Returns:
+            defer.Deferred: dict[str, X]: info on the device
+        Raises:
+            errors.NotFoundError: if the device was not found
+        """
+        try:
+            device = yield self.store.get_device(user_id, device_id)
+        except errors.StoreError:
+            raise errors.NotFoundError
+        ips = yield self.store.get_last_client_ip_by_device(
+            user_id, device_id,
+        )
+        _update_device_from_client_ips(device, ips)
+        defer.returnValue(device)
+
+    @measure_func("device.get_user_ids_changed")
+    @defer.inlineCallbacks
+    def get_user_ids_changed(self, user_id, from_token):
+        """Get list of users that have had the devices updated, or have newly
+        joined a room, that `user_id` may be interested in.
+
+        Args:
+            user_id (str)
+            from_token (StreamToken)
+        """
+        now_room_key = yield self.store.get_room_events_max_id()
+
+        room_ids = yield self.store.get_rooms_for_user(user_id)
+
+        # First we check if any devices have changed
+        changed = yield self.store.get_user_whose_devices_changed(
+            from_token.device_list_key
+        )
+
+        # Then work out if any users have since joined
+        rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
+
+        member_events = yield self.store.get_membership_changes_for_user(
+            user_id, from_token.room_key, now_room_key,
+        )
+        rooms_changed.update(event.room_id for event in member_events)
+
+        stream_ordering = RoomStreamToken.parse_stream_token(
+            from_token.room_key
+        ).stream
+
+        possibly_changed = set(changed)
+        possibly_left = set()
+        for room_id in rooms_changed:
+            current_state_ids = yield self.store.get_current_state_ids(room_id)
+
+            # The user may have left the room
+            # TODO: Check if they actually did or if we were just invited.
+            if room_id not in room_ids:
+                for key, event_id in iteritems(current_state_ids):
+                    etype, state_key = key
+                    if etype != EventTypes.Member:
+                        continue
+                    possibly_left.add(state_key)
+                continue
+
+            # Fetch the current state at the time.
+            try:
+                event_ids = yield self.store.get_forward_extremeties_for_room(
+                    room_id, stream_ordering=stream_ordering
+                )
+            except errors.StoreError:
+                # we have purged the stream_ordering index since the stream
+                # ordering: treat it the same as a new room
+                event_ids = []
+
+            # special-case for an empty prev state: include all members
+            # in the changed list
+            if not event_ids:
+                for key, event_id in iteritems(current_state_ids):
+                    etype, state_key = key
+                    if etype != EventTypes.Member:
+                        continue
+                    possibly_changed.add(state_key)
+                continue
+
+            current_member_id = current_state_ids.get((EventTypes.Member, user_id))
+            if not current_member_id:
+                continue
+
+            # mapping from event_id -> state_dict
+            prev_state_ids = yield self.store.get_state_ids_for_events(event_ids)
+
+            # Check if we've joined the room? If so we just blindly add all the users to
+            # the "possibly changed" users.
+            for state_dict in itervalues(prev_state_ids):
+                member_event = state_dict.get((EventTypes.Member, user_id), None)
+                if not member_event or member_event != current_member_id:
+                    for key, event_id in iteritems(current_state_ids):
+                        etype, state_key = key
+                        if etype != EventTypes.Member:
+                            continue
+                        possibly_changed.add(state_key)
+                    break
+
+            # If there has been any change in membership, include them in the
+            # possibly changed list. We'll check if they are joined below,
+            # and we're not toooo worried about spuriously adding users.
+            for key, event_id in iteritems(current_state_ids):
+                etype, state_key = key
+                if etype != EventTypes.Member:
+                    continue
+
+                # check if this member has changed since any of the extremities
+                # at the stream_ordering, and add them to the list if so.
+                for state_dict in itervalues(prev_state_ids):
+                    prev_event_id = state_dict.get(key, None)
+                    if not prev_event_id or prev_event_id != event_id:
+                        if state_key != user_id:
+                            possibly_changed.add(state_key)
+                        break
+
+        if possibly_changed or possibly_left:
+            users_who_share_room = yield self.store.get_users_who_share_room_with_user(
+                user_id
+            )
+
+            # Take the intersection of the users whose devices may have changed
+            # and those that actually still share a room with the user
+            possibly_joined = possibly_changed & users_who_share_room
+            possibly_left = (possibly_changed | possibly_left) - users_who_share_room
+        else:
+            possibly_joined = []
+            possibly_left = []
+
+        defer.returnValue({
+            "changed": list(possibly_joined),
+            "left": list(possibly_left),
+        })
+
+
+class DeviceHandler(DeviceWorkerHandler):
+    def __init__(self, hs):
+        super(DeviceHandler, self).__init__(hs)
+
         self.federation_sender = hs.get_federation_sender()
 
         self._edu_updater = DeviceListEduUpdater(hs, self)
@@ -104,52 +276,6 @@ class DeviceHandler(BaseHandler):
         raise errors.StoreError(500, "Couldn't generate a device ID.")
 
     @defer.inlineCallbacks
-    def get_devices_by_user(self, user_id):
-        """
-        Retrieve the given user's devices
-
-        Args:
-            user_id (str):
-        Returns:
-            defer.Deferred: list[dict[str, X]]: info on each device
-        """
-
-        device_map = yield self.store.get_devices_by_user(user_id)
-
-        ips = yield self.store.get_last_client_ip_by_device(
-            user_id, device_id=None
-        )
-
-        devices = list(device_map.values())
-        for device in devices:
-            _update_device_from_client_ips(device, ips)
-
-        defer.returnValue(devices)
-
-    @defer.inlineCallbacks
-    def get_device(self, user_id, device_id):
-        """ Retrieve the given device
-
-        Args:
-            user_id (str):
-            device_id (str):
-
-        Returns:
-            defer.Deferred: dict[str, X]: info on the device
-        Raises:
-            errors.NotFoundError: if the device was not found
-        """
-        try:
-            device = yield self.store.get_device(user_id, device_id)
-        except errors.StoreError:
-            raise errors.NotFoundError
-        ips = yield self.store.get_last_client_ip_by_device(
-            user_id, device_id,
-        )
-        _update_device_from_client_ips(device, ips)
-        defer.returnValue(device)
-
-    @defer.inlineCallbacks
     def delete_device(self, user_id, device_id):
         """ Delete the given device
 
@@ -287,126 +413,6 @@ class DeviceHandler(BaseHandler):
             for host in hosts:
                 self.federation_sender.send_device_messages(host)
 
-    @measure_func("device.get_user_ids_changed")
-    @defer.inlineCallbacks
-    def get_user_ids_changed(self, user_id, from_token):
-        """Get list of users that have had the devices updated, or have newly
-        joined a room, that `user_id` may be interested in.
-
-        Args:
-            user_id (str)
-            from_token (StreamToken)
-        """
-        now_token = yield self.hs.get_event_sources().get_current_token()
-
-        room_ids = yield self.store.get_rooms_for_user(user_id)
-
-        # First we check if any devices have changed
-        changed = yield self.store.get_user_whose_devices_changed(
-            from_token.device_list_key
-        )
-
-        # Then work out if any users have since joined
-        rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
-
-        member_events = yield self.store.get_membership_changes_for_user(
-            user_id, from_token.room_key, now_token.room_key
-        )
-        rooms_changed.update(event.room_id for event in member_events)
-
-        stream_ordering = RoomStreamToken.parse_stream_token(
-            from_token.room_key
-        ).stream
-
-        possibly_changed = set(changed)
-        possibly_left = set()
-        for room_id in rooms_changed:
-            current_state_ids = yield self.store.get_current_state_ids(room_id)
-
-            # The user may have left the room
-            # TODO: Check if they actually did or if we were just invited.
-            if room_id not in room_ids:
-                for key, event_id in iteritems(current_state_ids):
-                    etype, state_key = key
-                    if etype != EventTypes.Member:
-                        continue
-                    possibly_left.add(state_key)
-                continue
-
-            # Fetch the current state at the time.
-            try:
-                event_ids = yield self.store.get_forward_extremeties_for_room(
-                    room_id, stream_ordering=stream_ordering
-                )
-            except errors.StoreError:
-                # we have purged the stream_ordering index since the stream
-                # ordering: treat it the same as a new room
-                event_ids = []
-
-            # special-case for an empty prev state: include all members
-            # in the changed list
-            if not event_ids:
-                for key, event_id in iteritems(current_state_ids):
-                    etype, state_key = key
-                    if etype != EventTypes.Member:
-                        continue
-                    possibly_changed.add(state_key)
-                continue
-
-            current_member_id = current_state_ids.get((EventTypes.Member, user_id))
-            if not current_member_id:
-                continue
-
-            # mapping from event_id -> state_dict
-            prev_state_ids = yield self.store.get_state_ids_for_events(event_ids)
-
-            # Check if we've joined the room? If so we just blindly add all the users to
-            # the "possibly changed" users.
-            for state_dict in itervalues(prev_state_ids):
-                member_event = state_dict.get((EventTypes.Member, user_id), None)
-                if not member_event or member_event != current_member_id:
-                    for key, event_id in iteritems(current_state_ids):
-                        etype, state_key = key
-                        if etype != EventTypes.Member:
-                            continue
-                        possibly_changed.add(state_key)
-                    break
-
-            # If there has been any change in membership, include them in the
-            # possibly changed list. We'll check if they are joined below,
-            # and we're not toooo worried about spuriously adding users.
-            for key, event_id in iteritems(current_state_ids):
-                etype, state_key = key
-                if etype != EventTypes.Member:
-                    continue
-
-                # check if this member has changed since any of the extremities
-                # at the stream_ordering, and add them to the list if so.
-                for state_dict in itervalues(prev_state_ids):
-                    prev_event_id = state_dict.get(key, None)
-                    if not prev_event_id or prev_event_id != event_id:
-                        if state_key != user_id:
-                            possibly_changed.add(state_key)
-                        break
-
-        if possibly_changed or possibly_left:
-            users_who_share_room = yield self.store.get_users_who_share_room_with_user(
-                user_id
-            )
-
-            # Take the intersection of the users whose devices may have changed
-            # and those that actually still share a room with the user
-            possibly_joined = possibly_changed & users_who_share_room
-            possibly_left = (possibly_changed | possibly_left) - users_who_share_room
-        else:
-            possibly_joined = []
-            possibly_left = []
-
-        defer.returnValue({
-            "changed": list(possibly_joined),
-            "left": list(possibly_left),
-        })
-
     @defer.inlineCallbacks
     def on_federation_query_user_devices(self, user_id):
         stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index f80486102a..72b63d64d0 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -858,6 +858,52 @@ class FederationHandler(BaseHandler):
             logger.debug("Not backfilling as no extremeties found.")
             return
 
+        # We only want to paginate if we can actually see the events we'll get,
+        # as otherwise we'll just spend a lot of resources to get redacted
+        # events.
+        #
+        # We do this by filtering all the backwards extremities and seeing if
+        # any remain. Given we don't have the extremity events themselves, we
+        # need to actually check the events that reference them.
+        #
+        # *Note*: the spec wants us to keep backfilling until we reach the start
+        # of the room in case we are allowed to see some of the history. However
+        # in practice that causes more issues than its worth, as a) its
+        # relatively rare for there to be any visible history and b) even when
+        # there is its often sufficiently long ago that clients would stop
+        # attempting to paginate before backfill reached the visible history.
+        #
+        # TODO: If we do do a backfill then we should filter the backwards
+        #   extremities to only include those that point to visible portions of
+        #   history.
+        #
+        # TODO: Correctly handle the case where we are allowed to see the
+        #   forward event but not the backward extremity, e.g. in the case of
+        #   initial join of the server where we are allowed to see the join
+        #   event but not anything before it. This would require looking at the
+        #   state *before* the event, ignoring the special casing certain event
+        #   types have.
+
+        forward_events = yield self.store.get_successor_events(
+            list(extremities),
+        )
+
+        extremities_events = yield self.store.get_events(
+            forward_events,
+            check_redacted=False,
+            get_prev_content=False,
+        )
+
+        # We set `check_history_visibility_only` as we might otherwise get false
+        # positives from users having been erased.
+        filtered_extremities = yield filter_events_for_server(
+            self.store, self.server_name, list(extremities_events.values()),
+            redact=False, check_history_visibility_only=True,
+        )
+
+        if not filtered_extremities:
+            defer.returnValue(False)
+
         # Check if we reached a point where we should start backfilling.
         sorted_extremeties_tuple = sorted(
             extremities.items(),
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index ba3856674d..37e87fc054 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -816,7 +816,7 @@ class PresenceHandler(object):
         if self.is_mine(observed_user):
             yield self.invite_presence(observed_user, observer_user)
         else:
-            yield self.federation.send_edu(
+            yield self.federation.build_and_send_edu(
                 destination=observed_user.domain,
                 edu_type="m.presence_invite",
                 content={
@@ -836,7 +836,7 @@ class PresenceHandler(object):
         if self.is_mine(observer_user):
             yield self.accept_presence(observed_user, observer_user)
         else:
-            self.federation.send_edu(
+            self.federation.build_and_send_edu(
                 destination=observer_user.domain,
                 edu_type="m.presence_accept",
                 content={
@@ -848,7 +848,7 @@ class PresenceHandler(object):
             state_dict = yield self.get_state(observed_user, as_event=False)
             state_dict = format_user_presence_state(state_dict, self.clock.time_msec())
 
-            self.federation.send_edu(
+            self.federation.build_and_send_edu(
                 destination=observer_user.domain,
                 edu_type="m.presence",
                 content={
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 696469732c..1728089667 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -16,7 +16,6 @@ import logging
 
 from twisted.internet import defer
 
-from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.types import get_domain_from_id
 
 from ._base import BaseHandler
@@ -39,31 +38,6 @@ class ReceiptsHandler(BaseHandler):
         self.state = hs.get_state_handler()
 
     @defer.inlineCallbacks
-    def received_client_receipt(self, room_id, receipt_type, user_id,
-                                event_id):
-        """Called when a client tells us a local user has read up to the given
-        event_id in the room.
-        """
-        receipt = {
-            "room_id": room_id,
-            "receipt_type": receipt_type,
-            "user_id": user_id,
-            "event_ids": [event_id],
-            "data": {
-                "ts": int(self.clock.time_msec()),
-            }
-        }
-
-        is_new = yield self._handle_new_receipts([receipt])
-
-        if is_new:
-            # fire off a process in the background to send the receipt to
-            # remote servers
-            run_as_background_process(
-                'push_receipts_to_remotes', self._push_remotes, receipt
-            )
-
-    @defer.inlineCallbacks
     def _received_remote_receipt(self, origin, content):
         """Called when we receive an EDU of type m.receipt from a remote HS.
         """
@@ -128,43 +102,54 @@ class ReceiptsHandler(BaseHandler):
         defer.returnValue(True)
 
     @defer.inlineCallbacks
-    def _push_remotes(self, receipt):
-        """Given a receipt, works out which remote servers should be
-        poked and pokes them.
+    def received_client_receipt(self, room_id, receipt_type, user_id,
+                                event_id):
+        """Called when a client tells us a local user has read up to the given
+        event_id in the room.
         """
-        try:
-            # TODO: optimise this to move some of the work to the workers.
-            room_id = receipt["room_id"]
-            receipt_type = receipt["receipt_type"]
-            user_id = receipt["user_id"]
-            event_ids = receipt["event_ids"]
-            data = receipt["data"]
+        receipt = {
+            "room_id": room_id,
+            "receipt_type": receipt_type,
+            "user_id": user_id,
+            "event_ids": [event_id],
+            "data": {
+                "ts": int(self.clock.time_msec()),
+            }
+        }
 
-            users = yield self.state.get_current_user_in_room(room_id)
-            remotedomains = set(get_domain_from_id(u) for u in users)
-            remotedomains = remotedomains.copy()
-            remotedomains.discard(self.server_name)
-
-            logger.debug("Sending receipt to: %r", remotedomains)
-
-            for domain in remotedomains:
-                self.federation.send_edu(
-                    destination=domain,
-                    edu_type="m.receipt",
-                    content={
-                        room_id: {
-                            receipt_type: {
-                                user_id: {
-                                    "event_ids": event_ids,
-                                    "data": data,
-                                }
+        is_new = yield self._handle_new_receipts([receipt])
+        if not is_new:
+            return
+
+        # Work out which remote servers should be poked and poke them.
+
+        # TODO: optimise this to move some of the work to the workers.
+        data = receipt["data"]
+
+        # XXX why does this not use state.get_current_hosts_in_room() ?
+        users = yield self.state.get_current_user_in_room(room_id)
+        remotedomains = set(get_domain_from_id(u) for u in users)
+        remotedomains = remotedomains.copy()
+        remotedomains.discard(self.server_name)
+
+        logger.debug("Sending receipt to: %r", remotedomains)
+
+        for domain in remotedomains:
+            self.federation.build_and_send_edu(
+                destination=domain,
+                edu_type="m.receipt",
+                content={
+                    room_id: {
+                        receipt_type: {
+                            user_id: {
+                                "event_ids": [event_id],
+                                "data": data,
                             }
-                        },
+                        }
                     },
-                    key=(room_id, receipt_type, user_id),
-                )
-        except Exception:
-            logger.exception("Error pushing receipts to remote servers")
+                },
+                key=(room_id, receipt_type, user_id),
+            )
 
     @defer.inlineCallbacks
     def get_receipts_for_room(self, room_id, to_key):
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index c0e06929bd..03130edc54 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -24,6 +24,7 @@ from synapse.api.errors import (
     AuthError,
     Codes,
     InvalidCaptchaError,
+    LimitExceededError,
     RegistrationError,
     SynapseError,
 )
@@ -60,6 +61,7 @@ class RegistrationHandler(BaseHandler):
         self.user_directory_handler = hs.get_user_directory_handler()
         self.captcha_client = CaptchaServerHttpClient(hs)
         self.identity_handler = self.hs.get_handlers().identity_handler
+        self.ratelimiter = hs.get_registration_ratelimiter()
 
         self._next_generated_user_id = None
 
@@ -149,6 +151,7 @@ class RegistrationHandler(BaseHandler):
         threepid=None,
         user_type=None,
         default_display_name=None,
+        address=None,
     ):
         """Registers a new client on the server.
 
@@ -167,6 +170,7 @@ class RegistrationHandler(BaseHandler):
               api.constants.UserTypes, or None for a normal user.
             default_display_name (unicode|None): if set, the new user's displayname
               will be set to this. Defaults to 'localpart'.
+            address (str|None): the IP address used to perform the regitration.
         Returns:
             A tuple of (user_id, access_token).
         Raises:
@@ -206,7 +210,7 @@ class RegistrationHandler(BaseHandler):
             token = None
             if generate_token:
                 token = self.macaroon_gen.generate_access_token(user_id)
-            yield self._register_with_store(
+            yield self.register_with_store(
                 user_id=user_id,
                 token=token,
                 password_hash=password_hash,
@@ -215,6 +219,7 @@ class RegistrationHandler(BaseHandler):
                 create_profile_with_displayname=default_display_name,
                 admin=admin,
                 user_type=user_type,
+                address=address,
             )
 
             if self.hs.config.user_directory_search_all_users:
@@ -238,12 +243,13 @@ class RegistrationHandler(BaseHandler):
                 if default_display_name is None:
                     default_display_name = localpart
                 try:
-                    yield self._register_with_store(
+                    yield self.register_with_store(
                         user_id=user_id,
                         token=token,
                         password_hash=password_hash,
                         make_guest=make_guest,
                         create_profile_with_displayname=default_display_name,
+                        address=address,
                     )
                 except SynapseError:
                     # if user id is taken, just generate another
@@ -337,7 +343,7 @@ class RegistrationHandler(BaseHandler):
             user_id, allowed_appservice=service
         )
 
-        yield self._register_with_store(
+        yield self.register_with_store(
             user_id=user_id,
             password_hash="",
             appservice_id=service_id,
@@ -513,7 +519,7 @@ class RegistrationHandler(BaseHandler):
         token = self.macaroon_gen.generate_access_token(user_id)
 
         if need_register:
-            yield self._register_with_store(
+            yield self.register_with_store(
                 user_id=user_id,
                 token=token,
                 password_hash=password_hash,
@@ -590,10 +596,10 @@ class RegistrationHandler(BaseHandler):
             ratelimit=False,
         )
 
-    def _register_with_store(self, user_id, token=None, password_hash=None,
-                             was_guest=False, make_guest=False, appservice_id=None,
-                             create_profile_with_displayname=None, admin=False,
-                             user_type=None):
+    def register_with_store(self, user_id, token=None, password_hash=None,
+                            was_guest=False, make_guest=False, appservice_id=None,
+                            create_profile_with_displayname=None, admin=False,
+                            user_type=None, address=None):
         """Register user in the datastore.
 
         Args:
@@ -612,10 +618,26 @@ class RegistrationHandler(BaseHandler):
             admin (boolean): is an admin user?
             user_type (str|None): type of user. One of the values from
                 api.constants.UserTypes, or None for a normal user.
+            address (str|None): the IP address used to perform the regitration.
 
         Returns:
             Deferred
         """
+        # Don't rate limit for app services
+        if appservice_id is None and address is not None:
+            time_now = self.clock.time()
+
+            allowed, time_allowed = self.ratelimiter.can_do_action(
+                address, time_now_s=time_now,
+                rate_hz=self.hs.config.rc_registration_requests_per_second,
+                burst_count=self.hs.config.rc_registration_request_burst_count,
+            )
+
+            if not allowed:
+                raise LimitExceededError(
+                    retry_after_ms=int(1000 * (time_allowed - time_now)),
+                )
+
         if self.hs.config.worker_app:
             return self._register_client(
                 user_id=user_id,
@@ -627,6 +649,7 @@ class RegistrationHandler(BaseHandler):
                 create_profile_with_displayname=create_profile_with_displayname,
                 admin=admin,
                 user_type=user_type,
+                address=address,
             )
         else:
             return self.store.register(
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index bd97241ab4..57bb996245 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -39,6 +39,9 @@ from synapse.visibility import filter_events_for_client
 
 logger = logging.getLogger(__name__)
 
+# Debug logger for https://github.com/matrix-org/synapse/issues/4422
+issue4422_logger = logging.getLogger("synapse.handler.sync.4422_debug")
+
 
 # Counts the number of times we returned a non-empty sync. `type` is one of
 # "initial_sync", "full_state_sync" or "incremental_sync", `lazy_loaded` is
@@ -962,6 +965,15 @@ class SyncHandler(object):
 
         yield self._generate_sync_entry_for_groups(sync_result_builder)
 
+        # debug for https://github.com/matrix-org/synapse/issues/4422
+        for joined_room in sync_result_builder.joined:
+            room_id = joined_room.room_id
+            if room_id in newly_joined_rooms:
+                issue4422_logger.debug(
+                    "Sync result for newly joined room %s: %r",
+                    room_id, joined_room,
+                )
+
         defer.returnValue(SyncResult(
             presence=sync_result_builder.presence,
             account_data=sync_result_builder.account_data,
@@ -1425,6 +1437,17 @@ class SyncHandler(object):
                     old_mem_ev = yield self.store.get_event(
                         old_mem_ev_id, allow_none=True
                     )
+
+                # debug for #4422
+                if has_join:
+                    prev_membership = None
+                    if old_mem_ev:
+                        prev_membership = old_mem_ev.membership
+                    issue4422_logger.debug(
+                        "Previous membership for room %s with join: %s (event %s)",
+                        room_id, prev_membership, old_mem_ev_id,
+                    )
+
                 if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
                     newly_joined_rooms.append(room_id)
 
@@ -1519,30 +1542,39 @@ class SyncHandler(object):
         for room_id in sync_result_builder.joined_room_ids:
             room_entry = room_to_events.get(room_id, None)
 
+            newly_joined = room_id in newly_joined_rooms
             if room_entry:
                 events, start_key = room_entry
 
                 prev_batch_token = now_token.copy_and_replace("room_key", start_key)
 
-                room_entries.append(RoomSyncResultBuilder(
+                entry = RoomSyncResultBuilder(
                     room_id=room_id,
                     rtype="joined",
                     events=events,
-                    newly_joined=room_id in newly_joined_rooms,
+                    newly_joined=newly_joined,
                     full_state=False,
-                    since_token=None if room_id in newly_joined_rooms else since_token,
+                    since_token=None if newly_joined else since_token,
                     upto_token=prev_batch_token,
-                ))
+                )
             else:
-                room_entries.append(RoomSyncResultBuilder(
+                entry = RoomSyncResultBuilder(
                     room_id=room_id,
                     rtype="joined",
                     events=[],
-                    newly_joined=room_id in newly_joined_rooms,
+                    newly_joined=newly_joined,
                     full_state=False,
                     since_token=since_token,
                     upto_token=since_token,
-                ))
+                )
+
+            if newly_joined:
+                # debugging for https://github.com/matrix-org/synapse/issues/4422
+                issue4422_logger.debug(
+                    "RoomSyncResultBuilder events for newly joined room %s: %r",
+                    room_id, entry.events,
+                )
+            room_entries.append(entry)
 
         defer.returnValue((room_entries, invited, newly_joined_rooms, newly_left_rooms))
 
@@ -1663,6 +1695,13 @@ class SyncHandler(object):
             newly_joined_room=newly_joined,
         )
 
+        if newly_joined:
+            # debug for https://github.com/matrix-org/synapse/issues/4422
+            issue4422_logger.debug(
+                "Timeline events after filtering in newly-joined room %s: %r",
+                room_id, batch,
+            )
+
         # When we join the room (or the client requests full_state), we should
         # send down any existing tags. Usually the user won't have tags in a
         # newly joined room, unless either a) they've joined before or b) the
@@ -1894,15 +1933,34 @@ def _calculate_state(
 
 
 class SyncResultBuilder(object):
-    "Used to help build up a new SyncResult for a user"
+    """Used to help build up a new SyncResult for a user
+
+    Attributes:
+        sync_config (SyncConfig)
+        full_state (bool)
+        since_token (StreamToken)
+        now_token (StreamToken)
+        joined_room_ids (list[str])
+
+        # The following mirror the fields in a sync response
+        presence (list)
+        account_data (list)
+        joined (list[JoinedSyncResult])
+        invited (list[InvitedSyncResult])
+        archived (list[ArchivedSyncResult])
+        device (list)
+        groups (GroupsSyncResult|None)
+        to_device (list)
+    """
     def __init__(self, sync_config, full_state, since_token, now_token,
                  joined_room_ids):
         """
         Args:
-            sync_config(SyncConfig)
-            full_state(bool): The full_state flag as specified by user
-            since_token(StreamToken): The token supplied by user, or None.
-            now_token(StreamToken): The token to sync up to.
+            sync_config (SyncConfig)
+            full_state (bool): The full_state flag as specified by user
+            since_token (StreamToken): The token supplied by user, or None.
+            now_token (StreamToken): The token to sync up to.
+            joined_room_ids (list[str]): List of rooms the user is joined to
         """
         self.sync_config = sync_config
         self.full_state = full_state
@@ -1930,8 +1988,8 @@ class RoomSyncResultBuilder(object):
         Args:
             room_id(str)
             rtype(str): One of `"joined"` or `"archived"`
-            events(list): List of events to include in the room, (more events
-                may be added when generating result).
+            events(list[FrozenEvent]): List of events to include in the room
+                (more events may be added when generating result).
             newly_joined(bool): If the user has newly joined the room
             full_state(bool): Whether the full state should be sent in result
             since_token(StreamToken): Earliest point to return events from, or None
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index a61bbf9392..39df960c31 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -231,7 +231,7 @@ class TypingHandler(object):
             for domain in set(get_domain_from_id(u) for u in users):
                 if domain != self.server_name:
                     logger.debug("sending typing update to %s", domain)
-                    self.federation.send_edu(
+                    self.federation.build_and_send_edu(
                         destination=domain,
                         edu_type="m.typing",
                         content={
diff --git a/synapse/notifier.py b/synapse/notifier.py
index de02b1017e..ff589660da 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -178,8 +178,6 @@ class Notifier(object):
             self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS
         )
 
-        self.replication_deferred = ObservableDeferred(defer.Deferred())
-
         # This is not a very cheap test to perform, but it's only executed
         # when rendering the metrics page, which is likely once per minute at
         # most when scraping it.
@@ -205,7 +203,9 @@ class Notifier(object):
 
     def add_replication_callback(self, cb):
         """Add a callback that will be called when some new data is available.
-        Callback is not given any arguments.
+        Callback is not given any arguments. It should *not* return a Deferred - if
+        it needs to do any asynchronous work, a background thread should be started and
+        wrapped with run_as_background_process.
         """
         self.replication_callbacks.append(cb)
 
@@ -517,60 +517,5 @@ class Notifier(object):
 
     def notify_replication(self):
         """Notify the any replication listeners that there's a new event"""
-        with PreserveLoggingContext():
-            deferred = self.replication_deferred
-            self.replication_deferred = ObservableDeferred(defer.Deferred())
-            deferred.callback(None)
-
-            # the callbacks may well outlast the current request, so we run
-            # them in the sentinel logcontext.
-            #
-            # (ideally it would be up to the callbacks to know if they were
-            # starting off background processes and drop the logcontext
-            # accordingly, but that requires more changes)
-            for cb in self.replication_callbacks:
-                cb()
-
-    @defer.inlineCallbacks
-    def wait_for_replication(self, callback, timeout):
-        """Wait for an event to happen.
-
-        Args:
-            callback: Gets called whenever an event happens. If this returns a
-                truthy value then ``wait_for_replication`` returns, otherwise
-                it waits for another event.
-            timeout: How many milliseconds to wait for callback return a truthy
-                value.
-
-        Returns:
-            A deferred that resolves with the value returned by the callback.
-        """
-        listener = _NotificationListener(None)
-
-        end_time = self.clock.time_msec() + timeout
-
-        while True:
-            listener.deferred = self.replication_deferred.observe()
-            result = yield callback()
-            if result:
-                break
-
-            now = self.clock.time_msec()
-            if end_time <= now:
-                break
-
-            listener.deferred = timeout_deferred(
-                listener.deferred,
-                timeout=(end_time - now) / 1000.,
-                reactor=self.hs.get_reactor(),
-            )
-
-            try:
-                with PreserveLoggingContext():
-                    yield listener.deferred
-            except defer.TimeoutError:
-                break
-            except defer.CancelledError:
-                break
-
-        defer.returnValue(result)
+        for cb in self.replication_callbacks:
+            cb()
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index 1d27c9221f..912a5ac341 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -33,11 +33,12 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
     def __init__(self, hs):
         super(ReplicationRegisterServlet, self).__init__(hs)
         self.store = hs.get_datastore()
+        self.registration_handler = hs.get_registration_handler()
 
     @staticmethod
     def _serialize_payload(
         user_id, token, password_hash, was_guest, make_guest, appservice_id,
-        create_profile_with_displayname, admin, user_type,
+        create_profile_with_displayname, admin, user_type, address,
     ):
         """
         Args:
@@ -56,6 +57,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
             admin (boolean): is an admin user?
             user_type (str|None): type of user. One of the values from
                 api.constants.UserTypes, or None for a normal user.
+            address (str|None): the IP address used to perform the regitration.
         """
         return {
             "token": token,
@@ -66,13 +68,14 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
             "create_profile_with_displayname": create_profile_with_displayname,
             "admin": admin,
             "user_type": user_type,
+            "address": address,
         }
 
     @defer.inlineCallbacks
     def _handle_request(self, request, user_id):
         content = parse_json_object_from_request(request)
 
-        yield self.store.register(
+        yield self.registration_handler.register_with_store(
             user_id=user_id,
             token=content["token"],
             password_hash=content["password_hash"],
@@ -82,6 +85,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
             create_profile_with_displayname=content["create_profile_with_displayname"],
             admin=content["admin"],
             user_type=content["user_type"],
+            address=content["address"]
         )
 
         defer.returnValue((200, {}))
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 60641f1a49..5b8521c770 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -43,6 +43,8 @@ class SlavedClientIpStore(BaseSlavedStore):
         if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
             return
 
+        self.client_ip_last_seen.prefill(key, now)
+
         self.hs.get_tcp_replication().send_user_ip(
             user_id, access_token, ip, user_agent, device_id, now
         )
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 4f19fd35aa..4d59778863 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -13,15 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage import DataStore
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.storage.deviceinbox import DeviceInboxWorkerStore
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
-from ._base import BaseSlavedStore, __func__
-from ._slaved_id_tracker import SlavedIdTracker
 
-
-class SlavedDeviceInboxStore(BaseSlavedStore):
+class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
     def __init__(self, db_conn, hs):
         super(SlavedDeviceInboxStore, self).__init__(db_conn, hs)
         self._device_inbox_id_gen = SlavedIdTracker(
@@ -43,12 +42,6 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
             expiry_ms=30 * 60 * 1000,
         )
 
-    get_to_device_stream_token = __func__(DataStore.get_to_device_stream_token)
-    get_new_messages_for_device = __func__(DataStore.get_new_messages_for_device)
-    get_new_device_msgs_for_remote = __func__(DataStore.get_new_device_msgs_for_remote)
-    delete_messages_for_device = __func__(DataStore.delete_messages_for_device)
-    delete_device_msgs_for_remote = __func__(DataStore.delete_device_msgs_for_remote)
-
     def stream_positions(self):
         result = super(SlavedDeviceInboxStore, self).stream_positions()
         result["to_device"] = self._device_inbox_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index ec2fd561cc..16c9a162c5 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -13,15 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage import DataStore
-from synapse.storage.end_to_end_keys import EndToEndKeyStore
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.storage.devices import DeviceWorkerStore
+from synapse.storage.end_to_end_keys import EndToEndKeyWorkerStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
-from ._base import BaseSlavedStore, __func__
-from ._slaved_id_tracker import SlavedIdTracker
 
-
-class SlavedDeviceStore(BaseSlavedStore):
+class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
     def __init__(self, db_conn, hs):
         super(SlavedDeviceStore, self).__init__(db_conn, hs)
 
@@ -38,17 +37,6 @@ class SlavedDeviceStore(BaseSlavedStore):
             "DeviceListFederationStreamChangeCache", device_list_max,
         )
 
-    get_device_stream_token = __func__(DataStore.get_device_stream_token)
-    get_user_whose_devices_changed = __func__(DataStore.get_user_whose_devices_changed)
-    get_devices_by_remote = __func__(DataStore.get_devices_by_remote)
-    _get_devices_by_remote_txn = __func__(DataStore._get_devices_by_remote_txn)
-    _get_e2e_device_keys_txn = __func__(DataStore._get_e2e_device_keys_txn)
-    mark_as_sent_devices_by_remote = __func__(DataStore.mark_as_sent_devices_by_remote)
-    _mark_as_sent_devices_by_remote_txn = (
-        __func__(DataStore._mark_as_sent_devices_by_remote_txn)
-    )
-    count_e2e_one_time_keys = EndToEndKeyStore.__dict__["count_e2e_one_time_keys"]
-
     def stream_positions(self):
         result = super(SlavedDeviceStore, self).stream_positions()
         result["device_lists"] = self._device_list_id_gen.get_current_token()
@@ -58,14 +46,23 @@ class SlavedDeviceStore(BaseSlavedStore):
         if stream_name == "device_lists":
             self._device_list_id_gen.advance(token)
             for row in rows:
-                self._device_list_stream_cache.entity_has_changed(
-                    row.user_id, token
+                self._invalidate_caches_for_devices(
+                    token, row.user_id, row.destination,
                 )
-
-                if row.destination:
-                    self._device_list_federation_stream_cache.entity_has_changed(
-                        row.destination, token
-                    )
         return super(SlavedDeviceStore, self).process_replication_rows(
             stream_name, token, rows
         )
+
+    def _invalidate_caches_for_devices(self, token, user_id, destination):
+        self._device_list_stream_cache.entity_has_changed(
+            user_id, token
+        )
+
+        if destination:
+            self._device_list_federation_stream_cache.entity_has_changed(
+                destination, token
+            )
+
+        self._get_cached_devices_for_user.invalidate((user_id,))
+        self._get_cached_user_device.invalidate_many((user_id,))
+        self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index f0200c1e98..45fc913c52 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -20,7 +20,7 @@ from ._slaved_id_tracker import SlavedIdTracker
 from .events import SlavedEventStore
 
 
-class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore):
+class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
     def __init__(self, db_conn, hs):
         self._push_rules_stream_id_gen = SlavedIdTracker(
             db_conn, "push_rules_stream", "stream_id",
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 49ae5b3355..55630ba9a7 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -451,7 +451,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
     @defer.inlineCallbacks
     def subscribe_to_stream(self, stream_name, token):
-        """Subscribe the remote to a streams.
+        """Subscribe the remote to a stream.
 
         This invloves checking if they've missed anything and sending those
         updates down if they have. During that time new updates for the stream
@@ -478,11 +478,36 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
             # Now we can send any updates that came in while we were subscribing
             pending_rdata = self.pending_rdata.pop(stream_name, [])
+            updates = []
             for token, update in pending_rdata:
-                # Only send updates newer than the current token
-                if token > current_token:
+                # If the token is null, it is part of a batch update. Batches
+                # are multiple updates that share a single token. To denote
+                # this, the token is set to None for all tokens in the batch
+                # except for the last. If we find a None token, we keep looking
+                # through tokens until we find one that is not None and then
+                # process all previous updates in the batch as if they had the
+                # final token.
+                if token is None:
+                    # Store this update as part of a batch
+                    updates.append(update)
+                    continue
+
+                if token <= current_token:
+                    # This update or batch of updates is older than
+                    # current_token, dismiss it
+                    updates = []
+                    continue
+
+                updates.append(update)
+
+                # Send all updates that are part of this batch with the
+                # found token
+                for update in updates:
                     self.send_command(RdataCommand(stream_name, token, update))
 
+                # Clear stored updates
+                updates = []
+
             # They're now fully subscribed
             self.replication_streams.add(stream_name)
         except Exception as e:
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index 82433a2aa9..2a29f0c2af 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -17,12 +17,14 @@
 import hashlib
 import hmac
 import logging
+import platform
 
 from six import text_type
 from six.moves import http_client
 
 from twisted.internet import defer
 
+import synapse
 from synapse.api.constants import Membership, UserTypes
 from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
 from synapse.http.servlet import (
@@ -32,6 +34,7 @@ from synapse.http.servlet import (
     parse_string,
 )
 from synapse.types import UserID, create_requester
+from synapse.util.versionstring import get_version_string
 
 from .base import ClientV1RestServlet, client_path_patterns
 
@@ -66,6 +69,25 @@ class UsersRestServlet(ClientV1RestServlet):
         defer.returnValue((200, ret))
 
 
+class VersionServlet(ClientV1RestServlet):
+    PATTERNS = client_path_patterns("/admin/server_version")
+
+    @defer.inlineCallbacks
+    def on_GET(self, request):
+        requester = yield self.auth.get_user_by_req(request)
+        is_admin = yield self.auth.is_server_admin(requester.user)
+
+        if not is_admin:
+            raise AuthError(403, "You are not a server admin")
+
+        ret = {
+            'server_version': get_version_string(synapse),
+            'python_version': platform.python_version(),
+        }
+
+        defer.returnValue((200, ret))
+
+
 class UserRegisterServlet(ClientV1RestServlet):
     """
     Attributes:
@@ -466,17 +488,6 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
         )
         new_room_id = info["room_id"]
 
-        yield self.event_creation_handler.create_and_send_nonmember_event(
-            room_creator_requester,
-            {
-                "type": "m.room.message",
-                "content": {"body": message, "msgtype": "m.text"},
-                "room_id": new_room_id,
-                "sender": new_room_user_id,
-            },
-            ratelimit=False,
-        )
-
         requester_user_id = requester.user.to_string()
 
         logger.info("Shutting down room %r", room_id)
@@ -514,6 +525,17 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
 
             kicked_users.append(user_id)
 
+        yield self.event_creation_handler.create_and_send_nonmember_event(
+            room_creator_requester,
+            {
+                "type": "m.room.message",
+                "content": {"body": message, "msgtype": "m.text"},
+                "room_id": new_room_id,
+                "sender": new_room_user_id,
+            },
+            ratelimit=False,
+        )
+
         aliases_for_room = yield self.store.get_aliases_for_room(room_id)
 
         yield self.store.update_aliases_for_room(
@@ -763,3 +785,4 @@ def register_servlets(hs, http_server):
     QuarantineMediaInRoom(hs).register(http_server)
     ListMediaInRoom(hs).register(http_server)
     UserRegisterServlet(hs).register(http_server)
+    VersionServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 94cbba4303..6f34029431 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -25,7 +25,12 @@ from twisted.internet import defer
 import synapse
 import synapse.types
 from synapse.api.constants import LoginType
-from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError
+from synapse.api.errors import (
+    Codes,
+    LimitExceededError,
+    SynapseError,
+    UnrecognizedRequestError,
+)
 from synapse.config.server import is_threepid_reserved
 from synapse.http.servlet import (
     RestServlet,
@@ -191,18 +196,36 @@ class RegisterRestServlet(RestServlet):
         self.identity_handler = hs.get_handlers().identity_handler
         self.room_member_handler = hs.get_room_member_handler()
         self.macaroon_gen = hs.get_macaroon_generator()
+        self.ratelimiter = hs.get_registration_ratelimiter()
+        self.clock = hs.get_clock()
 
     @interactive_auth_handler
     @defer.inlineCallbacks
     def on_POST(self, request):
         body = parse_json_object_from_request(request)
 
+        client_addr = request.getClientIP()
+
+        time_now = self.clock.time()
+
+        allowed, time_allowed = self.ratelimiter.can_do_action(
+            client_addr, time_now_s=time_now,
+            rate_hz=self.hs.config.rc_registration_requests_per_second,
+            burst_count=self.hs.config.rc_registration_request_burst_count,
+            update=False,
+        )
+
+        if not allowed:
+            raise LimitExceededError(
+                retry_after_ms=int(1000 * (time_allowed - time_now)),
+            )
+
         kind = b"user"
         if b"kind" in request.args:
             kind = request.args[b"kind"][0]
 
         if kind == b"guest":
-            ret = yield self._do_guest_registration(body)
+            ret = yield self._do_guest_registration(body, address=client_addr)
             defer.returnValue(ret)
             return
         elif kind != b"user":
@@ -411,6 +434,7 @@ class RegisterRestServlet(RestServlet):
                 guest_access_token=guest_access_token,
                 generate_token=False,
                 threepid=threepid,
+                address=client_addr,
             )
             # Necessary due to auth checks prior to the threepid being
             # written to the db
@@ -522,12 +546,13 @@ class RegisterRestServlet(RestServlet):
         defer.returnValue(result)
 
     @defer.inlineCallbacks
-    def _do_guest_registration(self, params):
+    def _do_guest_registration(self, params, address=None):
         if not self.hs.config.allow_guest_access:
             raise SynapseError(403, "Guest access is disabled")
         user_id, _ = yield self.registration_handler.register(
             generate_token=False,
-            make_guest=True
+            make_guest=True,
+            address=address,
         )
 
         # we don't allow guests to specify their own device_id, because
diff --git a/synapse/server.py b/synapse/server.py
index 4d364fccce..72835e8c86 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -51,7 +51,7 @@ from synapse.handlers.acme import AcmeHandler
 from synapse.handlers.appservice import ApplicationServicesHandler
 from synapse.handlers.auth import AuthHandler, MacaroonGenerator
 from synapse.handlers.deactivate_account import DeactivateAccountHandler
-from synapse.handlers.device import DeviceHandler
+from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler
 from synapse.handlers.devicemessage import DeviceMessageHandler
 from synapse.handlers.e2e_keys import E2eKeysHandler
 from synapse.handlers.e2e_room_keys import E2eRoomKeysHandler
@@ -206,6 +206,7 @@ class HomeServer(object):
         self.clock = Clock(reactor)
         self.distributor = Distributor()
         self.ratelimiter = Ratelimiter()
+        self.registration_ratelimiter = Ratelimiter()
 
         self.datastore = None
 
@@ -251,6 +252,9 @@ class HomeServer(object):
     def get_ratelimiter(self):
         return self.ratelimiter
 
+    def get_registration_ratelimiter(self):
+        return self.registration_ratelimiter
+
     def build_federation_client(self):
         return FederationClient(self)
 
@@ -307,7 +311,10 @@ class HomeServer(object):
         return MacaroonGenerator(self)
 
     def build_device_handler(self):
-        return DeviceHandler(self)
+        if self.config.worker_app:
+            return DeviceWorkerHandler(self)
+        else:
+            return DeviceHandler(self)
 
     def build_device_message_handler(self):
         return DeviceMessageHandler(self)
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index e06b0bc56d..e6a42a53bb 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -19,14 +19,174 @@ from canonicaljson import json
 
 from twisted.internet import defer
 
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.background_updates import BackgroundUpdateStore
 from synapse.util.caches.expiringcache import ExpiringCache
 
-from .background_updates import BackgroundUpdateStore
-
 logger = logging.getLogger(__name__)
 
 
-class DeviceInboxStore(BackgroundUpdateStore):
+class DeviceInboxWorkerStore(SQLBaseStore):
+    def get_to_device_stream_token(self):
+        return self._device_inbox_id_gen.get_current_token()
+
+    def get_new_messages_for_device(
+        self, user_id, device_id, last_stream_id, current_stream_id, limit=100
+    ):
+        """
+        Args:
+            user_id(str): The recipient user_id.
+            device_id(str): The recipient device_id.
+            current_stream_id(int): The current position of the to device
+                message stream.
+        Returns:
+            Deferred ([dict], int): List of messages for the device and where
+                in the stream the messages got to.
+        """
+        has_changed = self._device_inbox_stream_cache.has_entity_changed(
+            user_id, last_stream_id
+        )
+        if not has_changed:
+            return defer.succeed(([], current_stream_id))
+
+        def get_new_messages_for_device_txn(txn):
+            sql = (
+                "SELECT stream_id, message_json FROM device_inbox"
+                " WHERE user_id = ? AND device_id = ?"
+                " AND ? < stream_id AND stream_id <= ?"
+                " ORDER BY stream_id ASC"
+                " LIMIT ?"
+            )
+            txn.execute(sql, (
+                user_id, device_id, last_stream_id, current_stream_id, limit
+            ))
+            messages = []
+            for row in txn:
+                stream_pos = row[0]
+                messages.append(json.loads(row[1]))
+            if len(messages) < limit:
+                stream_pos = current_stream_id
+            return (messages, stream_pos)
+
+        return self.runInteraction(
+            "get_new_messages_for_device", get_new_messages_for_device_txn,
+        )
+
+    @defer.inlineCallbacks
+    def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
+        """
+        Args:
+            user_id(str): The recipient user_id.
+            device_id(str): The recipient device_id.
+            up_to_stream_id(int): Where to delete messages up to.
+        Returns:
+            A deferred that resolves to the number of messages deleted.
+        """
+        # If we have cached the last stream id we've deleted up to, we can
+        # check if there is likely to be anything that needs deleting
+        last_deleted_stream_id = self._last_device_delete_cache.get(
+            (user_id, device_id), None
+        )
+        if last_deleted_stream_id:
+            has_changed = self._device_inbox_stream_cache.has_entity_changed(
+                user_id, last_deleted_stream_id
+            )
+            if not has_changed:
+                defer.returnValue(0)
+
+        def delete_messages_for_device_txn(txn):
+            sql = (
+                "DELETE FROM device_inbox"
+                " WHERE user_id = ? AND device_id = ?"
+                " AND stream_id <= ?"
+            )
+            txn.execute(sql, (user_id, device_id, up_to_stream_id))
+            return txn.rowcount
+
+        count = yield self.runInteraction(
+            "delete_messages_for_device", delete_messages_for_device_txn
+        )
+
+        # Update the cache, ensuring that we only ever increase the value
+        last_deleted_stream_id = self._last_device_delete_cache.get(
+            (user_id, device_id), 0
+        )
+        self._last_device_delete_cache[(user_id, device_id)] = max(
+            last_deleted_stream_id, up_to_stream_id
+        )
+
+        defer.returnValue(count)
+
+    def get_new_device_msgs_for_remote(
+        self, destination, last_stream_id, current_stream_id, limit=100
+    ):
+        """
+        Args:
+            destination(str): The name of the remote server.
+            last_stream_id(int|long): The last position of the device message stream
+                that the server sent up to.
+            current_stream_id(int|long): The current position of the device
+                message stream.
+        Returns:
+            Deferred ([dict], int|long): List of messages for the device and where
+                in the stream the messages got to.
+        """
+
+        has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
+            destination, last_stream_id
+        )
+        if not has_changed or last_stream_id == current_stream_id:
+            return defer.succeed(([], current_stream_id))
+
+        def get_new_messages_for_remote_destination_txn(txn):
+            sql = (
+                "SELECT stream_id, messages_json FROM device_federation_outbox"
+                " WHERE destination = ?"
+                " AND ? < stream_id AND stream_id <= ?"
+                " ORDER BY stream_id ASC"
+                " LIMIT ?"
+            )
+            txn.execute(sql, (
+                destination, last_stream_id, current_stream_id, limit
+            ))
+            messages = []
+            for row in txn:
+                stream_pos = row[0]
+                messages.append(json.loads(row[1]))
+            if len(messages) < limit:
+                stream_pos = current_stream_id
+            return (messages, stream_pos)
+
+        return self.runInteraction(
+            "get_new_device_msgs_for_remote",
+            get_new_messages_for_remote_destination_txn,
+        )
+
+    def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
+        """Used to delete messages when the remote destination acknowledges
+        their receipt.
+
+        Args:
+            destination(str): The destination server_name
+            up_to_stream_id(int): Where to delete messages up to.
+        Returns:
+            A deferred that resolves when the messages have been deleted.
+        """
+        def delete_messages_for_remote_destination_txn(txn):
+            sql = (
+                "DELETE FROM device_federation_outbox"
+                " WHERE destination = ?"
+                " AND stream_id <= ?"
+            )
+            txn.execute(sql, (destination, up_to_stream_id))
+
+        return self.runInteraction(
+            "delete_device_msgs_for_remote",
+            delete_messages_for_remote_destination_txn
+        )
+
+
+class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
 
     def __init__(self, db_conn, hs):
@@ -220,93 +380,6 @@ class DeviceInboxStore(BackgroundUpdateStore):
 
         txn.executemany(sql, rows)
 
-    def get_new_messages_for_device(
-        self, user_id, device_id, last_stream_id, current_stream_id, limit=100
-    ):
-        """
-        Args:
-            user_id(str): The recipient user_id.
-            device_id(str): The recipient device_id.
-            current_stream_id(int): The current position of the to device
-                message stream.
-        Returns:
-            Deferred ([dict], int): List of messages for the device and where
-                in the stream the messages got to.
-        """
-        has_changed = self._device_inbox_stream_cache.has_entity_changed(
-            user_id, last_stream_id
-        )
-        if not has_changed:
-            return defer.succeed(([], current_stream_id))
-
-        def get_new_messages_for_device_txn(txn):
-            sql = (
-                "SELECT stream_id, message_json FROM device_inbox"
-                " WHERE user_id = ? AND device_id = ?"
-                " AND ? < stream_id AND stream_id <= ?"
-                " ORDER BY stream_id ASC"
-                " LIMIT ?"
-            )
-            txn.execute(sql, (
-                user_id, device_id, last_stream_id, current_stream_id, limit
-            ))
-            messages = []
-            for row in txn:
-                stream_pos = row[0]
-                messages.append(json.loads(row[1]))
-            if len(messages) < limit:
-                stream_pos = current_stream_id
-            return (messages, stream_pos)
-
-        return self.runInteraction(
-            "get_new_messages_for_device", get_new_messages_for_device_txn,
-        )
-
-    @defer.inlineCallbacks
-    def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
-        """
-        Args:
-            user_id(str): The recipient user_id.
-            device_id(str): The recipient device_id.
-            up_to_stream_id(int): Where to delete messages up to.
-        Returns:
-            A deferred that resolves to the number of messages deleted.
-        """
-        # If we have cached the last stream id we've deleted up to, we can
-        # check if there is likely to be anything that needs deleting
-        last_deleted_stream_id = self._last_device_delete_cache.get(
-            (user_id, device_id), None
-        )
-        if last_deleted_stream_id:
-            has_changed = self._device_inbox_stream_cache.has_entity_changed(
-                user_id, last_deleted_stream_id
-            )
-            if not has_changed:
-                defer.returnValue(0)
-
-        def delete_messages_for_device_txn(txn):
-            sql = (
-                "DELETE FROM device_inbox"
-                " WHERE user_id = ? AND device_id = ?"
-                " AND stream_id <= ?"
-            )
-            txn.execute(sql, (user_id, device_id, up_to_stream_id))
-            return txn.rowcount
-
-        count = yield self.runInteraction(
-            "delete_messages_for_device", delete_messages_for_device_txn
-        )
-
-        # Update the cache, ensuring that we only ever increase the value
-        last_deleted_stream_id = self._last_device_delete_cache.get(
-            (user_id, device_id), 0
-        )
-        self._last_device_delete_cache[(user_id, device_id)] = max(
-            last_deleted_stream_id, up_to_stream_id
-        )
-
-        defer.returnValue(count)
-
     def get_all_new_device_messages(self, last_pos, current_pos, limit):
         """
         Args:
@@ -351,77 +424,6 @@ class DeviceInboxStore(BackgroundUpdateStore):
             "get_all_new_device_messages", get_all_new_device_messages_txn
         )
 
-    def get_to_device_stream_token(self):
-        return self._device_inbox_id_gen.get_current_token()
-
-    def get_new_device_msgs_for_remote(
-        self, destination, last_stream_id, current_stream_id, limit=100
-    ):
-        """
-        Args:
-            destination(str): The name of the remote server.
-            last_stream_id(int|long): The last position of the device message stream
-                that the server sent up to.
-            current_stream_id(int|long): The current position of the device
-                message stream.
-        Returns:
-            Deferred ([dict], int|long): List of messages for the device and where
-                in the stream the messages got to.
-        """
-
-        has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
-            destination, last_stream_id
-        )
-        if not has_changed or last_stream_id == current_stream_id:
-            return defer.succeed(([], current_stream_id))
-
-        def get_new_messages_for_remote_destination_txn(txn):
-            sql = (
-                "SELECT stream_id, messages_json FROM device_federation_outbox"
-                " WHERE destination = ?"
-                " AND ? < stream_id AND stream_id <= ?"
-                " ORDER BY stream_id ASC"
-                " LIMIT ?"
-            )
-            txn.execute(sql, (
-                destination, last_stream_id, current_stream_id, limit
-            ))
-            messages = []
-            for row in txn:
-                stream_pos = row[0]
-                messages.append(json.loads(row[1]))
-            if len(messages) < limit:
-                stream_pos = current_stream_id
-            return (messages, stream_pos)
-
-        return self.runInteraction(
-            "get_new_device_msgs_for_remote",
-            get_new_messages_for_remote_destination_txn,
-        )
-
-    def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
-        """Used to delete messages when the remote destination acknowledges
-        their receipt.
-
-        Args:
-            destination(str): The destination server_name
-            up_to_stream_id(int): Where to delete messages up to.
-        Returns:
-            A deferred that resolves when the messages have been deleted.
-        """
-        def delete_messages_for_remote_destination_txn(txn):
-            sql = (
-                "DELETE FROM device_federation_outbox"
-                " WHERE destination = ?"
-                " AND stream_id <= ?"
-            )
-            txn.execute(sql, (destination, up_to_stream_id))
-
-        return self.runInteraction(
-            "delete_device_msgs_for_remote",
-            delete_messages_for_remote_destination_txn
-        )
-
     @defer.inlineCallbacks
     def _background_drop_index_device_inbox(self, progress, batch_size):
         def reindex_txn(conn):
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index ecdab34e7d..e716dc1437 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -22,11 +22,10 @@ from twisted.internet import defer
 
 from synapse.api.errors import StoreError
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import Cache, SQLBaseStore, db_to_json
 from synapse.storage.background_updates import BackgroundUpdateStore
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
 
-from ._base import Cache, db_to_json
-
 logger = logging.getLogger(__name__)
 
 DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
@@ -34,7 +33,343 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
 )
 
 
-class DeviceStore(BackgroundUpdateStore):
+class DeviceWorkerStore(SQLBaseStore):
+    def get_device(self, user_id, device_id):
+        """Retrieve a device.
+
+        Args:
+            user_id (str): The ID of the user which owns the device
+            device_id (str): The ID of the device to retrieve
+        Returns:
+            defer.Deferred for a dict containing the device information
+        Raises:
+            StoreError: if the device is not found
+        """
+        return self._simple_select_one(
+            table="devices",
+            keyvalues={"user_id": user_id, "device_id": device_id},
+            retcols=("user_id", "device_id", "display_name"),
+            desc="get_device",
+        )
+
+    @defer.inlineCallbacks
+    def get_devices_by_user(self, user_id):
+        """Retrieve all of a user's registered devices.
+
+        Args:
+            user_id (str):
+        Returns:
+            defer.Deferred: resolves to a dict from device_id to a dict
+            containing "device_id", "user_id" and "display_name" for each
+            device.
+        """
+        devices = yield self._simple_select_list(
+            table="devices",
+            keyvalues={"user_id": user_id},
+            retcols=("user_id", "device_id", "display_name"),
+            desc="get_devices_by_user"
+        )
+
+        defer.returnValue({d["device_id"]: d for d in devices})
+
+    def get_devices_by_remote(self, destination, from_stream_id):
+        """Get stream of updates to send to remote servers
+
+        Returns:
+            (int, list[dict]): current stream id and list of updates
+        """
+        now_stream_id = self._device_list_id_gen.get_current_token()
+
+        has_changed = self._device_list_federation_stream_cache.has_entity_changed(
+            destination, int(from_stream_id)
+        )
+        if not has_changed:
+            return (now_stream_id, [])
+
+        return self.runInteraction(
+            "get_devices_by_remote", self._get_devices_by_remote_txn,
+            destination, from_stream_id, now_stream_id,
+        )
+
+    def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
+                                   now_stream_id):
+        sql = """
+            SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
+            WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
+            GROUP BY user_id, device_id
+            LIMIT 20
+        """
+        txn.execute(
+            sql, (destination, from_stream_id, now_stream_id, False)
+        )
+
+        # maps (user_id, device_id) -> stream_id
+        query_map = {(r[0], r[1]): r[2] for r in txn}
+        if not query_map:
+            return (now_stream_id, [])
+
+        if len(query_map) >= 20:
+            now_stream_id = max(stream_id for stream_id in itervalues(query_map))
+
+        devices = self._get_e2e_device_keys_txn(
+            txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
+        )
+
+        prev_sent_id_sql = """
+            SELECT coalesce(max(stream_id), 0) as stream_id
+            FROM device_lists_outbound_last_success
+            WHERE destination = ? AND user_id = ? AND stream_id <= ?
+        """
+
+        results = []
+        for user_id, user_devices in iteritems(devices):
+            # The prev_id for the first row is always the last row before
+            # `from_stream_id`
+            txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
+            rows = txn.fetchall()
+            prev_id = rows[0][0]
+            for device_id, device in iteritems(user_devices):
+                stream_id = query_map[(user_id, device_id)]
+                result = {
+                    "user_id": user_id,
+                    "device_id": device_id,
+                    "prev_id": [prev_id] if prev_id else [],
+                    "stream_id": stream_id,
+                }
+
+                prev_id = stream_id
+
+                if device is not None:
+                    key_json = device.get("key_json", None)
+                    if key_json:
+                        result["keys"] = db_to_json(key_json)
+                    device_display_name = device.get("device_display_name", None)
+                    if device_display_name:
+                        result["device_display_name"] = device_display_name
+                else:
+                    result["deleted"] = True
+
+                results.append(result)
+
+        return (now_stream_id, results)
+
+    def mark_as_sent_devices_by_remote(self, destination, stream_id):
+        """Mark that updates have successfully been sent to the destination.
+        """
+        return self.runInteraction(
+            "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
+            destination, stream_id,
+        )
+
+    def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
+        # We update the device_lists_outbound_last_success with the successfully
+        # poked users. We do the join to see which users need to be inserted and
+        # which updated.
+        sql = """
+            SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
+            FROM device_lists_outbound_pokes as o
+            LEFT JOIN device_lists_outbound_last_success as s
+                USING (destination, user_id)
+            WHERE destination = ? AND o.stream_id <= ?
+            GROUP BY user_id
+        """
+        txn.execute(sql, (destination, stream_id,))
+        rows = txn.fetchall()
+
+        sql = """
+            UPDATE device_lists_outbound_last_success
+            SET stream_id = ?
+            WHERE destination = ? AND user_id = ?
+        """
+        txn.executemany(
+            sql, ((row[1], destination, row[0],) for row in rows if row[2])
+        )
+
+        sql = """
+            INSERT INTO device_lists_outbound_last_success
+            (destination, user_id, stream_id) VALUES (?, ?, ?)
+        """
+        txn.executemany(
+            sql, ((destination, row[0], row[1],) for row in rows if not row[2])
+        )
+
+        # Delete all sent outbound pokes
+        sql = """
+            DELETE FROM device_lists_outbound_pokes
+            WHERE destination = ? AND stream_id <= ?
+        """
+        txn.execute(sql, (destination, stream_id,))
+
+    def get_device_stream_token(self):
+        return self._device_list_id_gen.get_current_token()
+
+    @defer.inlineCallbacks
+    def get_user_devices_from_cache(self, query_list):
+        """Get the devices (and keys if any) for remote users from the cache.
+
+        Args:
+            query_list(list): List of (user_id, device_ids), if device_ids is
+                falsey then return all device ids for that user.
+
+        Returns:
+            (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
+            a set of user_ids and results_map is a mapping of
+            user_id -> device_id -> device_info
+        """
+        user_ids = set(user_id for user_id, _ in query_list)
+        user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
+        user_ids_in_cache = set(
+            user_id for user_id, stream_id in user_map.items() if stream_id
+        )
+        user_ids_not_in_cache = user_ids - user_ids_in_cache
+
+        results = {}
+        for user_id, device_id in query_list:
+            if user_id not in user_ids_in_cache:
+                continue
+
+            if device_id:
+                device = yield self._get_cached_user_device(user_id, device_id)
+                results.setdefault(user_id, {})[device_id] = device
+            else:
+                results[user_id] = yield self._get_cached_devices_for_user(user_id)
+
+        defer.returnValue((user_ids_not_in_cache, results))
+
+    @cachedInlineCallbacks(num_args=2, tree=True)
+    def _get_cached_user_device(self, user_id, device_id):
+        content = yield self._simple_select_one_onecol(
+            table="device_lists_remote_cache",
+            keyvalues={
+                "user_id": user_id,
+                "device_id": device_id,
+            },
+            retcol="content",
+            desc="_get_cached_user_device",
+        )
+        defer.returnValue(db_to_json(content))
+
+    @cachedInlineCallbacks()
+    def _get_cached_devices_for_user(self, user_id):
+        devices = yield self._simple_select_list(
+            table="device_lists_remote_cache",
+            keyvalues={
+                "user_id": user_id,
+            },
+            retcols=("device_id", "content"),
+            desc="_get_cached_devices_for_user",
+        )
+        defer.returnValue({
+            device["device_id"]: db_to_json(device["content"])
+            for device in devices
+        })
+
+    def get_devices_with_keys_by_user(self, user_id):
+        """Get all devices (with any device keys) for a user
+
+        Returns:
+            (stream_id, devices)
+        """
+        return self.runInteraction(
+            "get_devices_with_keys_by_user",
+            self._get_devices_with_keys_by_user_txn, user_id,
+        )
+
+    def _get_devices_with_keys_by_user_txn(self, txn, user_id):
+        now_stream_id = self._device_list_id_gen.get_current_token()
+
+        devices = self._get_e2e_device_keys_txn(
+            txn, [(user_id, None)], include_all_devices=True
+        )
+
+        if devices:
+            user_devices = devices[user_id]
+            results = []
+            for device_id, device in iteritems(user_devices):
+                result = {
+                    "device_id": device_id,
+                }
+
+                key_json = device.get("key_json", None)
+                if key_json:
+                    result["keys"] = db_to_json(key_json)
+                device_display_name = device.get("device_display_name", None)
+                if device_display_name:
+                    result["device_display_name"] = device_display_name
+
+                results.append(result)
+
+            return now_stream_id, results
+
+        return now_stream_id, []
+
+    @defer.inlineCallbacks
+    def get_user_whose_devices_changed(self, from_key):
+        """Get set of users whose devices have changed since `from_key`.
+        """
+        from_key = int(from_key)
+        changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
+        if changed is not None:
+            defer.returnValue(set(changed))
+
+        sql = """
+            SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
+        """
+        rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
+        defer.returnValue(set(row[0] for row in rows))
+
+    def get_all_device_list_changes_for_remotes(self, from_key, to_key):
+        """Return a list of `(stream_id, user_id, destination)` which is the
+        combined list of changes to devices, and which destinations need to be
+        poked. `destination` may be None if no destinations need to be poked.
+        """
+        # We do a group by here as there can be a large number of duplicate
+        # entries, since we throw away device IDs.
+        sql = """
+            SELECT MAX(stream_id) AS stream_id, user_id, destination
+            FROM device_lists_stream
+            LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
+            WHERE ? < stream_id AND stream_id <= ?
+            GROUP BY user_id, destination
+        """
+        return self._execute(
+            "get_all_device_list_changes_for_remotes", None,
+            sql, from_key, to_key
+        )
+
+    @cached(max_entries=10000)
+    def get_device_list_last_stream_id_for_remote(self, user_id):
+        """Get the last stream_id we got for a user. May be None if we haven't
+        got any information for them.
+        """
+        return self._simple_select_one_onecol(
+            table="device_lists_remote_extremeties",
+            keyvalues={"user_id": user_id},
+            retcol="stream_id",
+            desc="get_device_list_last_stream_id_for_remote",
+            allow_none=True,
+        )
+
+    @cachedList(cached_method_name="get_device_list_last_stream_id_for_remote",
+                list_name="user_ids", inlineCallbacks=True)
+    def get_device_list_last_stream_id_for_remotes(self, user_ids):
+        rows = yield self._simple_select_many_batch(
+            table="device_lists_remote_extremeties",
+            column="user_id",
+            iterable=user_ids,
+            retcols=("user_id", "stream_id",),
+            desc="get_device_list_last_stream_id_for_remotes",
+        )
+
+        results = {user_id: None for user_id in user_ids}
+        results.update({
+            row["user_id"]: row["stream_id"] for row in rows
+        })
+
+        defer.returnValue(results)
+
+
+class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
     def __init__(self, db_conn, hs):
         super(DeviceStore, self).__init__(db_conn, hs)
 
@@ -121,24 +456,6 @@ class DeviceStore(BackgroundUpdateStore):
                          initial_device_display_name, e)
             raise StoreError(500, "Problem storing device.")
 
-    def get_device(self, user_id, device_id):
-        """Retrieve a device.
-
-        Args:
-            user_id (str): The ID of the user which owns the device
-            device_id (str): The ID of the device to retrieve
-        Returns:
-            defer.Deferred for a dict containing the device information
-        Raises:
-            StoreError: if the device is not found
-        """
-        return self._simple_select_one(
-            table="devices",
-            keyvalues={"user_id": user_id, "device_id": device_id},
-            retcols=("user_id", "device_id", "display_name"),
-            desc="get_device",
-        )
-
     @defer.inlineCallbacks
     def delete_device(self, user_id, device_id):
         """Delete a device.
@@ -203,57 +520,6 @@ class DeviceStore(BackgroundUpdateStore):
         )
 
     @defer.inlineCallbacks
-    def get_devices_by_user(self, user_id):
-        """Retrieve all of a user's registered devices.
-
-        Args:
-            user_id (str):
-        Returns:
-            defer.Deferred: resolves to a dict from device_id to a dict
-            containing "device_id", "user_id" and "display_name" for each
-            device.
-        """
-        devices = yield self._simple_select_list(
-            table="devices",
-            keyvalues={"user_id": user_id},
-            retcols=("user_id", "device_id", "display_name"),
-            desc="get_devices_by_user"
-        )
-
-        defer.returnValue({d["device_id"]: d for d in devices})
-
-    @cached(max_entries=10000)
-    def get_device_list_last_stream_id_for_remote(self, user_id):
-        """Get the last stream_id we got for a user. May be None if we haven't
-        got any information for them.
-        """
-        return self._simple_select_one_onecol(
-            table="device_lists_remote_extremeties",
-            keyvalues={"user_id": user_id},
-            retcol="stream_id",
-            desc="get_device_list_remote_extremity",
-            allow_none=True,
-        )
-
-    @cachedList(cached_method_name="get_device_list_last_stream_id_for_remote",
-                list_name="user_ids", inlineCallbacks=True)
-    def get_device_list_last_stream_id_for_remotes(self, user_ids):
-        rows = yield self._simple_select_many_batch(
-            table="device_lists_remote_extremeties",
-            column="user_id",
-            iterable=user_ids,
-            retcols=("user_id", "stream_id",),
-            desc="get_user_devices_from_cache",
-        )
-
-        results = {user_id: None for user_id in user_ids}
-        results.update({
-            row["user_id"]: row["stream_id"] for row in rows
-        })
-
-        defer.returnValue(results)
-
-    @defer.inlineCallbacks
     def mark_remote_user_device_list_as_unsubscribed(self, user_id):
         """Mark that we no longer track device lists for remote user.
         """
@@ -405,268 +671,6 @@ class DeviceStore(BackgroundUpdateStore):
             lock=False,
         )
 
-    def get_devices_by_remote(self, destination, from_stream_id):
-        """Get stream of updates to send to remote servers
-
-        Returns:
-            (int, list[dict]): current stream id and list of updates
-        """
-        now_stream_id = self._device_list_id_gen.get_current_token()
-
-        has_changed = self._device_list_federation_stream_cache.has_entity_changed(
-            destination, int(from_stream_id)
-        )
-        if not has_changed:
-            return (now_stream_id, [])
-
-        return self.runInteraction(
-            "get_devices_by_remote", self._get_devices_by_remote_txn,
-            destination, from_stream_id, now_stream_id,
-        )
-
-    def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
-                                   now_stream_id):
-        sql = """
-            SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
-            WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
-            GROUP BY user_id, device_id
-            LIMIT 20
-        """
-        txn.execute(
-            sql, (destination, from_stream_id, now_stream_id, False)
-        )
-
-        # maps (user_id, device_id) -> stream_id
-        query_map = {(r[0], r[1]): r[2] for r in txn}
-        if not query_map:
-            return (now_stream_id, [])
-
-        if len(query_map) >= 20:
-            now_stream_id = max(stream_id for stream_id in itervalues(query_map))
-
-        devices = self._get_e2e_device_keys_txn(
-            txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
-        )
-
-        prev_sent_id_sql = """
-            SELECT coalesce(max(stream_id), 0) as stream_id
-            FROM device_lists_outbound_last_success
-            WHERE destination = ? AND user_id = ? AND stream_id <= ?
-        """
-
-        results = []
-        for user_id, user_devices in iteritems(devices):
-            # The prev_id for the first row is always the last row before
-            # `from_stream_id`
-            txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
-            rows = txn.fetchall()
-            prev_id = rows[0][0]
-            for device_id, device in iteritems(user_devices):
-                stream_id = query_map[(user_id, device_id)]
-                result = {
-                    "user_id": user_id,
-                    "device_id": device_id,
-                    "prev_id": [prev_id] if prev_id else [],
-                    "stream_id": stream_id,
-                }
-
-                prev_id = stream_id
-
-                if device is not None:
-                    key_json = device.get("key_json", None)
-                    if key_json:
-                        result["keys"] = db_to_json(key_json)
-                    device_display_name = device.get("device_display_name", None)
-                    if device_display_name:
-                        result["device_display_name"] = device_display_name
-                else:
-                    result["deleted"] = True
-
-                results.append(result)
-
-        return (now_stream_id, results)
-
-    @defer.inlineCallbacks
-    def get_user_devices_from_cache(self, query_list):
-        """Get the devices (and keys if any) for remote users from the cache.
-
-        Args:
-            query_list(list): List of (user_id, device_ids), if device_ids is
-                falsey then return all device ids for that user.
-
-        Returns:
-            (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
-            a set of user_ids and results_map is a mapping of
-            user_id -> device_id -> device_info
-        """
-        user_ids = set(user_id for user_id, _ in query_list)
-        user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
-        user_ids_in_cache = set(
-            user_id for user_id, stream_id in user_map.items() if stream_id
-        )
-        user_ids_not_in_cache = user_ids - user_ids_in_cache
-
-        results = {}
-        for user_id, device_id in query_list:
-            if user_id not in user_ids_in_cache:
-                continue
-
-            if device_id:
-                device = yield self._get_cached_user_device(user_id, device_id)
-                results.setdefault(user_id, {})[device_id] = device
-            else:
-                results[user_id] = yield self._get_cached_devices_for_user(user_id)
-
-        defer.returnValue((user_ids_not_in_cache, results))
-
-    @cachedInlineCallbacks(num_args=2, tree=True)
-    def _get_cached_user_device(self, user_id, device_id):
-        content = yield self._simple_select_one_onecol(
-            table="device_lists_remote_cache",
-            keyvalues={
-                "user_id": user_id,
-                "device_id": device_id,
-            },
-            retcol="content",
-            desc="_get_cached_user_device",
-        )
-        defer.returnValue(db_to_json(content))
-
-    @cachedInlineCallbacks()
-    def _get_cached_devices_for_user(self, user_id):
-        devices = yield self._simple_select_list(
-            table="device_lists_remote_cache",
-            keyvalues={
-                "user_id": user_id,
-            },
-            retcols=("device_id", "content"),
-            desc="_get_cached_devices_for_user",
-        )
-        defer.returnValue({
-            device["device_id"]: db_to_json(device["content"])
-            for device in devices
-        })
-
-    def get_devices_with_keys_by_user(self, user_id):
-        """Get all devices (with any device keys) for a user
-
-        Returns:
-            (stream_id, devices)
-        """
-        return self.runInteraction(
-            "get_devices_with_keys_by_user",
-            self._get_devices_with_keys_by_user_txn, user_id,
-        )
-
-    def _get_devices_with_keys_by_user_txn(self, txn, user_id):
-        now_stream_id = self._device_list_id_gen.get_current_token()
-
-        devices = self._get_e2e_device_keys_txn(
-            txn, [(user_id, None)], include_all_devices=True
-        )
-
-        if devices:
-            user_devices = devices[user_id]
-            results = []
-            for device_id, device in iteritems(user_devices):
-                result = {
-                    "device_id": device_id,
-                }
-
-                key_json = device.get("key_json", None)
-                if key_json:
-                    result["keys"] = db_to_json(key_json)
-                device_display_name = device.get("device_display_name", None)
-                if device_display_name:
-                    result["device_display_name"] = device_display_name
-
-                results.append(result)
-
-            return now_stream_id, results
-
-        return now_stream_id, []
-
-    def mark_as_sent_devices_by_remote(self, destination, stream_id):
-        """Mark that updates have successfully been sent to the destination.
-        """
-        return self.runInteraction(
-            "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
-            destination, stream_id,
-        )
-
-    def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
-        # We update the device_lists_outbound_last_success with the successfully
-        # poked users. We do the join to see which users need to be inserted and
-        # which updated.
-        sql = """
-            SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
-            FROM device_lists_outbound_pokes as o
-            LEFT JOIN device_lists_outbound_last_success as s
-                USING (destination, user_id)
-            WHERE destination = ? AND o.stream_id <= ?
-            GROUP BY user_id
-        """
-        txn.execute(sql, (destination, stream_id,))
-        rows = txn.fetchall()
-
-        sql = """
-            UPDATE device_lists_outbound_last_success
-            SET stream_id = ?
-            WHERE destination = ? AND user_id = ?
-        """
-        txn.executemany(
-            sql, ((row[1], destination, row[0],) for row in rows if row[2])
-        )
-
-        sql = """
-            INSERT INTO device_lists_outbound_last_success
-            (destination, user_id, stream_id) VALUES (?, ?, ?)
-        """
-        txn.executemany(
-            sql, ((destination, row[0], row[1],) for row in rows if not row[2])
-        )
-
-        # Delete all sent outbound pokes
-        sql = """
-            DELETE FROM device_lists_outbound_pokes
-            WHERE destination = ? AND stream_id <= ?
-        """
-        txn.execute(sql, (destination, stream_id,))
-
-    @defer.inlineCallbacks
-    def get_user_whose_devices_changed(self, from_key):
-        """Get set of users whose devices have changed since `from_key`.
-        """
-        from_key = int(from_key)
-        changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
-        if changed is not None:
-            defer.returnValue(set(changed))
-
-        sql = """
-            SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
-        """
-        rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
-        defer.returnValue(set(row[0] for row in rows))
-
-    def get_all_device_list_changes_for_remotes(self, from_key, to_key):
-        """Return a list of `(stream_id, user_id, destination)` which is the
-        combined list of changes to devices, and which destinations need to be
-        poked. `destination` may be None if no destinations need to be poked.
-        """
-        # We do a group by here as there can be a large number of duplicate
-        # entries, since we throw away device IDs.
-        sql = """
-            SELECT MAX(stream_id) AS stream_id, user_id, destination
-            FROM device_lists_stream
-            LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
-            WHERE ? < stream_id AND stream_id <= ?
-            GROUP BY user_id, destination
-        """
-        return self._execute(
-            "get_all_device_list_changes_for_remotes", None,
-            sql, from_key, to_key
-        )
-
     @defer.inlineCallbacks
     def add_device_change_to_streams(self, user_id, device_ids, hosts):
         """Persist that a user's devices have been updated, and which hosts
@@ -732,9 +736,6 @@ class DeviceStore(BackgroundUpdateStore):
             ]
         )
 
-    def get_device_stream_token(self):
-        return self._device_list_id_gen.get_current_token()
-
     def _prune_old_outbound_device_pokes(self):
         """Delete old entries out of the device_lists_outbound_pokes to ensure
         that we don't fill up due to dead servers. We keep one entry per
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 2a0f6cfca9..e381e472a2 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -23,49 +23,7 @@ from synapse.util.caches.descriptors import cached
 from ._base import SQLBaseStore, db_to_json
 
 
-class EndToEndKeyStore(SQLBaseStore):
-    def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
-        """Stores device keys for a device. Returns whether there was a change
-        or the keys were already in the database.
-        """
-        def _set_e2e_device_keys_txn(txn):
-            old_key_json = self._simple_select_one_onecol_txn(
-                txn,
-                table="e2e_device_keys_json",
-                keyvalues={
-                    "user_id": user_id,
-                    "device_id": device_id,
-                },
-                retcol="key_json",
-                allow_none=True,
-            )
-
-            # In py3 we need old_key_json to match new_key_json type. The DB
-            # returns unicode while encode_canonical_json returns bytes.
-            new_key_json = encode_canonical_json(device_keys).decode("utf-8")
-
-            if old_key_json == new_key_json:
-                return False
-
-            self._simple_upsert_txn(
-                txn,
-                table="e2e_device_keys_json",
-                keyvalues={
-                    "user_id": user_id,
-                    "device_id": device_id,
-                },
-                values={
-                    "ts_added_ms": time_now,
-                    "key_json": new_key_json,
-                }
-            )
-
-            return True
-
-        return self.runInteraction(
-            "set_e2e_device_keys", _set_e2e_device_keys_txn
-        )
-
+class EndToEndKeyWorkerStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_e2e_device_keys(
         self, query_list, include_all_devices=False,
@@ -238,6 +196,50 @@ class EndToEndKeyStore(SQLBaseStore):
             "count_e2e_one_time_keys", _count_e2e_one_time_keys
         )
 
+
+class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
+    def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
+        """Stores device keys for a device. Returns whether there was a change
+        or the keys were already in the database.
+        """
+        def _set_e2e_device_keys_txn(txn):
+            old_key_json = self._simple_select_one_onecol_txn(
+                txn,
+                table="e2e_device_keys_json",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                },
+                retcol="key_json",
+                allow_none=True,
+            )
+
+            # In py3 we need old_key_json to match new_key_json type. The DB
+            # returns unicode while encode_canonical_json returns bytes.
+            new_key_json = encode_canonical_json(device_keys).decode("utf-8")
+
+            if old_key_json == new_key_json:
+                return False
+
+            self._simple_upsert_txn(
+                txn,
+                table="e2e_device_keys_json",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                },
+                values={
+                    "ts_added_ms": time_now,
+                    "key_json": new_key_json,
+                }
+            )
+
+            return True
+
+        return self.runInteraction(
+            "set_e2e_device_keys", _set_e2e_device_keys_txn
+        )
+
     def claim_e2e_one_time_keys(self, query_list):
         """Take a list of one time keys out of the database"""
         def _claim_e2e_one_time_keys(txn):
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 38809ed0fc..a8d90456e3 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -442,6 +442,28 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
         event_results.reverse()
         return event_results
 
+    @defer.inlineCallbacks
+    def get_successor_events(self, event_ids):
+        """Fetch all events that have the given events as a prev event
+
+        Args:
+            event_ids (iterable[str])
+
+        Returns:
+            Deferred[list[str]]
+        """
+        rows = yield self._simple_select_many_batch(
+            table="event_edges",
+            column="prev_event_id",
+            iterable=event_ids,
+            retcols=("event_id",),
+            desc="get_successor_events"
+        )
+
+        defer.returnValue([
+            row["event_id"] for row in rows
+        ])
+
 
 class EventFederationStore(EventFederationWorkerStore):
     """ Responsible for storing and serving up the various graphs associated
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 0ac665e967..0fd1ccc40a 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -346,15 +346,23 @@ class ReceiptsStore(ReceiptsWorkerStore):
 
     def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
                                       user_id, event_id, data, stream_id):
+        """Inserts a read-receipt into the database if it's newer than the current RR
+
+        Returns: int|None
+            None if the RR is older than the current RR
+            otherwise, the rx timestamp of the event that the RR corresponds to
+                (or 0 if the event is unknown)
+        """
         res = self._simple_select_one_txn(
             txn,
             table="events",
-            retcols=["topological_ordering", "stream_ordering"],
+            retcols=["stream_ordering", "received_ts"],
             keyvalues={"event_id": event_id},
             allow_none=True
         )
 
         stream_ordering = int(res["stream_ordering"]) if res else None
+        rx_ts = res["received_ts"] if res else 0
 
         # We don't want to clobber receipts for more recent events, so we
         # have to compare orderings of existing receipts
@@ -373,7 +381,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
                         "one for later event %s",
                         event_id, eid,
                     )
-                    return False
+                    return None
 
         txn.call_after(
             self.get_receipts_for_room.invalidate, (room_id, receipt_type)
@@ -429,7 +437,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
                 stream_ordering=stream_ordering,
             )
 
-        return True
+        return rx_ts
 
     @defer.inlineCallbacks
     def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
@@ -466,7 +474,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
 
         stream_id_manager = self._receipts_id_gen.get_next()
         with stream_id_manager as stream_id:
-            have_persisted = yield self.runInteraction(
+            event_ts = yield self.runInteraction(
                 "insert_linearized_receipt",
                 self.insert_linearized_receipt_txn,
                 room_id, receipt_type, user_id, linearized_event_id,
@@ -474,8 +482,14 @@ class ReceiptsStore(ReceiptsWorkerStore):
                 stream_id=stream_id,
             )
 
-            if not have_persisted:
-                defer.returnValue(None)
+        if event_ts is None:
+            defer.returnValue(None)
+
+        now = self._clock.time_msec()
+        logger.debug(
+            "RR for event %s in %s (%i ms old)",
+            linearized_event_id, room_id, now - event_ts,
+        )
 
         yield self.insert_graph_receipt(
             room_id, receipt_type, user_id, event_ids, data
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index d6cfdba519..580fafeb3a 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -191,6 +191,25 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
     @defer.inlineCallbacks
     def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0,
                                          order='DESC'):
+        """Get new room events in stream ordering since `from_key`.
+
+        Args:
+            room_id (str)
+            from_key (str): Token from which no events are returned before
+            to_key (str): Token from which no events are returned after. (This
+                is typically the current stream token)
+            limit (int): Maximum number of events to return
+            order (str): Either "DESC" or "ASC". Determines which events are
+                returned when the result is limited. If "DESC" then the most
+                recent `limit` events are returned, otherwise returns the
+                oldest `limit` events.
+
+        Returns:
+            Deferred[dict[str,tuple[list[FrozenEvent], str]]]
+                A map from room id to a tuple containing:
+                    - list of recent events in the room
+                    - stream ordering key for the start of the chunk of events returned.
+        """
         from_id = RoomStreamToken.parse_stream_token(from_key).stream
 
         room_ids = yield self._events_stream_cache.get_entities_changed(
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 0281a7c919..efec21673b 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -216,28 +216,36 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
 
 
 @defer.inlineCallbacks
-def filter_events_for_server(store, server_name, events):
-    # Whatever else we do, we need to check for senders which have requested
-    # erasure of their data.
-    erased_senders = yield store.are_users_erased(
-        (e.sender for e in events),
-    )
+def filter_events_for_server(store, server_name, events, redact=True,
+                             check_history_visibility_only=False):
+    """Filter a list of events based on whether given server is allowed to
+    see them.
 
-    def redact_disallowed(event, state):
-        # if the sender has been gdpr17ed, always return a redacted
-        # copy of the event.
-        if erased_senders[event.sender]:
+    Args:
+        store (DataStore)
+        server_name (str)
+        events (iterable[FrozenEvent])
+        redact (bool): Whether to return a redacted version of the event, or
+            to filter them out entirely.
+        check_history_visibility_only (bool): Whether to only check the
+            history visibility, rather than things like if the sender has been
+            erased. This is used e.g. during pagination to decide whether to
+            backfill or not.
+
+    Returns
+        Deferred[list[FrozenEvent]]
+    """
+
+    def is_sender_erased(event, erased_senders):
+        if erased_senders and erased_senders[event.sender]:
             logger.info(
                 "Sender of %s has been erased, redacting",
                 event.event_id,
             )
-            return prune_event(event)
-
-        # state will be None if we decided we didn't need to filter by
-        # room membership.
-        if not state:
-            return event
+            return True
+        return False
 
+    def check_event_is_visible(event, state):
         history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
         if history:
             visibility = history.content.get("history_visibility", "shared")
@@ -259,17 +267,17 @@ def filter_events_for_server(store, server_name, events):
 
                     memtype = ev.membership
                     if memtype == Membership.JOIN:
-                        return event
+                        return True
                     elif memtype == Membership.INVITE:
                         if visibility == "invited":
-                            return event
+                            return True
                 else:
                     # server has no users in the room: redact
-                    return prune_event(event)
+                    return False
 
-        return event
+        return True
 
-    # Next lets check to see if all the events have a history visibility
+    # Lets check to see if all the events have a history visibility
     # of "shared" or "world_readable". If thats the case then we don't
     # need to check membership (as we know the server is in the room).
     event_to_state_ids = yield store.get_state_ids_for_events(
@@ -296,16 +304,31 @@ def filter_events_for_server(store, server_name, events):
             for e in itervalues(event_map)
         )
 
+    if not check_history_visibility_only:
+        erased_senders = yield store.are_users_erased(
+            (e.sender for e in events),
+        )
+    else:
+        # We don't want to check whether users are erased, which is equivalent
+        # to no users having been erased.
+        erased_senders = {}
+
     if all_open:
         # all the history_visibility state affecting these events is open, so
         # we don't need to filter by membership state. We *do* need to check
         # for user erasure, though.
         if erased_senders:
-            events = [
-                redact_disallowed(e, None)
-                for e in events
-            ]
+            to_return = []
+            for e in events:
+                if not is_sender_erased(e, erased_senders):
+                    to_return.append(e)
+                elif redact:
+                    to_return.append(prune_event(e))
+
+            defer.returnValue(to_return)
 
+        # If there are no erased users then we can just return the given list
+        # of events without having to copy it.
         defer.returnValue(events)
 
     # Ok, so we're dealing with events that have non-trivial visibility
@@ -361,7 +384,13 @@ def filter_events_for_server(store, server_name, events):
         for e_id, key_to_eid in iteritems(event_to_state_ids)
     }
 
-    defer.returnValue([
-        redact_disallowed(e, event_to_state[e.event_id])
-        for e in events
-    ])
+    to_return = []
+    for e in events:
+        erased = is_sender_erased(e, erased_senders)
+        visible = check_event_is_visible(e, event_to_state[e.event_id])
+        if visible and not erased:
+            to_return.append(e)
+        elif redact:
+            to_return.append(prune_event(e))
+
+    defer.returnValue(to_return)