summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorWill Hunt <will@half-shot.uk>2020-09-28 11:11:57 +0100
committerGitHub <noreply@github.com>2020-09-28 11:11:57 +0100
commit6201ea56eeae22258d58db2f7381f96759f5e85e (patch)
tree37cb630a46fe96e416110de253ad50d256c0482e /synapse/storage
parentAdd support for device messages, start support for device lists (diff)
parenttypo (diff)
downloadsynapse-6201ea56eeae22258d58db2f7381f96759f5e85e.tar.xz
Merge branch 'develop' into hs/super-wip-edus-down-sync
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/__init__.py8
-rw-r--r--synapse/storage/databases/main/account_data.py4
-rw-r--r--synapse/storage/databases/main/deviceinbox.py4
-rw-r--r--synapse/storage/databases/main/devices.py6
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py2
-rw-r--r--synapse/storage/databases/main/events.py16
-rw-r--r--synapse/storage/databases/main/events_worker.py4
-rw-r--r--synapse/storage/databases/main/group_server.py2
-rw-r--r--synapse/storage/databases/main/presence.py4
-rw-r--r--synapse/storage/databases/main/push_rule.py8
-rw-r--r--synapse/storage/databases/main/pusher.py4
-rw-r--r--synapse/storage/databases/main/receipts.py2
-rw-r--r--synapse/storage/databases/main/registration.py18
-rw-r--r--synapse/storage/databases/main/room.py101
-rw-r--r--synapse/storage/databases/main/roommember.py14
-rw-r--r--synapse/storage/databases/main/schema/delta/56/event_labels.sql2
-rw-r--r--synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres4
-rw-r--r--synapse/storage/databases/main/schema/delta/58/18stream_positions.sql22
-rw-r--r--synapse/storage/databases/main/stats.py2
-rw-r--r--synapse/storage/databases/main/tags.py4
-rw-r--r--synapse/storage/databases/main/user_erasure_store.py2
-rw-r--r--synapse/storage/persist_events.py14
-rw-r--r--synapse/storage/roommember.py2
-rw-r--r--synapse/storage/util/id_generators.py282
24 files changed, 410 insertions, 121 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index ccb3384db9..0cb12f4c61 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -160,14 +160,20 @@ class DataStore(
         )
 
         if isinstance(self.database_engine, PostgresEngine):
+            # We set the `writers` to an empty list here as we don't care about
+            # missing updates over restarts, as we'll not have anything in our
+            # caches to invalidate. (This reduces the amount of writes to the DB
+            # that happen).
             self._cache_id_gen = MultiWriterIdGenerator(
                 db_conn,
                 database,
-                instance_name="master",
+                stream_name="caches",
+                instance_name=hs.get_instance_name(),
                 table="cache_invalidation_stream_by_instance",
                 instance_column="instance_name",
                 id_column="stream_id",
                 sequence_name="cache_invalidation_stream_seq",
+                writers=[],
             )
         else:
             self._cache_id_gen = None
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index c5a36990e4..ef81d73573 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -339,7 +339,7 @@ class AccountDataStore(AccountDataWorkerStore):
         """
         content_json = json_encoder.encode(content)
 
-        with await self._account_data_id_gen.get_next() as next_id:
+        async with self._account_data_id_gen.get_next() as next_id:
             # no need to lock here as room_account_data has a unique constraint
             # on (user_id, room_id, account_data_type) so simple_upsert will
             # retry if there is a conflict.
@@ -387,7 +387,7 @@ class AccountDataStore(AccountDataWorkerStore):
         """
         content_json = json_encoder.encode(content)
 
-        with await self._account_data_id_gen.get_next() as next_id:
+        async with self._account_data_id_gen.get_next() as next_id:
             # no need to lock here as account_data has a unique constraint on
             # (user_id, account_data_type) so simple_upsert will retry if
             # there is a conflict.
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index e4fd979a33..2d151b9134 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -394,7 +394,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
                 rows.append((destination, stream_id, now_ms, edu_json))
             txn.executemany(sql, rows)
 
-        with await self._device_inbox_id_gen.get_next() as stream_id:
+        async with self._device_inbox_id_gen.get_next() as stream_id:
             now_ms = self.clock.time_msec()
             await self.db_pool.runInteraction(
                 "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
@@ -443,7 +443,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
                 txn, stream_id, local_messages_by_user_then_device
             )
 
