summary refs log tree commit diff
path: root/synapse/storage/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/util')
-rw-r--r--synapse/storage/util/id_generators.py8
-rw-r--r--synapse/storage/util/sequence.py98
2 files changed, 102 insertions, 4 deletions
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index f89ce0bed2..787cebfbec 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -21,6 +21,7 @@ from typing import Dict, Set, Tuple
 from typing_extensions import Deque
 
 from synapse.storage.database import Database, LoggingTransaction
+from synapse.storage.util.sequence import PostgresSequenceGenerator
 
 
 class IdGenerator(object):
@@ -247,7 +248,6 @@ class MultiWriterIdGenerator:
     ):
         self._db = db
         self._instance_name = instance_name
-        self._sequence_name = sequence_name
 
         # We lock as some functions may be called from DB threads.
         self._lock = threading.Lock()
@@ -260,6 +260,8 @@ class MultiWriterIdGenerator:
         # should be less than the minimum of this set (if not empty).
         self._unfinished_ids = set()  # type: Set[int]
 
+        self._sequence_gen = PostgresSequenceGenerator(sequence_name)
+
     def _load_current_ids(
         self, db_conn, table: str, instance_column: str, id_column: str
     ) -> Dict[str, int]:
@@ -283,9 +285,7 @@ class MultiWriterIdGenerator:
         return current_positions
 
     def _load_next_id_txn(self, txn):
-        txn.execute("SELECT nextval(?)", (self._sequence_name,))
-        (next_id,) = txn.fetchone()
-        return next_id
+        return self._sequence_gen.get_next_id_txn(txn)
 
     async def get_next(self):
         """
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
new file mode 100644
index 0000000000..63dfea4220
--- /dev/null
+++ b/synapse/storage/util/sequence.py
@@ -0,0 +1,98 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import abc
+import threading
+from typing import Callable, Optional
+
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.types import Cursor
+
+
+class SequenceGenerator(metaclass=abc.ABCMeta):
+    """A class which generates a unique sequence of integers"""
+
+    @abc.abstractmethod
+    def get_next_id_txn(self, txn: Cursor) -> int:
+        """Gets the next ID in the sequence"""
+        ...
+
+
+class PostgresSequenceGenerator(SequenceGenerator):
+    """An implementation of SequenceGenerator which uses a postgres sequence"""
+
+    def __init__(self, sequence_name: str):
+        self._sequence_name = sequence_name
+
+    def get_next_id_txn(self, txn: Cursor) -> int:
+        txn.execute("SELECT nextval(?)", (self._sequence_name,))
+        return txn.fetchone()[0]
+
+
+GetFirstCallbackType = Callable[[Cursor], int]
+
+
+class LocalSequenceGenerator(SequenceGenerator):
+    """An implementation of SequenceGenerator which uses local locking
+
+    This only works reliably if there are no other worker processes generating IDs at
+    the same time.
+    """
+
+    def __init__(self, get_first_callback: GetFirstCallbackType):
+        """
+        Args:
+            get_first_callback: a callback which is called on the first call to
+                 get_next_id_txn; should return the curreent maximum id
+        """
+        # the callback. this is cleared after it is called, so that it can be GCed.
+        self._callback = get_first_callback  # type: Optional[GetFirstCallbackType]
+
+        # The current max value, or None if we haven't looked in the DB yet.
+        self._current_max_id = None  # type: Optional[int]
+        self._lock = threading.Lock()
+
+    def get_next_id_txn(self, txn: Cursor) -> int:
+        # We do application locking here since if we're using sqlite then
+        # we are a single process synapse.
+        with self._lock:
+            if self._current_max_id is None:
+                assert self._callback is not None
+                self._current_max_id = self._callback(txn)
+                self._callback = None
+
+            self._current_max_id += 1
+            return self._current_max_id
+
+
+def build_sequence_generator(
+    database_engine: BaseDatabaseEngine,
+    get_first_callback: GetFirstCallbackType,
+    sequence_name: str,
+) -> SequenceGenerator:
+    """Get the best impl of SequenceGenerator available
+
+    This uses PostgresSequenceGenerator on postgres, and a locally-locked impl on
+    sqlite.
+
+    Args:
+        database_engine: the database engine we are connected to
+        get_first_callback: a callback which gets the next sequence ID. Used if
+            we're on sqlite.
+        sequence_name: the name of a postgres sequence to use.
+    """
+    if isinstance(database_engine, PostgresEngine):
+        return PostgresSequenceGenerator(sequence_name)
+    else:
+        return LocalSequenceGenerator(get_first_callback)