diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 2cd3383f71..c6660fe687 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -218,11 +218,7 @@ class Auth:
# Deny the request if the user account has expired.
if self._account_validity.enabled and not allow_expired:
user_id = user.to_string()
- expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
- if (
- expiration_ts is not None
- and self.clock.time_msec() >= expiration_ts
- ):
+ if await self.store.is_account_expired(user_id, self.clock.time_msec()):
raise AuthError(
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 7f9ae94840..dfd414d886 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1185,54 +1185,7 @@ class EventCreationHandler:
)
for room_id in room_ids:
- # For each room we need to find a joined member we can use to send
- # the dummy event with.
-
- latest_event_ids = await self.store.get_prev_events_for_room(room_id)
-
- members = await self.state.get_current_users_in_room(
- room_id, latest_event_ids=latest_event_ids
- )
- dummy_event_sent = False
- for user_id in members:
- if not self.hs.is_mine_id(user_id):
- continue
- requester = create_requester(user_id)
- try:
- event, context = await self.create_event(
- requester,
- {
- "type": "org.matrix.dummy_event",
- "content": {},
- "room_id": room_id,
- "sender": user_id,
- },
- prev_event_ids=latest_event_ids,
- )
-
- event.internal_metadata.proactively_send = False
-
- # Since this is a dummy-event it is OK if it is sent by a
- # shadow-banned user.
- await self.send_nonmember_event(
- requester,
- event,
- context,
- ratelimit=False,
- ignore_shadow_ban=True,
- )
- dummy_event_sent = True
- break
- except ConsentNotGivenError:
- logger.info(
- "Failed to send dummy event into room %s for user %s due to "
- "lack of consent. Will try another user" % (room_id, user_id)
- )
- except AuthError:
- logger.info(
- "Failed to send dummy event into room %s for user %s due to "
- "lack of power. Will try another user" % (room_id, user_id)
- )
+ dummy_event_sent = await self._send_dummy_event_for_room(room_id)
if not dummy_event_sent:
# Did not find a valid user in the room, so remove from future attempts
@@ -1245,6 +1198,59 @@ class EventCreationHandler:
now = self.clock.time_msec()
self._rooms_to_exclude_from_dummy_event_insertion[room_id] = now
+ async def _send_dummy_event_for_room(self, room_id: str) -> bool:
+ """Attempt to send a dummy event for the given room.
+
+ Args:
+ room_id: room to try to send an event from
+
+ Returns:
+ True if a dummy event was successfully sent. False if no user was able
+ to send an event.
+ """
+
+ # For each room we need to find a joined member we can use to send
+ # the dummy event with.
+ latest_event_ids = await self.store.get_prev_events_for_room(room_id)
+ members = await self.state.get_current_users_in_room(
+ room_id, latest_event_ids=latest_event_ids
+ )
+ for user_id in members:
+ if not self.hs.is_mine_id(user_id):
+ continue
+ requester = create_requester(user_id)
+ try:
+ event, context = await self.create_event(
+ requester,
+ {
+ "type": "org.matrix.dummy_event",
+ "content": {},
+ "room_id": room_id,
+ "sender": user_id,
+ },
+ prev_event_ids=latest_event_ids,
+ )
+
+ event.internal_metadata.proactively_send = False
+
+ # Since this is a dummy-event it is OK if it is sent by a
+ # shadow-banned user.
+ await self.send_nonmember_event(
+ requester, event, context, ratelimit=False, ignore_shadow_ban=True,
+ )
+ return True
+ except ConsentNotGivenError:
+ logger.info(
+ "Failed to send dummy event into room %s for user %s due to "
+ "lack of consent. Will try another user" % (room_id, user_id)
+ )
+ except AuthError:
+ logger.info(
+ "Failed to send dummy event into room %s for user %s due to "
+ "lack of power. Will try another user" % (room_id, user_id)
+ )
+ return False
+
def _expire_rooms_to_exclude_from_dummy_event_insertion(self):
expire_before = self.clock.time_msec() - _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY
to_expire = set()
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index cc839ffce4..76150e117b 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -60,6 +60,8 @@ class PusherPool:
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
+ self._account_validity = hs.config.account_validity
+
# We shard the handling of push notifications by user ID.
self._pusher_shard_config = hs.config.push.pusher_shard_config
self._instance_name = hs.get_instance_name()
@@ -202,6 +204,14 @@ class PusherPool:
)
for u in users_affected:
+ # Don't push if the user account has expired
+ if self._account_validity.enabled:
+ expired = await self.store.is_account_expired(
+ u, self.clock.time_msec()
+ )
+ if expired:
+ continue
+
if u in self.pushers:
for p in self.pushers[u].values():
p.on_new_notifications(max_stream_id)
@@ -222,6 +232,14 @@ class PusherPool:
)
for u in users_affected:
+ # Don't push if the user account has expired
+ if self._account_validity.enabled:
+ expired = await self.store.is_account_expired(
+ u, self.clock.time_msec()
+ )
+ if expired:
+ continue
+
if u in self.pushers:
for p in self.pushers[u].values():
p.on_new_receipts(min_stream_id, max_stream_id)
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 67f019fd22..288631477e 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -37,6 +37,9 @@ logger = logging.getLogger(__name__)
# installed when that optional dependency requirement is specified. It is passed
# to setup() as extras_require in setup.py
#
+# Note that these both represent runtime dependencies (and the versions
+# installed are checked at runtime).
+#
# [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers.
REQUIREMENTS = [
@@ -92,20 +95,12 @@ CONDITIONAL_REQUIREMENTS = {
"oidc": ["authlib>=0.14.0"],
"systemd": ["systemd-python>=231"],
"url_preview": ["lxml>=3.5.0"],
- # Dependencies which are exclusively required by unit test code. This is
- # NOT a list of all modules that are necessary to run the unit tests.
- # Tests assume that all optional dependencies are installed.
- #
- # parameterized_class decorator was introduced in parameterized 0.7.0
- "test": ["mock>=2.0", "parameterized>=0.7.0"],
"sentry": ["sentry-sdk>=0.7.2"],
"opentracing": ["jaeger-client>=4.0.0", "opentracing>=2.2.0"],
"jwt": ["pyjwt>=1.6.4"],
# hiredis is not a *strict* dependency, but it makes things much faster.
# (if it is not installed, we fall back to slow code.)
"redis": ["txredisapi>=1.4.7", "hiredis"],
- # We pin black so that our tests don't start failing on new releases.
- "lint": ["isort==5.0.3", "black==19.10b0", "flake8-comprehensions", "flake8"],
}
ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str]
@@ -113,7 +108,7 @@ ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str]
for name, optional_deps in CONDITIONAL_REQUIREMENTS.items():
# Exclude systemd as it's a system-based requirement.
# Exclude lint as it's a dev-based requirement.
- if name not in ["systemd", "lint"]:
+ if name not in ["systemd"]:
ALL_OPTIONAL_REQUIREMENTS = set(optional_deps) | ALL_OPTIONAL_REQUIREMENTS
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 56d6afb863..5a5ea39e01 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -25,7 +25,6 @@ from typing import (
Sequence,
Set,
Union,
- cast,
overload,
)
@@ -42,7 +41,7 @@ from synapse.logging.utils import log_function
from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.roommember import ProfileInfo
-from synapse.types import Collection, MutableStateMap, StateMap
+from synapse.types import Collection, StateMap
from synapse.util import Clock
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
@@ -472,10 +471,9 @@ class StateResolutionHandler:
def __init__(self, hs):
self.clock = hs.get_clock()
- # dict of set of event_ids -> _StateCacheEntry.
- self._state_cache = None
self.resolve_linearizer = Linearizer(name="state_resolve_lock")
+ # dict of set of event_ids -> _StateCacheEntry.
self._state_cache = ExpiringCache(
cache_name="state_cache",
clock=self.clock,
@@ -519,57 +517,28 @@ class StateResolutionHandler:
Returns:
The resolved state
"""
- logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys())
-
group_names = frozenset(state_groups_ids.keys())
with (await self.resolve_linearizer.queue(group_names)):
- if self._state_cache is not None:
- cache = self._state_cache.get(group_names, None)
- if cache:
- return cache
+ cache = self._state_cache.get(group_names, None)
+ if cache:
+ return cache
logger.info(
- "Resolving state for %s with %d groups", room_id, len(state_groups_ids)
+ "Resolving state for %s with groups %s", room_id, list(group_names),
)
state_groups_histogram.observe(len(state_groups_ids))
- # start by assuming we won't have any conflicted state, and build up the new
- # state map by iterating through the state groups. If we discover a conflict,
- # we give up and instead use `resolve_events_with_store`.
- #
- # XXX: is this actually worthwhile, or should we just let
- # resolve_events_with_store do it?
- new_state = {} # type: MutableStateMap[str]
- conflicted_state = False
- for st in state_groups_ids.values():
- for key, e_id in st.items():
- if key in new_state:
- conflicted_state = True
- break
- new_state[key] = e_id
- if conflicted_state:
- break
-
- if conflicted_state:
- logger.info("Resolving conflicted state for %r", room_id)
- with Measure(self.clock, "state._resolve_events"):
- # resolve_events_with_store returns a StateMap, but we can
- # treat it as a MutableStateMap as it is above. It isn't
- # actually mutated anymore (and is frozen in
- # _make_state_cache_entry below).
- new_state = cast(
- MutableStateMap,
- await resolve_events_with_store(
- self.clock,
- room_id,
- room_version,
- list(state_groups_ids.values()),
- event_map=event_map,
- state_res_store=state_res_store,
- ),
- )
+ with Measure(self.clock, "state._resolve_events"):
+ new_state = await resolve_events_with_store(
+ self.clock,
+ room_id,
+ room_version,
+ list(state_groups_ids.values()),
+ event_map=event_map,
+ state_res_store=state_res_store,
+ )
# if the new state matches any of the input state groups, we can
# use that state group again. Otherwise we will generate a state_id
@@ -579,8 +548,7 @@ class StateResolutionHandler:
with Measure(self.clock, "state.create_group_ids"):
cache = _make_state_cache_entry(new_state, state_groups_ids)
- if self._state_cache is not None:
- self._state_cache[group_names] = cache
+ self._state_cache[group_names] = cache
return cache
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index c5a36990e4..ef81d73573 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -339,7 +339,7 @@ class AccountDataStore(AccountDataWorkerStore):
"""
content_json = json_encoder.encode(content)
- with await self._account_data_id_gen.get_next() as next_id:
+ async with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as room_account_data has a unique constraint
# on (user_id, room_id, account_data_type) so simple_upsert will
# retry if there is a conflict.
@@ -387,7 +387,7 @@ class AccountDataStore(AccountDataWorkerStore):
"""
content_json = json_encoder.encode(content)
- with await self._account_data_id_gen.get_next() as next_id:
+ async with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as account_data has a unique constraint on
# (user_id, account_data_type) so simple_upsert will retry if
# there is a conflict.
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index e71217a41f..d42faa3f1f 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows)
- with await self._device_inbox_id_gen.get_next() as stream_id:
+ async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
@@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
txn, stream_id, local_messages_by_user_then_device
)
- with await self._device_inbox_id_gen.get_next() as stream_id:
+ async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_from_remote_to_device_inbox",
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index c04374e43d..fdf394c612 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -377,7 +377,7 @@ class DeviceWorkerStore(SQLBaseStore):
THe new stream ID.
"""
- with await self._device_list_id_gen.get_next() as stream_id:
+ async with self._device_list_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction(
"add_user_sig_change_to_streams",
self._add_user_signature_change_txn,
@@ -1093,7 +1093,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not device_ids:
return
- with await self._device_list_id_gen.get_next_mult(
+ async with self._device_list_id_gen.get_next_mult(
len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
@@ -1108,7 +1108,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return stream_ids[-1]
context = get_active_span_text_map()
- with await self._device_list_id_gen.get_next_mult(
+ async with self._device_list_id_gen.get_next_mult(
len(hosts) * len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index c8df0bcb3f..22e1ed15d0 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -831,7 +831,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
key (dict): the key data
"""
- with await self._cross_signing_id_gen.get_next() as stream_id:
+ async with self._cross_signing_id_gen.get_next() as stream_id:
return await self.db_pool.runInteraction(
"add_e2e_cross_signing_key",
self._set_e2e_cross_signing_key_txn,
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 9a80f419e3..18def01f50 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -17,7 +17,7 @@
import itertools
import logging
from collections import OrderedDict, namedtuple
-from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
import attr
from prometheus_client import Counter
@@ -156,15 +156,15 @@ class PersistEventsStore:
# Note: Multiple instances of this function cannot be in flight at
# the same time for the same room.
if backfilled:
- stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
+ stream_ordering_manager = self._backfill_id_gen.get_next_mult(
len(events_and_contexts)
)
else:
- stream_ordering_manager = await self._stream_id_gen.get_next_mult(
+ stream_ordering_manager = self._stream_id_gen.get_next_mult(
len(events_and_contexts)
)
- with stream_ordering_manager as stream_orderings:
+ async with stream_ordering_manager as stream_orderings:
for (event, context), stream in zip(events_and_contexts, stream_orderings):
event.internal_metadata.stream_ordering = stream
@@ -1108,6 +1108,10 @@ class PersistEventsStore:
def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database.
"""
+
+ def str_or_none(val: Any) -> Optional[str]:
+ return val if isinstance(val, str) else None
+
self.db_pool.simple_insert_many_txn(
txn,
table="room_memberships",
@@ -1118,8 +1122,8 @@ class PersistEventsStore:
"sender": event.user_id,
"room_id": event.room_id,
"membership": event.membership,
- "display_name": event.content.get("displayname", None),
- "avatar_url": event.content.get("avatar_url", None),
+ "display_name": str_or_none(event.content.get("displayname")),
+ "avatar_url": str_or_none(event.content.get("avatar_url")),
}
for event in events
],
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index ccfbb2135e..7218191965 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -1265,7 +1265,7 @@ class GroupServerStore(GroupServerWorkerStore):
return next_id
- with await self._group_updates_id_gen.get_next() as next_id:
+ async with self._group_updates_id_gen.get_next() as next_id:
res = await self.db_pool.runInteraction(
"register_user_group_membership",
_register_user_group_membership_txn,
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index c9f655dfb7..dbbb99cb95 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -23,11 +23,11 @@ from synapse.util.iterutils import batch_iter
class PresenceStore(SQLBaseStore):
async def update_presence(self, presence_states):
- stream_ordering_manager = await self._presence_id_gen.get_next_mult(
+ stream_ordering_manager = self._presence_id_gen.get_next_mult(
len(presence_states)
)
- with stream_ordering_manager as stream_orderings:
+ async with stream_ordering_manager as stream_orderings:
await self.db_pool.runInteraction(
"update_presence",
self._update_presence_txn,
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index e20a16f907..711d5aa23d 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore):
) -> None:
conditions_json = json_encoder.encode(conditions)
actions_json = json_encoder.encode(actions)
- with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
if before or after:
@@ -585,7 +585,7 @@ class PushRuleStore(PushRulesWorkerStore):
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
)
- with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
@@ -616,7 +616,7 @@ class PushRuleStore(PushRulesWorkerStore):
Raises:
NotFoundError if the rule does not exist.
"""
- with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
"_set_push_rule_enabled_txn",
@@ -754,7 +754,7 @@ class PushRuleStore(PushRulesWorkerStore):
data={"actions": actions_json},
)
- with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index c388468273..df8609b97b 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore):
last_stream_ordering,
profile_tag="",
) -> None:
- with await self._pushers_id_gen.get_next() as stream_id:
+ async with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so simple_upsert will retry
await self.db_pool.simple_upsert(
@@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore):
},
)
- with await self._pushers_id_gen.get_next() as stream_id:
+ async with self._pushers_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction(
"delete_pusher", delete_pusher_txn, stream_id
)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index f880b5e562..c79ddff680 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -524,7 +524,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"insert_receipt_conv", graph_to_linear
)
- with await self._receipts_id_gen.get_next() as stream_id:
+ async with self._receipts_id_gen.get_next() as stream_id:
event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 4d5aa13a89..a06451b7f0 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -116,6 +116,20 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_expiration_ts_for_user",
)
+ async def is_account_expired(self, user_id: str, current_ts: int) -> bool:
+ """
+ Returns whether an user account is expired.
+
+ Args:
+ user_id: The user's ID
+ current_ts: The current timestamp
+
+ Returns:
+ Whether the user account has expired
+ """
+ expiration_ts = await self.get_expiration_ts_for_user(user_id)
+ return expiration_ts is not None and current_ts >= expiration_ts
+
async def set_account_validity_for_user(
self,
user_id: str,
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index a3d506a8c9..8fab8de973 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1159,7 +1159,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with await self._public_room_id_gen.get_next() as next_id:
+ async with self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"store_room_txn", store_room_txn, next_id
)
@@ -1226,7 +1226,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with await self._public_room_id_gen.get_next() as next_id:
+ async with self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"set_room_is_public", set_room_is_public_txn, next_id
)
@@ -1306,7 +1306,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with await self._public_room_id_gen.get_next() as next_id:
+ async with self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"set_room_is_public_appservice",
set_room_is_public_appservice_txn,
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 96ffe26cc9..9f120d3cb6 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore):
)
self._update_revision_txn(txn, user_id, room_id, next_id)
- with await self._account_data_id_gen.get_next() as next_id:
+ async with self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
@@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore):
txn.execute(sql, (user_id, room_id, tag))
self._update_revision_txn(txn, user_id, room_id, next_id)
- with await self._account_data_id_gen.get_next() as next_id:
+ async with self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 1de2b91587..b0353ac2dc 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -12,14 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-import contextlib
import heapq
import logging
import threading
from collections import deque
-from typing import Dict, List, Set
+from contextlib import contextmanager
+from typing import Dict, List, Optional, Set, Union
+import attr
from typing_extensions import Deque
from synapse.storage.database import DatabasePool, LoggingTransaction
@@ -86,7 +86,7 @@ class StreamIdGenerator:
upwards, -1 to grow downwards.
Usage:
- with await stream_id_gen.get_next() as stream_id:
+ async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
@@ -101,10 +101,10 @@ class StreamIdGenerator:
)
self._unfinished_ids = deque() # type: Deque[int]
- async def get_next(self):
+ def get_next(self):
"""
Usage:
- with await stream_id_gen.get_next() as stream_id:
+ async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
with self._lock:
@@ -113,7 +113,7 @@ class StreamIdGenerator:
self._unfinished_ids.append(next_id)
- @contextlib.contextmanager
+ @contextmanager
def manager():
try:
yield next_id
@@ -121,12 +121,12 @@ class StreamIdGenerator:
with self._lock:
self._unfinished_ids.remove(next_id)
- return manager()
+ return _AsyncCtxManagerWrapper(manager())
- async def get_next_mult(self, n):
+ def get_next_mult(self, n):
"""
Usage:
- with await stream_id_gen.get_next(n) as stream_ids:
+ async with stream_id_gen.get_next(n) as stream_ids:
# ... persist events ...
"""
with self._lock:
@@ -140,7 +140,7 @@ class StreamIdGenerator:
for next_id in next_ids:
self._unfinished_ids.append(next_id)
- @contextlib.contextmanager
+ @contextmanager
def manager():
try:
yield next_ids
@@ -149,7 +149,7 @@ class StreamIdGenerator:
for next_id in next_ids:
self._unfinished_ids.remove(next_id)
- return manager()
+ return _AsyncCtxManagerWrapper(manager())
def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or
@@ -282,59 +282,23 @@ class MultiWriterIdGenerator:
def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
return self._sequence_gen.get_next_mult_txn(txn, n)
- async def get_next(self):
+ def get_next(self):
"""
Usage:
- with await stream_id_gen.get_next() as stream_id:
+ async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
- next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
-
- # Assert the fetched ID is actually greater than what we currently
- # believe the ID to be. If not, then the sequence and table have got
- # out of sync somehow.
- with self._lock:
- assert self._current_positions.get(self._instance_name, 0) < next_id
-
- self._unfinished_ids.add(next_id)
-
- @contextlib.contextmanager
- def manager():
- try:
- # Multiply by the return factor so that the ID has correct sign.
- yield self._return_factor * next_id
- finally:
- self._mark_id_as_finished(next_id)
- return manager()
+ return _MultiWriterCtxManager(self)
- async def get_next_mult(self, n: int):
+ def get_next_mult(self, n: int):
"""
Usage:
- with await stream_id_gen.get_next_mult(5) as stream_ids:
+ async with stream_id_gen.get_next_mult(5) as stream_ids:
# ... persist events ...
"""
- next_ids = await self._db.runInteraction(
- "_load_next_mult_id", self._load_next_mult_id_txn, n
- )
- # Assert the fetched ID is actually greater than any ID we've already
- # seen. If not, then the sequence and table have got out of sync
- # somehow.
- with self._lock:
- assert max(self._current_positions.values(), default=0) < min(next_ids)
-
- self._unfinished_ids.update(next_ids)
-
- @contextlib.contextmanager
- def manager():
- try:
- yield [self._return_factor * i for i in next_ids]
- finally:
- for i in next_ids:
- self._mark_id_as_finished(i)
-
- return manager()
+ return _MultiWriterCtxManager(self, n)
def get_next_txn(self, txn: LoggingTransaction):
"""
@@ -482,3 +446,61 @@ class MultiWriterIdGenerator:
# There was a gap in seen positions, so there is nothing more to
# do.
break
+
+
+@attr.s(slots=True)
+class _AsyncCtxManagerWrapper:
+ """Helper class to convert a plain context manager to an async one.
+
+ This is mainly useful if you have a plain context manager but the interface
+ requires an async one.
+ """
+
+ inner = attr.ib()
+
+ async def __aenter__(self):
+ return self.inner.__enter__()
+
+ async def __aexit__(self, exc_type, exc, tb):
+ return self.inner.__exit__(exc_type, exc, tb)
+
+
+@attr.s(slots=True)
+class _MultiWriterCtxManager:
+ """Async context manager returned by MultiWriterIdGenerator
+ """
+
+ id_gen = attr.ib(type=MultiWriterIdGenerator)
+ multiple_ids = attr.ib(type=Optional[int], default=None)
+ stream_ids = attr.ib(type=List[int], factory=list)
+
+ async def __aenter__(self) -> Union[int, List[int]]:
+ self.stream_ids = await self.id_gen._db.runInteraction(
+ "_load_next_mult_id",
+ self.id_gen._load_next_mult_id_txn,
+ self.multiple_ids or 1,
+ )
+
+ # Assert the fetched ID is actually greater than any ID we've already
+ # seen. If not, then the sequence and table have got out of sync
+ # somehow.
+ with self.id_gen._lock:
+ assert max(self.id_gen._current_positions.values(), default=0) < min(
+ self.stream_ids
+ )
+
+ self.id_gen._unfinished_ids.update(self.stream_ids)
+
+ if self.multiple_ids is None:
+ return self.stream_ids[0] * self.id_gen._return_factor
+ else:
+ return [i * self.id_gen._return_factor for i in self.stream_ids]
+
+ async def __aexit__(self, exc_type, exc, tb):
+ for i in self.stream_ids:
+ self.id_gen._mark_id_as_finished(i)
+
+ if exc_type is not None:
+ return False
+
+ return False
|