summary refs log tree commit diff
path: root/synapse/storage/util/sequence.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/util/sequence.py')
-rw-r--r--synapse/storage/util/sequence.py104
1 files changed, 104 insertions, 0 deletions
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
new file mode 100644
index 0000000000..ffc1894748
--- /dev/null
+++ b/synapse/storage/util/sequence.py
@@ -0,0 +1,104 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import abc
+import threading
+from typing import Callable, List, Optional
+
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.types import Cursor
+
+
+class SequenceGenerator(metaclass=abc.ABCMeta):
+    """A class which generates a unique sequence of integers"""
+
+    @abc.abstractmethod
+    def get_next_id_txn(self, txn: Cursor) -> int:
+        """Gets the next ID in the sequence"""
+        ...
+
+
+class PostgresSequenceGenerator(SequenceGenerator):
+    """An implementation of SequenceGenerator which uses a postgres sequence"""
+
+    def __init__(self, sequence_name: str):
+        self._sequence_name = sequence_name
+
+    def get_next_id_txn(self, txn: Cursor) -> int:
+        txn.execute("SELECT nextval(?)", (self._sequence_name,))
+        return txn.fetchone()[0]
+
+    def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+        txn.execute(
+            "SELECT nextval(?) FROM generate_series(1, ?)", (self._sequence_name, n)
+        )
+        return [i for (i,) in txn]
+
+
+GetFirstCallbackType = Callable[[Cursor], int]
+
+
+class LocalSequenceGenerator(SequenceGenerator):
+    """An implementation of SequenceGenerator which uses local locking
+
+    This only works reliably if there are no other worker processes generating IDs at
+    the same time.
+    """
+
+    def __init__(self, get_first_callback: GetFirstCallbackType):
+        """
+        Args:
+            get_first_callback: a callback which is called on the first call to
+                 get_next_id_txn; should return the curreent maximum id
+        """
+        # the callback. this is cleared after it is called, so that it can be GCed.
+        self._callback = get_first_callback  # type: Optional[GetFirstCallbackType]
+
+        # The current max value, or None if we haven't looked in the DB yet.
+        self._current_max_id = None  # type: Optional[int]
+        self._lock = threading.Lock()
+
+    def get_next_id_txn(self, txn: Cursor) -> int:
+        # We do application locking here since if we're using sqlite then
+        # we are a single process synapse.
+        with self._lock:
+            if self._current_max_id is None:
+                assert self._callback is not None
+                self._current_max_id = self._callback(txn)
+                self._callback = None
+
+            self._current_max_id += 1
+            return self._current_max_id
+
+
+def build_sequence_generator(
+    database_engine: BaseDatabaseEngine,
+    get_first_callback: GetFirstCallbackType,
+    sequence_name: str,
+) -> SequenceGenerator:
+    """Get the best impl of SequenceGenerator available
+
+    This uses PostgresSequenceGenerator on postgres, and a locally-locked impl on
+    sqlite.
+
+    Args:
+        database_engine: the database engine we are connected to
+        get_first_callback: a callback which gets the next sequence ID. Used if
+            we're on sqlite.
+        sequence_name: the name of a postgres sequence to use.
+    """
+    if isinstance(database_engine, PostgresEngine):
+        return PostgresSequenceGenerator(sequence_name)
+    else:
+        return LocalSequenceGenerator(get_first_callback)