summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erikj@element.io>2024-05-29 13:19:10 +0100
committerGitHub <noreply@github.com>2024-05-29 12:19:10 +0000
commit466f344547fc6bea2c43257dd65286380fbb512d (patch)
tree40ba5abd666ac6584b7d7cbae3603a3898ce91c5
parentDon't invalidate all `get_relations_for_event` on history purge (#17083) (diff)
downloadsynapse-466f344547fc6bea2c43257dd65286380fbb512d.tar.xz
Move towards using `MultiWriterIdGenerator` everywhere (#17226)
There is a problem with `StreamIdGenerator` where it can go backwards
over restarts when a stream ID is requested but then not inserted into
the DB. This is problematic if we want to land #17215, and is generally
a potential cause for all sorts of nastiness.

Instead of trying to fix `StreamIdGenerator`, we may as well move to
`MultiWriterIdGenerator` that does not suffer from this problem (the
latest positions are stored in `stream_positions` table). This involves
adding SQLite support to the class.

This only changes id generators that were already using
`MultiWriterIdGenerator` under postgres, a separate PR will move the
rest of the uses of `StreamIdGenerator` over.
-rw-r--r--changelog.d/17226.misc1
-rw-r--r--synapse/storage/database.py21
-rw-r--r--synapse/storage/databases/main/account_data.py47
-rw-r--r--synapse/storage/databases/main/deviceinbox.py46
-rw-r--r--synapse/storage/databases/main/events_worker.py101
-rw-r--r--synapse/storage/databases/main/presence.py27
-rw-r--r--synapse/storage/databases/main/receipts.py43
-rw-r--r--synapse/storage/databases/main/room.py34
-rw-r--r--synapse/storage/util/id_generators.py49
-rw-r--r--tests/storage/test_id_generators.py351
10 files changed, 341 insertions, 379 deletions
diff --git a/changelog.d/17226.misc b/changelog.d/17226.misc
new file mode 100644
index 0000000000..7c023a5759
--- /dev/null
+++ b/changelog.d/17226.misc
@@ -0,0 +1 @@
+Move towards using `MultiWriterIdGenerator` everywhere.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index d9c85e411e..569f618193 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -2461,7 +2461,11 @@ class DatabasePool:
 
 
 def make_in_list_sql_clause(
-    database_engine: BaseDatabaseEngine, column: str, iterable: Collection[Any]
+    database_engine: BaseDatabaseEngine,
+    column: str,
+    iterable: Collection[Any],
+    *,
+    negative: bool = False,
 ) -> Tuple[str, list]:
     """Returns an SQL clause that checks the given column is in the iterable.
 
@@ -2474,6 +2478,7 @@ def make_in_list_sql_clause(
         database_engine
         column: Name of the column
         iterable: The values to check the column against.
+        negative: Whether we should check for inequality, i.e. `NOT IN`
 
     Returns:
         A tuple of SQL query and the args
@@ -2482,9 +2487,19 @@ def make_in_list_sql_clause(
     if database_engine.supports_using_any_list:
         # This should hopefully be faster, but also makes postgres query
         # stats easier to understand.
-        return "%s = ANY(?)" % (column,), [list(iterable)]
+        if not negative:
+            clause = f"{column} = ANY(?)"
+        else:
+            clause = f"{column} != ALL(?)"
+
+        return clause, [list(iterable)]
     else:
-        return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)
+        params = ",".join("?" for _ in iterable)
+        if not negative:
+            clause = f"{column} IN ({params})"
+        else:
+            clause = f"{column} NOT IN ({params})"
+        return clause, list(iterable)
 
 
 # These overloads ensure that `columns` and `iterable` values have the same length.
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 563450a97e..9611a84932 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -43,11 +43,9 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
-from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import (
     AbstractStreamIdGenerator,
     MultiWriterIdGenerator,
-    StreamIdGenerator,
 )
 from synapse.types import JsonDict, JsonMapping
 from synapse.util import json_encoder
@@ -75,37 +73,20 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
 
         self._account_data_id_gen: AbstractStreamIdGenerator
 
-        if isinstance(database.engine, PostgresEngine):
-            self._account_data_id_gen = MultiWriterIdGenerator(
-                db_conn=db_conn,
-                db=database,
-                notifier=hs.get_replication_notifier(),
-                stream_name="account_data",
-                instance_name=self._instance_name,
-                tables=[
-                    ("room_account_data", "instance_name", "stream_id"),
-                    ("room_tags_revisions", "instance_name", "stream_id"),
-                    ("account_data", "instance_name", "stream_id"),
-                ],
-                sequence_name="account_data_sequence",
-                writers=hs.config.worker.writers.account_data,
-            )
-        else:
-            # Multiple writers are not supported for SQLite.
-            #
-            # We shouldn't be running in worker mode with SQLite, but its useful
-            # to support it for unit tests.
-            self._account_data_id_gen = StreamIdGenerator(
-                db_conn,
-                hs.get_replication_notifier(),
-                "room_account_data",
-                "stream_id",
-                extra_tables=[
-                    ("account_data", "stream_id"),
-                    ("room_tags_revisions", "stream_id"),
-                ],
-                is_writer=self._instance_name in hs.config.worker.writers.account_data,
-            )
+        self._account_data_id_gen = MultiWriterIdGenerator(
+            db_conn=db_conn,
+            db=database,
+            notifier=hs.get_replication_notifier(),
+            stream_name="account_data",
+            instance_name=self._instance_name,
+            tables=[
+                ("room_account_data", "instance_name", "stream_id"),
+                ("room_tags_revisions", "instance_name", "stream_id"),
+                ("account_data", "instance_name", "stream_id"),
+            ],
+            sequence_name="account_data_sequence",
+            writers=hs.config.worker.writers.account_data,
+        )
 
         account_max = self.get_max_account_data_stream_id()
         self._account_data_stream_cache = StreamChangeCache(
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index e17821ff6e..25023b5e7a 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -50,11 +50,9 @@ from synapse.storage.database import (
     LoggingTransaction,
     make_in_list_sql_clause,
 )
-from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import (
     AbstractStreamIdGenerator,
     MultiWriterIdGenerator,
-    StreamIdGenerator,
 )
 from synapse.types import JsonDict
 from synapse.util import json_encoder
@@ -89,35 +87,23 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             expiry_ms=30 * 60 * 1000,
         )
 
-        if isinstance(database.engine, PostgresEngine):
-            self._can_write_to_device = (
-                self._instance_name in hs.config.worker.writers.to_device
-            )
+        self._can_write_to_device = (
+            self._instance_name in hs.config.worker.writers.to_device
+        )
 
-            self._to_device_msg_id_gen: AbstractStreamIdGenerator = (
-                MultiWriterIdGenerator(
-                    db_conn=db_conn,
-                    db=database,
-                    notifier=hs.get_replication_notifier(),
-                    stream_name="to_device",
-                    instance_name=self._instance_name,
-                    tables=[
-                        ("device_inbox", "instance_name", "stream_id"),
-                        ("device_federation_outbox", "instance_name", "stream_id"),
-                    ],
-                    sequence_name="device_inbox_sequence",
-                    writers=hs.config.worker.writers.to_device,
-                )
-            )
-        else:
-            self._can_write_to_device = True
-            self._to_device_msg_id_gen = StreamIdGenerator(
-                db_conn,
-                hs.get_replication_notifier(),
-                "device_inbox",
-                "stream_id",
-                extra_tables=[("device_federation_outbox", "stream_id")],
-            )
+        self._to_device_msg_id_gen: AbstractStreamIdGenerator = MultiWriterIdGenerator(
+            db_conn=db_conn,
+            db=database,
+            notifier=hs.get_replication_notifier(),
+            stream_name="to_device",
+            instance_name=self._instance_name,
+            tables=[
+                ("device_inbox", "instance_name", "stream_id"),
+                ("device_federation_outbox", "instance_name", "stream_id"),
+            ],
+            sequence_name="device_inbox_sequence",
+            writers=hs.config.worker.writers.to_device,
+        )
 
         max_device_inbox_id = self._to_device_msg_id_gen.get_current_token()
         device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index e39d4b9624..426df2a9d2 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -75,12 +75,10 @@ from synapse.storage.database import (
     LoggingDatabaseConnection,
     LoggingTransaction,
 )
-from synapse.storage.engines import PostgresEngine
 from synapse.storage.types import Cursor
 from synapse.storage.util.id_generators import (
     AbstractStreamIdGenerator,
     MultiWriterIdGenerator,
-    StreamIdGenerator,
 )
 from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import JsonDict, get_domain_from_id
@@ -195,51 +193,28 @@ class EventsWorkerStore(SQLBaseStore):
 
         self._stream_id_gen: AbstractStreamIdGenerator
         self._backfill_id_gen: AbstractStreamIdGenerator
-        if isinstance(database.engine, PostgresEngine):
-            # If we're using Postgres than we can use `MultiWriterIdGenerator`
-            # regardless of whether this process writes to the streams or not.
-            self._stream_id_gen = MultiWriterIdGenerator(
-                db_conn=db_conn,
-                db=database,
-                notifier=hs.get_replication_notifier(),
-                stream_name="events",
-                instance_name=hs.get_instance_name(),
-                tables=[("events", "instance_name", "stream_ordering")],
-                sequence_name="events_stream_seq",
-                writers=hs.config.worker.writers.events,
-            )
-            self._backfill_id_gen = MultiWriterIdGenerator(
-                db_conn=db_conn,
-                db=database,
-                notifier=hs.get_replication_notifier(),
-                stream_name="backfill",
-                instance_name=hs.get_instance_name(),
-                tables=[("events", "instance_name", "stream_ordering")],
-                sequence_name="events_backfill_stream_seq",
-                positive=False,
-                writers=hs.config.worker.writers.events,
-            )
-        else:
-            # Multiple writers are not supported for SQLite.
-            #
-            # We shouldn't be running in worker mode with SQLite, but its useful
-            # to support it for unit tests.
-            self._stream_id_gen = StreamIdGenerator(
-                db_conn,
-                hs.get_replication_notifier(),
-                "events",
-                "stream_ordering",
-                is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
-            )
-            self._backfill_id_gen = StreamIdGenerator(
-                db_conn,
-                hs.get_replication_notifier(),
-                "events",
-                "stream_ordering",
-                step=-1,
-                extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
-                is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
-            )
+
+        self._stream_id_gen = MultiWriterIdGenerator(
+            db_conn=db_conn,
+            db=database,
+            notifier=hs.get_replication_notifier(),
+            stream_name="events",
+            instance_name=hs.get_instance_name(),
+            tables=[("events", "instance_name", "stream_ordering")],
+            sequence_name="events_stream_seq",
+            writers=hs.config.worker.writers.events,
+        )
+        self._backfill_id_gen = MultiWriterIdGenerator(
+            db_conn=db_conn,
+            db=database,
+            notifier=hs.get_replication_notifier(),
+            stream_name="backfill",
+            instance_name=hs.get_instance_name(),
+            tables=[("events", "instance_name", "stream_ordering")],
+            sequence_name="events_backfill_stream_seq",
+            positive=False,
+            writers=hs.config.worker.writers.events,
+        )
 
         events_max = self._stream_id_gen.get_current_token()
         curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
@@ -309,27 +284,17 @@ class EventsWorkerStore(SQLBaseStore):
 
         self._un_partial_stated_events_stream_id_gen: AbstractStreamIdGenerator
 
-        if isinstance(database.engine, PostgresEngine):
-            self._un_partial_stated_events_stream_id_gen = MultiWriterIdGenerator(
-                db_conn=db_conn,
-                db=database,
-                notifier=hs.get_replication_notifier(),
-                stream_name="un_partial_stated_event_stream",
-                instance_name=hs.get_instance_name(),
-                tables=[
-                    ("un_partial_stated_event_stream", "instance_name", "stream_id")
-                ],
-                sequence_name="un_partial_stated_event_stream_sequence",
-                # TODO(faster_joins, multiple writers) Support multiple writers.
-                writers=["master"],
-            )
-        else:
-            self._un_partial_stated_events_stream_id_gen = StreamIdGenerator(
-                db_conn,
-                hs.get_replication_notifier(),
-                "un_partial_stated_event_stream",
-                "stream_id",
-            )
+        self._un_partial_stated_events_stream_id_gen = MultiWriterIdGenerator(
+            db_conn=db_conn,
+            db=database,
+            notifier=hs.get_replication_notifier(),
+            stream_name="un_partial_stated_event_stream",
+            instance_name=hs.get_instance_name(),
+            tables=[("un_partial_stated_event_stream", "instance_name", "stream_id")],
+            sequence_name="un_partial_stated_event_stream_sequence",
+            # TODO(faster_joins, multiple writers) Support multiple writers.
+            writers=["master"],
+        )
 
     def get_un_partial_stated_events_token(self, instance_name: str) -> int:
         return (
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 567c2d30bd..923e764491 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -40,13 +40,11 @@ from synapse.storage.database import (
     LoggingTransaction,
 )
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
-from synapse.storage.engines import PostgresEngine
 from synapse.storage.engines._base import IsolationLevel
 from synapse.storage.types import Connection
 from synapse.storage.util.id_generators import (
     AbstractStreamIdGenerator,
     MultiWriterIdGenerator,
-    StreamIdGenerator,
 )
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -91,21 +89,16 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
             self._instance_name in hs.config.worker.writers.presence
         )
 
-        if isinstance(database.engine, PostgresEngine):
-            self._presence_id_gen = MultiWriterIdGenerator(
-                db_conn=db_conn,
-                db=database,
-                notifier=hs.get_replication_notifier(),
-                stream_name="presence_stream",
-                instance_name=self._instance_name,
-                tables=[("presence_stream", "instance_name", "stream_id")],
-                sequence_name="presence_stream_sequence",
-                writers=hs.config.worker.writers.presence,
-            )
-        else:
-            self._presence_id_gen = StreamIdGenerator(
-                db_conn, hs.get_replication_notifier(), "presence_stream", "stream_id"
-            )
+        self._presence_id_gen = MultiWriterIdGenerator(
+            db_conn=db_conn,
+            db=database,
+            notifier=hs.get_replication_notifier(),
+            stream_name="presence_stream",
+            instance_name=self._instance_name,
+            tables=[("presence_stream", "instance_name", "stream_id")],
+            sequence_name="presence_stream_sequence",
+            writers=hs.config.worker.writers.presence,
+        )
 
         self.hs = hs
         self._presence_on_startup = self._get_active_presence(db_conn)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 13387a3839..8432560a89 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -44,12 +44,10 @@ from synapse.storage.database import (
     LoggingDatabaseConnection,
     LoggingTransaction,
 )
-from synapse.storage.engines import PostgresEngine
 from synapse.storage.engines._base import IsolationLevel
 from synapse.storage.util.id_generators import (
     AbstractStreamIdGenerator,
     MultiWriterIdGenerator,
-    StreamIdGenerator,
 )
 from synapse.types import (
     JsonDict,
@@ -80,35 +78,20 @@ class ReceiptsWorkerStore(SQLBaseStore):
         # class below that is used on the main process.
         self._receipts_id_gen: AbstractStreamIdGenerator
 
-        if isinstance(database.engine, PostgresEngine):
-            self._can_write_to_receipts = (
-                self._instance_name in hs.config.worker.writers.receipts
-            )
+        self._can_write_to_receipts = (
+            self._instance_name in hs.config.worker.writers.receipts
+        )
 
-            self._receipts_id_gen = MultiWriterIdGenerator(
-                db_conn=db_conn,
-                db=database,
-                notifier=hs.get_replication_notifier(),
-                stream_name="receipts",
-                instance_name=self._instance_name,
-                tables=[("receipts_linearized", "instance_name", "stream_id")],
-                sequence_name="receipts_sequence",
-                writers=hs.config.worker.writers.receipts,
-            )
-        else:
-            self._can_write_to_receipts = True
-
-            # Multiple writers are not supported for SQLite.
-            #
-            # We shouldn't be running in worker mode with SQLite, but its useful
-            # to support it for unit tests.
-            self._receipts_id_gen = StreamIdGenerator(
-                db_conn,
-                hs.get_replication_notifier(),
-                "receipts_linearized",
-                "stream_id",
-                is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts,
-            )
+        self._receipts_id_gen = MultiWriterIdGenerator(
+            db_conn=db_conn,
+            db=database,
+            notifier=hs.get_replication_notifier(),
+            stream_name="receipts",
+            instance_name=self._instance_name,
+            tables=[("receipts_linearized", "instance_name", "stream_id")],
+            sequence_name="receipts_sequence",
+            writers=hs.config.worker.writers.receipts,
+        )
 
         super().__init__(database, db_conn, hs)
 
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 8205109548..616c941687 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -58,13 +58,11 @@ from synapse.storage.database import (
     LoggingTransaction,
 )
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
-from synapse.storage.engines import PostgresEngine
 from synapse.storage.types import Cursor
 from synapse.storage.util.id_generators import (
     AbstractStreamIdGenerator,
     IdGenerator,
     MultiWriterIdGenerator,
-    StreamIdGenerator,
 )
 from synapse.types import JsonDict, RetentionPolicy, StrCollection, ThirdPartyInstanceID
 from synapse.util import json_encoder
@@ -155,27 +153,17 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
 
         self._un_partial_stated_rooms_stream_id_gen: AbstractStreamIdGenerator
 
-        if isinstance(database.engine, PostgresEngine):
-            self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
-                db_conn=db_conn,
-                db=database,
-                notifier=hs.get_replication_notifier(),
-                stream_name="un_partial_stated_room_stream",
-                instance_name=self._instance_name,
-                tables=[
-                    ("un_partial_stated_room_stream", "instance_name", "stream_id")
-                ],
-                sequence_name="un_partial_stated_room_stream_sequence",
-                # TODO(faster_joins, multiple writers) Support multiple writers.
-                writers=["master"],
-            )
-        else:
-            self._un_partial_stated_rooms_stream_id_gen = StreamIdGenerator(
-                db_conn,
-                hs.get_replication_notifier(),
-                "un_partial_stated_room_stream",
-                "stream_id",
-            )
+        self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
+            db_conn=db_conn,
+            db=database,
+            notifier=hs.get_replication_notifier(),
+            stream_name="un_partial_stated_room_stream",
+            instance_name=self._instance_name,
+            tables=[("un_partial_stated_room_stream", "instance_name", "stream_id")],
+            sequence_name="un_partial_stated_room_stream_sequence",
+            # TODO(faster_joins, multiple writers) Support multiple writers.
+            writers=["master"],
+        )
 
     def process_replication_position(
         self, stream_name: str, instance_name: str, token: int
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index fadc75cc80..0cf5851ad7 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -53,9 +53,11 @@ from synapse.storage.database import (
     DatabasePool,
     LoggingDatabaseConnection,
     LoggingTransaction,
+    make_in_list_sql_clause,
 )
+from synapse.storage.engines import PostgresEngine
 from synapse.storage.types import Cursor
-from synapse.storage.util.sequence import PostgresSequenceGenerator
+from synapse.storage.util.sequence import build_sequence_generator
 
 if TYPE_CHECKING:
     from synapse.notifier import ReplicationNotifier
@@ -432,7 +434,22 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
         # no active writes in progress.
         self._max_position_of_local_instance = self._max_seen_allocated_stream_id
 
-        self._sequence_gen = PostgresSequenceGenerator(sequence_name)
+        # This goes and fills out the above state from the database.
+        self._load_current_ids(db_conn, tables)
+
+        self._sequence_gen = build_sequence_generator(
+            db_conn=db_conn,
+            database_engine=db.engine,
+            get_first_callback=lambda _: self._persisted_upto_position,
+            sequence_name=sequence_name,
+            # We only need to set the below if we want it to call
+            # `check_consistency`, but we do that ourselves below so we can
+            # leave them blank.
+            table=None,
+            id_column=None,
+            stream_name=None,
+            positive=positive,
+        )
 
         # We check that the table and sequence haven't diverged.
         for table, _, id_column in tables:
@@ -444,9 +461,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
                 positive=positive,
             )
 
-        # This goes and fills out the above state from the database.
-        self._load_current_ids(db_conn, tables)
-
         self._max_seen_allocated_stream_id = max(
             self._current_positions.values(), default=1
         )
@@ -480,13 +494,17 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
             # important if we add back a writer after a long time; we want to
             # consider that a "new" writer, rather than using the old stale
             # entry here.
-            sql = """
+            clause, args = make_in_list_sql_clause(
+                self._db.engine, "instance_name", self._writers, negative=True
+            )
+
+            sql = f"""
                 DELETE FROM stream_positions
                 WHERE
                     stream_name = ?
-                    AND instance_name != ALL(?)
+                    AND {clause}
             """
-            cur.execute(sql, (self._stream_name, self._writers))
+            cur.execute(sql, [self._stream_name] + args)
 
             sql = """
                 SELECT instance_name, stream_id FROM stream_positions
@@ -508,12 +526,16 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
             # We add a GREATEST here to ensure that the result is always
             # positive. (This can be a problem for e.g. backfill streams where
             # the server has never backfilled).
+            greatest_func = (
+                "GREATEST" if isinstance(self._db.engine, PostgresEngine) else "MAX"
+            )
             max_stream_id = 1
             for table, _, id_column in tables:
                 sql = """
-                    SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
+                    SELECT %(greatest_func)s(COALESCE(%(agg)s(%(id)s), 1), 1)
                     FROM %(table)s
                 """ % {
+                    "greatest_func": greatest_func,
                     "id": id_column,
                     "table": table,
                     "agg": "MAX" if self._positive else "-MIN",
@@ -913,6 +935,11 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
 
         # We upsert the value, ensuring on conflict that we always increase the
         # value (or decrease if stream goes backwards).
+        if isinstance(self._db.engine, PostgresEngine):
+            agg = "GREATEST" if self._positive else "LEAST"
+        else:
+            agg = "MAX" if self._positive else "MIN"
+
         sql = """
             INSERT INTO stream_positions (stream_name, instance_name, stream_id)
             VALUES (?, ?, ?)
@@ -920,10 +947,10 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
             DO UPDATE SET
                 stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id)
         """ % {
-            "agg": "GREATEST" if self._positive else "LEAST",
+            "agg": agg,
         }
 
-        pos = (self.get_current_token_for_writer(self._instance_name),)
+        pos = self.get_current_token_for_writer(self._instance_name)
         txn.execute(sql, (self._stream_name, self._instance_name, pos))
 
 
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 409d856ab9..fad9511cea 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -31,6 +31,11 @@ from synapse.storage.database import (
 from synapse.storage.engines import IncorrectDatabaseSetup
 from synapse.storage.types import Cursor
 from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.sequence import (
+    LocalSequenceGenerator,
+    PostgresSequenceGenerator,
+    SequenceGenerator,
+)
 from synapse.util import Clock
 
 from tests.unittest import HomeserverTestCase
@@ -175,18 +180,22 @@ class StreamIdGeneratorTestCase(HomeserverTestCase):
         self.get_success(test_gen_next())
 
 
-class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
-    if not USE_POSTGRES_FOR_TESTS:
-        skip = "Requires Postgres"
-
+class MultiWriterIdGeneratorBase(HomeserverTestCase):
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
         self.db_pool: DatabasePool = self.store.db_pool
 
         self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
 
+        if USE_POSTGRES_FOR_TESTS:
+            self.seq_gen: SequenceGenerator = PostgresSequenceGenerator("foobar_seq")
+        else:
+            self.seq_gen = LocalSequenceGenerator(lambda _: 0)
+
     def _setup_db(self, txn: LoggingTransaction) -> None:
-        txn.execute("CREATE SEQUENCE foobar_seq")
+        if USE_POSTGRES_FOR_TESTS:
+            txn.execute("CREATE SEQUENCE foobar_seq")
+
         txn.execute(
             """
             CREATE TABLE foobar (
@@ -221,44 +230,27 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         def _insert(txn: LoggingTransaction) -> None:
             for _ in range(number):
+                next_val = self.seq_gen.get_next_id_txn(txn)
                 txn.execute(
-                    "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
-                    (instance_name,),
+                    "INSERT INTO foobar (stream_id, instance_name) VALUES (?, ?)",
+                    (
+                        next_val,
+                        instance_name,
+                    ),
                 )
+
                 txn.execute(
                     """
-                    INSERT INTO stream_positions VALUES ('test_stream', ?,  lastval())
-                    ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
+                    INSERT INTO stream_positions VALUES ('test_stream', ?,  ?)
+                    ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
                     """,
-                    (instance_name,),
+                    (instance_name, next_val, next_val),
                 )
 
         self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
 
-    def _insert_row_with_id(self, instance_name: str, stream_id: int) -> None:
-        """Insert one row as the given instance with given stream_id, updating
-        the postgres sequence position to match.
-        """
-
-        def _insert(txn: LoggingTransaction) -> None:
-            txn.execute(
-                "INSERT INTO foobar VALUES (?, ?)",
-                (
-                    stream_id,
-                    instance_name,
-                ),
-            )
-            txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
-            txn.execute(
-                """
-                INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
-                ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
-                """,
-                (instance_name, stream_id, stream_id),
-            )
-
-        self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
 
+class MultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
     def test_empty(self) -> None:
         """Test an ID generator against an empty database gives sensible
         current positions.
@@ -347,137 +339,106 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen.get_positions(), {"master": 11})
         self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
 
-    def test_multi_instance(self) -> None:
-        """Test that reads and writes from multiple processes are handled
-        correctly.
-        """
-        self._insert_rows("first", 3)
-        self._insert_rows("second", 4)
+    def test_get_next_txn(self) -> None:
+        """Test that the `get_next_txn` function works correctly."""
 
-        first_id_gen = self._create_id_generator("first", writers=["first", "second"])
-        second_id_gen = self._create_id_generator("second", writers=["first", "second"])
+        # Prefill table with 7 rows written by 'master'
+        self._insert_rows("master", 7)
 
-        self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
-        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
-        self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
+        id_gen = self._create_id_generator()
 
-        self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
-        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
-        self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
+        self.assertEqual(id_gen.get_positions(), {"master": 7})
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
 
         # Try allocating a new ID gen and check that we only see position
         # advanced after we leave the context manager.
 
-        async def _get_next_async() -> None:
-            async with first_id_gen.get_next() as stream_id:
-                self.assertEqual(stream_id, 8)
-
-                self.assertEqual(
-                    first_id_gen.get_positions(), {"first": 3, "second": 7}
-                )
-                self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
-
-        self.get_success(_get_next_async())
-
-        self.assertEqual(first_id_gen.get_positions(), {"first": 8, "second": 7})
-
-        # However the ID gen on the second instance won't have seen the update
-        self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
-
-        # ... but calling `get_next` on the second instance should give a unique
-        # stream ID
+        def _get_next_txn(txn: LoggingTransaction) -> None:
+            stream_id = id_gen.get_next_txn(txn)
+            self.assertEqual(stream_id, 8)
 
-        async def _get_next_async2() -> None:
-            async with second_id_gen.get_next() as stream_id:
-                self.assertEqual(stream_id, 9)
+            self.assertEqual(id_gen.get_positions(), {"master": 7})
+            self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
 
-                self.assertEqual(
-                    second_id_gen.get_positions(), {"first": 3, "second": 7}
-                )
+        self.get_success(self.db_pool.runInteraction("test", _get_next_txn))
 
-        self.get_success(_get_next_async2())
+        self.assertEqual(id_gen.get_positions(), {"master": 8})
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
 
-        self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9})
+    def test_restart_during_out_of_order_persistence(self) -> None:
+        """Test that restarting a process while another process is writing out
+        of order updates are handled correctly.
+        """
 
-        # If the second ID gen gets told about the first, it correctly updates
-        second_id_gen.advance("first", 8)
-        self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
+        # Prefill table with 7 rows written by 'master'
+        self._insert_rows("master", 7)
 
-    def test_multi_instance_empty_row(self) -> None:
-        """Test that reads and writes from multiple processes are handled
-        correctly, when one of the writers starts without any rows.
-        """
-        # Insert some rows for two out of three of the ID gens.
-        self._insert_rows("first", 3)
-        self._insert_rows("second", 4)
+        id_gen = self._create_id_generator()
 
-        first_id_gen = self._create_id_generator(
-            "first", writers=["first", "second", "third"]
-        )
-        second_id_gen = self._create_id_generator(
-            "second", writers=["first", "second", "third"]
-        )
-        third_id_gen = self._create_id_generator(
-            "third", writers=["first", "second", "third"]
-        )
+        self.assertEqual(id_gen.get_positions(), {"master": 7})
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
 
-        self.assertEqual(
-            first_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
-        )
-        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
-        self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
-        self.assertEqual(first_id_gen.get_current_token_for_writer("third"), 7)
+        # Persist two rows at once
+        ctx1 = id_gen.get_next()
+        ctx2 = id_gen.get_next()
 
-        self.assertEqual(
-            second_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
-        )
-        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
-        self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
-        self.assertEqual(second_id_gen.get_current_token_for_writer("third"), 7)
+        s1 = self.get_success(ctx1.__aenter__())
+        s2 = self.get_success(ctx2.__aenter__())
 
-        # Try allocating a new ID gen and check that we only see position
-        # advanced after we leave the context manager.
+        self.assertEqual(s1, 8)
+        self.assertEqual(s2, 9)
 
-        async def _get_next_async() -> None:
-            async with third_id_gen.get_next() as stream_id:
-                self.assertEqual(stream_id, 8)
+        self.assertEqual(id_gen.get_positions(), {"master": 7})
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
 
-                self.assertEqual(
-                    third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
-                )
-                self.assertEqual(third_id_gen.get_persisted_upto_position(), 7)
+        # We finish persisting the second row before restart
+        self.get_success(ctx2.__aexit__(None, None, None))
 
-        self.get_success(_get_next_async())
+        # We simulate a restart of another worker by just creating a new ID gen.
+        id_gen_worker = self._create_id_generator("worker")
 
-        self.assertEqual(
-            third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 8}
-        )
+        # Restarted worker should not see the second persisted row
+        self.assertEqual(id_gen_worker.get_positions(), {"master": 7})
+        self.assertEqual(id_gen_worker.get_current_token_for_writer("master"), 7)
 
-    def test_get_next_txn(self) -> None:
-        """Test that the `get_next_txn` function works correctly."""
+        # Now if we persist the first row then both instances should jump ahead
+        # correctly.
+        self.get_success(ctx1.__aexit__(None, None, None))
 
-        # Prefill table with 7 rows written by 'master'
-        self._insert_rows("master", 7)
+        self.assertEqual(id_gen.get_positions(), {"master": 9})
+        id_gen_worker.advance("master", 9)
+        self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
 
-        id_gen = self._create_id_generator()
 
-        self.assertEqual(id_gen.get_positions(), {"master": 7})
-        self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
+    if not USE_POSTGRES_FOR_TESTS:
+        skip = "Requires Postgres"
 
-        # Try allocating a new ID gen and check that we only see position
-        # advanced after we leave the context manager.
+    def _insert_row_with_id(self, instance_name: str, stream_id: int) -> None:
+        """Insert one row as the given instance with given stream_id, updating
+        the postgres sequence position to match.
+        """
 
-        def _get_next_txn(txn: LoggingTransaction) -> None:
-            stream_id = id_gen.get_next_txn(txn)
-            self.assertEqual(stream_id, 8)
+        def _insert(txn: LoggingTransaction) -> None:
+            txn.execute(
+                "INSERT INTO foobar (stream_id, instance_name) VALUES (?, ?)",
+                (
+                    stream_id,
+                    instance_name,
+                ),
+            )
 
-            self.assertEqual(id_gen.get_positions(), {"master": 7})
-            self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+            txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
 
-        self.get_success(self.db_pool.runInteraction("test", _get_next_txn))
+            txn.execute(
+                """
+                INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
+                ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
+                """,
+                (instance_name, stream_id, stream_id),
+            )
 
-        self.assertEqual(id_gen.get_positions(), {"master": 8})
-        self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
+        self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
 
     def test_get_persisted_upto_position(self) -> None:
         """Test that `get_persisted_upto_position` correctly tracks updates to
@@ -548,49 +509,111 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         # `persisted_upto_position` in this case, then it will be correct in the
         # other cases that are tested above (since they'll hit the same code).
 
-    def test_restart_during_out_of_order_persistence(self) -> None:
-        """Test that restarting a process while another process is writing out
-        of order updates are handled correctly.
+    def test_multi_instance(self) -> None:
+        """Test that reads and writes from multiple processes are handled
+        correctly.
         """
+        self._insert_rows("first", 3)
+        self._insert_rows("second", 4)
 
-        # Prefill table with 7 rows written by 'master'
-        self._insert_rows("master", 7)
+        first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+        second_id_gen = self._create_id_generator("second", writers=["first", "second"])
 
-        id_gen = self._create_id_generator()
+        self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
+        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
+        self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
 
-        self.assertEqual(id_gen.get_positions(), {"master": 7})
-        self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+        self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
 
-        # Persist two rows at once
-        ctx1 = id_gen.get_next()
-        ctx2 = id_gen.get_next()
+        # Try allocating a new ID gen and check that we only see position
+        # advanced after we leave the context manager.
 
-        s1 = self.get_success(ctx1.__aenter__())
-        s2 = self.get_success(ctx2.__aenter__())
+        async def _get_next_async() -> None:
+            async with first_id_gen.get_next() as stream_id:
+                self.assertEqual(stream_id, 8)
 
-        self.assertEqual(s1, 8)
-        self.assertEqual(s2, 9)
+                self.assertEqual(
+                    first_id_gen.get_positions(), {"first": 3, "second": 7}
+                )
+                self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
 
-        self.assertEqual(id_gen.get_positions(), {"master": 7})
-        self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+        self.get_success(_get_next_async())
 
-        # We finish persisting the second row before restart
-        self.get_success(ctx2.__aexit__(None, None, None))
+        self.assertEqual(first_id_gen.get_positions(), {"first": 8, "second": 7})
 
-        # We simulate a restart of another worker by just creating a new ID gen.
-        id_gen_worker = self._create_id_generator("worker")
+        # However the ID gen on the second instance won't have seen the update
+        self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
 
-        # Restarted worker should not see the second persisted row
-        self.assertEqual(id_gen_worker.get_positions(), {"master": 7})
-        self.assertEqual(id_gen_worker.get_current_token_for_writer("master"), 7)
+        # ... but calling `get_next` on the second instance should give a unique
+        # stream ID
 
-        # Now if we persist the first row then both instances should jump ahead
-        # correctly.
-        self.get_success(ctx1.__aexit__(None, None, None))
+        async def _get_next_async2() -> None:
+            async with second_id_gen.get_next() as stream_id:
+                self.assertEqual(stream_id, 9)
 
-        self.assertEqual(id_gen.get_positions(), {"master": 9})
-        id_gen_worker.advance("master", 9)
-        self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
+                self.assertEqual(
+                    second_id_gen.get_positions(), {"first": 3, "second": 7}
+                )
+
+        self.get_success(_get_next_async2())
+
+        self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9})
+
+        # If the second ID gen gets told about the first, it correctly updates
+        second_id_gen.advance("first", 8)
+        self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
+
+    def test_multi_instance_empty_row(self) -> None:
+        """Test that reads and writes from multiple processes are handled
+        correctly, when one of the writers starts without any rows.
+        """
+        # Insert some rows for two out of three of the ID gens.
+        self._insert_rows("first", 3)
+        self._insert_rows("second", 4)
+
+        first_id_gen = self._create_id_generator(
+            "first", writers=["first", "second", "third"]
+        )
+        second_id_gen = self._create_id_generator(
+            "second", writers=["first", "second", "third"]
+        )
+        third_id_gen = self._create_id_generator(
+            "third", writers=["first", "second", "third"]
+        )
+
+        self.assertEqual(
+            first_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
+        )
+        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
+        self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
+        self.assertEqual(first_id_gen.get_current_token_for_writer("third"), 7)
+
+        self.assertEqual(
+            second_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
+        )
+        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("third"), 7)
+
+        # Try allocating a new ID gen and check that we only see position
+        # advanced after we leave the context manager.
+
+        async def _get_next_async() -> None:
+            async with third_id_gen.get_next() as stream_id:
+                self.assertEqual(stream_id, 8)
+
+                self.assertEqual(
+                    third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
+                )
+                self.assertEqual(third_id_gen.get_persisted_upto_position(), 7)
+
+        self.get_success(_get_next_async())
+
+        self.assertEqual(
+            third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 8}
+        )
 
     def test_writer_config_change(self) -> None:
         """Test that changing the writer config correctly works."""