summary refs log tree commit diff
path: root/synapse/storage/util
diff options
context:
space:
mode:
authorBen Banfield-Zanin <benbz@matrix.org>2021-03-01 10:06:09 +0000
committerBen Banfield-Zanin <benbz@matrix.org>2021-03-01 10:06:09 +0000
commitb26bee9faf957643cd34c4146b250b0009be205d (patch)
treea7a7e29f30acb437d010bdf6116c0f2729f21a1b /synapse/storage/util
parentMerge remote-tracking branch 'origin/release-v1.26.0' into toml/keycloak_hints (diff)
parentFixup changelog (diff)
downloadsynapse-toml/keycloak_hints.tar.xz
Merge remote-tracking branch 'origin/release-v1.28.0' into toml/keycloak_hints github/toml/keycloak_hints toml/keycloak_hints
Diffstat (limited to 'synapse/storage/util')
-rw-r--r--synapse/storage/util/id_generators.py45
-rw-r--r--synapse/storage/util/sequence.py27
2 files changed, 49 insertions, 23 deletions
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index bb84c0d792..d4643c4fdf 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -15,12 +15,11 @@
 import heapq
 import logging
 import threading
-from collections import deque
+from collections import OrderedDict
 from contextlib import contextmanager
 from typing import Dict, List, Optional, Set, Tuple, Union
 
 import attr
-from typing_extensions import Deque
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.database import DatabasePool, LoggingTransaction
@@ -101,7 +100,13 @@ class StreamIdGenerator:
             self._current = (max if step > 0 else min)(
                 self._current, _load_current_id(db_conn, table, column, step)
             )
-        self._unfinished_ids = deque()  # type: Deque[int]
+
+        # We use this as an ordered set, as we want to efficiently append items,
+        # remove items and get the first item. Since we insert IDs in order, the
+        # insertion ordering will ensure its in the correct ordering.
+        #
+        # The key and values are the same, but we never look at the values.
+        self._unfinished_ids = OrderedDict()  # type: OrderedDict[int, int]
 
     def get_next(self):
         """
@@ -113,7 +118,7 @@ class StreamIdGenerator:
             self._current += self._step
             next_id = self._current
 
-            self._unfinished_ids.append(next_id)
+            self._unfinished_ids[next_id] = next_id
 
         @contextmanager
         def manager():
@@ -121,7 +126,7 @@ class StreamIdGenerator:
                 yield next_id
             finally:
                 with self._lock:
-                    self._unfinished_ids.remove(next_id)
+                    self._unfinished_ids.pop(next_id)
 
         return _AsyncCtxManagerWrapper(manager())
 
@@ -140,7 +145,7 @@ class StreamIdGenerator:
             self._current += n * self._step
 
             for next_id in next_ids:
-                self._unfinished_ids.append(next_id)
+                self._unfinished_ids[next_id] = next_id
 
         @contextmanager
         def manager():
@@ -149,7 +154,7 @@ class StreamIdGenerator:
             finally:
                 with self._lock:
                     for next_id in next_ids:
-                        self._unfinished_ids.remove(next_id)
+                        self._unfinished_ids.pop(next_id)
 
         return _AsyncCtxManagerWrapper(manager())
 
@@ -162,7 +167,7 @@ class StreamIdGenerator:
         """
         with self._lock:
             if self._unfinished_ids:
-                return self._unfinished_ids[0] - self._step
+                return next(iter(self._unfinished_ids)) - self._step
 
             return self._current
 
@@ -240,7 +245,7 @@ class MultiWriterIdGenerator:
         # and b) noting that if we have seen a run of persisted positions
         # without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7).
         #
-        # Note: There is no guarentee that the IDs generated by the sequence
+        # Note: There is no guarantee that the IDs generated by the sequence
         # will be gapless; gaps can form when e.g. a transaction was rolled
         # back. This means that sometimes we won't be able to skip forward the
         # position even though everything has been persisted. However, since
@@ -272,7 +277,9 @@ class MultiWriterIdGenerator:
         self._load_current_ids(db_conn, tables)
 
     def _load_current_ids(
-        self, db_conn, tables: List[Tuple[str, str, str]],
+        self,
+        db_conn,
+        tables: List[Tuple[str, str, str]],
     ):
         cur = db_conn.cursor(txn_name="_load_current_ids")
 