-        with await self._device_inbox_id_gen.get_next() as stream_id:
+        async with self._device_inbox_id_gen.get_next() as stream_id:
             now_ms = self.clock.time_msec()
             await self.db_pool.runInteraction(
                 "add_messages_from_remote_to_device_inbox",
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index c04374e43d..fdf394c612 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -377,7 +377,7 @@ class DeviceWorkerStore(SQLBaseStore):
             THe new stream ID.
         """
 
-        with await self._device_list_id_gen.get_next() as stream_id:
+        async with self._device_list_id_gen.get_next() as stream_id:
             await self.db_pool.runInteraction(
                 "add_user_sig_change_to_streams",
                 self._add_user_signature_change_txn,
@@ -1093,7 +1093,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         if not device_ids:
             return
 
-        with await self._device_list_id_gen.get_next_mult(
+        async with self._device_list_id_gen.get_next_mult(
             len(device_ids)
         ) as stream_ids:
             await self.db_pool.runInteraction(
@@ -1108,7 +1108,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             return stream_ids[-1]
 
         context = get_active_span_text_map()
-        with await self._device_list_id_gen.get_next_mult(
+        async with self._device_list_id_gen.get_next_mult(
             len(hosts) * len(device_ids)
         ) as stream_ids:
             await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index c8df0bcb3f..22e1ed15d0 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -831,7 +831,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             key (dict): the key data
         """
 
-        with await self._cross_signing_id_gen.get_next() as stream_id:
+        async with self._cross_signing_id_gen.get_next() as stream_id:
             return await self.db_pool.runInteraction(
                 "add_e2e_cross_signing_key",
                 self._set_e2e_cross_signing_key_txn,
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 9a80f419e3..18def01f50 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -17,7 +17,7 @@
 import itertools
 import logging
 from collections import OrderedDict, namedtuple
-from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
 
 import attr
 from prometheus_client import Counter
@@ -156,15 +156,15 @@ class PersistEventsStore:
         # Note: Multiple instances of this function cannot be in flight at
         # the same time for the same room.
         if backfilled:
-            stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
+            stream_ordering_manager = self._backfill_id_gen.get_next_mult(
                 len(events_and_contexts)
             )
         else:
-            stream_ordering_manager = await self._stream_id_gen.get_next_mult(
+            stream_ordering_manager = self._stream_id_gen.get_next_mult(
                 len(events_and_contexts)
             )
 
-        with stream_ordering_manager as stream_orderings:
+        async with stream_ordering_manager as stream_orderings:
             for (event, context), stream in zip(events_and_contexts, stream_orderings):
                 event.internal_metadata.stream_ordering = stream
 
@@ -1108,6 +1108,10 @@ class PersistEventsStore:
     def _store_room_members_txn(self, txn, events, backfilled):
         """Store a room member in the database.
         """
+
+        def str_or_none(val: Any) -> Optional[str]:
+            return val if isinstance(val, str) else None
+
         self.db_pool.simple_insert_many_txn(
             txn,
             table="room_memberships",
@@ -1118,8 +1122,8 @@ class PersistEventsStore:
                     "sender": event.user_id,
                     "room_id": event.room_id,
                     "membership": event.membership,
-                    "display_name": event.content.get("displayname", None),
-                    "avatar_url": event.content.get("avatar_url", None),
+                    "display_name": str_or_none(event.content.get("displayname")),
+                    "avatar_url": str_or_none(event.content.get("avatar_url")),
                 }
                 for event in events
             ],
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index de9e8d1dc6..f95679ebc4 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -83,21 +83,25 @@ class EventsWorkerStore(SQLBaseStore):
             self._stream_id_gen = MultiWriterIdGenerator(
                 db_conn=db_conn,
                 db=database,
+                stream_name="events",
                 instance_name=hs.get_instance_name(),
                 table="events",
                 instance_column="instance_name",
                 id_column="stream_ordering",
                 sequence_name="events_stream_seq",
+                writers=hs.config.worker.writers.events,
             )
             self._backfill_id_gen = MultiWriterIdGenerator(
                 db_conn=db_conn,
                 db=database,
+                stream_name="backfill",
                 instance_name=hs.get_instance_name(),
                 table="events",
                 instance_column="instance_name",
                 id_column="stream_ordering",
                 sequence_name="events_backfill_stream_seq",
                 positive=False,
+                writers=hs.config.worker.writers.events,
             )
         else:
             # We shouldn't be running in worker mode with SQLite, but its useful
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index ccfbb2135e..7218191965 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -1265,7 +1265,7 @@ class GroupServerStore(GroupServerWorkerStore):
 
             return next_id
 
-        with await self._group_updates_id_gen.get_next() as next_id:
+        async with self._group_updates_id_gen.get_next() as next_id:
             res = await self.db_pool.runInteraction(
                 "register_user_group_membership",
                 _register_user_group_membership_txn,
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index c9f655dfb7..dbbb99cb95 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -23,11 +23,11 @@ from synapse.util.iterutils import batch_iter
 
 class PresenceStore(SQLBaseStore):
     async def update_presence(self, presence_states):
-        stream_ordering_manager = await self._presence_id_gen.get_next_mult(
+        stream_ordering_manager = self._presence_id_gen.get_next_mult(
             len(presence_states)
         )
 
-        with stream_ordering_manager as stream_orderings:
+        async with stream_ordering_manager as stream_orderings:
             await self.db_pool.runInteraction(
                 "update_presence",
                 self._update_presence_txn,
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index e20a16f907..711d5aa23d 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore):
     ) -> None:
         conditions_json = json_encoder.encode(conditions)
         actions_json = json_encoder.encode(actions)
-        with await self._push_rules_stream_id_gen.get_next() as stream_id:
+        async with self._push_rules_stream_id_gen.get_next() as stream_id:
             event_stream_ordering = self._stream_id_gen.get_current_token()
 
             if before or after:
@@ -585,7 +585,7 @@ class PushRuleStore(PushRulesWorkerStore):
                 txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
             )
 
-        with await self._push_rules_stream_id_gen.get_next() as stream_id:
+        async with self._push_rules_stream_id_gen.get_next() as stream_id:
             event_stream_ordering = self._stream_id_gen.get_current_token()
 
             await self.db_pool.runInteraction(
@@ -616,7 +616,7 @@ class PushRuleStore(PushRulesWorkerStore):
         Raises:
             NotFoundError if the rule does not exist.
         """
-        with await self._push_rules_stream_id_gen.get_next() as stream_id:
+        async with self._push_rules_stream_id_gen.get_next() as stream_id:
             event_stream_ordering = self._stream_id_gen.get_current_token()
             await self.db_pool.runInteraction(
                 "_set_push_rule_enabled_txn",
@@ -754,7 +754,7 @@ class PushRuleStore(PushRulesWorkerStore):
                 data={"actions": actions_json},
             )
 
-        with await self._push_rules_stream_id_gen.get_next() as stream_id:
+        async with self._push_rules_stream_id_gen.get_next() as stream_id:
             event_stream_ordering = self._stream_id_gen.get_current_token()
 
             await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index c388468273..df8609b97b 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore):
         last_stream_ordering,
         profile_tag="",
     ) -> None:
-        with await self._pushers_id_gen.get_next() as stream_id:
+        async with self._pushers_id_gen.get_next() as stream_id:
             # no need to lock because `pushers` has a unique key on
             # (app_id, pushkey, user_name) so simple_upsert will retry
             await self.db_pool.simple_upsert(
@@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore):
                 },
             )
 
-        with await self._pushers_id_gen.get_next() as stream_id:
+        async with self._pushers_id_gen.get_next() as stream_id:
             await self.db_pool.runInteraction(
                 "delete_pusher", delete_pusher_txn, stream_id
             )
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 5867d52b62..c10a16ffa3 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -577,7 +577,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
                 "insert_receipt_conv", graph_to_linear
             )
 
-        with await self._receipts_id_gen.get_next() as stream_id:
+        async with self._receipts_id_gen.get_next() as stream_id:
             event_ts = await self.db_pool.runInteraction(
                 "insert_linearized_receipt",
                 self.insert_linearized_receipt_txn,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 675e81fe34..48ce7ecd16 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -116,6 +116,20 @@ class RegistrationWorkerStore(SQLBaseStore):
             desc="get_expiration_ts_for_user",
         )
 
+    async def is_account_expired(self, user_id: str, current_ts: int) -> bool:
+        """
+        Returns whether an user account is expired.
+
+        Args:
+            user_id: The user's ID
+            current_ts: The current timestamp
+
+        Returns:
+            Whether the user account has expired
+        """
+        expiration_ts = await self.get_expiration_ts_for_user(user_id)
+        return expiration_ts is not None and current_ts >= expiration_ts
+
     async def set_account_validity_for_user(
         self,
         user_id: str,
@@ -379,7 +393,7 @@ class RegistrationWorkerStore(SQLBaseStore):
 
     async def get_user_by_external_id(
         self, auth_provider: str, external_id: str
-    ) -> str:
+    ) -> Optional[str]:
         """Look up a user by their external auth id
 
         Args:
@@ -387,7 +401,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             external_id: id on that system
 
         Returns:
-            str|None: the mxid of the user, or None if they are not known
+            the mxid of the user, or None if they are not known
         """
         return await self.db_pool.simple_select_one_onecol(
             table="user_external_ids",
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index bd6f9553c6..3c7630857f 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1137,7 +1137,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                         },
                     )
 
-            with await self._public_room_id_gen.get_next() as next_id:
+            async with self._public_room_id_gen.get_next() as next_id:
                 await self.db_pool.runInteraction(
                     "store_room_txn", store_room_txn, next_id
                 )
@@ -1204,7 +1204,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                     },
                 )
 
-        with await self._public_room_id_gen.get_next() as next_id:
+        async with self._public_room_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction(
                 "set_room_is_public", set_room_is_public_txn, next_id
             )
@@ -1284,7 +1284,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                     },
                 )
 
-        with await self._public_room_id_gen.get_next() as next_id:
+        async with self._public_room_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction(
                 "set_room_is_public_appservice",
                 set_room_is_public_appservice_txn,
@@ -1328,6 +1328,101 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             desc="add_event_report",
         )
 
+    async def get_event_reports_paginate(
+        self,
+        start: int,
+        limit: int,
+        direction: str = "b",
+        user_id: Optional[str] = None,
+        room_id: Optional[str] = None,
+    ) -> Tuple[List[Dict[str, Any]], int]:
+        """Retrieve a paginated list of event reports
+
+        Args:
+            start: event offset to begin the query from
+            limit: number of rows to retrieve
+            direction: Whether to fetch the most recent first (`"b"`) or the
+                oldest first (`"f"`)
+            user_id: search for user_id. Ignored if user_id is None
+            room_id: search for room_id. Ignored if room_id is None
+        Returns:
+            event_reports: json list of event reports
+            count: total number of event reports matching the filter criteria
+        """
+
+        def _get_event_reports_paginate_txn(txn):
+            filters = []
+            args = []
+
+            if user_id:
+                filters.append("er.user_id LIKE ?")
+                args.extend(["%" + user_id + "%"])
+            if room_id:
+                filters.append("er.room_id LIKE ?")
+                args.extend(["%" + room_id + "%"])
+
+            if direction == "b":
+                order = "DESC"
+            else:
+                order = "ASC"
+
+            where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
+
+            sql = """
+                SELECT COUNT(*) as total_event_reports
+                FROM event_reports AS er
+                {}
+                """.format(
+                where_clause
+            )
+            txn.execute(sql, args)
+            count = txn.fetchone()[0]
+
+            sql = """
+                SELECT
+                    er.id,
+                    er.received_ts,
+                    er.room_id,
+                    er.event_id,
+                    er.user_id,
+                    er.reason,
+                    er.content,
+                    events.sender,
+                    room_aliases.room_alias,
+                    event_json.json AS event_json
+                FROM event_reports AS er
+                LEFT JOIN room_aliases
+                    ON room_aliases.room_id = er.room_id
+                JOIN events
+                    ON events.event_id = er.event_id
+                JOIN event_json
+                    ON event_json.event_id = er.event_id
+                {where_clause}
+                ORDER BY er.received_ts {order}
+                LIMIT ?
+                OFFSET ?
+            """.format(
+                where_clause=where_clause, order=order,
+            )
+
+            args += [limit, start]
+            txn.execute(sql, args)
+            event_reports = self.db_pool.cursor_to_dict(txn)
+
+            if count > 0:
+                for row in event_reports:
+                    try:
+                        row["content"] = db_to_json(row["content"])
+                        row["event_json"] = db_to_json(row["event_json"])
+                    except Exception:
+                        continue
+
+            return event_reports, count
+
+        return await self.db_pool.runInteraction(
+            "get_event_reports_paginate", _get_event_reports_paginate_txn
+        )
+
     def get_current_public_room_stream_id(self):
         return self._public_room_id_gen.get_current_token()
 
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 4fa8767b01..86ffe2479e 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -13,7 +13,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set
 
@@ -37,7 +36,7 @@ from synapse.storage.roommember import (
     ProfileInfo,
     RoomsForUser,
 )
-from synapse.types import Collection, get_domain_from_id
+from synapse.types import Collection, PersistedEventPosition, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
@@ -387,7 +386,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         # for rooms the server is participating in.
         if self._current_state_events_membership_up_to_date:
             sql = """
-                SELECT room_id, e.stream_ordering
+                SELECT room_id, e.instance_name, e.stream_ordering
                 FROM current_state_events AS c
                 INNER JOIN events AS e USING (room_id, event_id)
                 WHERE
@@ -397,7 +396,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             """
         else:
             sql = """
-                SELECT room_id, e.stream_ordering
+                SELECT room_id, e.instance_name, e.stream_ordering
                 FROM current_state_events AS c
                 INNER JOIN room_memberships AS m USING (room_id, event_id)
                 INNER JOIN events AS e USING (room_id, event_id)
@@ -408,7 +407,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             """
 
         txn.execute(sql, (user_id, Membership.JOIN))
-        return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
+        return frozenset(
+            GetRoomsForUserWithStreamOrdering(
+                room_id, PersistedEventPosition(instance, stream_id)
+            )
+            for room_id, instance, stream_id in txn
+        )
 
     async def get_users_server_still_shares_room_with(
         self, user_ids: Collection[str]
diff --git a/synapse/storage/databases/main/schema/delta/56/event_labels.sql b/synapse/storage/databases/main/schema/delta/56/event_labels.sql
index 5e29c1da19..ccf287971c 100644
--- a/synapse/storage/databases/main/schema/delta/56/event_labels.sql
+++ b/synapse/storage/databases/main/schema/delta/56/event_labels.sql
@@ -13,7 +13,7 @@
  * limitations under the License.
  */
 
--- room_id and topoligical_ordering are denormalised from the events table in order to
+-- room_id and topological_ordering are denormalised from the events table in order to
 -- make the index work.
 CREATE TABLE IF NOT EXISTS event_labels (
     event_id TEXT,
diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
index 97c1e6a0c5..c31f9af82a 100644
--- a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
+++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
@@ -21,6 +21,8 @@ SELECT setval('events_stream_seq', (
 
 CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq;
 
+-- If the server has never backfilled a room then doing `-MIN(...)` will give
+-- a negative result, hence why we do `GREATEST(...)`
 SELECT setval('events_backfill_stream_seq', (
-    SELECT COALESCE(-MIN(stream_ordering), 1) FROM events
+    SELECT GREATEST(COALESCE(-MIN(stream_ordering), 1), 1) FROM events
 ));
diff --git a/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql b/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql
new file mode 100644
index 0000000000..985fd949a2
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql
@@ -0,0 +1,22 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE stream_positions (
+    stream_name TEXT NOT NULL,
+    instance_name TEXT NOT NULL,
+    stream_id BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX stream_positions_idx ON stream_positions(stream_name, instance_name);
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index d7816a8606..5beb302be3 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -210,6 +210,7 @@ class StatsStore(StateDeltasStore):
         * topic
         * avatar
         * canonical_alias
+        * guest_access
 
         A is_federatable key can also be included with a boolean value.
 
@@ -234,6 +235,7 @@ class StatsStore(StateDeltasStore):
             "topic",
             "avatar",
             "canonical_alias",
+            "guest_access",
         ):
             field = fields.get(col, sentinel)
             if field is not sentinel and (not isinstance(field, str) or "\0" in field):
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 96ffe26cc9..9f120d3cb6 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore):
             )
             self._update_revision_txn(txn, user_id, room_id, next_id)
 
-        with await self._account_data_id_gen.get_next() as next_id:
+        async with self._account_data_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
 
         self.get_tags_for_user.invalidate((user_id,))
@@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore):
             txn.execute(sql, (user_id, room_id, tag))
             self._update_revision_txn(txn, user_id, room_id, next_id)
 
-        with await self._account_data_id_gen.get_next() as next_id:
+        async with self._account_data_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
 
         self.get_tags_for_user.invalidate((user_id,))
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index 2f7c95fc74..f9575b1f1f 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -100,7 +100,7 @@ class UserErasureStore(UserErasureWorkerStore):
                 return
 
             # They are there, delete them.
-            self.simple_delete_one_txn(
+            self.db_pool.simple_delete_one_txn(
                 txn, "erased_users", keyvalues={"user_id": user_id}
             )
 
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index d89f6ed128..603cd7d825 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -31,7 +31,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.databases import Databases
 from synapse.storage.databases.main.events import DeltaState
-from synapse.types import Collection, StateMap
+from synapse.types import Collection, PersistedEventPosition, RoomStreamToken, StateMap
 from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.metrics import Measure
 
@@ -190,6 +190,7 @@ class EventsPersistenceStorage:
         self.persist_events_store = stores.persist_events
 
         self._clock = hs.get_clock()
+        self._instance_name = hs.get_instance_name()
         self.is_mine_id = hs.is_mine_id
         self._event_persist_queue = _EventPeristenceQueue()
         self._state_resolution_handler = hs.get_state_resolution_handler()
@@ -198,7 +199,7 @@ class EventsPersistenceStorage:
         self,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         backfilled: bool = False,
-    ) -> int:
+    ) -> RoomStreamToken:
         """
         Write events to the database
         Args:
@@ -228,11 +229,11 @@ class EventsPersistenceStorage:
             defer.gatherResults(deferreds, consumeErrors=True)
         )
 
-        return self.main_store.get_current_events_token()
+        return RoomStreamToken(None, self.main_store.get_current_events_token())
 
     async def persist_event(
         self, event: EventBase, context: EventContext, backfilled: bool = False
-    ) -> Tuple[int, int]:
+    ) -> Tuple[PersistedEventPosition, RoomStreamToken]:
         """
         Returns:
             The stream ordering of `event`, and the stream ordering of the
@@ -247,7 +248,10 @@ class EventsPersistenceStorage:
         await make_deferred_yieldable(deferred)
 
         max_persisted_id = self.main_store.get_current_events_token()
-        return (event.internal_metadata.stream_ordering, max_persisted_id)
+        event_stream_id = event.internal_metadata.stream_ordering
+
+        pos = PersistedEventPosition(self._instance_name, event_stream_id)
+        return pos, RoomStreamToken(None, max_persisted_id)
 
     def _maybe_start_persisting(self, room_id: str):
         async def persisting_queue(item):
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 8c4a83a840..f152f63321 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -25,7 +25,7 @@ RoomsForUser = namedtuple(
 )
 
 GetRoomsForUserWithStreamOrdering = namedtuple(
-    "_GetRoomsForUserWithStreamOrdering", ("room_id", "stream_ordering")
+    "_GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos")
 )
 
 
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 1de2b91587..4269eaf918 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -12,16 +12,17 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-import contextlib
 import heapq
 import logging
 import threading
 from collections import deque
-from typing import Dict, List, Set
+from contextlib import contextmanager
+from typing import Dict, List, Optional, Set, Union
 
+import attr
 from typing_extensions import Deque
 
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.util.sequence import PostgresSequenceGenerator
 
@@ -86,7 +87,7 @@ class StreamIdGenerator:
             upwards, -1 to grow downwards.
 
     Usage:
-        with await stream_id_gen.get_next() as stream_id:
+        async with stream_id_gen.get_next() as stream_id:
             # ... persist event ...
     """
 
@@ -101,10 +102,10 @@ class StreamIdGenerator:
             )
         self._unfinished_ids = deque()  # type: Deque[int]
 
