From 52ed9655edf8849d68a178e1c76040c79824a353 Mon Sep 17 00:00:00 2001 From: Dan Callahan Date: Fri, 14 May 2021 10:59:10 +0100 Subject: Remove unnecessary SystemRandom from SQLBaseStore (#9987) It's not obvious that instances of SQLBaseStore each need their own instances of random.SystemRandom(); let's just use random directly. Introduced by 52839886d664576831462e033b88e5aba4c019e3 Signed-off-by: Dan Callahan --- synapse/storage/databases/main/registration.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'synapse/storage/databases/main') diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 6e5ee557d2..e5c5cf8ff0 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import random import re from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union @@ -997,7 +998,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): expiration_ts = now_ms + self._account_validity_period if use_delta: - expiration_ts = self.rand.randrange( + expiration_ts = random.randrange( expiration_ts - self._account_validity_startup_job_max_delta, expiration_ts, ) -- cgit 1.5.1 From 5090f26b636bf4439575767a2272d033fb33b2d5 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 14 May 2021 11:12:36 +0100 Subject: Minor `@cachedList` enhancements (#9975) - use a tuple rather than a list for the iterable that is passed into the wrapped function, for performance - test that we can pass an iterable and that keys are correctly deduped. --- changelog.d/9975.misc | 1 + synapse/storage/databases/main/devices.py | 2 +- synapse/storage/databases/main/end_to_end_keys.py | 4 ++-- synapse/storage/databases/main/user_erasure_store.py | 13 +++++-------- synapse/util/caches/descriptors.py | 14 ++++++++------ tests/util/caches/test_descriptors.py | 17 ++++++++++++++--- 6 files changed, 31 insertions(+), 20 deletions(-) create mode 100644 changelog.d/9975.misc (limited to 'synapse/storage/databases/main') diff --git a/changelog.d/9975.misc b/changelog.d/9975.misc new file mode 100644 index 0000000000..28b1e40c2b --- /dev/null +++ b/changelog.d/9975.misc @@ -0,0 +1 @@ +Minor enhancements to the `@cachedList` descriptor. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index c9346de316..a1f98b7e38 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -665,7 +665,7 @@ class DeviceWorkerStore(SQLBaseStore): cached_method_name="get_device_list_last_stream_id_for_remote", list_name="user_ids", ) - async def get_device_list_last_stream_id_for_remotes(self, user_ids: str): + async def get_device_list_last_stream_id_for_remotes(self, user_ids: Iterable[str]): rows = await self.db_pool.simple_select_many_batch( table="device_lists_remote_extremeties", column="user_id", diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 398d6b6acb..9ba5778a88 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -473,7 +473,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): num_args=1, ) async def _get_bare_e2e_cross_signing_keys_bulk( - self, user_ids: List[str] + self, user_ids: Iterable[str] ) -> Dict[str, Dict[str, dict]]: """Returns the cross-signing keys for a set of users. The output of this function should be passed to _get_e2e_cross_signing_signatures_txn if @@ -497,7 +497,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): def _get_bare_e2e_cross_signing_keys_bulk_txn( self, txn: Connection, - user_ids: List[str], + user_ids: Iterable[str], ) -> Dict[str, Dict[str, dict]]: """Returns the cross-signing keys for a set of users. The output of this function should be passed to _get_e2e_cross_signing_signatures_txn if diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py index acf6b2fb64..1ecdd40c38 100644 --- a/synapse/storage/databases/main/user_erasure_store.py +++ b/synapse/storage/databases/main/user_erasure_store.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, Iterable + from synapse.storage._base import SQLBaseStore from synapse.util.caches.descriptors import cached, cachedList @@ -37,21 +39,16 @@ class UserErasureWorkerStore(SQLBaseStore): return bool(result) @cachedList(cached_method_name="is_user_erased", list_name="user_ids") - async def are_users_erased(self, user_ids): + async def are_users_erased(self, user_ids: Iterable[str]) -> Dict[str, bool]: """ Checks which users in a list have requested erasure Args: - user_ids (iterable[str]): full user id to check + user_ids: full user ids to check Returns: - dict[str, bool]: - for each user, whether the user has requested erasure. + for each user, whether the user has requested erasure. """ - # this serves the dual purpose of (a) making sure we can do len and - # iterate it multiple times, and (b) avoiding duplicates. - user_ids = tuple(set(user_ids)) - rows = await self.db_pool.simple_select_many_batch( table="erased_users", column="user_id", diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index ac4a078b26..3a4d027095 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -322,8 +322,8 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): class DeferredCacheListDescriptor(_CacheDescriptorBase): """Wraps an existing cache to support bulk fetching of keys. - Given a list of keys it looks in the cache to find any hits, then passes - the list of missing keys to the wrapped function. + Given an iterable of keys it looks in the cache to find any hits, then passes + the tuple of missing keys to the wrapped function. Once wrapped, the function returns a Deferred which resolves to the list of results. @@ -437,7 +437,9 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): return f args_to_call = dict(arg_dict) - args_to_call[self.list_name] = list(missing) + # copy the missing set before sending it to the callee, to guard against + # modification. + args_to_call[self.list_name] = tuple(missing) cached_defers.append( defer.maybeDeferred( @@ -522,14 +524,14 @@ def cachedList( Used to do batch lookups for an already created cache. A single argument is specified as a list that is iterated through to lookup keys in the - original cache. A new list consisting of the keys that weren't in the cache - get passed to the original function, the result of which is stored in the + original cache. A new tuple consisting of the (deduplicated) keys that weren't in + the cache gets passed to the original function, the result of which is stored in the cache. Args: cached_method_name: The name of the single-item lookup method. This is only used to find the cache to use. - list_name: The name of the argument that is the list to use to + list_name: The name of the argument that is the iterable to use to do batch lookups in the cache. num_args: Number of arguments to use as the key in the cache (including list_name). Defaults to all named parameters. diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 178ac8a68c..bbbc276697 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -666,18 +666,20 @@ class CachedListDescriptorTestCase(unittest.TestCase): with LoggingContext("c1") as c1: obj = Cls() obj.mock.return_value = {10: "fish", 20: "chips"} + + # start the lookup off d1 = obj.list_fn([10, 20], 2) self.assertEqual(current_context(), SENTINEL_CONTEXT) r = yield d1 self.assertEqual(current_context(), c1) - obj.mock.assert_called_once_with([10, 20], 2) + obj.mock.assert_called_once_with((10, 20), 2) self.assertEqual(r, {10: "fish", 20: "chips"}) obj.mock.reset_mock() # a call with different params should call the mock again obj.mock.return_value = {30: "peas"} r = yield obj.list_fn([20, 30], 2) - obj.mock.assert_called_once_with([30], 2) + obj.mock.assert_called_once_with((30,), 2) self.assertEqual(r, {20: "chips", 30: "peas"}) obj.mock.reset_mock() @@ -692,6 +694,15 @@ class CachedListDescriptorTestCase(unittest.TestCase): obj.mock.assert_not_called() self.assertEqual(r, {10: "fish", 20: "chips", 30: "peas"}) + # we should also be able to use a (single-use) iterable, and should + # deduplicate the keys + obj.mock.reset_mock() + obj.mock.return_value = {40: "gravy"} + iterable = (x for x in [10, 40, 40]) + r = yield obj.list_fn(iterable, 2) + obj.mock.assert_called_once_with((40,), 2) + self.assertEqual(r, {10: "fish", 40: "gravy"}) + @defer.inlineCallbacks def test_invalidate(self): """Make sure that invalidation callbacks are called.""" @@ -717,7 +728,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): # cache miss obj.mock.return_value = {10: "fish", 20: "chips"} r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0) - obj.mock.assert_called_once_with([10, 20], 2) + obj.mock.assert_called_once_with((10, 20), 2) self.assertEqual(r1, {10: "fish", 20: "chips"}) obj.mock.reset_mock() -- cgit 1.5.1 From 4d6e5a5e995590efe44855d10dcd2a89b841dae8 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 18 May 2021 14:13:45 +0100 Subject: Use a database table to hold the users that should have full presence sent to them, instead of something in-memory (#9823) --- changelog.d/9823.misc | 1 + docs/presence_router_module.md | 6 +- synapse/handlers/presence.py | 136 +++++++-- synapse/module_api/__init__.py | 63 ++--- synapse/replication/http/presence.py | 11 +- synapse/rest/admin/server_notice_servlet.py | 8 +- synapse/storage/databases/main/presence.py | 58 +++- .../delta/59/13users_to_send_full_presence_to.sql | 34 +++ tests/events/test_presence_router.py | 15 +- tests/module_api/test_api.py | 303 +++++++++++++++------ tests/replication/test_sharded_event_persister.py | 2 +- 11 files changed, 479 insertions(+), 158 deletions(-) create mode 100644 changelog.d/9823.misc create mode 100644 synapse/storage/schema/main/delta/59/13users_to_send_full_presence_to.sql (limited to 'synapse/storage/databases/main') diff --git a/changelog.d/9823.misc b/changelog.d/9823.misc new file mode 100644 index 0000000000..bf924ab68c --- /dev/null +++ b/changelog.d/9823.misc @@ -0,0 +1 @@ +Allow sending full presence to users via workers other than the one that called `ModuleApi.send_local_online_presence_to`. \ No newline at end of file diff --git a/docs/presence_router_module.md b/docs/presence_router_module.md index d6566d978d..d2844915df 100644 --- a/docs/presence_router_module.md +++ b/docs/presence_router_module.md @@ -28,7 +28,11 @@ async def ModuleApi.send_local_online_presence_to(users: Iterable[str]) -> None which can be given a list of local or remote MXIDs to broadcast known, online user presence to (for those users that the receiving user is considered interested in). It does not include state for users who are currently offline, and it can only be -called on workers that support sending federation. +called on workers that support sending federation. Additionally, this method must +only be called from the process that has been configured to write to the +the [presence stream](https://github.com/matrix-org/synapse/blob/master/docs/workers.md#stream-writers). +By default, this is the main process, but another worker can be configured to do +so. ### Module structure diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 6fd1f34289..f5a049d754 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -222,9 +222,21 @@ class BasePresenceHandler(abc.ABC): @abc.abstractmethod async def set_state( - self, target_user: UserID, state: JsonDict, ignore_status_msg: bool = False + self, + target_user: UserID, + state: JsonDict, + ignore_status_msg: bool = False, + force_notify: bool = False, ) -> None: - """Set the presence state of the user. """ + """Set the presence state of the user. + + Args: + target_user: The ID of the user to set the presence state of. + state: The presence state as a JSON dictionary. + ignore_status_msg: True to ignore the "status_msg" field of the `state` dict. + If False, the user's current status will be updated. + force_notify: Whether to force notification of the update to clients. + """ @abc.abstractmethod async def bump_presence_active_time(self, user: UserID): @@ -296,6 +308,51 @@ class BasePresenceHandler(abc.ABC): for destinations, states in hosts_and_states: self._federation.send_presence_to_destinations(states, destinations) + async def send_full_presence_to_users(self, user_ids: Collection[str]): + """ + Adds to the list of users who should receive a full snapshot of presence + upon their next sync. Note that this only works for local users. + + Then, grabs the current presence state for a given set of users and adds it + to the top of the presence stream. + + Args: + user_ids: The IDs of the local users to send full presence to. + """ + # Retrieve one of the users from the given set + if not user_ids: + raise Exception( + "send_full_presence_to_users must be called with at least one user" + ) + user_id = next(iter(user_ids)) + + # Mark all users as receiving full presence on their next sync + await self.store.add_users_to_send_full_presence_to(user_ids) + + # Add a new entry to the presence stream. Since we use stream tokens to determine whether a + # local user should receive a full snapshot of presence when they sync, we need to bump the + # presence stream so that subsequent syncs with no presence activity in between won't result + # in the client receiving multiple full snapshots of presence. + # + # If we bump the stream ID, then the user will get a higher stream token next sync, and thus + # correctly won't receive a second snapshot. + + # Get the current presence state for one of the users (defaults to offline if not found) + current_presence_state = await self.get_state(UserID.from_string(user_id)) + + # Convert the UserPresenceState object into a serializable dict + state = { + "presence": current_presence_state.state, + "status_message": current_presence_state.status_msg, + } + + # Copy the presence state to the tip of the presence stream. + + # We set force_notify=True here so that this presence update is guaranteed to + # increment the presence stream ID (which resending the current user's presence + # otherwise would not do). + await self.set_state(UserID.from_string(user_id), state, force_notify=True) + class _NullContextManager(ContextManager[None]): """A context manager which does nothing.""" @@ -480,8 +537,17 @@ class WorkerPresenceHandler(BasePresenceHandler): target_user: UserID, state: JsonDict, ignore_status_msg: bool = False, + force_notify: bool = False, ) -> None: - """Set the presence state of the user.""" + """Set the presence state of the user. + + Args: + target_user: The ID of the user to set the presence state of. + state: The presence state as a JSON dictionary. + ignore_status_msg: True to ignore the "status_msg" field of the `state` dict. + If False, the user's current status will be updated. + force_notify: Whether to force notification of the update to clients. + """ presence = state["presence"] valid_presence = ( @@ -508,6 +574,7 @@ class WorkerPresenceHandler(BasePresenceHandler): user_id=user_id, state=state, ignore_status_msg=ignore_status_msg, + force_notify=force_notify, ) async def bump_presence_active_time(self, user: UserID) -> None: @@ -677,13 +744,19 @@ class PresenceHandler(BasePresenceHandler): [self.user_to_current_state[user_id] for user_id in unpersisted] ) - async def _update_states(self, new_states: Iterable[UserPresenceState]) -> None: + async def _update_states( + self, new_states: Iterable[UserPresenceState], force_notify: bool = False + ) -> None: """Updates presence of users. Sets the appropriate timeouts. Pokes the notifier and federation if and only if the changed presence state should be sent to clients/servers. Args: new_states: The new user presence state updates to process. + force_notify: Whether to force notifying clients of this presence state update, + even if it doesn't change the state of a user's presence (e.g online -> online). + This is currently used to bump the max presence stream ID without changing any + user's presence (see PresenceHandler.add_users_to_send_full_presence_to). """ now = self.clock.time_msec() @@ -720,6 +793,9 @@ class PresenceHandler(BasePresenceHandler): now=now, ) + if force_notify: + should_notify = True + self.user_to_current_state[user_id] = new_state if should_notify: @@ -1058,9 +1134,21 @@ class PresenceHandler(BasePresenceHandler): await self._update_states(updates) async def set_state( - self, target_user: UserID, state: JsonDict, ignore_status_msg: bool = False + self, + target_user: UserID, + state: JsonDict, + ignore_status_msg: bool = False, + force_notify: bool = False, ) -> None: - """Set the presence state of the user.""" + """Set the presence state of the user. + + Args: + target_user: The ID of the user to set the presence state of. + state: The presence state as a JSON dictionary. + ignore_status_msg: True to ignore the "status_msg" field of the `state` dict. + If False, the user's current status will be updated. + force_notify: Whether to force notification of the update to clients. + """ status_msg = state.get("status_msg", None) presence = state["presence"] @@ -1091,7 +1179,9 @@ class PresenceHandler(BasePresenceHandler): ): new_fields["last_active_ts"] = self.clock.time_msec() - await self._update_states([prev_state.copy_and_replace(**new_fields)]) + await self._update_states( + [prev_state.copy_and_replace(**new_fields)], force_notify=force_notify + ) async def is_visible(self, observed_user: UserID, observer_user: UserID) -> bool: """Returns whether a user can see another user's presence.""" @@ -1389,11 +1479,10 @@ class PresenceEventSource: # # Presence -> Notifier -> PresenceEventSource -> Presence # - # Same with get_module_api, get_presence_router + # Same with get_presence_router: # # AuthHandler -> Notifier -> PresenceEventSource -> ModuleApi -> AuthHandler self.get_presence_handler = hs.get_presence_handler - self.get_module_api = hs.get_module_api self.get_presence_router = hs.get_presence_router self.clock = hs.get_clock() self.store = hs.get_datastore() @@ -1424,16 +1513,21 @@ class PresenceEventSource: stream_change_cache = self.store.presence_stream_cache with Measure(self.clock, "presence.get_new_events"): - if user_id in self.get_module_api()._send_full_presence_to_local_users: - # This user has been specified by a module to receive all current, online - # user presence. Removing from_key and setting include_offline to false - # will do effectively this. - from_key = None - include_offline = False - if from_key is not None: from_key = int(from_key) + # Check if this user should receive all current, online user presence. We only + # bother to do this if from_key is set, as otherwise the user will receive all + # user presence anyways. + if await self.store.should_user_receive_full_presence_with_token( + user_id, from_key + ): + # This user has been specified by a module to receive all current, online + # user presence. Removing from_key and setting include_offline to false + # will do effectively this. + from_key = None + include_offline = False + max_token = self.store.get_current_presence_token() if from_key == max_token: # This is necessary as due to the way stream ID generators work @@ -1467,12 +1561,6 @@ class PresenceEventSource: user_id, include_offline, from_key ) - # Remove the user from the list of users to receive all presence - if user_id in self.get_module_api()._send_full_presence_to_local_users: - self.get_module_api()._send_full_presence_to_local_users.remove( - user_id - ) - return presence_updates, max_token # Make mypy happy. users_interested_in should now be a set @@ -1522,10 +1610,6 @@ class PresenceEventSource: ) presence_updates = list(users_to_state.values()) - # Remove the user from the list of users to receive all presence - if user_id in self.get_module_api()._send_full_presence_to_local_users: - self.get_module_api()._send_full_presence_to_local_users.remove(user_id) - if not include_offline: # Filter out offline presence states presence_updates = self._filter_offline_presence_state(presence_updates) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index a1a2b9aecc..cecdc96bf5 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -56,14 +56,6 @@ class ModuleApi: self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient self._public_room_list_manager = PublicRoomListManager(hs) - # The next time these users sync, they will receive the current presence - # state of all local users. Users are added by send_local_online_presence_to, - # and removed after a successful sync. - # - # We make this a private variable to deter modules from accessing it directly, - # though other classes in Synapse will still do so. - self._send_full_presence_to_local_users = set() - @property def http_client(self): """Allows making outbound HTTP requests to remote resources. @@ -405,39 +397,44 @@ class ModuleApi: Updates to remote users will be sent immediately, whereas local users will receive them on their next sync attempt. - Note that this method can only be run on the main or federation_sender worker - processes. + Note that this method can only be run on the process that is configured to write to the + presence stream. By default this is the main process. """ - if not self._hs.should_send_federation(): + if self._hs._instance_name not in self._hs.config.worker.writers.presence: raise Exception( "send_local_online_presence_to can only be run " - "on processes that send federation", + "on the process that is configured to write to the " + "presence stream (by default this is the main process)", ) + local_users = set() + remote_users = set() for user in users: if self._hs.is_mine_id(user): - # Modify SyncHandler._generate_sync_entry_for_presence to call - # presence_source.get_new_events with an empty `from_key` if - # that user's ID were in a list modified by ModuleApi somewhere. - # That user would then get all presence state on next incremental sync. - - # Force a presence initial_sync for this user next time - self._send_full_presence_to_local_users.add(user) + local_users.add(user) else: - # Retrieve presence state for currently online users that this user - # is considered interested in - presence_events, _ = await self._presence_stream.get_new_events( - UserID.from_string(user), from_key=None, include_offline=False - ) - - # Send to remote destinations. - - # We pull out the presence handler here to break a cyclic - # dependency between the presence router and module API. - presence_handler = self._hs.get_presence_handler() - await presence_handler.maybe_send_presence_to_interested_destinations( - presence_events - ) + remote_users.add(user) + + # We pull out the presence handler here to break a cyclic + # dependency between the presence router and module API. + presence_handler = self._hs.get_presence_handler() + + if local_users: + # Force a presence initial_sync for these users next time they sync. + await presence_handler.send_full_presence_to_users(local_users) + + for user in remote_users: + # Retrieve presence state for currently online users that this user + # is considered interested in. + presence_events, _ = await self._presence_stream.get_new_events( + UserID.from_string(user), from_key=None, include_offline=False + ) + + # Send to remote destinations. + destination = UserID.from_string(user).domain + presence_handler.get_federation_queue().send_presence_to_destinations( + presence_events, destination + ) class PublicRoomListManager: diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py index f25307620d..bb00247953 100644 --- a/synapse/replication/http/presence.py +++ b/synapse/replication/http/presence.py @@ -73,6 +73,7 @@ class ReplicationPresenceSetState(ReplicationEndpoint): { "state": { ... }, "ignore_status_msg": false, + "force_notify": false } 200 OK @@ -91,17 +92,23 @@ class ReplicationPresenceSetState(ReplicationEndpoint): self._presence_handler = hs.get_presence_handler() @staticmethod - async def _serialize_payload(user_id, state, ignore_status_msg=False): + async def _serialize_payload( + user_id, state, ignore_status_msg=False, force_notify=False + ): return { "state": state, "ignore_status_msg": ignore_status_msg, + "force_notify": force_notify, } async def _handle_request(self, request, user_id): content = parse_json_object_from_request(request) await self._presence_handler.set_state( - UserID.from_string(user_id), content["state"], content["ignore_status_msg"] + UserID.from_string(user_id), + content["state"], + content["ignore_status_msg"], + content["force_notify"], ) return ( diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py index cc3ab5854b..b5e4c474ef 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py @@ -54,7 +54,6 @@ class SendServerNoticeServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.txns = HttpTransactionCache(hs) - self.snm = hs.get_server_notices_manager() def register(self, json_resource: HttpServer): PATTERN = "/send_server_notice" @@ -77,7 +76,10 @@ class SendServerNoticeServlet(RestServlet): event_type = body.get("type", EventTypes.Message) state_key = body.get("state_key") - if not self.snm.is_enabled(): + # We grab the server notices manager here as its initialisation has a check for worker processes, + # but worker processes still need to initialise SendServerNoticeServlet (as it is part of the + # admin api). + if not self.hs.get_server_notices_manager().is_enabled(): raise SynapseError(400, "Server notices are not enabled on this server") user_id = body["user_id"] @@ -85,7 +87,7 @@ class SendServerNoticeServlet(RestServlet): if not self.hs.is_mine_id(user_id): raise SynapseError(400, "Server notices can only be sent to local users") - event = await self.snm.send_notice( + event = await self.hs.get_server_notices_manager().send_notice( user_id=body["user_id"], type=event_type, state_key=state_key, diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index db22fab23e..669a2af884 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, List, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple from synapse.api.presence import PresenceState, UserPresenceState from synapse.replication.tcp.streams import PresenceStream @@ -57,6 +57,7 @@ class PresenceStore(SQLBaseStore): db_conn, "presence_stream", "stream_id" ) + self.hs = hs self._presence_on_startup = self._get_active_presence(db_conn) presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict( @@ -210,6 +211,61 @@ class PresenceStore(SQLBaseStore): return {row["user_id"]: UserPresenceState(**row) for row in rows} + async def should_user_receive_full_presence_with_token( + self, + user_id: str, + from_token: int, + ) -> bool: + """Check whether the given user should receive full presence using the stream token + they're updating from. + + Args: + user_id: The ID of the user to check. + from_token: The stream token included in their /sync token. + + Returns: + True if the user should have full presence sent to them, False otherwise. + """ + + def _should_user_receive_full_presence_with_token_txn(txn): + sql = """ + SELECT 1 FROM users_to_send_full_presence_to + WHERE user_id = ? + AND presence_stream_id >= ? + """ + txn.execute(sql, (user_id, from_token)) + return bool(txn.fetchone()) + + return await self.db_pool.runInteraction( + "should_user_receive_full_presence_with_token", + _should_user_receive_full_presence_with_token_txn, + ) + + async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]): + """Adds to the list of users who should receive a full snapshot of presence + upon their next sync. + + Args: + user_ids: An iterable of user IDs. + """ + # Add user entries to the table, updating the presence_stream_id column if the user already + # exists in the table. + await self.db_pool.simple_upsert_many( + table="users_to_send_full_presence_to", + key_names=("user_id",), + key_values=[(user_id,) for user_id in user_ids], + value_names=("presence_stream_id",), + # We save the current presence stream ID token along with the user ID entry so + # that when a user /sync's, even if they syncing multiple times across separate + # devices at different times, each device will receive full presence once - when + # the presence stream ID in their sync token is less than the one in the table + # for their user ID. + value_values=( + (self._presence_id_gen.get_current_token(),) for _ in user_ids + ), + desc="add_users_to_send_full_presence_to", + ) + async def get_presence_for_all_users( self, include_offline: bool = True, diff --git a/synapse/storage/schema/main/delta/59/13users_to_send_full_presence_to.sql b/synapse/storage/schema/main/delta/59/13users_to_send_full_presence_to.sql new file mode 100644 index 0000000000..07b0f53ecf --- /dev/null +++ b/synapse/storage/schema/main/delta/59/13users_to_send_full_presence_to.sql @@ -0,0 +1,34 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Add a table that keeps track of a list of users who should, upon their next +-- sync request, receive presence for all currently online users that they are +-- "interested" in. + +-- The motivation for a DB table over an in-memory list is so that this list +-- can be added to and retrieved from by any worker. Specifically, we don't +-- want to duplicate work across multiple sync workers. + +CREATE TABLE IF NOT EXISTS users_to_send_full_presence_to( + -- The user ID to send full presence to. + user_id TEXT PRIMARY KEY, + -- A presence stream ID token - the current presence stream token when the row was last upserted. + -- If a user calls /sync and this token is part of the update they're to receive, we also include + -- full user presence in the response. + -- This allows multiple devices for a user to receive full presence whenever they next call /sync. + presence_stream_id BIGINT, + FOREIGN KEY (user_id) + REFERENCES users (name) +); \ No newline at end of file diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py index 01d257307c..875b0d0a11 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py @@ -302,11 +302,18 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): ) # Check that the expected presence updates were sent - expected_users = [ + # We explicitly compare using sets as we expect that calling + # module_api.send_local_online_presence_to will create a presence + # update that is a duplicate of the specified user's current presence. + # These are sent to clients and will be picked up below, thus we use a + # set to deduplicate. We're just interested that non-offline updates were + # sent out for each user ID. + expected_users = { self.other_user_id, self.presence_receiving_user_one_id, self.presence_receiving_user_two_id, - ] + } + found_users = set() calls = ( self.hs.get_federation_transport_client().send_transaction.call_args_list @@ -326,12 +333,12 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): # EDUs can contain multiple presence updates for presence_update in edu["content"]["push"]: # Check for presence updates that contain the user IDs we're after - expected_users.remove(presence_update["user_id"]) + found_users.add(presence_update["user_id"]) # Ensure that no offline states are being sent out self.assertNotEqual(presence_update["presence"], "offline") - self.assertEqual(len(expected_users), 0) + self.assertEqual(found_users, expected_users) def send_presence_update( diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 742ad14b8c..2c68b9a13c 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -13,6 +13,8 @@ # limitations under the License. from unittest.mock import Mock +from twisted.internet import defer + from synapse.api.constants import EduTypes from synapse.events import EventBase from synapse.federation.units import Transaction @@ -22,11 +24,13 @@ from synapse.rest.client.v1 import login, presence, room from synapse.types import create_requester from tests.events.test_presence_router import send_presence_update, sync_presence +from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.test_utils.event_injection import inject_member_event -from tests.unittest import FederatingHomeserverTestCase, override_config +from tests.unittest import HomeserverTestCase, override_config +from tests.utils import USE_POSTGRES_FOR_TESTS -class ModuleApiTestCase(FederatingHomeserverTestCase): +class ModuleApiTestCase(HomeserverTestCase): servlets = [ admin.register_servlets, login.register_servlets, @@ -217,97 +221,16 @@ class ModuleApiTestCase(FederatingHomeserverTestCase): ) self.assertFalse(is_in_public_rooms) - # The ability to send federation is required by send_local_online_presence_to. - @override_config({"send_federation": True}) def test_send_local_online_presence_to(self): - """Tests that send_local_presence_to_users sends local online presence to local users.""" - # Create a user who will send presence updates - self.presence_receiver_id = self.register_user("presence_receiver", "monkey") - self.presence_receiver_tok = self.login("presence_receiver", "monkey") - - # And another user that will send presence updates out - self.presence_sender_id = self.register_user("presence_sender", "monkey") - self.presence_sender_tok = self.login("presence_sender", "monkey") - - # Put them in a room together so they will receive each other's presence updates - room_id = self.helper.create_room_as( - self.presence_receiver_id, - tok=self.presence_receiver_tok, - ) - self.helper.join(room_id, self.presence_sender_id, tok=self.presence_sender_tok) - - # Presence sender comes online - send_presence_update( - self, - self.presence_sender_id, - self.presence_sender_tok, - "online", - "I'm online!", - ) - - # Presence receiver should have received it - presence_updates, sync_token = sync_presence(self, self.presence_receiver_id) - self.assertEqual(len(presence_updates), 1) - - presence_update = presence_updates[0] # type: UserPresenceState - self.assertEqual(presence_update.user_id, self.presence_sender_id) - self.assertEqual(presence_update.state, "online") - - # Syncing again should result in no presence updates - presence_updates, sync_token = sync_presence( - self, self.presence_receiver_id, sync_token - ) - self.assertEqual(len(presence_updates), 0) - - # Trigger sending local online presence - self.get_success( - self.module_api.send_local_online_presence_to( - [ - self.presence_receiver_id, - ] - ) - ) - - # Presence receiver should have received online presence again - presence_updates, sync_token = sync_presence( - self, self.presence_receiver_id, sync_token - ) - self.assertEqual(len(presence_updates), 1) - - presence_update = presence_updates[0] # type: UserPresenceState - self.assertEqual(presence_update.user_id, self.presence_sender_id) - self.assertEqual(presence_update.state, "online") - - # Presence sender goes offline - send_presence_update( - self, - self.presence_sender_id, - self.presence_sender_tok, - "offline", - "I slink back into the darkness.", - ) - - # Trigger sending local online presence - self.get_success( - self.module_api.send_local_online_presence_to( - [ - self.presence_receiver_id, - ] - ) - ) - - # Presence receiver should *not* have received offline state - presence_updates, sync_token = sync_presence( - self, self.presence_receiver_id, sync_token - ) - self.assertEqual(len(presence_updates), 0) + # Test sending local online presence to users from the main process + _test_sending_local_online_presence_to_local_user(self, test_with_workers=False) @override_config({"send_federation": True}) def test_send_local_online_presence_to_federation(self): """Tests that send_local_presence_to_users sends local online presence to remote users.""" # Create a user who will send presence updates - self.presence_sender_id = self.register_user("presence_sender", "monkey") - self.presence_sender_tok = self.login("presence_sender", "monkey") + self.presence_sender_id = self.register_user("presence_sender1", "monkey") + self.presence_sender_tok = self.login("presence_sender1", "monkey") # And a room they're a part of room_id = self.helper.create_room_as( @@ -374,3 +297,209 @@ class ModuleApiTestCase(FederatingHomeserverTestCase): found_update = True self.assertTrue(found_update) + + +class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase): + """For testing ModuleApi functionality in a multi-worker setup""" + + # Testing stream ID replication from the main to worker processes requires postgres + # (due to needing `MultiWriterIdGenerator`). + if not USE_POSTGRES_FOR_TESTS: + skip = "Requires Postgres" + + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + presence.register_servlets, + ] + + def default_config(self): + conf = super().default_config() + conf["redis"] = {"enabled": "true"} + conf["stream_writers"] = {"presence": ["presence_writer"]} + conf["instance_map"] = { + "presence_writer": {"host": "testserv", "port": 1001}, + } + return conf + + def prepare(self, reactor, clock, homeserver): + self.module_api = homeserver.get_module_api() + self.sync_handler = homeserver.get_sync_handler() + + def test_send_local_online_presence_to_workers(self): + # Test sending local online presence to users from a worker process + _test_sending_local_online_presence_to_local_user(self, test_with_workers=True) + + +def _test_sending_local_online_presence_to_local_user( + test_case: HomeserverTestCase, test_with_workers: bool = False +): + """Tests that send_local_presence_to_users sends local online presence to local users. + + This simultaneously tests two different usecases: + * Testing that this method works when either called from a worker or the main process. + - We test this by calling this method from both a TestCase that runs in monolith mode, and one that + runs with a main and generic_worker. + * Testing that multiple devices syncing simultaneously will all receive a snapshot of local, + online presence - but only once per device. + + Args: + test_with_workers: If True, this method will call ModuleApi.send_local_online_presence_to on a + worker process. The test users will still sync with the main process. The purpose of testing + with a worker is to check whether a Synapse module running on a worker can inform other workers/ + the main process that they should include additional presence when a user next syncs. + """ + if test_with_workers: + # Create a worker process to make module_api calls against + worker_hs = test_case.make_worker_hs( + "synapse.app.generic_worker", {"worker_name": "presence_writer"} + ) + + # Create a user who will send presence updates + test_case.presence_receiver_id = test_case.register_user( + "presence_receiver1", "monkey" + ) + test_case.presence_receiver_tok = test_case.login("presence_receiver1", "monkey") + + # And another user that will send presence updates out + test_case.presence_sender_id = test_case.register_user("presence_sender2", "monkey") + test_case.presence_sender_tok = test_case.login("presence_sender2", "monkey") + + # Put them in a room together so they will receive each other's presence updates + room_id = test_case.helper.create_room_as( + test_case.presence_receiver_id, + tok=test_case.presence_receiver_tok, + ) + test_case.helper.join( + room_id, test_case.presence_sender_id, tok=test_case.presence_sender_tok + ) + + # Presence sender comes online + send_presence_update( + test_case, + test_case.presence_sender_id, + test_case.presence_sender_tok, + "online", + "I'm online!", + ) + + # Presence receiver should have received it + presence_updates, sync_token = sync_presence( + test_case, test_case.presence_receiver_id + ) + test_case.assertEqual(len(presence_updates), 1) + + presence_update = presence_updates[0] # type: UserPresenceState + test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id) + test_case.assertEqual(presence_update.state, "online") + + if test_with_workers: + # Replicate the current sync presence token from the main process to the worker process. + # We need to do this so that the worker process knows the current presence stream ID to + # insert into the database when we call ModuleApi.send_local_online_presence_to. + test_case.replicate() + + # Syncing again should result in no presence updates + presence_updates, sync_token = sync_presence( + test_case, test_case.presence_receiver_id, sync_token + ) + test_case.assertEqual(len(presence_updates), 0) + + # We do an (initial) sync with a second "device" now, getting a new sync token. + # We'll use this in a moment. + _, sync_token_second_device = sync_presence( + test_case, test_case.presence_receiver_id + ) + + # Determine on which process (main or worker) to call ModuleApi.send_local_online_presence_to on + if test_with_workers: + module_api_to_use = worker_hs.get_module_api() + else: + module_api_to_use = test_case.module_api + + # Trigger sending local online presence. We expect this information + # to be saved to the database where all processes can access it. + # Note that we're syncing via the master. + d = module_api_to_use.send_local_online_presence_to( + [ + test_case.presence_receiver_id, + ] + ) + d = defer.ensureDeferred(d) + + if test_with_workers: + # In order for the required presence_set_state replication request to occur between the + # worker and main process, we need to pump the reactor. Otherwise, the coordinator that + # reads the request on the main process won't do so, and the request will time out. + while not d.called: + test_case.reactor.advance(0.1) + + test_case.get_success(d) + + # The presence receiver should have received online presence again. + presence_updates, sync_token = sync_presence( + test_case, test_case.presence_receiver_id, sync_token + ) + test_case.assertEqual(len(presence_updates), 1) + + presence_update = presence_updates[0] # type: UserPresenceState + test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id) + test_case.assertEqual(presence_update.state, "online") + + # We attempt to sync with the second sync token we received above - just to check that + # multiple syncing devices will each receive the necessary online presence. + presence_updates, sync_token_second_device = sync_presence( + test_case, test_case.presence_receiver_id, sync_token_second_device + ) + test_case.assertEqual(len(presence_updates), 1) + + presence_update = presence_updates[0] # type: UserPresenceState + test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id) + test_case.assertEqual(presence_update.state, "online") + + # However, if we now sync with either "device", we won't receive another burst of online presence + # until the API is called again sometime in the future + presence_updates, sync_token = sync_presence( + test_case, test_case.presence_receiver_id, sync_token + ) + + # Now we check that we don't receive *offline* updates using ModuleApi.send_local_online_presence_to. + + # Presence sender goes offline + send_presence_update( + test_case, + test_case.presence_sender_id, + test_case.presence_sender_tok, + "offline", + "I slink back into the darkness.", + ) + + # Presence receiver should have received the updated, offline state + presence_updates, sync_token = sync_presence( + test_case, test_case.presence_receiver_id, sync_token + ) + test_case.assertEqual(len(presence_updates), 1) + + # Now trigger sending local online presence. + d = module_api_to_use.send_local_online_presence_to( + [ + test_case.presence_receiver_id, + ] + ) + d = defer.ensureDeferred(d) + + if test_with_workers: + # In order for the required presence_set_state replication request to occur between the + # worker and main process, we need to pump the reactor. Otherwise, the coordinator that + # reads the request on the main process won't do so, and the request will time out. + while not d.called: + test_case.reactor.advance(0.1) + + test_case.get_success(d) + + # Presence receiver should *not* have received offline state + presence_updates, sync_token = sync_presence( + test_case, test_case.presence_receiver_id, sync_token + ) + test_case.assertEqual(len(presence_updates), 0) diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index d739eb6b17..5eca5c165d 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -30,7 +30,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): """Checks event persisting sharding works""" # Event persister sharding requires postgres (due to needing - # `MutliWriterIdGenerator`). + # `MultiWriterIdGenerator`). if not USE_POSTGRES_FOR_TESTS: skip = "Requires Postgres" -- cgit 1.5.1 From 6a8643ff3da905568e3f2ec047182753352e39d1 Mon Sep 17 00:00:00 2001 From: Marek Matys <57749215+thermaq@users.noreply.github.com> Date: Fri, 21 May 2021 13:02:06 +0200 Subject: Fixed removal of new presence stream states (#10014) Fixes: https://github.com/matrix-org/synapse/issues/9962 This is a fix for above problem. I fixed it by swaping the order of insertion of new records and deletion of old ones. This ensures that we don't delete fresh database records as we do deletes before inserts. Signed-off-by: Marek Matys --- changelog.d/10014.bugfix | 1 + synapse/storage/databases/main/presence.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 9 deletions(-) create mode 100644 changelog.d/10014.bugfix (limited to 'synapse/storage/databases/main') diff --git a/changelog.d/10014.bugfix b/changelog.d/10014.bugfix new file mode 100644 index 0000000000..7cf3603f94 --- /dev/null +++ b/changelog.d/10014.bugfix @@ -0,0 +1 @@ +Fixed deletion of new presence stream states from database. diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 669a2af884..6a2baa7841 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -97,6 +97,15 @@ class PresenceStore(SQLBaseStore): ) txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,)) + # Delete old rows to stop database from getting really big + sql = "DELETE FROM presence_stream WHERE stream_id < ? AND " + + for states in batch_iter(presence_states, 50): + clause, args = make_in_list_sql_clause( + self.database_engine, "user_id", [s.user_id for s in states] + ) + txn.execute(sql + clause, [stream_id] + list(args)) + # Actually insert new rows self.db_pool.simple_insert_many_txn( txn, @@ -117,15 +126,6 @@ class PresenceStore(SQLBaseStore): ], ) - # Delete old rows to stop database from getting really big - sql = "DELETE FROM presence_stream WHERE stream_id < ? AND " - - for states in batch_iter(presence_states, 50): - clause, args = make_in_list_sql_clause( - self.database_engine, "user_id", [s.user_id for s in states] - ) - txn.execute(sql + clause, [stream_id] + list(args)) - async def get_all_presence_updates( self, instance_name: str, last_id: int, current_id: int, limit: int ) -> Tuple[List[Tuple[int, list]], int, bool]: -- cgit 1.5.1 From 3e831f24ffc887e174f67ff7b1cfe3a429b7b5c1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 21 May 2021 17:57:08 +0100 Subject: Don't hammer the database for destination retry timings every ~5mins (#10036) --- changelog.d/10036.misc | 1 + synapse/app/generic_worker.py | 2 - synapse/federation/transport/server.py | 2 +- synapse/replication/slave/storage/transactions.py | 21 -------- synapse/storage/databases/main/__init__.py | 4 +- synapse/storage/databases/main/transactions.py | 66 +++++++++++++---------- synapse/util/retryutils.py | 8 ++- tests/handlers/test_typing.py | 8 +-- tests/storage/test_transactions.py | 8 ++- tests/util/test_retryutils.py | 18 ++++--- 10 files changed, 62 insertions(+), 76 deletions(-) create mode 100644 changelog.d/10036.misc delete mode 100644 synapse/replication/slave/storage/transactions.py (limited to 'synapse/storage/databases/main') diff --git a/changelog.d/10036.misc b/changelog.d/10036.misc new file mode 100644 index 0000000000..d2cf1e5473 --- /dev/null +++ b/changelog.d/10036.misc @@ -0,0 +1 @@ +Properly invalidate caches for destination retry timings every (instead of expiring entries every 5 minutes). diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index f730cdbd78..91ad326f19 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -61,7 +61,6 @@ from synapse.replication.slave.storage.pushers import SlavedPusherStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.room import RoomStore -from synapse.replication.slave.storage.transactions import SlavedTransactionStore from synapse.rest.admin import register_servlets_for_media_repo from synapse.rest.client.v1 import events, login, presence, room from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet @@ -237,7 +236,6 @@ class GenericWorkerSlavedStore( DirectoryStore, SlavedApplicationServiceStore, SlavedRegistrationStore, - SlavedTransactionStore, SlavedProfileStore, SlavedClientIpStore, SlavedFilteringStore, diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index c17a085a4f..9d50b05d01 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -160,7 +160,7 @@ class Authenticator: # If we get a valid signed request from the other side, its probably # alive retry_timings = await self.store.get_destination_retry_timings(origin) - if retry_timings and retry_timings["retry_last_ts"]: + if retry_timings and retry_timings.retry_last_ts: run_in_background(self._reset_retry_timings, origin) return origin diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py deleted file mode 100644 index a59e543924..0000000000 --- a/synapse/replication/slave/storage/transactions.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.storage.databases.main.transactions import TransactionStore - -from ._base import BaseSlavedStore - - -class SlavedTransactionStore(TransactionStore, BaseSlavedStore): - pass diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 49c7606d51..9cce62ae6c 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -67,7 +67,7 @@ from .state import StateStore from .stats import StatsStore from .stream import StreamStore from .tags import TagsStore -from .transactions import TransactionStore +from .transactions import TransactionWorkerStore from .ui_auth import UIAuthStore from .user_directory import UserDirectoryStore from .user_erasure_store import UserErasureStore @@ -83,7 +83,7 @@ class DataStore( StreamStore, ProfileStore, PresenceStore, - TransactionStore, + TransactionWorkerStore, DirectoryStore, KeyStore, StateStore, diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 82335e7a9d..d211c423b2 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -16,13 +16,15 @@ import logging from collections import namedtuple from typing import Iterable, List, Optional, Tuple +import attr from canonicaljson import encode_canonical_json from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage._base import db_to_json from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.types import JsonDict -from synapse.util.caches.expiringcache import ExpiringCache +from synapse.util.caches.descriptors import cached db_binary_type = memoryview @@ -38,10 +40,23 @@ _UpdateTransactionRow = namedtuple( "_TransactionRow", ("response_code", "response_json") ) -SENTINEL = object() +@attr.s(slots=True, frozen=True, auto_attribs=True) +class DestinationRetryTimings: + """The current destination retry timing info for a remote server.""" -class TransactionWorkerStore(SQLBaseStore): + # The first time we tried and failed to reach the remote server, in ms. + failure_ts: int + + # The last time we tried and failed to reach the remote server, in ms. + retry_last_ts: int + + # How long since the last time we tried to reach the remote server before + # trying again, in ms. + retry_interval: int + + +class TransactionWorkerStore(CacheInvalidationWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) @@ -60,19 +75,6 @@ class TransactionWorkerStore(SQLBaseStore): "_cleanup_transactions", _cleanup_transactions_txn ) - -class TransactionStore(TransactionWorkerStore): - """A collection of queries for handling PDUs.""" - - def __init__(self, database: DatabasePool, db_conn, hs): - super().__init__(database, db_conn, hs) - - self._destination_retry_cache = ExpiringCache( - cache_name="get_destination_retry_timings", - clock=self._clock, - expiry_ms=5 * 60 * 1000, - ) - async def get_received_txn_response( self, transaction_id: str, origin: str ) -> Optional[Tuple[int, JsonDict]]: @@ -145,7 +147,11 @@ class TransactionStore(TransactionWorkerStore): desc="set_received_txn_response", ) - async def get_destination_retry_timings(self, destination): + @cached(max_entries=10000) + async def get_destination_retry_timings( + self, + destination: str, + ) -> Optional[DestinationRetryTimings]: """Gets the current retry timings (if any) for a given destination. Args: @@ -156,34 +162,29 @@ class TransactionStore(TransactionWorkerStore): Otherwise a dict for the retry scheme """ - result = self._destination_retry_cache.get(destination, SENTINEL) - if result is not SENTINEL: - return result - result = await self.db_pool.runInteraction( "get_destination_retry_timings", self._get_destination_retry_timings, destination, ) - # We don't hugely care about race conditions between getting and - # invalidating the cache, since we time out fairly quickly anyway. - self._destination_retry_cache[destination] = result return result - def _get_destination_retry_timings(self, txn, destination): + def _get_destination_retry_timings( + self, txn, destination: str + ) -> Optional[DestinationRetryTimings]: result = self.db_pool.simple_select_one_txn( txn, table="destinations", keyvalues={"destination": destination}, - retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"), + retcols=("failure_ts", "retry_last_ts", "retry_interval"), allow_none=True, ) # check we have a row and retry_last_ts is not null or zero # (retry_last_ts can't be negative) if result and result["retry_last_ts"]: - return result + return DestinationRetryTimings(**result) else: return None @@ -204,7 +205,6 @@ class TransactionStore(TransactionWorkerStore): retry_interval: how long until next retry in ms """ - self._destination_retry_cache.pop(destination, None) if self.database_engine.can_native_upsert: return await self.db_pool.runInteraction( "set_destination_retry_timings", @@ -252,6 +252,10 @@ class TransactionStore(TransactionWorkerStore): txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval)) + self._invalidate_cache_and_stream( + txn, self.get_destination_retry_timings, (destination,) + ) + def _set_destination_retry_timings_emulated( self, txn, destination, failure_ts, retry_last_ts, retry_interval ): @@ -295,6 +299,10 @@ class TransactionStore(TransactionWorkerStore): }, ) + self._invalidate_cache_and_stream( + txn, self.get_destination_retry_timings, (destination,) + ) + async def store_destination_rooms_entries( self, destinations: Iterable[str], diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index f9c370a814..129b47cd49 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -82,11 +82,9 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k retry_timings = await store.get_destination_retry_timings(destination) if retry_timings: - failure_ts = retry_timings["failure_ts"] - retry_last_ts, retry_interval = ( - retry_timings["retry_last_ts"], - retry_timings["retry_interval"], - ) + failure_ts = retry_timings.failure_ts + retry_last_ts = retry_timings.retry_last_ts + retry_interval = retry_timings.retry_interval now = int(clock.time_msec()) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 0c89487eaf..f58afbc244 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -89,14 +89,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.event_source = hs.get_event_sources().sources["typing"] self.datastore = hs.get_datastore() - retry_timings_res = { - "destination": "", - "retry_last_ts": 0, - "retry_interval": 0, - "failure_ts": None, - } self.datastore.get_destination_retry_timings = Mock( - return_value=defer.succeed(retry_timings_res) + return_value=defer.succeed(None) ) self.datastore.get_device_updates_by_remote = Mock( diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py index b7f7eae8d0..bea9091d30 100644 --- a/tests/storage/test_transactions.py +++ b/tests/storage/test_transactions.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.storage.databases.main.transactions import DestinationRetryTimings from synapse.util.retryutils import MAX_RETRY_INTERVAL from tests.unittest import HomeserverTestCase @@ -36,8 +37,11 @@ class TransactionStoreTestCase(HomeserverTestCase): d = self.store.get_destination_retry_timings("example.com") r = self.get_success(d) - self.assert_dict( - {"retry_last_ts": 50, "retry_interval": 100, "failure_ts": 1000}, r + self.assertEqual( + DestinationRetryTimings( + retry_last_ts=50, retry_interval=100, failure_ts=1000 + ), + r, ) def test_initial_set_transactions(self): diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py index 9b2be83a43..9e1bebdc83 100644 --- a/tests/util/test_retryutils.py +++ b/tests/util/test_retryutils.py @@ -51,10 +51,12 @@ class RetryLimiterTestCase(HomeserverTestCase): except AssertionError: pass + self.pump() + new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) - self.assertEqual(new_timings["failure_ts"], failure_ts) - self.assertEqual(new_timings["retry_last_ts"], failure_ts) - self.assertEqual(new_timings["retry_interval"], MIN_RETRY_INTERVAL) + self.assertEqual(new_timings.failure_ts, failure_ts) + self.assertEqual(new_timings.retry_last_ts, failure_ts) + self.assertEqual(new_timings.retry_interval, MIN_RETRY_INTERVAL) # now if we try again we should get a failure self.get_failure( @@ -77,14 +79,16 @@ class RetryLimiterTestCase(HomeserverTestCase): except AssertionError: pass + self.pump() + new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) - self.assertEqual(new_timings["failure_ts"], failure_ts) - self.assertEqual(new_timings["retry_last_ts"], retry_ts) + self.assertEqual(new_timings.failure_ts, failure_ts) + self.assertEqual(new_timings.retry_last_ts, retry_ts) self.assertGreaterEqual( - new_timings["retry_interval"], MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 0.5 + new_timings.retry_interval, MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 0.5 ) self.assertLessEqual( - new_timings["retry_interval"], MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0 + new_timings.retry_interval, MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0 ) # -- cgit 1.5.1 From c0df6bae066fe818bb80d41af65503be7a07275d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 24 May 2021 14:02:01 +0100 Subject: Remove `keylen` from `LruCache`. (#9993) `keylen` seems to be a thing that is frequently incorrectly set, and we don't really need it. The only time it was used was to figure out if we had removed a subtree in `del_multi`, which we can do better by changing `TreeCache.pop` to return a different type (`TreeCacheNode`). Commits should be independently reviewable. --- changelog.d/9993.misc | 1 + synapse/replication/slave/storage/client_ips.py | 2 +- synapse/storage/databases/main/client_ips.py | 2 +- synapse/storage/databases/main/devices.py | 2 +- synapse/storage/databases/main/events_worker.py | 1 - synapse/util/caches/deferred_cache.py | 2 - synapse/util/caches/descriptors.py | 1 - synapse/util/caches/lrucache.py | 10 +-- synapse/util/caches/treecache.py | 104 +++++++++++++++--------- tests/util/test_lrucache.py | 4 +- tests/util/test_treecache.py | 6 +- 11 files changed, 80 insertions(+), 55 deletions(-) create mode 100644 changelog.d/9993.misc (limited to 'synapse/storage/databases/main') diff --git a/changelog.d/9993.misc b/changelog.d/9993.misc new file mode 100644 index 0000000000..0dd9244071 --- /dev/null +++ b/changelog.d/9993.misc @@ -0,0 +1 @@ +Remove `keylen` param on `LruCache`. diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 8730966380..13ed87adc4 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -24,7 +24,7 @@ class SlavedClientIpStore(BaseSlavedStore): super().__init__(database, db_conn, hs) self.client_ip_last_seen = LruCache( - cache_name="client_ip_last_seen", keylen=4, max_size=50000 + cache_name="client_ip_last_seen", max_size=50000 ) # type: LruCache[tuple, int] async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index d60010e942..074b077bef 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -436,7 +436,7 @@ class ClientIpStore(ClientIpWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): self.client_ip_last_seen = LruCache( - cache_name="client_ip_last_seen", keylen=4, max_size=50000 + cache_name="client_ip_last_seen", max_size=50000 ) super().__init__(database, db_conn, hs) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index a1f98b7e38..fd87ba71ab 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1053,7 +1053,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Map of (user_id, device_id) -> bool. If there is an entry that implies # the device exists. self.device_id_exists_cache = LruCache( - cache_name="device_id_exists", keylen=2, max_size=10000 + cache_name="device_id_exists", max_size=10000 ) async def store_device( diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 2c823e09cf..6963bbf7f4 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -157,7 +157,6 @@ class EventsWorkerStore(SQLBaseStore): self._get_event_cache = LruCache( cache_name="*getEvent*", - keylen=3, max_size=hs.config.caches.event_cache_size, ) diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 484097a48a..371e7e4dd0 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -70,7 +70,6 @@ class DeferredCache(Generic[KT, VT]): self, name: str, max_entries: int = 1000, - keylen: int = 1, tree: bool = False, iterable: bool = False, apply_cache_factor_from_config: bool = True, @@ -101,7 +100,6 @@ class DeferredCache(Generic[KT, VT]): # a Deferred. self.cache = LruCache( max_size=max_entries, - keylen=keylen, cache_name=name, cache_type=cache_type, size_callback=(lambda d: len(d) or 1) if iterable else None, diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 3a4d027095..2ac24a2f25 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -270,7 +270,6 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): cache = DeferredCache( name=self.orig.__name__, max_entries=self.max_entries, - keylen=self.num_args, tree=self.tree, iterable=self.iterable, ) # type: DeferredCache[CacheKey, Any] diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 1be675e014..54df407ff7 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -34,7 +34,7 @@ from typing_extensions import Literal from synapse.config import cache as cache_config from synapse.util import caches from synapse.util.caches import CacheMetric, register_cache -from synapse.util.caches.treecache import TreeCache +from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry try: from pympler.asizeof import Asizer @@ -160,7 +160,6 @@ class LruCache(Generic[KT, VT]): self, max_size: int, cache_name: Optional[str] = None, - keylen: int = 1, cache_type: Type[Union[dict, TreeCache]] = dict, size_callback: Optional[Callable] = None, metrics_collection_callback: Optional[Callable[[], None]] = None, @@ -173,9 +172,6 @@ class LruCache(Generic[KT, VT]): cache_name: The name of this cache, for the prometheus metrics. If unset, no metrics will be reported on this cache. - keylen: The length of the tuple used as the cache key. Ignored unless - cache_type is `TreeCache`. - cache_type (type): type of underlying cache to be used. Typically one of dict or TreeCache. @@ -403,7 +399,9 @@ class LruCache(Generic[KT, VT]): popped = cache.pop(key) if popped is None: return - for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))): + # for each deleted node, we now need to remove it from the linked list + # and run its callbacks. + for leaf in iterate_tree_cache_entry(popped): delete_node(leaf) @synchronized diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index eb4d98f683..73502a8b06 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -1,18 +1,43 @@ -from typing import Dict +# Copyright 2016-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. SENTINEL = object() +class TreeCacheNode(dict): + """The type of nodes in our tree. + + Has its own type so we can distinguish it from real dicts that are stored at the + leaves. + """ + + pass + + class TreeCache: """ Tree-based backing store for LruCache. Allows subtrees of data to be deleted efficiently. Keys must be tuples. + + The data structure is a chain of TreeCacheNodes: + root = {key_1: {key_2: _value}} """ def __init__(self): self.size = 0 - self.root = {} # type: Dict + self.root = TreeCacheNode() def __setitem__(self, key, value): return self.set(key, value) @@ -21,10 +46,23 @@ class TreeCache: return self.get(key, SENTINEL) is not SENTINEL def set(self, key, value): + if isinstance(value, TreeCacheNode): + # this would mean we couldn't tell where our tree ended and the value + # started. + raise ValueError("Cannot store TreeCacheNodes in a TreeCache") + node = self.root for k in key[:-1]: - node = node.setdefault(k, {}) - node[key[-1]] = _Entry(value) + next_node = node.get(k, SENTINEL) + if next_node is SENTINEL: + next_node = node[k] = TreeCacheNode() + elif not isinstance(next_node, TreeCacheNode): + # this suggests that the caller is not being consistent with its key + # length. + raise ValueError("value conflicts with an existing subtree") + node = next_node + + node[key[-1]] = value self.size += 1 def get(self, key, default=None): @@ -33,25 +71,41 @@ class TreeCache: node = node.get(k, None) if node is None: return default - return node.get(key[-1], _Entry(default)).value + return node.get(key[-1], default) def clear(self): self.size = 0 - self.root = {} + self.root = TreeCacheNode() def pop(self, key, default=None): + """Remove the given key, or subkey, from the cache + + Args: + key: key or subkey to remove. + default: value to return if key is not found + + Returns: + If the key is not found, 'default'. If the key is complete, the removed + value. If the key is partial, the TreeCacheNode corresponding to the part + of the tree that was removed. + """ + # a list of the nodes we have touched on the way down the tree nodes = [] node = self.root for k in key[:-1]: node = node.get(k, None) - nodes.append(node) # don't add the root node if node is None: return default + if not isinstance(node, TreeCacheNode): + # we've gone off the end of the tree + raise ValueError("pop() key too long") + nodes.append(node) # don't add the root node popped = node.pop(key[-1], SENTINEL) if popped is SENTINEL: return default + # working back up the tree, clear out any nodes that are now empty node_and_keys = list(zip(nodes, key)) node_and_keys.reverse() node_and_keys.append((self.root, None)) @@ -61,14 +115,15 @@ class TreeCache: if n: break + # found an empty node: remove it from its parent, and loop. node_and_keys[i + 1][0].pop(k) - popped, cnt = _strip_and_count_entires(popped) + cnt = sum(1 for _ in iterate_tree_cache_entry(popped)) self.size -= cnt return popped def values(self): - return list(iterate_tree_cache_entry(self.root)) + return iterate_tree_cache_entry(self.root) def __len__(self): return self.size @@ -78,36 +133,9 @@ def iterate_tree_cache_entry(d): """Helper function to iterate over the leaves of a tree, i.e. a dict of that can contain dicts. """ - if isinstance(d, dict): + if isinstance(d, TreeCacheNode): for value_d in d.values(): for value in iterate_tree_cache_entry(value_d): yield value else: - if isinstance(d, _Entry): - yield d.value - else: - yield d - - -class _Entry: - __slots__ = ["value"] - - def __init__(self, value): - self.value = value - - -def _strip_and_count_entires(d): - """Takes an _Entry or dict with leaves of _Entry's, and either returns the - value or a dictionary with _Entry's replaced by their values. - - Also returns the count of _Entry's - """ - if isinstance(d, dict): - cnt = 0 - for key, value in d.items(): - v, n = _strip_and_count_entires(value) - d[key] = v - cnt += n - return d, cnt - else: - return d.value, 1 + yield d diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index df3e27779f..377904e72e 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -59,7 +59,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase): self.assertEquals(cache.pop("key"), None) def test_del_multi(self): - cache = LruCache(4, keylen=2, cache_type=TreeCache) + cache = LruCache(4, cache_type=TreeCache) cache[("animal", "cat")] = "mew" cache[("animal", "dog")] = "woof" cache[("vehicles", "car")] = "vroom" @@ -165,7 +165,7 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): m2 = Mock() m3 = Mock() m4 = Mock() - cache = LruCache(4, keylen=2, cache_type=TreeCache) + cache = LruCache(4, cache_type=TreeCache) cache.set(("a", "1"), "value", callbacks=[m1]) cache.set(("a", "2"), "value", callbacks=[m2]) diff --git a/tests/util/test_treecache.py b/tests/util/test_treecache.py index 3b077af27e..6066372053 100644 --- a/tests/util/test_treecache.py +++ b/tests/util/test_treecache.py @@ -13,7 +13,7 @@ # limitations under the License. -from synapse.util.caches.treecache import TreeCache +from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry from .. import unittest @@ -64,12 +64,14 @@ class TreeCacheTestCase(unittest.TestCase): cache[("a", "b")] = "AB" cache[("b", "a")] = "BA" self.assertEquals(cache.get(("a", "a")), "AA") - cache.pop(("a",)) + popped = cache.pop(("a",)) self.assertEquals(cache.get(("a", "a")), None) self.assertEquals(cache.get(("a", "b")), None) self.assertEquals(cache.get(("b", "a")), "BA") self.assertEquals(len(cache), 1) + self.assertEquals({"AA", "AB"}, set(iterate_tree_cache_entry(popped))) + def test_clear(self): cache = TreeCache() cache[("a",)] = "A" -- cgit 1.5.1 From 7adcb20fc02d614b4a2b03b128b279f25633e2bd Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 24 May 2021 15:32:01 -0400 Subject: Add missing type hints to synapse.util (#9982) --- changelog.d/9982.misc | 1 + mypy.ini | 9 +++++++++ synapse/config/saml2.py | 8 +++++++- synapse/storage/databases/main/keys.py | 2 +- synapse/util/hash.py | 10 +++++----- synapse/util/iterutils.py | 11 ++++------- synapse/util/module_loader.py | 9 +++++---- synapse/util/msisdn.py | 10 +++++----- tests/util/test_itertools.py | 4 ++-- 9 files changed, 39 insertions(+), 25 deletions(-) create mode 100644 changelog.d/9982.misc (limited to 'synapse/storage/databases/main') diff --git a/changelog.d/9982.misc b/changelog.d/9982.misc new file mode 100644 index 0000000000..f3821f61a3 --- /dev/null +++ b/changelog.d/9982.misc @@ -0,0 +1 @@ +Add missing type hints to `synapse.util` module. diff --git a/mypy.ini b/mypy.ini index 1d1d1ea0f2..062872020e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -71,8 +71,13 @@ files = synapse/types.py, synapse/util/async_helpers.py, synapse/util/caches, + synapse/util/daemonize.py, + synapse/util/hash.py, + synapse/util/iterutils.py, synapse/util/metrics.py, synapse/util/macaroons.py, + synapse/util/module_loader.py, + synapse/util/msisdn.py, synapse/util/stringutils.py, synapse/visibility.py, tests/replication, @@ -80,6 +85,7 @@ files = tests/handlers/test_password_providers.py, tests/rest/client/v1/test_login.py, tests/rest/client/v2_alpha/test_auth.py, + tests/util/test_itertools.py, tests/util/test_stream_change_cache.py [mypy-pymacaroons.*] @@ -175,5 +181,8 @@ ignore_missing_imports = True [mypy-pympler.*] ignore_missing_imports = True +[mypy-phonenumbers.*] +ignore_missing_imports = True + [mypy-ijson.*] ignore_missing_imports = True diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py index 3d1218c8d1..05e983625d 100644 --- a/synapse/config/saml2.py +++ b/synapse/config/saml2.py @@ -164,7 +164,13 @@ class SAML2Config(Config): config_path = saml2_config.get("config_path", None) if config_path is not None: mod = load_python_module(config_path) - _dict_merge(merge_dict=mod.CONFIG, into_dict=saml2_config_dict) + config = getattr(mod, "CONFIG", None) + if config is None: + raise ConfigError( + "Config path specified by saml2_config.config_path does not " + "have a CONFIG property." + ) + _dict_merge(merge_dict=config, into_dict=saml2_config_dict) import saml2.config diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index 0e86807834..6990f3ed1d 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -55,7 +55,7 @@ class KeyStore(SQLBaseStore): """ keys = {} - def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str]]) -> None: + def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None: """Processes a batch of keys to fetch, and adds the result to `keys`.""" # batch_iter always returns tuples so it's safe to do len(batch) diff --git a/synapse/util/hash.py b/synapse/util/hash.py index ba676e1762..7625ca8c2c 100644 --- a/synapse/util/hash.py +++ b/synapse/util/hash.py @@ -17,15 +17,15 @@ import hashlib import unpaddedbase64 -def sha256_and_url_safe_base64(input_text): +def sha256_and_url_safe_base64(input_text: str) -> str: """SHA256 hash an input string, encode the digest as url-safe base64, and return - :param input_text: string to hash - :type input_text: str + Args: + input_text: string to hash - :returns a sha256 hashed and url-safe base64 encoded digest - :rtype: str + returns: + A sha256 hashed and url-safe base64 encoded digest """ digest = hashlib.sha256(input_text.encode()).digest() return unpaddedbase64.encode_base64(digest, urlsafe=True) diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py index abfdc29832..886afa9d19 100644 --- a/synapse/util/iterutils.py +++ b/synapse/util/iterutils.py @@ -30,12 +30,12 @@ from typing import ( T = TypeVar("T") -def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T]]: +def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T, ...]]: """batch an iterable up into tuples with a maximum size Args: - iterable (iterable): the iterable to slice - size (int): the maximum batch size + iterable: the iterable to slice + size: the maximum batch size Returns: an iterator over the chunks @@ -46,10 +46,7 @@ def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T]]: return iter(lambda: tuple(islice(sourceiter, size)), ()) -ISeq = TypeVar("ISeq", bound=Sequence, covariant=True) - - -def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]: +def chunk_seq(iseq: Sequence[T], maxlen: int) -> Iterable[Sequence[T]]: """Split the given sequence into chunks of the given size The last chunk may be shorter than the given size. diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py index 8acbe276e4..cbfbd097f9 100644 --- a/synapse/util/module_loader.py +++ b/synapse/util/module_loader.py @@ -15,6 +15,7 @@ import importlib import importlib.util import itertools +from types import ModuleType from typing import Any, Iterable, Tuple, Type import jsonschema @@ -44,8 +45,8 @@ def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]: # We need to import the module, and then pick the class out of # that, so we split based on the last dot. - module, clz = modulename.rsplit(".", 1) - module = importlib.import_module(module) + module_name, clz = modulename.rsplit(".", 1) + module = importlib.import_module(module_name) provider_class = getattr(module, clz) # Load the module config. If None, pass an empty dictionary instead @@ -69,11 +70,11 @@ def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]: return provider_class, provider_config -def load_python_module(location: str): +def load_python_module(location: str) -> ModuleType: """Load a python module, and return a reference to its global namespace Args: - location (str): path to the module + location: path to the module Returns: python module object diff --git a/synapse/util/msisdn.py b/synapse/util/msisdn.py index bbbdebf264..1046224f15 100644 --- a/synapse/util/msisdn.py +++ b/synapse/util/msisdn.py @@ -17,19 +17,19 @@ import phonenumbers from synapse.api.errors import SynapseError -def phone_number_to_msisdn(country, number): +def phone_number_to_msisdn(country: str, number: str) -> str: """ Takes an ISO-3166-1 2 letter country code and phone number and returns an msisdn representing the canonical version of that phone number. Args: - country (str): ISO-3166-1 2 letter country code - number (str): Phone number in a national or international format + country: ISO-3166-1 2 letter country code + number: Phone number in a national or international format Returns: - (str) The canonical form of the phone number, as an msisdn + The canonical form of the phone number, as an msisdn Raises: - SynapseError if the number could not be parsed. + SynapseError if the number could not be parsed. """ try: phoneNumber = phonenumbers.parse(number, country) diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py index 1bd0b45d94..e712eb42ea 100644 --- a/tests/util/test_itertools.py +++ b/tests/util/test_itertools.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List +from typing import Dict, Iterable, List, Sequence from synapse.util.iterutils import chunk_seq, sorted_topologically @@ -44,7 +44,7 @@ class ChunkSeqTests(TestCase): ) def test_empty_input(self): - parts = chunk_seq([], 5) + parts = chunk_seq([], 5) # type: Iterable[Sequence] self.assertEqual( list(parts), -- cgit 1.5.1