diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index 9bd6275e92..edc584d0cf 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -36,7 +36,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
hostname="matrix.org", # only used by get_groups_for_user
)
self.event = Mock(
- type="m.something", room_id="!foo:bar", sender="@someone:somewhere"
+ event_id="$abc:xyz",
+ type="m.something",
+ room_id="!foo:bar",
+ sender="@someone:somewhere",
)
self.store = Mock()
@@ -50,7 +53,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.assertTrue(
(
yield defer.ensureDeferred(
- self.service.is_interested(self.event, self.store)
+ self.service.is_interested_in_event(
+ self.event.event_id, self.event, self.store
+ )
)
)
)
@@ -62,7 +67,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.assertFalse(
(
yield defer.ensureDeferred(
- self.service.is_interested(self.event, self.store)
+ self.service.is_interested_in_event(
+ self.event.event_id, self.event, self.store
+ )
)
)
)
@@ -76,7 +83,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.assertTrue(
(
yield defer.ensureDeferred(
- self.service.is_interested(self.event, self.store)
+ self.service.is_interested_in_event(
+ self.event.event_id, self.event, self.store
+ )
)
)
)
@@ -90,7 +99,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.assertTrue(
(
yield defer.ensureDeferred(
- self.service.is_interested(self.event, self.store)
+ self.service.is_interested_in_event(
+ self.event.event_id, self.event, self.store
+ )
)
)
)
@@ -104,7 +115,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.assertFalse(
(
yield defer.ensureDeferred(
- self.service.is_interested(self.event, self.store)
+ self.service.is_interested_in_event(
+ self.event.event_id, self.event, self.store
+ )
)
)
)
@@ -121,7 +134,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.assertTrue(
(
yield defer.ensureDeferred(
- self.service.is_interested(self.event, self.store)
+ self.service.is_interested_in_event(
+ self.event.event_id, self.event, self.store
+ )
)
)
)
@@ -174,7 +189,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.assertFalse(
(
yield defer.ensureDeferred(
- self.service.is_interested(self.event, self.store)
+ self.service.is_interested_in_event(
+ self.event.event_id, self.event, self.store
+ )
)
)
)
@@ -191,7 +208,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.assertTrue(
(
yield defer.ensureDeferred(
- self.service.is_interested(self.event, self.store)
+ self.service.is_interested_in_event(
+ self.event.event_id, self.event, self.store
+ )
)
)
)
@@ -207,7 +226,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.assertTrue(
(
yield defer.ensureDeferred(
- self.service.is_interested(self.event, self.store)
+ self.service.is_interested_in_event(
+ self.event.event_id, self.event, self.store
+ )
)
)
)
@@ -225,7 +246,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.assertTrue(
(
yield defer.ensureDeferred(
- self.service.is_interested(event=self.event, store=self.store)
+ self.service.is_interested_in_event(
+ self.event.event_id, self.event, self.store
+ )
)
)
)
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index 45e3395b33..00ad19e446 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -16,6 +16,7 @@ from synapse.api.constants import EventContentFields
from synapse.api.room_versions import RoomVersions
from synapse.events import make_event_from_dict
from synapse.events.utils import (
+ SerializeEventConfig,
copy_power_levels_contents,
prune_event,
serialize_event,
@@ -392,7 +393,9 @@ class PruneEventTestCase(unittest.TestCase):
class SerializeEventTestCase(unittest.TestCase):
def serialize(self, ev, fields):
- return serialize_event(ev, 1479807801915, only_event_fields=fields)
+ return serialize_event(
+ ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields)
+ )
def test_event_fields_works_with_keys(self):
self.assertEqual(
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 60e0c31f43..e90592855a 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -201,9 +201,12 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.assertEqual(len(self.edus), 1)
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None)
+ # We queue up device list updates to be sent over federation, so we
+ # advance to clear the queue.
+ self.reactor.advance(1)
+
# a second call should produce no new device EDUs
self.hs.get_federation_sender().send_device_messages("host2")
- self.pump()
self.assertEqual(self.edus, [])
# a second device
@@ -232,6 +235,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
device1_signing_key = self.generate_and_upload_device_signing_key(u1, "D1")
device2_signing_key = self.generate_and_upload_device_signing_key(u1, "D2")
+ # We queue up device list updates to be sent over federation, so we
+ # advance to clear the queue.
+ self.reactor.advance(1)
+
# expect two more edus
self.assertEqual(len(self.edus), 2)
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id)
@@ -265,6 +272,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
e2e_handler.upload_signing_keys_for_user(u1, cross_signing_keys)
)
+ # We queue up device list updates to be sent over federation, so we
+ # advance to clear the queue.
+ self.reactor.advance(1)
+
# expect signing key update edu
self.assertEqual(len(self.edus), 2)
self.assertEqual(self.edus.pop(0)["edu_type"], "m.signing_key_update")
@@ -284,6 +295,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
)
self.assertEqual(ret["failures"], {})
+ # We queue up device list updates to be sent over federation, so we
+ # advance to clear the queue.
+ self.reactor.advance(1)
+
# expect two edus, in one or two transactions. We don't know what order the
# devices will be updated.
self.assertEqual(len(self.edus), 2)
@@ -307,6 +322,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.login("user", "pass", device_id="D2")
self.login("user", "pass", device_id="D3")
+ # We queue up device list updates to be sent over federation, so we
+ # advance to clear the queue.
+ self.reactor.advance(1)
+
# expect three edus
self.assertEqual(len(self.edus), 3)
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None)
@@ -318,6 +337,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
)
+ # We queue up device list updates to be sent over federation, so we
+ # advance to clear the queue.
+ self.reactor.advance(1)
+
# expect three edus, in an unknown order
self.assertEqual(len(self.edus), 3)
for edu in self.edus:
@@ -350,12 +373,19 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
)
+ # We queue up device list updates to be sent over federation, so we
+ # advance to clear the queue.
+ self.reactor.advance(1)
+
self.assertGreaterEqual(mock_send_txn.call_count, 4)
# recover the server
mock_send_txn.side_effect = self.record_transaction
self.hs.get_federation_sender().send_device_messages("host2")
- self.pump()
+
+ # We queue up device list updates to be sent over federation, so we
+ # advance to clear the queue.
+ self.reactor.advance(1)
# for each device, there should be a single update
self.assertEqual(len(self.edus), 3)
@@ -390,6 +420,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
)
+ # We queue up device list updates to be sent over federation, so we
+ # advance to clear the queue.
+ self.reactor.advance(1)
+
self.assertGreaterEqual(mock_send_txn.call_count, 4)
# run the prune job
@@ -401,7 +435,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# recover the server
mock_send_txn.side_effect = self.record_transaction
self.hs.get_federation_sender().send_device_messages("host2")
- self.pump()
+
+ # We queue up device list updates to be sent over federation, so we
+ # advance to clear the queue.
+ self.reactor.advance(1)
# there should be a single update for this user.
self.assertEqual(len(self.edus), 1)
@@ -435,6 +472,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.login("user", "pass", device_id="D2")
self.login("user", "pass", device_id="D3")
+ # We queue up device list updates to be sent over federation, so we
+ # advance to clear the queue.
+ self.reactor.advance(1)
+
# delete them again
self.get_success(
self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
@@ -451,7 +492,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# recover the server
mock_send_txn.side_effect = self.record_transaction
self.hs.get_federation_sender().send_device_messages("host2")
- self.pump()
+
+ # We queue up device list updates to be sent over federation, so we
+ # advance to clear the queue.
+ self.reactor.advance(1)
# ... and we should get a single update for this user.
self.assertEqual(len(self.edus), 1)
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 072e6bbcdd..cead9f90df 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -59,11 +59,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.event_source = hs.get_event_sources()
def test_notify_interested_services(self):
- interested_service = self._mkservice(is_interested=True)
+ interested_service = self._mkservice(is_interested_in_event=True)
services = [
- self._mkservice(is_interested=False),
+ self._mkservice(is_interested_in_event=False),
interested_service,
- self._mkservice(is_interested=False),
+ self._mkservice(is_interested_in_event=False),
]
self.mock_as_api.query_user.return_value = make_awaitable(True)
@@ -85,7 +85,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def test_query_user_exists_unknown_user(self):
user_id = "@someone:anywhere"
- services = [self._mkservice(is_interested=True)]
+ services = [self._mkservice(is_interested_in_event=True)]
services[0].is_interested_in_user.return_value = True
self.mock_store.get_app_services.return_value = services
self.mock_store.get_user_by_id.return_value = make_awaitable(None)
@@ -102,7 +102,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def test_query_user_exists_known_user(self):
user_id = "@someone:anywhere"
- services = [self._mkservice(is_interested=True)]
+ services = [self._mkservice(is_interested_in_event=True)]
services[0].is_interested_in_user.return_value = True
self.mock_store.get_app_services.return_value = services
self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id})
@@ -127,11 +127,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
room_id = "!alpha:bet"
servers = ["aperture"]
- interested_service = self._mkservice_alias(is_interested_in_alias=True)
+ interested_service = self._mkservice_alias(is_room_alias_in_namespace=True)
services = [
- self._mkservice_alias(is_interested_in_alias=False),
+ self._mkservice_alias(is_room_alias_in_namespace=False),
interested_service,
- self._mkservice_alias(is_interested_in_alias=False),
+ self._mkservice_alias(is_room_alias_in_namespace=False),
]
self.mock_as_api.query_alias.return_value = make_awaitable(True)
@@ -275,7 +275,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
to be pushed out to interested appservices, and that the stream ID is
updated accordingly.
"""
- interested_service = self._mkservice(is_interested=True)
+ interested_service = self._mkservice(is_interested_in_event=True)
services = [interested_service]
self.mock_store.get_app_services.return_value = services
self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
@@ -304,7 +304,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
Test sending out of order ephemeral events to the appservice handler
are ignored.
"""
- interested_service = self._mkservice(is_interested=True)
+ interested_service = self._mkservice(is_interested_in_event=True)
services = [interested_service]
self.mock_store.get_app_services.return_value = services
@@ -325,17 +325,45 @@ class AppServiceHandlerTestCase(unittest.TestCase):
interested_service, ephemeral=[]
)
- def _mkservice(self, is_interested, protocols=None):
+ def _mkservice(
+ self, is_interested_in_event: bool, protocols: Optional[Iterable] = None
+ ) -> Mock:
+ """
+ Create a new mock representing an ApplicationService.
+
+ Args:
+ is_interested_in_event: Whether this application service will be considered
+ interested in all events.
+ protocols: The third-party protocols that this application service claims to
+ support.
+
+ Returns:
+ A mock representing the ApplicationService.
+ """
service = Mock()
- service.is_interested.return_value = make_awaitable(is_interested)
+ service.is_interested_in_event.return_value = make_awaitable(
+ is_interested_in_event
+ )
service.token = "mock_service_token"
service.url = "mock_service_url"
service.protocols = protocols
return service
- def _mkservice_alias(self, is_interested_in_alias):
+ def _mkservice_alias(self, is_room_alias_in_namespace: bool) -> Mock:
+ """
+ Create a new mock representing an ApplicationService that is or is not interested
+ any given room aliase.
+
+ Args:
+ is_room_alias_in_namespace: If true, the application service will be interested
+ in all room aliases that are queried against it. If false, the application
+ service will not be interested in any room aliases.
+
+ Returns:
+ A mock representing the ApplicationService.
+ """
service = Mock()
- service.is_interested_in_alias.return_value = is_interested_in_alias
+ service.is_room_alias_in_namespace.return_value = is_room_alias_in_namespace
service.token = "mock_service_token"
service.url = "mock_service_url"
return service
diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py
index cff07a8973..d37292ce13 100644
--- a/tests/handlers/test_room_summary.py
+++ b/tests/handlers/test_room_summary.py
@@ -172,6 +172,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
result_room_ids = []
result_children_ids = []
for result_room in result["rooms"]:
+ # Ensure federation results are not leaking over the client-server API.
+ self.assertNotIn("allowed_room_ids", result_room)
+
result_room_ids.append(result_room["room_id"])
result_children_ids.append(
[
diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py
index fb36aa9940..becec84524 100644
--- a/tests/rest/admin/test_background_updates.py
+++ b/tests/rest/admin/test_background_updates.py
@@ -155,10 +155,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"current_updates": {
"master": {
"name": "test_update",
- "average_items_per_ms": 0.001,
+ "average_items_per_ms": 0.1,
"total_duration_ms": 1000.0,
"total_item_count": (
- BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
+ BackgroundUpdater.DEFAULT_BACKGROUND_BATCH_SIZE
),
}
},
@@ -210,10 +210,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"current_updates": {
"master": {
"name": "test_update",
- "average_items_per_ms": 0.001,
+ "average_items_per_ms": 0.1,
"total_duration_ms": 1000.0,
"total_item_count": (
- BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
+ BackgroundUpdater.DEFAULT_BACKGROUND_BATCH_SIZE
),
}
},
@@ -239,10 +239,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"current_updates": {
"master": {
"name": "test_update",
- "average_items_per_ms": 0.001,
+ "average_items_per_ms": 0.1,
"total_duration_ms": 1000.0,
"total_item_count": (
- BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
+ BackgroundUpdater.DEFAULT_BACKGROUND_BATCH_SIZE
),
}
},
@@ -278,11 +278,9 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"current_updates": {
"master": {
"name": "test_update",
- "average_items_per_ms": 0.001,
+ "average_items_per_ms": 0.05263157894736842,
"total_duration_ms": 2000.0,
- "total_item_count": (
- 2 * BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
- ),
+ "total_item_count": (110),
}
},
"enabled": True,
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index 6c4462e74a..def836054d 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -15,11 +15,12 @@ import json
import os
import re
from email.parser import Parser
-from typing import Dict, List, Optional
+from typing import Any, Dict, List, Optional, Union
from unittest.mock import Mock
import pkg_resources
+from twisted.internet.interfaces import IReactorTCP
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -30,6 +31,7 @@ from synapse.rest import admin
from synapse.rest.client import account, login, register, room
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
from synapse.server import HomeServer
+from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
@@ -46,7 +48,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
# Email config.
@@ -67,20 +69,27 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver(config=config)
async def sendmail(
- reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs
- ):
- self.email_attempts.append(msg)
-
- self.email_attempts = []
+ reactor: IReactorTCP,
+ smtphost: str,
+ smtpport: int,
+ from_addr: str,
+ to_addr: str,
+ msg_bytes: bytes,
+ *args: Any,
+ **kwargs: Any,
+ ) -> None:
+ self.email_attempts.append(msg_bytes)
+
+ self.email_attempts: List[bytes] = []
hs.get_send_email_handler()._sendmail = sendmail
return hs
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.submit_token_resource = PasswordResetSubmitTokenResource(hs)
- def test_basic_password_reset(self):
+ def test_basic_password_reset(self) -> None:
"""Test basic password reset flow"""
old_password = "monkey"
new_password = "kangeroo"
@@ -118,7 +127,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
self.attempt_wrong_password_login("kermit", old_password)
@override_config({"rc_3pid_validation": {"burst_count": 3}})
- def test_ratelimit_by_email(self):
+ def test_ratelimit_by_email(self) -> None:
"""Test that we ratelimit /requestToken for the same email."""
old_password = "monkey"
new_password = "kangeroo"
@@ -139,7 +148,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
)
)
- def reset(ip):
+ def reset(ip: str) -> None:
client_secret = "foobar"
session_id = self._request_token(email, client_secret, ip)
@@ -166,7 +175,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
self.assertEqual(cm.exception.code, 429)
- def test_basic_password_reset_canonicalise_email(self):
+ def test_basic_password_reset_canonicalise_email(self) -> None:
"""Test basic password reset flow
Request password reset with different spelling
"""
@@ -206,7 +215,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Assert we can't log in with the old password
self.attempt_wrong_password_login("kermit", old_password)
- def test_cant_reset_password_without_clicking_link(self):
+ def test_cant_reset_password_without_clicking_link(self) -> None:
"""Test that we do actually need to click the link in the email"""
old_password = "monkey"
new_password = "kangeroo"
@@ -241,7 +250,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Assert we can't log in with the new password
self.attempt_wrong_password_login("kermit", new_password)
- def test_no_valid_token(self):
+ def test_no_valid_token(self) -> None:
"""Test that we do actually need to request a token and can't just
make a session up.
"""
@@ -277,7 +286,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
self.attempt_wrong_password_login("kermit", new_password)
@unittest.override_config({"request_token_inhibit_3pid_errors": True})
- def test_password_reset_bad_email_inhibit_error(self):
+ def test_password_reset_bad_email_inhibit_error(self) -> None:
"""Test that triggering a password reset with an email address that isn't bound
to an account doesn't leak the lack of binding for that address if configured
that way.
@@ -292,7 +301,12 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
self.assertIsNotNone(session_id)
- def _request_token(self, email, client_secret, ip="127.0.0.1"):
+ def _request_token(
+ self,
+ email: str,
+ client_secret: str,
+ ip: str = "127.0.0.1",
+ ) -> str:
channel = self.make_request(
"POST",
b"account/password/email/requestToken",
@@ -309,7 +323,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
return channel.json_body["sid"]
- def _validate_token(self, link):
+ def _validate_token(self, link: str) -> None:
# Remove the host
path = link.replace("https://example.com", "")
@@ -339,7 +353,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(200, channel.code, channel.result)
- def _get_link_from_email(self):
+ def _get_link_from_email(self) -> str:
assert self.email_attempts, "No emails have been sent"
raw_msg = self.email_attempts[-1].decode("UTF-8")
@@ -354,14 +368,19 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
if not text:
self.fail("Could not find text portion of email to parse")
+ assert text is not None
match = re.search(r"https://example.com\S+", text)
assert match, "Could not find link in email"
return match.group(0)
def _reset_password(
- self, new_password, session_id, client_secret, expected_code=200
- ):
+ self,
+ new_password: str,
+ session_id: str,
+ client_secret: str,
+ expected_code: int = 200,
+ ) -> None:
channel = self.make_request(
"POST",
b"account/password",
@@ -388,11 +407,11 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.hs = self.setup_test_homeserver()
return self.hs
- def test_deactivate_account(self):
+ def test_deactivate_account(self) -> None:
user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test")
@@ -407,7 +426,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
channel = self.make_request("GET", "account/whoami", access_token=tok)
self.assertEqual(channel.code, 401)
- def test_pending_invites(self):
+ def test_pending_invites(self) -> None:
"""Tests that deactivating a user rejects every pending invite for them."""
store = self.hs.get_datastores().main
@@ -448,7 +467,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(memberships), 1, memberships)
self.assertEqual(memberships[0].room_id, room_id, memberships)
- def deactivate(self, user_id, tok):
+ def deactivate(self, user_id: str, tok: str) -> None:
request_data = json.dumps(
{
"auth": {
@@ -474,12 +493,12 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
register.register_servlets,
]
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["allow_guest_access"] = True
return config
- def test_GET_whoami(self):
+ def test_GET_whoami(self) -> None:
device_id = "wouldgohere"
user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test", device_id=device_id)
@@ -496,7 +515,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
},
)
- def test_GET_whoami_guests(self):
+ def test_GET_whoami_guests(self) -> None:
channel = self.make_request(
b"POST", b"/_matrix/client/r0/register?kind=guest", b"{}"
)
@@ -516,7 +535,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
},
)
- def test_GET_whoami_appservices(self):
+ def test_GET_whoami_appservices(self) -> None:
user_id = "@as:test"
as_token = "i_am_an_app_service"
@@ -541,7 +560,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
)
self.assertFalse(hasattr(whoami, "device_id"))
- def _whoami(self, tok):
+ def _whoami(self, tok: str) -> JsonDict:
channel = self.make_request("GET", "account/whoami", {}, access_token=tok)
self.assertEqual(channel.code, 200)
return channel.json_body
@@ -555,7 +574,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
synapse.rest.admin.register_servlets_for_client_rest_resource,
]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
# Email config.
@@ -576,16 +595,23 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.hs = self.setup_test_homeserver(config=config)
async def sendmail(
- reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs
- ):
- self.email_attempts.append(msg)
-
- self.email_attempts = []
+ reactor: IReactorTCP,
+ smtphost: str,
+ smtpport: int,
+ from_addr: str,
+ to_addr: str,
+ msg_bytes: bytes,
+ *args: Any,
+ **kwargs: Any,
+ ) -> None:
+ self.email_attempts.append(msg_bytes)
+
+ self.email_attempts: List[bytes] = []
self.hs.get_send_email_handler()._sendmail = sendmail
return self.hs
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.user_id = self.register_user("kermit", "test")
@@ -593,83 +619,73 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.email = "test@example.com"
self.url_3pid = b"account/3pid"
- def test_add_valid_email(self):
- self.get_success(self._add_email(self.email, self.email))
+ def test_add_valid_email(self) -> None:
+ self._add_email(self.email, self.email)
- def test_add_valid_email_second_time(self):
- self.get_success(self._add_email(self.email, self.email))
- self.get_success(
- self._request_token_invalid_email(
- self.email,
- expected_errcode=Codes.THREEPID_IN_USE,
- expected_error="Email is already in use",
- )
+ def test_add_valid_email_second_time(self) -> None:
+ self._add_email(self.email, self.email)
+ self._request_token_invalid_email(
+ self.email,
+ expected_errcode=Codes.THREEPID_IN_USE,
+ expected_error="Email is already in use",
)
- def test_add_valid_email_second_time_canonicalise(self):
- self.get_success(self._add_email(self.email, self.email))
- self.get_success(
- self._request_token_invalid_email(
- "TEST@EXAMPLE.COM",
- expected_errcode=Codes.THREEPID_IN_USE,
- expected_error="Email is already in use",
- )
+ def test_add_valid_email_second_time_canonicalise(self) -> None:
+ self._add_email(self.email, self.email)
+ self._request_token_invalid_email(
+ "TEST@EXAMPLE.COM",
+ expected_errcode=Codes.THREEPID_IN_USE,
+ expected_error="Email is already in use",
)
- def test_add_email_no_at(self):
- self.get_success(
- self._request_token_invalid_email(
- "address-without-at.bar",
- expected_errcode=Codes.UNKNOWN,
- expected_error="Unable to parse email address",
- )
+ def test_add_email_no_at(self) -> None:
+ self._request_token_invalid_email(
+ "address-without-at.bar",
+ expected_errcode=Codes.UNKNOWN,
+ expected_error="Unable to parse email address",
)
- def test_add_email_two_at(self):
- self.get_success(
- self._request_token_invalid_email(
- "foo@foo@test.bar",
- expected_errcode=Codes.UNKNOWN,
- expected_error="Unable to parse email address",
- )
+ def test_add_email_two_at(self) -> None:
+ self._request_token_invalid_email(
+ "foo@foo@test.bar",
+ expected_errcode=Codes.UNKNOWN,
+ expected_error="Unable to parse email address",
)
- def test_add_email_bad_format(self):
- self.get_success(
- self._request_token_invalid_email(
- "user@bad.example.net@good.example.com",
- expected_errcode=Codes.UNKNOWN,
- expected_error="Unable to parse email address",
- )
+ def test_add_email_bad_format(self) -> None:
+ self._request_token_invalid_email(
+ "user@bad.example.net@good.example.com",
+ expected_errcode=Codes.UNKNOWN,
+ expected_error="Unable to parse email address",
)
- def test_add_email_domain_to_lower(self):
- self.get_success(self._add_email("foo@TEST.BAR", "foo@test.bar"))
+ def test_add_email_domain_to_lower(self) -> None:
+ self._add_email("foo@TEST.BAR", "foo@test.bar")
- def test_add_email_domain_with_umlaut(self):
- self.get_success(self._add_email("foo@Öumlaut.com", "foo@öumlaut.com"))
+ def test_add_email_domain_with_umlaut(self) -> None:
+ self._add_email("foo@Öumlaut.com", "foo@öumlaut.com")
- def test_add_email_address_casefold(self):
- self.get_success(self._add_email("Strauß@Example.com", "strauss@example.com"))
+ def test_add_email_address_casefold(self) -> None:
+ self._add_email("Strauß@Example.com", "strauss@example.com")
- def test_address_trim(self):
- self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar"))
+ def test_address_trim(self) -> None:
+ self._add_email(" foo@test.bar ", "foo@test.bar")
@override_config({"rc_3pid_validation": {"burst_count": 3}})
- def test_ratelimit_by_ip(self):
+ def test_ratelimit_by_ip(self) -> None:
"""Tests that adding emails is ratelimited by IP"""
# We expect to be able to set three emails before getting ratelimited.
- self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar"))
- self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar"))
- self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar"))
+ self._add_email("foo1@test.bar", "foo1@test.bar")
+ self._add_email("foo2@test.bar", "foo2@test.bar")
+ self._add_email("foo3@test.bar", "foo3@test.bar")
with self.assertRaises(HttpResponseException) as cm:
- self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar"))
+ self._add_email("foo4@test.bar", "foo4@test.bar")
self.assertEqual(cm.exception.code, 429)
- def test_add_email_if_disabled(self):
+ def test_add_email_if_disabled(self) -> None:
"""Test adding email to profile when doing so is disallowed"""
self.hs.config.registration.enable_3pid_changes = False
@@ -695,7 +711,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
},
access_token=self.user_id_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Get user
@@ -705,10 +721,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
- def test_delete_email(self):
+ def test_delete_email(self) -> None:
"""Test deleting an email from profile"""
# Add a threepid
self.get_success(
@@ -727,7 +743,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
{"medium": "email", "address": self.email},
access_token=self.user_id_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.result["body"])
# Get user
channel = self.make_request(
@@ -736,10 +752,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
- def test_delete_email_if_disabled(self):
+ def test_delete_email_if_disabled(self) -> None:
"""Test deleting an email from profile when disallowed"""
self.hs.config.registration.enable_3pid_changes = False
@@ -761,7 +777,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Get user
@@ -771,11 +787,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
- def test_cant_add_email_without_clicking_link(self):
+ def test_cant_add_email_without_clicking_link(self) -> None:
"""Test that we do actually need to click the link in the email"""
client_secret = "foobar"
session_id = self._request_token(self.email, client_secret)
@@ -797,7 +813,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
},
access_token=self.user_id_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.result["body"])
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
# Get user
@@ -807,10 +823,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
- def test_no_valid_token(self):
+ def test_no_valid_token(self) -> None:
"""Test that we do actually need to request a token and can't just
make a session up.
"""
@@ -832,7 +848,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
},
access_token=self.user_id_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.result["body"])
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
# Get user
@@ -842,11 +858,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
@override_config({"next_link_domain_whitelist": None})
- def test_next_link(self):
+ def test_next_link(self) -> None:
"""Tests a valid next_link parameter value with no whitelist (good case)"""
self._request_token(
"something@example.com",
@@ -856,7 +872,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
)
@override_config({"next_link_domain_whitelist": None})
- def test_next_link_exotic_protocol(self):
+ def test_next_link_exotic_protocol(self) -> None:
"""Tests using a esoteric protocol as a next_link parameter value.
Someone may be hosting a client on IPFS etc.
"""
@@ -868,7 +884,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
)
@override_config({"next_link_domain_whitelist": None})
- def test_next_link_file_uri(self):
+ def test_next_link_file_uri(self) -> None:
"""Tests next_link parameters cannot be file URI"""
# Attempt to use a next_link value that points to the local disk
self._request_token(
@@ -879,7 +895,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
)
@override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
- def test_next_link_domain_whitelist(self):
+ def test_next_link_domain_whitelist(self) -> None:
"""Tests next_link parameters must fit the whitelist if provided"""
# Ensure not providing a next_link parameter still works
@@ -912,7 +928,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
)
@override_config({"next_link_domain_whitelist": []})
- def test_empty_next_link_domain_whitelist(self):
+ def test_empty_next_link_domain_whitelist(self) -> None:
"""Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially
disallowed
"""
@@ -962,28 +978,28 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
def _request_token_invalid_email(
self,
- email,
- expected_errcode,
- expected_error,
- client_secret="foobar",
- ):
+ email: str,
+ expected_errcode: str,
+ expected_error: str,
+ client_secret: str = "foobar",
+ ) -> None:
channel = self.make_request(
"POST",
b"account/3pid/email/requestToken",
{"client_secret": client_secret, "email": email, "send_attempt": 1},
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.result["body"])
self.assertEqual(expected_errcode, channel.json_body["errcode"])
self.assertEqual(expected_error, channel.json_body["error"])
- def _validate_token(self, link):
+ def _validate_token(self, link: str) -> None:
# Remove the host
path = link.replace("https://example.com", "")
channel = self.make_request("GET", path, shorthand=False)
self.assertEqual(200, channel.code, channel.result)
- def _get_link_from_email(self):
+ def _get_link_from_email(self) -> str:
assert self.email_attempts, "No emails have been sent"
raw_msg = self.email_attempts[-1].decode("UTF-8")
@@ -998,12 +1014,13 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
if not text:
self.fail("Could not find text portion of email to parse")
+ assert text is not None
match = re.search(r"https://example.com\S+", text)
assert match, "Could not find link in email"
return match.group(0)
- def _add_email(self, request_email, expected_email):
+ def _add_email(self, request_email: str, expected_email: str) -> None:
"""Test adding an email to profile"""
previous_email_attempts = len(self.email_attempts)
@@ -1030,7 +1047,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.result["body"])
# Get user
channel = self.make_request(
@@ -1039,7 +1056,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
@@ -1055,18 +1072,18 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
url = "/_matrix/client/unstable/org.matrix.msc3720/account_status"
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["experimental_features"] = {"msc3720_enabled": True}
return self.setup_test_homeserver(config=config)
- def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.requester = self.register_user("requester", "password")
self.requester_tok = self.login("requester", "password")
- self.server_name = homeserver.config.server.server_name
+ self.server_name = hs.config.server.server_name
- def test_missing_mxid(self):
+ def test_missing_mxid(self) -> None:
"""Tests that not providing any MXID raises an error."""
self._test_status(
users=None,
@@ -1074,7 +1091,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
expected_errcode=Codes.MISSING_PARAM,
)
- def test_invalid_mxid(self):
+ def test_invalid_mxid(self) -> None:
"""Tests that providing an invalid MXID raises an error."""
self._test_status(
users=["bad:test"],
@@ -1082,7 +1099,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
expected_errcode=Codes.INVALID_PARAM,
)
- def test_local_user_not_exists(self):
+ def test_local_user_not_exists(self) -> None:
"""Tests that the account status endpoints correctly reports that a user doesn't
exist.
"""
@@ -1098,7 +1115,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
expected_failures=[],
)
- def test_local_user_exists(self):
+ def test_local_user_exists(self) -> None:
"""Tests that the account status endpoint correctly reports that a user doesn't
exist.
"""
@@ -1115,7 +1132,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
expected_failures=[],
)
- def test_local_user_deactivated(self):
+ def test_local_user_deactivated(self) -> None:
"""Tests that the account status endpoint correctly reports a deactivated user."""
user = self.register_user("someuser", "password")
self.get_success(
@@ -1135,7 +1152,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
expected_failures=[],
)
- def test_mixed_local_and_remote_users(self):
+ def test_mixed_local_and_remote_users(self) -> None:
"""Tests that if some users are remote the account status endpoint correctly
merges the remote responses with the local result.
"""
@@ -1150,7 +1167,13 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
"@bad:badremote",
]
- async def post_json(destination, path, data, *a, **kwa):
+ async def post_json(
+ destination: str,
+ path: str,
+ data: Optional[JsonDict] = None,
+ *a: Any,
+ **kwa: Any,
+ ) -> Union[JsonDict, list]:
if destination == "remote":
return {
"account_statuses": {
@@ -1160,9 +1183,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
},
}
}
- if destination == "otherremote":
- return {}
- if destination == "badremote":
+ elif destination == "badremote":
# badremote tries to overwrite the status of a user that doesn't belong
# to it (i.e. users[1]) with false data, which Synapse is expected to
# ignore.
@@ -1176,6 +1197,9 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
},
}
}
+ # if destination == "otherremote"
+ else:
+ return {}
# Register a mock that will return the expected result depending on the remote.
self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json)
@@ -1205,7 +1229,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None,
expected_failures: Optional[List[str]] = None,
expected_errcode: Optional[str] = None,
- ):
+ ) -> None:
"""Send a request to the account status endpoint and check that the response
matches with what's expected.
diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py
index 5c31a54421..823e8ab8c4 100644
--- a/tests/rest/client/test_filter.py
+++ b/tests/rest/client/test_filter.py
@@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import Codes
from synapse.rest.client import filter
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests import unittest
@@ -30,11 +32,11 @@ class FilterTestCase(unittest.HomeserverTestCase):
EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}'
servlets = [filter.register_servlets]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.filtering = hs.get_filtering()
self.store = hs.get_datastores().main
- def test_add_filter(self):
+ def test_add_filter(self) -> None:
channel = self.make_request(
"POST",
"/_matrix/client/r0/user/%s/filter" % (self.user_id),
@@ -43,11 +45,13 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"200")
self.assertEqual(channel.json_body, {"filter_id": "0"})
- filter = self.store.get_user_filter(user_localpart="apple", filter_id=0)
+ filter = self.get_success(
+ self.store.get_user_filter(user_localpart="apple", filter_id=0)
+ )
self.pump()
- self.assertEqual(filter.result, self.EXAMPLE_FILTER)
+ self.assertEqual(filter, self.EXAMPLE_FILTER)
- def test_add_filter_for_other_user(self):
+ def test_add_filter_for_other_user(self) -> None:
channel = self.make_request(
"POST",
"/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"),
@@ -57,7 +61,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"403")
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
- def test_add_filter_non_local_user(self):
+ def test_add_filter_non_local_user(self) -> None:
_is_mine = self.hs.is_mine
self.hs.is_mine = lambda target_user: False
channel = self.make_request(
@@ -70,14 +74,13 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"403")
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
- def test_get_filter(self):
- filter_id = defer.ensureDeferred(
+ def test_get_filter(self) -> None:
+ filter_id = self.get_success(
self.filtering.add_user_filter(
user_localpart="apple", user_filter=self.EXAMPLE_FILTER
)
)
self.reactor.advance(1)
- filter_id = filter_id.result
channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id)
)
@@ -85,7 +88,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"200")
self.assertEqual(channel.json_body, self.EXAMPLE_FILTER)
- def test_get_filter_non_existant(self):
+ def test_get_filter_non_existant(self) -> None:
channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id)
)
@@ -95,7 +98,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
# Currently invalid params do not have an appropriate errcode
# in errors.py
- def test_get_filter_invalid_id(self):
+ def test_get_filter_invalid_id(self) -> None:
channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id)
)
@@ -103,7 +106,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"400")
# No ID also returns an invalid_id error
- def test_get_filter_no_id(self):
+ def test_get_filter_no_id(self) -> None:
channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id)
)
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index c8db45719e..a40a5de399 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -15,7 +15,7 @@
import itertools
import urllib.parse
-from typing import Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
from unittest.mock import patch
from twisted.test.proto_helpers import MemoryReactor
@@ -24,8 +24,7 @@ from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin
from synapse.rest.client import login, register, relations, room, sync
from synapse.server import HomeServer
-from synapse.storage.relations import RelationPaginationToken
-from synapse.types import JsonDict, StreamToken
+from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
@@ -34,7 +33,7 @@ from tests.test_utils import make_awaitable
from tests.test_utils.event_injection import inject_event
-class RelationsTestCase(unittest.HomeserverTestCase):
+class BaseRelationsTestCase(unittest.HomeserverTestCase):
servlets = [
relations.register_servlets,
room.register_servlets,
@@ -45,10 +44,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
]
hijack_auth = False
- def default_config(self) -> dict:
+ def default_config(self) -> Dict[str, Any]:
# We need to enable msc1849 support for aggregations
config = super().default_config()
- config["experimental_msc1849_support_enabled"] = True
# We enable frozen dicts as relations/edits change event contents, so we
# want to test that we don't modify the events in the caches.
@@ -67,10 +65,62 @@ class RelationsTestCase(unittest.HomeserverTestCase):
res = self.helper.send(self.room, body="Hi!", tok=self.user_token)
self.parent_id = res["event_id"]
- def test_send_relation(self) -> None:
- """Tests that sending a relation using the new /send_relation works
- creates the right shape of event.
+ def _create_user(self, localpart: str) -> Tuple[str, str]:
+ user_id = self.register_user(localpart, "abc123")
+ access_token = self.login(localpart, "abc123")
+
+ return user_id, access_token
+
+ def _send_relation(
+ self,
+ relation_type: str,
+ event_type: str,
+ key: Optional[str] = None,
+ content: Optional[dict] = None,
+ access_token: Optional[str] = None,
+ parent_id: Optional[str] = None,
+ ) -> FakeChannel:
+ """Helper function to send a relation pointing at `self.parent_id`
+
+ Args:
+ relation_type: One of `RelationTypes`
+ event_type: The type of the event to create
+ key: The aggregation key used for m.annotation relation type.
+ content: The content of the created event. Will be modified to configure
+ the m.relates_to key based on the other provided parameters.
+ access_token: The access token used to send the relation, defaults
+ to `self.user_token`
+ parent_id: The event_id this relation relates to. If None, then self.parent_id
+
+ Returns:
+ FakeChannel
"""
+ if not access_token:
+ access_token = self.user_token
+
+ original_id = parent_id if parent_id else self.parent_id
+
+ if content is None:
+ content = {}
+ content["m.relates_to"] = {
+ "event_id": original_id,
+ "rel_type": relation_type,
+ }
+ if key is not None:
+ content["m.relates_to"]["key"] = key
+
+ channel = self.make_request(
+ "POST",
+ f"/_matrix/client/v3/rooms/{self.room}/send/{event_type}",
+ content,
+ access_token=access_token,
+ )
+ return channel
+
+
+class RelationsTestCase(BaseRelationsTestCase):
+ def test_send_relation(self) -> None:
+ """Tests that sending a relation works."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
self.assertEqual(200, channel.code, channel.json_body)
@@ -79,7 +129,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "/rooms/%s/event/%s" % (self.room, event_id),
+ f"/rooms/{self.room}/event/{event_id}",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
@@ -230,15 +280,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel.json_body["chunk"][0],
)
- def _stream_token_to_relation_token(self, token: str) -> str:
- """Convert a StreamToken into a legacy token (RelationPaginationToken)."""
- room_key = self.get_success(StreamToken.from_string(self.store, token)).room_key
- return self.get_success(
- RelationPaginationToken(
- topological=room_key.topological, stream=room_key.stream
- ).to_string(self.store)
- )
-
def test_repeated_paginate_relations(self) -> None:
"""Test that if we paginate using a limit and tokens then we get the
expected events.
@@ -279,34 +320,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
found_event_ids.reverse()
self.assertEqual(found_event_ids, expected_event_ids)
- # Reset and try again, but convert the tokens to the legacy format.
- prev_token = ""
- found_event_ids = []
- for _ in range(20):
- from_token = ""
- if prev_token:
- from_token = "&from=" + self._stream_token_to_relation_token(prev_token)
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
- next_batch = channel.json_body.get("next_batch")
-
- self.assertNotEqual(prev_token, next_batch)
- prev_token = next_batch
-
- if not prev_token:
- break
-
- # We paginated backwards, so reverse
- found_event_ids.reverse()
- self.assertEqual(found_event_ids, expected_event_ids)
-
def test_pagination_from_sync_and_messages(self) -> None:
"""Pagination tokens from /sync and /messages can be used to paginate /relations."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
@@ -317,9 +330,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
# Request /sync, limiting it such that only the latest event is returned
# (and not the relation).
- filter = urllib.parse.quote_plus(
- '{"room": {"timeline": {"limit": 1}}}'.encode()
- )
+ filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}')
channel = self.make_request(
"GET", f"/sync?filter={filter}", access_token=self.user_token
)
@@ -404,8 +415,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s"
- % (self.room, self.parent_id, from_token),
+ f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
@@ -495,39 +505,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
found_event_ids.reverse()
self.assertEqual(found_event_ids, expected_event_ids)
- # Reset and try again, but convert the tokens to the legacy format.
- prev_token = ""
- found_event_ids = []
- for _ in range(20):
- from_token = ""
- if prev_token:
- from_token = "&from=" + self._stream_token_to_relation_token(prev_token)
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}"
- f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
- f"/m.reaction/{encoded_key}?limit=1{from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-
- found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
-
- next_batch = channel.json_body.get("next_batch")
-
- self.assertNotEqual(prev_token, next_batch)
- prev_token = next_batch
-
- if not prev_token:
- break
-
- # We paginated backwards, so reverse
- found_event_ids.reverse()
- self.assertEqual(found_event_ids, expected_event_ids)
-
def test_aggregation(self) -> None:
"""Test that annotations get correctly aggregated."""
@@ -544,8 +521,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "/_matrix/client/unstable/rooms/%s/aggregations/%s"
- % (self.room, self.parent_id),
+ f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
@@ -560,47 +536,13 @@ class RelationsTestCase(unittest.HomeserverTestCase):
},
)
- def test_aggregation_redactions(self) -> None:
- """Test that annotations get correctly aggregated after a redaction."""
-
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(200, channel.code, channel.json_body)
- to_redact_event_id = channel.json_body["event_id"]
-
- channel = self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- # Now lets redact one of the 'a' reactions
- channel = self.make_request(
- "POST",
- "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id),
- access_token=self.user_token,
- content={},
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/unstable/rooms/%s/aggregations/%s"
- % (self.room, self.parent_id),
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- self.assertEqual(
- channel.json_body,
- {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
- )
-
def test_aggregation_must_be_annotation(self) -> None:
"""Test that aggregations must be annotations."""
channel = self.make_request(
"GET",
- "/_matrix/client/unstable/rooms/%s/aggregations/%s/%s?limit=1"
- % (self.room, self.parent_id, RelationTypes.REPLACE),
+ f"/_matrix/client/unstable/rooms/{self.room}/aggregations"
+ f"/{self.parent_id}/{RelationTypes.REPLACE}?limit=1",
access_token=self.user_token,
)
self.assertEqual(400, channel.code, channel.json_body)
@@ -691,10 +633,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
}
},
"event_id": thread_2,
- "room_id": self.room,
"sender": self.user_id,
"type": "m.room.test",
- "user_id": self.user_id,
},
relations_dict[RelationTypes.THREAD].get("latest_event"),
)
@@ -986,9 +926,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
# Request sync, but limit the timeline so it becomes limited (and includes
# bundled aggregations).
- filter = urllib.parse.quote_plus(
- '{"room": {"timeline": {"limit": 2}}}'.encode()
- )
+ filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 2}}}')
channel = self.make_request(
"GET", f"/sync?filter={filter}", access_token=self.user_token
)
@@ -1053,7 +991,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "/rooms/%s/event/%s" % (self.room, self.parent_id),
+ f"/rooms/{self.room}/event/{self.parent_id}",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
@@ -1096,7 +1034,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "/rooms/%s/event/%s" % (self.room, reply),
+ f"/rooms/{self.room}/event/{reply}",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
@@ -1198,7 +1136,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
# Request the original event.
channel = self.make_request(
"GET",
- "/rooms/%s/event/%s" % (self.room, self.parent_id),
+ f"/rooms/{self.room}/event/{self.parent_id}",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
@@ -1217,102 +1155,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
{"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
)
- def test_relations_redaction_redacts_edits(self) -> None:
- """Test that edits of an event are redacted when the original event
- is redacted.
- """
- # Send a new event
- res = self.helper.send(self.room, body="Heyo!", tok=self.user_token)
- original_event_id = res["event_id"]
-
- # Add a relation
- channel = self._send_relation(
- RelationTypes.REPLACE,
- "m.room.message",
- parent_id=original_event_id,
- content={
- "msgtype": "m.text",
- "body": "Wibble",
- "m.new_content": {"msgtype": "m.text", "body": "First edit"},
- },
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- # Check the relation is returned
- channel = self.make_request(
- "GET",
- "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message"
- % (self.room, original_event_id),
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- self.assertIn("chunk", channel.json_body)
- self.assertEqual(len(channel.json_body["chunk"]), 1)
-
- # Redact the original event
- channel = self.make_request(
- "PUT",
- "/rooms/%s/redact/%s/%s"
- % (self.room, original_event_id, "test_relations_redaction_redacts_edits"),
- access_token=self.user_token,
- content="{}",
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- # Try to check for remaining m.replace relations
- channel = self.make_request(
- "GET",
- "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message"
- % (self.room, original_event_id),
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- # Check that no relations are returned
- self.assertIn("chunk", channel.json_body)
- self.assertEqual(channel.json_body["chunk"], [])
-
- def test_aggregations_redaction_prevents_access_to_aggregations(self) -> None:
- """Test that annotations of an event are redacted when the original event
- is redacted.
- """
- # Send a new event
- res = self.helper.send(self.room, body="Hello!", tok=self.user_token)
- original_event_id = res["event_id"]
-
- # Add a relation
- channel = self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", key="👍", parent_id=original_event_id
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- # Redact the original
- channel = self.make_request(
- "PUT",
- "/rooms/%s/redact/%s/%s"
- % (
- self.room,
- original_event_id,
- "test_aggregations_redaction_prevents_access_to_aggregations",
- ),
- access_token=self.user_token,
- content="{}",
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- # Check that aggregations returns zero
- channel = self.make_request(
- "GET",
- "/_matrix/client/unstable/rooms/%s/aggregations/%s/m.annotation/m.reaction"
- % (self.room, original_event_id),
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- self.assertIn("chunk", channel.json_body)
- self.assertEqual(channel.json_body["chunk"], [])
-
def test_unknown_relations(self) -> None:
"""Unknown relations should be accepted."""
channel = self._send_relation("m.relation.test", "m.room.test")
@@ -1321,8 +1163,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1"
- % (self.room, self.parent_id),
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
@@ -1343,7 +1184,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
# When bundling the unknown relation is not included.
channel = self.make_request(
"GET",
- "/rooms/%s/event/%s" % (self.room, self.parent_id),
+ f"/rooms/{self.room}/event/{self.parent_id}",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
@@ -1352,8 +1193,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
# But unknown relations can be directly queried.
channel = self.make_request(
"GET",
- "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1"
- % (self.room, self.parent_id),
+ f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
@@ -1369,58 +1209,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
raise AssertionError(f"Event {self.parent_id} not found in chunk")
- def _send_relation(
- self,
- relation_type: str,
- event_type: str,
- key: Optional[str] = None,
- content: Optional[dict] = None,
- access_token: Optional[str] = None,
- parent_id: Optional[str] = None,
- ) -> FakeChannel:
- """Helper function to send a relation pointing at `self.parent_id`
-
- Args:
- relation_type: One of `RelationTypes`
- event_type: The type of the event to create
- key: The aggregation key used for m.annotation relation type.
- content: The content of the created event. Will be modified to configure
- the m.relates_to key based on the other provided parameters.
- access_token: The access token used to send the relation, defaults
- to `self.user_token`
- parent_id: The event_id this relation relates to. If None, then self.parent_id
-
- Returns:
- FakeChannel
- """
- if not access_token:
- access_token = self.user_token
-
- original_id = parent_id if parent_id else self.parent_id
-
- if content is None:
- content = {}
- content["m.relates_to"] = {
- "event_id": original_id,
- "rel_type": relation_type,
- }
- if key is not None:
- content["m.relates_to"]["key"] = key
-
- channel = self.make_request(
- "POST",
- f"/_matrix/client/v3/rooms/{self.room}/send/{event_type}",
- content,
- access_token=access_token,
- )
- return channel
-
- def _create_user(self, localpart: str) -> Tuple[str, str]:
- user_id = self.register_user(localpart, "abc123")
- access_token = self.login(localpart, "abc123")
-
- return user_id, access_token
-
def test_background_update(self) -> None:
"""Test the event_arbitrary_relations background update."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
@@ -1482,3 +1270,235 @@ class RelationsTestCase(unittest.HomeserverTestCase):
[ev["event_id"] for ev in channel.json_body["chunk"]],
[annotation_event_id_good, thread_event_id],
)
+
+
+class RelationRedactionTestCase(BaseRelationsTestCase):
+ """
+ Test the behaviour of relations when the parent or child event is redacted.
+
+ The behaviour of each relation type is subtly different which causes the tests
+ to be a bit repetitive, they follow a naming scheme of:
+
+ test_redact_(relation|parent)_{relation_type}
+
+ The first bit of "relation" means that the event with the relation defined
+ on it (the child event) is to be redacted. A "parent" means that the target
+ of the relation (the parent event) is to be redacted.
+
+ The relation_type describes which type of relation is under test (i.e. it is
+ related to the value of rel_type in the event content).
+ """
+
+ def _redact(self, event_id: str) -> None:
+ channel = self.make_request(
+ "POST",
+ f"/_matrix/client/r0/rooms/{self.room}/redact/{event_id}",
+ access_token=self.user_token,
+ content={},
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ def _make_relation_requests(self) -> Tuple[List[str], JsonDict]:
+ """
+ Makes requests and ensures they result in a 200 response, returns a
+ tuple of results:
+
+ 1. `/relations` -> Returns a list of event IDs.
+ 2. `/event` -> Returns the response's m.relations field (from unsigned),
+ if it exists.
+ """
+
+ # Request the relations of the event.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]]
+
+ # Fetch the bundled aggregations of the event.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/event/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ bundled_relations = channel.json_body["unsigned"].get("m.relations", {})
+
+ return event_ids, bundled_relations
+
+ def _get_aggregations(self) -> List[JsonDict]:
+ """Request /aggregations on the parent ID and includes the returned chunk."""
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ return channel.json_body["chunk"]
+
+ def test_redact_relation_annotation(self) -> None:
+ """
+ Test that annotations of an event are properly handled after the
+ annotation is redacted.
+
+ The redacted relation should not be included in bundled aggregations or
+ the response to relations.
+ """
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ self.assertEqual(200, channel.code, channel.json_body)
+ to_redact_event_id = channel.json_body["event_id"]
+
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ unredacted_event_id = channel.json_body["event_id"]
+
+ # Both relations should exist.
+ event_ids, relations = self._make_relation_requests()
+ self.assertCountEqual(event_ids, [to_redact_event_id, unredacted_event_id])
+ self.assertEquals(
+ relations["m.annotation"],
+ {"chunk": [{"type": "m.reaction", "key": "a", "count": 2}]},
+ )
+
+ # Both relations appear in the aggregation.
+ chunk = self._get_aggregations()
+ self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 2}])
+
+ # Redact one of the reactions.
+ self._redact(to_redact_event_id)
+
+ # The unredacted relation should still exist.
+ event_ids, relations = self._make_relation_requests()
+ self.assertEquals(event_ids, [unredacted_event_id])
+ self.assertEquals(
+ relations["m.annotation"],
+ {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
+ )
+
+ # The unredacted aggregation should still exist.
+ chunk = self._get_aggregations()
+ self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 1}])
+
+ @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
+ def test_redact_relation_thread(self) -> None:
+ """
+ Test that thread replies are properly handled after the thread reply redacted.
+
+ The redacted event should not be included in bundled aggregations or
+ the response to relations.
+ """
+ channel = self._send_relation(
+ RelationTypes.THREAD,
+ EventTypes.Message,
+ content={"body": "reply 1", "msgtype": "m.text"},
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ unredacted_event_id = channel.json_body["event_id"]
+
+ # Note that the *last* event in the thread is redacted, as that gets
+ # included in the bundled aggregation.
+ channel = self._send_relation(
+ RelationTypes.THREAD,
+ EventTypes.Message,
+ content={"body": "reply 2", "msgtype": "m.text"},
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ to_redact_event_id = channel.json_body["event_id"]
+
+ # Both relations exist.
+ event_ids, relations = self._make_relation_requests()
+ self.assertEquals(event_ids, [to_redact_event_id, unredacted_event_id])
+ self.assertDictContainsSubset(
+ {
+ "count": 2,
+ "current_user_participated": True,
+ },
+ relations[RelationTypes.THREAD],
+ )
+ # And the latest event returned is the event that will be redacted.
+ self.assertEqual(
+ relations[RelationTypes.THREAD]["latest_event"]["event_id"],
+ to_redact_event_id,
+ )
+
+ # Redact one of the reactions.
+ self._redact(to_redact_event_id)
+
+ # The unredacted relation should still exist.
+ event_ids, relations = self._make_relation_requests()
+ self.assertEquals(event_ids, [unredacted_event_id])
+ self.assertDictContainsSubset(
+ {
+ "count": 1,
+ "current_user_participated": True,
+ },
+ relations[RelationTypes.THREAD],
+ )
+ # And the latest event is now the unredacted event.
+ self.assertEqual(
+ relations[RelationTypes.THREAD]["latest_event"]["event_id"],
+ unredacted_event_id,
+ )
+
+ def test_redact_parent_edit(self) -> None:
+ """Test that edits of an event are redacted when the original event
+ is redacted.
+ """
+ # Add a relation
+ channel = self._send_relation(
+ RelationTypes.REPLACE,
+ "m.room.message",
+ parent_id=self.parent_id,
+ content={
+ "msgtype": "m.text",
+ "body": "Wibble",
+ "m.new_content": {"msgtype": "m.text", "body": "First edit"},
+ },
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ # Check the relation is returned
+ event_ids, relations = self._make_relation_requests()
+ self.assertEqual(len(event_ids), 1)
+ self.assertIn(RelationTypes.REPLACE, relations)
+
+ # Redact the original event
+ self._redact(self.parent_id)
+
+ # The relations are not returned.
+ event_ids, relations = self._make_relation_requests()
+ self.assertEqual(len(event_ids), 0)
+ self.assertEqual(relations, {})
+
+ def test_redact_parent_annotation(self) -> None:
+ """Test that annotations of an event are redacted when the original event
+ is redacted.
+ """
+ # Add a relation
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ # The relations should exist.
+ event_ids, relations = self._make_relation_requests()
+ self.assertEqual(len(event_ids), 1)
+ self.assertIn(RelationTypes.ANNOTATION, relations)
+
+ # The aggregation should exist.
+ chunk = self._get_aggregations()
+ self.assertEqual(chunk, [{"type": "m.reaction", "key": "👍", "count": 1}])
+
+ # Redact the original event.
+ self._redact(self.parent_id)
+
+ # The relations are not returned.
+ event_ids, relations = self._make_relation_requests()
+ self.assertEqual(event_ids, [])
+ self.assertEqual(relations, {})
+
+ # There's nothing to aggregate.
+ chunk = self._get_aggregations()
+ self.assertEqual(chunk, [])
diff --git a/tests/rest/client/test_report_event.py b/tests/rest/client/test_report_event.py
index ee6b0b9ebf..20a259fc43 100644
--- a/tests/rest/client/test_report_event.py
+++ b/tests/rest/client/test_report_event.py
@@ -14,8 +14,13 @@
import json
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.rest.admin
from synapse.rest.client import login, report_event, room
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
from tests import unittest
@@ -28,7 +33,7 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
report_event.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
self.other_user = self.register_user("user", "pass")
@@ -42,35 +47,35 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
self.event_id = resp["event_id"]
self.report_path = f"rooms/{self.room_id}/report/{self.event_id}"
- def test_reason_str_and_score_int(self):
+ def test_reason_str_and_score_int(self) -> None:
data = {"reason": "this makes me sad", "score": -100}
self._assert_status(200, data)
- def test_no_reason(self):
+ def test_no_reason(self) -> None:
data = {"score": 0}
self._assert_status(200, data)
- def test_no_score(self):
+ def test_no_score(self) -> None:
data = {"reason": "this makes me sad"}
self._assert_status(200, data)
- def test_no_reason_and_no_score(self):
- data = {}
+ def test_no_reason_and_no_score(self) -> None:
+ data: JsonDict = {}
self._assert_status(200, data)
- def test_reason_int_and_score_str(self):
+ def test_reason_int_and_score_str(self) -> None:
data = {"reason": 10, "score": "string"}
self._assert_status(400, data)
- def test_reason_zero_and_score_blank(self):
+ def test_reason_zero_and_score_blank(self) -> None:
data = {"reason": 0, "score": ""}
self._assert_status(400, data)
- def test_reason_and_score_null(self):
+ def test_reason_and_score_null(self) -> None:
data = {"reason": None, "score": None}
self._assert_status(400, data)
- def _assert_status(self, response_status, data):
+ def _assert_status(self, response_status: int, data: JsonDict) -> None:
channel = self.make_request(
"POST",
self.report_path,
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index e0b11e7264..37866ee330 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -18,11 +18,12 @@
"""Tests REST events for /rooms paths."""
import json
-from typing import Iterable, List
+from typing import Any, Dict, Iterable, List, Optional
from unittest.mock import Mock, call
from urllib import parse as urlparse
from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.constants import (
@@ -35,7 +36,9 @@ from synapse.api.errors import Codes, HttpResponseException
from synapse.handlers.pagination import PurgeStatus
from synapse.rest import admin
from synapse.rest.client import account, directory, login, profile, room, sync
+from synapse.server import HomeServer
from synapse.types import JsonDict, RoomAlias, UserID, create_requester
+from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
@@ -45,11 +48,11 @@ PATH_PREFIX = b"/_matrix/client/api/v1"
class RoomBase(unittest.HomeserverTestCase):
- rmcreator_id = None
+ rmcreator_id: Optional[str] = None
servlets = [room.register_servlets, room.register_deprecated_servlets]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.hs = self.setup_test_homeserver(
"red",
@@ -57,15 +60,15 @@ class RoomBase(unittest.HomeserverTestCase):
federation_client=Mock(),
)
- self.hs.get_federation_handler = Mock()
+ self.hs.get_federation_handler = Mock() # type: ignore[assignment]
self.hs.get_federation_handler.return_value.maybe_backfill = Mock(
return_value=make_awaitable(None)
)
- async def _insert_client_ip(*args, **kwargs):
+ async def _insert_client_ip(*args: Any, **kwargs: Any) -> None:
return None
- self.hs.get_datastores().main.insert_client_ip = _insert_client_ip
+ self.hs.get_datastores().main.insert_client_ip = _insert_client_ip # type: ignore[assignment]
return self.hs
@@ -76,7 +79,7 @@ class RoomPermissionsTestCase(RoomBase):
user_id = "@sid1:red"
rmcreator_id = "@notme:red"
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.helper.auth_user_id = self.rmcreator_id
# create some rooms under the name rmcreator_id
@@ -108,12 +111,12 @@ class RoomPermissionsTestCase(RoomBase):
# auth as user_id now
self.helper.auth_user_id = self.user_id
- def test_can_do_action(self):
+ def test_can_do_action(self) -> None:
msg_content = b'{"msgtype":"m.text","body":"hello"}'
seq = iter(range(100))
- def send_msg_path():
+ def send_msg_path() -> str:
return "/rooms/%s/send/m.room.message/mid%s" % (
self.created_rmid,
str(next(seq)),
@@ -148,7 +151,7 @@ class RoomPermissionsTestCase(RoomBase):
channel = self.make_request("PUT", send_msg_path(), msg_content)
self.assertEqual(403, channel.code, msg=channel.result["body"])
- def test_topic_perms(self):
+ def test_topic_perms(self) -> None:
topic_content = b'{"topic":"My Topic Name"}'
topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid
@@ -214,14 +217,14 @@ class RoomPermissionsTestCase(RoomBase):
self.assertEqual(403, channel.code, msg=channel.result["body"])
def _test_get_membership(
- self, room=None, members: Iterable = frozenset(), expect_code=None
- ):
+ self, room: str, members: Iterable = frozenset(), expect_code: int = 200
+ ) -> None:
for member in members:
path = "/rooms/%s/state/m.room.member/%s" % (room, member)
channel = self.make_request("GET", path)
self.assertEqual(expect_code, channel.code)
- def test_membership_basic_room_perms(self):
+ def test_membership_basic_room_perms(self) -> None:
# === room does not exist ===
room = self.uncreated_rmid
# get membership of self, get membership of other, uncreated room
@@ -241,7 +244,7 @@ class RoomPermissionsTestCase(RoomBase):
self.helper.join(room=room, user=usr, expect_code=404)
self.helper.leave(room=room, user=usr, expect_code=404)
- def test_membership_private_room_perms(self):
+ def test_membership_private_room_perms(self) -> None:
room = self.created_rmid
# get membership of self, get membership of other, private room + invite
# expect all 403s
@@ -264,7 +267,7 @@ class RoomPermissionsTestCase(RoomBase):
members=[self.user_id, self.rmcreator_id], room=room, expect_code=200
)
- def test_membership_public_room_perms(self):
+ def test_membership_public_room_perms(self) -> None:
room = self.created_public_rmid
# get membership of self, get membership of other, public room + invite
# expect 403
@@ -287,7 +290,7 @@ class RoomPermissionsTestCase(RoomBase):
members=[self.user_id, self.rmcreator_id], room=room, expect_code=200
)
- def test_invited_permissions(self):
+ def test_invited_permissions(self) -> None:
room = self.created_rmid
self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
@@ -310,7 +313,7 @@ class RoomPermissionsTestCase(RoomBase):
expect_code=403,
)
- def test_joined_permissions(self):
+ def test_joined_permissions(self) -> None:
room = self.created_rmid
self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
self.helper.join(room=room, user=self.user_id)
@@ -348,7 +351,7 @@ class RoomPermissionsTestCase(RoomBase):
# set left of self, expect 200
self.helper.leave(room=room, user=self.user_id)
- def test_leave_permissions(self):
+ def test_leave_permissions(self) -> None:
room = self.created_rmid
self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
self.helper.join(room=room, user=self.user_id)
@@ -383,7 +386,7 @@ class RoomPermissionsTestCase(RoomBase):
)
# tests the "from banned" line from the table in https://spec.matrix.org/unstable/client-server-api/#mroommember
- def test_member_event_from_ban(self):
+ def test_member_event_from_ban(self) -> None:
room = self.created_rmid
self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
self.helper.join(room=room, user=self.user_id)
@@ -475,21 +478,21 @@ class RoomsMemberListTestCase(RoomBase):
user_id = "@sid1:red"
- def test_get_member_list(self):
+ def test_get_member_list(self) -> None:
room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
self.assertEqual(200, channel.code, msg=channel.result["body"])
- def test_get_member_list_no_room(self):
+ def test_get_member_list_no_room(self) -> None:
channel = self.make_request("GET", "/rooms/roomdoesnotexist/members")
self.assertEqual(403, channel.code, msg=channel.result["body"])
- def test_get_member_list_no_permission(self):
+ def test_get_member_list_no_permission(self) -> None:
room_id = self.helper.create_room_as("@some_other_guy:red")
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
self.assertEqual(403, channel.code, msg=channel.result["body"])
- def test_get_member_list_no_permission_with_at_token(self):
+ def test_get_member_list_no_permission_with_at_token(self) -> None:
"""
Tests that a stranger to the room cannot get the member list
(in the case that they use an at token).
@@ -509,7 +512,7 @@ class RoomsMemberListTestCase(RoomBase):
)
self.assertEqual(403, channel.code, msg=channel.result["body"])
- def test_get_member_list_no_permission_former_member(self):
+ def test_get_member_list_no_permission_former_member(self) -> None:
"""
Tests that a former member of the room can not get the member list.
"""
@@ -529,7 +532,7 @@ class RoomsMemberListTestCase(RoomBase):
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
self.assertEqual(403, channel.code, msg=channel.result["body"])
- def test_get_member_list_no_permission_former_member_with_at_token(self):
+ def test_get_member_list_no_permission_former_member_with_at_token(self) -> None:
"""
Tests that a former member of the room can not get the member list
(in the case that they use an at token).
@@ -569,7 +572,7 @@ class RoomsMemberListTestCase(RoomBase):
)
self.assertEqual(403, channel.code, msg=channel.result["body"])
- def test_get_member_list_mixed_memberships(self):
+ def test_get_member_list_mixed_memberships(self) -> None:
room_creator = "@some_other_guy:red"
room_id = self.helper.create_room_as(room_creator)
room_path = "/rooms/%s/members" % room_id
@@ -594,26 +597,26 @@ class RoomsCreateTestCase(RoomBase):
user_id = "@sid1:red"
- def test_post_room_no_keys(self):
+ def test_post_room_no_keys(self) -> None:
# POST with no config keys, expect new room id
channel = self.make_request("POST", "/createRoom", "{}")
self.assertEqual(200, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
- def test_post_room_visibility_key(self):
+ def test_post_room_visibility_key(self) -> None:
# POST with visibility config key, expect new room id
channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}')
self.assertEqual(200, channel.code)
self.assertTrue("room_id" in channel.json_body)
- def test_post_room_custom_key(self):
+ def test_post_room_custom_key(self) -> None:
# POST with custom config keys, expect new room id
channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}')
self.assertEqual(200, channel.code)
self.assertTrue("room_id" in channel.json_body)
- def test_post_room_known_and_unknown_keys(self):
+ def test_post_room_known_and_unknown_keys(self) -> None:
# POST with custom + known config keys, expect new room id
channel = self.make_request(
"POST", "/createRoom", b'{"visibility":"private","custom":"things"}'
@@ -621,7 +624,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(200, channel.code)
self.assertTrue("room_id" in channel.json_body)
- def test_post_room_invalid_content(self):
+ def test_post_room_invalid_content(self) -> None:
# POST with invalid content / paths, expect 400
channel = self.make_request("POST", "/createRoom", b'{"visibili')
self.assertEqual(400, channel.code)
@@ -629,7 +632,7 @@ class RoomsCreateTestCase(RoomBase):
channel = self.make_request("POST", "/createRoom", b'["hello"]')
self.assertEqual(400, channel.code)
- def test_post_room_invitees_invalid_mxid(self):
+ def test_post_room_invitees_invalid_mxid(self) -> None:
# POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088
# Note the trailing space in the MXID here!
channel = self.make_request(
@@ -638,7 +641,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(400, channel.code)
@unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}})
- def test_post_room_invitees_ratelimit(self):
+ def test_post_room_invitees_ratelimit(self) -> None:
"""Test that invites sent when creating a room are ratelimited by a RateLimiter,
which ratelimits them correctly, including by not limiting when the requester is
exempt from ratelimiting.
@@ -674,7 +677,7 @@ class RoomsCreateTestCase(RoomBase):
channel = self.make_request("POST", "/createRoom", content)
self.assertEqual(200, channel.code)
- def test_spam_checker_may_join_room(self):
+ def test_spam_checker_may_join_room(self) -> None:
"""Tests that the user_may_join_room spam checker callback is correctly bypassed
when creating a new room.
"""
@@ -704,12 +707,12 @@ class RoomTopicTestCase(RoomBase):
user_id = "@sid1:red"
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# create the room
self.room_id = self.helper.create_room_as(self.user_id)
self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,)
- def test_invalid_puts(self):
+ def test_invalid_puts(self) -> None:
# missing keys or invalid json
channel = self.make_request("PUT", self.path, "{}")
self.assertEqual(400, channel.code, msg=channel.result["body"])
@@ -736,7 +739,7 @@ class RoomTopicTestCase(RoomBase):
channel = self.make_request("PUT", self.path, content)
self.assertEqual(400, channel.code, msg=channel.result["body"])
- def test_rooms_topic(self):
+ def test_rooms_topic(self) -> None:
# nothing should be there
channel = self.make_request("GET", self.path)
self.assertEqual(404, channel.code, msg=channel.result["body"])
@@ -751,7 +754,7 @@ class RoomTopicTestCase(RoomBase):
self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(content), channel.json_body)
- def test_rooms_topic_with_extra_keys(self):
+ def test_rooms_topic_with_extra_keys(self) -> None:
# valid put with extra keys
content = '{"topic":"Seasons","subtopic":"Summer"}'
channel = self.make_request("PUT", self.path, content)
@@ -768,10 +771,10 @@ class RoomMemberStateTestCase(RoomBase):
user_id = "@sid1:red"
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_id = self.helper.create_room_as(self.user_id)
- def test_invalid_puts(self):
+ def test_invalid_puts(self) -> None:
path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id)
# missing keys or invalid json
channel = self.make_request("PUT", path, "{}")
@@ -801,7 +804,7 @@ class RoomMemberStateTestCase(RoomBase):
channel = self.make_request("PUT", path, content.encode("ascii"))
self.assertEqual(400, channel.code, msg=channel.result["body"])
- def test_rooms_members_self(self):
+ def test_rooms_members_self(self) -> None:
path = "/rooms/%s/state/m.room.member/%s" % (
urlparse.quote(self.room_id),
self.user_id,
@@ -812,13 +815,13 @@ class RoomMemberStateTestCase(RoomBase):
channel = self.make_request("PUT", path, content.encode("ascii"))
self.assertEqual(200, channel.code, msg=channel.result["body"])
- channel = self.make_request("GET", path, None)
+ channel = self.make_request("GET", path, content=b"")
self.assertEqual(200, channel.code, msg=channel.result["body"])
expected_response = {"membership": Membership.JOIN}
self.assertEqual(expected_response, channel.json_body)
- def test_rooms_members_other(self):
+ def test_rooms_members_other(self) -> None:
self.other_id = "@zzsid1:red"
path = "/rooms/%s/state/m.room.member/%s" % (
urlparse.quote(self.room_id),
@@ -830,11 +833,11 @@ class RoomMemberStateTestCase(RoomBase):
channel = self.make_request("PUT", path, content)
self.assertEqual(200, channel.code, msg=channel.result["body"])
- channel = self.make_request("GET", path, None)
+ channel = self.make_request("GET", path, content=b"")
self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertEqual(json.loads(content), channel.json_body)
- def test_rooms_members_other_custom_keys(self):
+ def test_rooms_members_other_custom_keys(self) -> None:
self.other_id = "@zzsid1:red"
path = "/rooms/%s/state/m.room.member/%s" % (
urlparse.quote(self.room_id),
@@ -849,7 +852,7 @@ class RoomMemberStateTestCase(RoomBase):
channel = self.make_request("PUT", path, content)
self.assertEqual(200, channel.code, msg=channel.result["body"])
- channel = self.make_request("GET", path, None)
+ channel = self.make_request("GET", path, content=b"")
self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertEqual(json.loads(content), channel.json_body)
@@ -866,7 +869,7 @@ class RoomInviteRatelimitTestCase(RoomBase):
@unittest.override_config(
{"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}}
)
- def test_invites_by_rooms_ratelimit(self):
+ def test_invites_by_rooms_ratelimit(self) -> None:
"""Tests that invites in a room are actually rate-limited."""
room_id = self.helper.create_room_as(self.user_id)
@@ -878,7 +881,7 @@ class RoomInviteRatelimitTestCase(RoomBase):
@unittest.override_config(
{"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
)
- def test_invites_by_users_ratelimit(self):
+ def test_invites_by_users_ratelimit(self) -> None:
"""Tests that invites to a specific user are actually rate-limited."""
for _ in range(3):
@@ -897,7 +900,7 @@ class RoomJoinTestCase(RoomBase):
room.register_servlets,
]
- def prepare(self, reactor, clock, homeserver):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user1 = self.register_user("thomas", "hackme")
self.tok1 = self.login("thomas", "hackme")
@@ -908,7 +911,7 @@ class RoomJoinTestCase(RoomBase):
self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
- def test_spam_checker_may_join_room(self):
+ def test_spam_checker_may_join_room(self) -> None:
"""Tests that the user_may_join_room spam checker callback is correctly called
and blocks room joins when needed.
"""
@@ -975,8 +978,8 @@ class RoomJoinRatelimitTestCase(RoomBase):
room.register_servlets,
]
- def prepare(self, reactor, clock, homeserver):
- super().prepare(reactor, clock, homeserver)
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ super().prepare(reactor, clock, hs)
# profile changes expect that the user is actually registered
user = UserID.from_string(self.user_id)
self.get_success(self.register_user(user.localpart, "supersecretpassword"))
@@ -984,7 +987,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
@unittest.override_config(
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
)
- def test_join_local_ratelimit(self):
+ def test_join_local_ratelimit(self) -> None:
"""Tests that local joins are actually rate-limited."""
for _ in range(3):
self.helper.create_room_as(self.user_id)
@@ -994,7 +997,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
@unittest.override_config(
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
)
- def test_join_local_ratelimit_profile_change(self):
+ def test_join_local_ratelimit_profile_change(self) -> None:
"""Tests that sending a profile update into all of the user's joined rooms isn't
rate-limited by the rate-limiter on joins."""
@@ -1031,7 +1034,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
@unittest.override_config(
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
)
- def test_join_local_ratelimit_idempotent(self):
+ def test_join_local_ratelimit_idempotent(self) -> None:
"""Tests that the room join endpoints remain idempotent despite rate-limiting
on room joins."""
room_id = self.helper.create_room_as(self.user_id)
@@ -1056,7 +1059,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
"autocreate_auto_join_rooms": True,
},
)
- def test_autojoin_rooms(self):
+ def test_autojoin_rooms(self) -> None:
user_id = self.register_user("testuser", "password")
# Check that the new user successfully joined the four rooms
@@ -1071,10 +1074,10 @@ class RoomMessagesTestCase(RoomBase):
user_id = "@sid1:red"
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_id = self.helper.create_room_as(self.user_id)
- def test_invalid_puts(self):
+ def test_invalid_puts(self) -> None:
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
# missing keys or invalid json
channel = self.make_request("PUT", path, b"{}")
@@ -1095,7 +1098,7 @@ class RoomMessagesTestCase(RoomBase):
channel = self.make_request("PUT", path, b"")
self.assertEqual(400, channel.code, msg=channel.result["body"])
- def test_rooms_messages_sent(self):
+ def test_rooms_messages_sent(self) -> None:
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
content = b'{"body":"test","msgtype":{"type":"a"}}'
@@ -1119,11 +1122,11 @@ class RoomInitialSyncTestCase(RoomBase):
user_id = "@sid1:red"
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# create the room
self.room_id = self.helper.create_room_as(self.user_id)
- def test_initial_sync(self):
+ def test_initial_sync(self) -> None:
channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id)
self.assertEqual(200, channel.code)
@@ -1131,7 +1134,7 @@ class RoomInitialSyncTestCase(RoomBase):
self.assertEqual("join", channel.json_body["membership"])
# Room state is easier to assert on if we unpack it into a dict
- state = {}
+ state: JsonDict = {}
for event in channel.json_body["state"]:
if "state_key" not in event:
continue
@@ -1160,10 +1163,10 @@ class RoomMessageListTestCase(RoomBase):
user_id = "@sid1:red"
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_id = self.helper.create_room_as(self.user_id)
- def test_topo_token_is_accepted(self):
+ def test_topo_token_is_accepted(self) -> None:
token = "t1-0_0_0_0_0_0_0_0_0"
channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
@@ -1174,7 +1177,7 @@ class RoomMessageListTestCase(RoomBase):
self.assertTrue("chunk" in channel.json_body)
self.assertTrue("end" in channel.json_body)
- def test_stream_token_is_accepted_for_fwd_pagianation(self):
+ def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None:
token = "s0_0_0_0_0_0_0_0_0"
channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
@@ -1185,7 +1188,7 @@ class RoomMessageListTestCase(RoomBase):
self.assertTrue("chunk" in channel.json_body)
self.assertTrue("end" in channel.json_body)
- def test_room_messages_purge(self):
+ def test_room_messages_purge(self) -> None:
store = self.hs.get_datastores().main
pagination_handler = self.hs.get_pagination_handler()
@@ -1278,10 +1281,10 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
user_id = True
hijack_auth = False
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# Register the user who does the searching
- self.user_id = self.register_user("user", "pass")
+ self.user_id2 = self.register_user("user", "pass")
self.access_token = self.login("user", "pass")
# Register the user who sends the message
@@ -1289,12 +1292,12 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
self.other_access_token = self.login("otheruser", "pass")
# Create a room
- self.room = self.helper.create_room_as(self.user_id, tok=self.access_token)
+ self.room = self.helper.create_room_as(self.user_id2, tok=self.access_token)
# Invite the other person
self.helper.invite(
room=self.room,
- src=self.user_id,
+ src=self.user_id2,
tok=self.access_token,
targ=self.other_user_id,
)
@@ -1304,7 +1307,7 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
room=self.room, user=self.other_user_id, tok=self.other_access_token
)
- def test_finds_message(self):
+ def test_finds_message(self) -> None:
"""
The search functionality will search for content in messages if asked to
do so.
@@ -1333,7 +1336,7 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
# No context was requested, so we should get none.
self.assertEqual(results["results"][0]["context"], {})
- def test_include_context(self):
+ def test_include_context(self) -> None:
"""
When event_context includes include_profile, profile information will be
included in the search response.
@@ -1379,7 +1382,7 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.url = b"/_matrix/client/r0/publicRooms"
@@ -1389,11 +1392,11 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
return self.hs
- def test_restricted_no_auth(self):
+ def test_restricted_no_auth(self) -> None:
channel = self.make_request("GET", self.url)
self.assertEqual(channel.code, 401, channel.result)
- def test_restricted_auth(self):
+ def test_restricted_auth(self) -> None:
self.register_user("user", "pass")
tok = self.login("user", "pass")
@@ -1412,19 +1415,19 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(federation_client=Mock())
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.register_user("user", "pass")
self.token = self.login("user", "pass")
self.federation_client = hs.get_federation_client()
- def test_simple(self):
+ def test_simple(self) -> None:
"Simple test for searching rooms over federation"
- self.federation_client.get_public_rooms.side_effect = (
- lambda *a, **k: defer.succeed({})
+ self.federation_client.get_public_rooms.side_effect = lambda *a, **k: defer.succeed( # type: ignore[attr-defined]
+ {}
)
search_filter = {"generic_search_term": "foobar"}
@@ -1437,7 +1440,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200, channel.result)
- self.federation_client.get_public_rooms.assert_called_once_with(
+ self.federation_client.get_public_rooms.assert_called_once_with( # type: ignore[attr-defined]
"testserv",
limit=100,
since_token=None,
@@ -1446,12 +1449,12 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
third_party_instance_id=None,
)
- def test_fallback(self):
+ def test_fallback(self) -> None:
"Test that searching public rooms over federation falls back if it gets a 404"
# The `get_public_rooms` should be called again if the first call fails
# with a 404, when using search filters.
- self.federation_client.get_public_rooms.side_effect = (
+ self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined]
HttpResponseException(404, "Not Found", b""),
defer.succeed({}),
)
@@ -1466,7 +1469,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200, channel.result)
- self.federation_client.get_public_rooms.assert_has_calls(
+ self.federation_client.get_public_rooms.assert_has_calls( # type: ignore[attr-defined]
[
call(
"testserv",
@@ -1497,14 +1500,14 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
profile.register_servlets,
]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["allow_per_room_profiles"] = False
self.hs = self.setup_test_homeserver(config=config)
return self.hs
- def prepare(self, reactor, clock, homeserver):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("test", "test")
self.tok = self.login("test", "test")
@@ -1522,7 +1525,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
- def test_per_room_profile_forbidden(self):
+ def test_per_room_profile_forbidden(self) -> None:
data = {"membership": "join", "displayname": "other test user"}
request_data = json.dumps(data)
channel = self.make_request(
@@ -1557,7 +1560,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, homeserver):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.creator = self.register_user("creator", "test")
self.creator_tok = self.login("creator", "test")
@@ -1566,7 +1569,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok)
- def test_join_reason(self):
+ def test_join_reason(self) -> None:
reason = "hello"
channel = self.make_request(
"POST",
@@ -1578,7 +1581,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self._check_for_reason(reason)
- def test_leave_reason(self):
+ def test_leave_reason(self) -> None:
self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
reason = "hello"
@@ -1592,7 +1595,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self._check_for_reason(reason)
- def test_kick_reason(self):
+ def test_kick_reason(self) -> None:
self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
reason = "hello"
@@ -1606,7 +1609,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self._check_for_reason(reason)
- def test_ban_reason(self):
+ def test_ban_reason(self) -> None:
self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
reason = "hello"
@@ -1620,7 +1623,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self._check_for_reason(reason)
- def test_unban_reason(self):
+ def test_unban_reason(self) -> None:
reason = "hello"
channel = self.make_request(
"POST",
@@ -1632,7 +1635,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self._check_for_reason(reason)
- def test_invite_reason(self):
+ def test_invite_reason(self) -> None:
reason = "hello"
channel = self.make_request(
"POST",
@@ -1644,7 +1647,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self._check_for_reason(reason)
- def test_reject_invite_reason(self):
+ def test_reject_invite_reason(self) -> None:
self.helper.invite(
self.room_id,
src=self.creator,
@@ -1663,7 +1666,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self._check_for_reason(reason)
- def _check_for_reason(self, reason):
+ def _check_for_reason(self, reason: str) -> None:
channel = self.make_request(
"GET",
"/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format(
@@ -1704,12 +1707,12 @@ class LabelsTestCase(unittest.HomeserverTestCase):
"org.matrix.not_labels": ["#notfun"],
}
- def prepare(self, reactor, clock, homeserver):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("test", "test")
self.tok = self.login("test", "test")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
- def test_context_filter_labels(self):
+ def test_context_filter_labels(self) -> None:
"""Test that we can filter by a label on a /context request."""
event_id = self._send_labelled_messages_in_room()
@@ -1739,7 +1742,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
events_after[0]["content"]["body"], "with right label", events_after[0]
)
- def test_context_filter_not_labels(self):
+ def test_context_filter_not_labels(self) -> None:
"""Test that we can filter by the absence of a label on a /context request."""
event_id = self._send_labelled_messages_in_room()
@@ -1772,7 +1775,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
events_after[1]["content"]["body"], "with two wrong labels", events_after[1]
)
- def test_context_filter_labels_not_labels(self):
+ def test_context_filter_labels_not_labels(self) -> None:
"""Test that we can filter by both a label and the absence of another label on a
/context request.
"""
@@ -1801,7 +1804,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
events_after[0]["content"]["body"], "with wrong label", events_after[0]
)
- def test_messages_filter_labels(self):
+ def test_messages_filter_labels(self) -> None:
"""Test that we can filter by a label on a /messages request."""
self._send_labelled_messages_in_room()
@@ -1818,7 +1821,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
- def test_messages_filter_not_labels(self):
+ def test_messages_filter_not_labels(self) -> None:
"""Test that we can filter by the absence of a label on a /messages request."""
self._send_labelled_messages_in_room()
@@ -1839,7 +1842,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
events[3]["content"]["body"], "with two wrong labels", events[3]
)
- def test_messages_filter_labels_not_labels(self):
+ def test_messages_filter_labels_not_labels(self) -> None:
"""Test that we can filter by both a label and the absence of another label on a
/messages request.
"""
@@ -1862,7 +1865,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(events), 1, [event["content"] for event in events])
self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
- def test_search_filter_labels(self):
+ def test_search_filter_labels(self) -> None:
"""Test that we can filter by a label on a /search request."""
request_data = json.dumps(
{
@@ -1899,7 +1902,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
results[1]["result"]["content"]["body"],
)
- def test_search_filter_not_labels(self):
+ def test_search_filter_not_labels(self) -> None:
"""Test that we can filter by the absence of a label on a /search request."""
request_data = json.dumps(
{
@@ -1946,7 +1949,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
results[3]["result"]["content"]["body"],
)
- def test_search_filter_labels_not_labels(self):
+ def test_search_filter_labels_not_labels(self) -> None:
"""Test that we can filter by both a label and the absence of another label on a
/search request.
"""
@@ -1980,7 +1983,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
results[0]["result"]["content"]["body"],
)
- def _send_labelled_messages_in_room(self):
+ def _send_labelled_messages_in_room(self) -> str:
"""Sends several messages to a room with different labels (or without any) to test
filtering by label.
Returns:
@@ -2056,12 +2059,12 @@ class RelationsTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["experimental_features"] = {"msc3440_enabled": True}
return config
- def prepare(self, reactor, clock, homeserver):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("test", "test")
self.tok = self.login("test", "test")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
@@ -2136,7 +2139,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
return channel.json_body["chunk"]
- def test_filter_relation_senders(self):
+ def test_filter_relation_senders(self) -> None:
# Messages which second user reacted to.
filter = {"io.element.relation_senders": [self.second_user_id]}
chunk = self._filter_messages(filter)
@@ -2159,7 +2162,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
[c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
)
- def test_filter_relation_type(self):
+ def test_filter_relation_type(self) -> None:
# Messages which have annotations.
filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]}
chunk = self._filter_messages(filter)
@@ -2185,7 +2188,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
[c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
)
- def test_filter_relation_senders_and_type(self):
+ def test_filter_relation_senders_and_type(self) -> None:
# Messages which second user reacted to.
filter = {
"io.element.relation_senders": [self.second_user_id],
@@ -2205,7 +2208,7 @@ class ContextTestCase(unittest.HomeserverTestCase):
account.register_servlets,
]
- def prepare(self, reactor, clock, homeserver):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("user", "password")
self.tok = self.login("user", "password")
self.room_id = self.helper.create_room_as(
@@ -2218,7 +2221,7 @@ class ContextTestCase(unittest.HomeserverTestCase):
self.helper.invite(self.room_id, self.user_id, self.other_user_id, tok=self.tok)
self.helper.join(self.room_id, self.other_user_id, tok=self.other_tok)
- def test_erased_sender(self):
+ def test_erased_sender(self) -> None:
"""Test that an erasure request results in the requester's events being hidden
from any new member of the room.
"""
@@ -2332,7 +2335,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def prepare(self, reactor, clock, homeserver):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_owner = self.register_user("room_owner", "test")
self.room_owner_tok = self.login("room_owner", "test")
@@ -2340,17 +2343,17 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
self.room_owner, tok=self.room_owner_tok
)
- def test_no_aliases(self):
+ def test_no_aliases(self) -> None:
res = self._get_aliases(self.room_owner_tok)
self.assertEqual(res["aliases"], [])
- def test_not_in_room(self):
+ def test_not_in_room(self) -> None:
self.register_user("user", "test")
user_tok = self.login("user", "test")
res = self._get_aliases(user_tok, expected_code=403)
self.assertEqual(res["errcode"], "M_FORBIDDEN")
- def test_admin_user(self):
+ def test_admin_user(self) -> None:
alias1 = self._random_alias()
self._set_alias_via_directory(alias1)
@@ -2360,7 +2363,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
res = self._get_aliases(user_tok)
self.assertEqual(res["aliases"], [alias1])
- def test_with_aliases(self):
+ def test_with_aliases(self) -> None:
alias1 = self._random_alias()
alias2 = self._random_alias()
@@ -2370,7 +2373,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
res = self._get_aliases(self.room_owner_tok)
self.assertEqual(set(res["aliases"]), {alias1, alias2})
- def test_peekable_room(self):
+ def test_peekable_room(self) -> None:
alias1 = self._random_alias()
self._set_alias_via_directory(alias1)
@@ -2404,7 +2407,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
def _random_alias(self) -> str:
return RoomAlias(random_string(5), self.hs.hostname).to_string()
- def _set_alias_via_directory(self, alias: str, expected_code: int = 200):
+ def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None:
url = "/_matrix/client/r0/directory/room/" + alias
data = {"room_id": self.room_id}
request_data = json.dumps(data)
@@ -2423,7 +2426,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def prepare(self, reactor, clock, homeserver):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_owner = self.register_user("room_owner", "test")
self.room_owner_tok = self.login("room_owner", "test")
@@ -2434,7 +2437,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
self.alias = "#alias:test"
self._set_alias_via_directory(self.alias)
- def _set_alias_via_directory(self, alias: str, expected_code: int = 200):
+ def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None:
url = "/_matrix/client/r0/directory/room/" + alias
data = {"room_id": self.room_id}
request_data = json.dumps(data)
@@ -2456,7 +2459,9 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
self.assertIsInstance(res, dict)
return res
- def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict:
+ def _set_canonical_alias(
+ self, content: JsonDict, expected_code: int = 200
+ ) -> JsonDict:
"""Calls the endpoint under test. returns the json response object."""
channel = self.make_request(
"PUT",
@@ -2469,7 +2474,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
self.assertIsInstance(res, dict)
return res
- def test_canonical_alias(self):
+ def test_canonical_alias(self) -> None:
"""Test a basic alias message."""
# There is no canonical alias to start with.
self._get_canonical_alias(expected_code=404)
@@ -2488,7 +2493,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
res = self._get_canonical_alias()
self.assertEqual(res, {})
- def test_alt_aliases(self):
+ def test_alt_aliases(self) -> None:
"""Test a canonical alias message with alt_aliases."""
# Create an alias.
self._set_canonical_alias({"alt_aliases": [self.alias]})
@@ -2504,7 +2509,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
res = self._get_canonical_alias()
self.assertEqual(res, {})
- def test_alias_alt_aliases(self):
+ def test_alias_alt_aliases(self) -> None:
"""Test a canonical alias message with an alias and alt_aliases."""
# Create an alias.
self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
@@ -2520,7 +2525,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
res = self._get_canonical_alias()
self.assertEqual(res, {})
- def test_partial_modify(self):
+ def test_partial_modify(self) -> None:
"""Test removing only the alt_aliases."""
# Create an alias.
self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
@@ -2536,7 +2541,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
res = self._get_canonical_alias()
self.assertEqual(res, {"alias": self.alias})
- def test_add_alias(self):
+ def test_add_alias(self) -> None:
"""Test removing only the alt_aliases."""
# Create an additional alias.
second_alias = "#second:test"
@@ -2556,7 +2561,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]}
)
- def test_bad_data(self):
+ def test_bad_data(self) -> None:
"""Invalid data for alt_aliases should cause errors."""
self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400)
self._set_canonical_alias({"alt_aliases": None}, expected_code=400)
@@ -2566,7 +2571,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
self._set_canonical_alias({"alt_aliases": True}, expected_code=400)
self._set_canonical_alias({"alt_aliases": {}}, expected_code=400)
- def test_bad_alias(self):
+ def test_bad_alias(self) -> None:
"""An alias which does not point to the room raises a SynapseError."""
self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400)
self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400)
@@ -2580,13 +2585,13 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def prepare(self, reactor, clock, homeserver):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("thomas", "hackme")
self.tok = self.login("thomas", "hackme")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
- def test_threepid_invite_spamcheck(self):
+ def test_threepid_invite_spamcheck(self) -> None:
# Mock a few functions to prevent the test from failing due to failing to talk to
# a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we
# can check its call_count later on during the test.
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index bfc04785b7..58f1ea11b7 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -12,16 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
-from typing import TYPE_CHECKING, Dict, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from unittest.mock import Mock
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EventTypes, LoginType, Membership
from synapse.api.errors import SynapseError
+from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.rest import admin
from synapse.rest.client import account, login, profile, room
+from synapse.server import HomeServer
from synapse.types import JsonDict, Requester, StateMap
+from synapse.util import Clock
from synapse.util.frozenutils import unfreeze
from tests import unittest
@@ -34,7 +40,7 @@ thread_local = threading.local()
class LegacyThirdPartyRulesTestModule:
- def __init__(self, config: Dict, module_api: "ModuleApi"):
+ def __init__(self, config: Dict, module_api: "ModuleApi") -> None:
# keep a record of the "current" rules module, so that the test can patch
# it if desired.
thread_local.rules_module = self
@@ -42,32 +48,36 @@ class LegacyThirdPartyRulesTestModule:
async def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool
- ):
+ ) -> bool:
return True
- async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
+ async def check_event_allowed(
+ self, event: EventBase, state: StateMap[EventBase]
+ ) -> Union[bool, dict]:
return True
@staticmethod
- def parse_config(config):
+ def parse_config(config: Dict[str, Any]) -> Dict[str, Any]:
return config
class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule):
- def __init__(self, config: Dict, module_api: "ModuleApi"):
+ def __init__(self, config: Dict, module_api: "ModuleApi") -> None:
super().__init__(config, module_api)
- def on_create_room(
+ async def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool
- ):
+ ) -> bool:
return False
class LegacyChangeEvents(LegacyThirdPartyRulesTestModule):
- def __init__(self, config: Dict, module_api: "ModuleApi"):
+ def __init__(self, config: Dict, module_api: "ModuleApi") -> None:
super().__init__(config, module_api)
- async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
+ async def check_event_allowed(
+ self, event: EventBase, state: StateMap[EventBase]
+ ) -> JsonDict:
d = event.get_dict()
content = unfreeze(event.content)
content["foo"] = "bar"
@@ -84,7 +94,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
account.register_servlets,
]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver()
load_legacy_third_party_event_rules(hs)
@@ -94,22 +104,30 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
# Note that these checks are not relevant to this test case.
# Have this homeserver auto-approve all event signature checking.
- async def approve_all_signature_checking(_, pdu):
+ async def approve_all_signature_checking(
+ _: RoomVersion, pdu: EventBase
+ ) -> EventBase:
return pdu
- hs.get_federation_server()._check_sigs_and_hash = approve_all_signature_checking
+ hs.get_federation_server()._check_sigs_and_hash = approve_all_signature_checking # type: ignore[assignment]
# Have this homeserver skip event auth checks. This is necessary due to
# event auth checks ensuring that events were signed by the sender's homeserver.
- async def _check_event_auth(origin, event, context, *args, **kwargs):
+ async def _check_event_auth(
+ origin: str,
+ event: EventBase,
+ context: EventContext,
+ *args: Any,
+ **kwargs: Any,
+ ) -> EventContext:
return context
- hs.get_federation_event_handler()._check_event_auth = _check_event_auth
+ hs.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment]
return hs
- def prepare(self, reactor, clock, homeserver):
- super().prepare(reactor, clock, homeserver)
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ super().prepare(reactor, clock, hs)
# Create some users and a room to play with during the tests
self.user_id = self.register_user("kermit", "monkey")
self.invitee = self.register_user("invitee", "hackme")
@@ -121,13 +139,15 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
except Exception:
pass
- def test_third_party_rules(self):
+ def test_third_party_rules(self) -> None:
"""Tests that a forbidden event is forbidden from being sent, but an allowed one
can be sent.
"""
# patch the rules module with a Mock which will return False for some event
# types
- async def check(ev, state):
+ async def check(
+ ev: EventBase, state: StateMap[EventBase]
+ ) -> Tuple[bool, Optional[JsonDict]]:
return ev.type != "foo.bar.forbidden", None
callback = Mock(spec=[], side_effect=check)
@@ -161,7 +181,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
)
self.assertEqual(channel.result["code"], b"403", channel.result)
- def test_third_party_rules_workaround_synapse_errors_pass_through(self):
+ def test_third_party_rules_workaround_synapse_errors_pass_through(self) -> None:
"""
Tests that the workaround introduced by https://github.com/matrix-org/synapse/pull/11042
is functional: that SynapseErrors are passed through from check_event_allowed
@@ -172,7 +192,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"""
class NastyHackException(SynapseError):
- def error_dict(self):
+ def error_dict(self) -> JsonDict:
"""
This overrides SynapseError's `error_dict` to nastily inject
JSON into the error response.
@@ -182,7 +202,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
return result
# add a callback that will raise our hacky exception
- async def check(ev, state) -> Tuple[bool, Optional[JsonDict]]:
+ async def check(
+ ev: EventBase, state: StateMap[EventBase]
+ ) -> Tuple[bool, Optional[JsonDict]]:
raise NastyHackException(429, "message")
self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
@@ -202,11 +224,13 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
{"errcode": "M_UNKNOWN", "error": "message", "nasty": "very"},
)
- def test_cannot_modify_event(self):
+ def test_cannot_modify_event(self) -> None:
"""cannot accidentally modify an event before it is persisted"""
# first patch the event checker so that it will try to modify the event
- async def check(ev: EventBase, state):
+ async def check(
+ ev: EventBase, state: StateMap[EventBase]
+ ) -> Tuple[bool, Optional[JsonDict]]:
ev.content = {"x": "y"}
return True, None
@@ -223,10 +247,12 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
# 500 Internal Server Error
self.assertEqual(channel.code, 500, channel.result)
- def test_modify_event(self):
+ def test_modify_event(self) -> None:
"""The module can return a modified version of the event"""
# first patch the event checker so that it will modify the event
- async def check(ev: EventBase, state):
+ async def check(
+ ev: EventBase, state: StateMap[EventBase]
+ ) -> Tuple[bool, Optional[JsonDict]]:
d = ev.get_dict()
d["content"] = {"x": "y"}
return True, d
@@ -253,10 +279,12 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
ev = channel.json_body
self.assertEqual(ev["content"]["x"], "y")
- def test_message_edit(self):
+ def test_message_edit(self) -> None:
"""Ensure that the module doesn't cause issues with edited messages."""
# first patch the event checker so that it will modify the event
- async def check(ev: EventBase, state):
+ async def check(
+ ev: EventBase, state: StateMap[EventBase]
+ ) -> Tuple[bool, Optional[JsonDict]]:
d = ev.get_dict()
d["content"] = {
"msgtype": "m.text",
@@ -315,7 +343,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
ev = channel.json_body
self.assertEqual(ev["content"]["body"], "EDITED BODY")
- def test_send_event(self):
+ def test_send_event(self) -> None:
"""Tests that a module can send an event into a room via the module api"""
content = {
"msgtype": "m.text",
@@ -344,7 +372,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
}
}
)
- def test_legacy_check_event_allowed(self):
+ def test_legacy_check_event_allowed(self) -> None:
"""Tests that the wrapper for legacy check_event_allowed callbacks works
correctly.
"""
@@ -379,13 +407,13 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
}
}
)
- def test_legacy_on_create_room(self):
+ def test_legacy_on_create_room(self) -> None:
"""Tests that the wrapper for legacy on_create_room callbacks works
correctly.
"""
self.helper.create_room_as(self.user_id, tok=self.tok, expect_code=403)
- def test_sent_event_end_up_in_room_state(self):
+ def test_sent_event_end_up_in_room_state(self) -> None:
"""Tests that a state event sent by a module while processing another state event
doesn't get dropped from the state of the room. This is to guard against a bug
where Synapse has been observed doing so, see https://github.com/matrix-org/synapse/issues/10830
@@ -400,7 +428,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
api = self.hs.get_module_api()
# Define a callback that sends a custom event on power levels update.
- async def test_fn(event: EventBase, state_events):
+ async def test_fn(
+ event: EventBase, state_events: StateMap[EventBase]
+ ) -> Tuple[bool, Optional[JsonDict]]:
if event.is_state and event.type == EventTypes.PowerLevels:
await api.create_and_send_event_into_room(
{
@@ -436,7 +466,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["i"], i)
- def test_on_new_event(self):
+ def test_on_new_event(self) -> None:
"""Test that the on_new_event callback is called on new events"""
on_new_event = Mock(make_awaitable(None))
self.hs.get_third_party_event_rules()._on_new_event_callbacks.append(
@@ -501,7 +531,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
- def _update_power_levels(self, event_default: int = 0):
+ def _update_power_levels(self, event_default: int = 0) -> None:
"""Updates the room's power levels.
Args:
@@ -533,7 +563,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
tok=self.tok,
)
- def test_on_profile_update(self):
+ def test_on_profile_update(self) -> None:
"""Tests that the on_profile_update module callback is correctly called on
profile updates.
"""
@@ -592,7 +622,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(profile_info.display_name, displayname)
self.assertEqual(profile_info.avatar_url, avatar_url)
- def test_on_profile_update_admin(self):
+ def test_on_profile_update_admin(self) -> None:
"""Tests that the on_profile_update module callback is correctly called on
profile updates triggered by a server admin.
"""
@@ -634,7 +664,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(profile_info.display_name, displayname)
self.assertEqual(profile_info.avatar_url, avatar_url)
- def test_on_user_deactivation_status_changed(self):
+ def test_on_user_deactivation_status_changed(self) -> None:
"""Tests that the on_user_deactivation_status_changed module callback is called
correctly when processing a user's deactivation.
"""
@@ -691,7 +721,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
args = profile_mock.call_args[0]
self.assertTrue(args[3])
- def test_on_user_deactivation_status_changed_admin(self):
+ def test_on_user_deactivation_status_changed_admin(self) -> None:
"""Tests that the on_user_deactivation_status_changed module callback is called
correctly when processing a user's deactivation triggered by a server admin as
well as a reactivation.
diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py
index 8b2da88e8a..d6da510773 100644
--- a/tests/rest/client/test_typing.py
+++ b/tests/rest/client/test_typing.py
@@ -15,10 +15,12 @@
"""Tests REST events for /rooms paths."""
-from unittest.mock import Mock
+from twisted.test.proto_helpers import MemoryReactor
from synapse.rest.client import room
+from synapse.server import HomeServer
from synapse.types import UserID
+from synapse.util import Clock
from tests import unittest
@@ -33,40 +35,17 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
user = UserID.from_string(user_id)
servlets = [room.register_servlets]
- def make_homeserver(self, reactor, clock):
-
- hs = self.setup_test_homeserver(
- "red",
- federation_http_client=None,
- federation_client=Mock(),
- )
-
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ hs = self.setup_test_homeserver("red")
self.event_source = hs.get_event_sources().sources.typing
-
- hs.get_federation_handler = Mock()
-
- async def get_user_by_access_token(token=None, allow_guest=False):
- return {
- "user": UserID.from_string(self.auth_user_id),
- "token_id": 1,
- "is_guest": False,
- }
-
- hs.get_auth().get_user_by_access_token = get_user_by_access_token
-
- async def _insert_client_ip(*args, **kwargs):
- return None
-
- hs.get_datastores().main.insert_client_ip = _insert_client_ip
-
return hs
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_id = self.helper.create_room_as(self.user_id)
# Need another user to make notifications actually work
self.helper.join(self.room_id, user="@jim:red")
- def test_set_typing(self):
+ def test_set_typing(self) -> None:
channel = self.make_request(
"PUT",
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
@@ -95,7 +74,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
],
)
- def test_set_not_typing(self):
+ def test_set_not_typing(self) -> None:
channel = self.make_request(
"PUT",
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
@@ -103,7 +82,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(200, channel.code)
- def test_typing_timeout(self):
+ def test_typing_timeout(self) -> None:
channel = self.make_request(
"PUT",
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 4672a68596..978c252f84 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -13,19 +13,24 @@
# limitations under the License.
import urllib.parse
from io import BytesIO, StringIO
+from typing import Any, Dict, Optional, Union
from unittest.mock import Mock
import signedjson.key
from canonicaljson import encode_canonical_json
-from nacl.signing import SigningKey
from signedjson.sign import sign_json
+from signedjson.types import SigningKey
-from twisted.web.resource import NoResource
+from twisted.test.proto_helpers import MemoryReactor
+from twisted.web.resource import NoResource, Resource
from synapse.crypto.keyring import PerspectivesKeyFetcher
from synapse.http.site import SynapseRequest
from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.server import HomeServer
from synapse.storage.keys import FetchKeyResult
+from synapse.types import JsonDict
+from synapse.util import Clock
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.stringutils import random_string
@@ -35,11 +40,11 @@ from tests.utils import default_config
class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.http_client = Mock()
return self.setup_test_homeserver(federation_http_client=self.http_client)
- def create_test_resource(self):
+ def create_test_resource(self) -> Resource:
return create_resource_tree(
{"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource()
)
@@ -51,7 +56,12 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
Tell the mock http client to expect an outgoing GET request for the given key
"""
- async def get_json(destination, path, ignore_backoff=False, **kwargs):
+ async def get_json(
+ destination: str,
+ path: str,
+ ignore_backoff: bool = False,
+ **kwargs: Any,
+ ) -> Union[JsonDict, list]:
self.assertTrue(ignore_backoff)
self.assertEqual(destination, server_name)
key_id = "%s:%s" % (signing_key.alg, signing_key.version)
@@ -84,7 +94,8 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
Checks that the response is a 200 and returns the decoded json body.
"""
channel = FakeChannel(self.site, self.reactor)
- req = SynapseRequest(channel, self.site)
+ # channel is a `FakeChannel` but `HTTPChannel` is expected
+ req = SynapseRequest(channel, self.site) # type: ignore[arg-type]
req.content = BytesIO(b"")
req.requestReceived(
b"GET",
@@ -97,7 +108,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
resp = channel.json_body
return resp
- def test_get_key(self):
+ def test_get_key(self) -> None:
"""Fetch a remote key"""
SERVER_NAME = "remote.server"
testkey = signedjson.key.generate_signing_key("ver1")
@@ -114,7 +125,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
self.assertIn(SERVER_NAME, keys[0]["signatures"])
self.assertIn(self.hs.hostname, keys[0]["signatures"])
- def test_get_own_key(self):
+ def test_get_own_key(self) -> None:
"""Fetch our own key"""
testkey = signedjson.key.generate_signing_key("ver1")
self.expect_outgoing_key_request(self.hs.hostname, testkey)
@@ -141,7 +152,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
endpoint, to check that the two implementations are compatible.
"""
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
# replace the signing key with our own
@@ -152,7 +163,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
return config
- def prepare(self, reactor, clock, homeserver):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# make a second homeserver, configured to use the first one as a key notary
self.http_client2 = Mock()
config = default_config(name="keyclient")
@@ -175,7 +186,9 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
# wire up outbound POST /key/v2/query requests from hs2 so that they
# will be forwarded to hs1
- async def post_json(destination, path, data):
+ async def post_json(
+ destination: str, path: str, data: Optional[JsonDict] = None
+ ) -> Union[JsonDict, list]:
self.assertEqual(destination, self.hs.hostname)
self.assertEqual(
path,
@@ -183,7 +196,8 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
)
channel = FakeChannel(self.site, self.reactor)
- req = SynapseRequest(channel, self.site)
+ # channel is a `FakeChannel` but `HTTPChannel` is expected
+ req = SynapseRequest(channel, self.site) # type: ignore[arg-type]
req.content = BytesIO(encode_canonical_json(data))
req.requestReceived(
@@ -198,7 +212,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
self.http_client2.post_json.side_effect = post_json
- def test_get_key(self):
+ def test_get_key(self) -> None:
"""Fetch a key belonging to a random server"""
# make up a key to be fetched.
testkey = signedjson.key.generate_signing_key("abc")
@@ -218,7 +232,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
signedjson.key.encode_verify_key_base64(testkey.verify_key),
)
- def test_get_notary_key(self):
+ def test_get_notary_key(self) -> None:
"""Fetch a key belonging to the notary server"""
# make up a key to be fetched. We randomise the keyid to try to get it to
# appear before the key server signing key sometimes (otherwise we bail out
@@ -240,7 +254,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
signedjson.key.encode_verify_key_base64(testkey.verify_key),
)
- def test_get_notary_keyserver_key(self):
+ def test_get_notary_keyserver_key(self) -> None:
"""Fetch the notary's keyserver key"""
# we expect hs1 to make a regular key request to itself
self.expect_outgoing_key_request(self.hs.hostname, self.hs_signing_key)
diff --git a/tests/rest/media/v1/test_base.py b/tests/rest/media/v1/test_base.py
index f761e23f1b..c73179151a 100644
--- a/tests/rest/media/v1/test_base.py
+++ b/tests/rest/media/v1/test_base.py
@@ -28,11 +28,11 @@ class GetFileNameFromHeadersTests(unittest.TestCase):
b"inline; filename*=utf-8''foo%C2%A3bar": "foo£bar",
}
- def tests(self):
+ def tests(self) -> None:
for hdr, expected in self.TEST_CASES.items():
res = get_filename_from_headers({b"Content-Disposition": [hdr]})
self.assertEqual(
res,
expected,
- "expected output for %s to be %s but was %s" % (hdr, expected, res),
+ f"expected output for {hdr!r} to be {expected} but was {res}",
)
diff --git a/tests/rest/media/v1/test_filepath.py b/tests/rest/media/v1/test_filepath.py
index 913bc530aa..43e6f0f70a 100644
--- a/tests/rest/media/v1/test_filepath.py
+++ b/tests/rest/media/v1/test_filepath.py
@@ -21,12 +21,12 @@ from tests import unittest
class MediaFilePathsTestCase(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
super().setUp()
self.filepaths = MediaFilePaths("/media_store")
- def test_local_media_filepath(self):
+ def test_local_media_filepath(self) -> None:
"""Test local media paths"""
self.assertEqual(
self.filepaths.local_media_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"),
@@ -37,7 +37,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/local_content/Ge/rZ/NDnDZVjsOtardLuwfIBg",
)
- def test_local_media_thumbnail(self):
+ def test_local_media_thumbnail(self) -> None:
"""Test local media thumbnail paths"""
self.assertEqual(
self.filepaths.local_media_thumbnail_rel(
@@ -52,14 +52,14 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
)
- def test_local_media_thumbnail_dir(self):
+ def test_local_media_thumbnail_dir(self) -> None:
"""Test local media thumbnail directory paths"""
self.assertEqual(
self.filepaths.local_media_thumbnail_dir("GerZNDnDZVjsOtardLuwfIBg"),
"/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
)
- def test_remote_media_filepath(self):
+ def test_remote_media_filepath(self) -> None:
"""Test remote media paths"""
self.assertEqual(
self.filepaths.remote_media_filepath_rel(
@@ -74,7 +74,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/remote_content/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg",
)
- def test_remote_media_thumbnail(self):
+ def test_remote_media_thumbnail(self) -> None:
"""Test remote media thumbnail paths"""
self.assertEqual(
self.filepaths.remote_media_thumbnail_rel(
@@ -99,7 +99,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
)
- def test_remote_media_thumbnail_legacy(self):
+ def test_remote_media_thumbnail_legacy(self) -> None:
"""Test old-style remote media thumbnail paths"""
self.assertEqual(
self.filepaths.remote_media_thumbnail_rel_legacy(
@@ -108,7 +108,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg",
)
- def test_remote_media_thumbnail_dir(self):
+ def test_remote_media_thumbnail_dir(self) -> None:
"""Test remote media thumbnail directory paths"""
self.assertEqual(
self.filepaths.remote_media_thumbnail_dir(
@@ -117,7 +117,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg",
)
- def test_url_cache_filepath(self):
+ def test_url_cache_filepath(self) -> None:
"""Test URL cache paths"""
self.assertEqual(
self.filepaths.url_cache_filepath_rel("2020-01-02_GerZNDnDZVjsOtar"),
@@ -128,7 +128,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache/2020-01-02/GerZNDnDZVjsOtar",
)
- def test_url_cache_filepath_legacy(self):
+ def test_url_cache_filepath_legacy(self) -> None:
"""Test old-style URL cache paths"""
self.assertEqual(
self.filepaths.url_cache_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"),
@@ -139,7 +139,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache/Ge/rZ/NDnDZVjsOtardLuwfIBg",
)
- def test_url_cache_filepath_dirs_to_delete(self):
+ def test_url_cache_filepath_dirs_to_delete(self) -> None:
"""Test URL cache cleanup paths"""
self.assertEqual(
self.filepaths.url_cache_filepath_dirs_to_delete(
@@ -148,7 +148,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
["/media_store/url_cache/2020-01-02"],
)
- def test_url_cache_filepath_dirs_to_delete_legacy(self):
+ def test_url_cache_filepath_dirs_to_delete_legacy(self) -> None:
"""Test old-style URL cache cleanup paths"""
self.assertEqual(
self.filepaths.url_cache_filepath_dirs_to_delete(
@@ -160,7 +160,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
],
)
- def test_url_cache_thumbnail(self):
+ def test_url_cache_thumbnail(self) -> None:
"""Test URL cache thumbnail paths"""
self.assertEqual(
self.filepaths.url_cache_thumbnail_rel(
@@ -175,7 +175,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar/800-600-image-jpeg-scale",
)
- def test_url_cache_thumbnail_legacy(self):
+ def test_url_cache_thumbnail_legacy(self) -> None:
"""Test old-style URL cache thumbnail paths"""
self.assertEqual(
self.filepaths.url_cache_thumbnail_rel(
@@ -190,7 +190,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
)
- def test_url_cache_thumbnail_directory(self):
+ def test_url_cache_thumbnail_directory(self) -> None:
"""Test URL cache thumbnail directory paths"""
self.assertEqual(
self.filepaths.url_cache_thumbnail_directory_rel(
@@ -203,7 +203,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar",
)
- def test_url_cache_thumbnail_directory_legacy(self):
+ def test_url_cache_thumbnail_directory_legacy(self) -> None:
"""Test old-style URL cache thumbnail directory paths"""
self.assertEqual(
self.filepaths.url_cache_thumbnail_directory_rel(
@@ -216,7 +216,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
)
- def test_url_cache_thumbnail_dirs_to_delete(self):
+ def test_url_cache_thumbnail_dirs_to_delete(self) -> None:
"""Test URL cache thumbnail cleanup paths"""
self.assertEqual(
self.filepaths.url_cache_thumbnail_dirs_to_delete(
@@ -228,7 +228,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
],
)
- def test_url_cache_thumbnail_dirs_to_delete_legacy(self):
+ def test_url_cache_thumbnail_dirs_to_delete_legacy(self) -> None:
"""Test old-style URL cache thumbnail cleanup paths"""
self.assertEqual(
self.filepaths.url_cache_thumbnail_dirs_to_delete(
@@ -241,7 +241,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
],
)
- def test_server_name_validation(self):
+ def test_server_name_validation(self) -> None:
"""Test validation of server names"""
self._test_path_validation(
[
@@ -274,7 +274,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
],
)
- def test_file_id_validation(self):
+ def test_file_id_validation(self) -> None:
"""Test validation of local, remote and legacy URL cache file / media IDs"""
# File / media IDs get split into three parts to form paths, consisting of the
# first two characters, next two characters and rest of the ID.
@@ -357,7 +357,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
invalid_values=invalid_file_ids,
)
- def test_url_cache_media_id_validation(self):
+ def test_url_cache_media_id_validation(self) -> None:
"""Test validation of URL cache media IDs"""
self._test_path_validation(
[
@@ -387,7 +387,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
],
)
- def test_content_type_validation(self):
+ def test_content_type_validation(self) -> None:
"""Test validation of thumbnail content types"""
self._test_path_validation(
[
@@ -410,7 +410,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
],
)
- def test_thumbnail_method_validation(self):
+ def test_thumbnail_method_validation(self) -> None:
"""Test validation of thumbnail methods"""
self._test_path_validation(
[
@@ -440,7 +440,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
parameter: str,
valid_values: Iterable[str],
invalid_values: Iterable[str],
- ):
+ ) -> None:
"""Test that the specified methods validate the named parameter as expected
Args:
diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/rest/media/v1/test_html_preview.py
index a4b57e3d1f..3fb37a2a59 100644
--- a/tests/rest/media/v1/test_html_preview.py
+++ b/tests/rest/media/v1/test_html_preview.py
@@ -32,7 +32,7 @@ class SummarizeTestCase(unittest.TestCase):
if not lxml:
skip = "url preview feature requires lxml"
- def test_long_summarize(self):
+ def test_long_summarize(self) -> None:
example_paras = [
"""Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:
Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in
@@ -90,7 +90,7 @@ class SummarizeTestCase(unittest.TestCase):
" Tromsøya had a population of 36,088. Substantial parts of the urban…",
)
- def test_short_summarize(self):
+ def test_short_summarize(self) -> None:
example_paras = [
"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -117,7 +117,7 @@ class SummarizeTestCase(unittest.TestCase):
" most of the year.",
)
- def test_small_then_large_summarize(self):
+ def test_small_then_large_summarize(self) -> None:
example_paras = [
"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -150,7 +150,7 @@ class CalcOgTestCase(unittest.TestCase):
if not lxml:
skip = "url preview feature requires lxml"
- def test_simple(self):
+ def test_simple(self) -> None:
html = b"""
<html>
<head><title>Foo</title></head>
@@ -165,7 +165,7 @@ class CalcOgTestCase(unittest.TestCase):
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
- def test_comment(self):
+ def test_comment(self) -> None:
html = b"""
<html>
<head><title>Foo</title></head>
@@ -181,7 +181,7 @@ class CalcOgTestCase(unittest.TestCase):
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
- def test_comment2(self):
+ def test_comment2(self) -> None:
html = b"""
<html>
<head><title>Foo</title></head>
@@ -206,7 +206,7 @@ class CalcOgTestCase(unittest.TestCase):
},
)
- def test_script(self):
+ def test_script(self) -> None:
html = b"""
<html>
<head><title>Foo</title></head>
@@ -222,7 +222,7 @@ class CalcOgTestCase(unittest.TestCase):
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
- def test_missing_title(self):
+ def test_missing_title(self) -> None:
html = b"""
<html>
<body>
@@ -236,7 +236,7 @@ class CalcOgTestCase(unittest.TestCase):
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
- def test_h1_as_title(self):
+ def test_h1_as_title(self) -> None:
html = b"""
<html>
<meta property="og:description" content="Some text."/>
@@ -251,7 +251,7 @@ class CalcOgTestCase(unittest.TestCase):
self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
- def test_missing_title_and_broken_h1(self):
+ def test_missing_title_and_broken_h1(self) -> None:
html = b"""
<html>
<body>
@@ -266,19 +266,19 @@ class CalcOgTestCase(unittest.TestCase):
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
- def test_empty(self):
+ def test_empty(self) -> None:
"""Test a body with no data in it."""
html = b""
tree = decode_body(html, "http://example.com/test.html")
self.assertIsNone(tree)
- def test_no_tree(self):
+ def test_no_tree(self) -> None:
"""A valid body with no tree in it."""
html = b"\x00"
tree = decode_body(html, "http://example.com/test.html")
self.assertIsNone(tree)
- def test_xml(self):
+ def test_xml(self) -> None:
"""Test decoding XML and ensure it works properly."""
# Note that the strip() call is important to ensure the xml tag starts
# at the initial byte.
@@ -293,7 +293,7 @@ class CalcOgTestCase(unittest.TestCase):
og = parse_html_to_open_graph(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
- def test_invalid_encoding(self):
+ def test_invalid_encoding(self) -> None:
"""An invalid character encoding should be ignored and treated as UTF-8, if possible."""
html = b"""
<html>
@@ -307,7 +307,7 @@ class CalcOgTestCase(unittest.TestCase):
og = parse_html_to_open_graph(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
- def test_invalid_encoding2(self):
+ def test_invalid_encoding2(self) -> None:
"""A body which doesn't match the sent character encoding."""
# Note that this contains an invalid UTF-8 sequence in the title.
html = b"""
@@ -322,7 +322,7 @@ class CalcOgTestCase(unittest.TestCase):
og = parse_html_to_open_graph(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
- def test_windows_1252(self):
+ def test_windows_1252(self) -> None:
"""A body which uses cp1252, but doesn't declare that."""
html = b"""
<html>
@@ -338,7 +338,7 @@ class CalcOgTestCase(unittest.TestCase):
class MediaEncodingTestCase(unittest.TestCase):
- def test_meta_charset(self):
+ def test_meta_charset(self) -> None:
"""A character encoding is found via the meta tag."""
encodings = _get_html_media_encodings(
b"""
@@ -363,7 +363,7 @@ class MediaEncodingTestCase(unittest.TestCase):
)
self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
- def test_meta_charset_underscores(self):
+ def test_meta_charset_underscores(self) -> None:
"""A character encoding contains underscore."""
encodings = _get_html_media_encodings(
b"""
@@ -376,7 +376,7 @@ class MediaEncodingTestCase(unittest.TestCase):
)
self.assertEqual(list(encodings), ["shift_jis", "utf-8", "cp1252"])
- def test_xml_encoding(self):
+ def test_xml_encoding(self) -> None:
"""A character encoding is found via the meta tag."""
encodings = _get_html_media_encodings(
b"""
@@ -388,7 +388,7 @@ class MediaEncodingTestCase(unittest.TestCase):
)
self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
- def test_meta_xml_encoding(self):
+ def test_meta_xml_encoding(self) -> None:
"""Meta tags take precedence over XML encoding."""
encodings = _get_html_media_encodings(
b"""
@@ -402,7 +402,7 @@ class MediaEncodingTestCase(unittest.TestCase):
)
self.assertEqual(list(encodings), ["utf-16", "ascii", "utf-8", "cp1252"])
- def test_content_type(self):
+ def test_content_type(self) -> None:
"""A character encoding is found via the Content-Type header."""
# Test a few variations of the header.
headers = (
@@ -417,12 +417,12 @@ class MediaEncodingTestCase(unittest.TestCase):
encodings = _get_html_media_encodings(b"", header)
self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
- def test_fallback(self):
+ def test_fallback(self) -> None:
"""A character encoding cannot be found in the body or header."""
encodings = _get_html_media_encodings(b"", "text/html")
self.assertEqual(list(encodings), ["utf-8", "cp1252"])
- def test_duplicates(self):
+ def test_duplicates(self) -> None:
"""Ensure each encoding is only attempted once."""
encodings = _get_html_media_encodings(
b"""
@@ -436,7 +436,7 @@ class MediaEncodingTestCase(unittest.TestCase):
)
self.assertEqual(list(encodings), ["utf-8", "cp1252"])
- def test_unknown_invalid(self):
+ def test_unknown_invalid(self) -> None:
"""A character encoding should be ignored if it is unknown or invalid."""
encodings = _get_html_media_encodings(
b"""
@@ -451,7 +451,7 @@ class MediaEncodingTestCase(unittest.TestCase):
class RebaseUrlTestCase(unittest.TestCase):
- def test_relative(self):
+ def test_relative(self) -> None:
"""Relative URLs should be resolved based on the context of the base URL."""
self.assertEqual(
rebase_url("subpage", "https://example.com/foo/"),
@@ -466,14 +466,14 @@ class RebaseUrlTestCase(unittest.TestCase):
"https://example.com/bar",
)
- def test_absolute(self):
+ def test_absolute(self) -> None:
"""Absolute URLs should not be modified."""
self.assertEqual(
rebase_url("https://alice.com/a/", "https://example.com/foo/"),
"https://alice.com/a/",
)
- def test_data(self):
+ def test_data(self) -> None:
"""Data URLs should not be modified."""
self.assertEqual(
rebase_url("data:,Hello%2C%20World%21", "https://example.com/foo/"),
diff --git a/tests/rest/media/v1/test_oembed.py b/tests/rest/media/v1/test_oembed.py
index 048d0ca44a..f38d7225f8 100644
--- a/tests/rest/media/v1/test_oembed.py
+++ b/tests/rest/media/v1/test_oembed.py
@@ -16,7 +16,7 @@ import json
from twisted.test.proto_helpers import MemoryReactor
-from synapse.rest.media.v1.oembed import OEmbedProvider
+from synapse.rest.media.v1.oembed import OEmbedProvider, OEmbedResult
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
@@ -25,15 +25,15 @@ from tests.unittest import HomeserverTestCase
class OEmbedTests(HomeserverTestCase):
- def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
- self.oembed = OEmbedProvider(homeserver)
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.oembed = OEmbedProvider(hs)
- def parse_response(self, response: JsonDict):
+ def parse_response(self, response: JsonDict) -> OEmbedResult:
return self.oembed.parse_oembed_response(
"https://test", json.dumps(response).encode("utf-8")
)
- def test_version(self):
+ def test_version(self) -> None:
"""Accept versions that are similar to 1.0 as a string or int (or missing)."""
for version in ("1.0", 1.0, 1):
result = self.parse_response({"version": version, "type": "link"})
diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py
index 01d48c3860..da325955f8 100644
--- a/tests/rest/test_health.py
+++ b/tests/rest/test_health.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+from http import HTTPStatus
from synapse.rest.health import HealthResource
@@ -19,12 +19,12 @@ from tests import unittest
class HealthCheckTests(unittest.HomeserverTestCase):
- def create_test_resource(self):
+ def create_test_resource(self) -> HealthResource:
# replace the JsonResource with a HealthResource.
return HealthResource()
- def test_health(self):
+ def test_health(self) -> None:
channel = self.make_request("GET", "/health", shorthand=False)
- self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(channel.result["body"], b"OK")
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index 118aa93a32..11f78f52b8 100644
--- a/tests/rest/test_well_known.py
+++ b/tests/rest/test_well_known.py
@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from http import HTTPStatus
+
from twisted.web.resource import Resource
from synapse.rest.well_known import well_known_resource
@@ -19,7 +21,7 @@ from tests import unittest
class WellKnownTests(unittest.HomeserverTestCase):
- def create_test_resource(self):
+ def create_test_resource(self) -> Resource:
# replace the JsonResource with a Resource wrapping the WellKnownResource
res = Resource()
res.putChild(b".well-known", well_known_resource(self.hs))
@@ -31,12 +33,12 @@ class WellKnownTests(unittest.HomeserverTestCase):
"default_identity_server": "https://testis",
}
)
- def test_client_well_known(self):
+ def test_client_well_known(self) -> None:
channel = self.make_request(
"GET", "/.well-known/matrix/client", shorthand=False
)
- self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(
channel.json_body,
{
@@ -50,27 +52,27 @@ class WellKnownTests(unittest.HomeserverTestCase):
"public_baseurl": None,
}
)
- def test_client_well_known_no_public_baseurl(self):
+ def test_client_well_known_no_public_baseurl(self) -> None:
channel = self.make_request(
"GET", "/.well-known/matrix/client", shorthand=False
)
- self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
@unittest.override_config({"serve_server_wellknown": True})
- def test_server_well_known(self):
+ def test_server_well_known(self) -> None:
channel = self.make_request(
"GET", "/.well-known/matrix/server", shorthand=False
)
- self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(
channel.json_body,
{"m.server": "test:443"},
)
- def test_server_well_known_disabled(self):
+ def test_server_well_known_disabled(self) -> None:
channel = self.make_request(
"GET", "/.well-known/matrix/server", shorthand=False
)
- self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 39dcc094bd..9fdf54ea31 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -66,13 +66,13 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
self.update_handler.reset_mock()
res = self.get_success(
self.updates.do_next_background_update(False),
- by=0.01,
+ by=0.02,
)
self.assertFalse(res)
# on the first call, we should get run with the default background update size
self.update_handler.assert_called_once_with(
- {"my_key": 1}, self.updates.MINIMUM_BACKGROUND_BATCH_SIZE
+ {"my_key": 1}, self.updates.DEFAULT_BACKGROUND_BATCH_SIZE
)
# second step: complete the update
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 219b5660b1..532e3fe9cd 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -13,11 +13,12 @@
# limitations under the License.
import logging
from typing import Optional
+from unittest.mock import patch
from synapse.api.room_versions import RoomVersions
-from synapse.events import EventBase
-from synapse.types import JsonDict
-from synapse.visibility import filter_events_for_server
+from synapse.events import EventBase, make_event_from_dict
+from synapse.types import JsonDict, create_requester
+from synapse.visibility import filter_events_for_client, filter_events_for_server
from tests import unittest
from tests.utils import create_room
@@ -185,3 +186,72 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
self.get_success(self.storage.persistence.persist_event(event, context))
return event
+
+
+class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
+ def test_out_of_band_invite_rejection(self):
+ # this is where we have received an invite event over federation, and then
+ # rejected it.
+ invite_pdu = {
+ "room_id": "!room:id",
+ "depth": 1,
+ "auth_events": [],
+ "prev_events": [],
+ "origin_server_ts": 1,
+ "sender": "@someone:" + self.OTHER_SERVER_NAME,
+ "type": "m.room.member",
+ "state_key": "@user:test",
+ "content": {"membership": "invite"},
+ }
+ self.add_hashes_and_signatures(invite_pdu)
+ invite_event_id = make_event_from_dict(invite_pdu, RoomVersions.V9).event_id
+
+ self.get_success(
+ self.hs.get_federation_server().on_invite_request(
+ self.OTHER_SERVER_NAME,
+ invite_pdu,
+ "9",
+ )
+ )
+
+ # stub out do_remotely_reject_invite so that we fall back to a locally-
+ # generated rejection
+ with patch.object(
+ self.hs.get_federation_handler(),
+ "do_remotely_reject_invite",
+ side_effect=Exception(),
+ ):
+ reject_event_id, _ = self.get_success(
+ self.hs.get_room_member_handler().remote_reject_invite(
+ invite_event_id,
+ txn_id=None,
+ requester=create_requester("@user:test"),
+ content={},
+ )
+ )
+
+ invite_event, reject_event = self.get_success(
+ self.hs.get_datastores().main.get_events_as_list(
+ [invite_event_id, reject_event_id]
+ )
+ )
+
+ # the invited user should be able to see both the invite and the rejection
+ self.assertEqual(
+ self.get_success(
+ filter_events_for_client(
+ self.hs.get_storage(), "@user:test", [invite_event, reject_event]
+ )
+ ),
+ [invite_event, reject_event],
+ )
+
+ # other users should see neither
+ self.assertEqual(
+ self.get_success(
+ filter_events_for_client(
+ self.hs.get_storage(), "@other:test", [invite_event, reject_event]
+ )
+ ),
+ [],
+ )
diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py
index 362014f4cb..ff53ce114b 100644
--- a/tests/util/test_async_helpers.py
+++ b/tests/util/test_async_helpers.py
@@ -100,6 +100,34 @@ class ObservableDeferredTest(TestCase):
self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result")
self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result")
+ def test_cancellation(self):
+ """Test that cancelling an observer does not affect other observers."""
+ origin_d: "Deferred[int]" = Deferred()
+ observable = ObservableDeferred(origin_d, consumeErrors=True)
+
+ observer1 = observable.observe()
+ observer2 = observable.observe()
+ observer3 = observable.observe()
+
+ self.assertFalse(observer1.called)
+ self.assertFalse(observer2.called)
+ self.assertFalse(observer3.called)
+
+ # cancel the second observer
+ observer2.cancel()
+ self.assertFalse(observer1.called)
+ self.failureResultOf(observer2, CancelledError)
+ self.assertFalse(observer3.called)
+
+ # other observers resolve as normal
+ origin_d.callback(123)
+ self.assertEqual(observer1.result, 123, "observer 1 callback result")
+ self.assertEqual(observer3.result, 123, "observer 3 callback result")
+
+ # additional observers resolve as normal
+ observer4 = observable.observe()
+ self.assertEqual(observer4.result, 123, "observer 4 callback result")
+
class TimeoutDeferredTest(TestCase):
def setUp(self):
|