summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8161.misc1
-rw-r--r--synapse/storage/databases/main/account_data.py4
-rw-r--r--synapse/storage/databases/main/deviceinbox.py4
-rw-r--r--synapse/storage/databases/main/devices.py8
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py43
-rw-r--r--synapse/storage/databases/main/events.py4
-rw-r--r--synapse/storage/databases/main/group_server.py2
-rw-r--r--synapse/storage/databases/main/presence.py2
-rw-r--r--synapse/storage/databases/main/push_rule.py8
-rw-r--r--synapse/storage/databases/main/pusher.py4
-rw-r--r--synapse/storage/databases/main/receipts.py3
-rw-r--r--synapse/storage/databases/main/room.py6
-rw-r--r--synapse/storage/databases/main/tags.py4
-rw-r--r--synapse/storage/util/id_generators.py10
14 files changed, 54 insertions, 49 deletions
diff --git a/changelog.d/8161.misc b/changelog.d/8161.misc
new file mode 100644
index 0000000000..89ff274de3
--- /dev/null
+++ b/changelog.d/8161.misc
@@ -0,0 +1 @@
+Refactor `StreamIdGenerator` and `MultiWriterIdGenerator` to have the same interface.
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 82aac2bbf3..04042a2c98 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -336,7 +336,7 @@ class AccountDataStore(AccountDataWorkerStore):
         """
         content_json = json_encoder.encode(content)
 
-        with self._account_data_id_gen.get_next() as next_id:
+        with await self._account_data_id_gen.get_next() as next_id:
             # no need to lock here as room_account_data has a unique constraint
             # on (user_id, room_id, account_data_type) so simple_upsert will
             # retry if there is a conflict.
@@ -384,7 +384,7 @@ class AccountDataStore(AccountDataWorkerStore):
         """
         content_json = json_encoder.encode(content)
 
-        with self._account_data_id_gen.get_next() as next_id:
+        with await self._account_data_id_gen.get_next() as next_id:
             # no need to lock here as account_data has a unique constraint on
             # (user_id, account_data_type) so simple_upsert will retry if
             # there is a conflict.
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 1f6e995c4f..bb85637a95 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
                 rows.append((destination, stream_id, now_ms, edu_json))
             txn.executemany(sql, rows)
 
-        with self._device_inbox_id_gen.get_next() as stream_id:
+        with await self._device_inbox_id_gen.get_next() as stream_id:
             now_ms = self.clock.time_msec()
             await self.db_pool.runInteraction(
                 "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
@@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
                 txn, stream_id, local_messages_by_user_then_device
             )
 
