summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/app/generic_worker.py53
-rw-r--r--synapse/config/logger.py1
-rw-r--r--synapse/config/workers.py30
-rw-r--r--synapse/handlers/federation.py16
-rw-r--r--synapse/handlers/message.py12
-rw-r--r--synapse/handlers/presence.py6
-rw-r--r--synapse/handlers/room.py7
-rw-r--r--synapse/handlers/room_member.py39
-rw-r--r--synapse/replication/http/__init__.py4
-rw-r--r--synapse/replication/http/_base.py3
-rw-r--r--synapse/replication/http/membership.py40
-rw-r--r--synapse/replication/http/presence.py116
-rw-r--r--synapse/replication/tcp/handler.py10
-rw-r--r--synapse/rest/admin/rooms.py11
-rw-r--r--synapse/storage/data_stores/__init__.py6
-rw-r--r--synapse/storage/data_stores/main/devices.py30
-rw-r--r--synapse/storage/data_stores/main/events.py34
-rw-r--r--synapse/storage/data_stores/main/events_worker.py2
-rw-r--r--synapse/storage/data_stores/main/roommember.py25
-rw-r--r--synapse/storage/persist_events.py6
20 files changed, 378 insertions, 73 deletions
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index a37520000a..2906b93f6a 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -39,7 +39,11 @@ from synapse.config.homeserver import HomeServerConfig
 from synapse.config.logger import setup_logging
 from synapse.federation import send_queue
 from synapse.federation.transport.server import TransportLayerServer
-from synapse.handlers.presence import BasePresenceHandler, get_interested_parties
+from synapse.handlers.presence import (
+    BasePresenceHandler,
+    PresenceState,
+    get_interested_parties,
+)
 from synapse.http.server import JsonResource, OptionsResource
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
 from synapse.http.site import SynapseSite
@@ -47,6 +51,10 @@ from synapse.logging.context import LoggingContext
 from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
+from synapse.replication.http.presence import (
+    ReplicationBumpPresenceActiveTime,
+    ReplicationPresenceSetState,
+)
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
 from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
@@ -247,6 +255,9 @@ class GenericWorkerPresence(BasePresenceHandler):
         # but we haven't notified the master of that yet
         self.users_going_offline = {}
 
+        self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs)
+        self._set_state_client = ReplicationPresenceSetState.make_client(hs)
+
         self._send_stop_syncing_loop = self.clock.looping_call(
             self.send_stop_syncing, UPDATE_SYNCING_USERS_MS
         )
@@ -304,10 +315,6 @@ class GenericWorkerPresence(BasePresenceHandler):
                 self.users_going_offline.pop(user_id, None)
                 self.send_user_sync(user_id, False, last_sync_ms)
 
-    def set_state(self, user, state, ignore_status_msg=False):
-        # TODO Hows this supposed to work?
-        return defer.succeed(None)
-
     async def user_syncing(
         self, user_id: str, affect_presence: bool
     ) -> ContextManager[None]:
@@ -386,6 +393,42 @@ class GenericWorkerPresence(BasePresenceHandler):
             if count > 0
         ]
 
+    async def set_state(self, target_user, state, ignore_status_msg=False):
+        """Set the presence state of the user.
+        """
+        presence = state["presence"]
+
+        valid_presence = (
+            PresenceState.ONLINE,
+            PresenceState.UNAVAILABLE,
+            PresenceState.OFFLINE,
+        )
+        if presence not in valid_presence:
+            raise SynapseError(400, "Invalid presence state")
+
+        user_id = target_user.to_string()
+
+        # If presence is disabled, no-op
+        if not self.hs.config.use_presence:
+            return
+
+        # Proxy request to master
+        await self._set_state_client(
+            user_id=user_id, state=state, ignore_status_msg=ignore_status_msg
+        )
+
+    async def bump_presence_active_time(self, user):
+        """We've seen the user do something that indicates they're interacting
+        with the app.
+        """
+        # If presence is disabled, no-op
+        if not self.hs.config.use_presence:
+            return
+
+        # Proxy request to master
+        user_id = user.to_string()
+        await self._bump_active_client(user_id=user_id)
+
 
 class GenericWorkerTyping(object):
     def __init__(self, hs):
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index a25c70e928..49f6c32beb 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -257,5 +257,6 @@ def setup_logging(
     logging.warning("***** STARTING SERVER *****")
     logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse))
     logging.info("Server hostname: %s", config.server_name)
