diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index cdb0048122..ce96574915 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -69,6 +69,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
self.store.get_user_by_access_token = simple_async_mock(user_info)
self.store.mark_access_token_as_used = simple_async_mock(None)
+ self.store.get_user_locked_status = simple_async_mock(False)
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
@@ -293,6 +294,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
self.store.insert_client_ip = simple_async_mock(None)
self.store.mark_access_token_as_used = simple_async_mock(None)
+ self.store.get_user_locked_status = simple_async_mock(False)
request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
@@ -311,6 +313,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
token_used=True,
)
)
+ self.store.get_user_locked_status = simple_async_mock(False)
self.store.insert_client_ip = simple_async_mock(None)
self.store.mark_access_token_as_used = simple_async_mock(None)
request = Mock(args={})
diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py
index 9305b758d7..93af614def 100644
--- a/tests/app/test_phone_stats_home.py
+++ b/tests/app/test_phone_stats_home.py
@@ -26,7 +26,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase):
def make_homeserver(
self, reactor: ThreadedMemoryReactorClock, clock: Clock
) -> HomeServer:
- hs = super(PhoneHomeR30V2TestCase, self).make_homeserver(reactor, clock)
+ hs = super().make_homeserver(reactor, clock)
# We don't want our tests to actually report statistics, so check
# that it's not enabled
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 7c63b2ea4c..2be341ac7b 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -312,7 +312,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
[("server9", get_key_id(key1))]
)
result = self.get_success(d)
- self.assertEquals(result[("server9", get_key_id(key1))].valid_until_ts, 0)
+ self.assertEqual(result[("server9", get_key_id(key1))].valid_until_ts, 0)
def test_verify_json_dedupes_key_requests(self) -> None:
"""Two requests for the same key should be deduped."""
@@ -456,24 +456,19 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated
- lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json_for_remote(
- [lookup_triplet]
+ SERVER_NAME, [testverifykey_id]
)
)
- res_keys = key_json[lookup_triplet]
- self.assertEqual(len(res_keys), 1)
- res = res_keys[0]
- self.assertEqual(res["key_id"], testverifykey_id)
- self.assertEqual(res["from_server"], SERVER_NAME)
- self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
- self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
+ res = key_json[testverifykey_id]
+ self.assertIsNotNone(res)
+ assert res is not None
+ self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
+ self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
# we expect it to be encoded as canonical json *before* it hits the db
- self.assertEqual(
- bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
- )
+ self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
# change the server name: the result should be ignored
response["server_name"] = "OTHER_SERVER"
@@ -576,23 +571,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated
- lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json_for_remote(
- [lookup_triplet]
+ SERVER_NAME, [testverifykey_id]
)
)
- res_keys = key_json[lookup_triplet]
- self.assertEqual(len(res_keys), 1)
- res = res_keys[0]
- self.assertEqual(res["key_id"], testverifykey_id)
- self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
- self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
- self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
-
- self.assertEqual(
- bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
- )
+ res = key_json[testverifykey_id]
+ self.assertIsNotNone(res)
+ assert res is not None
+ self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
+ self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
+
+ self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
def test_get_multiple_keys_from_perspectives(self) -> None:
"""Check that we can correctly request multiple keys for the same server"""
@@ -699,23 +689,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated
- lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json_for_remote(
- [lookup_triplet]
+ SERVER_NAME, [testverifykey_id]
)
)
- res_keys = key_json[lookup_triplet]
- self.assertEqual(len(res_keys), 1)
- res = res_keys[0]
- self.assertEqual(res["key_id"], testverifykey_id)
- self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
- self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
- self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
-
- self.assertEqual(
- bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
- )
+ res = key_json[testverifykey_id]
+ self.assertIsNotNone(res)
+ assert res is not None
+ self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
+ self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
+
+ self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
def test_invalid_perspectives_responses(self) -> None:
"""Check that invalid responses from the perspectives server are rejected"""
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 647ee09279..e1e58fa6e6 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -566,15 +566,16 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(res["events"]), 1)
self.assertEqual(res["events"][0]["content"]["body"], "foo")
- # Fetch the message of the dehydrated device again, which should return nothing
- # and delete the old messages
+ # Fetch the message of the dehydrated device again, which should return
+ # the same message as it has not been deleted
res = self.get_success(
self.message_handler.get_events_for_dehydrated_device(
requester=requester,
device_id=stored_dehydrated_device_id,
- since_token=res["next_batch"],
+ since_token=None,
limit=10,
)
)
self.assertTrue(len(res["next_batch"]) > 1)
- self.assertEqual(len(res["events"]), 0)
+ self.assertEqual(len(res["events"]), 1)
+ self.assertEqual(res["events"][0]["content"]["body"], "foo")
diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py
index 6309d7b36e..82c26e303f 100644
--- a/tests/handlers/test_oauth_delegation.py
+++ b/tests/handlers/test_oauth_delegation.py
@@ -491,6 +491,68 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
self.assertEqual(error.value.code, 503)
+ def test_introspection_token_cache(self) -> None:
+ access_token = "open_sesame"
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={"active": "true", "scope": "guest", "jti": access_token},
+ )
+ )
+
+ # first call should cache response
+ # Mpyp ignores below are due to mypy not understanding the dynamic substitution of msc3861 auth code
+ # for regular auth code via the config
+ self.get_success(
+ self.auth._introspect_token(access_token) # type: ignore[attr-defined]
+ )
+ introspection_token = self.auth._token_cache.get(access_token) # type: ignore[attr-defined]
+ self.assertEqual(introspection_token["jti"], access_token)
+ # there's been one http request
+ self.http_client.request.assert_called_once()
+
+ # second call should pull from cache, there should still be only one http request
+ token = self.get_success(self.auth._introspect_token(access_token)) # type: ignore[attr-defined]
+ self.http_client.request.assert_called_once()
+ self.assertEqual(token["jti"], access_token)
+
+ # advance past five minutes and check that cache expired - there should be more than one http call now
+ self.reactor.advance(360)
+ token_2 = self.get_success(self.auth._introspect_token(access_token)) # type: ignore[attr-defined]
+ self.assertEqual(self.http_client.request.call_count, 2)
+ self.assertEqual(token_2["jti"], access_token)
+
+ # test that if a cached token is expired, a fresh token will be pulled from authorizing server - first add a
+ # token with a soon-to-expire `exp` field to the cache
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={
+ "active": "true",
+ "scope": "guest",
+ "jti": "stale",
+ "exp": self.clock.time() + 100,
+ },
+ )
+ )
+ self.get_success(
+ self.auth._introspect_token("stale") # type: ignore[attr-defined]
+ )
+ introspection_token = self.auth._token_cache.get("stale") # type: ignore[attr-defined]
+ self.assertEqual(introspection_token["jti"], "stale")
+ self.assertEqual(self.http_client.request.call_count, 1)
+
+ # advance the reactor past the token expiry but less than the cache expiry
+ self.reactor.advance(120)
+ self.assertEqual(self.auth._token_cache.get("stale"), introspection_token) # type: ignore[attr-defined]
+
+ # check that the next call causes another http request (which will fail because the token is technically expired
+ # but the important thing is we discard the token from the cache and try the network)
+ self.get_failure(
+ self.auth._introspect_token("stale"), InvalidClientTokenError # type: ignore[attr-defined]
+ )
+ self.assertEqual(self.http_client.request.call_count, 2)
+
def make_device_keys(self, user_id: str, device_id: str) -> JsonDict:
# We only generate a master key to simplify the test.
master_signing_key = generate_signing_key(device_id)
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index aed2a4c07a..6a0b5fc0bd 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -514,7 +514,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertEqual(response.code, 200)
# Send the body
- request.write('{ "a": 1 }'.encode("ascii"))
+ request.write(b'{ "a": 1 }')
request.finish()
self.reactor.pump((0.1,))
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index b3310abe1b..fe631d7ecb 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -757,7 +757,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
self.assertEqual(channel.json_body["creator"], user_id)
# Check room alias.
- self.assertEquals(room_alias, f"#foo-bar:{self.module_api.server_name}")
+ self.assertEqual(room_alias, f"#foo-bar:{self.module_api.server_name}")
# Let's try a room with no alias.
room_id, room_alias = self.get_success(
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 1527b4a82d..6e78daa830 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -116,7 +116,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(request.method, b"GET")
self.assertEqual(
request.path,
- f"/_matrix/media/r0/download/{target}/{media_id}".encode("utf-8"),
+ f"/_matrix/media/r0/download/{target}/{media_id}".encode(),
)
self.assertEqual(
request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")]
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 9af9db6e3e..41a959b4d6 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -29,7 +29,16 @@ from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
from synapse.api.room_versions import RoomVersions
from synapse.media.filepath import MediaFilePaths
-from synapse.rest.client import devices, login, logout, profile, register, room, sync
+from synapse.rest.client import (
+ devices,
+ login,
+ logout,
+ profile,
+ register,
+ room,
+ sync,
+ user_directory,
+)
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
@@ -1477,6 +1486,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
login.register_servlets,
sync.register_servlets,
register.register_servlets,
+ user_directory.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -2464,6 +2474,105 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# This key was removed intentionally. Ensure it is not accidentally re-included.
self.assertNotIn("password_hash", channel.json_body)
+ def test_locked_user(self) -> None:
+ # User can sync
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v3/sync",
+ access_token=self.other_user_token,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Lock user
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={"locked": True},
+ )
+
+ # User is not authorized to sync anymore
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v3/sync",
+ access_token=self.other_user_token,
+ )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.USER_LOCKED, channel.json_body["errcode"])
+ self.assertTrue(channel.json_body["soft_logout"])
+
+ @override_config({"user_directory": {"enabled": True, "search_all_users": True}})
+ def test_locked_user_not_in_user_dir(self) -> None:
+ # User is available in the user dir
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/user_directory/search",
+ {"search_term": self.other_user},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertIn("results", channel.json_body)
+ self.assertEqual(1, len(channel.json_body["results"]))
+
+ # Lock user
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={"locked": True},
+ )
+
+ # User is not available anymore in the user dir
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/user_directory/search",
+ {"search_term": self.other_user},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertIn("results", channel.json_body)
+ self.assertEqual(0, len(channel.json_body["results"]))
+
+ @override_config(
+ {
+ "user_directory": {
+ "enabled": True,
+ "search_all_users": True,
+ "show_locked_users": True,
+ }
+ }
+ )
+ def test_locked_user_in_user_dir_with_show_locked_users_option(self) -> None:
+ # User is available in the user dir
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/user_directory/search",
+ {"search_term": self.other_user},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertIn("results", channel.json_body)
+ self.assertEqual(1, len(channel.json_body["results"]))
+
+ # Lock user
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={"locked": True},
+ )
+
+ # User is still available in the user dir
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/user_directory/search",
+ {"search_term": self.other_user},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertIn("results", channel.json_body)
+ self.assertEqual(1, len(channel.json_body["results"]))
+
@override_config({"user_directory": {"enabled": True, "search_all_users": True}})
def test_change_name_deactivate_user_user_directory(self) -> None:
"""
diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py
index 3cf29c10ea..60099f8c59 100644
--- a/tests/rest/client/test_devices.py
+++ b/tests/rest/client/test_devices.py
@@ -20,7 +20,7 @@ from synapse.api.errors import NotFoundError
from synapse.rest import admin, devices, room, sync
from synapse.rest.client import account, keys, login, register
from synapse.server import HomeServer
-from synapse.types import JsonDict, create_requester
+from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
from tests import unittest
@@ -282,6 +282,17 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
"<user_id>": {"<algorithm>:<device_id>": "<signature_base64>"}
},
},
+ "fallback_keys": {
+ "alg1:device1": "f4llb4ckk3y",
+ "signed_<algorithm>:<device_id>": {
+ "fallback": "true",
+ "key": "f4llb4ckk3y",
+ "signatures": {
+ "<user_id>": {"<algorithm>:<device_id>": "<key_base64>"}
+ },
+ },
+ },
+ "one_time_keys": {"alg1:k1": "0net1m3k3y"},
}
channel = self.make_request(
"PUT",
@@ -312,6 +323,55 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
}
self.assertEqual(device_data, expected_device_data)
+ # test that the keys are correctly uploaded
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/keys/query",
+ {
+ "device_keys": {
+ user: ["device1"],
+ },
+ },
+ token,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body["device_keys"][user][device_id]["keys"],
+ content["device_keys"]["keys"],
+ )
+ # first claim should return the onetime key we uploaded
+ res = self.get_success(
+ self.hs.get_e2e_keys_handler().claim_one_time_keys(
+ {user: {device_id: {"alg1": 1}}},
+ UserID.from_string(user),
+ timeout=None,
+ always_include_fallback_keys=False,
+ )
+ )
+ self.assertEqual(
+ res,
+ {
+ "failures": {},
+ "one_time_keys": {user: {device_id: {"alg1:k1": "0net1m3k3y"}}},
+ },
+ )
+ # second claim should return fallback key
+ res2 = self.get_success(
+ self.hs.get_e2e_keys_handler().claim_one_time_keys(
+ {user: {device_id: {"alg1": 1}}},
+ UserID.from_string(user),
+ timeout=None,
+ always_include_fallback_keys=False,
+ )
+ )
+ self.assertEqual(
+ res2,
+ {
+ "failures": {},
+ "one_time_keys": {user: {device_id: {"alg1:device1": "f4llb4ckk3y"}}},
+ },
+ )
+
# create another device for the user
(
new_device_id,
@@ -348,10 +408,21 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
expected_content = {"body": "test_message"}
self.assertEqual(channel.json_body["events"][0]["content"], expected_content)
+
+ # fetch messages again and make sure that the message was not deleted
+ channel = self.make_request(
+ "POST",
+ f"_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device/{device_id}/events",
+ content={},
+ access_token=token,
+ shorthand=False,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["events"][0]["content"], expected_content)
next_batch_token = channel.json_body.get("next_batch")
- # fetch messages again and make sure that the message was deleted and we are returned an
- # empty array
+ # make sure fetching messages with next batch token works - there are no unfetched
+ # messages so we should receive an empty array
content = {"next_batch": next_batch_token}
channel = self.make_request(
"POST",
diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py
index 180b635ea6..4e0a387bd3 100644
--- a/tests/rest/client/test_redactions.py
+++ b/tests/rest/client/test_redactions.py
@@ -627,8 +627,8 @@ class RedactionsTestCase(HomeserverTestCase):
redact_event = timeline[-1]
self.assertEqual(redact_event["type"], EventTypes.Redaction)
# The redacts key should be in the content and the redacts keys.
- self.assertEquals(redact_event["content"]["redacts"], event_id)
- self.assertEquals(redact_event["redacts"], event_id)
+ self.assertEqual(redact_event["content"]["redacts"], event_id)
+ self.assertEqual(redact_event["redacts"], event_id)
# But it isn't actually part of the event.
def get_event(txn: LoggingTransaction) -> JsonDict:
@@ -642,10 +642,10 @@ class RedactionsTestCase(HomeserverTestCase):
event_json = self.get_success(
main_datastore.db_pool.runInteraction("get_event", get_event)
)
- self.assertEquals(event_json["type"], EventTypes.Redaction)
+ self.assertEqual(event_json["type"], EventTypes.Redaction)
if expect_content:
self.assertNotIn("redacts", event_json)
- self.assertEquals(event_json["content"]["redacts"], event_id)
+ self.assertEqual(event_json["content"]["redacts"], event_id)
else:
- self.assertEquals(event_json["redacts"], event_id)
+ self.assertEqual(event_json["redacts"], event_id)
self.assertNotIn("redacts", event_json["content"])
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 75439416c1..9bfe913e45 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -129,7 +129,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}",
access_token=self.user_token,
)
- self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEqual(200, channel.code, channel.json_body)
return [ev["event_id"] for ev in channel.json_body["chunk"]]
def _get_bundled_aggregations(self) -> JsonDict:
@@ -142,7 +142,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
f"/_matrix/client/v3/rooms/{self.room}/event/{self.parent_id}",
access_token=self.user_token,
)
- self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEqual(200, channel.code, channel.json_body)
return channel.json_body["unsigned"].get("m.relations", {})
def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
@@ -1602,7 +1602,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads",
access_token=self.user_token,
)
- self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEqual(200, channel.code, channel.json_body)
threads = channel.json_body["chunk"]
return [
(
@@ -1634,7 +1634,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
##################################################
# Check the test data is configured as expected. #
##################################################
- self.assertEquals(self._get_related_events(), list(reversed(thread_replies)))
+ self.assertEqual(self._get_related_events(), list(reversed(thread_replies)))
relations = self._get_bundled_aggregations()
self.assertDictContainsSubset(
{"count": 3, "current_user_participated": True},
@@ -1655,7 +1655,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
self._redact(thread_replies.pop())
# The thread should still exist, but the latest event should be updated.
- self.assertEquals(self._get_related_events(), list(reversed(thread_replies)))
+ self.assertEqual(self._get_related_events(), list(reversed(thread_replies)))
relations = self._get_bundled_aggregations()
self.assertDictContainsSubset(
{"count": 2, "current_user_participated": True},
@@ -1674,7 +1674,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
self._redact(thread_replies.pop(0))
# Nothing should have changed (except the thread count).
- self.assertEquals(self._get_related_events(), thread_replies)
+ self.assertEqual(self._get_related_events(), thread_replies)
relations = self._get_bundled_aggregations()
self.assertDictContainsSubset(
{"count": 1, "current_user_participated": True},
@@ -1691,11 +1691,11 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
# Redact the last remaining event. #
####################################
self._redact(thread_replies.pop(0))
- self.assertEquals(thread_replies, [])
+ self.assertEqual(thread_replies, [])
# The event should no longer be considered a thread.
- self.assertEquals(self._get_related_events(), [])
- self.assertEquals(self._get_bundled_aggregations(), {})
+ self.assertEqual(self._get_related_events(), [])
+ self.assertEqual(self._get_bundled_aggregations(), {})
self.assertEqual(self._get_threads(), [])
def test_redact_parent_edit(self) -> None:
@@ -1749,8 +1749,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
# The relations are returned.
event_ids = self._get_related_events()
relations = self._get_bundled_aggregations()
- self.assertEquals(event_ids, [related_event_id])
- self.assertEquals(
+ self.assertEqual(event_ids, [related_event_id])
+ self.assertEqual(
relations[RelationTypes.REFERENCE],
{"chunk": [{"event_id": related_event_id}]},
)
@@ -1772,7 +1772,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
# The unredacted relation should still exist.
event_ids = self._get_related_events()
relations = self._get_bundled_aggregations()
- self.assertEquals(len(event_ids), 1)
+ self.assertEqual(len(event_ids), 1)
self.assertDictContainsSubset(
{
"count": 1,
@@ -1816,7 +1816,7 @@ class ThreadsTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads",
access_token=self.user_token,
)
- self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEqual(200, channel.code, channel.json_body)
threads = self._get_threads(channel.json_body)
self.assertEqual(threads, [(thread_2, reply_2), (thread_1, reply_1)])
@@ -1829,7 +1829,7 @@ class ThreadsTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads",
access_token=self.user_token,
)
- self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEqual(200, channel.code, channel.json_body)
# Tuple of (thread ID, latest event ID) for each thread.
threads = self._get_threads(channel.json_body)
self.assertEqual(threads, [(thread_1, reply_3), (thread_2, reply_2)])
@@ -1850,7 +1850,7 @@ class ThreadsTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1",
access_token=self.user_token,
)
- self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEqual(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_2])
@@ -1864,7 +1864,7 @@ class ThreadsTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1&from={next_batch}",
access_token=self.user_token,
)
- self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEqual(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_1], channel.json_body)
@@ -1899,7 +1899,7 @@ class ThreadsTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads",
access_token=self.user_token,
)
- self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEqual(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(
thread_roots, [thread_3, thread_2, thread_1], channel.json_body
@@ -1911,7 +1911,7 @@ class ThreadsTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads?include=participated",
access_token=self.user_token,
)
- self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEqual(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body)
@@ -1943,6 +1943,6 @@ class ThreadsTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads",
access_token=self.user_token,
)
- self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEqual(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_1], channel.json_body)
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 4f6347be15..88e579dc39 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -1362,7 +1362,7 @@ class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase):
# Ensure the event was persisted with the correct timestamp.
res = self.get_success(self.main_store.get_event(event_id))
- self.assertEquals(ts, res.origin_server_ts)
+ self.assertEqual(ts, res.origin_server_ts)
def test_send_state_event_ts(self) -> None:
"""Test sending a state event with a custom timestamp."""
@@ -1384,7 +1384,7 @@ class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase):
# Ensure the event was persisted with the correct timestamp.
res = self.get_success(self.main_store.get_event(event_id))
- self.assertEquals(ts, res.origin_server_ts)
+ self.assertEqual(ts, res.origin_server_ts)
def test_send_membership_event_ts(self) -> None:
"""Test sending a membership event with a custom timestamp."""
@@ -1406,7 +1406,7 @@ class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase):
# Ensure the event was persisted with the correct timestamp.
res = self.get_success(self.main_store.get_event(event_id))
- self.assertEquals(ts, res.origin_server_ts)
+ self.assertEqual(ts, res.origin_server_ts)
class RoomJoinRatelimitTestCase(RoomBase):
diff --git a/tests/server.py b/tests/server.py
index c84a524e8c..481fe34c5c 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -26,6 +26,7 @@ from typing import (
Any,
Awaitable,
Callable,
+ Deque,
Dict,
Iterable,
List,
@@ -41,7 +42,7 @@ from typing import (
from unittest.mock import Mock
import attr
-from typing_extensions import Deque, ParamSpec
+from typing_extensions import ParamSpec
from zope.interface import implementer
from twisted.internet import address, threads, udp
diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py
index 383da83dfb..f541f1d6be 100644
--- a/tests/storage/databases/main/test_lock.py
+++ b/tests/storage/databases/main/test_lock.py
@@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+
from twisted.internet import defer, reactor
from twisted.internet.base import ReactorBase
from twisted.internet.defer import Deferred
from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
-from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS
+from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS, _RENEWAL_INTERVAL_MS
from synapse.util import Clock
from tests import unittest
@@ -380,8 +381,8 @@ class ReadWriteLockTestCase(unittest.HomeserverTestCase):
self.get_success(lock.__aenter__())
# Wait for ages with the lock, we should not be able to get the lock.
- self.reactor.advance(5 * _LOCK_TIMEOUT_MS / 1000)
- self.pump()
+ for _ in range(0, 10):
+ self.reactor.advance((_RENEWAL_INTERVAL_MS / 1000))
lock2 = self.get_success(
self.store.try_acquire_read_write_lock("name", "key", write=True)
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 5e1324a169..71302facd1 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -40,7 +40,7 @@ from tests.test_utils import make_awaitable
class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
def setUp(self) -> None:
- super(ApplicationServiceStoreTestCase, self).setUp()
+ super().setUp()
self.as_yaml_files: List[str] = []
@@ -71,7 +71,7 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
except Exception:
pass
- super(ApplicationServiceStoreTestCase, self).tearDown()
+ super().tearDown()
def _add_appservice(
self, as_token: str, id: str, url: str, hs_token: str, sender: str
@@ -110,7 +110,7 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
def setUp(self) -> None:
- super(ApplicationServiceTransactionStoreTestCase, self).setUp()
+ super().setUp()
self.as_yaml_files: List[str] = []
self.hs.config.appservice.app_service_config_files = self.as_yaml_files
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index 27f450e22d..b8823d6993 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -20,7 +20,7 @@ from tests import unittest
class DataStoreTestCase(unittest.HomeserverTestCase):
def setUp(self) -> None:
- super(DataStoreTestCase, self).setUp()
+ super().setUp()
self.store = self.hs.get_datastores().main
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 05ea802008..ba41459d08 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -48,6 +48,7 @@ class RegistrationStoreTestCase(HomeserverTestCase):
"creation_ts": 0,
"user_type": None,
"deactivated": 0,
+ "locked": 0,
"shadow_banned": 0,
"approved": 1,
},
diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py
index f183c38477..52ffa91c81 100644
--- a/tests/storage/test_room_search.py
+++ b/tests/storage/test_room_search.py
@@ -318,14 +318,14 @@ class MessageSearchTest(HomeserverTestCase):
result = self.get_success(
store.search_msgs([self.room_id], query, ["content.body"])
)
- self.assertEquals(
+ self.assertEqual(
result["count"],
1 if expect_to_contain else 0,
f"expected '{query}' to match '{self.PHRASE}'"
if expect_to_contain
else f"'{query}' unexpectedly matched '{self.PHRASE}'",
)
- self.assertEquals(
+ self.assertEqual(
len(result["results"]),
1 if expect_to_contain else 0,
"results array length should match count",
@@ -336,14 +336,14 @@ class MessageSearchTest(HomeserverTestCase):
result = self.get_success(
store.search_rooms([self.room_id], query, ["content.body"], 10)
)
- self.assertEquals(
+ self.assertEqual(
result["count"],
1 if expect_to_contain else 0,
f"expected '{query}' to match '{self.PHRASE}'"
if expect_to_contain
else f"'{query}' unexpectedly matched '{self.PHRASE}'",
)
- self.assertEquals(
+ self.assertEqual(
len(result["results"]),
1 if expect_to_contain else 0,
"results array length should match count",
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 9ed330f554..a46c29ddf4 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -31,7 +31,7 @@ TEST_ROOM_ID = "!TEST:ROOM"
class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
def setUp(self) -> None:
- super(FilterEventsForServerTestCase, self).setUp()
+ super().setUp()
self.event_creation_handler = self.hs.get_event_creation_handler()
self.event_builder_factory = self.hs.get_event_builder_factory()
self._storage_controllers = self.hs.get_storage_controllers()
|