summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2021-04-23 12:21:55 +0100
committerGitHub <noreply@github.com>2021-04-23 12:21:55 +0100
commit9d25a0ae65ce8728d0fda1eebaf0b469316f84d7 (patch)
treefca194d9ae52b9e60eeb4f9d72d1efa9b9b90682 /synapse
parentCheck for space membership during a remote join of a restricted room (#9814) (diff)
downloadsynapse-9d25a0ae65ce8728d0fda1eebaf0b469316f84d7.tar.xz
Split presence out of master (#9820)
Diffstat (limited to 'synapse')
-rw-r--r--synapse/app/generic_worker.py31
-rw-r--r--synapse/config/workers.py27
-rw-r--r--synapse/handlers/presence.py56
-rw-r--r--synapse/replication/http/_base.py5
-rw-r--r--synapse/replication/slave/storage/presence.py50
-rw-r--r--synapse/replication/tcp/handler.py18
-rw-r--r--synapse/replication/tcp/streams/_base.py17
-rw-r--r--synapse/rest/client/v1/presence.py7
-rw-r--r--synapse/server.py6
-rw-r--r--synapse/storage/databases/main/__init__.py47
-rw-r--r--synapse/storage/databases/main/presence.py92
-rw-r--r--synapse/storage/databases/main/schema/delta/59/12presence_stream_instance.sql18
-rw-r--r--synapse/storage/databases/main/schema/delta/59/12presence_stream_instance_seq.sql.postgres20
13 files changed, 236 insertions, 158 deletions
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 26c458dbb6..7b2ac3ca64 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -55,7 +55,6 @@ from synapse.replication.slave.storage.events import SlavedEventStore
 from synapse.replication.slave.storage.filtering import SlavedFilteringStore
 from synapse.replication.slave.storage.groups import SlavedGroupServerStore
 from synapse.replication.slave.storage.keys import SlavedKeyStore
-from synapse.replication.slave.storage.presence import SlavedPresenceStore
 from synapse.replication.slave.storage.profile import SlavedProfileStore
 from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
 from synapse.replication.slave.storage.pushers import SlavedPusherStore
@@ -64,7 +63,7 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
 from synapse.replication.slave.storage.room import RoomStore
 from synapse.replication.slave.storage.transactions import SlavedTransactionStore
 from synapse.rest.admin import register_servlets_for_media_repo
-from synapse.rest.client.v1 import events, login, room
+from synapse.rest.client.v1 import events, login, presence, room
 from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
 from synapse.rest.client.v1.profile import (
     ProfileAvatarURLRestServlet,
@@ -110,6 +109,7 @@ from synapse.storage.databases.main.metrics import ServerMetricsStore
 from synapse.storage.databases.main.monthly_active_users import (
     MonthlyActiveUsersWorkerStore,
 )
+from synapse.storage.databases.main.presence import PresenceStore
 from synapse.storage.databases.main.search import SearchWorkerStore
 from synapse.storage.databases.main.stats import StatsStore
 from synapse.storage.databases.main.transactions import TransactionWorkerStore
@@ -121,26 +121,6 @@ from synapse.util.versionstring import get_version_string
 logger = logging.getLogger("synapse.app.generic_worker")
 
 
-class PresenceStatusStubServlet(RestServlet):
-    """If presence is disabled this servlet can be used to stub out setting
-    presence status.
-    """
-
-    PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status")
-
-    def __init__(self, hs):
-        super().__init__()
-        self.auth = hs.get_auth()
-
-    async def on_GET(self, request, user_id):
-        await self.auth.get_user_by_req(request)
-        return 200, {"presence": "offline"}
-
-    async def on_PUT(self, request, user_id):
-        await self.auth.get_user_by_req(request)
-        return 200, {}
-
-
 class KeyUploadServlet(RestServlet):
     """An implementation of the `KeyUploadServlet` that responds to read only
     requests, but otherwise proxies through to the master instance.
@@ -241,6 +221,7 @@ class GenericWorkerSlavedStore(
     StatsStore,
     UIAuthWorkerStore,
     EndToEndRoomKeyStore,
+    PresenceStore,
     SlavedDeviceInboxStore,
     SlavedDeviceStore,
     SlavedReceiptsStore,
@@ -259,7 +240,6 @@ class GenericWorkerSlavedStore(
     SlavedTransactionStore,
     SlavedProfileStore,
     SlavedClientIpStore,
-    SlavedPresenceStore,
     SlavedFilteringStore,
     MonthlyActiveUsersWorkerStore,
     MediaRepositoryStore,
@@ -327,10 +307,7 @@ class GenericWorkerServer(HomeServer):
 
                     user_directory.register_servlets(self, resource)
 
-                    # If presence is disabled, use the stub servlet that does
-                    # not allow sending presence
-                    if not self.config.use_presence:
-                        PresenceStatusStubServlet(self).register(resource)
+                    presence.register_servlets(self, resource)
 
                     groups.register_servlets(self, resource)
 
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index b2540163d1..462630201d 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -64,6 +64,14 @@ class WriterLocations:
     Attributes:
         events: The instances that write to the event and backfill streams.
         typing: The instance that writes to the typing stream.
+        to_device: The instances that write to the to_device stream. Currently
+            can only be a single instance.
+        account_data: The instances that write to the account data streams. Currently
+            can only be a single instance.
+        receipts: The instances that write to the receipts stream. Currently
+            can only be a single instance.
+        presence: The instances that write to the presence stream. Currently
+            can only be a single instance.
     """
 
     events = attr.ib(
@@ -85,6 +93,11 @@ class WriterLocations:
         type=List[str],
         converter=_instance_to_list_converter,
     )
+    presence = attr.ib(
+        default=["master"],
+        type=List[str],
+        converter=_instance_to_list_converter,
+    )
 
 
 class WorkerConfig(Config):
@@ -188,7 +201,14 @@ class WorkerConfig(Config):
 
         # Check that the configured writers for events and typing also appears in
         # `instance_map`.
-        for stream in ("events", "typing", "to_device", "account_data", "receipts"):
+        for stream in (
+            "events",
+            "typing",
+            "to_device",
+            "account_data",
+            "receipts",
+            "presence",
+        ):
             instances = _instance_to_list_converter(getattr(self.writers, stream))
             for instance in instances:
                 if instance != "master" and instance not in self.instance_map:
@@ -215,6 +235,11 @@ class WorkerConfig(Config):
         if len(self.writers.events) == 0:
             raise ConfigError("Must specify at least one instance to handle `events`.")
 
+        if len(self.writers.presence) != 1:
+            raise ConfigError(
+                "Must only specify one instance to handle `presence` messages."
+            )
+
         self.events_shard_config = RoutableShardedWorkerHandlingConfig(
             self.writers.events
         )
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 7fd28ffa54..9938be3821 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -122,7 +122,8 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
 
 
 class BasePresenceHandler(abc.ABC):
-    """Parts of the PresenceHandler that are shared between workers and master"""
+    """Parts of the PresenceHandler that are shared between workers and presence
+    writer"""
 
     def __init__(self, hs: "HomeServer"):
         self.clock = hs.get_clock()
@@ -309,8 +310,16 @@ class WorkerPresenceHandler(BasePresenceHandler):
         super().__init__(hs)
         self.hs = hs
 
+        self._presence_writer_instance = hs.config.worker.writers.presence[0]
+
         self._presence_enabled = hs.config.use_presence
 
+        # Route presence EDUs to the right worker
+        hs.get_federation_registry().register_instances_for_edu(
+            "m.presence",
+            hs.config.worker.writers.presence,
+        )
+
         # The number of ongoing syncs on this process, by user id.
         # Empty if _presence_enabled is false.
         self._user_to_num_current_syncs = {}  # type: Dict[str, int]
@@ -318,8 +327,8 @@ class WorkerPresenceHandler(BasePresenceHandler):
         self.notifier = hs.get_notifier()
         self.instance_id = hs.get_instance_id()
 
-        # user_id -> last_sync_ms. Lists the users that have stopped syncing
-        # but we haven't notified the master of that yet
+        # user_id -> last_sync_ms. Lists the users that have stopped syncing but
+        # we haven't notified the presence writer of that yet
         self.users_going_offline = {}
 
         self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs)
@@ -352,22 +361,23 @@ class WorkerPresenceHandler(BasePresenceHandler):
             )
 
     def mark_as_coming_online(self, user_id):
-        """A user has started syncing. Send a UserSync to the master, unless they
-        had recently stopped syncing.
+        """A user has started syncing. Send a UserSync to the presence writer,
+        unless they had recently stopped syncing.
 
         Args:
             user_id (str)
         """
         going_offline = self.users_going_offline.pop(user_id, None)
         if not going_offline:
-            # Safe to skip because we haven't yet told the master they were offline
+            # Safe to skip because we haven't yet told the presence writer they
+            # were offline
             self.send_user_sync(user_id, True, self.clock.time_msec())
 
     def mark_as_going_offline(self, user_id):
-        """A user has stopped syncing. We wait before notifying the master as
-        its likely they'll come back soon. This allows us to avoid sending
-        a stopped syncing immediately followed by a started syncing notification
-        to the master
+        """A user has stopped syncing. We wait before notifying the presence
+        writer as its likely they'll come back soon. This allows us to avoid
+        sending a stopped syncing immediately followed by a started syncing
+        notification to the presence writer
 
         Args:
             user_id (str)
@@ -375,8 +385,8 @@ class WorkerPresenceHandler(BasePresenceHandler):
         self.users_going_offline[user_id] = self.clock.time_msec()
 
     def send_stop_syncing(self):
-        """Check if there are any users who have stopped syncing a while ago
-        and haven't come back yet. If there are poke the master about them.
+        """Check if there are any users who have stopped syncing a while ago and
+        haven't come back yet. If there are poke the presence writer about them.
         """
         now = self.clock.time_msec()
         for user_id, last_sync_ms in list(self.users_going_offline.items()):
@@ -492,9 +502,12 @@ class WorkerPresenceHandler(BasePresenceHandler):
         if not self.hs.config.use_presence:
             return
 
-        # Proxy request to master
+        # Proxy request to instance that writes presence
         await self._set_state_client(
-            user_id=user_id, state=state, ignore_status_msg=ignore_status_msg
+            instance_name=self._presence_writer_instance,
+            user_id=user_id,
+            state=state,
+            ignore_status_msg=ignore_status_msg,
         )
 
     async def bump_presence_active_time(self, user):
@@ -505,9 +518,11 @@ class WorkerPresenceHandler(BasePresenceHandler):
         if not self.hs.config.use_presence:
             return
 
-        # Proxy request to master
+        # Proxy request to instance that writes presence
         user_id = user.to_string()
-        await self._bump_active_client(user_id=user_id)
+        await self._bump_active_client(
+            instance_name=self._presence_writer_instance, user_id=user_id
+        )
 
 
 class PresenceHandler(BasePresenceHandler):
@@ -1909,7 +1924,7 @@ class PresenceFederationQueue:
         self._queue_presence_updates = True
 
         # Whether this instance is a presence writer.
-        self._presence_writer = hs.config.worker.worker_app is None
+        self._presence_writer = self._instance_name in hs.config.worker.writers.presence
 
         # The FederationSender instance, if this process sends federation traffic directly.
         self._federation = None
@@ -1957,7 +1972,7 @@ class PresenceFederationQueue:
         Will forward to the local federation sender (if there is one) and queue
         to send over replication (if there are other federation sender instances.).
 
-        Must only be called on the master process.
+        Must only be called on the presence writer process.
         """
 
         # This should only be called on a presence writer.
@@ -2003,10 +2018,11 @@ class PresenceFederationQueue:
         We return rows in the form of `(destination, user_id)` to keep the size
         of each row bounded (rather than returning the sets in a row).
 
-        On workers this will query the master process via HTTP replication.
+        On workers this will query the presence writer process via HTTP replication.
         """
         if instance_name != self._instance_name:
-            # If not local we query over http replication from the master
+            # If not local we query over http replication from the presence
+            # writer
             result = await self._repl_client(
                 instance_name=instance_name,
                 stream_name=PresenceFederationStream.NAME,
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index ece03467b5..5685cf2121 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -158,7 +158,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
     def make_client(cls, hs):
         """Create a client that makes requests.
 
-        Returns a callable that accepts the same parameters as `_serialize_payload`.
+        Returns a callable that accepts the same parameters as
+        `_serialize_payload`, and also accepts an optional `instance_name`
+        parameter to specify which instance to hit (the instance must be in
+        the `instance_map` config).
         """
         clock = hs.get_clock()
         client = hs.get_simple_http_client()
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
deleted file mode 100644
index 57327d910d..0000000000
--- a/synapse/replication/slave/storage/presence.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# Copyright 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.replication.tcp.streams import PresenceStream
-from synapse.storage import DataStore
-from synapse.storage.database import DatabasePool
-from synapse.storage.databases.main.presence import PresenceStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
-
-from ._base import BaseSlavedStore
-from ._slaved_id_tracker import SlavedIdTracker
-
-
-class SlavedPresenceStore(BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        super().__init__(database, db_conn, hs)
-        self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id")
-
-        self._presence_on_startup = self._get_active_presence(db_conn)  # type: ignore
-
-        self.presence_stream_cache = StreamChangeCache(
-            "PresenceStreamChangeCache", self._presence_id_gen.get_current_token()
-        )
-
-    _get_active_presence = DataStore._get_active_presence
-    take_presence_startup_info = DataStore.take_presence_startup_info
-    _get_presence_for_user = PresenceStore.__dict__["_get_presence_for_user"]
-    get_presence_for_users = PresenceStore.__dict__["get_presence_for_users"]
-
-    def get_current_presence_token(self):
-        return self._presence_id_gen.get_current_token()
-
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == PresenceStream.NAME:
-            self._presence_id_gen.advance(instance_name, token)
-            for row in rows:
-                self.presence_stream_cache.entity_has_changed(row.user_id, token)
-                self._get_presence_for_user.invalidate((row.user_id,))
-        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 2ce1b9f222..7ced4c543c 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -55,6 +55,8 @@ from synapse.replication.tcp.streams import (
     CachesStream,
     EventsStream,
     FederationStream,
+    PresenceFederationStream,
+    PresenceStream,
     ReceiptsStream,
     Stream,
     TagAccountDataStream,
@@ -99,6 +101,10 @@ class ReplicationCommandHandler:
         self._instance_id = hs.get_instance_id()
         self._instance_name = hs.get_instance_name()
 
+        self._is_presence_writer = (
+            hs.get_instance_name() in hs.config.worker.writers.presence
+        )
+
         self._streams = {
             stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
         }  # type: Dict[str, Stream]
@@ -153,6 +159,14 @@ class ReplicationCommandHandler:
 
                 continue
 
+            if isinstance(stream, (PresenceStream, PresenceFederationStream)):
+                # Only add PresenceStream as a source on the instance in charge
+                # of presence.
+                if self._is_presence_writer:
+                    self._streams_to_replicate.append(stream)
+
+                continue
+
             # Only add any other streams if we're on master.
             if hs.config.worker_app is not None:
                 continue
@@ -350,7 +364,7 @@ class ReplicationCommandHandler:
     ) -> Optional[Awaitable[None]]:
         user_sync_counter.inc()
 
-        if self._is_master:
+        if self._is_presence_writer:
             return self._presence_handler.update_external_syncs_row(
                 cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
             )
@@ -360,7 +374,7 @@ class ReplicationCommandHandler:
     def on_CLEAR_USER_SYNC(
         self, conn: IReplicationConnection, cmd: ClearUserSyncsCommand
     ) -> Optional[Awaitable[None]]:
-        if self._is_master:
+        if self._is_presence_writer:
             return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
         else:
             return None
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 9d75a89f1c..b03824925a 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -272,15 +272,22 @@ class PresenceStream(Stream):
     NAME = "presence"
     ROW_TYPE = PresenceStreamRow
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         store = hs.get_datastore()
 
-        if hs.config.worker_app is None:
-            # on the master, query the presence handler
+        if hs.get_instance_name() in hs.config.worker.writers.presence:
+            # on the presence writer, query the presence handler
             presence_handler = hs.get_presence_handler()
-            update_function = presence_handler.get_all_presence_updates
+
+            from synapse.handlers.presence import PresenceHandler
+
+            assert isinstance(presence_handler, PresenceHandler)
+
+            update_function = (
+                presence_handler.get_all_presence_updates
+            )  # type: UpdateFunction
         else:
-            # Query master process
+            # Query presence writer process
             update_function = make_http_update_function(hs, self.NAME)
 
         super().__init__(
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index c232484f29..2b24fe5aa6 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -35,10 +35,15 @@ class PresenceStatusRestServlet(RestServlet):
         self.clock = hs.get_clock()
         self.auth = hs.get_auth()
 
+        self._use_presence = hs.config.server.use_presence
+
     async def on_GET(self, request, user_id):
         requester = await self.auth.get_user_by_req(request)
         user = UserID.from_string(user_id)
 
+        if not self._use_presence:
+            return 200, {"presence": "offline"}
+
         if requester.user != user:
             allowed = await self.presence_handler.is_visible(
                 observed_user=user, observer_user=requester.user
@@ -80,7 +85,7 @@ class PresenceStatusRestServlet(RestServlet):
         except Exception:
             raise SynapseError(400, "Unable to parse state")
 
-        if self.hs.config.use_presence:
+        if self._use_presence:
             await self.presence_handler.set_state(user, state)
 
         return 200, {}
diff --git a/synapse/server.py b/synapse/server.py
index 67598fffe3..8c147be2b3 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -418,10 +418,10 @@ class HomeServer(metaclass=abc.ABCMeta):
 
     @cache_in_self
     def get_presence_handler(self) -> BasePresenceHandler:
-        if self.config.worker_app:
-            return WorkerPresenceHandler(self)
-        else:
+        if self.get_instance_name() in self.config.worker.writers.presence:
             return PresenceHandler(self)
+        else:
+            return WorkerPresenceHandler(self)
 
     @cache_in_self
     def get_typing_writer_handler(self) -> TypingWriterHandler:
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 5c50f5f950..49c7606d51 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -17,7 +17,6 @@
 import logging
 from typing import List, Optional, Tuple
 
-from synapse.api.constants import PresenceState
 from synapse.config.homeserver import HomeServerConfig
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.stats import UserSortOrder
@@ -51,7 +50,7 @@ from .media_repository import MediaRepositoryStore
 from .metrics import ServerMetricsStore
 from .monthly_active_users import MonthlyActiveUsersStore
 from .openid import OpenIdStore
-from .presence import PresenceStore, UserPresenceState
+from .presence import PresenceStore
 from .profile import ProfileStore
 from .purge_events import PurgeEventsStore
 from .push_rule import PushRuleStore
@@ -126,9 +125,6 @@ class DataStore(
         self._clock = hs.get_clock()
         self.database_engine = database.engine
 
-        self._presence_id_gen = StreamIdGenerator(
-            db_conn, "presence_stream", "stream_id"
-        )
         self._public_room_id_gen = StreamIdGenerator(
             db_conn, "public_room_list_stream", "stream_id"
         )
@@ -177,21 +173,6 @@ class DataStore(
 
         super().__init__(database, db_conn, hs)
 
-        self._presence_on_startup = self._get_active_presence(db_conn)
-
-        presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict(
-            db_conn,
-            "presence_stream",
-            entity_column="user_id",
-            stream_column="stream_id",
-            max_value=self._presence_id_gen.get_current_token(),
-        )
-        self.presence_stream_cache = StreamChangeCache(
-            "PresenceStreamChangeCache",
-            min_presence_val,
-            prefilled_cache=presence_cache_prefill,
-        )
-
         device_list_max = self._device_list_id_gen.get_current_token()
         self._device_list_stream_cache = StreamChangeCache(
             "DeviceListStreamChangeCache", device_list_max
@@ -238,32 +219,6 @@ class DataStore(
     def get_device_stream_token(self) -> int:
         return self._device_list_id_gen.get_current_token()
 
-    def take_presence_startup_info(self):
-        active_on_startup = self._presence_on_startup
-        self._presence_on_startup = None
-        return active_on_startup
-
-    def _get_active_presence(self, db_conn):
-        """Fetch non-offline presence from the database so that we can register
-        the appropriate time outs.
-        """
-
-        sql = (
-            "SELECT user_id, state, last_active_ts, last_federation_update_ts,"
-            " last_user_sync_ts, status_msg, currently_active FROM presence_stream"
-            " WHERE state != ?"
-        )
-
-        txn = db_conn.cursor()
-        txn.execute(sql, (PresenceState.OFFLINE,))
-        rows = self.db_pool.cursor_to_dict(txn)
-        txn.close()
-
-        for row in rows:
-            row["currently_active"] = bool(row["currently_active"])
-
-        return [UserPresenceState(**row) for row in rows]
-
     async def get_users(self) -> List[JsonDict]:
         """Function to retrieve a list of users in users table.
 
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index c207d917b1..db22fab23e 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -12,16 +12,69 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Dict, List, Tuple
+from typing import TYPE_CHECKING, Dict, List, Tuple
 
-from synapse.api.presence import UserPresenceState
+from synapse.api.presence import PresenceState, UserPresenceState
+from synapse.replication.tcp.streams import PresenceStream
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import DatabasePool
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.types import Connection
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
 from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.iterutils import batch_iter
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 
 class PresenceStore(SQLBaseStore):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: Connection,
+        hs: "HomeServer",
+    ):
+        super().__init__(database, db_conn, hs)
+
+        self._can_persist_presence = (
+            hs.get_instance_name() in hs.config.worker.writers.presence
+        )
+
+        if isinstance(database.engine, PostgresEngine):
+            self._presence_id_gen = MultiWriterIdGenerator(
+                db_conn=db_conn,
+                db=database,
+                stream_name="presence_stream",
+                instance_name=self._instance_name,
+                tables=[("presence_stream", "instance_name", "stream_id")],
+                sequence_name="presence_stream_sequence",
+                writers=hs.config.worker.writers.to_device,
+            )
+        else:
+            self._presence_id_gen = StreamIdGenerator(
+                db_conn, "presence_stream", "stream_id"
+            )
+
+        self._presence_on_startup = self._get_active_presence(db_conn)
+
+        presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict(
+            db_conn,
+            "presence_stream",
+            entity_column="user_id",
+            stream_column="stream_id",
+            max_value=self._presence_id_gen.get_current_token(),
+        )
+        self.presence_stream_cache = StreamChangeCache(
+            "PresenceStreamChangeCache",
+            min_presence_val,
+            prefilled_cache=presence_cache_prefill,
+        )
+
     async def update_presence(self, presence_states):
+        assert self._can_persist_presence
+
         stream_ordering_manager = self._presence_id_gen.get_next_mult(
             len(presence_states)
         )
@@ -57,6 +110,7 @@ class PresenceStore(SQLBaseStore):
                     "last_user_sync_ts": state.last_user_sync_ts,
                     "status_msg": state.status_msg,
                     "currently_active": state.currently_active,
+                    "instance_name": self._instance_name,
                 }
                 for stream_id, state in zip(stream_orderings, presence_states)
             ],
@@ -216,3 +270,37 @@ class PresenceStore(SQLBaseStore):
 
     def get_current_presence_token(self):
         return self._presence_id_gen.get_current_token()
+
+    def _get_active_presence(self, db_conn: Connection):
+        """Fetch non-offline presence from the database so that we can register
+        the appropriate time outs.
+        """
+
+        sql = (
+            "SELECT user_id, state, last_active_ts, last_federation_update_ts,"
+            " last_user_sync_ts, status_msg, currently_active FROM presence_stream"
+            " WHERE state != ?"
+        )
+
+        txn = db_conn.cursor()
+        txn.execute(sql, (PresenceState.OFFLINE,))
+        rows = self.db_pool.cursor_to_dict(txn)
+        txn.close()
+
+        for row in rows:
+            row["currently_active"] = bool(row["currently_active"])
+
+        return [UserPresenceState(**row) for row in rows]
+
+    def take_presence_startup_info(self):
+        active_on_startup = self._presence_on_startup
+        self._presence_on_startup = None
+        return active_on_startup
+
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
+        if stream_name == PresenceStream.NAME:
+            self._presence_id_gen.advance(instance_name, token)
+            for row in rows:
+                self.presence_stream_cache.entity_has_changed(row.user_id, token)
+                self._get_presence_for_user.invalidate((row.user_id,))
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/storage/databases/main/schema/delta/59/12presence_stream_instance.sql b/synapse/storage/databases/main/schema/delta/59/12presence_stream_instance.sql
new file mode 100644
index 0000000000..b6ba0bda1a
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/12presence_stream_instance.sql
@@ -0,0 +1,18 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Add a column to specify which instance wrote the row. Historic rows have
+-- `NULL`, which indicates that the master instance wrote them.
+ALTER TABLE presence_stream ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/databases/main/schema/delta/59/12presence_stream_instance_seq.sql.postgres b/synapse/storage/databases/main/schema/delta/59/12presence_stream_instance_seq.sql.postgres
new file mode 100644
index 0000000000..02b182adf9
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/12presence_stream_instance_seq.sql.postgres
@@ -0,0 +1,20 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE IF NOT EXISTS presence_stream_sequence;
+
+SELECT setval('presence_stream_sequence', (
+    SELECT COALESCE(MAX(stream_id), 1) FROM presence_stream
+));