+    logging.info("Instance name: %s", hs.get_instance_name())
 
     return logger
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index c80c338584..ed06b91a54 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -15,7 +15,7 @@
 
 import attr
 
-from ._base import Config
+from ._base import Config, ConfigError
 
 
 @attr.s
@@ -27,6 +27,17 @@ class InstanceLocationConfig:
     port = attr.ib(type=int)
 
 
+@attr.s
+class WriterLocations:
+    """Specifies the instances that write various streams.
+
+    Attributes:
+        events: The instance that writes to the event and backfill streams.
+    """
+
+    events = attr.ib(default="master", type=str)
+
+
 class WorkerConfig(Config):
     """The workers are processes run separately to the main synapse process.
     They have their own pid_file and listener configuration. They use the
@@ -83,11 +94,26 @@ class WorkerConfig(Config):
                     bind_addresses.append("")
 
         # A map from instance name to host/port of their HTTP replication endpoint.
-        instance_map = config.get("instance_map", {}) or {}
+        instance_map = config.get("instance_map") or {}
         self.instance_map = {
             name: InstanceLocationConfig(**c) for name, c in instance_map.items()
         }
 
+        # Map from type of streams to source, c.f. WriterLocations.
+        writers = config.get("stream_writers") or {}
+        self.writers = WriterLocations(**writers)
+
+        # Check that the configured writer for events also appears in
+        # `instance_map`.
+        if (
+            self.writers.events != "master"
+            and self.writers.events not in self.instance_map
+        ):
+            raise ConfigError(
+                "Instance %r is configured to write events but does not appear in `instance_map` config."
+                % (self.writers.events,)
+            )
+
     def read_arguments(self, args):
         # We support a bunch of command line arguments that override options in
         # the config. A lot of these options have a worker_* prefix when running
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index e354c803db..75ec90d267 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -126,11 +126,10 @@ class FederationHandler(BaseHandler):
         self._server_notices_mxid = hs.config.server_notices_mxid
         self.config = hs.config
         self.http_client = hs.get_simple_http_client()
+        self._instance_name = hs.get_instance_name()
         self._replication = hs.get_replication_data_handler()
 
-        self._send_events_to_master = ReplicationFederationSendEventsRestServlet.make_client(
-            hs
-        )
+        self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs)
         self._notify_user_membership_change = ReplicationUserJoinedLeftRoomRestServlet.make_client(
             hs
         )
@@ -1243,6 +1242,10 @@ class FederationHandler(BaseHandler):
 
             content: The event content to use for the join event.
         """