-    async def get_next(self):
+    def get_next(self):
         """
         Usage:
-            with await stream_id_gen.get_next() as stream_id:
+            async with stream_id_gen.get_next() as stream_id:
                 # ... persist event ...
         """
         with self._lock:
@@ -113,7 +114,7 @@ class StreamIdGenerator:
 
             self._unfinished_ids.append(next_id)
 
-        @contextlib.contextmanager
+        @contextmanager
         def manager():
             try:
                 yield next_id
@@ -121,12 +122,12 @@ class StreamIdGenerator:
                 with self._lock:
                     self._unfinished_ids.remove(next_id)
 
-        return manager()
+        return _AsyncCtxManagerWrapper(manager())
 
-    async def get_next_mult(self, n):
+    def get_next_mult(self, n):
         """
         Usage:
-            with await stream_id_gen.get_next(n) as stream_ids:
+            async with stream_id_gen.get_next(n) as stream_ids:
                 # ... persist events ...
         """
         with self._lock:
@@ -140,7 +141,7 @@ class StreamIdGenerator:
             for next_id in next_ids:
                 self._unfinished_ids.append(next_id)
 
-        @contextlib.contextmanager
+        @contextmanager
         def manager():
             try:
                 yield next_ids
@@ -149,7 +150,7 @@ class StreamIdGenerator:
                     for next_id in next_ids:
                         self._unfinished_ids.remove(next_id)
 