@@ -359,7 +366,10 @@ class MultiWriterIdGenerator:
             rows.sort()
 
             with self._lock:
-                for (instance, stream_id,) in rows:
+                for (
+                    instance,
+                    stream_id,
+                ) in rows:
                     stream_id = self._return_factor * stream_id
                     self._add_persisted_position(stream_id)
 
@@ -413,7 +423,7 @@ class MultiWriterIdGenerator:
         # bother, as nothing will read it).
         #
         # We only do this on the success path so that the persisted current
-        # position points to a persited row with the correct instance name.
+        # position points to a persisted row with the correct instance name.
         if self._writers:
             txn.call_after(
                 run_as_background_process,
@@ -476,8 +486,7 @@ class MultiWriterIdGenerator:
         return self.get_persisted_upto_position()
 
     def get_current_token_for_writer(self, instance_name: str) -> int:
-        """Returns the position of the given writer.
-        """
+        """Returns the position of the given writer."""
 
         # If we don't have an entry for the given instance name, we assume it's a
         # new writer.
@@ -504,7 +513,7 @@ class MultiWriterIdGenerator:
             }
 
     def advance(self, instance_name: str, new_id: int):
-        """Advance the postion of the named writer to the given ID, if greater
+        """Advance the position of the named writer to the given ID, if greater
         than existing entry.
         """
 
@@ -576,8 +585,7 @@ class MultiWriterIdGenerator:
                 break
 
     def _update_stream_positions_table_txn(self, txn: Cursor):
-        """Update the `stream_positions` table with newly persisted position.
-        """
+        """Update the `stream_positions` table with newly persisted position."""
 
         if not self._writers:
             return
@@ -617,8 +625,7 @@ class _AsyncCtxManagerWrapper:
 
 @attr.s(slots=True)
 class _MultiWriterCtxManager:
-    """Async context manager returned by MultiWriterIdGenerator
-    """
+    """Async context manager returned by MultiWriterIdGenerator"""
 
     id_gen = attr.ib(type=MultiWriterIdGenerator)
     multiple_ids = attr.ib(type=Optional[int], default=None)
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index c780ade077..3ea637b281 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -70,6 +70,11 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
         ...
 
     @abc.abstractmethod
+    def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+        """Get the next `n` IDs in the sequence"""
+        ...
+
+    @abc.abstractmethod
     def check_consistency(
         self,
         db_conn: "LoggingDatabaseConnection",
@@ -101,7 +106,9 @@ class PostgresSequenceGenerator(SequenceGenerator):
 
     def get_next_id_txn(self, txn: Cursor) -> int:
         txn.execute("SELECT nextval(?)", (self._sequence_name,))
-        return txn.fetchone()[0]
+        fetch_res = txn.fetchone()
+        assert fetch_res is not None
+        return fetch_res[0]
 
     def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
         txn.execute(
@@ -117,8 +124,7 @@ class PostgresSequenceGenerator(SequenceGenerator):
         stream_name: Optional[str] = None,
         positive: bool = True,
     ):
-        """See SequenceGenerator.check_consistency for docstring.
-        """
+        """See SequenceGenerator.check_consistency for docstring."""
 
         txn = db_conn.cursor(txn_name="sequence.check_consistency")
 
@@ -142,7 +148,9 @@ class PostgresSequenceGenerator(SequenceGenerator):
         txn.execute(
             "SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name}
         )
-        last_value, is_called = txn.fetchone()
+        fetch_res = txn.fetchone()
+        assert fetch_res is not None
+        last_value, is_called = fetch_res
 
         # If we have an associated stream check the stream_positions table.
         max_in_stream_positions = None
@@ -219,6 +227,17 @@ class LocalSequenceGenerator(SequenceGenerator):
             self._current_max_id += 1
             return self._current_max_id
 
+    def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+        with self._lock:
+            if self._current_max_id is None:
+                assert self._callback is not None
+                self._current_max_id = self._callback(txn)
+                self._callback = None
+
+            first_id = self._current_max_id + 1
+            self._current_max_id += n
+            return [first_id + i for i in range(n)]
+
     def check_consistency(
         self,
         db_conn: Connection,