+        # TODO: We should be able to call this on workers, but the upgrading of
+        # room stuff after join currently doesn't work on workers.
+        assert self.config.worker.worker_app is None
+
         logger.debug("Joining %s to %s", joinee, room_id)
 
         origin, event, room_version_obj = await self._make_and_verify_event(
@@ -1314,7 +1317,7 @@ class FederationHandler(BaseHandler):
             #
             # TODO: Currently the events stream is written to from master
             await self._replication.wait_for_stream_position(
-                "master", "events", max_stream_id
+                self.config.worker.writers.events, "events", max_stream_id
             )
 
             # Check whether this room is the result of an upgrade of a room we already know
@@ -2854,8 +2857,9 @@ class FederationHandler(BaseHandler):
             backfilled: Whether these events are a result of
                 backfilling or not
         """
-        if self.config.worker_app:
-            result = await self._send_events_to_master(
+        if self.config.worker.writers.events != self._instance_name:
+            result = await self._send_events(
+                instance_name=self.config.worker.writers.events,
                 store=self.store,
                 event_and_contexts=event_and_contexts,
                 backfilled=backfilled,
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index f445e2aa2a..ea25f0515a 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -366,10 +366,11 @@ class EventCreationHandler(object):
         self.notifier = hs.get_notifier()
         self.config = hs.config
         self.require_membership_for_aliases = hs.config.require_membership_for_aliases
+        self._instance_name = hs.get_instance_name()
 
         self.room_invite_state_types = self.hs.config.room_invite_state_types
 
-        self.send_event_to_master = ReplicationSendEventRestServlet.make_client(hs)
+        self.send_event = ReplicationSendEventRestServlet.make_client(hs)
 
         # This is only used to get at ratelimit function, and maybe_kick_guest_users
         self.base_handler = BaseHandler(hs)
@@ -835,8 +836,9 @@ class EventCreationHandler(object):
         success = False
         try:
             # If we're a worker we need to hit out to the master.
-            if self.config.worker_app:
-                result = await self.send_event_to_master(
+            if self.config.worker.writers.events != self._instance_name:
+                result = await self.send_event(
+                    instance_name=self.config.worker.writers.events,
                     event_id=event.event_id,
                     store=self.store,
                     requester=requester,
@@ -902,9 +904,9 @@ class EventCreationHandler(object):
         """Called when we have fully built the event, have already
         calculated the push actions for the event, and checked auth.
 
-        This should only be run on master.
+        This should only be run on the instance in charge of persisting events.
         """
-        assert not self.config.worker_app
+        assert self.config.worker.writers.events == self._instance_name
 
         if ratelimit:
             # We check if this is a room admin redacting an event so that we
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 9ea11c0754..3594f3b00f 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -193,6 +193,12 @@ class BasePresenceHandler(abc.ABC):
     ) -> None:
         """Set the presence state of the user. """
 
+    @abc.abstractmethod
+    async def bump_presence_active_time(self, user: UserID):
+        """We've seen the user do something that indicates they're interacting
+        with the app.
+        """
+
 
 class PresenceHandler(BasePresenceHandler):
     def __init__(self, hs: "synapse.server.HomeServer"):
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 2698a129ca..61db3ccc43 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -89,6 +89,8 @@ class RoomCreationHandler(BaseHandler):
         self.room_member_handler = hs.get_room_member_handler()
         self.config = hs.config
 
+        self._replication = hs.get_replication_data_handler()
+
         # linearizer to stop two upgrades happening at once
         self._upgrade_linearizer = Linearizer("room_upgrade_linearizer")
 
@@ -752,6 +754,11 @@ class RoomCreationHandler(BaseHandler):
         if room_alias:
             result["room_alias"] = room_alias.to_string()
 
+        # Always wait for room creation to progate before returning
+        await self._replication.wait_for_stream_position(
+            self.hs.config.worker.writers.events, "events", last_stream_id
+        )
+
         return result, last_stream_id
 
     async def _send_events_for_new_room(
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 691b6705b2..0f7af982f0 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -26,6 +26,9 @@ from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import AuthError, Codes, SynapseError
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
+from synapse.replication.http.membership import (
+    ReplicationLocallyRejectInviteRestServlet,
+)
 from synapse.types import Collection, Requester, RoomAlias, RoomID, UserID
 from synapse.util.async_helpers import Linearizer
 from synapse.util.distributor import user_joined_room, user_left_room
@@ -44,11 +47,6 @@ class RoomMemberHandler(object):
     __metaclass__ = abc.ABCMeta
 
     def __init__(self, hs):
-        """
-
-        Args:
-            hs (synapse.server.HomeServer):
-        """
         self.hs = hs
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
@@ -71,6 +69,17 @@ class RoomMemberHandler(object):
         self._enable_lookup = hs.config.enable_3pid_lookup
         self.allow_per_room_profiles = self.config.allow_per_room_profiles
 
+        self._event_stream_writer_instance = hs.config.worker.writers.events
+        self._is_on_event_persistence_instance = (
+            self._event_stream_writer_instance == hs.get_instance_name()
+        )
+        if self._is_on_event_persistence_instance:
+            self.persist_event_storage = hs.get_storage().persistence
+        else:
+            self._locally_reject_client = ReplicationLocallyRejectInviteRestServlet.make_client(
+                hs
+            )
+
         # This is only used to get at ratelimit function, and
         # maybe_kick_guest_users. It's fine there are multiple of these as
         # it doesn't store state.
@@ -121,6 +130,22 @@ class RoomMemberHandler(object):
         """
         raise NotImplementedError()
 
+    async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
+        """Mark the invite has having been rejected even though we failed to
+        create a leave event for it.
+        """
+        if self._is_on_event_persistence_instance:
+            return await self.persist_event_storage.locally_reject_invite(
+                user_id, room_id
+            )
+        else:
+            result = await self._locally_reject_client(
+                instance_name=self._event_stream_writer_instance,
+                user_id=user_id,
+                room_id=room_id,
+            )
+            return result["stream_id"]
+
     @abc.abstractmethod
     async def _user_joined_room(self, target: UserID, room_id: str) -> None:
         """Notifies distributor on master process that the user has joined the
@@ -1015,9 +1040,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
             #
             logger.warning("Failed to reject invite: %s", e)
 
-            stream_id = await self.store.locally_reject_invite(
-                target.to_string(), room_id
-            )
+            stream_id = await self.locally_reject_invite(target.to_string(), room_id)
             return None, stream_id
 
     async def _user_joined_room(self, target: UserID, room_id: str) -> None:
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index a909744e93..19b69e0e11 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -19,6 +19,7 @@ from synapse.replication.http import (
     federation,
     login,
     membership,
+    presence,
     register,
     send_event,
     streams,
@@ -35,10 +36,11 @@ class ReplicationRestResource(JsonResource):
     def register_servlets(self, hs):
         send_event.register_servlets(hs, self)
         federation.register_servlets(hs, self)
+        presence.register_servlets(hs, self)
+        membership.register_servlets(hs, self)
 
         # The following can't currently be instantiated on workers.
         if hs.config.worker.worker_app is None:
-            membership.register_servlets(hs, self)
             login.register_servlets(hs, self)
             register.register_servlets(hs, self)
             devices.register_servlets(hs, self)
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index c3136a4eb9..793cef6c26 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -142,6 +142,7 @@ class ReplicationEndpoint(object):
         """
         clock = hs.get_clock()
         client = hs.get_simple_http_client()
+        local_instance_name = hs.get_instance_name()
 
         master_host = hs.config.worker_replication_host
         master_port = hs.config.worker_replication_http_port
@@ -151,6 +152,8 @@ class ReplicationEndpoint(object):
         @trace(opname="outgoing_replication_request")
         @defer.inlineCallbacks
         def send_request(instance_name="master", **kwargs):
+            if instance_name == local_instance_name:
+                raise Exception("Trying to send HTTP request to self")
             if instance_name == "master":
                 host = master_host
                 port = master_port
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 050fd34562..a7174c4a8f 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -14,12 +14,16 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
 
 from synapse.http.servlet import parse_json_object_from_request
 from synapse.replication.http._base import ReplicationEndpoint
 from synapse.types import Requester, UserID
 from synapse.util.distributor import user_joined_room, user_left_room
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -106,6 +110,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
         self.federation_handler = hs.get_handlers().federation_handler
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
+        self.member_handler = hs.get_room_member_handler()
 
     @staticmethod
     def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content):