-        return manager()
+        return _AsyncCtxManagerWrapper(manager())
 
     def get_current_token(self):
         """Returns the maximum stream id such that all stream ids less than or
@@ -184,12 +185,16 @@ class MultiWriterIdGenerator:
     Args:
         db_conn
         db
+        stream_name: A name for the stream.
         instance_name: The name of this instance.
         table: Database table associated with stream.
         instance_column: Column that stores the row's writer's instance name
         id_column: Column that stores the stream ID.
         sequence_name: The name of the postgres sequence used to generate new
             IDs.
+        writers: A list of known writers to use to populate current positions
+            on startup. Can be empty if nothing uses `get_current_token` or
+            `get_positions` (e.g. caches stream).
         positive: Whether the IDs are positive (true) or negative (false).
             When using negative IDs we go backwards from -1 to -2, -3, etc.
     """
@@ -198,16 +203,20 @@ class MultiWriterIdGenerator:
         self,
         db_conn,
         db: DatabasePool,
+        stream_name: str,
         instance_name: str,
         table: str,
         instance_column: str,
         id_column: str,
         sequence_name: str,
+        writers: List[str],
         positive: bool = True,
     ):
         self._db = db
+        self._stream_name = stream_name
         self._instance_name = instance_name
         self._positive = positive
