diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 133c0e7a28..39a3ab1162 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -17,7 +17,7 @@ import logging
import threading
from collections import deque
from contextlib import contextmanager
-from typing import Dict, List, Optional, Set, Union
+from typing import Dict, List, Optional, Set, Tuple, Union
import attr
from typing_extensions import Deque
@@ -186,11 +186,12 @@ class MultiWriterIdGenerator:
Args:
db_conn
db
- stream_name: A name for the stream.
+ stream_name: A name for the stream, for use in the `stream_positions`
+ table. (Does not need to be the same as the replication stream name)
instance_name: The name of this instance.
- table: Database table associated with stream.
- instance_column: Column that stores the row's writer's instance name
- id_column: Column that stores the stream ID.
+ tables: List of tables associated with the stream. Tuple of table
+ name, column name that stores the writer's instance name, and
+ column name that stores the stream ID.
sequence_name: The name of the postgres sequence used to generate new
IDs.
writers: A list of known writers to use to populate current positions
@@ -206,9 +207,7 @@ class MultiWriterIdGenerator:
db: DatabasePool,
stream_name: str,
instance_name: str,
- table: str,
- instance_column: str,
- id_column: str,
+ tables: List[Tuple[str, str, str]],
sequence_name: str,
writers: List[str],
positive: bool = True,
@@ -260,15 +259,16 @@ class MultiWriterIdGenerator:
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
# We check that the table and sequence haven't diverged.
- self._sequence_gen.check_consistency(
- db_conn, table=table, id_column=id_column, positive=positive
- )
+ for table, _, id_column in tables:
+ self._sequence_gen.check_consistency(
+ db_conn, table=table, id_column=id_column, positive=positive
+ )
# This goes and fills out the above state from the database.
- self._load_current_ids(db_conn, table, instance_column, id_column)
+ self._load_current_ids(db_conn, tables)
def _load_current_ids(
- self, db_conn, table: str, instance_column: str, id_column: str
+ self, db_conn, tables: List[Tuple[str, str, str]],
):
cur = db_conn.cursor(txn_name="_load_current_ids")
@@ -306,17 +306,22 @@ class MultiWriterIdGenerator:
# We add a GREATEST here to ensure that the result is always
# positive. (This can be a problem for e.g. backfill streams where
# the server has never backfilled).
- sql = """
- SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
- FROM %(table)s
- """ % {
- "id": id_column,
- "table": table,
- "agg": "MAX" if self._positive else "-MIN",
- }
- cur.execute(sql)
- (stream_id,) = cur.fetchone()
- self._persisted_upto_position = stream_id
+ max_stream_id = 1
+ for table, _, id_column in tables:
+ sql = """
+ SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
+ FROM %(table)s
+ """ % {
+ "id": id_column,
+ "table": table,
+ "agg": "MAX" if self._positive else "-MIN",
+ }
+ cur.execute(sql)
+ (stream_id,) = cur.fetchone()
+
+ max_stream_id = max(max_stream_id, stream_id)
+
+ self._persisted_upto_position = max_stream_id
else:
# If we have a min_stream_id then we pull out everything greater
# than it from the DB so that we can prefill
@@ -329,21 +334,28 @@ class MultiWriterIdGenerator:
# stream positions table before restart (or the stream position
# table otherwise got out of date).
- sql = """
- SELECT %(instance)s, %(id)s FROM %(table)s
- WHERE ? %(cmp)s %(id)s
- """ % {
- "id": id_column,
- "table": table,
- "instance": instance_column,
- "cmp": "<=" if self._positive else ">=",
- }
- cur.execute(sql, (min_stream_id * self._return_factor,))
-
self._persisted_upto_position = min_stream_id
+ rows = []
+ for table, instance_column, id_column in tables:
+ sql = """
+ SELECT %(instance)s, %(id)s FROM %(table)s
+ WHERE ? %(cmp)s %(id)s
+ """ % {
+ "id": id_column,
+ "table": table,
+ "instance": instance_column,
+ "cmp": "<=" if self._positive else ">=",
+ }
+ cur.execute(sql, (min_stream_id * self._return_factor,))
+
+ rows.extend(cur)
+
+ # Sort so that we handle rows in order for each instance.
+ rows.sort()
+
with self._lock:
- for (instance, stream_id,) in cur:
+ for (instance, stream_id,) in rows:
stream_id = self._return_factor * stream_id
self._add_persisted_position(stream_id)
|