diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index ad017207aa..3d8da48f2d 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -55,7 +55,7 @@ def _load_current_id(db_conn, table, column, step=1):
"""
# debug logging for https://github.com/matrix-org/synapse/issues/7968
logger.info("initialising stream generator for %s(%s)", table, column)
- cur = db_conn.cursor()
+ cur = db_conn.cursor(txn_name="_load_current_id")
if step == 1:
cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
else:
@@ -270,7 +270,7 @@ class MultiWriterIdGenerator:
def _load_current_ids(
self, db_conn, table: str, instance_column: str, id_column: str
):
- cur = db_conn.cursor()
+ cur = db_conn.cursor(txn_name="_load_current_ids")
# Load the current positions of all writers for the stream.
if self._writers:
@@ -284,15 +284,12 @@ class MultiWriterIdGenerator:
stream_name = ?
AND instance_name != ALL(?)
"""
- sql = self._db.engine.convert_param_style(sql)
cur.execute(sql, (self._stream_name, self._writers))
sql = """
SELECT instance_name, stream_id FROM stream_positions
WHERE stream_name = ?
"""
- sql = self._db.engine.convert_param_style(sql)
-
cur.execute(sql, (self._stream_name,))
self._current_positions = {
@@ -341,7 +338,6 @@ class MultiWriterIdGenerator:
"instance": instance_column,
"cmp": "<=" if self._positive else ">=",
}
- sql = self._db.engine.convert_param_style(sql)
cur.execute(sql, (min_stream_id * self._return_factor,))
self._persisted_upto_position = min_stream_id
@@ -422,7 +418,7 @@ class MultiWriterIdGenerator:
self._unfinished_ids.discard(next_id)
self._finished_ids.add(next_id)
- new_cur = None
+ new_cur = None # type: Optional[int]
if self._unfinished_ids:
# If there are unfinished IDs then the new position will be the
@@ -528,6 +524,16 @@ class MultiWriterIdGenerator:
heapq.heappush(self._known_persisted_positions, new_id)
+ # If we're a writer and we don't have any active writes we update our
+ # current position to the latest position seen. This allows the instance
+ # to report a recent position when asked, rather than a potentially old
+ # one (if this instance hasn't written anything for a while).
+ our_current_position = self._current_positions.get(self._instance_name)
+ if our_current_position and not self._unfinished_ids:
+ self._current_positions[self._instance_name] = max(
+ our_current_position, new_id
+ )
+
# We move the current min position up if the minimum current positions
# of all instances is higher (since by definition all positions less
# that that have been persisted).
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 2dd95e2709..4386b6101e 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -17,6 +17,7 @@ import logging
import threading
from typing import Callable, List, Optional
+from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import (
BaseDatabaseEngine,
IncorrectDatabaseSetup,
@@ -53,7 +54,11 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
@abc.abstractmethod
def check_consistency(
- self, db_conn: Connection, table: str, id_column: str, positive: bool = True
+ self,
+ db_conn: LoggingDatabaseConnection,
+ table: str,
+ id_column: str,
+ positive: bool = True,
):
"""Should be called during start up to test that the current value of
the sequence is greater than or equal to the maximum ID in the table.
@@ -82,9 +87,13 @@ class PostgresSequenceGenerator(SequenceGenerator):
return [i for (i,) in txn]
def check_consistency(
- self, db_conn: Connection, table: str, id_column: str, positive: bool = True
+ self,
+ db_conn: LoggingDatabaseConnection,
+ table: str,
+ id_column: str,
+ positive: bool = True,
):
- txn = db_conn.cursor()
+ txn = db_conn.cursor(txn_name="sequence.check_consistency")
# First we get the current max ID from the table.
table_sql = "SELECT GREATEST(%(agg)s(%(id)s), 0) FROM %(table)s" % {
@@ -117,6 +126,8 @@ class PostgresSequenceGenerator(SequenceGenerator):
if max_stream_id > last_value:
logger.warning(
"Postgres sequence %s is behind table %s: %d < %d",
+ self._sequence_name,
+ table,
last_value,
max_stream_id,
)
|