diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index bb84c0d792..71ef5a72dc 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -15,12 +15,11 @@
import heapq
import logging
import threading
-from collections import deque
+from collections import OrderedDict
from contextlib import contextmanager
from typing import Dict, List, Optional, Set, Tuple, Union
import attr
-from typing_extensions import Deque
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.database import DatabasePool, LoggingTransaction
@@ -101,7 +100,13 @@ class StreamIdGenerator:
self._current = (max if step > 0 else min)(
self._current, _load_current_id(db_conn, table, column, step)
)
- self._unfinished_ids = deque() # type: Deque[int]
+
+ # We use this as an ordered set, as we want to efficiently append items,
+ # remove items and get the first item. Since we insert IDs in order, the
+ # insertion ordering will ensure its in the correct ordering.
+ #
+ # The key and values are the same, but we never look at the values.
+ self._unfinished_ids = OrderedDict() # type: OrderedDict[int, int]
def get_next(self):
"""
@@ -113,7 +118,7 @@ class StreamIdGenerator:
self._current += self._step
next_id = self._current
- self._unfinished_ids.append(next_id)
+ self._unfinished_ids[next_id] = next_id
@contextmanager
def manager():
@@ -121,7 +126,7 @@ class StreamIdGenerator:
yield next_id
finally:
with self._lock:
- self._unfinished_ids.remove(next_id)
+ self._unfinished_ids.pop(next_id)
return _AsyncCtxManagerWrapper(manager())
@@ -140,7 +145,7 @@ class StreamIdGenerator:
self._current += n * self._step
for next_id in next_ids:
- self._unfinished_ids.append(next_id)
+ self._unfinished_ids[next_id] = next_id
@contextmanager
def manager():
@@ -149,7 +154,7 @@ class StreamIdGenerator:
finally:
with self._lock:
for next_id in next_ids:
- self._unfinished_ids.remove(next_id)
+ self._unfinished_ids.pop(next_id)
return _AsyncCtxManagerWrapper(manager())
@@ -162,7 +167,7 @@ class StreamIdGenerator:
"""
with self._lock:
if self._unfinished_ids:
- return self._unfinished_ids[0] - self._step
+ return next(iter(self._unfinished_ids)) - self._step
return self._current
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index c780ade077..0ec4dc2918 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -70,6 +70,11 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
...
@abc.abstractmethod
+ def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+ """Get the next `n` IDs in the sequence"""
+ ...
+
+ @abc.abstractmethod
def check_consistency(
self,
db_conn: "LoggingDatabaseConnection",
@@ -219,6 +224,17 @@ class LocalSequenceGenerator(SequenceGenerator):
self._current_max_id += 1
return self._current_max_id
+ def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+ with self._lock:
+ if self._current_max_id is None:
+ assert self._callback is not None
+ self._current_max_id = self._callback(txn)
+ self._callback = None
+
+ first_id = self._current_max_id + 1
+ self._current_max_id += n
+ return [first_id + i for i in range(n)]
+
def check_consistency(
self,
db_conn: Connection,
|