+        self._writers = writers
         self._return_factor = 1 if positive else -1
 
         # We lock as some functions may be called from DB threads.
@@ -216,9 +225,7 @@ class MultiWriterIdGenerator:
         # Note: If we are a negative stream then we still store all the IDs as
         # positive to make life easier for us, and simply negate the IDs when we
         # return them.
-        self._current_positions = self._load_current_ids(
-            db_conn, table, instance_column, id_column
-        )
+        self._current_positions = {}  # type: Dict[str, int]
 
         # Set of local IDs that we're still processing. The current position
         # should be less than the minimum of this set (if not empty).
@@ -251,30 +258,84 @@ class MultiWriterIdGenerator:
 
         self._sequence_gen = PostgresSequenceGenerator(sequence_name)
 
+        # This goes and fills out the above state from the database.
+        self._load_current_ids(db_conn, table, instance_column, id_column)
+
     def _load_current_ids(
         self, db_conn, table: str, instance_column: str, id_column: str
-    ) -> Dict[str, int]:
-        # If positive stream aggregate via MAX. For negative stream use MIN
-        # *and* negate the result to get a positive number.
-        sql = """
-            SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s
-            GROUP BY %(instance)s
-        """ % {
-            "instance": instance_column,
-            "id": id_column,
-            "table": table,
-            "agg": "MAX" if self._positive else "-MIN",
-        }
-
+    ):
         cur = db_conn.cursor()
