diff --git a/changelog.d/7836.misc b/changelog.d/7836.misc
new file mode 100644
index 0000000000..a3a97c7590
--- /dev/null
+++ b/changelog.d/7836.misc
@@ -0,0 +1 @@
+Ensure that calls to `json.dumps` are compatible with the standard library json.
diff --git a/changelog.d/7844.bugfix b/changelog.d/7844.bugfix
new file mode 100644
index 0000000000..ad296f1b3c
--- /dev/null
+++ b/changelog.d/7844.bugfix
@@ -0,0 +1 @@
+Errors which occur while using the non-standard JWT login now return the proper error: `403 Forbidden` with an error code of `M_FORBIDDEN`.
diff --git a/changelog.d/7847.feature b/changelog.d/7847.feature
new file mode 100644
index 0000000000..4b9a8d8569
--- /dev/null
+++ b/changelog.d/7847.feature
@@ -0,0 +1 @@
+Add the ability to re-activate an account from the admin API.
diff --git a/changelog.d/7848.misc b/changelog.d/7848.misc
new file mode 100644
index 0000000000..d9db1d8357
--- /dev/null
+++ b/changelog.d/7848.misc
@@ -0,0 +1 @@
+Remove redundant `retry_on_integrity_error` wrapper for event persistence code.
diff --git a/changelog.d/7851.misc b/changelog.d/7851.misc
new file mode 100644
index 0000000000..e5cf540edf
--- /dev/null
+++ b/changelog.d/7851.misc
@@ -0,0 +1 @@
+Convert E2E keys and room keys handlers to async/await.
diff --git a/changelog.d/7853.misc b/changelog.d/7853.misc
new file mode 100644
index 0000000000..b4f614084d
--- /dev/null
+++ b/changelog.d/7853.misc
@@ -0,0 +1 @@
+Add support for handling registration requests across multiple client reader workers.
diff --git a/changelog.d/7854.bugfix b/changelog.d/7854.bugfix
new file mode 100644
index 0000000000..b11f9dedfe
--- /dev/null
+++ b/changelog.d/7854.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in Synapse 1.10.0 which could cause a "no create event in auth events" error during room creation.
diff --git a/changelog.d/7856.misc b/changelog.d/7856.misc
new file mode 100644
index 0000000000..7d99fb67be
--- /dev/null
+++ b/changelog.d/7856.misc
@@ -0,0 +1 @@
+Small performance improvement in typing processing.
diff --git a/changelog.d/7870.misc b/changelog.d/7870.misc
new file mode 100644
index 0000000000..27cce2f2f9
--- /dev/null
+++ b/changelog.d/7870.misc
@@ -0,0 +1 @@
+Add some type annotations to `HomeServer` and `BaseHandler`.
diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst
index 7b030a6285..be05128b3e 100644
--- a/docs/admin_api/user_admin_api.rst
+++ b/docs/admin_api/user_admin_api.rst
@@ -91,10 +91,14 @@ Body parameters:
- ``admin``, optional, defaults to ``false``.
-- ``deactivated``, optional, defaults to ``false``.
+- ``deactivated``, optional. If unspecified, deactivation state will be left
+ unchanged on existing accounts and set to ``false`` for new accounts.
If the user already exists then optional parameters default to the current value.
+In order to re-activate an account ``deactivated`` must be set to ``false``. If
+users do not login via single-sign-on, a new ``password`` must be provided.
+
List Accounts
=============
diff --git a/docs/jwt.md b/docs/jwt.md
index 93b8d05236..5be9fd26e3 100644
--- a/docs/jwt.md
+++ b/docs/jwt.md
@@ -31,10 +31,7 @@ The `token` field should include the JSON web token with the following claims:
Providing the audience claim when not configured will cause validation to fail.
In the case that the token is not valid, the homeserver must respond with
-`401 Unauthorized` and an error code of `M_UNAUTHORIZED`.
-
-(Note that this differs from the token based logins which return a
-`403 Forbidden` and an error code of `M_FORBIDDEN` if an error occurs.)
+`403 Forbidden` and an error code of `M_FORBIDDEN`.
As with other login types, there are additional fields (e.g. `device_id` and
`initial_device_display_name`) which can be included in the above request.
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 94db877050..bd004350dd 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -16,12 +16,14 @@
# limitations under the License.
"""Contains exceptions and error codes."""
-import json
+
import logging
import typing
from http import HTTPStatus
from typing import Dict, List, Optional, Union
+from canonicaljson import json
+
from twisted.web import http
if typing.TYPE_CHECKING:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 2aab9c5f55..8c53330c49 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -14,10 +14,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
import logging
from typing import Any, Callable, Dict, List, Match, Optional, Tuple, Union
+from canonicaljson import json
from prometheus_client import Counter, Histogram
from twisted.internet import defer
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index a2752a54a5..8280f8b900 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -61,8 +61,6 @@ class TransactionManager(object):
# all the edus in that transaction. This needs to be done since there is
# no active span here, so if the edus were not received by the remote the
# span would have no causality and it would be forgotten.
- # The span_contexts is a generator so that it won't be evaluated if
- # opentracing is disabled. (Yay speed!)
span_contexts = []
keep_destination = whitelisted_homeserver(destination)
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 61dc4beafe..6a4944467a 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -17,6 +17,8 @@ import logging
from twisted.internet import defer
+import synapse.state
+import synapse.storage
import synapse.types
from synapse.api.constants import EventTypes, Membership
from synapse.api.ratelimiting import Ratelimiter
@@ -28,10 +30,6 @@ logger = logging.getLogger(__name__)
class BaseHandler(object):
"""
Common base class for the event handlers.
-
- Attributes:
- store (synapse.storage.DataStore):
- state_handler (synapse.state.StateHandler):
"""
def __init__(self, hs):
@@ -39,10 +37,10 @@ class BaseHandler(object):
Args:
hs (synapse.server.HomeServer):
"""
- self.store = hs.get_datastore()
+ self.store = hs.get_datastore() # type: synapse.storage.DataStore
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
- self.state_handler = hs.get_state_handler()
+ self.state_handler = hs.get_state_handler() # type: synapse.state.StateHandler
self.distributor = hs.get_distributor()
self.clock = hs.get_clock()
self.hs = hs
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 591657a5c2..3789b1b495 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import Optional
from synapse.api.errors import SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -46,19 +47,20 @@ class DeactivateAccountHandler(BaseHandler):
self._account_validity_enabled = hs.config.account_validity.enabled
- async def deactivate_account(self, user_id, erase_data, id_server=None):
+ async def deactivate_account(
+ self, user_id: str, erase_data: bool, id_server: Optional[str] = None
+ ) -> bool:
"""Deactivate a user's account
Args:
- user_id (str): ID of user to be deactivated
- erase_data (bool): whether to GDPR-erase the user's data
- id_server (str|None): Use the given identity server when unbinding
+ user_id: ID of user to be deactivated
+ erase_data: whether to GDPR-erase the user's data
+ id_server: Use the given identity server when unbinding
any threepids. If None then will attempt to unbind using the
identity server specified when binding (if known).
Returns:
- Deferred[bool]: True if identity server supports removing
- threepids, otherwise False.
+ True if identity server supports removing threepids, otherwise False.
"""
# FIXME: Theoretically there is a race here wherein user resets
# password using threepid.
@@ -138,11 +140,11 @@ class DeactivateAccountHandler(BaseHandler):
return identity_server_supports_unbinding
- async def _reject_pending_invites_for_user(self, user_id):
+ async def _reject_pending_invites_for_user(self, user_id: str):
"""Reject pending invites addressed to a given user ID.
Args:
- user_id (str): The user ID to reject pending invites for.
+ user_id: The user ID to reject pending invites for.
"""
user = UserID.from_string(user_id)
pending_invites = await self.store.get_invited_rooms_for_local_user(user_id)
@@ -170,22 +172,16 @@ class DeactivateAccountHandler(BaseHandler):
room.room_id,
)
- def _start_user_parting(self):
+ def _start_user_parting(self) -> None:
"""
Start the process that goes through the table of users
pending deactivation, if it isn't already running.
-
- Returns:
- None
"""
if not self._user_parter_running:
run_as_background_process("user_parter_loop", self._user_parter_loop)
- async def _user_parter_loop(self):
+ async def _user_parter_loop(self) -> None:
"""Loop that parts deactivated users from rooms
-
- Returns:
- None
"""
self._user_parter_running = True
logger.info("Starting user parter")
@@ -202,11 +198,8 @@ class DeactivateAccountHandler(BaseHandler):
finally:
self._user_parter_running = False
- async def _part_user(self, user_id):
+ async def _part_user(self, user_id: str) -> None:
"""Causes the given user_id to leave all the rooms they're joined to
-
- Returns:
- None
"""
user = UserID.from_string(user_id)
@@ -228,3 +221,18 @@ class DeactivateAccountHandler(BaseHandler):
user_id,
room_id,
)
+
+ async def activate_account(self, user_id: str) -> None:
+ """
+ Activate an account that was previously deactivated.
+
+ This simply marks the user as activate in the database and does not
+ attempt to rejoin rooms, re-add threepids, etc.
+
+ The user will also need a password hash set to actually login.
+
+ Args:
+ user_id: ID of user to be deactivated
+ """
+ # Mark the user as activate.
+ await self.store.set_user_deactivated_status(user_id, False)
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index a7e60cbc26..361dd64cd2 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -77,8 +77,7 @@ class E2eKeysHandler(object):
)
@trace
- @defer.inlineCallbacks
- def query_devices(self, query_body, timeout, from_user_id):
+ async def query_devices(self, query_body, timeout, from_user_id):
""" Handle a device key query from a client
{
@@ -124,7 +123,7 @@ class E2eKeysHandler(object):
failures = {}
results = {}
if local_query:
- local_result = yield self.query_local_devices(local_query)
+ local_result = await self.query_local_devices(local_query)
for user_id, keys in local_result.items():
if user_id in local_query:
results[user_id] = keys
@@ -142,7 +141,7 @@ class E2eKeysHandler(object):
(
user_ids_not_in_cache,
remote_results,
- ) = yield self.store.get_user_devices_from_cache(query_list)
+ ) = await self.store.get_user_devices_from_cache(query_list)
for user_id, devices in remote_results.items():
user_devices = results.setdefault(user_id, {})
for device_id, device in devices.items():
@@ -161,14 +160,13 @@ class E2eKeysHandler(object):
r[user_id] = remote_queries[user_id]
# Get cached cross-signing keys
- cross_signing_keys = yield self.get_cross_signing_keys_from_cache(
+ cross_signing_keys = await self.get_cross_signing_keys_from_cache(
device_keys_query, from_user_id
)
# Now fetch any devices that we don't have in our cache
@trace
- @defer.inlineCallbacks
- def do_remote_query(destination):
+ async def do_remote_query(destination):
"""This is called when we are querying the device list of a user on
a remote homeserver and their device list is not in the device list
cache. If we share a room with this user and we're not querying for
@@ -192,7 +190,7 @@ class E2eKeysHandler(object):
if device_list:
continue
- room_ids = yield self.store.get_rooms_for_user(user_id)
+ room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids:
continue
@@ -201,11 +199,11 @@ class E2eKeysHandler(object):
# done an initial sync on the device list so we do it now.
try:
if self._is_master:
- user_devices = yield self.device_handler.device_list_updater.user_device_resync(
+ user_devices = await self.device_handler.device_list_updater.user_device_resync(
user_id
)
else:
- user_devices = yield self._user_device_resync_client(
+ user_devices = await self._user_device_resync_client(
user_id=user_id
)
@@ -227,7 +225,7 @@ class E2eKeysHandler(object):
destination_query.pop(user_id)
try:
- remote_result = yield self.federation.query_client_keys(
+ remote_result = await self.federation.query_client_keys(
destination, {"device_keys": destination_query}, timeout=timeout
)
@@ -251,7 +249,7 @@ class E2eKeysHandler(object):
set_tag("error", True)
set_tag("reason", failure)
- yield make_deferred_yieldable(
+ await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(do_remote_query, destination)
@@ -267,8 +265,7 @@ class E2eKeysHandler(object):
return ret
- @defer.inlineCallbacks
- def get_cross_signing_keys_from_cache(self, query, from_user_id):
+ async def get_cross_signing_keys_from_cache(self, query, from_user_id):
"""Get cross-signing keys for users from the database
Args:
@@ -289,7 +286,7 @@ class E2eKeysHandler(object):
user_ids = list(query)
- keys = yield self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id)
+ keys = await self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id)
for user_id, user_info in keys.items():
if user_info is None:
@@ -315,8 +312,7 @@ class E2eKeysHandler(object):
}
@trace
- @defer.inlineCallbacks
- def query_local_devices(self, query):
+ async def query_local_devices(self, query):
"""Get E2E device keys for local users
Args:
@@ -354,7 +350,7 @@ class E2eKeysHandler(object):
# make sure that each queried user appears in the result dict
result_dict[user_id] = {}
- results = yield self.store.get_e2e_device_keys(local_query)
+ results = await self.store.get_e2e_device_keys(local_query)
# Build the result structure
for user_id, device_keys in results.items():
@@ -364,16 +360,15 @@ class E2eKeysHandler(object):
log_kv(results)
return result_dict
- @defer.inlineCallbacks
- def on_federation_query_client_keys(self, query_body):
+ async def on_federation_query_client_keys(self, query_body):
""" Handle a device key query from a federated server
"""
device_keys_query = query_body.get("device_keys", {})
- res = yield self.query_local_devices(device_keys_query)
+ res = await self.query_local_devices(device_keys_query)
ret = {"device_keys": res}
# add in the cross-signing keys
- cross_signing_keys = yield self.get_cross_signing_keys_from_cache(
+ cross_signing_keys = await self.get_cross_signing_keys_from_cache(
device_keys_query, None
)
@@ -382,8 +377,7 @@ class E2eKeysHandler(object):
return ret
@trace
- @defer.inlineCallbacks
- def claim_one_time_keys(self, query, timeout):
+ async def claim_one_time_keys(self, query, timeout):
local_query = []
remote_queries = {}
@@ -399,7 +393,7 @@ class E2eKeysHandler(object):
set_tag("local_key_query", local_query)
set_tag("remote_key_query", remote_queries)
- results = yield self.store.claim_e2e_one_time_keys(local_query)
+ results = await self.store.claim_e2e_one_time_keys(local_query)
json_result = {}
failures = {}
@@ -411,12 +405,11 @@ class E2eKeysHandler(object):
}
@trace
- @defer.inlineCallbacks
- def claim_client_keys(destination):
+ async def claim_client_keys(destination):
set_tag("destination", destination)
device_keys = remote_queries[destination]
try:
- remote_result = yield self.federation.claim_client_keys(
+ remote_result = await self.federation.claim_client_keys(
destination, {"one_time_keys": device_keys}, timeout=timeout
)
for user_id, keys in remote_result["one_time_keys"].items():
@@ -429,7 +422,7 @@ class E2eKeysHandler(object):
set_tag("error", True)
set_tag("reason", failure)
- yield make_deferred_yieldable(
+ await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(claim_client_keys, destination)
@@ -454,9 +447,8 @@ class E2eKeysHandler(object):
log_kv({"one_time_keys": json_result, "failures": failures})
return {"one_time_keys": json_result, "failures": failures}
- @defer.inlineCallbacks
@tag_args
- def upload_keys_for_user(self, user_id, device_id, keys):
+ async def upload_keys_for_user(self, user_id, device_id, keys):
time_now = self.clock.time_msec()
@@ -477,12 +469,12 @@ class E2eKeysHandler(object):
}
)
# TODO: Sign the JSON with the server key
- changed = yield self.store.set_e2e_device_keys(
+ changed = await self.store.set_e2e_device_keys(
user_id, device_id, time_now, device_keys
)
if changed:
# Only notify about device updates *if* the keys actually changed
- yield self.device_handler.notify_device_update(user_id, [device_id])
+ await self.device_handler.notify_device_update(user_id, [device_id])
else:
log_kv({"message": "Not updating device_keys for user", "user_id": user_id})
one_time_keys = keys.get("one_time_keys", None)
@@ -494,7 +486,7 @@ class E2eKeysHandler(object):
"device_id": device_id,
}
)
- yield self._upload_one_time_keys_for_user(
+ await self._upload_one_time_keys_for_user(
user_id, device_id, time_now, one_time_keys
)
else:
@@ -507,15 +499,14 @@ class E2eKeysHandler(object):
# old access_token without an associated device_id. Either way, we
# need to double-check the device is registered to avoid ending up with
# keys without a corresponding device.
- yield self.device_handler.check_device_registered(user_id, device_id)
+ await self.device_handler.check_device_registered(user_id, device_id)
- result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
+ result = await self.store.count_e2e_one_time_keys(user_id, device_id)
set_tag("one_time_key_counts", result)
return {"one_time_key_counts": result}
- @defer.inlineCallbacks
- def _upload_one_time_keys_for_user(
+ async def _upload_one_time_keys_for_user(
self, user_id, device_id, time_now, one_time_keys
):
logger.info(
@@ -533,7 +524,7 @@ class E2eKeysHandler(object):
key_list.append((algorithm, key_id, key_obj))
# First we check if we have already persisted any of the keys.
- existing_key_map = yield self.store.get_e2e_one_time_keys(
+ existing_key_map = await self.store.get_e2e_one_time_keys(
user_id, device_id, [k_id for _, k_id, _ in key_list]
)
@@ -556,10 +547,9 @@ class E2eKeysHandler(object):
)
log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
- yield self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
+ await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
- @defer.inlineCallbacks
- def upload_signing_keys_for_user(self, user_id, keys):
+ async def upload_signing_keys_for_user(self, user_id, keys):
"""Upload signing keys for cross-signing
Args:
@@ -574,7 +564,7 @@ class E2eKeysHandler(object):
_check_cross_signing_key(master_key, user_id, "master")
else:
- master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master")
+ master_key = await self.store.get_e2e_cross_signing_key(user_id, "master")
# if there is no master key, then we can't do anything, because all the
# other cross-signing keys need to be signed by the master key
@@ -613,10 +603,10 @@ class E2eKeysHandler(object):
# if everything checks out, then store the keys and send notifications
deviceids = []
if "master_key" in keys:
- yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
+ await self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
deviceids.append(master_verify_key.version)
if "self_signing_key" in keys:
- yield self.store.set_e2e_cross_signing_key(
+ await self.store.set_e2e_cross_signing_key(
user_id, "self_signing", self_signing_key
)
try:
@@ -626,23 +616,22 @@ class E2eKeysHandler(object):
except ValueError:
raise SynapseError(400, "Invalid self-signing key", Codes.INVALID_PARAM)
if "user_signing_key" in keys:
- yield self.store.set_e2e_cross_signing_key(
+ await self.store.set_e2e_cross_signing_key(
user_id, "user_signing", user_signing_key
)
# the signature stream matches the semantics that we want for
# user-signing key updates: only the user themselves is notified of
# their own user-signing key updates
- yield self.device_handler.notify_user_signature_update(user_id, [user_id])
+ await self.device_handler.notify_user_signature_update(user_id, [user_id])
# master key and self-signing key updates match the semantics of device
# list updates: all users who share an encrypted room are notified
if len(deviceids):
- yield self.device_handler.notify_device_update(user_id, deviceids)
+ await self.device_handler.notify_device_update(user_id, deviceids)
return {}
- @defer.inlineCallbacks
- def upload_signatures_for_device_keys(self, user_id, signatures):
+ async def upload_signatures_for_device_keys(self, user_id, signatures):
"""Upload device signatures for cross-signing
Args:
@@ -667,13 +656,13 @@ class E2eKeysHandler(object):
self_signatures = signatures.get(user_id, {})
other_signatures = {k: v for k, v in signatures.items() if k != user_id}
- self_signature_list, self_failures = yield self._process_self_signatures(
+ self_signature_list, self_failures = await self._process_self_signatures(
user_id, self_signatures
)
signature_list.extend(self_signature_list)
failures.update(self_failures)
- other_signature_list, other_failures = yield self._process_other_signatures(
+ other_signature_list, other_failures = await self._process_other_signatures(
user_id, other_signatures
)
signature_list.extend(other_signature_list)
@@ -681,21 +670,20 @@ class E2eKeysHandler(object):
# store the signature, and send the appropriate notifications for sync
logger.debug("upload signature failures: %r", failures)
- yield self.store.store_e2e_cross_signing_signatures(user_id, signature_list)
+ await self.store.store_e2e_cross_signing_signatures(user_id, signature_list)
self_device_ids = [item.target_device_id for item in self_signature_list]
if self_device_ids:
- yield self.device_handler.notify_device_update(user_id, self_device_ids)
+ await self.device_handler.notify_device_update(user_id, self_device_ids)
signed_users = [item.target_user_id for item in other_signature_list]
if signed_users:
- yield self.device_handler.notify_user_signature_update(
+ await self.device_handler.notify_user_signature_update(
user_id, signed_users
)
return {"failures": failures}
- @defer.inlineCallbacks
- def _process_self_signatures(self, user_id, signatures):
+ async def _process_self_signatures(self, user_id, signatures):
"""Process uploaded signatures of the user's own keys.
Signatures of the user's own keys from this API come in two forms:
@@ -728,7 +716,7 @@ class E2eKeysHandler(object):
_,
self_signing_key_id,
self_signing_verify_key,
- ) = yield self._get_e2e_cross_signing_verify_key(user_id, "self_signing")
+ ) = await self._get_e2e_cross_signing_verify_key(user_id, "self_signing")
# get our master key, since we may have received a signature of it.
# We need to fetch it here so that we know what its key ID is, so
@@ -738,12 +726,12 @@ class E2eKeysHandler(object):
master_key,
_,
master_verify_key,
- ) = yield self._get_e2e_cross_signing_verify_key(user_id, "master")
+ ) = await self._get_e2e_cross_signing_verify_key(user_id, "master")
# fetch our stored devices. This is used to 1. verify
# signatures on the master key, and 2. to compare with what
# was sent if the device was signed
- devices = yield self.store.get_e2e_device_keys([(user_id, None)])
+ devices = await self.store.get_e2e_device_keys([(user_id, None)])
if user_id not in devices:
raise NotFoundError("No device keys found")
@@ -853,8 +841,7 @@ class E2eKeysHandler(object):
return master_key_signature_list
- @defer.inlineCallbacks
- def _process_other_signatures(self, user_id, signatures):
+ async def _process_other_signatures(self, user_id, signatures):
"""Process uploaded signatures of other users' keys. These will be the
target user's master keys, signed by the uploading user's user-signing
key.
@@ -882,7 +869,7 @@ class E2eKeysHandler(object):
user_signing_key,
user_signing_key_id,
user_signing_verify_key,
- ) = yield self._get_e2e_cross_signing_verify_key(user_id, "user_signing")
+ ) = await self._get_e2e_cross_signing_verify_key(user_id, "user_signing")
except SynapseError as e:
failure = _exception_to_failure(e)
for user, devicemap in signatures.items():
@@ -905,7 +892,7 @@ class E2eKeysHandler(object):
master_key,
master_key_id,
_,
- ) = yield self._get_e2e_cross_signing_verify_key(
+ ) = await self._get_e2e_cross_signing_verify_key(
target_user, "master", user_id
)
@@ -958,8 +945,7 @@ class E2eKeysHandler(object):
return signature_list, failures
- @defer.inlineCallbacks
- def _get_e2e_cross_signing_verify_key(
+ async def _get_e2e_cross_signing_verify_key(
self, user_id: str, key_type: str, from_user_id: str = None
):
"""Fetch locally or remotely query for a cross-signing public key.
@@ -983,7 +969,7 @@ class E2eKeysHandler(object):
SynapseError: if `user_id` is invalid
"""
user = UserID.from_string(user_id)
- key = yield self.store.get_e2e_cross_signing_key(
+ key = await self.store.get_e2e_cross_signing_key(
user_id, key_type, from_user_id
)
@@ -1009,15 +995,14 @@ class E2eKeysHandler(object):
key,
key_id,
verify_key,
- ) = yield self._retrieve_cross_signing_keys_for_remote_user(user, key_type)
+ ) = await self._retrieve_cross_signing_keys_for_remote_user(user, key_type)
if key is None:
raise NotFoundError("No %s key found for %s" % (key_type, user_id))
return key, key_id, verify_key
- @defer.inlineCallbacks
- def _retrieve_cross_signing_keys_for_remote_user(
+ async def _retrieve_cross_signing_keys_for_remote_user(
self, user: UserID, desired_key_type: str,
):
"""Queries cross-signing keys for a remote user and saves them to the database
@@ -1035,7 +1020,7 @@ class E2eKeysHandler(object):
If the key cannot be retrieved, all values in the tuple will instead be None.
"""
try:
- remote_result = yield self.federation.query_user_devices(
+ remote_result = await self.federation.query_user_devices(
user.domain, user.to_string()
)
except Exception as e:
@@ -1101,14 +1086,14 @@ class E2eKeysHandler(object):
desired_key_id = key_id
# At the same time, store this key in the db for subsequent queries
- yield self.store.set_e2e_cross_signing_key(
+ await self.store.set_e2e_cross_signing_key(
user.to_string(), key_type, key_content
)
# Notify clients that new devices for this user have been discovered
if retrieved_device_ids:
# XXX is this necessary?
- yield self.device_handler.notify_device_update(
+ await self.device_handler.notify_device_update(
user.to_string(), retrieved_device_ids
)
@@ -1250,8 +1235,7 @@ class SigningKeyEduUpdater(object):
iterable=True,
)
- @defer.inlineCallbacks
- def incoming_signing_key_update(self, origin, edu_content):
+ async def incoming_signing_key_update(self, origin, edu_content):
"""Called on incoming signing key update from federation. Responsible for
parsing the EDU and adding to pending updates list.
@@ -1268,7 +1252,7 @@ class SigningKeyEduUpdater(object):
logger.warning("Got signing key update edu for %r from %r", user_id, origin)
return
- room_ids = yield self.store.get_rooms_for_user(user_id)
+ room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids:
# We don't share any rooms with this user. Ignore update, as we
# probably won't get any further updates.
@@ -1278,10 +1262,9 @@ class SigningKeyEduUpdater(object):
(master_key, self_signing_key)
)
- yield self._handle_signing_key_updates(user_id)
+ await self._handle_signing_key_updates(user_id)
- @defer.inlineCallbacks
- def _handle_signing_key_updates(self, user_id):
+ async def _handle_signing_key_updates(self, user_id):
"""Actually handle pending updates.
Args:
@@ -1291,7 +1274,7 @@ class SigningKeyEduUpdater(object):
device_handler = self.e2e_keys_handler.device_handler
device_list_updater = device_handler.device_list_updater
- with (yield self._remote_edu_linearizer.queue(user_id)):
+ with (await self._remote_edu_linearizer.queue(user_id)):
pending_updates = self._pending_updates.pop(user_id, [])
if not pending_updates:
# This can happen since we batch updates
@@ -1302,9 +1285,9 @@ class SigningKeyEduUpdater(object):
logger.info("pending updates: %r", pending_updates)
for master_key, self_signing_key in pending_updates:
- new_device_ids = yield device_list_updater.process_cross_signing_key_update(
+ new_device_ids = await device_list_updater.process_cross_signing_key_update(
user_id, master_key, self_signing_key,
)
device_ids = device_ids + new_device_ids
- yield device_handler.notify_device_update(user_id, device_ids)
+ await device_handler.notify_device_update(user_id, device_ids)
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index f55470a707..0bb983dc28 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -16,8 +16,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import (
Codes,
NotFoundError,
@@ -50,8 +48,7 @@ class E2eRoomKeysHandler(object):
self._upload_linearizer = Linearizer("upload_room_keys_lock")
@trace
- @defer.inlineCallbacks
- def get_room_keys(self, user_id, version, room_id=None, session_id=None):
+ async def get_room_keys(self, user_id, version, room_id=None, session_id=None):
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session.
See EndToEndRoomKeyStore.get_e2e_room_keys for full details.
@@ -71,17 +68,17 @@ class E2eRoomKeysHandler(object):
# we deliberately take the lock to get keys so that changing the version
# works atomically
- with (yield self._upload_linearizer.queue(user_id)):
+ with (await self._upload_linearizer.queue(user_id)):
# make sure the backup version exists
try:
- yield self.store.get_e2e_room_keys_version_info(user_id, version)
+ await self.store.get_e2e_room_keys_version_info(user_id, version)
except StoreError as e:
if e.code == 404:
raise NotFoundError("Unknown backup version")
else:
raise
- results = yield self.store.get_e2e_room_keys(
+ results = await self.store.get_e2e_room_keys(
user_id, version, room_id, session_id
)
@@ -89,8 +86,7 @@ class E2eRoomKeysHandler(object):
return results
@trace
- @defer.inlineCallbacks
- def delete_room_keys(self, user_id, version, room_id=None, session_id=None):
+ async def delete_room_keys(self, user_id, version, room_id=None, session_id=None):
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
room or a given session.
See EndToEndRoomKeyStore.delete_e2e_room_keys for full details.
@@ -109,10 +105,10 @@ class E2eRoomKeysHandler(object):
"""
# lock for consistency with uploading
- with (yield self._upload_linearizer.queue(user_id)):
+ with (await self._upload_linearizer.queue(user_id)):
# make sure the backup version exists
try:
- version_info = yield self.store.get_e2e_room_keys_version_info(
+ version_info = await self.store.get_e2e_room_keys_version_info(
user_id, version
)
except StoreError as e:
@@ -121,19 +117,18 @@ class E2eRoomKeysHandler(object):
else:
raise
- yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id)
+ await self.store.delete_e2e_room_keys(user_id, version, room_id, session_id)
version_etag = version_info["etag"] + 1
- yield self.store.update_e2e_room_keys_version(
+ await self.store.update_e2e_room_keys_version(
user_id, version, None, version_etag
)
- count = yield self.store.count_e2e_room_keys(user_id, version)
+ count = await self.store.count_e2e_room_keys(user_id, version)
return {"etag": str(version_etag), "count": count}
@trace
- @defer.inlineCallbacks
- def upload_room_keys(self, user_id, version, room_keys):
+ async def upload_room_keys(self, user_id, version, room_keys):
"""Bulk upload a list of room keys into a given backup version, asserting
that the given version is the current backup version. room_keys are merged
into the current backup as described in RoomKeysServlet.on_PUT().
@@ -169,11 +164,11 @@ class E2eRoomKeysHandler(object):
# TODO: Validate the JSON to make sure it has the right keys.
# XXX: perhaps we should use a finer grained lock here?
- with (yield self._upload_linearizer.queue(user_id)):
+ with (await self._upload_linearizer.queue(user_id)):
# Check that the version we're trying to upload is the current version
try:
- version_info = yield self.store.get_e2e_room_keys_version_info(user_id)
+ version_info = await self.store.get_e2e_room_keys_version_info(user_id)
except StoreError as e:
if e.code == 404:
raise NotFoundError("Version '%s' not found" % (version,))
@@ -183,7 +178,7 @@ class E2eRoomKeysHandler(object):
if version_info["version"] != version:
# Check that the version we're trying to upload actually exists
try:
- version_info = yield self.store.get_e2e_room_keys_version_info(
+ version_info = await self.store.get_e2e_room_keys_version_info(
user_id, version
)
# if we get this far, the version must exist
@@ -198,7 +193,7 @@ class E2eRoomKeysHandler(object):
# submitted. Then compare them with the submitted keys. If the
# key is new, insert it; if the key should be updated, then update
# it; otherwise, drop it.
- existing_keys = yield self.store.get_e2e_room_keys_multi(
+ existing_keys = await self.store.get_e2e_room_keys_multi(
user_id, version, room_keys["rooms"]
)
to_insert = [] # batch the inserts together
@@ -227,7 +222,7 @@ class E2eRoomKeysHandler(object):
# updates are done one at a time in the DB, so send
# updates right away rather than batching them up,
# like we do with the inserts
- yield self.store.update_e2e_room_key(
+ await self.store.update_e2e_room_key(
user_id, version, room_id, session_id, room_key
)
changed = True
@@ -246,16 +241,16 @@ class E2eRoomKeysHandler(object):
changed = True
if len(to_insert):
- yield self.store.add_e2e_room_keys(user_id, version, to_insert)
+ await self.store.add_e2e_room_keys(user_id, version, to_insert)
version_etag = version_info["etag"]
if changed:
version_etag = version_etag + 1
- yield self.store.update_e2e_room_keys_version(
+ await self.store.update_e2e_room_keys_version(
user_id, version, None, version_etag
)
- count = yield self.store.count_e2e_room_keys(user_id, version)
+ count = await self.store.count_e2e_room_keys(user_id, version)
return {"etag": str(version_etag), "count": count}
@staticmethod
@@ -291,8 +286,7 @@ class E2eRoomKeysHandler(object):
return True
@trace
- @defer.inlineCallbacks
- def create_version(self, user_id, version_info):
+ async def create_version(self, user_id, version_info):
"""Create a new backup version. This automatically becomes the new
backup version for the user's keys; previous backups will no longer be
writeable to.
@@ -313,14 +307,13 @@ class E2eRoomKeysHandler(object):
# TODO: Validate the JSON to make sure it has the right keys.
# lock everyone out until we've switched version
- with (yield self._upload_linearizer.queue(user_id)):
- new_version = yield self.store.create_e2e_room_keys_version(
+ with (await self._upload_linearizer.queue(user_id)):
+ new_version = await self.store.create_e2e_room_keys_version(
user_id, version_info
)
return new_version
- @defer.inlineCallbacks
- def get_version_info(self, user_id, version=None):
+ async def get_version_info(self, user_id, version=None):
"""Get the info about a given version of the user's backup
Args:
@@ -339,22 +332,21 @@ class E2eRoomKeysHandler(object):
}
"""
- with (yield self._upload_linearizer.queue(user_id)):
+ with (await self._upload_linearizer.queue(user_id)):
try:
- res = yield self.store.get_e2e_room_keys_version_info(user_id, version)
+ res = await self.store.get_e2e_room_keys_version_info(user_id, version)
except StoreError as e:
if e.code == 404:
raise NotFoundError("Unknown backup version")
else:
raise
- res["count"] = yield self.store.count_e2e_room_keys(user_id, res["version"])
+ res["count"] = await self.store.count_e2e_room_keys(user_id, res["version"])
res["etag"] = str(res["etag"])
return res
@trace
- @defer.inlineCallbacks
- def delete_version(self, user_id, version=None):
+ async def delete_version(self, user_id, version=None):
"""Deletes a given version of the user's e2e_room_keys backup
Args:
@@ -364,9 +356,9 @@ class E2eRoomKeysHandler(object):
NotFoundError: if this backup version doesn't exist
"""
- with (yield self._upload_linearizer.queue(user_id)):
+ with (await self._upload_linearizer.queue(user_id)):
try:
- yield self.store.delete_e2e_room_keys_version(user_id, version)
+ await self.store.delete_e2e_room_keys_version(user_id, version)
except StoreError as e:
if e.code == 404:
raise NotFoundError("Unknown backup version")
@@ -374,8 +366,7 @@ class E2eRoomKeysHandler(object):
raise
@trace
- @defer.inlineCallbacks
- def update_version(self, user_id, version, version_info):
+ async def update_version(self, user_id, version, version_info):
"""Update the info about a given version of the user's backup
Args:
@@ -393,9 +384,9 @@ class E2eRoomKeysHandler(object):
raise SynapseError(
400, "Version in body does not match", Codes.INVALID_PARAM
)
- with (yield self._upload_linearizer.queue(user_id)):
+ with (await self._upload_linearizer.queue(user_id)):
try:
- old_info = yield self.store.get_e2e_room_keys_version_info(
+ old_info = await self.store.get_e2e_room_keys_version_info(
user_id, version
)
except StoreError as e:
@@ -406,7 +397,7 @@ class E2eRoomKeysHandler(object):
if old_info["algorithm"] != version_info["algorithm"]:
raise SynapseError(400, "Algorithm does not match", Codes.INVALID_PARAM)
- yield self.store.update_e2e_room_keys_version(
+ await self.store.update_e2e_room_keys_version(
user_id, version, version_info
)
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 879c4c07c6..846ddbdc6c 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -185,7 +185,7 @@ class TypingHandler(object):
async def _push_remote(self, member, typing):
try:
- users = await self.state.get_current_users_in_room(member.room_id)
+ users = await self.store.get_users_in_room(member.room_id)
self._member_last_federation_poke[member] = self.clock.time_msec()
now = self.clock.time_msec()
@@ -224,7 +224,7 @@ class TypingHandler(object):
)
return
- users = await self.state.get_current_users_in_room(room_id)
+ users = await self.store.get_users_in_room(room_id)
domains = {get_domain_from_id(u) for u in users}
if self.server_name in domains:
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 8b24a73319..a140e9391e 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import logging
from canonicaljson import json
@@ -117,7 +118,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
except PartialDownloadError as pde:
# Twisted is silly
data = pde.response
- resp_body = json.loads(data)
+ resp_body = json.loads(data.decode("utf-8"))
if "success" in resp_body:
# Note that we do NOT check the hostname here: we explicitly
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 505872ee90..6bc51202cd 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -13,13 +13,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
+
import logging
import urllib
from io import BytesIO
import treq
-from canonicaljson import encode_canonical_json
+from canonicaljson import encode_canonical_json, json
from netaddr import IPAddress
from prometheus_client import Counter
from zope.interface import implementer, provider
@@ -31,6 +31,7 @@ from twisted.internet.interfaces import (
IReactorPluggableNameResolver,
IResolutionReceiver,
)
+from twisted.internet.task import Cooperator
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
from twisted.web.client import Agent, HTTPConnectionPool, readBody
@@ -69,6 +70,21 @@ def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
return False
+_EPSILON = 0.00000001
+
+
+def _make_scheduler(reactor):
+ """Makes a schedular suitable for a Cooperator using the given reactor.
+
+ (This is effectively just a copy from `twisted.internet.task`)
+ """
+
+ def _scheduler(x):
+ return reactor.callLater(_EPSILON, x)
+
+ return _scheduler
+
+
class IPBlacklistingResolver(object):
"""
A proxy for reactor.nameResolver which only produces non-blacklisted IP
@@ -212,6 +228,10 @@ class SimpleHttpClient(object):
if hs.config.user_agent_suffix:
self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix)
+ # We use this for our body producers to ensure that they use the correct
+ # reactor.
+ self._cooperator = Cooperator(scheduler=_make_scheduler(hs.get_reactor()))
+
self.user_agent = self.user_agent.encode("ascii")
if self._ip_blacklist:
@@ -292,7 +312,9 @@ class SimpleHttpClient(object):
try:
body_producer = None
if data is not None:
- body_producer = QuieterFileBodyProducer(BytesIO(data))
+ body_producer = QuieterFileBodyProducer(
+ BytesIO(data), cooperator=self._cooperator,
+ )
request_deferred = treq.request(
method,
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 3cabe9d02e..a34e5ead88 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -14,9 +14,11 @@
# limitations under the License.
""" This module contains base REST classes for constructing REST servlets. """
-import json
+
import logging
+from canonicaljson import json
+
from synapse.api.errors import Codes, SynapseError
logger = logging.getLogger(__name__)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index e4330c39d6..cc0bdfa5c9 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -239,6 +239,15 @@ class UserRestServletV2(RestServlet):
await self.deactivate_account_handler.deactivate_account(
target_user.to_string(), False
)
+ elif not deactivate and user["deactivated"]:
+ if "password" not in body:
+ raise SynapseError(
+ 400, "Must provide a password to re-activate an account."
+ )
+
+ await self.deactivate_account_handler.activate_account(
+ target_user.to_string()
+ )
user = await self.admin_handler.get_user(target_user)
return 200, user
@@ -254,7 +263,6 @@ class UserRestServletV2(RestServlet):
admin = body.get("admin", None)
user_type = body.get("user_type", None)
displayname = body.get("displayname", None)
- threepids = body.get("threepids", None)
if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
raise SynapseError(400, "Invalid user type")
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 326ffa0056..379f668d6f 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -371,7 +371,7 @@ class LoginRestServlet(RestServlet):
token = login_submission.get("token", None)
if token is None:
raise LoginError(
- 401, "Token field for JWT is missing", errcode=Codes.UNAUTHORIZED
+ 403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN
)
import jwt
@@ -387,14 +387,12 @@ class LoginRestServlet(RestServlet):
except jwt.PyJWTError as e:
# A JWT error occurred, return some info back to the client.
raise LoginError(
- 401,
- "JWT validation failed: %s" % (str(e),),
- errcode=Codes.UNAUTHORIZED,
+ 403, "JWT validation failed: %s" % (str(e),), errcode=Codes.FORBIDDEN,
)
user = payload.get("sub", None)
if user is None:
- raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
+ raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
user_id = UserID(user, self.hs.hostname).to_string()
result = await self._complete_login(
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 01b80b86fa..c5a84af047 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -15,6 +15,7 @@
# limitations under the License.
""" This module contains REST servlets to do with rooms: /rooms/<paths> """
+
import logging
import re
from typing import List, Optional
@@ -515,9 +516,9 @@ class RoomMessageListRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(request, default_limit=10)
as_client_event = b"raw" not in request.args
- filter_bytes = parse_string(request, b"filter", encoding=None)
- if filter_bytes:
- filter_json = urlparse.unquote(filter_bytes.decode("UTF-8"))
+ filter_str = parse_string(request, b"filter", encoding="utf-8")
+ if filter_str:
+ filter_json = urlparse.unquote(filter_str)
event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
if (
event_filter
@@ -627,9 +628,9 @@ class RoomEventContextServlet(RestServlet):
limit = parse_integer(request, "limit", default=10)
# picking the API shape for symmetry with /messages
- filter_bytes = parse_string(request, "filter")
- if filter_bytes:
- filter_json = urlparse.unquote(filter_bytes)
+ filter_str = parse_string(request, b"filter", encoding="utf-8")
+ if filter_str:
+ filter_json = urlparse.unquote(filter_str)
event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
else:
event_filter = None
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index e149ac1733..9b3f85b306 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -202,9 +202,11 @@ class RemoteKey(DirectServeJsonResource):
if miss:
cache_misses.setdefault(server_name, set()).add(key_id)
+ # Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(most_recent_result["key_json"]))
else:
for ts_added, result in results:
+ # Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(result["key_json"]))
if cache_misses and query_remote_on_cache_miss:
@@ -213,7 +215,7 @@ class RemoteKey(DirectServeJsonResource):
else:
signed_keys = []
for key_json in json_results:
- key_json = json.loads(key_json)
+ key_json = json.loads(key_json.decode("utf-8"))
for signing_key in self.config.key_server_signing_keys:
key_json = sign_json(key_json, self.config.server_name, signing_key)
diff --git a/synapse/server.py b/synapse/server.py
index ca42c2195a..f838a03d71 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -106,7 +106,7 @@ from synapse.server_notices.worker_server_notices_sender import (
WorkerServerNoticesSender,
)
from synapse.state import StateHandler, StateResolutionHandler
-from synapse.storage import DataStores, Storage
+from synapse.storage import DataStore, DataStores, Storage
from synapse.streams.events import EventSources
from synapse.util import Clock
from synapse.util.distributor import Distributor
@@ -314,7 +314,7 @@ class HomeServer(object):
def get_clock(self):
return self.clock
- def get_datastore(self):
+ def get_datastore(self) -> DataStore:
return self.datastores.main
def get_datastores(self):
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 58cd099e6d..cd50c721b8 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -20,6 +20,7 @@ import synapse.handlers.room
import synapse.handlers.room_member
import synapse.handlers.set_password
import synapse.http.client
+import synapse.http.matrixfederationclient
import synapse.notifier
import synapse.push.pusherpool
import synapse.replication.tcp.client
@@ -143,3 +144,7 @@ class HomeServer(object):
pass
def get_replication_streams(self) -> Dict[str, Stream]:
pass
+ def get_http_client(
+ self,
+ ) -> synapse.http.matrixfederationclient.MatrixFederationHttpClient:
+ pass
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 230fb5cd7f..66f01aad84 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -17,7 +17,6 @@
import itertools
import logging
from collections import OrderedDict, namedtuple
-from functools import wraps
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
import attr
@@ -69,27 +68,6 @@ def encode_json(json_object):
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
-def _retry_on_integrity_error(func):
- """Wraps a database function so that it gets retried on IntegrityError,
- with `delete_existing=True` passed in.
-
- Args:
- func: function that returns a Deferred and accepts a `delete_existing` arg
- """
-
- @wraps(func)
- @defer.inlineCallbacks
- def f(self, *args, **kwargs):
- try:
- res = yield func(self, *args, delete_existing=False, **kwargs)
- except self.database_engine.module.IntegrityError:
- logger.exception("IntegrityError, retrying.")
- res = yield func(self, *args, delete_existing=True, **kwargs)
- return res
-
- return f
-
-
@attr.s(slots=True)
class DeltaState:
"""Deltas to use to update the `current_state_events` table.
@@ -134,7 +112,6 @@ class PersistEventsStore:
hs.config.worker.writers.events == hs.get_instance_name()
), "Can only instantiate EventsStore on master"
- @_retry_on_integrity_error
@defer.inlineCallbacks
def _persist_events_and_state_updates(
self,
@@ -143,7 +120,6 @@ class PersistEventsStore:
state_delta_for_room: Dict[str, DeltaState],
new_forward_extremeties: Dict[str, List[str]],
backfilled: bool = False,
- delete_existing: bool = False,
):
"""Persist a set of events alongside updates to the current state and
forward extremities tables.
@@ -157,7 +133,6 @@ class PersistEventsStore:
new_forward_extremities: Map from room_id to list of event IDs
that are the new forward extremities of the room.
backfilled
- delete_existing
Returns:
Deferred: resolves when the events have been persisted
@@ -197,7 +172,6 @@ class PersistEventsStore:
self._persist_events_txn,
events_and_contexts=events_and_contexts,
backfilled=backfilled,
- delete_existing=delete_existing,
state_delta_for_room=state_delta_for_room,
new_forward_extremeties=new_forward_extremeties,
)
@@ -341,7 +315,6 @@ class PersistEventsStore:
txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool,
- delete_existing: bool = False,
state_delta_for_room: Dict[str, DeltaState] = {},
new_forward_extremeties: Dict[str, List[str]] = {},
):
@@ -393,13 +366,6 @@ class PersistEventsStore:
# From this point onwards the events are only events that we haven't
# seen before.
- if delete_existing:
- # For paranoia reasons, we go and delete all the existing entries
- # for these events so we can reinsert them.
- # This gets around any problems with some tables already having
- # entries.
- self._delete_existing_rows_txn(txn, events_and_contexts=events_and_contexts)
-
self._store_event_txn(txn, events_and_contexts=events_and_contexts)
# Insert into event_to_state_groups.
@@ -797,39 +763,6 @@ class PersistEventsStore:
return [ec for ec in events_and_contexts if ec[0] not in to_remove]
- @classmethod
- def _delete_existing_rows_txn(cls, txn, events_and_contexts):
- if not events_and_contexts:
- # nothing to do here
- return
-
- logger.info("Deleting existing")
-
- for table in (
- "events",
- "event_auth",
- "event_json",
- "event_edges",
- "event_forward_extremities",
- "event_reference_hashes",
- "event_search",
- "event_to_state_groups",
- "state_events",
- "rejections",
- "redactions",
- "room_memberships",
- ):
- txn.executemany(
- "DELETE FROM %s WHERE event_id = ?" % (table,),
- [(ev.event_id,) for ev, _ in events_and_contexts],
- )
-
- for table in ("event_push_actions",):
- txn.executemany(
- "DELETE FROM %s WHERE room_id = ? AND event_id = ?" % (table,),
- [(ev.room_id, ev.event_id) for ev, _ in events_and_contexts],
- )
-
def _store_event_txn(self, txn, events_and_contexts):
"""Insert new events into the event and event_json tables
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index 347cc50778..bb38a04ede 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -353,6 +353,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
last_room_id = progress.get("last_room_id", "")
def _background_remove_left_rooms_txn(txn):
+ # get a batch of room ids to consider
sql = """
SELECT DISTINCT room_id FROM current_state_events
WHERE room_id > ? ORDER BY room_id LIMIT ?
@@ -363,24 +364,68 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
if not room_ids:
return True, set()
+ ###########################################################################
+ #
+ # exclude rooms where we have active members
+
sql = """
SELECT room_id
- FROM current_state_events
+ FROM local_current_membership
WHERE
room_id > ? AND room_id <= ?
- AND type = 'm.room.member'
AND membership = 'join'
- AND state_key LIKE ?
GROUP BY room_id
"""
- txn.execute(sql, (last_room_id, room_ids[-1], "%:" + self.server_name))
-
+ txn.execute(sql, (last_room_id, room_ids[-1]))
joined_room_ids = {row[0] for row in txn}
+ to_delete = set(room_ids) - joined_room_ids
+
+ ###########################################################################
+ #
+ # exclude rooms which we are in the process of constructing; these otherwise
+ # qualify as "rooms with no local users", and would have their
+ # forward extremities cleaned up.
+
+ # the following query will return a list of rooms which have forward
+ # extremities that are *not* also the create event in the room - ie
+ # those that are not being created currently.
+
+ sql = """
+ SELECT DISTINCT efe.room_id
+ FROM event_forward_extremities efe
+ LEFT JOIN current_state_events cse ON
+ cse.event_id = efe.event_id
+ AND cse.type = 'm.room.create'
+ AND cse.state_key = ''
+ WHERE
+ cse.event_id IS NULL
+ AND efe.room_id > ? AND efe.room_id <= ?
+ """
+
+ txn.execute(sql, (last_room_id, room_ids[-1]))
+
+ # build a set of those rooms within `to_delete` that do not appear in
+ # the above, leaving us with the rooms in `to_delete` that *are* being
+ # created.
+ creating_rooms = to_delete.difference(row[0] for row in txn)
+ logger.info("skipping rooms which are being created: %s", creating_rooms)
+
+ # now remove the rooms being created from the list of those to delete.
+ #
+ # (we could have just taken the intersection of `to_delete` with the result
+ # of the sql query, but it's useful to be able to log `creating_rooms`; and
+ # having done so, it's quicker to remove the (few) creating rooms from
+ # `to_delete` than it is to form the intersection with the (larger) list of
+ # not-creating-rooms)
+
+ to_delete -= creating_rooms
- left_rooms = set(room_ids) - joined_room_ids
+ ###########################################################################
+ #
+ # now clear the state for the rooms
- logger.info("Deleting current state left rooms: %r", left_rooms)
+ logger.info("Deleting current state left rooms: %r", to_delete)
# First we get all users that we still think were joined to the
# room. This is so that we can mark those device lists as
@@ -391,7 +436,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
txn,
table="current_state_events",
column="room_id",
- iterable=left_rooms,
+ iterable=to_delete,
keyvalues={"type": EventTypes.Member, "membership": Membership.JOIN},
retcols=("state_key",),
)
@@ -403,7 +448,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
txn,
table="current_state_events",
column="room_id",
- iterable=left_rooms,
+ iterable=to_delete,
keyvalues={},
)
@@ -411,7 +456,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
txn,
table="event_forward_extremities",
column="room_id",
- iterable=left_rooms,
+ iterable=to_delete,
keyvalues={},
)
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 1acf287ca4..cdd093ffa8 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -46,7 +46,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"""If the user has no devices, we expect an empty list.
"""
local_user = "@boris:" + self.hs.hostname
- res = yield self.handler.query_local_devices({local_user: None})
+ res = yield defer.ensureDeferred(
+ self.handler.query_local_devices({local_user: None})
+ )
self.assertDictEqual(res, {local_user: {}})
@defer.inlineCallbacks
@@ -60,15 +62,19 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"alg2:k3": {"key": "key3"},
}
- res = yield self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": keys}
+ res = yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": keys}
+ )
)
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
# we should be able to change the signature without a problem
keys["alg2:k2"]["signatures"]["k1"] = "sig2"
- res = yield self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": keys}
+ res = yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": keys}
+ )
)
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
@@ -84,44 +90,56 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"alg2:k3": {"key": "key3"},
}
- res = yield self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": keys}
+ res = yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": keys}
+ )
)
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
try:
- yield self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
+ )
)
self.fail("No error when changing string key")
except errors.SynapseError:
pass
try:
- yield self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
+ )
)
self.fail("No error when replacing dict key with string")
except errors.SynapseError:
pass
try:
- yield self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": {"alg1:k1": {"key": "key"}}}
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user,
+ device_id,
+ {"one_time_keys": {"alg1:k1": {"key": "key"}}},
+ )
)
self.fail("No error when replacing string key with dict")
except errors.SynapseError:
pass
try:
- yield self.handler.upload_keys_for_user(
- local_user,
- device_id,
- {
- "one_time_keys": {
- "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
- }
- },
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user,
+ device_id,
+ {
+ "one_time_keys": {
+ "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
+ }
+ },
+ )
)
self.fail("No error when replacing dict key")
except errors.SynapseError:
@@ -133,13 +151,17 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
device_id = "xyz"
keys = {"alg1:k1": "key1"}
- res = yield self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": keys}
+ res = yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": keys}
+ )
)
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}})
- res2 = yield self.handler.claim_one_time_keys(
- {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ res2 = yield defer.ensureDeferred(
+ self.handler.claim_one_time_keys(
+ {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ )
)
self.assertEqual(
res2,
@@ -163,7 +185,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
},
}
}
- yield self.handler.upload_signing_keys_for_user(local_user, keys1)
+ yield defer.ensureDeferred(
+ self.handler.upload_signing_keys_for_user(local_user, keys1)
+ )
keys2 = {
"master_key": {
@@ -175,10 +199,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
},
}
}
- yield self.handler.upload_signing_keys_for_user(local_user, keys2)
+ yield defer.ensureDeferred(
+ self.handler.upload_signing_keys_for_user(local_user, keys2)
+ )
- devices = yield self.handler.query_devices(
- {"device_keys": {local_user: []}}, 0, local_user
+ devices = yield defer.ensureDeferred(
+ self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
)
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
@@ -215,7 +241,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
"2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0",
)
- yield self.handler.upload_signing_keys_for_user(local_user, keys1)
+ yield defer.ensureDeferred(
+ self.handler.upload_signing_keys_for_user(local_user, keys1)
+ )
# upload two device keys, which will be signed later by the self-signing key
device_key_1 = {
@@ -245,18 +273,24 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"signatures": {local_user: {"ed25519:def": "base64+signature"}},
}
- yield self.handler.upload_keys_for_user(
- local_user, "abc", {"device_keys": device_key_1}
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, "abc", {"device_keys": device_key_1}
+ )
)
- yield self.handler.upload_keys_for_user(
- local_user, "def", {"device_keys": device_key_2}
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, "def", {"device_keys": device_key_2}
+ )
)
# sign the first device key and upload it
del device_key_1["signatures"]
sign.sign_json(device_key_1, local_user, signing_key)
- yield self.handler.upload_signatures_for_device_keys(
- local_user, {local_user: {"abc": device_key_1}}
+ yield defer.ensureDeferred(
+ self.handler.upload_signatures_for_device_keys(
+ local_user, {local_user: {"abc": device_key_1}}
+ )
)
# sign the second device key and upload both device keys. The server
@@ -264,14 +298,16 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
# signature for it
del device_key_2["signatures"]
sign.sign_json(device_key_2, local_user, signing_key)
- yield self.handler.upload_signatures_for_device_keys(
- local_user, {local_user: {"abc": device_key_1, "def": device_key_2}}
+ yield defer.ensureDeferred(
+ self.handler.upload_signatures_for_device_keys(
+ local_user, {local_user: {"abc": device_key_1, "def": device_key_2}}
+ )
)
device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature"
device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature"
- devices = yield self.handler.query_devices(
- {"device_keys": {local_user: []}}, 0, local_user
+ devices = yield defer.ensureDeferred(
+ self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
)
del devices["device_keys"][local_user]["abc"]["unsigned"]
del devices["device_keys"][local_user]["def"]["unsigned"]
@@ -292,7 +328,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
},
}
}
- yield self.handler.upload_signing_keys_for_user(local_user, keys1)
+ yield defer.ensureDeferred(
+ self.handler.upload_signing_keys_for_user(local_user, keys1)
+ )
res = None
try:
@@ -305,7 +343,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
res = e.code
self.assertEqual(res, 400)
- res = yield self.handler.query_local_devices({local_user: None})
+ res = yield defer.ensureDeferred(
+ self.handler.query_local_devices({local_user: None})
+ )
self.assertDictEqual(res, {local_user: {}})
@defer.inlineCallbacks
@@ -331,8 +371,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA"
)
- yield self.handler.upload_keys_for_user(
- local_user, device_id, {"device_keys": device_key}
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"device_keys": device_key}
+ )
)
# private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0
@@ -372,7 +414,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"user_signing_key": usersigning_key,
"self_signing_key": selfsigning_key,
}
- yield self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys)
+ yield defer.ensureDeferred(
+ self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys)
+ )
# set up another user with a master key. This user will be signed by
# the first user
@@ -384,76 +428,90 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"usage": ["master"],
"keys": {"ed25519:" + other_master_pubkey: other_master_pubkey},
}
- yield self.handler.upload_signing_keys_for_user(
- other_user, {"master_key": other_master_key}
+ yield defer.ensureDeferred(
+ self.handler.upload_signing_keys_for_user(
+ other_user, {"master_key": other_master_key}
+ )
)
# test various signature failures (see below)
- ret = yield self.handler.upload_signatures_for_device_keys(
- local_user,
- {
- local_user: {
- # fails because the signature is invalid
- # should fail with INVALID_SIGNATURE
- device_id: {
- "user_id": local_user,
- "device_id": device_id,
- "algorithms": [
- "m.olm.curve25519-aes-sha2",
- RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
- ],
- "keys": {
- "curve25519:xyz": "curve25519+key",
- # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
- "ed25519:xyz": device_pubkey,
- },
- "signatures": {
- local_user: {"ed25519:" + selfsigning_pubkey: "something"}
+ ret = yield defer.ensureDeferred(
+ self.handler.upload_signatures_for_device_keys(
+ local_user,
+ {
+ local_user: {
+ # fails because the signature is invalid
+ # should fail with INVALID_SIGNATURE
+ device_id: {
+ "user_id": local_user,
+ "device_id": device_id,
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
+ "keys": {
+ "curve25519:xyz": "curve25519+key",
+ # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
+ "ed25519:xyz": device_pubkey,
+ },
+ "signatures": {
+ local_user: {
+ "ed25519:" + selfsigning_pubkey: "something"
+ }
+ },
},
- },
- # fails because device is unknown
- # should fail with NOT_FOUND
- "unknown": {
- "user_id": local_user,
- "device_id": "unknown",
- "signatures": {
- local_user: {"ed25519:" + selfsigning_pubkey: "something"}
+ # fails because device is unknown
+ # should fail with NOT_FOUND
+ "unknown": {
+ "user_id": local_user,
+ "device_id": "unknown",
+ "signatures": {
+ local_user: {
+ "ed25519:" + selfsigning_pubkey: "something"
+ }
+ },
},
- },
- # fails because the signature is invalid
- # should fail with INVALID_SIGNATURE
- master_pubkey: {
- "user_id": local_user,
- "usage": ["master"],
- "keys": {"ed25519:" + master_pubkey: master_pubkey},
- "signatures": {
- local_user: {"ed25519:" + device_pubkey: "something"}
+ # fails because the signature is invalid
+ # should fail with INVALID_SIGNATURE
+ master_pubkey: {
+ "user_id": local_user,
+ "usage": ["master"],
+ "keys": {"ed25519:" + master_pubkey: master_pubkey},
+ "signatures": {
+ local_user: {"ed25519:" + device_pubkey: "something"}
+ },
},
},
- },
- other_user: {
- # fails because the device is not the user's master-signing key
- # should fail with NOT_FOUND
- "unknown": {
- "user_id": other_user,
- "device_id": "unknown",
- "signatures": {
- local_user: {"ed25519:" + usersigning_pubkey: "something"}
+ other_user: {
+ # fails because the device is not the user's master-signing key
+ # should fail with NOT_FOUND
+ "unknown": {
+ "user_id": other_user,
+ "device_id": "unknown",
+ "signatures": {
+ local_user: {
+ "ed25519:" + usersigning_pubkey: "something"
+ }
+ },
},
- },
- other_master_pubkey: {
- # fails because the key doesn't match what the server has
- # should fail with UNKNOWN
- "user_id": other_user,
- "usage": ["master"],
- "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey},
- "something": "random",
- "signatures": {
- local_user: {"ed25519:" + usersigning_pubkey: "something"}
+ other_master_pubkey: {
+ # fails because the key doesn't match what the server has
+ # should fail with UNKNOWN
+ "user_id": other_user,
+ "usage": ["master"],
+ "keys": {
+ "ed25519:" + other_master_pubkey: other_master_pubkey
+ },
+ "something": "random",
+ "signatures": {
+ local_user: {
+ "ed25519:" + usersigning_pubkey: "something"
+ }
+ },
},
},
},
- },
+ )
)
user_failures = ret["failures"][local_user]
@@ -478,19 +536,23 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
sign.sign_json(device_key, local_user, selfsigning_signing_key)
sign.sign_json(master_key, local_user, device_signing_key)
sign.sign_json(other_master_key, local_user, usersigning_signing_key)
- ret = yield self.handler.upload_signatures_for_device_keys(
- local_user,
- {
- local_user: {device_id: device_key, master_pubkey: master_key},
- other_user: {other_master_pubkey: other_master_key},
- },
+ ret = yield defer.ensureDeferred(
+ self.handler.upload_signatures_for_device_keys(
+ local_user,
+ {
+ local_user: {device_id: device_key, master_pubkey: master_key},
+ other_user: {other_master_pubkey: other_master_key},
+ },
+ )
)
self.assertEqual(ret["failures"], {})
# fetch the signed keys/devices and make sure that the signatures are there
- ret = yield self.handler.query_devices(
- {"device_keys": {local_user: [], other_user: []}}, 0, local_user
+ ret = yield defer.ensureDeferred(
+ self.handler.query_devices(
+ {"device_keys": {local_user: [], other_user: []}}, 0, local_user
+ )
)
self.assertEqual(
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 822ea42dde..3362050ce0 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -66,7 +66,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
- yield self.handler.get_version_info(self.local_user)
+ yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -78,7 +78,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
- yield self.handler.get_version_info(self.local_user, "bogus_version")
+ yield defer.ensureDeferred(
+ self.handler.get_version_info(self.local_user, "bogus_version")
+ )
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -87,14 +89,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_create_version(self):
"""Check that we can create and then retrieve versions.
"""
- res = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ res = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(res, "1")
# check we can retrieve it as the current version
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
version_etag = res["etag"]
self.assertIsInstance(version_etag, str)
del res["etag"]
@@ -109,7 +116,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
# check we can retrieve it as a specific version
- res = yield self.handler.get_version_info(self.local_user, "1")
+ res = yield defer.ensureDeferred(
+ self.handler.get_version_info(self.local_user, "1")
+ )
self.assertEqual(res["etag"], version_etag)
del res["etag"]
self.assertDictEqual(
@@ -123,17 +132,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
# upload a new one...
- res = yield self.handler.create_version(
- self.local_user,
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "second_version_auth_data",
- },
+ res = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "second_version_auth_data",
+ },
+ )
)
self.assertEqual(res, "2")
# check we can retrieve it as the current version
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
del res["etag"]
self.assertDictEqual(
res,
@@ -149,25 +160,32 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_update_version(self):
"""Check that we can update versions.
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
- res = yield self.handler.update_version(
- self.local_user,
- version,
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "revised_first_version_auth_data",
- "version": version,
- },
+ res = yield defer.ensureDeferred(
+ self.handler.update_version(
+ self.local_user,
+ version,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ "version": version,
+ },
+ )
)
self.assertDictEqual(res, {})
# check we can retrieve it as the current version
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
del res["etag"]
self.assertDictEqual(
res,
@@ -185,14 +203,16 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
- yield self.handler.update_version(
- self.local_user,
- "1",
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "revised_first_version_auth_data",
- "version": "1",
- },
+ yield defer.ensureDeferred(
+ self.handler.update_version(
+ self.local_user,
+ "1",
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ "version": "1",
+ },
+ )
)
except errors.SynapseError as e:
res = e.code
@@ -202,23 +222,30 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_update_omitted_version(self):
"""Check that the update succeeds if the version is missing from the body
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
- yield self.handler.update_version(
- self.local_user,
- version,
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "revised_first_version_auth_data",
- },
+ yield defer.ensureDeferred(
+ self.handler.update_version(
+ self.local_user,
+ version,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ },
+ )
)
# check we can retrieve it as the current version
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
del res["etag"] # etag is opaque, so don't test its contents
self.assertDictEqual(
res,
@@ -234,22 +261,29 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_update_bad_version(self):
"""Check that we get a 400 if the version in the body doesn't match
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
res = None
try:
- yield self.handler.update_version(
- self.local_user,
- version,
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "revised_first_version_auth_data",
- "version": "incorrect",
- },
+ yield defer.ensureDeferred(
+ self.handler.update_version(
+ self.local_user,
+ version,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ "version": "incorrect",
+ },
+ )
)
except errors.SynapseError as e:
res = e.code
@@ -261,7 +295,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
- yield self.handler.delete_version(self.local_user, "1")
+ yield defer.ensureDeferred(
+ self.handler.delete_version(self.local_user, "1")
+ )
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -272,7 +308,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
- yield self.handler.delete_version(self.local_user)
+ yield defer.ensureDeferred(self.handler.delete_version(self.local_user))
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -281,19 +317,26 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_delete_version(self):
"""Check that we can create and then delete versions.
"""
- res = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ res = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(res, "1")
# check we can delete it
- yield self.handler.delete_version(self.local_user, "1")
+ yield defer.ensureDeferred(self.handler.delete_version(self.local_user, "1"))
# check that it's gone
res = None
try:
- yield self.handler.get_version_info(self.local_user, "1")
+ yield defer.ensureDeferred(
+ self.handler.get_version_info(self.local_user, "1")
+ )
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -304,7 +347,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
- yield self.handler.get_room_keys(self.local_user, "bogus_version")
+ yield defer.ensureDeferred(
+ self.handler.get_room_keys(self.local_user, "bogus_version")
+ )
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -313,13 +358,20 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_get_missing_room_keys(self):
"""Check we get an empty response from an empty backup
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
- res = yield self.handler.get_room_keys(self.local_user, version)
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(self.local_user, version)
+ )
self.assertDictEqual(res, {"rooms": {}})
# TODO: test the locking semantics when uploading room_keys,
@@ -331,8 +383,8 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
- yield self.handler.upload_room_keys(
- self.local_user, "no_version", room_keys
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, "no_version", room_keys)
)
except errors.SynapseError as e:
res = e.code
@@ -343,16 +395,23 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""Check that we get a 404 on uploading keys when an nonexistent version
is specified
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
res = None
try:
- yield self.handler.upload_room_keys(
- self.local_user, "bogus_version", room_keys
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(
+ self.local_user, "bogus_version", room_keys
+ )
)
except errors.SynapseError as e:
res = e.code
@@ -362,24 +421,33 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_upload_room_keys_wrong_version(self):
"""Check that we get a 403 on uploading keys for an old version
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
- version = yield self.handler.create_version(
- self.local_user,
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "second_version_auth_data",
- },
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "second_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "2")
res = None
try:
- yield self.handler.upload_room_keys(self.local_user, "1", room_keys)
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, "1", room_keys)
+ )
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 403)
@@ -388,26 +456,39 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_upload_room_keys_insert(self):
"""Check that we can insert and retrieve keys for a session
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
- yield self.handler.upload_room_keys(self.local_user, version, room_keys)
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, room_keys)
+ )
- res = yield self.handler.get_room_keys(self.local_user, version)
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(self.local_user, version)
+ )
self.assertDictEqual(res, room_keys)
# check getting room_keys for a given room
- res = yield self.handler.get_room_keys(
- self.local_user, version, room_id="!abc:matrix.org"
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(
+ self.local_user, version, room_id="!abc:matrix.org"
+ )
)
self.assertDictEqual(res, room_keys)
# check getting room_keys for a given session_id
- res = yield self.handler.get_room_keys(
- self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(
+ self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ )
)
self.assertDictEqual(res, room_keys)
@@ -415,16 +496,23 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_upload_room_keys_merge(self):
"""Check that we can upload a new room_key for an existing session and
have it correctly merged"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
- yield self.handler.upload_room_keys(self.local_user, version, room_keys)
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, room_keys)
+ )
# get the etag to compare to future versions
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
backup_etag = res["etag"]
self.assertEqual(res["count"], 1)
@@ -434,29 +522,37 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# test that increasing the message_index doesn't replace the existing session
new_room_key["first_message_index"] = 2
new_room_key["session_data"] = "new"
- yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+ )
- res = yield self.handler.get_room_keys(self.local_user, version)
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(self.local_user, version)
+ )
self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
"SSBBTSBBIEZJU0gK",
)
# the etag should be the same since the session did not change
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
self.assertEqual(res["etag"], backup_etag)
# test that marking the session as verified however /does/ replace it
new_room_key["is_verified"] = True
- yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+ )
- res = yield self.handler.get_room_keys(self.local_user, version)
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(self.local_user, version)
+ )
self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
)
# the etag should NOT be equal now, since the key changed
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
self.assertNotEqual(res["etag"], backup_etag)
backup_etag = res["etag"]
@@ -464,15 +560,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# with a lower forwarding count
new_room_key["forwarded_count"] = 2
new_room_key["session_data"] = "other"
- yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+ )
- res = yield self.handler.get_room_keys(self.local_user, version)
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(self.local_user, version)
+ )
self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
)
# the etag should be the same since the session did not change
- res = yield self.handler.get_version_info(self.local_user)
+ res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
self.assertEqual(res["etag"], backup_etag)
# TODO: check edge cases as well as the common variations here
@@ -481,36 +581,59 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_delete_room_keys(self):
"""Check that we can insert and delete keys for a session
"""
- version = yield self.handler.create_version(
- self.local_user,
- {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ version = yield defer.ensureDeferred(
+ self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
)
self.assertEqual(version, "1")
# check for bulk-delete
- yield self.handler.upload_room_keys(self.local_user, version, room_keys)
- yield self.handler.delete_room_keys(self.local_user, version)
- res = yield self.handler.get_room_keys(
- self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, room_keys)
+ )
+ yield defer.ensureDeferred(
+ self.handler.delete_room_keys(self.local_user, version)
+ )
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(
+ self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ )
)
self.assertDictEqual(res, {"rooms": {}})
# check for bulk-delete per room
- yield self.handler.upload_room_keys(self.local_user, version, room_keys)
- yield self.handler.delete_room_keys(
- self.local_user, version, room_id="!abc:matrix.org"
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, room_keys)
+ )
+ yield defer.ensureDeferred(
+ self.handler.delete_room_keys(
+ self.local_user, version, room_id="!abc:matrix.org"
+ )
)
- res = yield self.handler.get_room_keys(
- self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(
+ self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ )
)
self.assertDictEqual(res, {"rooms": {}})
# check for bulk-delete per session
- yield self.handler.upload_room_keys(self.local_user, version, room_keys)
- yield self.handler.delete_room_keys(
- self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ yield defer.ensureDeferred(
+ self.handler.upload_room_keys(self.local_user, version, room_keys)
+ )
+ yield defer.ensureDeferred(
+ self.handler.delete_room_keys(
+ self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ )
)
- res = yield self.handler.get_room_keys(
- self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ res = yield defer.ensureDeferred(
+ self.handler.get_room_keys(
+ self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+ )
)
self.assertDictEqual(res, {"rooms": {}})
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 1e6a53bf7f..5878f74175 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -138,10 +138,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
- def get_current_users_in_room(room_id):
+ def get_users_in_room(room_id):
return defer.succeed({str(u) for u in self.room_members})
- hs.get_state_handler().get_current_users_in_room = get_current_users_in_room
+ self.datastore.get_users_in_room = get_users_in_room
self.datastore.get_user_directory_stream_pos.return_value = (
# we deliberately return a non-None stream pos to avoid doing an initial_spam
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 9d4f0bbe44..06575ba0a6 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Any, List, Optional, Tuple
+from typing import Any, Callable, List, Optional, Tuple
import attr
@@ -26,8 +26,9 @@ from synapse.app.generic_worker import (
GenericWorkerReplicationHandler,
GenericWorkerServer,
)
+from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest
-from synapse.replication.http import streams
+from synapse.replication.http import ReplicationRestResource, streams
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -35,7 +36,7 @@ from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
-from tests.server import FakeTransport
+from tests.server import FakeTransport, render
logger = logging.getLogger(__name__)
@@ -180,6 +181,159 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.assertEqual(request.method, b"GET")
+class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
+ """Base class for tests running multiple workers.
+
+ Automatically handle HTTP replication requests from workers to master,
+ unlike `BaseStreamTestCase`.
+ """
+
+ servlets = [] # type: List[Callable[[HomeServer, JsonResource], None]]
+
+ def setUp(self):
+ super().setUp()
+
+ # build a replication server
+ self.server_factory = ReplicationStreamProtocolFactory(self.hs)
+ self.streamer = self.hs.get_replication_streamer()
+
+ store = self.hs.get_datastore()
+ self.database = store.db
+
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+
+ self._worker_hs_to_resource = {}
+
+ # When we see a connection attempt to the master replication listener we
+ # automatically set up the connection. This is so that tests don't
+ # manually have to go and explicitly set it up each time (plus sometimes
+ # it is impossible to write the handling explicitly in the tests).
+ self.reactor.add_tcp_client_callback(
+ "1.2.3.4", 8765, self._handle_http_replication_attempt
+ )
+
+ def create_test_json_resource(self):
+ """Overrides `HomeserverTestCase.create_test_json_resource`.
+ """
+ # We override this so that it automatically registers all the HTTP
+ # replication servlets, without having to explicitly do that in all
+ # subclassses.
+
+ resource = ReplicationRestResource(self.hs)
+
+ for servlet in self.servlets:
+ servlet(self.hs, resource)
+
+ return resource
+
+ def make_worker_hs(
+ self, worker_app: str, extra_config: dict = {}, **kwargs
+ ) -> HomeServer:
+ """Make a new worker HS instance, correctly connecting replcation
+ stream to the master HS.
+
+ Args:
+ worker_app: Type of worker, e.g. `synapse.app.federation_sender`.
+ extra_config: Any extra config to use for this instances.
+ **kwargs: Options that get passed to `self.setup_test_homeserver`,
+ useful to e.g. pass some mocks for things like `http_client`
+
+ Returns:
+ The new worker HomeServer instance.
+ """
+
+ config = self._get_worker_hs_config()
+ config["worker_app"] = worker_app
+ config.update(extra_config)
+
+ worker_hs = self.setup_test_homeserver(
+ homeserverToUse=GenericWorkerServer,
+ config=config,
+ reactor=self.reactor,
+ **kwargs
+ )
+
+ store = worker_hs.get_datastore()
+ store.db._db_pool = self.database._db_pool
+
+ repl_handler = ReplicationCommandHandler(worker_hs)
+ client = ClientReplicationStreamProtocol(
+ worker_hs, "client", "test", self.clock, repl_handler,
+ )
+ server = self.server_factory.buildProtocol(None)
+
+ client_transport = FakeTransport(server, self.reactor)
+ client.makeConnection(client_transport)
+
+ server_transport = FakeTransport(client, self.reactor)
+ server.makeConnection(server_transport)
+
+ # Set up a resource for the worker
+ resource = ReplicationRestResource(self.hs)
+
+ for servlet in self.servlets:
+ servlet(worker_hs, resource)
+
+ self._worker_hs_to_resource[worker_hs] = resource
+
+ return worker_hs
+
+ def _get_worker_hs_config(self) -> dict:
+ config = self.default_config()
+ config["worker_replication_host"] = "testserv"
+ config["worker_replication_http_port"] = "8765"
+ return config
+
+ def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest):
+ render(request, self._worker_hs_to_resource[worker_hs], self.reactor)
+
+ def replicate(self):
+ """Tell the master side of replication that something has happened, and then
+ wait for the replication to occur.
+ """
+ self.streamer.on_notifier_poke()
+ self.pump()
+
+ def _handle_http_replication_attempt(self):
+ """Handles a connection attempt to the master replication HTTP
+ listener.
+ """
+
+ # We should have at least one outbound connection attempt, where the
+ # last is one to the HTTP repication IP/port.
+ clients = self.reactor.tcpClients
+ self.assertGreaterEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8765)
+
+ # Set up client side protocol
+ client_protocol = client_factory.buildProtocol(None)
+
+ request_factory = OneShotRequestFactory()
+
+ # Set up the server side protocol
+ channel = _PushHTTPChannel(self.reactor)
+ channel.requestFactory = request_factory
+ channel.site = self.site
+
+ # Connect client to server and vice versa.
+ client_to_server_transport = FakeTransport(
+ channel, self.reactor, client_protocol
+ )
+ client_protocol.makeConnection(client_to_server_transport)
+
+ server_to_client_transport = FakeTransport(
+ client_protocol, self.reactor, channel
+ )
+ channel.makeConnection(server_to_client_transport)
+
+ # Note: at this point we've wired everything up, but we need to return
+ # before the data starts flowing over the connections as this is called
+ # inside `connecTCP` before the connection has been passed back to the
+ # code that requested the TCP connection.
+
+
class TestReplicationDataHandler(GenericWorkerReplicationHandler):
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
@@ -241,6 +395,14 @@ class _PushHTTPChannel(HTTPChannel):
# We need to manually stop the _PullToPushProducer.
self._pull_to_push_producer.stop()
+ def checkPersistence(self, request, version):
+ """Check whether the connection can be re-used
+ """
+ # We hijack this to always say no for ease of wiring stuff up in
+ # `handle_http_replication_attempt`.
+ request.responseHeaders.setRawHeaders(b"connection", [b"close"])
+ return False
+
class _PullToPushProducer:
"""A push producer that wraps a pull producer.
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
index b7d753e0a3..86c03fd89c 100644
--- a/tests/replication/test_client_reader_shard.py
+++ b/tests/replication/test_client_reader_shard.py
@@ -15,63 +15,26 @@
import logging
from synapse.api.constants import LoginType
-from synapse.app.generic_worker import GenericWorkerServer
-from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest
-from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.rest.client.v2_alpha import register
-from tests import unittest
+from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
-from tests.server import FakeChannel, render
+from tests.server import FakeChannel
logger = logging.getLogger(__name__)
-class ClientReaderTestCase(unittest.HomeserverTestCase):
+class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
"""Base class for tests of the replication streams"""
- servlets = [
- register.register_servlets,
- ]
+ servlets = [register.register_servlets]
def prepare(self, reactor, clock, hs):
- # build a replication server
- self.server_factory = ReplicationStreamProtocolFactory(hs)
- self.streamer = hs.get_replication_streamer()
-
- store = hs.get_datastore()
- self.database = store.db
-
self.recaptcha_checker = DummyRecaptchaChecker(hs)
auth_handler = hs.get_auth_handler()
auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
- self.reactor.lookups["testserv"] = "1.2.3.4"
-
- def make_worker_hs(self, extra_config={}):
- config = self._get_worker_hs_config()
- config.update(extra_config)
-
- worker_hs = self.setup_test_homeserver(
- homeserverToUse=GenericWorkerServer, config=config, reactor=self.reactor,
- )
-
- store = worker_hs.get_datastore()
- store.db._db_pool = self.database._db_pool
-
- # Register the expected servlets, essentially this is HomeserverTestCase.create_test_json_resource.
- resource = JsonResource(self.hs)
-
- for servlet in self.servlets:
- servlet(worker_hs, resource)
-
- # Essentially HomeserverTestCase.render.
- def _render(request):
- render(request, self.resource, self.reactor)
-
- return worker_hs, _render
-
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.client_reader"
@@ -82,14 +45,14 @@ class ClientReaderTestCase(unittest.HomeserverTestCase):
def test_register_single_worker(self):
"""Test that registration works when using a single client reader worker.
"""
- _, worker_render = self.make_worker_hs()
+ worker_hs = self.make_worker_hs("synapse.app.client_reader")
request_1, channel_1 = self.make_request(
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
) # type: SynapseRequest, FakeChannel
- worker_render(request_1)
+ self.render_on_worker(worker_hs, request_1)
self.assertEqual(request_1.code, 401)
# Grab the session
@@ -99,7 +62,7 @@ class ClientReaderTestCase(unittest.HomeserverTestCase):
request_2, channel_2 = self.make_request(
"POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
) # type: SynapseRequest, FakeChannel
- worker_render(request_2)
+ self.render_on_worker(worker_hs, request_2)
self.assertEqual(request_2.code, 200)
# We're given a registered user.
@@ -108,15 +71,15 @@ class ClientReaderTestCase(unittest.HomeserverTestCase):
def test_register_multi_worker(self):
"""Test that registration works when using multiple client reader workers.
"""
- _, worker_render_1 = self.make_worker_hs()
- _, worker_render_2 = self.make_worker_hs()
+ worker_hs_1 = self.make_worker_hs("synapse.app.client_reader")
+ worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
request_1, channel_1 = self.make_request(
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
) # type: SynapseRequest, FakeChannel
- worker_render_1(request_1)
+ self.render_on_worker(worker_hs_1, request_1)
self.assertEqual(request_1.code, 401)
# Grab the session
@@ -126,7 +89,7 @@ class ClientReaderTestCase(unittest.HomeserverTestCase):
request_2, channel_2 = self.make_request(
"POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
) # type: SynapseRequest, FakeChannel
- worker_render_2(request_2)
+ self.render_on_worker(worker_hs_2, request_2)
self.assertEqual(request_2.code, 200)
# We're given a registered user.
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 519a2dc510..8d4dbf232e 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -19,132 +19,40 @@ from mock import Mock
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
-from synapse.app.generic_worker import GenericWorkerServer
from synapse.events.builder import EventBuilderFactory
-from synapse.replication.http import streams
-from synapse.replication.tcp.handler import ReplicationCommandHandler
-from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
-from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.rest.admin import register_servlets_for_client_rest_resource
from synapse.rest.client.v1 import login, room
from synapse.types import UserID
-from tests import unittest
-from tests.server import FakeTransport
+from tests.replication._base import BaseMultiWorkerStreamTestCase
logger = logging.getLogger(__name__)
-class BaseStreamTestCase(unittest.HomeserverTestCase):
- """Base class for tests of the replication streams"""
-
+class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
servlets = [
- streams.register_servlets,
+ login.register_servlets,
+ register_servlets_for_client_rest_resource,
+ room.register_servlets,
]
- def prepare(self, reactor, clock, hs):
- # build a replication server
- self.server_factory = ReplicationStreamProtocolFactory(hs)
- self.streamer = hs.get_replication_streamer()
-
- store = hs.get_datastore()
- self.database = store.db
-
- self.reactor.lookups["testserv"] = "1.2.3.4"
-
def default_config(self):
conf = super().default_config()
conf["send_federation"] = False
return conf
- def make_worker_hs(self, extra_config={}):
- config = self._get_worker_hs_config()
- config.update(extra_config)
-
- mock_federation_client = Mock(spec=["put_json"])
- mock_federation_client.put_json.side_effect = lambda *_, **__: defer.succeed({})
-
- worker_hs = self.setup_test_homeserver(
- http_client=mock_federation_client,
- homeserverToUse=GenericWorkerServer,
- config=config,
- reactor=self.reactor,
- )
-
- store = worker_hs.get_datastore()
- store.db._db_pool = self.database._db_pool
-
- repl_handler = ReplicationCommandHandler(worker_hs)
- client = ClientReplicationStreamProtocol(
- worker_hs, "client", "test", self.clock, repl_handler,
- )
- server = self.server_factory.buildProtocol(None)
-
- client_transport = FakeTransport(server, self.reactor)
- client.makeConnection(client_transport)
-
- server_transport = FakeTransport(client, self.reactor)
- server.makeConnection(server_transport)
-
- return worker_hs
-
- def _get_worker_hs_config(self) -> dict:
- config = self.default_config()
- config["worker_app"] = "synapse.app.federation_sender"
- config["worker_replication_host"] = "testserv"
- config["worker_replication_http_port"] = "8765"
- return config
-
- def replicate(self):
- """Tell the master side of replication that something has happened, and then
- wait for the replication to occur.
- """
- self.streamer.on_notifier_poke()
- self.pump()
-
- def create_room_with_remote_server(self, user, token, remote_server="other_server"):
- room = self.helper.create_room_as(user, tok=token)
- store = self.hs.get_datastore()
- federation = self.hs.get_handlers().federation_handler
-
- prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room))
- room_version = self.get_success(store.get_room_version(room))
-
- factory = EventBuilderFactory(self.hs)
- factory.hostname = remote_server
-
- user_id = UserID("user", remote_server).to_string()
-
- event_dict = {
- "type": EventTypes.Member,
- "state_key": user_id,
- "content": {"membership": Membership.JOIN},
- "sender": user_id,
- "room_id": room,
- }
-
- builder = factory.for_room_version(room_version, event_dict)
- join_event = self.get_success(builder.build(prev_event_ids))
-
- self.get_success(federation.on_send_join_request(remote_server, join_event))
- self.replicate()
-
- return room
-
-
-class FederationSenderTestCase(BaseStreamTestCase):
- servlets = [
- login.register_servlets,
- register_servlets_for_client_rest_resource,
- room.register_servlets,
- ]
-
def test_send_event_single_sender(self):
"""Test that using a single federation sender worker correctly sends a
new event.
"""
- worker_hs = self.make_worker_hs({"send_federation": True})
- mock_client = worker_hs.get_http_client()
+ mock_client = Mock(spec=["put_json"])
+ mock_client.put_json.side_effect = lambda *_, **__: defer.succeed({})
+
+ self.make_worker_hs(
+ "synapse.app.federation_sender",
+ {"send_federation": True},
+ http_client=mock_client,
+ )
user = self.register_user("user", "pass")
token = self.login("user", "pass")
@@ -165,23 +73,29 @@ class FederationSenderTestCase(BaseStreamTestCase):
"""Test that using two federation sender workers correctly sends
new events.
"""
- worker1 = self.make_worker_hs(
+ mock_client1 = Mock(spec=["put_json"])
+ mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({})
+ self.make_worker_hs(
+ "synapse.app.federation_sender",
{
"send_federation": True,
"worker_name": "sender1",
"federation_sender_instances": ["sender1", "sender2"],
- }
+ },
+ http_client=mock_client1,
)
- mock_client1 = worker1.get_http_client()
- worker2 = self.make_worker_hs(
+ mock_client2 = Mock(spec=["put_json"])
+ mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({})
+ self.make_worker_hs(
+ "synapse.app.federation_sender",
{
"send_federation": True,
"worker_name": "sender2",
"federation_sender_instances": ["sender1", "sender2"],
- }
+ },
+ http_client=mock_client2,
)
- mock_client2 = worker2.get_http_client()
user = self.register_user("user2", "pass")
token = self.login("user2", "pass")
@@ -191,8 +105,8 @@ class FederationSenderTestCase(BaseStreamTestCase):
for i in range(20):
server_name = "other_server_%d" % (i,)
room = self.create_room_with_remote_server(user, token, server_name)
- mock_client1.reset_mock()
- mock_client2.reset_mock()
+ mock_client1.reset_mock() # type: ignore[attr-defined]
+ mock_client2.reset_mock() # type: ignore[attr-defined]
self.create_and_send_event(room, UserID.from_string(user))
self.replicate()
@@ -222,23 +136,29 @@ class FederationSenderTestCase(BaseStreamTestCase):
"""Test that using two federation sender workers correctly sends
new typing EDUs.
"""
- worker1 = self.make_worker_hs(
+ mock_client1 = Mock(spec=["put_json"])
+ mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({})
+ self.make_worker_hs(
+ "synapse.app.federation_sender",
{
"send_federation": True,
"worker_name": "sender1",
"federation_sender_instances": ["sender1", "sender2"],
- }
+ },
+ http_client=mock_client1,
)
- mock_client1 = worker1.get_http_client()
- worker2 = self.make_worker_hs(
+ mock_client2 = Mock(spec=["put_json"])
+ mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({})
+ self.make_worker_hs(
+ "synapse.app.federation_sender",
{
"send_federation": True,
"worker_name": "sender2",
"federation_sender_instances": ["sender1", "sender2"],
- }
+ },
+ http_client=mock_client2,
)
- mock_client2 = worker2.get_http_client()
user = self.register_user("user3", "pass")
token = self.login("user3", "pass")
@@ -250,8 +170,8 @@ class FederationSenderTestCase(BaseStreamTestCase):
for i in range(20):
server_name = "other_server_%d" % (i,)
room = self.create_room_with_remote_server(user, token, server_name)
- mock_client1.reset_mock()
- mock_client2.reset_mock()
+ mock_client1.reset_mock() # type: ignore[attr-defined]
+ mock_client2.reset_mock() # type: ignore[attr-defined]
self.get_success(
typing_handler.started_typing(
@@ -284,3 +204,32 @@ class FederationSenderTestCase(BaseStreamTestCase):
self.assertTrue(sent_on_1)
self.assertTrue(sent_on_2)
+
+ def create_room_with_remote_server(self, user, token, remote_server="other_server"):
+ room = self.helper.create_room_as(user, tok=token)
+ store = self.hs.get_datastore()
+ federation = self.hs.get_handlers().federation_handler
+
+ prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room))
+ room_version = self.get_success(store.get_room_version(room))
+
+ factory = EventBuilderFactory(self.hs)
+ factory.hostname = remote_server
+
+ user_id = UserID("user", remote_server).to_string()
+
+ event_dict = {
+ "type": EventTypes.Member,
+ "state_key": user_id,
+ "content": {"membership": Membership.JOIN},
+ "sender": user_id,
+ "room_id": room,
+ }
+
+ builder = factory.for_room_version(room_version, event_dict)
+ join_event = self.get_success(builder.build(prev_event_ids))
+
+ self.get_success(federation.on_send_join_request(remote_server, join_event))
+ self.replicate()
+
+ return room
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index cca5f548e6..f16eef15f7 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -857,6 +857,53 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
+ def test_reactivate_user(self):
+ """
+ Test reactivating another user.
+ """
+
+ # Deactivate the user.
+ request, channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content=json.dumps({"deactivated": True}).encode(encoding="utf_8"),
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Attempt to reactivate the user (without a password).
+ request, channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content=json.dumps({"deactivated": False}).encode(encoding="utf_8"),
+ )
+ self.render(request)
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Reactivate the user.
+ request, channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content=json.dumps({"deactivated": False, "password": "foo"}).encode(
+ encoding="utf_8"
+ ),
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_other_user, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(False, channel.json_body["deactivated"])
+
def test_set_user_as_admin(self):
"""
Test setting the admin flag on a user.
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 4413bb3932..db52725cfe 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -547,8 +547,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_jwt_invalid_signature(self):
channel = self.jwt_login({"sub": "frog"}, "notsecret")
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
"JWT validation failed: Signature verification failed",
@@ -556,8 +556,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_jwt_expired(self):
channel = self.jwt_login({"sub": "frog", "exp": 864000})
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"], "JWT validation failed: Signature has expired"
)
@@ -565,8 +565,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_jwt_not_before(self):
now = int(time.time())
channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
"JWT validation failed: The token is not yet valid (nbf)",
@@ -574,8 +574,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_no_sub(self):
channel = self.jwt_login({"username": "root"})
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Invalid JWT")
@override_config(
@@ -597,16 +597,16 @@ class JWTTestCase(unittest.HomeserverTestCase):
# An invalid issuer.
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"], "JWT validation failed: Invalid issuer"
)
# Not providing an issuer.
channel = self.jwt_login({"sub": "kermit"})
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
'JWT validation failed: Token is missing the "iss" claim',
@@ -637,16 +637,16 @@ class JWTTestCase(unittest.HomeserverTestCase):
# An invalid audience.
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"], "JWT validation failed: Invalid audience"
)
# Not providing an audience.
channel = self.jwt_login({"sub": "kermit"})
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
'JWT validation failed: Token is missing the "aud" claim',
@@ -655,7 +655,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_aud_no_config(self):
"""Test providing an audience without requiring it in the configuration."""
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"], "JWT validation failed: Invalid audience"
)
@@ -664,8 +665,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
params = json.dumps({"type": "org.matrix.login.jwt"})
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
@@ -747,8 +748,8 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
def test_login_jwt_invalid_signature(self):
channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
"JWT validation failed: Signature verification failed",
diff --git a/tests/server.py b/tests/server.py
index a5e57c52fa..b6e0b14e78 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -237,6 +237,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
def __init__(self):
self.threadpool = ThreadPool(self)
+ self._tcp_callbacks = {}
self._udp = []
lookups = self.lookups = {}
@@ -268,6 +269,29 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
def getThreadPool(self):
return self.threadpool
+ def add_tcp_client_callback(self, host, port, callback):
+ """Add a callback that will be invoked when we receive a connection
+ attempt to the given IP/port using `connectTCP`.
+
+ Note that the callback gets run before we return the connection to the
+ client, which means callbacks cannot block while waiting for writes.
+ """
+ self._tcp_callbacks[(host, port)] = callback
+
+ def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
+ """Fake L{IReactorTCP.connectTCP}.
+ """
+
+ conn = super().connectTCP(
+ host, port, factory, timeout=timeout, bindAddress=None
+ )
+
+ callback = self._tcp_callbacks.get((host, port))
+ if callback:
+ callback()
+
+ return conn
+
class ThreadPool:
"""
@@ -486,7 +510,7 @@ class FakeTransport(object):
try:
self.other.dataReceived(to_write)
except Exception as e:
- logger.warning("Exception writing to protocol: %s", e)
+ logger.exception("Exception writing to protocol: %s", e)
return
self.buffer = self.buffer[len(to_write) :]
|