@@ -149,12 +154,44 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
             #
             logger.warning("Failed to reject invite: %s", e)
 
-            stream_id = await self.store.locally_reject_invite(user_id, room_id)
+            stream_id = await self.member_handler.locally_reject_invite(
+                user_id, room_id
+            )
             event_id = None
 
         return 200, {"event_id": event_id, "stream_id": stream_id}
 
 
+class ReplicationLocallyRejectInviteRestServlet(ReplicationEndpoint):
+    """Rejects the invite for the user and room locally.
+
+    Request format:
+
+        POST /_synapse/replication/locally_reject_invite/:room_id/:user_id
+
+        {}
+    """
+
+    NAME = "locally_reject_invite"
+    PATH_ARGS = ("room_id", "user_id")
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__(hs)
+
+        self.member_handler = hs.get_room_member_handler()
+
+    @staticmethod
+    def _serialize_payload(room_id, user_id):
+        return {}
+
+    async def _handle_request(self, request, room_id, user_id):
+        logger.info("locally_reject_invite: %s out of room: %s", user_id, room_id)
+
+        stream_id = await self.member_handler.locally_reject_invite(user_id, room_id)
+
+        return 200, {"stream_id": stream_id}
+
+
 class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
     """Notifies that a user has joined or left the room
 
@@ -208,3 +245,4 @@ def register_servlets(hs, http_server):
     ReplicationRemoteJoinRestServlet(hs).register(http_server)
     ReplicationRemoteRejectInviteRestServlet(hs).register(http_server)
     ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server)
+    ReplicationLocallyRejectInviteRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py
new file mode 100644
index 0000000000..ea1b33331b
--- /dev/null
+++ b/synapse/replication/http/presence.py
@@ -0,0 +1,116 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING
+
+from synapse.http.servlet import parse_json_object_from_request
+from synapse.replication.http._base import ReplicationEndpoint
+from synapse.types import UserID
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationBumpPresenceActiveTime(ReplicationEndpoint):
+    """We've seen the user do something that indicates they're interacting
+    with the app.
+
+    The POST looks like:
+
+        POST /_synapse/replication/bump_presence_active_time/<user_id>
+
+        200 OK
+
+        {}
+    """
+
+    NAME = "bump_presence_active_time"
+    PATH_ARGS = ("user_id",)
+    METHOD = "POST"
+    CACHE = False
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__(hs)
+
+        self._presence_handler = hs.get_presence_handler()
+
+    @staticmethod
+    def _serialize_payload(user_id):
+        return {}
+
+    async def _handle_request(self, request, user_id):
+        await self._presence_handler.bump_presence_active_time(
+            UserID.from_string(user_id)
+        )
+
+        return (
+            200,
+            {},
+        )
+
+
+class ReplicationPresenceSetState(ReplicationEndpoint):
+    """Set the presence state for a user.
+
+    The POST looks like:
+
+        POST /_synapse/replication/presence_set_state/<user_id>
+
+        {
+            "state": { ... },
+            "ignore_status_msg": false,
+        }
+
+        200 OK
+
+        {}
+    """
+
+    NAME = "presence_set_state"
+    PATH_ARGS = ("user_id",)
+    METHOD = "POST"
+    CACHE = False
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__(hs)
+
+        self._presence_handler = hs.get_presence_handler()
+
+    @staticmethod
+    def _serialize_payload(user_id, state, ignore_status_msg=False):
+        return {
+            "state": state,
+            "ignore_status_msg": ignore_status_msg,
+        }
+
+    async def _handle_request(self, request, user_id):
+        content = parse_json_object_from_request(request)
+
+        await self._presence_handler.set_state(
+            UserID.from_string(user_id), content["state"], content["ignore_status_msg"]
+        )
+
+        return (
+            200,
+            {},
+        )
+
+
+def register_servlets(hs, http_server):
+    ReplicationBumpPresenceActiveTime(hs).register(http_server)
+    ReplicationPresenceSetState(hs).register(http_server)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index acfa66a7a8..03300e5336 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -38,7 +38,9 @@ from synapse.replication.tcp.commands import (
 from synapse.replication.tcp.protocol import AbstractConnection
 from synapse.replication.tcp.streams import (
     STREAMS_MAP,
+    BackfillStream,
     CachesStream,
+    EventsStream,
     FederationStream,
     Stream,
 )
@@ -87,6 +89,14 @@ class ReplicationCommandHandler:
                 self._streams_to_replicate.append(stream)
                 continue
 
+            if isinstance(stream, (EventsStream, BackfillStream)):
+                # Only add EventStream and BackfillStream as a source on the
+                # instance in charge of event persistence.
+                if hs.config.worker.writers.events == hs.get_instance_name():
+                    self._streams_to_replicate.append(stream)
+
+                continue
+
             # Only add any other streams if we're on master.
             if hs.config.worker_app is not None:
                 continue
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 0a13e1ed34..8173baef8f 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -100,7 +100,9 @@ class ShutdownRoomRestServlet(RestServlet):
         # we try and auto join below.
         #
         # TODO: Currently the events stream is written to from master
-        await self._replication.wait_for_stream_position("master", "events", stream_id)
+        await self._replication.wait_for_stream_position(
+            self.hs.config.worker.writers.events, "events", stream_id
+        )
 
         users = await self.state.get_current_users_in_room(room_id)
         kicked_users = []
@@ -113,7 +115,7 @@ class ShutdownRoomRestServlet(RestServlet):
 
             try:
                 target_requester = create_requester(user_id)
-                await self.room_member_handler.update_membership(
+                _, stream_id = await self.room_member_handler.update_membership(
                     requester=target_requester,
                     target=target_requester.user,
                     room_id=room_id,
@@ -123,6 +125,11 @@ class ShutdownRoomRestServlet(RestServlet):
                     require_consent=False,
                 )
 
+                # Wait for leave to come in over replication before trying to forget.
+                await self._replication.wait_for_stream_position(
+                    self.hs.config.worker.writers.events, "events", stream_id
+                )
+
                 await self.room_member_handler.forget(target_requester.user, room_id)
 
                 await self.room_member_handler.update_membership(
diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py
index 791961b296..599ee470d4 100644
--- a/synapse/storage/data_stores/__init__.py
+++ b/synapse/storage/data_stores/__init__.py
@@ -66,9 +66,9 @@ class DataStores(object):
 
                     self.main = main_store_class(database, db_conn, hs)
 
-                    # If we're on a process that can persist events (currently
-                    # master), also instantiate a `PersistEventsStore`
-                    if hs.config.worker.worker_app is None:
+                    # If we're on a process that can persist events also
+                    # instantiate a `PersistEventsStore`
+                    if hs.config.worker.writers.events == hs.get_instance_name():
                         self.persist_events = PersistEventsStore(
                             hs, database, self.main
                         )
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 0e8378714a..417ac8dc7c 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -689,6 +689,25 @@ class DeviceWorkerStore(SQLBaseStore):
             desc="make_remote_user_device_cache_as_stale",
         )
 
+    def mark_remote_user_device_list_as_unsubscribed(self, user_id):
+        """Mark that we no longer track device lists for remote user.
+        """
+
+        def _mark_remote_user_device_list_as_unsubscribed_txn(txn):
+            self.db.simple_delete_txn(
+                txn,
+                table="device_lists_remote_extremeties",
+                keyvalues={"user_id": user_id},
+            )
+            self._invalidate_cache_and_stream(
+                txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
+            )
+
+        return self.db.runInteraction(
+            "mark_remote_user_device_list_as_unsubscribed",
+            _mark_remote_user_device_list_as_unsubscribed_txn,
+        )
+
 
 class DeviceBackgroundUpdateStore(SQLBaseStore):
     def __init__(self, database: Database, db_conn, hs):
@@ -969,17 +988,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             desc="update_device",
         )
 
-    @defer.inlineCallbacks
-    def mark_remote_user_device_list_as_unsubscribed(self, user_id):
-        """Mark that we no longer track device lists for remote user.
-        """
-        yield self.db.simple_delete(
-            table="device_lists_remote_extremeties",
-            keyvalues={"user_id": user_id},
-            desc="mark_remote_user_device_list_as_unsubscribed",
-        )
-        self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
-
     def update_remote_device_list_cache_entry(
         self, user_id, device_id, content, stream_id
     ):
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index a97f8b3934..a6572571b4 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -138,10 +138,10 @@ class PersistEventsStore:
         self._backfill_id_gen = self.store._backfill_id_gen  # type: StreamIdGenerator
         self._stream_id_gen = self.store._stream_id_gen  # type: StreamIdGenerator
 
-        # This should only exist on master for now
+        # This should only exist on instances that are configured to write
         assert (
-            hs.config.worker.worker_app is None
-        ), "Can only instantiate PersistEventsStore on master"
+            hs.config.worker.writers.events == hs.get_instance_name()
+        ), "Can only instantiate EventsStore on master"
 
     @_retry_on_integrity_error
     @defer.inlineCallbacks
@@ -1590,3 +1590,31 @@ class PersistEventsStore:
                 if not ev.internal_metadata.is_outlier()
             ],
         )
+
+    async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
+        """Mark the invite has having been rejected even though we failed to
+        create a leave event for it.
+        """
+
+        sql = (
+            "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
+            " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+            " AND replaced_by is NULL"
+        )
+
+        def f(txn, stream_ordering):
+            txn.execute(sql, (stream_ordering, True, room_id, user_id))
+
+            # We also clear this entry from `local_current_membership`.
+            # Ideally we'd point to a leave event, but we don't have one, so
+            # nevermind.
+            self.db.simple_delete_txn(
+                txn,
+                table="local_current_membership",
+                keyvalues={"room_id": room_id, "user_id": user_id},
+            )
+
+        with self._stream_id_gen.get_next() as stream_ordering:
+            await self.db.runInteraction("locally_reject_invite", f, stream_ordering)
+
+        return stream_ordering
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index b880a71782..213d69100a 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -76,7 +76,7 @@ class EventsWorkerStore(SQLBaseStore):
     def __init__(self, database: Database, db_conn, hs):
         super(EventsWorkerStore, self).__init__(database, db_conn, hs)
 
-        if hs.config.worker_app is None:
+        if hs.config.worker.writers.events == hs.get_instance_name():
             # We are the process in charge of generating stream ids for events,
             # so instantiate ID generators based on the database
             self._stream_id_gen = StreamIdGenerator(
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index 7c5ca81ae0..137ebac833 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -1046,31 +1046,6 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
     def __init__(self, database: Database, db_conn, hs):
         super(RoomMemberStore, self).__init__(database, db_conn, hs)
 
-    @defer.inlineCallbacks
-    def locally_reject_invite(self, user_id, room_id):
-        sql = (
-            "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
-            " room_id = ? AND invitee = ? AND locally_rejected is NULL"
-            " AND replaced_by is NULL"
-        )
-
-        def f(txn, stream_ordering):
-            txn.execute(sql, (stream_ordering, True, room_id, user_id))
-
-            # We also clear this entry from `local_current_membership`.
-            # Ideally we'd point to a leave event, but we don't have one, so
-            # nevermind.
-            self.db.simple_delete_txn(
-                txn,
-                table="local_current_membership",
-                keyvalues={"room_id": room_id, "user_id": user_id},
-            )
-
-        with self._stream_id_gen.get_next() as stream_ordering:
-            yield self.db.runInteraction("locally_reject_invite", f, stream_ordering)
-
-        return stream_ordering
-
     def forget(self, user_id, room_id):
         """Indicate that user_id wishes to discard history for room_id."""
 
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 12e1ffb9a2..f159400a87 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -786,3 +786,9 @@ class EventsPersistenceStorage(object):
 
         for user_id in left_users:
             await self.main_store.mark_remote_user_device_list_as_unsubscribed(user_id)
+
+    async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
+        """Mark the invite has having been rejected even though we failed to
+        create a leave event for it.
+        """
+        return await self.persist_events_store.locally_reject_invite(user_id, room_id)