diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
new file mode 100644
index 0000000000..63dfea4220
--- /dev/null
+++ b/synapse/storage/util/sequence.py
@@ -0,0 +1,98 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import abc
+import threading
+from typing import Callable, Optional
+
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.types import Cursor
+
+
+class SequenceGenerator(metaclass=abc.ABCMeta):
+ """A class which generates a unique sequence of integers"""
+
+ @abc.abstractmethod
+ def get_next_id_txn(self, txn: Cursor) -> int:
+ """Gets the next ID in the sequence"""
+ ...
+
+
+class PostgresSequenceGenerator(SequenceGenerator):
+ """An implementation of SequenceGenerator which uses a postgres sequence"""
+
+ def __init__(self, sequence_name: str):
+ self._sequence_name = sequence_name
+
+ def get_next_id_txn(self, txn: Cursor) -> int:
+ txn.execute("SELECT nextval(?)", (self._sequence_name,))
+ return txn.fetchone()[0]
+
+
+GetFirstCallbackType = Callable[[Cursor], int]
+
+
+class LocalSequenceGenerator(SequenceGenerator):
+ """An implementation of SequenceGenerator which uses local locking
+
+ This only works reliably if there are no other worker processes generating IDs at
+ the same time.
+ """
+
+ def __init__(self, get_first_callback: GetFirstCallbackType):
+ """
+ Args:
+ get_first_callback: a callback which is called on the first call to
+ get_next_id_txn; should return the curreent maximum id
+ """
+ # the callback. this is cleared after it is called, so that it can be GCed.
+ self._callback = get_first_callback # type: Optional[GetFirstCallbackType]
+
+ # The current max value, or None if we haven't looked in the DB yet.
+ self._current_max_id = None # type: Optional[int]
+ self._lock = threading.Lock()
+
+ def get_next_id_txn(self, txn: Cursor) -> int:
+ # We do application locking here since if we're using sqlite then
+ # we are a single process synapse.
+ with self._lock:
+ if self._current_max_id is None:
+ assert self._callback is not None
+ self._current_max_id = self._callback(txn)
+ self._callback = None
+
+ self._current_max_id += 1
+ return self._current_max_id
+
+
+def build_sequence_generator(
+ database_engine: BaseDatabaseEngine,
+ get_first_callback: GetFirstCallbackType,
+ sequence_name: str,
+) -> SequenceGenerator:
+ """Get the best impl of SequenceGenerator available
+
+ This uses PostgresSequenceGenerator on postgres, and a locally-locked impl on
+ sqlite.
+
+ Args:
+ database_engine: the database engine we are connected to
+ get_first_callback: a callback which gets the next sequence ID. Used if
+ we're on sqlite.
+ sequence_name: the name of a postgres sequence to use.
+ """
+ if isinstance(database_engine, PostgresEngine):
+ return PostgresSequenceGenerator(sequence_name)
+ else:
+ return LocalSequenceGenerator(get_first_callback)
|