diff options
39 files changed, 1029 insertions, 691 deletions
diff --git a/changelog.d/7836.misc b/changelog.d/7836.misc new file mode 100644 index 0000000000..a3a97c7590 --- /dev/null +++ b/changelog.d/7836.misc @@ -0,0 +1 @@ +Ensure that calls to `json.dumps` are compatible with the standard library json. diff --git a/changelog.d/7844.bugfix b/changelog.d/7844.bugfix new file mode 100644 index 0000000000..ad296f1b3c --- /dev/null +++ b/changelog.d/7844.bugfix @@ -0,0 +1 @@ +Errors which occur while using the non-standard JWT login now return the proper error: `403 Forbidden` with an error code of `M_FORBIDDEN`. diff --git a/changelog.d/7847.feature b/changelog.d/7847.feature new file mode 100644 index 0000000000..4b9a8d8569 --- /dev/null +++ b/changelog.d/7847.feature @@ -0,0 +1 @@ +Add the ability to re-activate an account from the admin API. diff --git a/changelog.d/7848.misc b/changelog.d/7848.misc new file mode 100644 index 0000000000..d9db1d8357 --- /dev/null +++ b/changelog.d/7848.misc @@ -0,0 +1 @@ +Remove redundant `retry_on_integrity_error` wrapper for event persistence code. diff --git a/changelog.d/7851.misc b/changelog.d/7851.misc new file mode 100644 index 0000000000..e5cf540edf --- /dev/null +++ b/changelog.d/7851.misc @@ -0,0 +1 @@ +Convert E2E keys and room keys handlers to async/await. diff --git a/changelog.d/7853.misc b/changelog.d/7853.misc new file mode 100644 index 0000000000..b4f614084d --- /dev/null +++ b/changelog.d/7853.misc @@ -0,0 +1 @@ +Add support for handling registration requests across multiple client reader workers. diff --git a/changelog.d/7854.bugfix b/changelog.d/7854.bugfix new file mode 100644 index 0000000000..b11f9dedfe --- /dev/null +++ b/changelog.d/7854.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.10.0 which could cause a "no create event in auth events" error during room creation. diff --git a/changelog.d/7856.misc b/changelog.d/7856.misc new file mode 100644 index 0000000000..7d99fb67be --- /dev/null +++ b/changelog.d/7856.misc @@ -0,0 +1 @@ +Small performance improvement in typing processing. diff --git a/changelog.d/7870.misc b/changelog.d/7870.misc new file mode 100644 index 0000000000..27cce2f2f9 --- /dev/null +++ b/changelog.d/7870.misc @@ -0,0 +1 @@ +Add some type annotations to `HomeServer` and `BaseHandler`. diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index 7b030a6285..be05128b3e 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -91,10 +91,14 @@ Body parameters: - ``admin``, optional, defaults to ``false``. -- ``deactivated``, optional, defaults to ``false``. +- ``deactivated``, optional. If unspecified, deactivation state will be left + unchanged on existing accounts and set to ``false`` for new accounts. If the user already exists then optional parameters default to the current value. +In order to re-activate an account ``deactivated`` must be set to ``false``. If +users do not login via single-sign-on, a new ``password`` must be provided. + List Accounts ============= diff --git a/docs/jwt.md b/docs/jwt.md index 93b8d05236..5be9fd26e3 100644 --- a/docs/jwt.md +++ b/docs/jwt.md @@ -31,10 +31,7 @@ The `token` field should include the JSON web token with the following claims: Providing the audience claim when not configured will cause validation to fail. In the case that the token is not valid, the homeserver must respond with -`401 Unauthorized` and an error code of `M_UNAUTHORIZED`. - -(Note that this differs from the token based logins which return a -`403 Forbidden` and an error code of `M_FORBIDDEN` if an error occurs.) +`403 Forbidden` and an error code of `M_FORBIDDEN`. As with other login types, there are additional fields (e.g. `device_id` and `initial_device_display_name`) which can be included in the above request. diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 94db877050..bd004350dd 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -16,12 +16,14 @@ # limitations under the License. """Contains exceptions and error codes.""" -import json + import logging import typing from http import HTTPStatus from typing import Dict, List, Optional, Union +from canonicaljson import json + from twisted.web import http if typing.TYPE_CHECKING: diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 2aab9c5f55..8c53330c49 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -14,10 +14,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json import logging from typing import Any, Callable, Dict, List, Match, Optional, Tuple, Union +from canonicaljson import json from prometheus_client import Counter, Histogram from twisted.internet import defer diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index a2752a54a5..8280f8b900 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -61,8 +61,6 @@ class TransactionManager(object): # all the edus in that transaction. This needs to be done since there is # no active span here, so if the edus were not received by the remote the # span would have no causality and it would be forgotten. - # The span_contexts is a generator so that it won't be evaluated if - # opentracing is disabled. (Yay speed!) span_contexts = [] keep_destination = whitelisted_homeserver(destination) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 61dc4beafe..6a4944467a 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -17,6 +17,8 @@ import logging from twisted.internet import defer +import synapse.state +import synapse.storage import synapse.types from synapse.api.constants import EventTypes, Membership from synapse.api.ratelimiting import Ratelimiter @@ -28,10 +30,6 @@ logger = logging.getLogger(__name__) class BaseHandler(object): """ Common base class for the event handlers. - - Attributes: - store (synapse.storage.DataStore): - state_handler (synapse.state.StateHandler): """ def __init__(self, hs): @@ -39,10 +37,10 @@ class BaseHandler(object): Args: hs (synapse.server.HomeServer): """ - self.store = hs.get_datastore() + self.store = hs.get_datastore() # type: synapse.storage.DataStore self.auth = hs.get_auth() self.notifier = hs.get_notifier() - self.state_handler = hs.get_state_handler() + self.state_handler = hs.get_state_handler() # type: synapse.state.StateHandler self.distributor = hs.get_distributor() self.clock = hs.get_clock() self.hs = hs diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 591657a5c2..3789b1b495 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Optional from synapse.api.errors import SynapseError from synapse.metrics.background_process_metrics import run_as_background_process @@ -46,19 +47,20 @@ class DeactivateAccountHandler(BaseHandler): self._account_validity_enabled = hs.config.account_validity.enabled - async def deactivate_account(self, user_id, erase_data, id_server=None): + async def deactivate_account( + self, user_id: str, erase_data: bool, id_server: Optional[str] = None + ) -> bool: """Deactivate a user's account Args: - user_id (str): ID of user to be deactivated - erase_data (bool): whether to GDPR-erase the user's data - id_server (str|None): Use the given identity server when unbinding + user_id: ID of user to be deactivated + erase_data: whether to GDPR-erase the user's data + id_server: Use the given identity server when unbinding any threepids. If None then will attempt to unbind using the identity server specified when binding (if known). Returns: - Deferred[bool]: True if identity server supports removing - threepids, otherwise False. + True if identity server supports removing threepids, otherwise False. """ # FIXME: Theoretically there is a race here wherein user resets # password using threepid. @@ -138,11 +140,11 @@ class DeactivateAccountHandler(BaseHandler): return identity_server_supports_unbinding - async def _reject_pending_invites_for_user(self, user_id): + async def _reject_pending_invites_for_user(self, user_id: str): """Reject pending invites addressed to a given user ID. Args: - user_id (str): The user ID to reject pending invites for. + user_id: The user ID to reject pending invites for. """ user = UserID.from_string(user_id) pending_invites = await self.store.get_invited_rooms_for_local_user(user_id) @@ -170,22 +172,16 @@ class DeactivateAccountHandler(BaseHandler): room.room_id, ) - def _start_user_parting(self): + def _start_user_parting(self) -> None: """ Start the process that goes through the table of users pending deactivation, if it isn't already running. - - Returns: - None """ if not self._user_parter_running: run_as_background_process("user_parter_loop", self._user_parter_loop) - async def _user_parter_loop(self): + async def _user_parter_loop(self) -> None: """Loop that parts deactivated users from rooms - - Returns: - None """ self._user_parter_running = True logger.info("Starting user parter") @@ -202,11 +198,8 @@ class DeactivateAccountHandler(BaseHandler): finally: self._user_parter_running = False - async def _part_user(self, user_id): + async def _part_user(self, user_id: str) -> None: """Causes the given user_id to leave all the rooms they're joined to - - Returns: - None """ user = UserID.from_string(user_id) @@ -228,3 +221,18 @@ class DeactivateAccountHandler(BaseHandler): user_id, room_id, ) + + async def activate_account(self, user_id: str) -> None: + """ + Activate an account that was previously deactivated. + + This simply marks the user as activate in the database and does not + attempt to rejoin rooms, re-add threepids, etc. + + The user will also need a password hash set to actually login. + + Args: + user_id: ID of user to be deactivated + """ + # Mark the user as activate. + await self.store.set_user_deactivated_status(user_id, False) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index a7e60cbc26..361dd64cd2 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -77,8 +77,7 @@ class E2eKeysHandler(object): ) @trace - @defer.inlineCallbacks - def query_devices(self, query_body, timeout, from_user_id): + async def query_devices(self, query_body, timeout, from_user_id): """ Handle a device key query from a client { @@ -124,7 +123,7 @@ class E2eKeysHandler(object): failures = {} results = {} if local_query: - local_result = yield self.query_local_devices(local_query) + local_result = await self.query_local_devices(local_query) for user_id, keys in local_result.items(): if user_id in local_query: results[user_id] = keys @@ -142,7 +141,7 @@ class E2eKeysHandler(object): ( user_ids_not_in_cache, remote_results, - ) = yield self.store.get_user_devices_from_cache(query_list) + ) = await self.store.get_user_devices_from_cache(query_list) for user_id, devices in remote_results.items(): user_devices = results.setdefault(user_id, {}) for device_id, device in devices.items(): @@ -161,14 +160,13 @@ class E2eKeysHandler(object): r[user_id] = remote_queries[user_id] # Get cached cross-signing keys - cross_signing_keys = yield self.get_cross_signing_keys_from_cache( + cross_signing_keys = await self.get_cross_signing_keys_from_cache( device_keys_query, from_user_id ) # Now fetch any devices that we don't have in our cache @trace - @defer.inlineCallbacks - def do_remote_query(destination): + async def do_remote_query(destination): """This is called when we are querying the device list of a user on a remote homeserver and their device list is not in the device list cache. If we share a room with this user and we're not querying for @@ -192,7 +190,7 @@ class E2eKeysHandler(object): if device_list: continue - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) if not room_ids: continue @@ -201,11 +199,11 @@ class E2eKeysHandler(object): # done an initial sync on the device list so we do it now. try: if self._is_master: - user_devices = yield self.device_handler.device_list_updater.user_device_resync( + user_devices = await self.device_handler.device_list_updater.user_device_resync( user_id ) else: - user_devices = yield self._user_device_resync_client( + user_devices = await self._user_device_resync_client( user_id=user_id ) @@ -227,7 +225,7 @@ class E2eKeysHandler(object): destination_query.pop(user_id) try: - remote_result = yield self.federation.query_client_keys( + remote_result = await self.federation.query_client_keys( destination, {"device_keys": destination_query}, timeout=timeout ) @@ -251,7 +249,7 @@ class E2eKeysHandler(object): set_tag("error", True) set_tag("reason", failure) - yield make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults( [ run_in_background(do_remote_query, destination) @@ -267,8 +265,7 @@ class E2eKeysHandler(object): return ret - @defer.inlineCallbacks - def get_cross_signing_keys_from_cache(self, query, from_user_id): + async def get_cross_signing_keys_from_cache(self, query, from_user_id): """Get cross-signing keys for users from the database Args: @@ -289,7 +286,7 @@ class E2eKeysHandler(object): user_ids = list(query) - keys = yield self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id) + keys = await self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id) for user_id, user_info in keys.items(): if user_info is None: @@ -315,8 +312,7 @@ class E2eKeysHandler(object): } @trace - @defer.inlineCallbacks - def query_local_devices(self, query): + async def query_local_devices(self, query): """Get E2E device keys for local users Args: @@ -354,7 +350,7 @@ class E2eKeysHandler(object): # make sure that each queried user appears in the result dict result_dict[user_id] = {} - results = yield self.store.get_e2e_device_keys(local_query) + results = await self.store.get_e2e_device_keys(local_query) # Build the result structure for user_id, device_keys in results.items(): @@ -364,16 +360,15 @@ class E2eKeysHandler(object): log_kv(results) return result_dict - @defer.inlineCallbacks - def on_federation_query_client_keys(self, query_body): + async def on_federation_query_client_keys(self, query_body): """ Handle a device key query from a federated server """ device_keys_query = query_body.get("device_keys", {}) - res = yield self.query_local_devices(device_keys_query) + res = await self.query_local_devices(device_keys_query) ret = {"device_keys": res} # add in the cross-signing keys - cross_signing_keys = yield self.get_cross_signing_keys_from_cache( + cross_signing_keys = await self.get_cross_signing_keys_from_cache( device_keys_query, None ) @@ -382,8 +377,7 @@ class E2eKeysHandler(object): return ret @trace - @defer.inlineCallbacks - def claim_one_time_keys(self, query, timeout): + async def claim_one_time_keys(self, query, timeout): local_query = [] remote_queries = {} @@ -399,7 +393,7 @@ class E2eKeysHandler(object): set_tag("local_key_query", local_query) set_tag("remote_key_query", remote_queries) - results = yield self.store.claim_e2e_one_time_keys(local_query) + results = await self.store.claim_e2e_one_time_keys(local_query) json_result = {} failures = {} @@ -411,12 +405,11 @@ class E2eKeysHandler(object): } @trace - @defer.inlineCallbacks - def claim_client_keys(destination): + async def claim_client_keys(destination): set_tag("destination", destination) device_keys = remote_queries[destination] try: - remote_result = yield self.federation.claim_client_keys( + remote_result = await self.federation.claim_client_keys( destination, {"one_time_keys": device_keys}, timeout=timeout ) for user_id, keys in remote_result["one_time_keys"].items(): @@ -429,7 +422,7 @@ class E2eKeysHandler(object): set_tag("error", True) set_tag("reason", failure) - yield make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults( [ run_in_background(claim_client_keys, destination) @@ -454,9 +447,8 @@ class E2eKeysHandler(object): log_kv({"one_time_keys": json_result, "failures": failures}) return {"one_time_keys": json_result, "failures": failures} - @defer.inlineCallbacks @tag_args - def upload_keys_for_user(self, user_id, device_id, keys): + async def upload_keys_for_user(self, user_id, device_id, keys): time_now = self.clock.time_msec() @@ -477,12 +469,12 @@ class E2eKeysHandler(object): } ) # TODO: Sign the JSON with the server key - changed = yield self.store.set_e2e_device_keys( + changed = await self.store.set_e2e_device_keys( user_id, device_id, time_now, device_keys ) if changed: # Only notify about device updates *if* the keys actually changed - yield self.device_handler.notify_device_update(user_id, [device_id]) + await self.device_handler.notify_device_update(user_id, [device_id]) else: log_kv({"message": "Not updating device_keys for user", "user_id": user_id}) one_time_keys = keys.get("one_time_keys", None) @@ -494,7 +486,7 @@ class E2eKeysHandler(object): "device_id": device_id, } ) - yield self._upload_one_time_keys_for_user( + await self._upload_one_time_keys_for_user( user_id, device_id, time_now, one_time_keys ) else: @@ -507,15 +499,14 @@ class E2eKeysHandler(object): # old access_token without an associated device_id. Either way, we # need to double-check the device is registered to avoid ending up with # keys without a corresponding device. - yield self.device_handler.check_device_registered(user_id, device_id) + await self.device_handler.check_device_registered(user_id, device_id) - result = yield self.store.count_e2e_one_time_keys(user_id, device_id) + result = await self.store.count_e2e_one_time_keys(user_id, device_id) set_tag("one_time_key_counts", result) return {"one_time_key_counts": result} - @defer.inlineCallbacks - def _upload_one_time_keys_for_user( + async def _upload_one_time_keys_for_user( self, user_id, device_id, time_now, one_time_keys ): logger.info( @@ -533,7 +524,7 @@ class E2eKeysHandler(object): key_list.append((algorithm, key_id, key_obj)) # First we check if we have already persisted any of the keys. - existing_key_map = yield self.store.get_e2e_one_time_keys( + existing_key_map = await self.store.get_e2e_one_time_keys( user_id, device_id, [k_id for _, k_id, _ in key_list] ) @@ -556,10 +547,9 @@ class E2eKeysHandler(object): ) log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys}) - yield self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) + await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) - @defer.inlineCallbacks - def upload_signing_keys_for_user(self, user_id, keys): + async def upload_signing_keys_for_user(self, user_id, keys): """Upload signing keys for cross-signing Args: @@ -574,7 +564,7 @@ class E2eKeysHandler(object): _check_cross_signing_key(master_key, user_id, "master") else: - master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master") + master_key = await self.store.get_e2e_cross_signing_key(user_id, "master") # if there is no master key, then we can't do anything, because all the # other cross-signing keys need to be signed by the master key @@ -613,10 +603,10 @@ class E2eKeysHandler(object): # if everything checks out, then store the keys and send notifications deviceids = [] if "master_key" in keys: - yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key) + await self.store.set_e2e_cross_signing_key(user_id, "master", master_key) deviceids.append(master_verify_key.version) if "self_signing_key" in keys: - yield self.store.set_e2e_cross_signing_key( + await self.store.set_e2e_cross_signing_key( user_id, "self_signing", self_signing_key ) try: @@ -626,23 +616,22 @@ class E2eKeysHandler(object): except ValueError: raise SynapseError(400, "Invalid self-signing key", Codes.INVALID_PARAM) if "user_signing_key" in keys: - yield self.store.set_e2e_cross_signing_key( + await self.store.set_e2e_cross_signing_key( user_id, "user_signing", user_signing_key ) # the signature stream matches the semantics that we want for # user-signing key updates: only the user themselves is notified of # their own user-signing key updates - yield self.device_handler.notify_user_signature_update(user_id, [user_id]) + await self.device_handler.notify_user_signature_update(user_id, [user_id]) # master key and self-signing key updates match the semantics of device # list updates: all users who share an encrypted room are notified if len(deviceids): - yield self.device_handler.notify_device_update(user_id, deviceids) + await self.device_handler.notify_device_update(user_id, deviceids) return {} - @defer.inlineCallbacks - def upload_signatures_for_device_keys(self, user_id, signatures): + async def upload_signatures_for_device_keys(self, user_id, signatures): """Upload device signatures for cross-signing Args: @@ -667,13 +656,13 @@ class E2eKeysHandler(object): self_signatures = signatures.get(user_id, {}) other_signatures = {k: v for k, v in signatures.items() if k != user_id} - self_signature_list, self_failures = yield self._process_self_signatures( + self_signature_list, self_failures = await self._process_self_signatures( user_id, self_signatures ) signature_list.extend(self_signature_list) failures.update(self_failures) - other_signature_list, other_failures = yield self._process_other_signatures( + other_signature_list, other_failures = await self._process_other_signatures( user_id, other_signatures ) signature_list.extend(other_signature_list) @@ -681,21 +670,20 @@ class E2eKeysHandler(object): # store the signature, and send the appropriate notifications for sync logger.debug("upload signature failures: %r", failures) - yield self.store.store_e2e_cross_signing_signatures(user_id, signature_list) + await self.store.store_e2e_cross_signing_signatures(user_id, signature_list) self_device_ids = [item.target_device_id for item in self_signature_list] if self_device_ids: - yield self.device_handler.notify_device_update(user_id, self_device_ids) + await self.device_handler.notify_device_update(user_id, self_device_ids) signed_users = [item.target_user_id for item in other_signature_list] if signed_users: - yield self.device_handler.notify_user_signature_update( + await self.device_handler.notify_user_signature_update( user_id, signed_users ) return {"failures": failures} - @defer.inlineCallbacks - def _process_self_signatures(self, user_id, signatures): + async def _process_self_signatures(self, user_id, signatures): """Process uploaded signatures of the user's own keys. Signatures of the user's own keys from this API come in two forms: @@ -728,7 +716,7 @@ class E2eKeysHandler(object): _, self_signing_key_id, self_signing_verify_key, - ) = yield self._get_e2e_cross_signing_verify_key(user_id, "self_signing") + ) = await self._get_e2e_cross_signing_verify_key(user_id, "self_signing") # get our master key, since we may have received a signature of it. # We need to fetch it here so that we know what its key ID is, so @@ -738,12 +726,12 @@ class E2eKeysHandler(object): master_key, _, master_verify_key, - ) = yield self._get_e2e_cross_signing_verify_key(user_id, "master") + ) = await self._get_e2e_cross_signing_verify_key(user_id, "master") # fetch our stored devices. This is used to 1. verify # signatures on the master key, and 2. to compare with what # was sent if the device was signed - devices = yield self.store.get_e2e_device_keys([(user_id, None)]) + devices = await self.store.get_e2e_device_keys([(user_id, None)]) if user_id not in devices: raise NotFoundError("No device keys found") @@ -853,8 +841,7 @@ class E2eKeysHandler(object): return master_key_signature_list - @defer.inlineCallbacks - def _process_other_signatures(self, user_id, signatures): + async def _process_other_signatures(self, user_id, signatures): """Process uploaded signatures of other users' keys. These will be the target user's master keys, signed by the uploading user's user-signing key. @@ -882,7 +869,7 @@ class E2eKeysHandler(object): user_signing_key, user_signing_key_id, user_signing_verify_key, - ) = yield self._get_e2e_cross_signing_verify_key(user_id, "user_signing") + ) = await self._get_e2e_cross_signing_verify_key(user_id, "user_signing") except SynapseError as e: failure = _exception_to_failure(e) for user, devicemap in signatures.items(): @@ -905,7 +892,7 @@ class E2eKeysHandler(object): master_key, master_key_id, _, - ) = yield self._get_e2e_cross_signing_verify_key( + ) = await self._get_e2e_cross_signing_verify_key( target_user, "master", user_id ) @@ -958,8 +945,7 @@ class E2eKeysHandler(object): return signature_list, failures - @defer.inlineCallbacks - def _get_e2e_cross_signing_verify_key( + async def _get_e2e_cross_signing_verify_key( self, user_id: str, key_type: str, from_user_id: str = None ): """Fetch locally or remotely query for a cross-signing public key. @@ -983,7 +969,7 @@ class E2eKeysHandler(object): SynapseError: if `user_id` is invalid """ user = UserID.from_string(user_id) - key = yield self.store.get_e2e_cross_signing_key( + key = await self.store.get_e2e_cross_signing_key( user_id, key_type, from_user_id ) @@ -1009,15 +995,14 @@ class E2eKeysHandler(object): key, key_id, verify_key, - ) = yield self._retrieve_cross_signing_keys_for_remote_user(user, key_type) + ) = await self._retrieve_cross_signing_keys_for_remote_user(user, key_type) if key is None: raise NotFoundError("No %s key found for %s" % (key_type, user_id)) return key, key_id, verify_key - @defer.inlineCallbacks - def _retrieve_cross_signing_keys_for_remote_user( + async def _retrieve_cross_signing_keys_for_remote_user( self, user: UserID, desired_key_type: str, ): """Queries cross-signing keys for a remote user and saves them to the database @@ -1035,7 +1020,7 @@ class E2eKeysHandler(object): If the key cannot be retrieved, all values in the tuple will instead be None. """ try: - remote_result = yield self.federation.query_user_devices( + remote_result = await self.federation.query_user_devices( user.domain, user.to_string() ) except Exception as e: @@ -1101,14 +1086,14 @@ class E2eKeysHandler(object): desired_key_id = key_id # At the same time, store this key in the db for subsequent queries - yield self.store.set_e2e_cross_signing_key( + await self.store.set_e2e_cross_signing_key( user.to_string(), key_type, key_content ) # Notify clients that new devices for this user have been discovered if retrieved_device_ids: # XXX is this necessary? - yield self.device_handler.notify_device_update( + await self.device_handler.notify_device_update( user.to_string(), retrieved_device_ids ) @@ -1250,8 +1235,7 @@ class SigningKeyEduUpdater(object): iterable=True, ) - @defer.inlineCallbacks - def incoming_signing_key_update(self, origin, edu_content): + async def incoming_signing_key_update(self, origin, edu_content): """Called on incoming signing key update from federation. Responsible for parsing the EDU and adding to pending updates list. @@ -1268,7 +1252,7 @@ class SigningKeyEduUpdater(object): logger.warning("Got signing key update edu for %r from %r", user_id, origin) return - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) if not room_ids: # We don't share any rooms with this user. Ignore update, as we # probably won't get any further updates. @@ -1278,10 +1262,9 @@ class SigningKeyEduUpdater(object): (master_key, self_signing_key) ) - yield self._handle_signing_key_updates(user_id) + await self._handle_signing_key_updates(user_id) - @defer.inlineCallbacks - def _handle_signing_key_updates(self, user_id): + async def _handle_signing_key_updates(self, user_id): """Actually handle pending updates. Args: @@ -1291,7 +1274,7 @@ class SigningKeyEduUpdater(object): device_handler = self.e2e_keys_handler.device_handler device_list_updater = device_handler.device_list_updater - with (yield self._remote_edu_linearizer.queue(user_id)): + with (await self._remote_edu_linearizer.queue(user_id)): pending_updates = self._pending_updates.pop(user_id, []) if not pending_updates: # This can happen since we batch updates @@ -1302,9 +1285,9 @@ class SigningKeyEduUpdater(object): logger.info("pending updates: %r", pending_updates) for master_key, self_signing_key in pending_updates: - new_device_ids = yield device_list_updater.process_cross_signing_key_update( + new_device_ids = await device_list_updater.process_cross_signing_key_update( user_id, master_key, self_signing_key, ) device_ids = device_ids + new_device_ids - yield device_handler.notify_device_update(user_id, device_ids) + await device_handler.notify_device_update(user_id, device_ids) diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index f55470a707..0bb983dc28 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import ( Codes, NotFoundError, @@ -50,8 +48,7 @@ class E2eRoomKeysHandler(object): self._upload_linearizer = Linearizer("upload_room_keys_lock") @trace - @defer.inlineCallbacks - def get_room_keys(self, user_id, version, room_id=None, session_id=None): + async def get_room_keys(self, user_id, version, room_id=None, session_id=None): """Bulk get the E2E room keys for a given backup, optionally filtered to a given room, or a given session. See EndToEndRoomKeyStore.get_e2e_room_keys for full details. @@ -71,17 +68,17 @@ class E2eRoomKeysHandler(object): # we deliberately take the lock to get keys so that changing the version # works atomically - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): # make sure the backup version exists try: - yield self.store.get_e2e_room_keys_version_info(user_id, version) + await self.store.get_e2e_room_keys_version_info(user_id, version) except StoreError as e: if e.code == 404: raise NotFoundError("Unknown backup version") else: raise - results = yield self.store.get_e2e_room_keys( + results = await self.store.get_e2e_room_keys( user_id, version, room_id, session_id ) @@ -89,8 +86,7 @@ class E2eRoomKeysHandler(object): return results @trace - @defer.inlineCallbacks - def delete_room_keys(self, user_id, version, room_id=None, session_id=None): + async def delete_room_keys(self, user_id, version, room_id=None, session_id=None): """Bulk delete the E2E room keys for a given backup, optionally filtered to a given room or a given session. See EndToEndRoomKeyStore.delete_e2e_room_keys for full details. @@ -109,10 +105,10 @@ class E2eRoomKeysHandler(object): """ # lock for consistency with uploading - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): # make sure the backup version exists try: - version_info = yield self.store.get_e2e_room_keys_version_info( + version_info = await self.store.get_e2e_room_keys_version_info( user_id, version ) except StoreError as e: @@ -121,19 +117,18 @@ class E2eRoomKeysHandler(object): else: raise - yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id) + await self.store.delete_e2e_room_keys(user_id, version, room_id, session_id) version_etag = version_info["etag"] + 1 - yield self.store.update_e2e_room_keys_version( + await self.store.update_e2e_room_keys_version( user_id, version, None, version_etag ) - count = yield self.store.count_e2e_room_keys(user_id, version) + count = await self.store.count_e2e_room_keys(user_id, version) return {"etag": str(version_etag), "count": count} @trace - @defer.inlineCallbacks - def upload_room_keys(self, user_id, version, room_keys): + async def upload_room_keys(self, user_id, version, room_keys): """Bulk upload a list of room keys into a given backup version, asserting that the given version is the current backup version. room_keys are merged into the current backup as described in RoomKeysServlet.on_PUT(). @@ -169,11 +164,11 @@ class E2eRoomKeysHandler(object): # TODO: Validate the JSON to make sure it has the right keys. # XXX: perhaps we should use a finer grained lock here? - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): # Check that the version we're trying to upload is the current version try: - version_info = yield self.store.get_e2e_room_keys_version_info(user_id) + version_info = await self.store.get_e2e_room_keys_version_info(user_id) except StoreError as e: if e.code == 404: raise NotFoundError("Version '%s' not found" % (version,)) @@ -183,7 +178,7 @@ class E2eRoomKeysHandler(object): if version_info["version"] != version: # Check that the version we're trying to upload actually exists try: - version_info = yield self.store.get_e2e_room_keys_version_info( + version_info = await self.store.get_e2e_room_keys_version_info( user_id, version ) # if we get this far, the version must exist @@ -198,7 +193,7 @@ class E2eRoomKeysHandler(object): # submitted. Then compare them with the submitted keys. If the # key is new, insert it; if the key should be updated, then update # it; otherwise, drop it. - existing_keys = yield self.store.get_e2e_room_keys_multi( + existing_keys = await self.store.get_e2e_room_keys_multi( user_id, version, room_keys["rooms"] ) to_insert = [] # batch the inserts together @@ -227,7 +222,7 @@ class E2eRoomKeysHandler(object): # updates are done one at a time in the DB, so send # updates right away rather than batching them up, # like we do with the inserts - yield self.store.update_e2e_room_key( + await self.store.update_e2e_room_key( user_id, version, room_id, session_id, room_key ) changed = True @@ -246,16 +241,16 @@ class E2eRoomKeysHandler(object): changed = True if len(to_insert): - yield self.store.add_e2e_room_keys(user_id, version, to_insert) + await self.store.add_e2e_room_keys(user_id, version, to_insert) version_etag = version_info["etag"] if changed: version_etag = version_etag + 1 - yield self.store.update_e2e_room_keys_version( + await self.store.update_e2e_room_keys_version( user_id, version, None, version_etag ) - count = yield self.store.count_e2e_room_keys(user_id, version) + count = await self.store.count_e2e_room_keys(user_id, version) return {"etag": str(version_etag), "count": count} @staticmethod @@ -291,8 +286,7 @@ class E2eRoomKeysHandler(object): return True @trace - @defer.inlineCallbacks - def create_version(self, user_id, version_info): + async def create_version(self, user_id, version_info): """Create a new backup version. This automatically becomes the new backup version for the user's keys; previous backups will no longer be writeable to. @@ -313,14 +307,13 @@ class E2eRoomKeysHandler(object): # TODO: Validate the JSON to make sure it has the right keys. # lock everyone out until we've switched version - with (yield self._upload_linearizer.queue(user_id)): - new_version = yield self.store.create_e2e_room_keys_version( + with (await self._upload_linearizer.queue(user_id)): + new_version = await self.store.create_e2e_room_keys_version( user_id, version_info ) return new_version - @defer.inlineCallbacks - def get_version_info(self, user_id, version=None): + async def get_version_info(self, user_id, version=None): """Get the info about a given version of the user's backup Args: @@ -339,22 +332,21 @@ class E2eRoomKeysHandler(object): } """ - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): try: - res = yield self.store.get_e2e_room_keys_version_info(user_id, version) + res = await self.store.get_e2e_room_keys_version_info(user_id, version) except StoreError as e: if e.code == 404: raise NotFoundError("Unknown backup version") else: raise - res["count"] = yield self.store.count_e2e_room_keys(user_id, res["version"]) + res["count"] = await self.store.count_e2e_room_keys(user_id, res["version"]) res["etag"] = str(res["etag"]) return res @trace - @defer.inlineCallbacks - def delete_version(self, user_id, version=None): + async def delete_version(self, user_id, version=None): """Deletes a given version of the user's e2e_room_keys backup Args: @@ -364,9 +356,9 @@ class E2eRoomKeysHandler(object): NotFoundError: if this backup version doesn't exist """ - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): try: - yield self.store.delete_e2e_room_keys_version(user_id, version) + await self.store.delete_e2e_room_keys_version(user_id, version) except StoreError as e: if e.code == 404: raise NotFoundError("Unknown backup version") @@ -374,8 +366,7 @@ class E2eRoomKeysHandler(object): raise @trace - @defer.inlineCallbacks - def update_version(self, user_id, version, version_info): + async def update_version(self, user_id, version, version_info): """Update the info about a given version of the user's backup Args: @@ -393,9 +384,9 @@ class E2eRoomKeysHandler(object): raise SynapseError( 400, "Version in body does not match", Codes.INVALID_PARAM ) - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): try: - old_info = yield self.store.get_e2e_room_keys_version_info( + old_info = await self.store.get_e2e_room_keys_version_info( user_id, version ) except StoreError as e: @@ -406,7 +397,7 @@ class E2eRoomKeysHandler(object): if old_info["algorithm"] != version_info["algorithm"]: raise SynapseError(400, "Algorithm does not match", Codes.INVALID_PARAM) - yield self.store.update_e2e_room_keys_version( + await self.store.update_e2e_room_keys_version( user_id, version, version_info ) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 879c4c07c6..846ddbdc6c 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -185,7 +185,7 @@ class TypingHandler(object): async def _push_remote(self, member, typing): try: - users = await self.state.get_current_users_in_room(member.room_id) + users = await self.store.get_users_in_room(member.room_id) self._member_last_federation_poke[member] = self.clock.time_msec() now = self.clock.time_msec() @@ -224,7 +224,7 @@ class TypingHandler(object): ) return - users = await self.state.get_current_users_in_room(room_id) + users = await self.store.get_users_in_room(room_id) domains = {get_domain_from_id(u) for u in users} if self.server_name in domains: diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 8b24a73319..a140e9391e 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import logging from canonicaljson import json @@ -117,7 +118,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): except PartialDownloadError as pde: # Twisted is silly data = pde.response - resp_body = json.loads(data) + resp_body = json.loads(data.decode("utf-8")) if "success" in resp_body: # Note that we do NOT check the hostname here: we explicitly diff --git a/synapse/http/client.py b/synapse/http/client.py index 505872ee90..6bc51202cd 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -13,13 +13,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json + import logging import urllib from io import BytesIO import treq -from canonicaljson import encode_canonical_json +from canonicaljson import encode_canonical_json, json from netaddr import IPAddress from prometheus_client import Counter from zope.interface import implementer, provider @@ -31,6 +31,7 @@ from twisted.internet.interfaces import ( IReactorPluggableNameResolver, IResolutionReceiver, ) +from twisted.internet.task import Cooperator from twisted.python.failure import Failure from twisted.web._newclient import ResponseDone from twisted.web.client import Agent, HTTPConnectionPool, readBody @@ -69,6 +70,21 @@ def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist): return False +_EPSILON = 0.00000001 + + +def _make_scheduler(reactor): + """Makes a schedular suitable for a Cooperator using the given reactor. + + (This is effectively just a copy from `twisted.internet.task`) + """ + + def _scheduler(x): + return reactor.callLater(_EPSILON, x) + + return _scheduler + + class IPBlacklistingResolver(object): """ A proxy for reactor.nameResolver which only produces non-blacklisted IP @@ -212,6 +228,10 @@ class SimpleHttpClient(object): if hs.config.user_agent_suffix: self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix) + # We use this for our body producers to ensure that they use the correct + # reactor. + self._cooperator = Cooperator(scheduler=_make_scheduler(hs.get_reactor())) + self.user_agent = self.user_agent.encode("ascii") if self._ip_blacklist: @@ -292,7 +312,9 @@ class SimpleHttpClient(object): try: body_producer = None if data is not None: - body_producer = QuieterFileBodyProducer(BytesIO(data)) + body_producer = QuieterFileBodyProducer( + BytesIO(data), cooperator=self._cooperator, + ) request_deferred = treq.request( method, diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 3cabe9d02e..a34e5ead88 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -14,9 +14,11 @@ # limitations under the License. """ This module contains base REST classes for constructing REST servlets. """ -import json + import logging +from canonicaljson import json + from synapse.api.errors import Codes, SynapseError logger = logging.getLogger(__name__) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index e4330c39d6..cc0bdfa5c9 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -239,6 +239,15 @@ class UserRestServletV2(RestServlet): await self.deactivate_account_handler.deactivate_account( target_user.to_string(), False ) + elif not deactivate and user["deactivated"]: + if "password" not in body: + raise SynapseError( + 400, "Must provide a password to re-activate an account." + ) + + await self.deactivate_account_handler.activate_account( + target_user.to_string() + ) user = await self.admin_handler.get_user(target_user) return 200, user @@ -254,7 +263,6 @@ class UserRestServletV2(RestServlet): admin = body.get("admin", None) user_type = body.get("user_type", None) displayname = body.get("displayname", None) - threepids = body.get("threepids", None) if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: raise SynapseError(400, "Invalid user type") diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 326ffa0056..379f668d6f 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -371,7 +371,7 @@ class LoginRestServlet(RestServlet): token = login_submission.get("token", None) if token is None: raise LoginError( - 401, "Token field for JWT is missing", errcode=Codes.UNAUTHORIZED + 403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN ) import jwt @@ -387,14 +387,12 @@ class LoginRestServlet(RestServlet): except jwt.PyJWTError as e: # A JWT error occurred, return some info back to the client. raise LoginError( - 401, - "JWT validation failed: %s" % (str(e),), - errcode=Codes.UNAUTHORIZED, + 403, "JWT validation failed: %s" % (str(e),), errcode=Codes.FORBIDDEN, ) user = payload.get("sub", None) if user is None: - raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) + raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN) user_id = UserID(user, self.hs.hostname).to_string() result = await self._complete_login( diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 01b80b86fa..c5a84af047 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -15,6 +15,7 @@ # limitations under the License. """ This module contains REST servlets to do with rooms: /rooms/<paths> """ + import logging import re from typing import List, Optional @@ -515,9 +516,9 @@ class RoomMessageListRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request, allow_guest=True) pagination_config = PaginationConfig.from_request(request, default_limit=10) as_client_event = b"raw" not in request.args - filter_bytes = parse_string(request, b"filter", encoding=None) - if filter_bytes: - filter_json = urlparse.unquote(filter_bytes.decode("UTF-8")) + filter_str = parse_string(request, b"filter", encoding="utf-8") + if filter_str: + filter_json = urlparse.unquote(filter_str) event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter] if ( event_filter @@ -627,9 +628,9 @@ class RoomEventContextServlet(RestServlet): limit = parse_integer(request, "limit", default=10) # picking the API shape for symmetry with /messages - filter_bytes = parse_string(request, "filter") - if filter_bytes: - filter_json = urlparse.unquote(filter_bytes) + filter_str = parse_string(request, b"filter", encoding="utf-8") + if filter_str: + filter_json = urlparse.unquote(filter_str) event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter] else: event_filter = None diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index e149ac1733..9b3f85b306 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -202,9 +202,11 @@ class RemoteKey(DirectServeJsonResource): if miss: cache_misses.setdefault(server_name, set()).add(key_id) + # Cast to bytes since postgresql returns a memoryview. json_results.add(bytes(most_recent_result["key_json"])) else: for ts_added, result in results: + # Cast to bytes since postgresql returns a memoryview. json_results.add(bytes(result["key_json"])) if cache_misses and query_remote_on_cache_miss: @@ -213,7 +215,7 @@ class RemoteKey(DirectServeJsonResource): else: signed_keys = [] for key_json in json_results: - key_json = json.loads(key_json) + key_json = json.loads(key_json.decode("utf-8")) for signing_key in self.config.key_server_signing_keys: key_json = sign_json(key_json, self.config.server_name, signing_key) diff --git a/synapse/server.py b/synapse/server.py index ca42c2195a..f838a03d71 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -106,7 +106,7 @@ from synapse.server_notices.worker_server_notices_sender import ( WorkerServerNoticesSender, ) from synapse.state import StateHandler, StateResolutionHandler -from synapse.storage import DataStores, Storage +from synapse.storage import DataStore, DataStores, Storage from synapse.streams.events import EventSources from synapse.util import Clock from synapse.util.distributor import Distributor @@ -314,7 +314,7 @@ class HomeServer(object): def get_clock(self): return self.clock - def get_datastore(self): + def get_datastore(self) -> DataStore: return self.datastores.main def get_datastores(self): diff --git a/synapse/server.pyi b/synapse/server.pyi index 58cd099e6d..cd50c721b8 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -20,6 +20,7 @@ import synapse.handlers.room import synapse.handlers.room_member import synapse.handlers.set_password import synapse.http.client +import synapse.http.matrixfederationclient import synapse.notifier import synapse.push.pusherpool import synapse.replication.tcp.client @@ -143,3 +144,7 @@ class HomeServer(object): pass def get_replication_streams(self) -> Dict[str, Stream]: pass + def get_http_client( + self, + ) -> synapse.http.matrixfederationclient.MatrixFederationHttpClient: + pass diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 230fb5cd7f..66f01aad84 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -17,7 +17,6 @@ import itertools import logging from collections import OrderedDict, namedtuple -from functools import wraps from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple import attr @@ -69,27 +68,6 @@ def encode_json(json_object): _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) -def _retry_on_integrity_error(func): - """Wraps a database function so that it gets retried on IntegrityError, - with `delete_existing=True` passed in. - - Args: - func: function that returns a Deferred and accepts a `delete_existing` arg - """ - - @wraps(func) - @defer.inlineCallbacks - def f(self, *args, **kwargs): - try: - res = yield func(self, *args, delete_existing=False, **kwargs) - except self.database_engine.module.IntegrityError: - logger.exception("IntegrityError, retrying.") - res = yield func(self, *args, delete_existing=True, **kwargs) - return res - - return f - - @attr.s(slots=True) class DeltaState: """Deltas to use to update the `current_state_events` table. @@ -134,7 +112,6 @@ class PersistEventsStore: hs.config.worker.writers.events == hs.get_instance_name() ), "Can only instantiate EventsStore on master" - @_retry_on_integrity_error @defer.inlineCallbacks def _persist_events_and_state_updates( self, @@ -143,7 +120,6 @@ class PersistEventsStore: state_delta_for_room: Dict[str, DeltaState], new_forward_extremeties: Dict[str, List[str]], backfilled: bool = False, - delete_existing: bool = False, ): """Persist a set of events alongside updates to the current state and forward extremities tables. @@ -157,7 +133,6 @@ class PersistEventsStore: new_forward_extremities: Map from room_id to list of event IDs that are the new forward extremities of the room. backfilled - delete_existing Returns: Deferred: resolves when the events have been persisted @@ -197,7 +172,6 @@ class PersistEventsStore: self._persist_events_txn, events_and_contexts=events_and_contexts, backfilled=backfilled, - delete_existing=delete_existing, state_delta_for_room=state_delta_for_room, new_forward_extremeties=new_forward_extremeties, ) @@ -341,7 +315,6 @@ class PersistEventsStore: txn: LoggingTransaction, events_and_contexts: List[Tuple[EventBase, EventContext]], backfilled: bool, - delete_existing: bool = False, state_delta_for_room: Dict[str, DeltaState] = {}, new_forward_extremeties: Dict[str, List[str]] = {}, ): @@ -393,13 +366,6 @@ class PersistEventsStore: # From this point onwards the events are only events that we haven't # seen before. - if delete_existing: - # For paranoia reasons, we go and delete all the existing entries - # for these events so we can reinsert them. - # This gets around any problems with some tables already having - # entries. - self._delete_existing_rows_txn(txn, events_and_contexts=events_and_contexts) - self._store_event_txn(txn, events_and_contexts=events_and_contexts) # Insert into event_to_state_groups. @@ -797,39 +763,6 @@ class PersistEventsStore: return [ec for ec in events_and_contexts if ec[0] not in to_remove] - @classmethod - def _delete_existing_rows_txn(cls, txn, events_and_contexts): - if not events_and_contexts: - # nothing to do here - return - - logger.info("Deleting existing") - - for table in ( - "events", - "event_auth", - "event_json", - "event_edges", - "event_forward_extremities", - "event_reference_hashes", - "event_search", - "event_to_state_groups", - "state_events", - "rejections", - "redactions", - "room_memberships", - ): - txn.executemany( - "DELETE FROM %s WHERE event_id = ?" % (table,), - [(ev.event_id,) for ev, _ in events_and_contexts], - ) - - for table in ("event_push_actions",): - txn.executemany( - "DELETE FROM %s WHERE room_id = ? AND event_id = ?" % (table,), - [(ev.room_id, ev.event_id) for ev, _ in events_and_contexts], - ) - def _store_event_txn(self, txn, events_and_contexts): """Insert new events into the event and event_json tables diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 347cc50778..bb38a04ede 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -353,6 +353,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): last_room_id = progress.get("last_room_id", "") def _background_remove_left_rooms_txn(txn): + # get a batch of room ids to consider sql = """ SELECT DISTINCT room_id FROM current_state_events WHERE room_id > ? ORDER BY room_id LIMIT ? @@ -363,24 +364,68 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): if not room_ids: return True, set() + ########################################################################### + # + # exclude rooms where we have active members + sql = """ SELECT room_id - FROM current_state_events + FROM local_current_membership WHERE room_id > ? AND room_id <= ? - AND type = 'm.room.member' AND membership = 'join' - AND state_key LIKE ? GROUP BY room_id """ - txn.execute(sql, (last_room_id, room_ids[-1], "%:" + self.server_name)) - + txn.execute(sql, (last_room_id, room_ids[-1])) joined_room_ids = {row[0] for row in txn} + to_delete = set(room_ids) - joined_room_ids + + ########################################################################### + # + # exclude rooms which we are in the process of constructing; these otherwise + # qualify as "rooms with no local users", and would have their + # forward extremities cleaned up. + + # the following query will return a list of rooms which have forward + # extremities that are *not* also the create event in the room - ie + # those that are not being created currently. + + sql = """ + SELECT DISTINCT efe.room_id + FROM event_forward_extremities efe + LEFT JOIN current_state_events cse ON + cse.event_id = efe.event_id + AND cse.type = 'm.room.create' + AND cse.state_key = '' + WHERE + cse.event_id IS NULL + AND efe.room_id > ? AND efe.room_id <= ? + """ + + txn.execute(sql, (last_room_id, room_ids[-1])) + + # build a set of those rooms within `to_delete` that do not appear in + # the above, leaving us with the rooms in `to_delete` that *are* being + # created. + creating_rooms = to_delete.difference(row[0] for row in txn) + logger.info("skipping rooms which are being created: %s", creating_rooms) + + # now remove the rooms being created from the list of those to delete. + # + # (we could have just taken the intersection of `to_delete` with the result + # of the sql query, but it's useful to be able to log `creating_rooms`; and + # having done so, it's quicker to remove the (few) creating rooms from + # `to_delete` than it is to form the intersection with the (larger) list of + # not-creating-rooms) + + to_delete -= creating_rooms - left_rooms = set(room_ids) - joined_room_ids + ########################################################################### + # + # now clear the state for the rooms - logger.info("Deleting current state left rooms: %r", left_rooms) + logger.info("Deleting current state left rooms: %r", to_delete) # First we get all users that we still think were joined to the # room. This is so that we can mark those device lists as @@ -391,7 +436,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): txn, table="current_state_events", column="room_id", - iterable=left_rooms, + iterable=to_delete, keyvalues={"type": EventTypes.Member, "membership": Membership.JOIN}, retcols=("state_key",), ) @@ -403,7 +448,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): txn, table="current_state_events", column="room_id", - iterable=left_rooms, + iterable=to_delete, keyvalues={}, ) @@ -411,7 +456,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): txn, table="event_forward_extremities", column="room_id", - iterable=left_rooms, + iterable=to_delete, keyvalues={}, ) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 1acf287ca4..cdd093ffa8 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -46,7 +46,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): """If the user has no devices, we expect an empty list. """ local_user = "@boris:" + self.hs.hostname - res = yield self.handler.query_local_devices({local_user: None}) + res = yield defer.ensureDeferred( + self.handler.query_local_devices({local_user: None}) + ) self.assertDictEqual(res, {local_user: {}}) @defer.inlineCallbacks @@ -60,15 +62,19 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "alg2:k3": {"key": "key3"}, } - res = yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys} + res = yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": keys} + ) ) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) # we should be able to change the signature without a problem keys["alg2:k2"]["signatures"]["k1"] = "sig2" - res = yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys} + res = yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": keys} + ) ) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) @@ -84,44 +90,56 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "alg2:k3": {"key": "key3"}, } - res = yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys} + res = yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": keys} + ) ) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) try: - yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}} + ) ) self.fail("No error when changing string key") except errors.SynapseError: pass try: - yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}} + ) ) self.fail("No error when replacing dict key with string") except errors.SynapseError: pass try: - yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": {"alg1:k1": {"key": "key"}}} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, + device_id, + {"one_time_keys": {"alg1:k1": {"key": "key"}}}, + ) ) self.fail("No error when replacing string key with dict") except errors.SynapseError: pass try: - yield self.handler.upload_keys_for_user( - local_user, - device_id, - { - "one_time_keys": { - "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}} - } - }, + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, + device_id, + { + "one_time_keys": { + "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}} + } + }, + ) ) self.fail("No error when replacing dict key") except errors.SynapseError: @@ -133,13 +151,17 @@ class E2eKeysHandlerTestCase(unittest.TestCase): device_id = "xyz" keys = {"alg1:k1": "key1"} - res = yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys} + res = yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": keys} + ) ) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}}) - res2 = yield self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + res2 = yield defer.ensureDeferred( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + ) ) self.assertEqual( res2, @@ -163,7 +185,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): }, } } - yield self.handler.upload_signing_keys_for_user(local_user, keys1) + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user(local_user, keys1) + ) keys2 = { "master_key": { @@ -175,10 +199,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase): }, } } - yield self.handler.upload_signing_keys_for_user(local_user, keys2) + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user(local_user, keys2) + ) - devices = yield self.handler.query_devices( - {"device_keys": {local_user: []}}, 0, local_user + devices = yield defer.ensureDeferred( + self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user) ) self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) @@ -215,7 +241,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0", ) - yield self.handler.upload_signing_keys_for_user(local_user, keys1) + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user(local_user, keys1) + ) # upload two device keys, which will be signed later by the self-signing key device_key_1 = { @@ -245,18 +273,24 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "signatures": {local_user: {"ed25519:def": "base64+signature"}}, } - yield self.handler.upload_keys_for_user( - local_user, "abc", {"device_keys": device_key_1} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, "abc", {"device_keys": device_key_1} + ) ) - yield self.handler.upload_keys_for_user( - local_user, "def", {"device_keys": device_key_2} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, "def", {"device_keys": device_key_2} + ) ) # sign the first device key and upload it del device_key_1["signatures"] sign.sign_json(device_key_1, local_user, signing_key) - yield self.handler.upload_signatures_for_device_keys( - local_user, {local_user: {"abc": device_key_1}} + yield defer.ensureDeferred( + self.handler.upload_signatures_for_device_keys( + local_user, {local_user: {"abc": device_key_1}} + ) ) # sign the second device key and upload both device keys. The server @@ -264,14 +298,16 @@ class E2eKeysHandlerTestCase(unittest.TestCase): # signature for it del device_key_2["signatures"] sign.sign_json(device_key_2, local_user, signing_key) - yield self.handler.upload_signatures_for_device_keys( - local_user, {local_user: {"abc": device_key_1, "def": device_key_2}} + yield defer.ensureDeferred( + self.handler.upload_signatures_for_device_keys( + local_user, {local_user: {"abc": device_key_1, "def": device_key_2}} + ) ) device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature" device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature" - devices = yield self.handler.query_devices( - {"device_keys": {local_user: []}}, 0, local_user + devices = yield defer.ensureDeferred( + self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user) ) del devices["device_keys"][local_user]["abc"]["unsigned"] del devices["device_keys"][local_user]["def"]["unsigned"] @@ -292,7 +328,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): }, } } - yield self.handler.upload_signing_keys_for_user(local_user, keys1) + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user(local_user, keys1) + ) res = None try: @@ -305,7 +343,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): res = e.code self.assertEqual(res, 400) - res = yield self.handler.query_local_devices({local_user: None}) + res = yield defer.ensureDeferred( + self.handler.query_local_devices({local_user: None}) + ) self.assertDictEqual(res, {local_user: {}}) @defer.inlineCallbacks @@ -331,8 +371,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA" ) - yield self.handler.upload_keys_for_user( - local_user, device_id, {"device_keys": device_key} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"device_keys": device_key} + ) ) # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0 @@ -372,7 +414,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "user_signing_key": usersigning_key, "self_signing_key": selfsigning_key, } - yield self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys) + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys) + ) # set up another user with a master key. This user will be signed by # the first user @@ -384,76 +428,90 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "usage": ["master"], "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey}, } - yield self.handler.upload_signing_keys_for_user( - other_user, {"master_key": other_master_key} + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user( + other_user, {"master_key": other_master_key} + ) ) # test various signature failures (see below) - ret = yield self.handler.upload_signatures_for_device_keys( - local_user, - { - local_user: { - # fails because the signature is invalid - # should fail with INVALID_SIGNATURE - device_id: { - "user_id": local_user, - "device_id": device_id, - "algorithms": [ - "m.olm.curve25519-aes-sha2", - RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2, - ], - "keys": { - "curve25519:xyz": "curve25519+key", - # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA - "ed25519:xyz": device_pubkey, - }, - "signatures": { - local_user: {"ed25519:" + selfsigning_pubkey: "something"} + ret = yield defer.ensureDeferred( + self.handler.upload_signatures_for_device_keys( + local_user, + { + local_user: { + # fails because the signature is invalid + # should fail with INVALID_SIGNATURE + device_id: { + "user_id": local_user, + "device_id": device_id, + "algorithms": [ + "m.olm.curve25519-aes-sha2", + RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2, + ], + "keys": { + "curve25519:xyz": "curve25519+key", + # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA + "ed25519:xyz": device_pubkey, + }, + "signatures": { + local_user: { + "ed25519:" + selfsigning_pubkey: "something" + } + }, }, - }, - # fails because device is unknown - # should fail with NOT_FOUND - "unknown": { - "user_id": local_user, - "device_id": "unknown", - "signatures": { - local_user: {"ed25519:" + selfsigning_pubkey: "something"} + # fails because device is unknown + # should fail with NOT_FOUND + "unknown": { + "user_id": local_user, + "device_id": "unknown", + "signatures": { + local_user: { + "ed25519:" + selfsigning_pubkey: "something" + } + }, }, - }, - # fails because the signature is invalid - # should fail with INVALID_SIGNATURE - master_pubkey: { - "user_id": local_user, - "usage": ["master"], - "keys": {"ed25519:" + master_pubkey: master_pubkey}, - "signatures": { - local_user: {"ed25519:" + device_pubkey: "something"} + # fails because the signature is invalid + # should fail with INVALID_SIGNATURE + master_pubkey: { + "user_id": local_user, + "usage": ["master"], + "keys": {"ed25519:" + master_pubkey: master_pubkey}, + "signatures": { + local_user: {"ed25519:" + device_pubkey: "something"} + }, }, }, - }, - other_user: { - # fails because the device is not the user's master-signing key - # should fail with NOT_FOUND - "unknown": { - "user_id": other_user, - "device_id": "unknown", - "signatures": { - local_user: {"ed25519:" + usersigning_pubkey: "something"} + other_user: { + # fails because the device is not the user's master-signing key + # should fail with NOT_FOUND + "unknown": { + "user_id": other_user, + "device_id": "unknown", + "signatures": { + local_user: { + "ed25519:" + usersigning_pubkey: "something" + } + }, }, - }, - other_master_pubkey: { - # fails because the key doesn't match what the server has - # should fail with UNKNOWN - "user_id": other_user, - "usage": ["master"], - "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey}, - "something": "random", - "signatures": { - local_user: {"ed25519:" + usersigning_pubkey: "something"} + other_master_pubkey: { + # fails because the key doesn't match what the server has + # should fail with UNKNOWN + "user_id": other_user, + "usage": ["master"], + "keys": { + "ed25519:" + other_master_pubkey: other_master_pubkey + }, + "something": "random", + "signatures": { + local_user: { + "ed25519:" + usersigning_pubkey: "something" + } + }, }, }, }, - }, + ) ) user_failures = ret["failures"][local_user] @@ -478,19 +536,23 @@ class E2eKeysHandlerTestCase(unittest.TestCase): sign.sign_json(device_key, local_user, selfsigning_signing_key) sign.sign_json(master_key, local_user, device_signing_key) sign.sign_json(other_master_key, local_user, usersigning_signing_key) - ret = yield self.handler.upload_signatures_for_device_keys( - local_user, - { - local_user: {device_id: device_key, master_pubkey: master_key}, - other_user: {other_master_pubkey: other_master_key}, - }, + ret = yield defer.ensureDeferred( + self.handler.upload_signatures_for_device_keys( + local_user, + { + local_user: {device_id: device_key, master_pubkey: master_key}, + other_user: {other_master_pubkey: other_master_key}, + }, + ) ) self.assertEqual(ret["failures"], {}) # fetch the signed keys/devices and make sure that the signatures are there - ret = yield self.handler.query_devices( - {"device_keys": {local_user: [], other_user: []}}, 0, local_user + ret = yield defer.ensureDeferred( + self.handler.query_devices( + {"device_keys": {local_user: [], other_user: []}}, 0, local_user + ) ) self.assertEqual( diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 822ea42dde..3362050ce0 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -66,7 +66,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.get_version_info(self.local_user) + yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -78,7 +78,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.get_version_info(self.local_user, "bogus_version") + yield defer.ensureDeferred( + self.handler.get_version_info(self.local_user, "bogus_version") + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -87,14 +89,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_create_version(self): """Check that we can create and then retrieve versions. """ - res = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + res = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(res, "1") # check we can retrieve it as the current version - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) version_etag = res["etag"] self.assertIsInstance(version_etag, str) del res["etag"] @@ -109,7 +116,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) # check we can retrieve it as a specific version - res = yield self.handler.get_version_info(self.local_user, "1") + res = yield defer.ensureDeferred( + self.handler.get_version_info(self.local_user, "1") + ) self.assertEqual(res["etag"], version_etag) del res["etag"] self.assertDictEqual( @@ -123,17 +132,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) # upload a new one... - res = yield self.handler.create_version( - self.local_user, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "second_version_auth_data", - }, + res = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }, + ) ) self.assertEqual(res, "2") # check we can retrieve it as the current version - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) del res["etag"] self.assertDictEqual( res, @@ -149,25 +160,32 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_update_version(self): """Check that we can update versions. """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - res = yield self.handler.update_version( - self.local_user, - version, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": version, - }, + res = yield defer.ensureDeferred( + self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": version, + }, + ) ) self.assertDictEqual(res, {}) # check we can retrieve it as the current version - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) del res["etag"] self.assertDictEqual( res, @@ -185,14 +203,16 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.update_version( - self.local_user, - "1", - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": "1", - }, + yield defer.ensureDeferred( + self.handler.update_version( + self.local_user, + "1", + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": "1", + }, + ) ) except errors.SynapseError as e: res = e.code @@ -202,23 +222,30 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_update_omitted_version(self): """Check that the update succeeds if the version is missing from the body """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - yield self.handler.update_version( - self.local_user, - version, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - }, + yield defer.ensureDeferred( + self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + }, + ) ) # check we can retrieve it as the current version - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) del res["etag"] # etag is opaque, so don't test its contents self.assertDictEqual( res, @@ -234,22 +261,29 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_update_bad_version(self): """Check that we get a 400 if the version in the body doesn't match """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") res = None try: - yield self.handler.update_version( - self.local_user, - version, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": "incorrect", - }, + yield defer.ensureDeferred( + self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": "incorrect", + }, + ) ) except errors.SynapseError as e: res = e.code @@ -261,7 +295,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.delete_version(self.local_user, "1") + yield defer.ensureDeferred( + self.handler.delete_version(self.local_user, "1") + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -272,7 +308,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.delete_version(self.local_user) + yield defer.ensureDeferred(self.handler.delete_version(self.local_user)) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -281,19 +317,26 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_delete_version(self): """Check that we can create and then delete versions. """ - res = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + res = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(res, "1") # check we can delete it - yield self.handler.delete_version(self.local_user, "1") + yield defer.ensureDeferred(self.handler.delete_version(self.local_user, "1")) # check that it's gone res = None try: - yield self.handler.get_version_info(self.local_user, "1") + yield defer.ensureDeferred( + self.handler.get_version_info(self.local_user, "1") + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -304,7 +347,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.get_room_keys(self.local_user, "bogus_version") + yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, "bogus_version") + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -313,13 +358,20 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_get_missing_room_keys(self): """Check we get an empty response from an empty backup """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - res = yield self.handler.get_room_keys(self.local_user, version) + res = yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, version) + ) self.assertDictEqual(res, {"rooms": {}}) # TODO: test the locking semantics when uploading room_keys, @@ -331,8 +383,8 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.upload_room_keys( - self.local_user, "no_version", room_keys + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, "no_version", room_keys) ) except errors.SynapseError as e: res = e.code @@ -343,16 +395,23 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """Check that we get a 404 on uploading keys when an nonexistent version is specified """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") res = None try: - yield self.handler.upload_room_keys( - self.local_user, "bogus_version", room_keys + yield defer.ensureDeferred( + self.handler.upload_room_keys( + self.local_user, "bogus_version", room_keys + ) ) except errors.SynapseError as e: res = e.code @@ -362,24 +421,33 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_upload_room_keys_wrong_version(self): """Check that we get a 403 on uploading keys for an old version """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - version = yield self.handler.create_version( - self.local_user, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "second_version_auth_data", - }, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }, + ) ) self.assertEqual(version, "2") res = None try: - yield self.handler.upload_room_keys(self.local_user, "1", room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, "1", room_keys) + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 403) @@ -388,26 +456,39 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_upload_room_keys_insert(self): """Check that we can insert and retrieve keys for a session """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - yield self.handler.upload_room_keys(self.local_user, version, room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, room_keys) + ) - res = yield self.handler.get_room_keys(self.local_user, version) + res = yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, version) + ) self.assertDictEqual(res, room_keys) # check getting room_keys for a given room - res = yield self.handler.get_room_keys( - self.local_user, version, room_id="!abc:matrix.org" + res = yield defer.ensureDeferred( + self.handler.get_room_keys( + self.local_user, version, room_id="!abc:matrix.org" + ) ) self.assertDictEqual(res, room_keys) # check getting room_keys for a given session_id - res = yield self.handler.get_room_keys( - self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + res = yield defer.ensureDeferred( + self.handler.get_room_keys( + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + ) ) self.assertDictEqual(res, room_keys) @@ -415,16 +496,23 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_upload_room_keys_merge(self): """Check that we can upload a new room_key for an existing session and have it correctly merged""" - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - yield self.handler.upload_room_keys(self.local_user, version, room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, room_keys) + ) # get the etag to compare to future versions - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) backup_etag = res["etag"] self.assertEqual(res["count"], 1) @@ -434,29 +522,37 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # test that increasing the message_index doesn't replace the existing session new_room_key["first_message_index"] = 2 new_room_key["session_data"] = "new" - yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, new_room_keys) + ) - res = yield self.handler.get_room_keys(self.local_user, version) + res = yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, version) + ) self.assertEqual( res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "SSBBTSBBIEZJU0gK", ) # the etag should be the same since the session did not change - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) self.assertEqual(res["etag"], backup_etag) # test that marking the session as verified however /does/ replace it new_room_key["is_verified"] = True - yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, new_room_keys) + ) - res = yield self.handler.get_room_keys(self.local_user, version) + res = yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, version) + ) self.assertEqual( res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) # the etag should NOT be equal now, since the key changed - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) self.assertNotEqual(res["etag"], backup_etag) backup_etag = res["etag"] @@ -464,15 +560,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # with a lower forwarding count new_room_key["forwarded_count"] = 2 new_room_key["session_data"] = "other" - yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, new_room_keys) + ) - res = yield self.handler.get_room_keys(self.local_user, version) + res = yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, version) + ) self.assertEqual( res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) # the etag should be the same since the session did not change - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) self.assertEqual(res["etag"], backup_etag) # TODO: check edge cases as well as the common variations here @@ -481,36 +581,59 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_delete_room_keys(self): """Check that we can insert and delete keys for a session """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") # check for bulk-delete - yield self.handler.upload_room_keys(self.local_user, version, room_keys) - yield self.handler.delete_room_keys(self.local_user, version) - res = yield self.handler.get_room_keys( - self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, room_keys) + ) + yield defer.ensureDeferred( + self.handler.delete_room_keys(self.local_user, version) + ) + res = yield defer.ensureDeferred( + self.handler.get_room_keys( + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + ) ) self.assertDictEqual(res, {"rooms": {}}) # check for bulk-delete per room - yield self.handler.upload_room_keys(self.local_user, version, room_keys) - yield self.handler.delete_room_keys( - self.local_user, version, room_id="!abc:matrix.org" + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, room_keys) + ) + yield defer.ensureDeferred( + self.handler.delete_room_keys( + self.local_user, version, room_id="!abc:matrix.org" + ) ) - res = yield self.handler.get_room_keys( - self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + res = yield defer.ensureDeferred( + self.handler.get_room_keys( + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + ) ) self.assertDictEqual(res, {"rooms": {}}) # check for bulk-delete per session - yield self.handler.upload_room_keys(self.local_user, version, room_keys) - yield self.handler.delete_room_keys( - self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, room_keys) + ) + yield defer.ensureDeferred( + self.handler.delete_room_keys( + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + ) ) - res = yield self.handler.get_room_keys( - self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + res = yield defer.ensureDeferred( + self.handler.get_room_keys( + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + ) ) self.assertDictEqual(res, {"rooms": {}}) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 1e6a53bf7f..5878f74175 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -138,10 +138,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room - def get_current_users_in_room(room_id): + def get_users_in_room(room_id): return defer.succeed({str(u) for u in self.room_members}) - hs.get_state_handler().get_current_users_in_room = get_current_users_in_room + self.datastore.get_users_in_room = get_users_in_room self.datastore.get_user_directory_stream_pos.return_value = ( # we deliberately return a non-None stream pos to avoid doing an initial_spam diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 9d4f0bbe44..06575ba0a6 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Any, List, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple import attr @@ -26,8 +26,9 @@ from synapse.app.generic_worker import ( GenericWorkerReplicationHandler, GenericWorkerServer, ) +from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest -from synapse.replication.http import streams +from synapse.replication.http import ReplicationRestResource, streams from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory @@ -35,7 +36,7 @@ from synapse.server import HomeServer from synapse.util import Clock from tests import unittest -from tests.server import FakeTransport +from tests.server import FakeTransport, render logger = logging.getLogger(__name__) @@ -180,6 +181,159 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.assertEqual(request.method, b"GET") +class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): + """Base class for tests running multiple workers. + + Automatically handle HTTP replication requests from workers to master, + unlike `BaseStreamTestCase`. + """ + + servlets = [] # type: List[Callable[[HomeServer, JsonResource], None]] + + def setUp(self): + super().setUp() + + # build a replication server + self.server_factory = ReplicationStreamProtocolFactory(self.hs) + self.streamer = self.hs.get_replication_streamer() + + store = self.hs.get_datastore() + self.database = store.db + + self.reactor.lookups["testserv"] = "1.2.3.4" + + self._worker_hs_to_resource = {} + + # When we see a connection attempt to the master replication listener we + # automatically set up the connection. This is so that tests don't + # manually have to go and explicitly set it up each time (plus sometimes + # it is impossible to write the handling explicitly in the tests). + self.reactor.add_tcp_client_callback( + "1.2.3.4", 8765, self._handle_http_replication_attempt + ) + + def create_test_json_resource(self): + """Overrides `HomeserverTestCase.create_test_json_resource`. + """ + # We override this so that it automatically registers all the HTTP + # replication servlets, without having to explicitly do that in all + # subclassses. + + resource = ReplicationRestResource(self.hs) + + for servlet in self.servlets: + servlet(self.hs, resource) + + return resource + + def make_worker_hs( + self, worker_app: str, extra_config: dict = {}, **kwargs + ) -> HomeServer: + """Make a new worker HS instance, correctly connecting replcation + stream to the master HS. + + Args: + worker_app: Type of worker, e.g. `synapse.app.federation_sender`. + extra_config: Any extra config to use for this instances. + **kwargs: Options that get passed to `self.setup_test_homeserver`, + useful to e.g. pass some mocks for things like `http_client` + + Returns: + The new worker HomeServer instance. + """ + + config = self._get_worker_hs_config() + config["worker_app"] = worker_app + config.update(extra_config) + + worker_hs = self.setup_test_homeserver( + homeserverToUse=GenericWorkerServer, + config=config, + reactor=self.reactor, + **kwargs + ) + + store = worker_hs.get_datastore() + store.db._db_pool = self.database._db_pool + + repl_handler = ReplicationCommandHandler(worker_hs) + client = ClientReplicationStreamProtocol( + worker_hs, "client", "test", self.clock, repl_handler, + ) + server = self.server_factory.buildProtocol(None) + + client_transport = FakeTransport(server, self.reactor) + client.makeConnection(client_transport) + + server_transport = FakeTransport(client, self.reactor) + server.makeConnection(server_transport) + + # Set up a resource for the worker + resource = ReplicationRestResource(self.hs) + + for servlet in self.servlets: + servlet(worker_hs, resource) + + self._worker_hs_to_resource[worker_hs] = resource + + return worker_hs + + def _get_worker_hs_config(self) -> dict: + config = self.default_config() + config["worker_replication_host"] = "testserv" + config["worker_replication_http_port"] = "8765" + return config + + def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest): + render(request, self._worker_hs_to_resource[worker_hs], self.reactor) + + def replicate(self): + """Tell the master side of replication that something has happened, and then + wait for the replication to occur. + """ + self.streamer.on_notifier_poke() + self.pump() + + def _handle_http_replication_attempt(self): + """Handles a connection attempt to the master replication HTTP + listener. + """ + + # We should have at least one outbound connection attempt, where the + # last is one to the HTTP repication IP/port. + clients = self.reactor.tcpClients + self.assertGreaterEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop() + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 8765) + + # Set up client side protocol + client_protocol = client_factory.buildProtocol(None) + + request_factory = OneShotRequestFactory() + + # Set up the server side protocol + channel = _PushHTTPChannel(self.reactor) + channel.requestFactory = request_factory + channel.site = self.site + + # Connect client to server and vice versa. + client_to_server_transport = FakeTransport( + channel, self.reactor, client_protocol + ) + client_protocol.makeConnection(client_to_server_transport) + + server_to_client_transport = FakeTransport( + client_protocol, self.reactor, channel + ) + channel.makeConnection(server_to_client_transport) + + # Note: at this point we've wired everything up, but we need to return + # before the data starts flowing over the connections as this is called + # inside `connecTCP` before the connection has been passed back to the + # code that requested the TCP connection. + + class TestReplicationDataHandler(GenericWorkerReplicationHandler): """Drop-in for ReplicationDataHandler which just collects RDATA rows""" @@ -241,6 +395,14 @@ class _PushHTTPChannel(HTTPChannel): # We need to manually stop the _PullToPushProducer. self._pull_to_push_producer.stop() + def checkPersistence(self, request, version): + """Check whether the connection can be re-used + """ + # We hijack this to always say no for ease of wiring stuff up in + # `handle_http_replication_attempt`. + request.responseHeaders.setRawHeaders(b"connection", [b"close"]) + return False + class _PullToPushProducer: """A push producer that wraps a pull producer. diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py index b7d753e0a3..86c03fd89c 100644 --- a/tests/replication/test_client_reader_shard.py +++ b/tests/replication/test_client_reader_shard.py @@ -15,63 +15,26 @@ import logging from synapse.api.constants import LoginType -from synapse.app.generic_worker import GenericWorkerServer -from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest -from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.rest.client.v2_alpha import register -from tests import unittest +from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker -from tests.server import FakeChannel, render +from tests.server import FakeChannel logger = logging.getLogger(__name__) -class ClientReaderTestCase(unittest.HomeserverTestCase): +class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): """Base class for tests of the replication streams""" - servlets = [ - register.register_servlets, - ] + servlets = [register.register_servlets] def prepare(self, reactor, clock, hs): - # build a replication server - self.server_factory = ReplicationStreamProtocolFactory(hs) - self.streamer = hs.get_replication_streamer() - - store = hs.get_datastore() - self.database = store.db - self.recaptcha_checker = DummyRecaptchaChecker(hs) auth_handler = hs.get_auth_handler() auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker - self.reactor.lookups["testserv"] = "1.2.3.4" - - def make_worker_hs(self, extra_config={}): - config = self._get_worker_hs_config() - config.update(extra_config) - - worker_hs = self.setup_test_homeserver( - homeserverToUse=GenericWorkerServer, config=config, reactor=self.reactor, - ) - - store = worker_hs.get_datastore() - store.db._db_pool = self.database._db_pool - - # Register the expected servlets, essentially this is HomeserverTestCase.create_test_json_resource. - resource = JsonResource(self.hs) - - for servlet in self.servlets: - servlet(worker_hs, resource) - - # Essentially HomeserverTestCase.render. - def _render(request): - render(request, self.resource, self.reactor) - - return worker_hs, _render - def _get_worker_hs_config(self) -> dict: config = self.default_config() config["worker_app"] = "synapse.app.client_reader" @@ -82,14 +45,14 @@ class ClientReaderTestCase(unittest.HomeserverTestCase): def test_register_single_worker(self): """Test that registration works when using a single client reader worker. """ - _, worker_render = self.make_worker_hs() + worker_hs = self.make_worker_hs("synapse.app.client_reader") request_1, channel_1 = self.make_request( "POST", "register", {"username": "user", "type": "m.login.password", "password": "bar"}, ) # type: SynapseRequest, FakeChannel - worker_render(request_1) + self.render_on_worker(worker_hs, request_1) self.assertEqual(request_1.code, 401) # Grab the session @@ -99,7 +62,7 @@ class ClientReaderTestCase(unittest.HomeserverTestCase): request_2, channel_2 = self.make_request( "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} ) # type: SynapseRequest, FakeChannel - worker_render(request_2) + self.render_on_worker(worker_hs, request_2) self.assertEqual(request_2.code, 200) # We're given a registered user. @@ -108,15 +71,15 @@ class ClientReaderTestCase(unittest.HomeserverTestCase): def test_register_multi_worker(self): """Test that registration works when using multiple client reader workers. """ - _, worker_render_1 = self.make_worker_hs() - _, worker_render_2 = self.make_worker_hs() + worker_hs_1 = self.make_worker_hs("synapse.app.client_reader") + worker_hs_2 = self.make_worker_hs("synapse.app.client_reader") request_1, channel_1 = self.make_request( "POST", "register", {"username": "user", "type": "m.login.password", "password": "bar"}, ) # type: SynapseRequest, FakeChannel - worker_render_1(request_1) + self.render_on_worker(worker_hs_1, request_1) self.assertEqual(request_1.code, 401) # Grab the session @@ -126,7 +89,7 @@ class ClientReaderTestCase(unittest.HomeserverTestCase): request_2, channel_2 = self.make_request( "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} ) # type: SynapseRequest, FakeChannel - worker_render_2(request_2) + self.render_on_worker(worker_hs_2, request_2) self.assertEqual(request_2.code, 200) # We're given a registered user. diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 519a2dc510..8d4dbf232e 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -19,132 +19,40 @@ from mock import Mock from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.app.generic_worker import GenericWorkerServer from synapse.events.builder import EventBuilderFactory -from synapse.replication.http import streams -from synapse.replication.tcp.handler import ReplicationCommandHandler -from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol -from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.rest.admin import register_servlets_for_client_rest_resource from synapse.rest.client.v1 import login, room from synapse.types import UserID -from tests import unittest -from tests.server import FakeTransport +from tests.replication._base import BaseMultiWorkerStreamTestCase logger = logging.getLogger(__name__) -class BaseStreamTestCase(unittest.HomeserverTestCase): - """Base class for tests of the replication streams""" - +class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): servlets = [ - streams.register_servlets, + login.register_servlets, + register_servlets_for_client_rest_resource, + room.register_servlets, ] - def prepare(self, reactor, clock, hs): - # build a replication server - self.server_factory = ReplicationStreamProtocolFactory(hs) - self.streamer = hs.get_replication_streamer() - - store = hs.get_datastore() - self.database = store.db - - self.reactor.lookups["testserv"] = "1.2.3.4" - def default_config(self): conf = super().default_config() conf["send_federation"] = False return conf - def make_worker_hs(self, extra_config={}): - config = self._get_worker_hs_config() - config.update(extra_config) - - mock_federation_client = Mock(spec=["put_json"]) - mock_federation_client.put_json.side_effect = lambda *_, **__: defer.succeed({}) - - worker_hs = self.setup_test_homeserver( - http_client=mock_federation_client, - homeserverToUse=GenericWorkerServer, - config=config, - reactor=self.reactor, - ) - - store = worker_hs.get_datastore() - store.db._db_pool = self.database._db_pool - - repl_handler = ReplicationCommandHandler(worker_hs) - client = ClientReplicationStreamProtocol( - worker_hs, "client", "test", self.clock, repl_handler, - ) - server = self.server_factory.buildProtocol(None) - - client_transport = FakeTransport(server, self.reactor) - client.makeConnection(client_transport) - - server_transport = FakeTransport(client, self.reactor) - server.makeConnection(server_transport) - - return worker_hs - - def _get_worker_hs_config(self) -> dict: - config = self.default_config() - config["worker_app"] = "synapse.app.federation_sender" - config["worker_replication_host"] = "testserv" - config["worker_replication_http_port"] = "8765" - return config - - def replicate(self): - """Tell the master side of replication that something has happened, and then - wait for the replication to occur. - """ - self.streamer.on_notifier_poke() - self.pump() - - def create_room_with_remote_server(self, user, token, remote_server="other_server"): - room = self.helper.create_room_as(user, tok=token) - store = self.hs.get_datastore() - federation = self.hs.get_handlers().federation_handler - - prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room)) - room_version = self.get_success(store.get_room_version(room)) - - factory = EventBuilderFactory(self.hs) - factory.hostname = remote_server - - user_id = UserID("user", remote_server).to_string() - - event_dict = { - "type": EventTypes.Member, - "state_key": user_id, - "content": {"membership": Membership.JOIN}, - "sender": user_id, - "room_id": room, - } - - builder = factory.for_room_version(room_version, event_dict) - join_event = self.get_success(builder.build(prev_event_ids)) - - self.get_success(federation.on_send_join_request(remote_server, join_event)) - self.replicate() - - return room - - -class FederationSenderTestCase(BaseStreamTestCase): - servlets = [ - login.register_servlets, - register_servlets_for_client_rest_resource, - room.register_servlets, - ] - def test_send_event_single_sender(self): """Test that using a single federation sender worker correctly sends a new event. """ - worker_hs = self.make_worker_hs({"send_federation": True}) - mock_client = worker_hs.get_http_client() + mock_client = Mock(spec=["put_json"]) + mock_client.put_json.side_effect = lambda *_, **__: defer.succeed({}) + + self.make_worker_hs( + "synapse.app.federation_sender", + {"send_federation": True}, + http_client=mock_client, + ) user = self.register_user("user", "pass") token = self.login("user", "pass") @@ -165,23 +73,29 @@ class FederationSenderTestCase(BaseStreamTestCase): """Test that using two federation sender workers correctly sends new events. """ - worker1 = self.make_worker_hs( + mock_client1 = Mock(spec=["put_json"]) + mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({}) + self.make_worker_hs( + "synapse.app.federation_sender", { "send_federation": True, "worker_name": "sender1", "federation_sender_instances": ["sender1", "sender2"], - } + }, + http_client=mock_client1, ) - mock_client1 = worker1.get_http_client() - worker2 = self.make_worker_hs( + mock_client2 = Mock(spec=["put_json"]) + mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({}) + self.make_worker_hs( + "synapse.app.federation_sender", { "send_federation": True, "worker_name": "sender2", "federation_sender_instances": ["sender1", "sender2"], - } + }, + http_client=mock_client2, ) - mock_client2 = worker2.get_http_client() user = self.register_user("user2", "pass") token = self.login("user2", "pass") @@ -191,8 +105,8 @@ class FederationSenderTestCase(BaseStreamTestCase): for i in range(20): server_name = "other_server_%d" % (i,) room = self.create_room_with_remote_server(user, token, server_name) - mock_client1.reset_mock() - mock_client2.reset_mock() + mock_client1.reset_mock() # type: ignore[attr-defined] + mock_client2.reset_mock() # type: ignore[attr-defined] self.create_and_send_event(room, UserID.from_string(user)) self.replicate() @@ -222,23 +136,29 @@ class FederationSenderTestCase(BaseStreamTestCase): """Test that using two federation sender workers correctly sends new typing EDUs. """ - worker1 = self.make_worker_hs( + mock_client1 = Mock(spec=["put_json"]) + mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({}) + self.make_worker_hs( + "synapse.app.federation_sender", { "send_federation": True, "worker_name": "sender1", "federation_sender_instances": ["sender1", "sender2"], - } + }, + http_client=mock_client1, ) - mock_client1 = worker1.get_http_client() - worker2 = self.make_worker_hs( + mock_client2 = Mock(spec=["put_json"]) + mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({}) + self.make_worker_hs( + "synapse.app.federation_sender", { "send_federation": True, "worker_name": "sender2", "federation_sender_instances": ["sender1", "sender2"], - } + }, + http_client=mock_client2, ) - mock_client2 = worker2.get_http_client() user = self.register_user("user3", "pass") token = self.login("user3", "pass") @@ -250,8 +170,8 @@ class FederationSenderTestCase(BaseStreamTestCase): for i in range(20): server_name = "other_server_%d" % (i,) room = self.create_room_with_remote_server(user, token, server_name) - mock_client1.reset_mock() - mock_client2.reset_mock() + mock_client1.reset_mock() # type: ignore[attr-defined] + mock_client2.reset_mock() # type: ignore[attr-defined] self.get_success( typing_handler.started_typing( @@ -284,3 +204,32 @@ class FederationSenderTestCase(BaseStreamTestCase): self.assertTrue(sent_on_1) self.assertTrue(sent_on_2) + + def create_room_with_remote_server(self, user, token, remote_server="other_server"): + room = self.helper.create_room_as(user, tok=token) + store = self.hs.get_datastore() + federation = self.hs.get_handlers().federation_handler + + prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room)) + room_version = self.get_success(store.get_room_version(room)) + + factory = EventBuilderFactory(self.hs) + factory.hostname = remote_server + + user_id = UserID("user", remote_server).to_string() + + event_dict = { + "type": EventTypes.Member, + "state_key": user_id, + "content": {"membership": Membership.JOIN}, + "sender": user_id, + "room_id": room, + } + + builder = factory.for_room_version(room_version, event_dict) + join_event = self.get_success(builder.build(prev_event_ids)) + + self.get_success(federation.on_send_join_request(remote_server, join_event)) + self.replicate() + + return room diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index cca5f548e6..f16eef15f7 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -857,6 +857,53 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) + def test_reactivate_user(self): + """ + Test reactivating another user. + """ + + # Deactivate the user. + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=json.dumps({"deactivated": True}).encode(encoding="utf_8"), + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Attempt to reactivate the user (without a password). + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=json.dumps({"deactivated": False}).encode(encoding="utf_8"), + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + + # Reactivate the user. + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=json.dumps({"deactivated": False, "password": "foo"}).encode( + encoding="utf_8" + ), + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(False, channel.json_body["deactivated"]) + def test_set_user_as_admin(self): """ Test setting the admin flag on a user. diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 4413bb3932..db52725cfe 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -547,8 +547,8 @@ class JWTTestCase(unittest.HomeserverTestCase): def test_login_jwt_invalid_signature(self): channel = self.jwt_login({"sub": "frog"}, "notsecret") - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], "JWT validation failed: Signature verification failed", @@ -556,8 +556,8 @@ class JWTTestCase(unittest.HomeserverTestCase): def test_login_jwt_expired(self): channel = self.jwt_login({"sub": "frog", "exp": 864000}) - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], "JWT validation failed: Signature has expired" ) @@ -565,8 +565,8 @@ class JWTTestCase(unittest.HomeserverTestCase): def test_login_jwt_not_before(self): now = int(time.time()) channel = self.jwt_login({"sub": "frog", "nbf": now + 3600}) - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], "JWT validation failed: The token is not yet valid (nbf)", @@ -574,8 +574,8 @@ class JWTTestCase(unittest.HomeserverTestCase): def test_login_no_sub(self): channel = self.jwt_login({"username": "root"}) - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Invalid JWT") @override_config( @@ -597,16 +597,16 @@ class JWTTestCase(unittest.HomeserverTestCase): # An invalid issuer. channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], "JWT validation failed: Invalid issuer" ) # Not providing an issuer. channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], 'JWT validation failed: Token is missing the "iss" claim', @@ -637,16 +637,16 @@ class JWTTestCase(unittest.HomeserverTestCase): # An invalid audience. channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], "JWT validation failed: Invalid audience" ) # Not providing an audience. channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], 'JWT validation failed: Token is missing the "aud" claim', @@ -655,7 +655,8 @@ class JWTTestCase(unittest.HomeserverTestCase): def test_login_aud_no_config(self): """Test providing an audience without requiring it in the configuration.""" channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], "JWT validation failed: Invalid audience" ) @@ -664,8 +665,8 @@ class JWTTestCase(unittest.HomeserverTestCase): params = json.dumps({"type": "org.matrix.login.jwt"}) request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Token field for JWT is missing") @@ -747,8 +748,8 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): def test_login_jwt_invalid_signature(self): channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey) - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], "JWT validation failed: Signature verification failed", diff --git a/tests/server.py b/tests/server.py index a5e57c52fa..b6e0b14e78 100644 --- a/tests/server.py +++ b/tests/server.py @@ -237,6 +237,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): def __init__(self): self.threadpool = ThreadPool(self) + self._tcp_callbacks = {} self._udp = [] lookups = self.lookups = {} @@ -268,6 +269,29 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): def getThreadPool(self): return self.threadpool + def add_tcp_client_callback(self, host, port, callback): + """Add a callback that will be invoked when we receive a connection + attempt to the given IP/port using `connectTCP`. + + Note that the callback gets run before we return the connection to the + client, which means callbacks cannot block while waiting for writes. + """ + self._tcp_callbacks[(host, port)] = callback + + def connectTCP(self, host, port, factory, timeout=30, bindAddress=None): + """Fake L{IReactorTCP.connectTCP}. + """ + + conn = super().connectTCP( + host, port, factory, timeout=timeout, bindAddress=None + ) + + callback = self._tcp_callbacks.get((host, port)) + if callback: + callback() + + return conn + class ThreadPool: """ @@ -486,7 +510,7 @@ class FakeTransport(object): try: self.other.dataReceived(to_write) except Exception as e: - logger.warning("Exception writing to protocol: %s", e) + logger.exception("Exception writing to protocol: %s", e) return self.buffer = self.buffer[len(to_write) :] |