-        cur.execute(sql)
 
-        # `cur` is an iterable over returned rows, which are 2-tuples.
-        current_positions = dict(cur)
+        # Load the current positions of all writers for the stream.
+        if self._writers:
+            sql = """
+                SELECT instance_name, stream_id FROM stream_positions
+                WHERE stream_name = ?
+            """
+            sql = self._db.engine.convert_param_style(sql)
 
-        cur.close()
+            cur.execute(sql, (self._stream_name,))
 
-        return current_positions
+            self._current_positions = {
+                instance: stream_id * self._return_factor
+                for instance, stream_id in cur
+                if instance in self._writers
+            }
+
+        # We set the `_persisted_upto_position` to be the minimum of all current
+        # positions. If empty we use the max stream ID from the DB table.
+        min_stream_id = min(self._current_positions.values(), default=None)
+
+        if min_stream_id is None:
+            # We add a GREATEST here to ensure that the result is always
+            # positive. (This can be a problem for e.g. backfill streams where
+            # the server has never backfilled).
+            sql = """
+                SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
+                FROM %(table)s
+            """ % {
+                "id": id_column,
+                "table": table,
+                "agg": "MAX" if self._positive else "-MIN",
+            }
+            cur.execute(sql)
+            (stream_id,) = cur.fetchone()
+            self._persisted_upto_position = stream_id
+        else:
+            # If we have a min_stream_id then we pull out everything greater
+            # than it from the DB so that we can prefill
+            # `_known_persisted_positions` and get a more accurate
+            # `_persisted_upto_position`.
+            #
+            # We also check if any of the later rows are from this instance, in
+            # which case we use that for this instance's current position. This
+            # is to handle the case where we didn't finish persisting to the
+            # stream positions table before restart (or the stream position
+            # table otherwise got out of date).
+
+            sql = """
+                SELECT %(instance)s, %(id)s FROM %(table)s
+                WHERE ? %(cmp)s %(id)s
+            """ % {
+                "id": id_column,
+                "table": table,
+                "instance": instance_column,
+                "cmp": "<=" if self._positive else ">=",
+            }
+            sql = self._db.engine.convert_param_style(sql)
+            cur.execute(sql, (min_stream_id,))
+
+            self._persisted_upto_position = min_stream_id
+
+            with self._lock:
+                for (instance, stream_id,) in cur:
+                    stream_id = self._return_factor * stream_id
+                    self._add_persisted_position(stream_id)
+
+                    if instance == self._instance_name:
+                        self._current_positions[instance] = stream_id
+
+        cur.close()
 
     def _load_next_id_txn(self, txn) -> int:
         return self._sequence_gen.get_next_id_txn(txn)