-        with self._device_inbox_id_gen.get_next() as stream_id:
+        with await self._device_inbox_id_gen.get_next() as stream_id:
             now_ms = self.clock.time_msec()
             await self.db_pool.runInteraction(
                 "add_messages_from_remote_to_device_inbox",
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 9a786e2929..03b45dbc4d 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -380,7 +380,7 @@ class DeviceWorkerStore(SQLBaseStore):
             THe new stream ID.
         """
 
-        with self._device_list_id_gen.get_next() as stream_id:
+        with await self._device_list_id_gen.get_next() as stream_id:
             await self.db_pool.runInteraction(
                 "add_user_sig_change_to_streams",
                 self._add_user_signature_change_txn,
@@ -1146,7 +1146,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         if not device_ids:
             return
 
-        with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
+        with await self._device_list_id_gen.get_next_mult(
+            len(device_ids)
+        ) as stream_ids:
             await self.db_pool.runInteraction(
                 "add_device_change_to_stream",
                 self._add_device_change_to_stream_txn,
@@ -1159,7 +1161,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             return stream_ids[-1]
 
         context = get_active_span_text_map()
-        with self._device_list_id_gen.get_next_mult(
+        with await self._device_list_id_gen.get_next_mult(
             len(hosts) * len(device_ids)
         ) as stream_ids:
             await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index f93e0d320d..385868bdab 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -648,7 +648,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
         )
 
-    def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key):
+    def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id):
         """Set a user's cross-signing key.
 
         Args:
@@ -658,6 +658,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 for a master key, 'self_signing' for a self-signing key, or
                 'user_signing' for a user-signing key
             key (dict): the key data
+            stream_id (int)
         """
         # the 'key' dict will look something like:
         # {
@@ -695,23 +696,22 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             )
 
         # and finally, store the key itself
-        with self._cross_signing_id_gen.get_next() as stream_id:
-            self.db_pool.simple_insert_txn(
-                txn,
-                "e2e_cross_signing_keys",
-                values={
-                    "user_id": user_id,
-                    "keytype": key_type,
-                    "keydata": json_encoder.encode(key),
-                    "stream_id": stream_id,
-                },
-            )
+        self.db_pool.simple_insert_txn(
+            txn,
+            "e2e_cross_signing_keys",
+            values={
+                "user_id": user_id,
+                "keytype": key_type,
+                "keydata": json_encoder.encode(key),
+                "stream_id": stream_id,
+            },
+        )
 
         self._invalidate_cache_and_stream(
             txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
         )
 
-    def set_e2e_cross_signing_key(self, user_id, key_type, key):
+    async def set_e2e_cross_signing_key(self, user_id, key_type, key):
         """Set a user's cross-signing key.
 
         Args:
@@ -719,13 +719,16 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             key_type (str): the type of cross-signing key to set
             key (dict): the key data
         """
-        return self.db_pool.runInteraction(
-            "add_e2e_cross_signing_key",
-            self._set_e2e_cross_signing_key_txn,
-            user_id,
-            key_type,
-            key,
-        )
+
+        with await self._cross_signing_id_gen.get_next() as stream_id:
+            return await self.db_pool.runInteraction(
+                "add_e2e_cross_signing_key",
+                self._set_e2e_cross_signing_key_txn,
+                user_id,
+                key_type,
+                key,
+                stream_id,
+            )
 
     def store_e2e_cross_signing_signatures(self, user_id, signatures):
         """Stores cross-signing signatures.
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index b90e6de2d5..6313b41eef 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -153,11 +153,11 @@ class PersistEventsStore:
         # Note: Multiple instances of this function cannot be in flight at
         # the same time for the same room.
         if backfilled:
-            stream_ordering_manager = self._backfill_id_gen.get_next_mult(
+            stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
                 len(events_and_contexts)
             )
         else:
-            stream_ordering_manager = self._stream_id_gen.get_next_mult(
+            stream_ordering_manager = await self._stream_id_gen.get_next_mult(
                 len(events_and_contexts)
             )
 
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 0e3b8739c6..a488e0924b 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -1182,7 +1182,7 @@ class GroupServerStore(GroupServerWorkerStore):
 
             return next_id
 
-        with self._group_updates_id_gen.get_next() as next_id:
+        with await self._group_updates_id_gen.get_next() as next_id:
             res = await self.db_pool.runInteraction(
                 "register_user_group_membership",
                 _register_user_group_membership_txn,
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 4e3ec02d14..c9f655dfb7 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -23,7 +23,7 @@ from synapse.util.iterutils import batch_iter
 
 class PresenceStore(SQLBaseStore):
     async def update_presence(self, presence_states):
-        stream_ordering_manager = self._presence_id_gen.get_next_mult(
+        stream_ordering_manager = await self._presence_id_gen.get_next_mult(
             len(presence_states)
         )
 
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index a585e54812..2fb5b02d7d 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore):
     ) -> None:
         conditions_json = json_encoder.encode(conditions)
         actions_json = json_encoder.encode(actions)
-        with self._push_rules_stream_id_gen.get_next() as stream_id:
+        with await self._push_rules_stream_id_gen.get_next() as stream_id:
             event_stream_ordering = self._stream_id_gen.get_current_token()
 
             if before or after:
@@ -560,7 +560,7 @@ class PushRuleStore(PushRulesWorkerStore):
                 txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
             )
 
-        with self._push_rules_stream_id_gen.get_next() as stream_id:
+        with await self._push_rules_stream_id_gen.get_next() as stream_id:
             event_stream_ordering = self._stream_id_gen.get_current_token()
 
             await self.db_pool.runInteraction(
@@ -571,7 +571,7 @@ class PushRuleStore(PushRulesWorkerStore):
             )
 
     async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
-        with self._push_rules_stream_id_gen.get_next() as stream_id:
+        with await self._push_rules_stream_id_gen.get_next() as stream_id:
             event_stream_ordering = self._stream_id_gen.get_current_token()
 
             await self.db_pool.runInteraction(
@@ -646,7 +646,7 @@ class PushRuleStore(PushRulesWorkerStore):
                 data={"actions": actions_json},
             )
 
-        with self._push_rules_stream_id_gen.get_next() as stream_id:
+        with await self._push_rules_stream_id_gen.get_next() as stream_id:
             event_stream_ordering = self._stream_id_gen.get_current_token()
 
             await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 1126fd0751..c388468273 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore):
         last_stream_ordering,
         profile_tag="",
     ) -> None:
-        with self._pushers_id_gen.get_next() as stream_id:
+        with await self._pushers_id_gen.get_next() as stream_id:
             # no need to lock because `pushers` has a unique key on
             # (app_id, pushkey, user_name) so simple_upsert will retry
             await self.db_pool.simple_upsert(
@@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore):
                 },
             )
 
-        with self._pushers_id_gen.get_next() as stream_id:
+        with await self._pushers_id_gen.get_next() as stream_id:
             await self.db_pool.runInteraction(
                 "delete_pusher", delete_pusher_txn, stream_id
             )
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 19ad1c056f..6821476ee0 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -520,8 +520,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
                 "insert_receipt_conv", graph_to_linear
             )
 
-        stream_id_manager = self._receipts_id_gen.get_next()
-        with stream_id_manager as stream_id:
+        with await self._receipts_id_gen.get_next() as stream_id:
             event_ts = await self.db_pool.runInteraction(
                 "insert_linearized_receipt",
                 self.insert_linearized_receipt_txn,
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 7d3ac47261..b3772be2b2 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1129,7 +1129,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                         },
                     )
 
-            with self._public_room_id_gen.get_next() as next_id:
+            with await self._public_room_id_gen.get_next() as next_id:
                 await self.db_pool.runInteraction(
                     "store_room_txn", store_room_txn, next_id
                 )
@@ -1196,7 +1196,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                     },
                 )
 
-        with self._public_room_id_gen.get_next() as next_id:
+        with await self._public_room_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction(
                 "set_room_is_public", set_room_is_public_txn, next_id
             )
@@ -1276,7 +1276,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                     },
                 )
 
-        with self._public_room_id_gen.get_next() as next_id:
+        with await self._public_room_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction(
                 "set_room_is_public_appservice",
                 set_room_is_public_appservice_txn,
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index ade7abc927..0c34bbf21a 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore):
             )
             self._update_revision_txn(txn, user_id, room_id, next_id)
 
-        with self._account_data_id_gen.get_next() as next_id:
+        with await self._account_data_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
 
         self.get_tags_for_user.invalidate((user_id,))
@@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore):
             txn.execute(sql, (user_id, room_id, tag))
             self._update_revision_txn(txn, user_id, room_id, next_id)
 
-        with self._account_data_id_gen.get_next() as next_id:
+        with await self._account_data_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
 
         self.get_tags_for_user.invalidate((user_id,))
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 0bf772d4d1..ddb5c8c60c 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -80,7 +80,7 @@ class StreamIdGenerator(object):
             upwards, -1 to grow downwards.
 
     Usage:
-        with stream_id_gen.get_next() as stream_id:
+        with await stream_id_gen.get_next() as stream_id:
             # ... persist event ...
     """
 
@@ -95,10 +95,10 @@ class StreamIdGenerator(object):
             )
         self._unfinished_ids = deque()  # type: Deque[int]
 
-    def get_next(self):
+    async def get_next(self):
         """
         Usage:
-            with stream_id_gen.get_next() as stream_id:
+            with await stream_id_gen.get_next() as stream_id:
                 # ... persist event ...
         """
         with self._lock:
@@ -117,10 +117,10 @@ class StreamIdGenerator(object):
 
         return manager()
 
-    def get_next_mult(self, n):
+    async def get_next_mult(self, n):
         """
         Usage:
-            with stream_id_gen.get_next(n) as stream_ids:
+            with await stream_id_gen.get_next(n) as stream_ids:
                 # ... persist events ...
         """
         with self._lock: