diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index c5a36990e4..ef81d73573 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -339,7 +339,7 @@ class AccountDataStore(AccountDataWorkerStore):
"""
content_json = json_encoder.encode(content)
- with await self._account_data_id_gen.get_next() as next_id:
+ async with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as room_account_data has a unique constraint
# on (user_id, room_id, account_data_type) so simple_upsert will
# retry if there is a conflict.
@@ -387,7 +387,7 @@ class AccountDataStore(AccountDataWorkerStore):
"""
content_json = json_encoder.encode(content)
- with await self._account_data_id_gen.get_next() as next_id:
+ async with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as account_data has a unique constraint on
# (user_id, account_data_type) so simple_upsert will retry if
# there is a conflict.
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index e71217a41f..d42faa3f1f 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows)
- with await self._device_inbox_id_gen.get_next() as stream_id:
+ async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
@@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
txn, stream_id, local_messages_by_user_then_device
)
- with await self._device_inbox_id_gen.get_next() as stream_id:
+ async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_from_remote_to_device_inbox",
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index c04374e43d..fdf394c612 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -377,7 +377,7 @@ class DeviceWorkerStore(SQLBaseStore):
THe new stream ID.
"""
- with await self._device_list_id_gen.get_next() as stream_id:
+ async with self._device_list_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction(
"add_user_sig_change_to_streams",
self._add_user_signature_change_txn,
@@ -1093,7 +1093,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not device_ids:
return
- with await self._device_list_id_gen.get_next_mult(
+ async with self._device_list_id_gen.get_next_mult(
len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
@@ -1108,7 +1108,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return stream_ids[-1]
context = get_active_span_text_map()
- with await self._device_list_id_gen.get_next_mult(
+ async with self._device_list_id_gen.get_next_mult(
len(hosts) * len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index c8df0bcb3f..22e1ed15d0 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -831,7 +831,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
key (dict): the key data
"""
- with await self._cross_signing_id_gen.get_next() as stream_id:
+ async with self._cross_signing_id_gen.get_next() as stream_id:
return await self.db_pool.runInteraction(
"add_e2e_cross_signing_key",
self._set_e2e_cross_signing_key_txn,
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 9a80f419e3..7723d82496 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -156,15 +156,15 @@ class PersistEventsStore:
# Note: Multiple instances of this function cannot be in flight at
# the same time for the same room.
if backfilled:
- stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
+ stream_ordering_manager = self._backfill_id_gen.get_next_mult(
len(events_and_contexts)
)
else:
- stream_ordering_manager = await self._stream_id_gen.get_next_mult(
+ stream_ordering_manager = self._stream_id_gen.get_next_mult(
len(events_and_contexts)
)
- with stream_ordering_manager as stream_orderings:
+ async with stream_ordering_manager as stream_orderings:
for (event, context), stream in zip(events_and_contexts, stream_orderings):
event.internal_metadata.stream_ordering = stream
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index ccfbb2135e..7218191965 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -1265,7 +1265,7 @@ class GroupServerStore(GroupServerWorkerStore):
return next_id
- with await self._group_updates_id_gen.get_next() as next_id:
+ async with self._group_updates_id_gen.get_next() as next_id:
res = await self.db_pool.runInteraction(
"register_user_group_membership",
_register_user_group_membership_txn,
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index c9f655dfb7..dbbb99cb95 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -23,11 +23,11 @@ from synapse.util.iterutils import batch_iter
class PresenceStore(SQLBaseStore):
async def update_presence(self, presence_states):
- stream_ordering_manager = await self._presence_id_gen.get_next_mult(
+ stream_ordering_manager = self._presence_id_gen.get_next_mult(
len(presence_states)
)
- with stream_ordering_manager as stream_orderings:
+ async with stream_ordering_manager as stream_orderings:
await self.db_pool.runInteraction(
"update_presence",
self._update_presence_txn,
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index e20a16f907..711d5aa23d 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore):
) -> None:
conditions_json = json_encoder.encode(conditions)
actions_json = json_encoder.encode(actions)
- with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
if before or after:
@@ -585,7 +585,7 @@ class PushRuleStore(PushRulesWorkerStore):
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
)
- with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
@@ -616,7 +616,7 @@ class PushRuleStore(PushRulesWorkerStore):
Raises:
NotFoundError if the rule does not exist.
"""
- with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
"_set_push_rule_enabled_txn",
@@ -754,7 +754,7 @@ class PushRuleStore(PushRulesWorkerStore):
data={"actions": actions_json},
)
- with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index c388468273..df8609b97b 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore):
last_stream_ordering,
profile_tag="",
) -> None:
- with await self._pushers_id_gen.get_next() as stream_id:
+ async with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so simple_upsert will retry
await self.db_pool.simple_upsert(
@@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore):
},
)
- with await self._pushers_id_gen.get_next() as stream_id:
+ async with self._pushers_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction(
"delete_pusher", delete_pusher_txn, stream_id
)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index f880b5e562..c79ddff680 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -524,7 +524,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"insert_receipt_conv", graph_to_linear
)
- with await self._receipts_id_gen.get_next() as stream_id:
+ async with self._receipts_id_gen.get_next() as stream_id:
event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 3ee097abf7..3c7630857f 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1137,7 +1137,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with await self._public_room_id_gen.get_next() as next_id:
+ async with self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"store_room_txn", store_room_txn, next_id
)
@@ -1204,7 +1204,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with await self._public_room_id_gen.get_next() as next_id:
+ async with self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"set_room_is_public", set_room_is_public_txn, next_id
)
@@ -1284,7 +1284,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with await self._public_room_id_gen.get_next() as next_id:
+ async with self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"set_room_is_public_appservice",
set_room_is_public_appservice_txn,
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 96ffe26cc9..9f120d3cb6 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore):
)
self._update_revision_txn(txn, user_id, room_id, next_id)
- with await self._account_data_id_gen.get_next() as next_id:
+ async with self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
@@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore):
txn.execute(sql, (user_id, room_id, tag))
self._update_revision_txn(txn, user_id, room_id, next_id)
- with await self._account_data_id_gen.get_next() as next_id:
+ async with self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 1de2b91587..b0353ac2dc 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -12,14 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-import contextlib
import heapq
import logging
import threading
from collections import deque
-from typing import Dict, List, Set
+from contextlib import contextmanager
+from typing import Dict, List, Optional, Set, Union
+import attr
from typing_extensions import Deque
from synapse.storage.database import DatabasePool, LoggingTransaction
@@ -86,7 +86,7 @@ class StreamIdGenerator:
upwards, -1 to grow downwards.
Usage:
- with await stream_id_gen.get_next() as stream_id:
+ async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
@@ -101,10 +101,10 @@ class StreamIdGenerator:
)
self._unfinished_ids = deque() # type: Deque[int]
- async def get_next(self):
+ def get_next(self):
"""
Usage:
- with await stream_id_gen.get_next() as stream_id:
+ async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
with self._lock:
@@ -113,7 +113,7 @@ class StreamIdGenerator:
self._unfinished_ids.append(next_id)
- @contextlib.contextmanager
+ @contextmanager
def manager():
try:
yield next_id
@@ -121,12 +121,12 @@ class StreamIdGenerator:
with self._lock:
self._unfinished_ids.remove(next_id)
- return manager()
+ return _AsyncCtxManagerWrapper(manager())
- async def get_next_mult(self, n):
+ def get_next_mult(self, n):
"""
Usage:
- with await stream_id_gen.get_next(n) as stream_ids:
+ async with stream_id_gen.get_next(n) as stream_ids:
# ... persist events ...
"""
with self._lock:
@@ -140,7 +140,7 @@ class StreamIdGenerator:
for next_id in next_ids:
self._unfinished_ids.append(next_id)
- @contextlib.contextmanager
+ @contextmanager
def manager():
try:
yield next_ids
@@ -149,7 +149,7 @@ class StreamIdGenerator:
for next_id in next_ids:
self._unfinished_ids.remove(next_id)
- return manager()
+ return _AsyncCtxManagerWrapper(manager())
def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or
@@ -282,59 +282,23 @@ class MultiWriterIdGenerator:
def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
return self._sequence_gen.get_next_mult_txn(txn, n)
- async def get_next(self):
+ def get_next(self):
"""
Usage:
- with await stream_id_gen.get_next() as stream_id:
+ async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
- next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
-
- # Assert the fetched ID is actually greater than what we currently
- # believe the ID to be. If not, then the sequence and table have got
- # out of sync somehow.
- with self._lock:
- assert self._current_positions.get(self._instance_name, 0) < next_id
-
- self._unfinished_ids.add(next_id)
-
- @contextlib.contextmanager
- def manager():
- try:
- # Multiply by the return factor so that the ID has correct sign.
- yield self._return_factor * next_id
- finally:
- self._mark_id_as_finished(next_id)
- return manager()
+ return _MultiWriterCtxManager(self)
- async def get_next_mult(self, n: int):
+ def get_next_mult(self, n: int):
"""
Usage:
- with await stream_id_gen.get_next_mult(5) as stream_ids:
+ async with stream_id_gen.get_next_mult(5) as stream_ids:
# ... persist events ...
"""
- next_ids = await self._db.runInteraction(
- "_load_next_mult_id", self._load_next_mult_id_txn, n
- )
- # Assert the fetched ID is actually greater than any ID we've already
- # seen. If not, then the sequence and table have got out of sync
- # somehow.
- with self._lock:
- assert max(self._current_positions.values(), default=0) < min(next_ids)
-
- self._unfinished_ids.update(next_ids)
-
- @contextlib.contextmanager
- def manager():
- try:
- yield [self._return_factor * i for i in next_ids]
- finally:
- for i in next_ids:
- self._mark_id_as_finished(i)
-
- return manager()
+ return _MultiWriterCtxManager(self, n)
def get_next_txn(self, txn: LoggingTransaction):
"""
@@ -482,3 +446,61 @@ class MultiWriterIdGenerator:
# There was a gap in seen positions, so there is nothing more to
# do.
break
+
+
+@attr.s(slots=True)
+class _AsyncCtxManagerWrapper:
+ """Helper class to convert a plain context manager to an async one.
+
+ This is mainly useful if you have a plain context manager but the interface
+ requires an async one.
+ """
+
+ inner = attr.ib()
+
+ async def __aenter__(self):
+ return self.inner.__enter__()
+
+ async def __aexit__(self, exc_type, exc, tb):
+ return self.inner.__exit__(exc_type, exc, tb)
+
+
+@attr.s(slots=True)
+class _MultiWriterCtxManager:
+ """Async context manager returned by MultiWriterIdGenerator
+ """
+
+ id_gen = attr.ib(type=MultiWriterIdGenerator)
+ multiple_ids = attr.ib(type=Optional[int], default=None)
+ stream_ids = attr.ib(type=List[int], factory=list)
+
+ async def __aenter__(self) -> Union[int, List[int]]:
+ self.stream_ids = await self.id_gen._db.runInteraction(
+ "_load_next_mult_id",
+ self.id_gen._load_next_mult_id_txn,
+ self.multiple_ids or 1,
+ )
+
+ # Assert the fetched ID is actually greater than any ID we've already
+ # seen. If not, then the sequence and table have got out of sync
+ # somehow.
+ with self.id_gen._lock:
+ assert max(self.id_gen._current_positions.values(), default=0) < min(
+ self.stream_ids
+ )
+
+ self.id_gen._unfinished_ids.update(self.stream_ids)
+
+ if self.multiple_ids is None:
+ return self.stream_ids[0] * self.id_gen._return_factor
+ else:
+ return [i * self.id_gen._return_factor for i in self.stream_ids]
+
+ async def __aexit__(self, exc_type, exc, tb):
+ for i in self.stream_ids:
+ self.id_gen._mark_id_as_finished(i)
+
+ if exc_type is not None:
+ return False
+
+ return False
|