@@ -282,59 +343,23 @@ class MultiWriterIdGenerator:
     def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
         return self._sequence_gen.get_next_mult_txn(txn, n)
 
-    async def get_next(self):
+    def get_next(self):
         """
         Usage:
-            with await stream_id_gen.get_next() as stream_id:
+            async with stream_id_gen.get_next() as stream_id:
                 # ... persist event ...
         """
-        next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
-
-        # Assert the fetched ID is actually greater than what we currently
-        # believe the ID to be. If not, then the sequence and table have got
-        # out of sync somehow.
-        with self._lock:
-            assert self._current_positions.get(self._instance_name, 0) < next_id
-
-            self._unfinished_ids.add(next_id)
 
-        @contextlib.contextmanager
-        def manager():
-            try:
-                # Multiply by the return factor so that the ID has correct sign.
-                yield self._return_factor * next_id
-            finally:
-                self._mark_id_as_finished(next_id)
+        return _MultiWriterCtxManager(self)
 
-        return manager()
-
-    async def get_next_mult(self, n: int):
+    def get_next_mult(self, n: int):
         """
         Usage:
-            with await stream_id_gen.get_next_mult(5) as stream_ids:
+            async with stream_id_gen.get_next_mult(5) as stream_ids:
                 # ... persist events ...
         """
