diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/storage/util/id_generators.py | 39 |
1 files changed, 28 insertions, 11 deletions
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index b27a4843d0..9f3d23f0a5 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -185,6 +185,8 @@ class MultiWriterIdGenerator: id_column: Column that stores the stream ID. sequence_name: The name of the postgres sequence used to generate new IDs. + positive: Whether the IDs are positive (true) or negative (false). + When using negative IDs we go backwards from -1 to -2, -3, etc. """ def __init__( @@ -196,13 +198,19 @@ class MultiWriterIdGenerator: instance_column: str, id_column: str, sequence_name: str, + positive: bool = True, ): self._db = db self._instance_name = instance_name + self._positive = positive + self._return_factor = 1 if positive else -1 # We lock as some functions may be called from DB threads. self._lock = threading.Lock() + # Note: If we are a negative stream then we still store all the IDs as + # positive to make life easier for us, and simply negate the IDs when we + # return them. self._current_positions = self._load_current_ids( db_conn, table, instance_column, id_column ) @@ -233,13 +241,16 @@ class MultiWriterIdGenerator: def _load_current_ids( self, db_conn, table: str, instance_column: str, id_column: str ) -> Dict[str, int]: + # If positive stream aggregate via MAX. For negative stream use MIN + # *and* negate the result to get a positive number. sql = """ - SELECT %(instance)s, MAX(%(id)s) FROM %(table)s + SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s GROUP BY %(instance)s """ % { "instance": instance_column, "id": id_column, "table": table, + "agg": "MAX" if self._positive else "-MIN", } cur = db_conn.cursor() @@ -269,15 +280,16 @@ class MultiWriterIdGenerator: # Assert the fetched ID is actually greater than what we currently # believe the ID to be. If not, then the sequence and table have got # out of sync somehow. - assert self.get_current_token_for_writer(self._instance_name) < next_id - with self._lock: + assert self._current_positions.get(self._instance_name, 0) < next_id + self._unfinished_ids.add(next_id) @contextlib.contextmanager def manager(): try: - yield next_id + # Multiply by the return factor so that the ID has correct sign. + yield self._return_factor * next_id finally: self._mark_id_as_finished(next_id) @@ -296,15 +308,15 @@ class MultiWriterIdGenerator: # Assert the fetched ID is actually greater than any ID we've already # seen. If not, then the sequence and table have got out of sync # somehow. - assert max(self.get_positions().values(), default=0) < min(next_ids) - with self._lock: + assert max(self._current_positions.values(), default=0) < min(next_ids) + self._unfinished_ids.update(next_ids) @contextlib.contextmanager def manager(): try: - yield next_ids + yield [self._return_factor * i for i in next_ids] finally: for i in next_ids: self._mark_id_as_finished(i) @@ -327,7 +339,7 @@ class MultiWriterIdGenerator: txn.call_after(self._mark_id_as_finished, next_id) txn.call_on_exception(self._mark_id_as_finished, next_id) - return next_id + return self._return_factor * next_id def _mark_id_as_finished(self, next_id: int): """The ID has finished being processed so we should advance the @@ -359,20 +371,25 @@ class MultiWriterIdGenerator: """ with self._lock: - return self._current_positions.get(instance_name, 0) + return self._return_factor * self._current_positions.get(instance_name, 0) def get_positions(self) -> Dict[str, int]: """Get a copy of the current positon map. """ with self._lock: - return dict(self._current_positions) + return { + name: self._return_factor * i + for name, i in self._current_positions.items() + } def advance(self, instance_name: str, new_id: int): """Advance the postion of the named writer to the given ID, if greater than existing entry. """ + new_id *= self._return_factor + with self._lock: self._current_positions[instance_name] = max( new_id, self._current_positions.get(instance_name, 0) @@ -390,7 +407,7 @@ class MultiWriterIdGenerator: """ with self._lock: - return self._persisted_upto_position + return self._return_factor * self._persisted_upto_position def _add_persisted_position(self, new_id: int): """Record that we have persisted a position. |