diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index f89ce0bed2..5b07847773 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -14,13 +14,15 @@
# limitations under the License.
import contextlib
+import heapq
import threading
from collections import deque
-from typing import Dict, Set, Tuple
+from typing import Dict, List, Set
from typing_extensions import Deque
-from synapse.storage.database import Database, LoggingTransaction
+from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.util.sequence import PostgresSequenceGenerator
class IdGenerator(object):
@@ -79,7 +81,7 @@ class StreamIdGenerator(object):
upwards, -1 to grow downwards.
Usage:
- with stream_id_gen.get_next() as stream_id:
+ with await stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
@@ -94,10 +96,10 @@ class StreamIdGenerator(object):
)
self._unfinished_ids = deque() # type: Deque[int]
- def get_next(self):
+ async def get_next(self):
"""
Usage:
- with stream_id_gen.get_next() as stream_id:
+ with await stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
with self._lock:
@@ -116,10 +118,10 @@ class StreamIdGenerator(object):
return manager()
- def get_next_mult(self, n):
+ async def get_next_mult(self, n):
"""
Usage:
- with stream_id_gen.get_next(n) as stream_ids:
+ with await stream_id_gen.get_next(n) as stream_ids:
# ... persist events ...
"""
with self._lock:
@@ -157,63 +159,13 @@ class StreamIdGenerator(object):
return self._current
+ def get_current_token_for_writer(self, instance_name: str) -> int:
+ """Returns the position of the given writer.
-class ChainedIdGenerator(object):
- """Used to generate new stream ids where the stream must be kept in sync
- with another stream. It generates pairs of IDs, the first element is an
- integer ID for this stream, the second element is the ID for the stream
- that this stream needs to be kept in sync with."""
-
- def __init__(self, chained_generator, db_conn, table, column):
- self.chained_generator = chained_generator
- self._table = table
- self._lock = threading.Lock()
- self._current_max = _load_current_id(db_conn, table, column)
- self._unfinished_ids = deque() # type: Deque[Tuple[int, int]]
-
- def get_next(self):
- """
- Usage:
- with stream_id_gen.get_next() as (stream_id, chained_id):
- # ... persist event ...
+ For streams with single writers this is equivalent to
+ `get_current_token`.
"""
- with self._lock:
- self._current_max += 1
- next_id = self._current_max
- chained_id = self.chained_generator.get_current_token()
-
- self._unfinished_ids.append((next_id, chained_id))
-
- @contextlib.contextmanager
- def manager():
- try:
- yield (next_id, chained_id)
- finally:
- with self._lock:
- self._unfinished_ids.remove((next_id, chained_id))
-
- return manager()
-
- def get_current_token(self):
- """Returns the maximum stream id such that all stream ids less than or
- equal to it have been successfully persisted.
- """
- with self._lock:
- if self._unfinished_ids:
- stream_id, chained_id = self._unfinished_ids[0]
- return stream_id - 1, chained_id
-
- return self._current_max, self.chained_generator.get_current_token()
-
- def advance(self, token: int):
- """Stub implementation for advancing the token when receiving updates
- over replication; raises an exception as this instance should be the
- only source of updates.
- """
-
- raise Exception(
- "Attempted to advance token on source for table %r", self._table
- )
+ return self.get_current_token()
class MultiWriterIdGenerator:
@@ -238,7 +190,7 @@ class MultiWriterIdGenerator:
def __init__(
self,
db_conn,
- db: Database,
+ db: DatabasePool,
instance_name: str,
table: str,
instance_column: str,
@@ -247,7 +199,6 @@ class MultiWriterIdGenerator:
):
self._db = db
self._instance_name = instance_name
- self._sequence_name = sequence_name
# We lock as some functions may be called from DB threads.
self._lock = threading.Lock()
@@ -260,6 +211,25 @@ class MultiWriterIdGenerator:
# should be less than the minimum of this set (if not empty).
self._unfinished_ids = set() # type: Set[int]
+ # We track the max position where we know everything before has been
+ # persisted. This is done by a) looking at the min across all instances
+ # and b) noting that if we have seen a run of persisted positions
+ # without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7).
+ #
+ # Note: There is no guarentee that the IDs generated by the sequence
+ # will be gapless; gaps can form when e.g. a transaction was rolled
+ # back. This means that sometimes we won't be able to skip forward the
+ # position even though everything has been persisted. However, since
+ # gaps should be relatively rare it's still worth doing the book keeping
+ # that allows us to skip forwards when there are gapless runs of
+ # positions.
+ self._persisted_upto_position = (
+ min(self._current_positions.values()) if self._current_positions else 0
+ )
+ self._known_persisted_positions = [] # type: List[int]
+
+ self._sequence_gen = PostgresSequenceGenerator(sequence_name)
+
def _load_current_ids(
self, db_conn, table: str, instance_column: str, id_column: str
) -> Dict[str, int]:
@@ -282,10 +252,11 @@ class MultiWriterIdGenerator:
return current_positions
- def _load_next_id_txn(self, txn):
- txn.execute("SELECT nextval(?)", (self._sequence_name,))
- (next_id,) = txn.fetchone()
- return next_id
+ def _load_next_id_txn(self, txn) -> int:
+ return self._sequence_gen.get_next_id_txn(txn)
+
+ def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
+ return self._sequence_gen.get_next_mult_txn(txn, n)
async def get_next(self):
"""
@@ -298,7 +269,7 @@ class MultiWriterIdGenerator:
# Assert the fetched ID is actually greater than what we currently
# believe the ID to be. If not, then the sequence and table have got
# out of sync somehow.
- assert self.get_current_token() < next_id
+ assert self.get_current_token_for_writer(self._instance_name) < next_id
with self._lock:
self._unfinished_ids.add(next_id)
@@ -312,6 +283,34 @@ class MultiWriterIdGenerator:
return manager()
+ async def get_next_mult(self, n: int):
+ """
+ Usage:
+ with await stream_id_gen.get_next_mult(5) as stream_ids:
+ # ... persist events ...
+ """
+ next_ids = await self._db.runInteraction(
+ "_load_next_mult_id", self._load_next_mult_id_txn, n
+ )
+
+ # Assert the fetched ID is actually greater than any ID we've already
+ # seen. If not, then the sequence and table have got out of sync
+ # somehow.
+ assert max(self.get_positions().values(), default=0) < min(next_ids)
+
+ with self._lock:
+ self._unfinished_ids.update(next_ids)
+
+ @contextlib.contextmanager
+ def manager():
+ try:
+ yield next_ids
+ finally:
+ for i in next_ids:
+ self._mark_id_as_finished(i)
+
+ return manager()
+
def get_next_txn(self, txn: LoggingTransaction):
"""
Usage:
@@ -344,16 +343,18 @@ class MultiWriterIdGenerator:
curr = self._current_positions.get(self._instance_name, 0)
self._current_positions[self._instance_name] = max(curr, next_id)
- def get_current_token(self, instance_name: str = None) -> int:
- """Gets the current position of a named writer (defaults to current
- instance).
-
- Returns 0 if we don't have a position for the named writer (likely due
- to it being a new writer).
+ def get_current_token(self) -> int:
+ """Returns the maximum stream id such that all stream ids less than or
+ equal to it have been successfully persisted.
"""
- if instance_name is None:
- instance_name = self._instance_name
+ # Currently we don't support this operation, as it's not obvious how to
+ # condense the stream positions of multiple writers into a single int.
+ raise NotImplementedError()
+
+ def get_current_token_for_writer(self, instance_name: str) -> int:
+ """Returns the position of the given writer.
+ """
with self._lock:
return self._current_positions.get(instance_name, 0)
@@ -374,3 +375,53 @@ class MultiWriterIdGenerator:
self._current_positions[instance_name] = max(
new_id, self._current_positions.get(instance_name, 0)
)
+
+ self._add_persisted_position(new_id)
+
+ def get_persisted_upto_position(self) -> int:
+ """Get the max position where all previous positions have been
+ persisted.
+
+ Note: In the worst case scenario this will be equal to the minimum
+ position across writers. This means that the returned position here can
+ lag if one writer doesn't write very often.
+ """
+
+ with self._lock:
+ return self._persisted_upto_position
+
+ def _add_persisted_position(self, new_id: int):
+ """Record that we have persisted a position.
+
+ This is used to keep the `_current_positions` up to date.
+ """
+
+ # We require that the lock is locked by caller
+ assert self._lock.locked()
+
+ heapq.heappush(self._known_persisted_positions, new_id)
+
+ # We move the current min position up if the minimum current positions
+ # of all instances is higher (since by definition all positions less
+ # that that have been persisted).
+ min_curr = min(self._current_positions.values())
+ self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
+
+ # We now iterate through the seen positions, discarding those that are
+ # less than the current min positions, and incrementing the min position
+ # if its exactly one greater.
+ #
+ # This is also where we discard items from `_known_persisted_positions`
+ # (to ensure the list doesn't infinitely grow).
+ while self._known_persisted_positions:
+ if self._known_persisted_positions[0] <= self._persisted_upto_position:
+ heapq.heappop(self._known_persisted_positions)
+ elif (
+ self._known_persisted_positions[0] == self._persisted_upto_position + 1
+ ):
+ heapq.heappop(self._known_persisted_positions)
+ self._persisted_upto_position += 1
+ else:
+ # There was a gap in seen positions, so there is nothing more to
+ # do.
+ break
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
new file mode 100644
index 0000000000..ffc1894748
--- /dev/null
+++ b/synapse/storage/util/sequence.py
@@ -0,0 +1,104 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import abc
+import threading
+from typing import Callable, List, Optional
+
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.types import Cursor
+
+
+class SequenceGenerator(metaclass=abc.ABCMeta):
+ """A class which generates a unique sequence of integers"""
+
+ @abc.abstractmethod
+ def get_next_id_txn(self, txn: Cursor) -> int:
+ """Gets the next ID in the sequence"""
+ ...
+
+
+class PostgresSequenceGenerator(SequenceGenerator):
+ """An implementation of SequenceGenerator which uses a postgres sequence"""
+
+ def __init__(self, sequence_name: str):
+ self._sequence_name = sequence_name
+
+ def get_next_id_txn(self, txn: Cursor) -> int:
+ txn.execute("SELECT nextval(?)", (self._sequence_name,))
+ return txn.fetchone()[0]
+
+ def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+ txn.execute(
+ "SELECT nextval(?) FROM generate_series(1, ?)", (self._sequence_name, n)
+ )
+ return [i for (i,) in txn]
+
+
+GetFirstCallbackType = Callable[[Cursor], int]
+
+
+class LocalSequenceGenerator(SequenceGenerator):
+ """An implementation of SequenceGenerator which uses local locking
+
+ This only works reliably if there are no other worker processes generating IDs at
+ the same time.
+ """
+
+ def __init__(self, get_first_callback: GetFirstCallbackType):
+ """
+ Args:
+ get_first_callback: a callback which is called on the first call to
+ get_next_id_txn; should return the curreent maximum id
+ """
+ # the callback. this is cleared after it is called, so that it can be GCed.
+ self._callback = get_first_callback # type: Optional[GetFirstCallbackType]
+
+ # The current max value, or None if we haven't looked in the DB yet.
+ self._current_max_id = None # type: Optional[int]
+ self._lock = threading.Lock()
+
+ def get_next_id_txn(self, txn: Cursor) -> int:
+ # We do application locking here since if we're using sqlite then
+ # we are a single process synapse.
+ with self._lock:
+ if self._current_max_id is None:
+ assert self._callback is not None
+ self._current_max_id = self._callback(txn)
+ self._callback = None
+
+ self._current_max_id += 1
+ return self._current_max_id
+
+
+def build_sequence_generator(
+ database_engine: BaseDatabaseEngine,
+ get_first_callback: GetFirstCallbackType,
+ sequence_name: str,
+) -> SequenceGenerator:
+ """Get the best impl of SequenceGenerator available
+
+ This uses PostgresSequenceGenerator on postgres, and a locally-locked impl on
+ sqlite.
+
+ Args:
+ database_engine: the database engine we are connected to
+ get_first_callback: a callback which gets the next sequence ID. Used if
+ we're on sqlite.
+ sequence_name: the name of a postgres sequence to use.
+ """
+ if isinstance(database_engine, PostgresEngine):
+ return PostgresSequenceGenerator(sequence_name)
+ else:
+ return LocalSequenceGenerator(get_first_callback)
|