-        next_ids = await self._db.runInteraction(
-            "_load_next_mult_id", self._load_next_mult_id_txn, n
-        )
-
-        # Assert the fetched ID is actually greater than any ID we've already
-        # seen. If not, then the sequence and table have got out of sync
-        # somehow.
-        with self._lock:
-            assert max(self._current_positions.values(), default=0) < min(next_ids)
-
-            self._unfinished_ids.update(next_ids)
-
-        @contextlib.contextmanager
-        def manager():
-            try:
-                yield [self._return_factor * i for i in next_ids]
-            finally:
-                for i in next_ids:
-                    self._mark_id_as_finished(i)
 
-        return manager()
+        return _MultiWriterCtxManager(self, n)
 
     def get_next_txn(self, txn: LoggingTransaction):
         """
@@ -352,6 +377,21 @@ class MultiWriterIdGenerator:
         txn.call_after(self._mark_id_as_finished, next_id)
         txn.call_on_exception(self._mark_id_as_finished, next_id)
 
+        # Update the `stream_positions` table with newly updated stream
+        # ID (unless self._writers is not set in which case we don't
+        # bother, as nothing will read it).
+        #
+        # We only do this on the success path so that the persisted current
+        # position points to a persited row with the correct instance name.
+        if self._writers:
+            txn.call_after(
+                run_as_background_process,
+                "MultiWriterIdGenerator._update_table",
+                self._db.runInteraction,
+                "MultiWriterIdGenerator._update_table",
+                self._update_stream_positions_table_txn,
+            )
+
         return self._return_factor * next_id
 
     def _mark_id_as_finished(self, next_id: int):
@@ -482,3 +522,95 @@ class MultiWriterIdGenerator:
                 # There was a gap in seen positions, so there is nothing more to
                 # do.
                 break
+
+    def _update_stream_positions_table_txn(self, txn):
+        """Update the `stream_positions` table with newly persisted position.
+        """
+
+        if not self._writers:
+            return
+
+        # We upsert the value, ensuring on conflict that we always increase the
+        # value (or decrease if stream goes backwards).
+        sql = """
+            INSERT INTO stream_positions (stream_name, instance_name, stream_id)
+            VALUES (?, ?, ?)
+            ON CONFLICT (stream_name, instance_name)
+            DO UPDATE SET
+                stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id)
+        """ % {
+            "agg": "GREATEST" if self._positive else "LEAST",
+        }
+
+        pos = (self.get_current_token_for_writer(self._instance_name),)
+        txn.execute(sql, (self._stream_name, self._instance_name, pos))
+
+
+@attr.s(slots=True)
+class _AsyncCtxManagerWrapper:
+    """Helper class to convert a plain context manager to an async one.
+
+    This is mainly useful if you have a plain context manager but the interface
+    requires an async one.
+    """
+
+    inner = attr.ib()
+
+    async def __aenter__(self):
+        return self.inner.__enter__()
+
+    async def __aexit__(self, exc_type, exc, tb):
+        return self.inner.__exit__(exc_type, exc, tb)
+
+
+@attr.s(slots=True)
+class _MultiWriterCtxManager:
+    """Async context manager returned by MultiWriterIdGenerator
+    """
+
+    id_gen = attr.ib(type=MultiWriterIdGenerator)
+    multiple_ids = attr.ib(type=Optional[int], default=None)
+    stream_ids = attr.ib(type=List[int], factory=list)
+
+    async def __aenter__(self) -> Union[int, List[int]]:
+        self.stream_ids = await self.id_gen._db.runInteraction(
+            "_load_next_mult_id",
+            self.id_gen._load_next_mult_id_txn,
+            self.multiple_ids or 1,
+        )
+
+        # Assert the fetched ID is actually greater than any ID we've already
+        # seen. If not, then the sequence and table have got out of sync
+        # somehow.
+        with self.id_gen._lock:
+            assert max(self.id_gen._current_positions.values(), default=0) < min(
+                self.stream_ids
+            )
+
+            self.id_gen._unfinished_ids.update(self.stream_ids)
+
+        if self.multiple_ids is None:
+            return self.stream_ids[0] * self.id_gen._return_factor
+        else:
+            return [i * self.id_gen._return_factor for i in self.stream_ids]
+
+    async def __aexit__(self, exc_type, exc, tb):
+        for i in self.stream_ids:
+            self.id_gen._mark_id_as_finished(i)
+
+        if exc_type is not None:
+            return False
+
+        # Update the `stream_positions` table with newly updated stream
+        # ID (unless self._writers is not set in which case we don't
+        # bother, as nothing will read it).
+        #
+        # We only do this on the success path so that the persisted current
+        # position points to a persisted row with the correct instance name.
+        if self.id_gen._writers:
+            await self.id_gen._db.runInteraction(
+                "MultiWriterIdGenerator._update_table",
+                self.id_gen._update_stream_positions_table_txn,
+            )
+
+        return False