summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/account_data.py10
-rw-r--r--synapse/storage/databases/main/deviceinbox.py10
-rw-r--r--synapse/storage/databases/main/devices.py3
-rw-r--r--synapse/storage/databases/main/events_worker.py4
-rw-r--r--synapse/storage/databases/main/presence.py10
-rw-r--r--synapse/storage/databases/main/push_rule.py3
-rw-r--r--synapse/storage/databases/main/receipts.py10
-rw-r--r--synapse/storage/databases/main/room.py11
-rw-r--r--synapse/storage/databases/main/stream.py3
-rw-r--r--synapse/storage/util/id_generators.py5
-rw-r--r--synapse/storage/util/sequence.py24
11 files changed, 65 insertions, 28 deletions
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 9611a84932..966393869b 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -43,10 +43,7 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
-from synapse.storage.util.id_generators import (
-    AbstractStreamIdGenerator,
-    MultiWriterIdGenerator,
-)
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.types import JsonDict, JsonMapping
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
@@ -71,7 +68,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             self._instance_name in hs.config.worker.writers.account_data
         )
 
-        self._account_data_id_gen: AbstractStreamIdGenerator
+        self._account_data_id_gen: MultiWriterIdGenerator
 
         self._account_data_id_gen = MultiWriterIdGenerator(
             db_conn=db_conn,
@@ -113,6 +110,9 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
         """
         return self._account_data_id_gen.get_current_token()
 
+    def get_account_data_id_generator(self) -> MultiWriterIdGenerator:
+        return self._account_data_id_gen
+
     @cached()
     async def get_global_account_data_for_user(
         self, user_id: str
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 07333efff8..304ac42411 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -50,10 +50,7 @@ from synapse.storage.database import (
     LoggingTransaction,
     make_in_list_sql_clause,
 )
-from synapse.storage.util.id_generators import (
-    AbstractStreamIdGenerator,
-    MultiWriterIdGenerator,
-)
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -92,7 +89,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             self._instance_name in hs.config.worker.writers.to_device
         )
 
-        self._to_device_msg_id_gen: AbstractStreamIdGenerator = MultiWriterIdGenerator(
+        self._to_device_msg_id_gen: MultiWriterIdGenerator = MultiWriterIdGenerator(
             db_conn=db_conn,
             db=database,
             notifier=hs.get_replication_notifier(),
@@ -169,6 +166,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
     def get_to_device_stream_token(self) -> int:
         return self._to_device_msg_id_gen.get_current_token()
 
+    def get_to_device_id_generator(self) -> MultiWriterIdGenerator:
+        return self._to_device_msg_id_gen
+
     async def get_messages_for_user_devices(
         self,
         user_ids: Collection[str],
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 59a035dd62..53024bddc3 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -243,6 +243,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     def get_device_stream_token(self) -> int:
         return self._device_list_id_gen.get_current_token()
 
+    def get_device_stream_id_generator(self) -> MultiWriterIdGenerator:
+        return self._device_list_id_gen
+
     async def count_devices_by_users(
         self, user_ids: Optional[Collection[str]] = None
     ) -> int:
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index e264d36f02..198e65cfa5 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -192,8 +192,8 @@ class EventsWorkerStore(SQLBaseStore):
     ):
         super().__init__(database, db_conn, hs)
 
-        self._stream_id_gen: AbstractStreamIdGenerator
-        self._backfill_id_gen: AbstractStreamIdGenerator
+        self._stream_id_gen: MultiWriterIdGenerator
+        self._backfill_id_gen: MultiWriterIdGenerator
 
         self._stream_id_gen = MultiWriterIdGenerator(
             db_conn=db_conn,
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 923e764491..065c885603 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -42,10 +42,7 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.engines._base import IsolationLevel
 from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import (
-    AbstractStreamIdGenerator,
-    MultiWriterIdGenerator,
-)
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.iterutils import batch_iter
@@ -83,7 +80,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
         super().__init__(database, db_conn, hs)
 
         self._instance_name = hs.get_instance_name()
-        self._presence_id_gen: AbstractStreamIdGenerator
+        self._presence_id_gen: MultiWriterIdGenerator
 
         self._can_persist_presence = (
             self._instance_name in hs.config.worker.writers.presence
@@ -455,6 +452,9 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
     def get_current_presence_token(self) -> int:
         return self._presence_id_gen.get_current_token()
 
+    def get_presence_stream_id_gen(self) -> MultiWriterIdGenerator:
+        return self._presence_id_gen
+
     def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]:
         """Fetch non-offline presence from the database so that we can register
         the appropriate time outs.
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 2a39dc9f90..bbdde17711 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -178,6 +178,9 @@ class PushRulesWorkerStore(
         """
         return self._push_rules_stream_id_gen.get_current_token()
 
+    def get_push_rules_stream_id_gen(self) -> MultiWriterIdGenerator:
+        return self._push_rules_stream_id_gen
+
     def process_replication_rows(
         self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
     ) -> None:
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 8432560a89..3bde0ae0d4 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -45,10 +45,7 @@ from synapse.storage.database import (
     LoggingTransaction,
 )
 from synapse.storage.engines._base import IsolationLevel
-from synapse.storage.util.id_generators import (
-    AbstractStreamIdGenerator,
-    MultiWriterIdGenerator,
-)
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.types import (
     JsonDict,
     JsonMapping,
@@ -76,7 +73,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         # In the worker store this is an ID tracker which we overwrite in the non-worker
         # class below that is used on the main process.
-        self._receipts_id_gen: AbstractStreamIdGenerator
+        self._receipts_id_gen: MultiWriterIdGenerator
 
         self._can_write_to_receipts = (
             self._instance_name in hs.config.worker.writers.receipts
@@ -136,6 +133,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
     def get_receipt_stream_id_for_instance(self, instance_name: str) -> int:
         return self._receipts_id_gen.get_current_token_for_writer(instance_name)
 
+    def get_receipts_stream_id_gen(self) -> MultiWriterIdGenerator:
+        return self._receipts_id_gen
+
     def get_last_unthreaded_receipt_for_user_txn(
         self,
         txn: LoggingTransaction,
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index d5627b1d6e..80a4bf95f2 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -59,11 +59,7 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.types import Cursor
-from synapse.storage.util.id_generators import (
-    AbstractStreamIdGenerator,
-    IdGenerator,
-    MultiWriterIdGenerator,
-)
+from synapse.storage.util.id_generators import IdGenerator, MultiWriterIdGenerator
 from synapse.types import JsonDict, RetentionPolicy, StrCollection, ThirdPartyInstanceID
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
@@ -151,7 +147,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
 
         self.config: HomeServerConfig = hs.config
 
-        self._un_partial_stated_rooms_stream_id_gen: AbstractStreamIdGenerator
+        self._un_partial_stated_rooms_stream_id_gen: MultiWriterIdGenerator
 
         self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
             db_conn=db_conn,
@@ -1409,6 +1405,9 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
             instance_name
         )
 
+    def get_un_partial_stated_rooms_id_generator(self) -> MultiWriterIdGenerator:
+        return self._un_partial_stated_rooms_stream_id_gen
+
     async def get_un_partial_stated_rooms_between(
         self, last_id: int, current_id: int, room_ids: Collection[str]
     ) -> Set[str]:
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index ff0d723684..b7eb3116ae 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -577,6 +577,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return RoomStreamToken(stream=min_pos, instance_map=immutabledict(positions))
 
+    def get_events_stream_id_generator(self) -> MultiWriterIdGenerator:
+        return self._stream_id_gen
+
     async def get_room_events_stream_for_rooms(
         self,
         room_ids: Collection[str],
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 48f88a6f8a..e8588f33cf 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -812,6 +812,11 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
         pos = self.get_current_token_for_writer(self._instance_name)
         txn.execute(sql, (self._stream_name, self._instance_name, pos))
 
+    async def get_max_allocated_token(self) -> int:
+        return await self._db.runInteraction(
+            "get_max_allocated_token", self._sequence_gen.get_max_allocated
+        )
+
 
 @attr.s(frozen=True, auto_attribs=True)
 class _AsyncCtxManagerWrapper(Generic[T]):
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index c4c0602b28..cac3eba1a5 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -88,6 +88,10 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
         """
         ...
 
+    @abc.abstractmethod
+    def get_max_allocated(self, txn: Cursor) -> int:
+        """Get the maximum ID that we have allocated"""
+
 
 class PostgresSequenceGenerator(SequenceGenerator):
     """An implementation of SequenceGenerator which uses a postgres sequence"""
@@ -190,6 +194,17 @@ class PostgresSequenceGenerator(SequenceGenerator):
                 % {"seq": self._sequence_name, "stream_name": stream_name}
             )
 
+    def get_max_allocated(self, txn: Cursor) -> int:
+        # We just read from the sequence what the last value we fetched was.
+        txn.execute(f"SELECT last_value, is_called FROM {self._sequence_name}")
+        row = txn.fetchone()
+        assert row is not None
+
+        last_value, is_called = row
+        if not is_called:
+            last_value -= 1
+        return last_value
+
 
 GetFirstCallbackType = Callable[[Cursor], int]
 
@@ -248,6 +263,15 @@ class LocalSequenceGenerator(SequenceGenerator):
         # There is nothing to do for in memory sequences
         pass
 
+    def get_max_allocated(self, txn: Cursor) -> int:
+        with self._lock:
+            if self._current_max_id is None:
+                assert self._callback is not None
+                self._current_max_id = self._callback(txn)
+                self._callback = None
+
+            return self._current_max_id
+
 
 def build_sequence_generator(
     db_conn: "LoggingDatabaseConnection",