summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erikj@element.io>2024-05-30 12:07:32 +0100
committerGitHub <noreply@github.com>2024-05-30 11:07:32 +0000
commitd16910ca021320f0fa09c6cf82a802ee97e22a0c (patch)
treeeea35c87fecdccec394c8afee0ad23182d4c783d
parentClean out invalid destinations from outbox (#17242) (diff)
downloadsynapse-d16910ca021320f0fa09c6cf82a802ee97e22a0c.tar.xz
Replaces all usages of `StreamIdGenerator` with `MultiWriterIdGenerator` (#17229)
Replaces all usages of `StreamIdGenerator` with `MultiWriterIdGenerator`, which is safer.
-rw-r--r--changelog.d/17229.misc1
-rwxr-xr-xsynapse/_scripts/synapse_port_db.py71
-rw-r--r--synapse/storage/databases/main/devices.py54
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py19
-rw-r--r--synapse/storage/databases/main/push_rule.py24
-rw-r--r--synapse/storage/databases/main/pusher.py42
-rw-r--r--synapse/storage/schema/main/delta/85/02_add_instance_names.sql27
-rw-r--r--synapse/storage/schema/main/delta/85/03_new_sequences.sql.postgres54
-rw-r--r--synapse/storage/util/id_generators.py158
-rw-r--r--tests/storage/test_id_generators.py140
10 files changed, 227 insertions, 363 deletions
diff --git a/changelog.d/17229.misc b/changelog.d/17229.misc
new file mode 100644
index 0000000000..d411550786
--- /dev/null
+++ b/changelog.d/17229.misc
@@ -0,0 +1 @@
+Replaces all usages of `StreamIdGenerator` with `MultiWriterIdGenerator`.
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index 1e56f46911..3bb4a34938 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -777,22 +777,74 @@ class Porter:
             await self._setup_events_stream_seqs()
             await self._setup_sequence(
                 "un_partial_stated_event_stream_sequence",
-                ("un_partial_stated_event_stream",),
+                [("un_partial_stated_event_stream", "stream_id")],
             )
             await self._setup_sequence(
-                "device_inbox_sequence", ("device_inbox", "device_federation_outbox")
+                "device_inbox_sequence",
+                [
+                    ("device_inbox", "stream_id"),
+                    ("device_federation_outbox", "stream_id"),
+                ],
             )
             await self._setup_sequence(
                 "account_data_sequence",
-                ("room_account_data", "room_tags_revisions", "account_data"),
+                [
+                    ("room_account_data", "stream_id"),
+                    ("room_tags_revisions", "stream_id"),
+                    ("account_data", "stream_id"),
+                ],
+            )
+            await self._setup_sequence(
+                "receipts_sequence",
+                [
+                    ("receipts_linearized", "stream_id"),
+                ],
+            )
+            await self._setup_sequence(
+                "presence_stream_sequence",
+                [
+                    ("presence_stream", "stream_id"),
+                ],
             )
-            await self._setup_sequence("receipts_sequence", ("receipts_linearized",))
-            await self._setup_sequence("presence_stream_sequence", ("presence_stream",))
             await self._setup_auth_chain_sequence()
             await self._setup_sequence(
                 "application_services_txn_id_seq",
-                ("application_services_txns",),
-                "txn_id",
+                [
+                    (
+                        "application_services_txns",
+                        "txn_id",
+                    )
+                ],
+            )
+            await self._setup_sequence(
+                "device_lists_sequence",
+                [
+                    ("device_lists_stream", "stream_id"),
+                    ("user_signature_stream", "stream_id"),
+                    ("device_lists_outbound_pokes", "stream_id"),
+                    ("device_lists_changes_in_room", "stream_id"),
+                    ("device_lists_remote_pending", "stream_id"),
+                    ("device_lists_changes_converted_stream_position", "stream_id"),
+                ],
+            )
+            await self._setup_sequence(
+                "e2e_cross_signing_keys_sequence",
+                [
+                    ("e2e_cross_signing_keys", "stream_id"),
+                ],
+            )
+            await self._setup_sequence(
+                "push_rules_stream_sequence",
+                [
+                    ("push_rules_stream", "stream_id"),
+                ],
+            )
+            await self._setup_sequence(
+                "pushers_sequence",
+                [
+                    ("pushers", "id"),
+                    ("deleted_pushers", "stream_id"),
+                ],
             )
 
             # Step 3. Get tables.
@@ -1101,12 +1153,11 @@ class Porter:
     async def _setup_sequence(
         self,
         sequence_name: str,
-        stream_id_tables: Iterable[str],
-        column_name: str = "stream_id",
+        stream_id_tables: Iterable[Tuple[str, str]],
     ) -> None:
         """Set a sequence to the correct value."""
         current_stream_ids = []
-        for stream_id_table in stream_id_tables:
+        for stream_id_table, column_name in stream_id_tables:
             max_stream_id = cast(
                 int,
                 await self.sqlite_store.db_pool.simple_select_one_onecol(
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 48384e238c..1c771e48f7 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -57,10 +57,7 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.types import Cursor
-from synapse.storage.util.id_generators import (
-    AbstractStreamIdGenerator,
-    StreamIdGenerator,
-)
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.types import (
     JsonDict,
     JsonMapping,
@@ -99,19 +96,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
 
         # In the worker store this is an ID tracker which we overwrite in the non-worker
         # class below that is used on the main process.
-        self._device_list_id_gen = StreamIdGenerator(
-            db_conn,
-            hs.get_replication_notifier(),
-            "device_lists_stream",
-            "stream_id",
-            extra_tables=[
-                ("user_signature_stream", "stream_id"),
-                ("device_lists_outbound_pokes", "stream_id"),
-                ("device_lists_changes_in_room", "stream_id"),
-                ("device_lists_remote_pending", "stream_id"),
-                ("device_lists_changes_converted_stream_position", "stream_id"),
+        self._device_list_id_gen = MultiWriterIdGenerator(
+            db_conn=db_conn,
+            db=database,
+            notifier=hs.get_replication_notifier(),
+            stream_name="device_lists_stream",
+            instance_name=self._instance_name,
+            tables=[
+                ("device_lists_stream", "instance_name", "stream_id"),
+                ("user_signature_stream", "instance_name", "stream_id"),
+                ("device_lists_outbound_pokes", "instance_name", "stream_id"),
+                ("device_lists_changes_in_room", "instance_name", "stream_id"),
+                ("device_lists_remote_pending", "instance_name", "stream_id"),
             ],
-            is_writer=hs.config.worker.worker_app is None,
+            sequence_name="device_lists_sequence",
+            writers=["master"],
         )
 
         device_list_max = self._device_list_id_gen.get_current_token()
@@ -762,6 +761,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
                 "stream_id": stream_id,
                 "from_user_id": from_user_id,
                 "user_ids": json_encoder.encode(user_ids),
+                "instance_name": self._instance_name,
             },
         )
 
@@ -1582,6 +1582,8 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
     ):
         super().__init__(database, db_conn, hs)
 
+        self._instance_name = hs.get_instance_name()
+
         self.db_pool.updates.register_background_index_update(
             "device_lists_stream_idx",
             index_name="device_lists_stream_user_id",
@@ -1694,6 +1696,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
                     "device_lists_outbound_pokes",
                     {
                         "stream_id": stream_id,
+                        "instance_name": self._instance_name,
                         "destination": destination,
                         "user_id": user_id,
                         "device_id": device_id,
@@ -1730,10 +1733,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
 
 
 class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
-    # Because we have write access, this will be a StreamIdGenerator
-    # (see DeviceWorkerStore.__init__)
-    _device_list_id_gen: AbstractStreamIdGenerator
-
     def __init__(
         self,
         database: DatabasePool,
@@ -2092,9 +2091,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         self.db_pool.simple_insert_many_txn(
             txn,
             table="device_lists_stream",
-            keys=("stream_id", "user_id", "device_id"),
+            keys=("instance_name", "stream_id", "user_id", "device_id"),
             values=[
-                (stream_id, user_id, device_id)
+                (self._instance_name, stream_id, user_id, device_id)
                 for stream_id, device_id in zip(stream_ids, device_ids)
             ],
         )
@@ -2124,6 +2123,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         values = [
             (
                 destination,
+                self._instance_name,
                 next(stream_id_iterator),
                 user_id,
                 device_id,
@@ -2139,6 +2139,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             table="device_lists_outbound_pokes",
             keys=(
                 "destination",
+                "instance_name",
                 "stream_id",
                 "user_id",
                 "device_id",
@@ -2157,7 +2158,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                 device_id,
                 {
                     stream_id: destination
-                    for (destination, stream_id, _, _, _, _, _) in values
+                    for (destination, _, stream_id, _, _, _, _, _) in values
                 },
             )
 
@@ -2210,6 +2211,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                 "device_id",
                 "room_id",
                 "stream_id",
+                "instance_name",
                 "converted_to_destinations",
                 "opentracing_context",
             ),
@@ -2219,6 +2221,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                     device_id,
                     room_id,
                     stream_id,
+                    self._instance_name,
                     # We only need to calculate outbound pokes for local users
                     not self.hs.is_mine_id(user_id),
                     encoded_context,
@@ -2338,7 +2341,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                     "user_id": user_id,
                     "device_id": device_id,
                 },
-                values={"stream_id": stream_id},
+                values={
+                    "stream_id": stream_id,
+                    "instance_name": self._instance_name,
+                },
                 desc="add_remote_device_list_to_pending",
             )
 
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index b219ea70ee..38d8785faa 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -58,7 +58,7 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.engines import PostgresEngine
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.types import JsonDict, JsonMapping
 from synapse.util import json_decoder, json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
@@ -1448,11 +1448,17 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
     ):
         super().__init__(database, db_conn, hs)
 
-        self._cross_signing_id_gen = StreamIdGenerator(
-            db_conn,
-            hs.get_replication_notifier(),
-            "e2e_cross_signing_keys",
-            "stream_id",
+        self._cross_signing_id_gen = MultiWriterIdGenerator(
+            db_conn=db_conn,
+            db=database,
+            notifier=hs.get_replication_notifier(),
+            stream_name="e2e_cross_signing_keys",
+            instance_name=self._instance_name,
+            tables=[
+                ("e2e_cross_signing_keys", "instance_name", "stream_id"),
+            ],
+            sequence_name="e2e_cross_signing_keys_sequence",
+            writers=["master"],
         )
 
     async def set_e2e_device_keys(
@@ -1627,6 +1633,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 "keytype": key_type,
                 "keydata": json_encoder.encode(key),
                 "stream_id": stream_id,
+                "instance_name": self._instance_name,
             },
         )
 
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 660c834518..2a39dc9f90 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -53,7 +53,7 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from synapse.storage.util.id_generators import IdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import IdGenerator, MultiWriterIdGenerator
 from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules
 from synapse.types import JsonDict
 from synapse.util import json_encoder, unwrapFirstError
@@ -126,7 +126,7 @@ class PushRulesWorkerStore(
     `get_max_push_rules_stream_id` which can be called in the initializer.
     """
 
-    _push_rules_stream_id_gen: StreamIdGenerator
+    _push_rules_stream_id_gen: MultiWriterIdGenerator
 
     def __init__(
         self,
@@ -140,14 +140,17 @@ class PushRulesWorkerStore(
             hs.get_instance_name() in hs.config.worker.writers.push_rules
         )
 
-        # In the worker store this is an ID tracker which we overwrite in the non-worker
-        # class below that is used on the main process.
-        self._push_rules_stream_id_gen = StreamIdGenerator(
-            db_conn,
-            hs.get_replication_notifier(),
-            "push_rules_stream",
-            "stream_id",
-            is_writer=self._is_push_writer,
+        self._push_rules_stream_id_gen = MultiWriterIdGenerator(
+            db_conn=db_conn,
+            db=database,
+            notifier=hs.get_replication_notifier(),
+            stream_name="push_rules_stream",
+            instance_name=self._instance_name,
+            tables=[
+                ("push_rules_stream", "instance_name", "stream_id"),
+            ],
+            sequence_name="push_rules_stream_sequence",
+            writers=hs.config.worker.writers.push_rules,
         )
 
         push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict(
@@ -880,6 +883,7 @@ class PushRulesWorkerStore(
             raise Exception("Not a push writer")
 
         values = {
+            "instance_name": self._instance_name,
             "stream_id": stream_id,
             "event_stream_ordering": event_stream_ordering,
             "user_id": user_id,
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 39e22d3b43..a8a37b6c85 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -40,10 +40,7 @@ from synapse.storage.database import (
     LoggingDatabaseConnection,
     LoggingTransaction,
 )
-from synapse.storage.util.id_generators import (
-    AbstractStreamIdGenerator,
-    StreamIdGenerator,
-)
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
@@ -84,15 +81,20 @@ class PusherWorkerStore(SQLBaseStore):
     ):
         super().__init__(database, db_conn, hs)
 
-        # In the worker store this is an ID tracker which we overwrite in the non-worker
-        # class below that is used on the main process.
-        self._pushers_id_gen = StreamIdGenerator(
-            db_conn,
-            hs.get_replication_notifier(),
-            "pushers",
-            "id",
-            extra_tables=[("deleted_pushers", "stream_id")],
-            is_writer=hs.config.worker.worker_app is None,
+        self._instance_name = hs.get_instance_name()
+
+        self._pushers_id_gen = MultiWriterIdGenerator(
+            db_conn=db_conn,
+            db=database,
+            notifier=hs.get_replication_notifier(),
+            stream_name="pushers",
+            instance_name=self._instance_name,
+            tables=[
+                ("pushers", "instance_name", "id"),
+                ("deleted_pushers", "instance_name", "stream_id"),
+            ],
+            sequence_name="pushers_sequence",
+            writers=["master"],
         )
 
         self.db_pool.updates.register_background_update_handler(
@@ -655,7 +657,7 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
 class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
     # Because we have write access, this will be a StreamIdGenerator
     # (see PusherWorkerStore.__init__)
-    _pushers_id_gen: AbstractStreamIdGenerator
+    _pushers_id_gen: MultiWriterIdGenerator
 
     async def add_pusher(
         self,
@@ -688,6 +690,7 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
                     "last_stream_ordering": last_stream_ordering,
                     "profile_tag": profile_tag,
                     "id": stream_id,
+                    "instance_name": self._instance_name,
                     "enabled": enabled,
                     "device_id": device_id,
                     # XXX(quenting): We're only really persisting the access token ID
@@ -735,6 +738,7 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
                 table="deleted_pushers",
                 values={
                     "stream_id": stream_id,
+                    "instance_name": self._instance_name,
                     "app_id": app_id,
                     "pushkey": pushkey,
                     "user_id": user_id,
@@ -773,9 +777,15 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
             self.db_pool.simple_insert_many_txn(
                 txn,
                 table="deleted_pushers",
-                keys=("stream_id", "app_id", "pushkey", "user_id"),
+                keys=("stream_id", "instance_name", "app_id", "pushkey", "user_id"),
                 values=[
-                    (stream_id, pusher.app_id, pusher.pushkey, user_id)
+                    (
+                        stream_id,
+                        self._instance_name,
+                        pusher.app_id,
+                        pusher.pushkey,
+                        user_id,
+                    )
                     for stream_id, pusher in zip(stream_ids, pushers)
                 ],
             )
diff --git a/synapse/storage/schema/main/delta/85/02_add_instance_names.sql b/synapse/storage/schema/main/delta/85/02_add_instance_names.sql
new file mode 100644
index 0000000000..d604595f73
--- /dev/null
+++ b/synapse/storage/schema/main/delta/85/02_add_instance_names.sql
@@ -0,0 +1,27 @@
+--
+-- This file is licensed under the Affero General Public License (AGPL) version 3.
+--
+-- Copyright (C) 2024 New Vector, Ltd
+--
+-- This program is free software: you can redistribute it and/or modify
+-- it under the terms of the GNU Affero General Public License as
+-- published by the Free Software Foundation, either version 3 of the
+-- License, or (at your option) any later version.
+--
+-- See the GNU Affero General Public License for more details:
+-- <https://www.gnu.org/licenses/agpl-3.0.html>.
+
+-- Add `instance_name` columns to stream tables to allow them to be used with
+-- `MultiWriterIdGenerator`
+ALTER TABLE device_lists_stream ADD COLUMN instance_name TEXT;
+ALTER TABLE user_signature_stream ADD COLUMN instance_name TEXT;
+ALTER TABLE device_lists_outbound_pokes ADD COLUMN instance_name TEXT;
+ALTER TABLE device_lists_changes_in_room ADD COLUMN instance_name TEXT;
+ALTER TABLE device_lists_remote_pending ADD COLUMN instance_name TEXT;
+
+ALTER TABLE e2e_cross_signing_keys ADD COLUMN instance_name TEXT;
+
+ALTER TABLE push_rules_stream ADD COLUMN instance_name TEXT;
+
+ALTER TABLE pushers ADD COLUMN instance_name TEXT;
+ALTER TABLE deleted_pushers ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/schema/main/delta/85/03_new_sequences.sql.postgres b/synapse/storage/schema/main/delta/85/03_new_sequences.sql.postgres
new file mode 100644
index 0000000000..9d34066bf5
--- /dev/null
+++ b/synapse/storage/schema/main/delta/85/03_new_sequences.sql.postgres
@@ -0,0 +1,54 @@
+--
+-- This file is licensed under the Affero General Public License (AGPL) version 3.
+--
+-- Copyright (C) 2024 New Vector, Ltd
+--
+-- This program is free software: you can redistribute it and/or modify
+-- it under the terms of the GNU Affero General Public License as
+-- published by the Free Software Foundation, either version 3 of the
+-- License, or (at your option) any later version.
+--
+-- See the GNU Affero General Public License for more details:
+-- <https://www.gnu.org/licenses/agpl-3.0.html>.
+
+-- Add squences for stream tables to allow them to be used with
+-- `MultiWriterIdGenerator`
+CREATE SEQUENCE IF NOT EXISTS device_lists_sequence;
+
+-- We need to take the max across all the device lists tables as they share the
+-- ID generator
+SELECT setval('device_lists_sequence', (
+    SELECT GREATEST(
+        (SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_stream),
+        (SELECT COALESCE(MAX(stream_id), 1) FROM user_signature_stream),
+        (SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_outbound_pokes),
+        (SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_changes_in_room),
+        (SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_remote_pending),
+        (SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_changes_converted_stream_position)
+    )
+));
+
+CREATE SEQUENCE IF NOT EXISTS e2e_cross_signing_keys_sequence;
+
+SELECT setval('e2e_cross_signing_keys_sequence', (
+    SELECT COALESCE(MAX(stream_id), 1) FROM e2e_cross_signing_keys
+));
+
+
+CREATE SEQUENCE IF NOT EXISTS push_rules_stream_sequence;
+
+SELECT setval('push_rules_stream_sequence', (
+    SELECT COALESCE(MAX(stream_id), 1) FROM push_rules_stream
+));
+
+
+CREATE SEQUENCE IF NOT EXISTS pushers_sequence;
+
+-- We need to take the max across all the pusher tables as they share the
+-- ID generator
+SELECT setval('pushers_sequence', (
+    SELECT GREATEST(
+        (SELECT COALESCE(MAX(id), 1) FROM pushers),
+        (SELECT COALESCE(MAX(stream_id), 1) FROM deleted_pushers)
+    )
+));
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 0cf5851ad7..59c8e05c39 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -23,15 +23,12 @@ import abc
 import heapq
 import logging
 import threading
-from collections import OrderedDict
-from contextlib import contextmanager
 from types import TracebackType
 from typing import (
     TYPE_CHECKING,
     AsyncContextManager,
     ContextManager,
     Dict,
-    Generator,
     Generic,
     Iterable,
     List,
@@ -179,161 +176,6 @@ class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
         raise NotImplementedError()
 
 
-class StreamIdGenerator(AbstractStreamIdGenerator):
-    """Generates and tracks stream IDs for a stream with a single writer.
-
-    This class must only be used when the current Synapse process is the sole
-    writer for a stream.
-
-    Args:
-        db_conn(connection):  A database connection to use to fetch the
-            initial value of the generator from.
-        table(str): A database table to read the initial value of the id
-            generator from.
-        column(str): The column of the database table to read the initial
-            value from the id generator from.
-        extra_tables(list): List of pairs of database tables and columns to
-            use to source the initial value of the generator from. The value
-            with the largest magnitude is used.
-        step(int): which direction the stream ids grow in. +1 to grow
-            upwards, -1 to grow downwards.
-
-    Usage:
-        async with stream_id_gen.get_next() as stream_id:
-            # ... persist event ...
-    """
-
-    def __init__(
-        self,
-        db_conn: LoggingDatabaseConnection,
-        notifier: "ReplicationNotifier",
-        table: str,
-        column: str,
-        extra_tables: Iterable[Tuple[str, str]] = (),
-        step: int = 1,
-        is_writer: bool = True,
-    ) -> None:
-        assert step != 0
-        self._lock = threading.Lock()
-        self._step: int = step
-        self._current: int = _load_current_id(db_conn, table, column, step)
-        self._is_writer = is_writer
-        for table, column in extra_tables:
-            self._current = (max if step > 0 else min)(
-                self._current, _load_current_id(db_conn, table, column, step)
-            )
-
-        # We use this as an ordered set, as we want to efficiently append items,
-        # remove items and get the first item. Since we insert IDs in order, the
-        # insertion ordering will ensure its in the correct ordering.
-        #
-        # The key and values are the same, but we never look at the values.
-        self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
-
-        self._notifier = notifier
-
-    def advance(self, instance_name: str, new_id: int) -> None:
-        # Advance should never be called on a writer instance, only over replication
-        if self._is_writer:
-            raise Exception("Replication is not supported by writer StreamIdGenerator")
-
-        self._current = (max if self._step > 0 else min)(self._current, new_id)
-
-    def get_next(self) -> AsyncContextManager[int]:
-        with self._lock:
-            self._current += self._step
-            next_id = self._current
-
-            self._unfinished_ids[next_id] = next_id
-
-        @contextmanager
-        def manager() -> Generator[int, None, None]:
-            try:
-                yield next_id
-            finally:
-                with self._lock:
-                    self._unfinished_ids.pop(next_id)
-
-                self._notifier.notify_replication()
-
-        return _AsyncCtxManagerWrapper(manager())
-
-    def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
-        with self._lock:
-            next_ids = range(
-                self._current + self._step,
-                self._current + self._step * (n + 1),
-                self._step,
-            )
-            self._current += n * self._step
-
-            for next_id in next_ids:
-                self._unfinished_ids[next_id] = next_id
-
-        @contextmanager
-        def manager() -> Generator[Sequence[int], None, None]:
-            try:
-                yield next_ids
-            finally:
-                with self._lock:
-                    for next_id in next_ids:
-                        self._unfinished_ids.pop(next_id)
-
-                self._notifier.notify_replication()
-
-        return _AsyncCtxManagerWrapper(manager())
-
-    def get_next_txn(self, txn: LoggingTransaction) -> int:
-        """
-        Retrieve the next stream ID from within a database transaction.
-
-        Clean-up functions will be called when the transaction finishes.
-
-        Args:
-            txn: The database transaction object.
-
-        Returns:
-            The next stream ID.
-        """
-        if not self._is_writer:
-            raise Exception("Tried to allocate stream ID on non-writer")
-
-        # Get the next stream ID.
-        with self._lock:
-            self._current += self._step
-            next_id = self._current
-
-            self._unfinished_ids[next_id] = next_id
-
-        def clear_unfinished_id(id_to_clear: int) -> None:
-            """A function to mark processing this ID as finished"""
-            with self._lock:
-                self._unfinished_ids.pop(id_to_clear)
-
-        # Mark this ID as finished once the database transaction itself finishes.
-        txn.call_after(clear_unfinished_id, next_id)
-        txn.call_on_exception(clear_unfinished_id, next_id)
-
-        # Return the new ID.
-        return next_id
-
-    def get_current_token(self) -> int:
-        if not self._is_writer:
-            return self._current
-
-        with self._lock:
-            if self._unfinished_ids:
-                return next(iter(self._unfinished_ids)) - self._step
-
-            return self._current
-
-    def get_current_token_for_writer(self, instance_name: str) -> int:
-        return self.get_current_token()
-
-    def get_minimal_local_current_token(self) -> int:
-        return self.get_current_token()
-
-
 class MultiWriterIdGenerator(AbstractStreamIdGenerator):
     """Generates and tracks stream IDs for a stream with multiple writers.
 
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index fad9511cea..f0307252f3 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -30,7 +30,7 @@ from synapse.storage.database import (
 )
 from synapse.storage.engines import IncorrectDatabaseSetup
 from synapse.storage.types import Cursor
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.storage.util.sequence import (
     LocalSequenceGenerator,
     PostgresSequenceGenerator,
@@ -42,144 +42,6 @@ from tests.unittest import HomeserverTestCase
 from tests.utils import USE_POSTGRES_FOR_TESTS
 
 
-class StreamIdGeneratorTestCase(HomeserverTestCase):
-    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastores().main
-        self.db_pool: DatabasePool = self.store.db_pool
-
-        self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
-
-    def _setup_db(self, txn: LoggingTransaction) -> None:
-        txn.execute(
-            """
-            CREATE TABLE foobar (
-                stream_id BIGINT NOT NULL,
-                data TEXT
-            );
-            """
-        )
-        txn.execute("INSERT INTO foobar VALUES (123, 'hello world');")
-
-    def _create_id_generator(self) -> StreamIdGenerator:
-        def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator:
-            return StreamIdGenerator(
-                db_conn=conn,
-                notifier=self.hs.get_replication_notifier(),
-                table="foobar",
-                column="stream_id",
-            )
-
-        return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
-
-    def test_initial_value(self) -> None:
-        """Check that we read the current token from the DB."""
-        id_gen = self._create_id_generator()
-        self.assertEqual(id_gen.get_current_token(), 123)
-
-    def test_single_gen_next(self) -> None:
-        """Check that we correctly increment the current token from the DB."""
-        id_gen = self._create_id_generator()
-
-        async def test_gen_next() -> None:
-            async with id_gen.get_next() as next_id:
-                # We haven't persisted `next_id` yet; current token is still 123
-                self.assertEqual(id_gen.get_current_token(), 123)
-                # But we did learn what the next value is
-                self.assertEqual(next_id, 124)
-
-            # Once the context manager closes we assume that the `next_id` has been
-            # written to the DB.
-            self.assertEqual(id_gen.get_current_token(), 124)
-
-        self.get_success(test_gen_next())
-
-    def test_multiple_gen_nexts(self) -> None:
-        """Check that we handle overlapping calls to gen_next sensibly."""
-        id_gen = self._create_id_generator()
-
-        async def test_gen_next() -> None:
-            ctx1 = id_gen.get_next()
-            ctx2 = id_gen.get_next()
-            ctx3 = id_gen.get_next()
-
-            # Request three new stream IDs.
-            self.assertEqual(await ctx1.__aenter__(), 124)
-            self.assertEqual(await ctx2.__aenter__(), 125)
-            self.assertEqual(await ctx3.__aenter__(), 126)
-
-            # None are persisted: current token unchanged.
-            self.assertEqual(id_gen.get_current_token(), 123)
-
-            # Persist each in turn.
-            await ctx1.__aexit__(None, None, None)
-            self.assertEqual(id_gen.get_current_token(), 124)
-            await ctx2.__aexit__(None, None, None)
-            self.assertEqual(id_gen.get_current_token(), 125)
-            await ctx3.__aexit__(None, None, None)
-            self.assertEqual(id_gen.get_current_token(), 126)
-
-        self.get_success(test_gen_next())
-
-    def test_multiple_gen_nexts_closed_in_different_order(self) -> None:
-        """Check that we handle overlapping calls to gen_next, even when their IDs
-        created and persisted in different orders."""
-        id_gen = self._create_id_generator()
-
-        async def test_gen_next() -> None:
-            ctx1 = id_gen.get_next()
-            ctx2 = id_gen.get_next()
-            ctx3 = id_gen.get_next()
-
-            # Request three new stream IDs.
-            self.assertEqual(await ctx1.__aenter__(), 124)
-            self.assertEqual(await ctx2.__aenter__(), 125)
-            self.assertEqual(await ctx3.__aenter__(), 126)
-
-            # None are persisted: current token unchanged.
-            self.assertEqual(id_gen.get_current_token(), 123)
-
-            # Persist them in a different order, starting with 126 from ctx3.
-            await ctx3.__aexit__(None, None, None)
-            # We haven't persisted 124 from ctx1 yet---current token is still 123.
-            self.assertEqual(id_gen.get_current_token(), 123)
-
-            # Now persist 124 from ctx1.
-            await ctx1.__aexit__(None, None, None)
-            # Current token is then 124, waiting for 125 to be persisted.
-            self.assertEqual(id_gen.get_current_token(), 124)
-
-            # Finally persist 125 from ctx2.
-            await ctx2.__aexit__(None, None, None)
-            # Current token is then 126 (skipping over 125).
-            self.assertEqual(id_gen.get_current_token(), 126)
-
-        self.get_success(test_gen_next())
-
-    def test_gen_next_while_still_waiting_for_persistence(self) -> None:
-        """Check that we handle overlapping calls to gen_next."""
-        id_gen = self._create_id_generator()
-
-        async def test_gen_next() -> None:
-            ctx1 = id_gen.get_next()
-            ctx2 = id_gen.get_next()
-            ctx3 = id_gen.get_next()
-
-            # Request two new stream IDs.
-            self.assertEqual(await ctx1.__aenter__(), 124)
-            self.assertEqual(await ctx2.__aenter__(), 125)
-
-            # Persist ctx2 first.
-            await ctx2.__aexit__(None, None, None)
-            # Still waiting on ctx1's ID to be persisted.
-            self.assertEqual(id_gen.get_current_token(), 123)
-
-            # Now request a third stream ID. It should be 126 (the smallest ID that
-            # we've not yet handed out.)
-            self.assertEqual(await ctx3.__aenter__(), 126)
-
-        self.get_success(test_gen_next())
-
-
 class MultiWriterIdGeneratorBase(HomeserverTestCase):
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main