diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index b27a4843d0..9f3d23f0a5 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -185,6 +185,8 @@ class MultiWriterIdGenerator:
id_column: Column that stores the stream ID.
sequence_name: The name of the postgres sequence used to generate new
IDs.
+ positive: Whether the IDs are positive (true) or negative (false).
+ When using negative IDs we go backwards from -1 to -2, -3, etc.
"""
def __init__(
@@ -196,13 +198,19 @@ class MultiWriterIdGenerator:
instance_column: str,
id_column: str,
sequence_name: str,
+ positive: bool = True,
):
self._db = db
self._instance_name = instance_name
+ self._positive = positive
+ self._return_factor = 1 if positive else -1
# We lock as some functions may be called from DB threads.
self._lock = threading.Lock()
+ # Note: If we are a negative stream then we still store all the IDs as
+ # positive to make life easier for us, and simply negate the IDs when we
+ # return them.
self._current_positions = self._load_current_ids(
db_conn, table, instance_column, id_column
)
@@ -233,13 +241,16 @@ class MultiWriterIdGenerator:
def _load_current_ids(
self, db_conn, table: str, instance_column: str, id_column: str
) -> Dict[str, int]:
+ # If positive stream aggregate via MAX. For negative stream use MIN
+ # *and* negate the result to get a positive number.
sql = """
- SELECT %(instance)s, MAX(%(id)s) FROM %(table)s
+ SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s
GROUP BY %(instance)s
""" % {
"instance": instance_column,
"id": id_column,
"table": table,
+ "agg": "MAX" if self._positive else "-MIN",
}
cur = db_conn.cursor()
@@ -269,15 +280,16 @@ class MultiWriterIdGenerator:
# Assert the fetched ID is actually greater than what we currently
# believe the ID to be. If not, then the sequence and table have got
# out of sync somehow.
- assert self.get_current_token_for_writer(self._instance_name) < next_id
-
with self._lock:
+ assert self._current_positions.get(self._instance_name, 0) < next_id
+
self._unfinished_ids.add(next_id)
@contextlib.contextmanager
def manager():
try:
- yield next_id
+ # Multiply by the return factor so that the ID has correct sign.
+ yield self._return_factor * next_id
finally:
self._mark_id_as_finished(next_id)
@@ -296,15 +308,15 @@ class MultiWriterIdGenerator:
# Assert the fetched ID is actually greater than any ID we've already
# seen. If not, then the sequence and table have got out of sync
# somehow.
- assert max(self.get_positions().values(), default=0) < min(next_ids)
-
with self._lock:
+ assert max(self._current_positions.values(), default=0) < min(next_ids)
+
self._unfinished_ids.update(next_ids)
@contextlib.contextmanager
def manager():
try:
- yield next_ids
+ yield [self._return_factor * i for i in next_ids]
finally:
for i in next_ids:
self._mark_id_as_finished(i)
@@ -327,7 +339,7 @@ class MultiWriterIdGenerator:
txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id)
- return next_id
+ return self._return_factor * next_id
def _mark_id_as_finished(self, next_id: int):
"""The ID has finished being processed so we should advance the
@@ -359,20 +371,25 @@ class MultiWriterIdGenerator:
"""
with self._lock:
- return self._current_positions.get(instance_name, 0)
+ return self._return_factor * self._current_positions.get(instance_name, 0)
def get_positions(self) -> Dict[str, int]:
"""Get a copy of the current positon map.
"""
with self._lock:
- return dict(self._current_positions)
+ return {
+ name: self._return_factor * i
+ for name, i in self._current_positions.items()
+ }
def advance(self, instance_name: str, new_id: int):
"""Advance the postion of the named writer to the given ID, if greater
than existing entry.
"""
+ new_id *= self._return_factor
+
with self._lock:
self._current_positions[instance_name] = max(
new_id, self._current_positions.get(instance_name, 0)
@@ -390,7 +407,7 @@ class MultiWriterIdGenerator:
"""
with self._lock:
- return self._persisted_upto_position
+ return self._return_factor * self._persisted_upto_position
def _add_persisted_position(self, new_id: int):
"""Record that we have persisted a position.
|