diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index febcc1499d..e2a3bad065 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, cast
+from typing import List, Optional, Sequence, Tuple, cast
from unittest.mock import Mock
from typing_extensions import TypeAlias
from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
from synapse.appservice import (
ApplicationService,
@@ -40,9 +41,6 @@ from tests.test_utils import simple_async_mock
from ..utils import MockClock
-if TYPE_CHECKING:
- from twisted.internet.testing import MemoryReactor
-
class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
def setUp(self) -> None:
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 35dd9a20df..33af8770fd 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -24,7 +24,6 @@ from tests.test_utils import make_awaitable
class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
-
servlets = [
admin.register_servlets,
room.register_servlets,
@@ -37,7 +36,6 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
return config
def test_complexity_simple(self) -> None:
-
u1 = self.register_user("u1", "pass")
u1_token = self.login("u1", "pass")
@@ -71,7 +69,6 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
self.assertEqual(complexity, 1.23)
def test_join_too_large(self) -> None:
-
u1 = self.register_user("u1", "pass")
handler = self.hs.get_room_member_handler()
@@ -131,7 +128,6 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def test_join_too_large_once_joined(self) -> None:
-
u1 = self.register_user("u1", "pass")
u1_token = self.login("u1", "pass")
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index bba6469b55..6c7738d810 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -34,7 +34,6 @@ from tests.unittest import override_config
class FederationServerTests(unittest.FederatingHomeserverTestCase):
-
servlets = [
admin.register_servlets,
room.register_servlets,
diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index 6f300b8e11..5569ccef8a 100644
--- a/tests/handlers/test_admin.py
+++ b/tests/handlers/test_admin.py
@@ -23,6 +23,7 @@ from synapse.api.constants import EventTypes, JoinRules
from synapse.api.room_versions import RoomVersions
from synapse.rest.client import knock, login, room
from synapse.server import HomeServer
+from synapse.types import UserID
from synapse.util import Clock
from tests import unittest
@@ -296,3 +297,58 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(args[0][0]["user_agent"], "user_agent")
self.assertGreater(args[0][0]["last_seen"], 0)
self.assertNotIn("access_token", args[0][0])
+
+ def test_account_data(self) -> None:
+ """Tests that user account data get exported."""
+ # add account data
+ self.get_success(
+ self._store.add_account_data_for_user(self.user2, "m.global", {"a": 1})
+ )
+ self.get_success(
+ self._store.add_account_data_to_room(
+ self.user2, "test_room", "m.per_room", {"b": 2}
+ )
+ )
+
+ writer = Mock()
+
+ self.get_success(self.admin_handler.export_user_data(self.user2, writer))
+
+ # two calls, one call for user data and one call for room data
+ writer.write_account_data.assert_called()
+
+ args = writer.write_account_data.call_args_list[0][0]
+ self.assertEqual(args[0], "global")
+ self.assertEqual(args[1]["m.global"]["a"], 1)
+
+ args = writer.write_account_data.call_args_list[1][0]
+ self.assertEqual(args[0], "test_room")
+ self.assertEqual(args[1]["m.per_room"]["b"], 2)
+
+ def test_media_ids(self) -> None:
+ """Tests that media's metadata get exported."""
+
+ self.get_success(
+ self._store.store_local_media(
+ media_id="media_1",
+ media_type="image/png",
+ time_now_ms=self.clock.time_msec(),
+ upload_name=None,
+ media_length=50,
+ user_id=UserID.from_string(self.user2),
+ )
+ )
+
+ writer = Mock()
+
+ self.get_success(self.admin_handler.export_user_data(self.user2, writer))
+
+ writer.write_media_id.assert_called_once()
+
+ args = writer.write_media_id.call_args[0]
+ self.assertEqual(args[0], "media_1")
+ self.assertEqual(args[1]["media_id"], "media_1")
+ self.assertEqual(args[1]["media_length"], 50)
+ self.assertGreater(args[1]["created_ts"], 0)
+ self.assertIsNone(args[1]["upload_name"])
+ self.assertIsNone(args[1]["last_access_ts"])
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index 179c96adf2..2e838e6572 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -18,7 +18,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes
from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
@@ -81,7 +81,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
def _create_duplicate_event(
self, txn_id: str
- ) -> Tuple[EventBase, EventContext, Optional[dict]]:
+ ) -> Tuple[EventBase, UnpersistedEventContextBase, Optional[dict]]:
"""Create a new event with the given transaction ID. All events produced
by this method will be considered duplicates.
"""
@@ -109,7 +109,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
txn_id = "something_suitably_random"
- event1, context, _ = self._create_duplicate_event(txn_id)
+ event1, unpersisted_context, _ = self._create_duplicate_event(txn_id)
+ context = self.get_success(unpersisted_context.persist(event1))
ret_event1 = self.get_success(
self.handler.handle_new_client_event(
@@ -121,7 +122,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertEqual(event1.event_id, ret_event1.event_id)
- event2, context, _ = self._create_duplicate_event(txn_id)
+ event2, unpersisted_context, _ = self._create_duplicate_event(txn_id)
+ context = self.get_success(unpersisted_context.persist(event2))
# We want to test that the deduplication at the persit event end works,
# so we want to make sure we test with different events.
@@ -142,7 +144,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
# Let's test that calling `persist_event` directly also does the right
# thing.
- event3, context, _ = self._create_duplicate_event(txn_id)
+ event3, unpersisted_context, _ = self._create_duplicate_event(txn_id)
+ context = self.get_success(unpersisted_context.persist(event3))
+
self.assertNotEqual(event1.event_id, event3.event_id)
ret_event3, event_pos3, _ = self.get_success(
@@ -156,7 +160,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
# Let's test that calling `persist_events` directly also does the right
# thing.
- event4, context, _ = self._create_duplicate_event(txn_id)
+ event4, unpersisted_context, _ = self._create_duplicate_event(txn_id)
+ context = self.get_success(unpersisted_context.persist(event4))
+
self.assertNotEqual(event1.event_id, event3.event_id)
events, _ = self.get_success(
@@ -176,8 +182,10 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
txn_id = "something_else_suitably_random"
# Create two duplicate events to persist at the same time
- event1, context1, _ = self._create_duplicate_event(txn_id)
- event2, context2, _ = self._create_duplicate_event(txn_id)
+ event1, unpersisted_context1, _ = self._create_duplicate_event(txn_id)
+ context1 = self.get_success(unpersisted_context1.persist(event1))
+ event2, unpersisted_context2, _ = self._create_duplicate_event(txn_id)
+ context2 = self.get_success(unpersisted_context2.persist(event2))
# Ensure their event IDs are different to start with
self.assertNotEqual(event1.event_id, event2.event_id)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index ad42a7183e..161ff0a6c1 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -62,7 +62,7 @@ class TestSpamChecker:
request_info: Collection[Tuple[str, str]],
auth_provider_id: Optional[str],
) -> RegistrationBehaviour:
- pass
+ return RegistrationBehaviour.ALLOW
class DenyAll(TestSpamChecker):
@@ -111,7 +111,7 @@ class TestLegacyRegistrationSpamChecker:
username: Optional[str],
request_info: Collection[Tuple[str, str]],
) -> RegistrationBehaviour:
- pass
+ return RegistrationBehaviour.ALLOW
class LegacyAllowAll(TestLegacyRegistrationSpamChecker):
@@ -507,7 +507,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Lower the permissions of the inviter.
event_creation_handler = self.hs.get_event_creation_handler()
requester = create_requester(inviter)
- event, context, _ = self.get_success(
+
+ event, unpersisted_context, _ = self.get_success(
event_creation_handler.create_event(
requester,
{
@@ -519,6 +520,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
},
)
)
+ context = self.get_success(unpersisted_context.persist(event))
self.get_success(
event_creation_handler.handle_new_client_event(
requester, events_and_context=[(event, context)]
diff --git a/tests/handlers/test_sso.py b/tests/handlers/test_sso.py
index 137deab138..d6f43a98fc 100644
--- a/tests/handlers/test_sso.py
+++ b/tests/handlers/test_sso.py
@@ -113,7 +113,6 @@ async def mock_get_file(
headers: Optional[RawHeaders] = None,
is_allowed_content_type: Optional[Callable[[str], bool]] = None,
) -> Tuple[int, Dict[bytes, List[bytes]], str, int]:
-
fake_response = FakeResponse(code=404)
if url == "http://my.server/me.png":
fake_response = FakeResponse(
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index f1a50c5bcb..d11ded6c5b 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -31,7 +31,6 @@ EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 6
class StatsRoomTests(unittest.HomeserverTestCase):
-
servlets = [
admin.register_servlets_for_client_rest_resource,
room.register_servlets,
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 9bb9c1afc6..c5746005b5 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -192,6 +192,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.helper.join(room, self.appservice.sender, tok=self.appservice.token)
self._check_only_one_user_in_directory(user, room)
+ def test_search_term_with_colon_in_it_does_not_raise(self) -> None:
+ """
+ Regression test: Test that search terms with colons in them are acceptable.
+ """
+ u1 = self.register_user("user1", "pass")
+ self.get_success(self.handler.search_users(u1, "haha:paamayim-nekudotayim", 10))
+
def test_user_not_in_users_table(self) -> None:
"""Unclear how it happens, but on matrix.org we've seen join events
for users who aren't in the users table. Test that we don't fall over
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index acfdcd3bca..eb7f53fee5 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -30,7 +30,7 @@ from twisted.internet.interfaces import (
IOpenSSLClientConnectionCreator,
IProtocolFactory,
)
-from twisted.internet.protocol import Factory
+from twisted.internet.protocol import Factory, Protocol
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.web._newclient import ResponseNeverReceived
from twisted.web.client import Agent
@@ -63,7 +63,7 @@ from tests.http import (
get_test_ca_cert_file,
)
from tests.server import FakeTransport, ThreadedMemoryReactorClock
-from tests.utils import default_config
+from tests.utils import checked_cast, default_config
logger = logging.getLogger(__name__)
@@ -146,8 +146,10 @@ class MatrixFederationAgentTests(unittest.TestCase):
#
# Normally this would be done by the TCP socket code in Twisted, but we are
# stubbing that out here.
- client_protocol = client_factory.buildProtocol(dummy_address)
- assert isinstance(client_protocol, _WrappingProtocol)
+ # NB: we use a checked_cast here to workaround https://github.com/Shoobx/mypy-zope/issues/91)
+ client_protocol = checked_cast(
+ _WrappingProtocol, client_factory.buildProtocol(dummy_address)
+ )
client_protocol.makeConnection(
FakeTransport(server_protocol, self.reactor, client_protocol)
)
@@ -446,7 +448,6 @@ class MatrixFederationAgentTests(unittest.TestCase):
server_ssl_protocol = _wrap_server_factory_for_tls(
_get_test_protocol_factory()
).buildProtocol(dummy_address)
- assert isinstance(server_ssl_protocol, TLSMemoryBIOProtocol)
# Tell the HTTP server to send outgoing traffic back via the proxy's transport.
proxy_server_transport = proxy_server.transport
@@ -465,7 +466,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
else:
assert isinstance(proxy_server_transport, FakeTransport)
client_protocol = proxy_server_transport.other
- c2s_transport = client_protocol.transport
+ assert isinstance(client_protocol, Protocol)
+ c2s_transport = checked_cast(FakeTransport, client_protocol.transport)
c2s_transport.other = server_ssl_protocol
self.reactor.advance(0)
@@ -1529,7 +1531,7 @@ def _check_logcontext(context: LoggingContextOrSentinel) -> None:
def _wrap_server_factory_for_tls(
factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None
-) -> IProtocolFactory:
+) -> TLSMemoryBIOFactory:
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
The resultant factory will create a TLS server which presents a certificate
signed by our test CA, valid for the domains in `sanlist`
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index 7748f56ee6..6ab13357f9 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -46,7 +46,6 @@ class SrvResolverTestCase(unittest.TestCase):
@defer.inlineCallbacks
def do_lookup() -> Generator["Deferred[object]", object, List[Server]]:
-
with LoggingContext("one") as ctx:
resolve_d = resolver.resolve_service(service_name)
result: List[Server]
diff --git a/tests/http/test_client.py b/tests/http/test_client.py
index 9cfe1ad0de..f6d6684985 100644
--- a/tests/http/test_client.py
+++ b/tests/http/test_client.py
@@ -149,7 +149,7 @@ class BlacklistingAgentTest(TestCase):
self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1"
# Configure the reactor's DNS resolver.
- for (domain, ip) in (
+ for domain, ip in (
(self.safe_domain, self.safe_ip),
(self.unsafe_domain, self.unsafe_ip),
(self.allowed_domain, self.allowed_ip),
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index a817940730..cc175052ac 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -28,7 +28,7 @@ from twisted.internet.endpoints import (
_WrappingProtocol,
)
from twisted.internet.interfaces import IProtocol, IProtocolFactory
-from twisted.internet.protocol import Factory
+from twisted.internet.protocol import Factory, Protocol
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.web.http import HTTPChannel
@@ -43,6 +43,7 @@ from tests.http import (
)
from tests.server import FakeTransport, ThreadedMemoryReactorClock
from tests.unittest import TestCase
+from tests.utils import checked_cast
logger = logging.getLogger(__name__)
@@ -620,7 +621,6 @@ class MatrixFederationAgentTests(TestCase):
server_ssl_protocol = _wrap_server_factory_for_tls(
_get_test_protocol_factory()
).buildProtocol(dummy_address)
- assert isinstance(server_ssl_protocol, TLSMemoryBIOProtocol)
# Tell the HTTP server to send outgoing traffic back via the proxy's transport.
proxy_server_transport = proxy_server.transport
@@ -644,7 +644,8 @@ class MatrixFederationAgentTests(TestCase):
else:
assert isinstance(proxy_server_transport, FakeTransport)
client_protocol = proxy_server_transport.other
- c2s_transport = client_protocol.transport
+ assert isinstance(client_protocol, Protocol)
+ c2s_transport = checked_cast(FakeTransport, client_protocol.transport)
c2s_transport.other = server_ssl_protocol
self.reactor.advance(0)
@@ -757,12 +758,14 @@ class MatrixFederationAgentTests(TestCase):
assert isinstance(proxy_server, HTTPChannel)
# fish the transports back out so that we can do the old switcheroo
- s2c_transport = proxy_server.transport
- assert isinstance(s2c_transport, FakeTransport)
- client_protocol = s2c_transport.other
- assert isinstance(client_protocol, _WrappingProtocol)
- c2s_transport = client_protocol.transport
- assert isinstance(c2s_transport, FakeTransport)
+ # To help mypy out with the various Protocols and wrappers and mocks, we do
+ # some explicit casting. Without the casts, we hit the bug I reported at
+ # https://github.com/Shoobx/mypy-zope/issues/91 .
+ # We also double-checked these casts at runtime (test-time) because I found it
+ # quite confusing to deduce these types in the first place!
+ s2c_transport = checked_cast(FakeTransport, proxy_server.transport)
+ client_protocol = checked_cast(_WrappingProtocol, s2c_transport.other)
+ c2s_transport = checked_cast(FakeTransport, client_protocol.transport)
# the FakeTransport is async, so we need to pump the reactor
self.reactor.advance(0)
@@ -822,9 +825,9 @@ class MatrixFederationAgentTests(TestCase):
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888"})
def test_proxy_with_no_scheme(self) -> None:
http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
- assert isinstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint)
- self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com")
- self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888)
+ proxy_ep = checked_cast(HostnameEndpoint, http_proxy_agent.http_proxy_endpoint)
+ self.assertEqual(proxy_ep._hostStr, "proxy.com")
+ self.assertEqual(proxy_ep._port, 8888)
@patch.dict(os.environ, {"http_proxy": "socks://proxy.com:8888"})
def test_proxy_with_unsupported_scheme(self) -> None:
@@ -834,25 +837,21 @@ class MatrixFederationAgentTests(TestCase):
@patch.dict(os.environ, {"http_proxy": "http://proxy.com:8888"})
def test_proxy_with_http_scheme(self) -> None:
http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
- assert isinstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint)
- self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com")
- self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888)
+ proxy_ep = checked_cast(HostnameEndpoint, http_proxy_agent.http_proxy_endpoint)
+ self.assertEqual(proxy_ep._hostStr, "proxy.com")
+ self.assertEqual(proxy_ep._port, 8888)
@patch.dict(os.environ, {"http_proxy": "https://proxy.com:8888"})
def test_proxy_with_https_scheme(self) -> None:
https_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
- assert isinstance(https_proxy_agent.http_proxy_endpoint, _WrapperEndpoint)
- self.assertEqual(
- https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._hostStr, "proxy.com"
- )
- self.assertEqual(
- https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._port, 8888
- )
+ proxy_ep = checked_cast(_WrapperEndpoint, https_proxy_agent.http_proxy_endpoint)
+ self.assertEqual(proxy_ep._wrappedEndpoint._hostStr, "proxy.com")
+ self.assertEqual(proxy_ep._wrappedEndpoint._port, 8888)
def _wrap_server_factory_for_tls(
factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None
-) -> IProtocolFactory:
+) -> TLSMemoryBIOFactory:
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
The resultant factory will create a TLS server which presents a certificate
diff --git a/tests/logging/test_remote_handler.py b/tests/logging/test_remote_handler.py
index c08954d887..5191e31a8a 100644
--- a/tests/logging/test_remote_handler.py
+++ b/tests/logging/test_remote_handler.py
@@ -21,6 +21,7 @@ from synapse.logging import RemoteHandler
from tests.logging import LoggerCleanupMixin
from tests.server import FakeTransport, get_clock
from tests.unittest import TestCase
+from tests.utils import checked_cast
def connect_logging_client(
@@ -56,8 +57,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
client, server = connect_logging_client(self.reactor, 0)
# Trigger data being sent
- assert isinstance(client.transport, FakeTransport)
- client.transport.flush()
+ client_transport = checked_cast(FakeTransport, client.transport)
+ client_transport.flush()
# One log message, with a single trailing newline
logs = server.data.decode("utf8").splitlines()
@@ -89,8 +90,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
# Allow the reconnection
client, server = connect_logging_client(self.reactor, 0)
- assert isinstance(client.transport, FakeTransport)
- client.transport.flush()
+ client_transport = checked_cast(FakeTransport, client.transport)
+ client_transport.flush()
# Only the 7 infos made it through, the debugs were elided
logs = server.data.splitlines()
@@ -123,8 +124,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
# Allow the reconnection
client, server = connect_logging_client(self.reactor, 0)
- assert isinstance(client.transport, FakeTransport)
- client.transport.flush()
+ client_transport = checked_cast(FakeTransport, client.transport)
+ client_transport.flush()
# The 10 warnings made it through, the debugs and infos were elided
logs = server.data.splitlines()
@@ -148,8 +149,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
# Allow the reconnection
client, server = connect_logging_client(self.reactor, 0)
- assert isinstance(client.transport, FakeTransport)
- client.transport.flush()
+ client_transport = checked_cast(FakeTransport, client.transport)
+ client_transport.flush()
# The first five and last five warnings made it through, the debugs and
# infos were elided
diff --git a/tests/rest/media/v1/__init__.py b/tests/media/__init__.py
index b1ee10cfcc..68910cbf5b 100644
--- a/tests/rest/media/v1/__init__.py
+++ b/tests/media/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2018 New Vector Ltd
+# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/rest/media/v1/test_base.py b/tests/media/test_base.py
index c73179151a..66498c744d 100644
--- a/tests/rest/media/v1/test_base.py
+++ b/tests/media/test_base.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.rest.media.v1._base import get_filename_from_headers
+from synapse.media._base import get_filename_from_headers
from tests import unittest
diff --git a/tests/rest/media/v1/test_filepath.py b/tests/media/test_filepath.py
index 43e6f0f70a..95e3b83d5a 100644
--- a/tests/rest/media/v1/test_filepath.py
+++ b/tests/media/test_filepath.py
@@ -15,7 +15,7 @@ import inspect
import os
from typing import Iterable
-from synapse.rest.media.v1.filepath import MediaFilePaths, _wrap_with_jail_check
+from synapse.media.filepath import MediaFilePaths, _wrap_with_jail_check
from tests import unittest
diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/media/test_html_preview.py
index 1062081a06..e7da75db3e 100644
--- a/tests/rest/media/v1/test_html_preview.py
+++ b/tests/media/test_html_preview.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.rest.media.v1.preview_html import (
+from synapse.media.preview_html import (
_get_html_media_encodings,
decode_body,
parse_html_to_open_graph,
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/media/test_media_storage.py
index 17a3b06a8e..870047d0f2 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/media/test_media_storage.py
@@ -34,13 +34,13 @@ from synapse.events import EventBase
from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable
+from synapse.media._base import FileInfo
+from synapse.media.filepath import MediaFilePaths
+from synapse.media.media_storage import MediaStorage, ReadableFileWrapper
+from synapse.media.storage_provider import FileStorageProviderBackend
from synapse.module_api import ModuleApi
from synapse.rest import admin
from synapse.rest.client import login
-from synapse.rest.media.v1._base import FileInfo
-from synapse.rest.media.v1.filepath import MediaFilePaths
-from synapse.rest.media.v1.media_storage import MediaStorage, ReadableFileWrapper
-from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
from synapse.server import HomeServer
from synapse.types import JsonDict, RoomAlias
from synapse.util import Clock
@@ -52,7 +52,6 @@ from tests.utils import default_config
class MediaStorageTests(unittest.HomeserverTestCase):
-
needs_threadpool = True
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -207,7 +206,6 @@ class MediaRepoTests(unittest.HomeserverTestCase):
user_id = "@test:user"
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
self.fetches: List[
Tuple[
"Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]]",
@@ -255,7 +253,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
config["max_image_pixels"] = 2000000
provider_config = {
- "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
+ "module": "synapse.media.storage_provider.FileStorageProviderBackend",
"store_local": True,
"store_synchronous": False,
"store_remote": True,
@@ -268,7 +266,6 @@ class MediaRepoTests(unittest.HomeserverTestCase):
return hs
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-
media_resource = hs.get_media_repository_resource()
self.download_resource = media_resource.children[b"download"]
self.thumbnail_resource = media_resource.children[b"thumbnail"]
diff --git a/tests/rest/media/v1/test_oembed.py b/tests/media/test_oembed.py
index 3f7f1dbab9..c8bf8421da 100644
--- a/tests/rest/media/v1/test_oembed.py
+++ b/tests/media/test_oembed.py
@@ -18,7 +18,7 @@ from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor
-from synapse.rest.media.v1.oembed import OEmbedProvider, OEmbedResult
+from synapse.media.oembed import OEmbedProvider, OEmbedResult
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py
index 2fe879df47..af0341808d 100644
--- a/tests/push/test_bulk_push_rule_evaluator.py
+++ b/tests/push/test_bulk_push_rule_evaluator.py
@@ -33,7 +33,6 @@ from tests.unittest import HomeserverTestCase, override_config
class TestBulkPushRuleEvaluator(HomeserverTestCase):
-
servlets = [
admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -131,7 +130,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
# Create a new message event, and try to evaluate it under the dodgy
# power level event.
- event, context, _ = self.get_success(
+ event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_event(
self.requester,
{
@@ -146,6 +145,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
prev_event_ids=[pl_event_id],
)
)
+ context = self.get_success(unpersisted_context.persist(event))
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
# should not raise
@@ -171,7 +171,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
"""Ensure that push rules are not calculated when disabled in the config"""
# Create a new message event which should cause a notification.
- event, context, _ = self.get_success(
+ event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_event(
self.requester,
{
@@ -185,6 +185,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
},
)
)
+ context = self.get_success(unpersisted_context.persist(event))
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
# Mock the method which calculates push rules -- we do this instead of
@@ -201,7 +202,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
) -> bool:
"""Returns true iff the `mentions` trigger an event push action."""
# Create a new message event which should cause a notification.
- event, context, _ = self.get_success(
+ event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_event(
self.requester,
{
@@ -212,7 +213,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
},
)
)
-
+ context = self.get_success(unpersisted_context.persist(event))
# Execute the push rule machinery.
self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)]))
@@ -227,7 +228,14 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
)
return len(result) > 0
- @override_config({"experimental_features": {"msc3952_intentional_mentions": True}})
+ @override_config(
+ {
+ "experimental_features": {
+ "msc3952_intentional_mentions": True,
+ "msc3966_exact_event_property_contains": True,
+ }
+ }
+ )
def test_user_mentions(self) -> None:
"""Test the behavior of an event which includes invalid user mentions."""
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
@@ -323,7 +331,14 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
)
)
- @override_config({"experimental_features": {"msc3952_intentional_mentions": True}})
+ @override_config(
+ {
+ "experimental_features": {
+ "msc3952_intentional_mentions": True,
+ "msc3966_exact_event_property_contains": True,
+ }
+ }
+ )
def test_room_mentions(self) -> None:
"""Test the behavior of an event which includes invalid room mentions."""
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
@@ -377,7 +392,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
# Create & persist an event to use as the parent of the relation.
- event, context, _ = self.get_success(
+ event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_event(
self.requester,
{
@@ -391,6 +406,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
},
)
)
+ context = self.get_success(unpersisted_context.persist(event))
self.get_success(
self.event_creation_handler.handle_new_client_event(
self.requester, events_and_context=[(event, context)]
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 7563f33fdc..4ea5472eb4 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -39,7 +39,6 @@ class _User:
class EmailPusherTests(HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -48,7 +47,6 @@ class EmailPusherTests(HomeserverTestCase):
hijack_auth = False
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
config = self.default_config()
config["email"] = {
"enable_notifs": True,
@@ -371,10 +369,8 @@ class EmailPusherTests(HomeserverTestCase):
# disassociate the user's email address
self.get_success(
- self.auth_handler.delete_threepid(
- user_id=self.user_id,
- medium="email",
- address="a@example.com",
+ self.auth_handler.delete_local_threepid(
+ user_id=self.user_id, medium="email", address="a@example.com"
)
)
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 0554d247bc..ff5a9a66f5 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional, Set, Union, cast
+from typing import Any, Dict, List, Optional, Union, cast
import frozendict
@@ -54,7 +54,7 @@ class FlattenDictTestCase(unittest.TestCase):
self.assertEqual({"m.foo.b\\ar": "abc"}, _flatten_dict(input))
self.assertEqual(
{"m\\.foo.b\\\\ar": "abc"},
- _flatten_dict(input, msc3783_escape_event_match_key=True),
+ _flatten_dict(input, msc3873_escape_event_match_key=True),
)
def test_non_string(self) -> None:
@@ -147,9 +147,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
self,
content: JsonMapping,
*,
- has_mentions: bool = False,
- user_mentions: Optional[Set[str]] = None,
- room_mention: bool = False,
related_events: Optional[JsonDict] = None,
) -> PushRuleEvaluator:
event = FrozenEvent(
@@ -168,9 +165,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
return PushRuleEvaluator(
_flatten_dict(event),
- has_mentions,
- user_mentions or set(),
- room_mention,
+ False,
room_member_count,
sender_power_level,
cast(Dict[str, int], power_levels.get("notifications", {})),
@@ -178,7 +173,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
related_event_match_enabled=True,
room_version_feature_flags=event.room_version.msc3931_push_features,
msc3931_enabled=True,
- msc3758_exact_event_match=True,
msc3966_exact_event_property_contains=True,
)
@@ -206,53 +200,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
# A display name with spaces should work fine.
self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar"))
- def test_user_mentions(self) -> None:
- """Check for user mentions."""
- condition = {"kind": "org.matrix.msc3952.is_user_mention"}
-
- # No mentions shouldn't match.
- evaluator = self._get_evaluator({}, has_mentions=True)
- self.assertFalse(evaluator.matches(condition, "@user:test", None))
-
- # An empty set shouldn't match
- evaluator = self._get_evaluator({}, has_mentions=True, user_mentions=set())
- self.assertFalse(evaluator.matches(condition, "@user:test", None))
-
- # The Matrix ID appearing anywhere in the mentions list should match
- evaluator = self._get_evaluator(
- {}, has_mentions=True, user_mentions={"@user:test"}
- )
- self.assertTrue(evaluator.matches(condition, "@user:test", None))
-
- evaluator = self._get_evaluator(
- {}, has_mentions=True, user_mentions={"@another:test", "@user:test"}
- )
- self.assertTrue(evaluator.matches(condition, "@user:test", None))
-
- # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions
- # since the BulkPushRuleEvaluator is what handles data sanitisation.
-
- def test_room_mentions(self) -> None:
- """Check for room mentions."""
- condition = {"kind": "org.matrix.msc3952.is_room_mention"}
-
- # No room mention shouldn't match.
- evaluator = self._get_evaluator({}, has_mentions=True)
- self.assertFalse(evaluator.matches(condition, None, None))
-
- # Room mention should match.
- evaluator = self._get_evaluator({}, has_mentions=True, room_mention=True)
- self.assertTrue(evaluator.matches(condition, None, None))
-
- # A room mention and user mention is valid.
- evaluator = self._get_evaluator(
- {}, has_mentions=True, user_mentions={"@another:test"}, room_mention=True
- )
- self.assertTrue(evaluator.matches(condition, None, None))
-
- # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions
- # since the BulkPushRuleEvaluator is what handles data sanitisation.
-
def _assert_matches(
self, condition: JsonDict, content: JsonMapping, msg: Optional[str] = None
) -> None:
@@ -424,12 +371,39 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
"pattern should not match before a newline",
)
+ def test_event_match_pattern(self) -> None:
+ """Check that event_match conditions do not use a "pattern_type" from user data."""
+
+ # The pattern_type should not be deserialized into anything valid.
+ condition = {
+ "kind": "event_match",
+ "key": "content.value",
+ "pattern_type": "user_id",
+ }
+ self._assert_not_matches(
+ condition,
+ {"value": "@user:test"},
+ "should not be possible to pass a pattern_type in",
+ )
+
+ # This is an internal-only condition which shouldn't get deserialized.
+ condition = {
+ "kind": "event_match_type",
+ "key": "content.value",
+ "pattern_type": "user_id",
+ }
+ self._assert_not_matches(
+ condition,
+ {"value": "@user:test"},
+ "should not be possible to pass a pattern_type in",
+ )
+
def test_exact_event_match_string(self) -> None:
"""Check that exact_event_match conditions work as expected for strings."""
# Test against a string value.
condition = {
- "kind": "com.beeper.msc3758.exact_event_match",
+ "kind": "event_property_is",
"key": "content.value",
"value": "foobaz",
}
@@ -467,11 +441,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
"""Check that exact_event_match conditions work as expected for booleans."""
# Test against a True boolean value.
- condition = {
- "kind": "com.beeper.msc3758.exact_event_match",
- "key": "content.value",
- "value": True,
- }
+ condition = {"kind": "event_property_is", "key": "content.value", "value": True}
self._assert_matches(
condition,
{"value": True},
@@ -491,7 +461,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
# Test against a False boolean value.
condition = {
- "kind": "com.beeper.msc3758.exact_event_match",
+ "kind": "event_property_is",
"key": "content.value",
"value": False,
}
@@ -516,11 +486,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
def test_exact_event_match_null(self) -> None:
"""Check that exact_event_match conditions work as expected for null."""
- condition = {
- "kind": "com.beeper.msc3758.exact_event_match",
- "key": "content.value",
- "value": None,
- }
+ condition = {"kind": "event_property_is", "key": "content.value", "value": None}
self._assert_matches(
condition,
{"value": None},
@@ -536,11 +502,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
def test_exact_event_match_integer(self) -> None:
"""Check that exact_event_match conditions work as expected for integers."""
- condition = {
- "kind": "com.beeper.msc3758.exact_event_match",
- "key": "content.value",
- "value": 1,
- }
+ condition = {"kind": "event_property_is", "key": "content.value", "value": 1}
self._assert_matches(
condition,
{"value": 1},
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index ddca9d696c..57c781a0c3 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -64,7 +64,6 @@ def patch__eq__(cls: object) -> Callable[[], None]:
class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
-
STORE_TYPE = EventsWorkerStore
def setUp(self) -> None:
diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py
index bf927beb6a..bab77b2df7 100644
--- a/tests/replication/tcp/test_handler.py
+++ b/tests/replication/tcp/test_handler.py
@@ -141,3 +141,64 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
self.get_success(ctx_worker1.__aexit__(None, None, None))
self.assertTrue(d.called)
+
+ def test_wait_for_stream_position_rdata(self) -> None:
+ """Check that wait for stream position correctly waits for an update
+ from the correct instance, when RDATA is sent.
+ """
+ store = self.hs.get_datastores().main
+ cmd_handler = self.hs.get_replication_command_handler()
+ data_handler = self.hs.get_replication_data_handler()
+
+ worker1 = self.make_worker_hs(
+ "synapse.app.generic_worker",
+ extra_config={
+ "worker_name": "worker1",
+ "run_background_tasks_on": "worker1",
+ "redis": {"enabled": True},
+ },
+ )
+
+ cache_id_gen = worker1.get_datastores().main._cache_id_gen
+ assert cache_id_gen is not None
+
+ self.replicate()
+
+ # First, make sure the master knows that `worker1` exists.
+ initial_token = cache_id_gen.get_current_token()
+ cmd_handler.send_command(
+ PositionCommand("caches", "worker1", initial_token, initial_token)
+ )
+ self.replicate()
+
+ # `wait_for_stream_position` should only return once master receives a
+ # notification that `next_token2` has persisted.
+ ctx_worker1 = cache_id_gen.get_next_mult(2)
+ next_token1, next_token2 = self.get_success(ctx_worker1.__aenter__())
+
+ d = defer.ensureDeferred(
+ data_handler.wait_for_stream_position("worker1", "caches", next_token2)
+ )
+ self.assertFalse(d.called)
+
+ # Insert an entry into the cache stream with token `next_token1`, but
+ # not `next_token2`.
+ self.get_success(
+ store.db_pool.simple_insert(
+ table="cache_invalidation_stream_by_instance",
+ values={
+ "stream_id": next_token1,
+ "instance_name": "worker1",
+ "cache_func": "foo",
+ "keys": [],
+ "invalidation_ts": 0,
+ },
+ )
+ )
+
+ # Finish the context manager, triggering the data to be sent to master.
+ self.get_success(ctx_worker1.__aexit__(None, None, None))
+
+ # Master should get told about `next_token2`, so the deferred should
+ # resolve.
+ self.assertTrue(d.called)
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index 03f2112b07..aaa488bced 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -28,7 +28,6 @@ from tests import unittest
class DeviceRestTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
@@ -291,7 +290,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
class DevicesRestTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
@@ -415,7 +413,6 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index 233eba3516..f189b07769 100644
--- a/tests/rest/admin/test_event_reports.py
+++ b/tests/rest/admin/test_event_reports.py
@@ -78,7 +78,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
"""
Try to get an event report without authentication.
"""
- channel = self.make_request("GET", self.url, b"{}")
+ channel = self.make_request("GET", self.url, {})
self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -473,7 +473,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
"""
Try to get event report without authentication.
"""
- channel = self.make_request("GET", self.url, b"{}")
+ channel = self.make_request("GET", self.url, {})
self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -599,3 +599,142 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
self.assertIn("room_id", content["event_json"])
self.assertIn("sender", content["event_json"])
self.assertIn("content", content["event_json"])
+
+
+class DeleteEventReportTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self._store = hs.get_datastores().main
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+
+ # create report
+ event_id = self.get_success(
+ self._store.add_event_report(
+ "room_id",
+ "event_id",
+ self.other_user,
+ "this makes me sad",
+ {},
+ self.clock.time_msec(),
+ )
+ )
+
+ self.url = f"/_synapse/admin/v1/event_reports/{event_id}"
+
+ def test_no_auth(self) -> None:
+ """
+ Try to delete event report without authentication.
+ """
+ channel = self.make_request("DELETE", self.url)
+
+ self.assertEqual(401, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self) -> None:
+ """
+ If the user is not a server admin, an error 403 is returned.
+ """
+
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ access_token=self.other_user_tok,
+ )
+
+ self.assertEqual(403, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_delete_success(self) -> None:
+ """
+ Testing delete a report.
+ """
+
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual({}, channel.json_body)
+
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+
+ # check that report was deleted
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_invalid_report_id(self) -> None:
+ """
+ Testing that an invalid `report_id` returns a 400.
+ """
+
+ # `report_id` is negative
+ channel = self.make_request(
+ "DELETE",
+ "/_synapse/admin/v1/event_reports/-123",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+ self.assertEqual(
+ "The report_id parameter must be a string representing a positive integer.",
+ channel.json_body["error"],
+ )
+
+ # `report_id` is a non-numerical string
+ channel = self.make_request(
+ "DELETE",
+ "/_synapse/admin/v1/event_reports/abcdef",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+ self.assertEqual(
+ "The report_id parameter must be a string representing a positive integer.",
+ channel.json_body["error"],
+ )
+
+ # `report_id` is undefined
+ channel = self.make_request(
+ "DELETE",
+ "/_synapse/admin/v1/event_reports/",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+ self.assertEqual(
+ "The report_id parameter must be a string representing a positive integer.",
+ channel.json_body["error"],
+ )
+
+ def test_report_id_not_found(self) -> None:
+ """
+ Testing that a not existing `report_id` returns a 404.
+ """
+
+ channel = self.make_request(
+ "DELETE",
+ "/_synapse/admin/v1/event_reports/123",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+ self.assertEqual("Event report not found", channel.json_body["error"])
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index db77a45ae3..6d04911d67 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -20,8 +20,8 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.errors import Codes
+from synapse.media.filepath import MediaFilePaths
from synapse.rest.client import login, profile, room
-from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.server import HomeServer
from synapse.util import Clock
@@ -34,7 +34,6 @@ INVALID_TIMESTAMP_IN_S = 1893456000 # 2030-01-01 in seconds
class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.admin.register_servlets_for_media_repo,
@@ -196,7 +195,6 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.admin.register_servlets_for_media_repo,
@@ -594,7 +592,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.admin.register_servlets_for_media_repo,
@@ -724,7 +721,6 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.admin.register_servlets_for_media_repo,
@@ -821,7 +817,6 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.admin.register_servlets_for_media_repo,
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 453a6e979c..9dbb778679 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1990,7 +1990,6 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase):
class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
room.register_servlets,
diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
index f71ff46d87..28b999573e 100644
--- a/tests/rest/admin/test_server_notice.py
+++ b/tests/rest/admin/test_server_notice.py
@@ -28,7 +28,6 @@ from tests.unittest import override_config
class ServerNoticeTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 87efab59bb..c278f6bbad 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -28,8 +28,8 @@ import synapse.rest.admin
from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
from synapse.api.room_versions import RoomVersions
+from synapse.media.filepath import MediaFilePaths
from synapse.rest.client import devices, login, logout, profile, register, room, sync
-from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index e2ee1a1766..2b05dffc7d 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -40,7 +40,6 @@ from tests.unittest import override_config
class PasswordResetTestCase(unittest.HomeserverTestCase):
-
servlets = [
account.register_servlets,
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -408,7 +407,6 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
class DeactivateTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
@@ -492,7 +490,6 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
class WhoamiTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
@@ -567,7 +564,6 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
-
servlets = [
account.register_servlets,
login.register_servlets,
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 208ec44829..0d8fe77b88 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -34,7 +34,7 @@ from synapse.util import Clock
from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC
from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER
-from tests.server import FakeChannel, make_request
+from tests.server import FakeChannel
from tests.unittest import override_config, skip_unless
@@ -43,13 +43,15 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker):
super().__init__(hs)
self.recaptcha_attempts: List[Tuple[dict, str]] = []
+ def is_enabled(self) -> bool:
+ return True
+
def check_auth(self, authdict: dict, clientip: str) -> Any:
self.recaptcha_attempts.append((authdict, clientip))
return succeed(True)
class FallbackAuthTests(unittest.HomeserverTestCase):
-
servlets = [
auth.register_servlets,
register.register_servlets,
@@ -57,7 +59,6 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
hijack_auth = False
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
config = self.default_config()
config["enable_registration_captcha"] = True
@@ -1319,16 +1320,8 @@ class OidcBackchannelLogoutTests(unittest.HomeserverTestCase):
channel = self.submit_logout_token(logout_token)
self.assertEqual(channel.code, 200)
- # Now try to exchange the login token
- channel = make_request(
- self.hs.get_reactor(),
- self.site,
- "POST",
- "/login",
- content={"type": "m.login.token", "token": login_token},
- )
- # It should have failed
- self.assertEqual(channel.code, 403)
+ # Now try to exchange the login token, it should fail.
+ self.helper.login_via_token(login_token, 403)
@override_config(
{
diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py
index d1751e1557..c16e8d43f4 100644
--- a/tests/rest/client/test_capabilities.py
+++ b/tests/rest/client/test_capabilities.py
@@ -26,7 +26,6 @@ from tests.unittest import override_config
class CapabilitiesTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
capabilities.register_servlets,
diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py
index b1ca81a911..bb845179d3 100644
--- a/tests/rest/client/test_consent.py
+++ b/tests/rest/client/test_consent.py
@@ -38,7 +38,6 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
hijack_auth = False
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
config = self.default_config()
config["form_secret"] = "123abc"
diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py
index 7a88aa2cda..6490e883bf 100644
--- a/tests/rest/client/test_directory.py
+++ b/tests/rest/client/test_directory.py
@@ -28,7 +28,6 @@ from tests.unittest import override_config
class DirectoryTestCase(unittest.HomeserverTestCase):
-
servlets = [
admin.register_servlets_for_client_rest_resource,
directory.register_servlets,
diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py
index 9fa1f82dfe..f31ebc8021 100644
--- a/tests/rest/client/test_ephemeral_message.py
+++ b/tests/rest/client/test_ephemeral_message.py
@@ -26,7 +26,6 @@ from tests import unittest
class EphemeralMessageTestCase(unittest.HomeserverTestCase):
-
user_id = "@user:test"
servlets = [
diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py
index a9b7db9db2..54df2a252c 100644
--- a/tests/rest/client/test_events.py
+++ b/tests/rest/client/test_events.py
@@ -38,7 +38,6 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
config = self.default_config()
config["enable_registration_captcha"] = False
config["enable_registration"] = True
@@ -51,7 +50,6 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
return hs
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-
# register an account
self.user_id = self.register_user("sid1", "pass")
self.token = self.login(self.user_id, "pass")
@@ -142,7 +140,6 @@ class GetEventsTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-
# register an account
self.user_id = self.register_user("sid1", "pass")
self.token = self.login(self.user_id, "pass")
diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py
index 830762fd53..91678abf13 100644
--- a/tests/rest/client/test_filter.py
+++ b/tests/rest/client/test_filter.py
@@ -25,7 +25,6 @@ PATH_PREFIX = "/_matrix/client/v2_alpha"
class FilterTestCase(unittest.HomeserverTestCase):
-
user_id = "@apple:test"
hijack_auth = True
EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}}
diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py
index 741fecea77..8ee5489057 100644
--- a/tests/rest/client/test_keys.py
+++ b/tests/rest/client/test_keys.py
@@ -14,12 +14,21 @@
from http import HTTPStatus
+from signedjson.key import (
+ encode_verify_key_base64,
+ generate_signing_key,
+ get_verify_key,
+)
+from signedjson.sign import sign_json
+
from synapse.api.errors import Codes
from synapse.rest import admin
from synapse.rest.client import keys, login
+from synapse.types import JsonDict
from tests import unittest
from tests.http.server._base import make_request_with_cancellation_test
+from tests.unittest import override_config
class KeyQueryTestCase(unittest.HomeserverTestCase):
@@ -118,3 +127,135 @@ class KeyQueryTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertIn(bob, channel.json_body["device_keys"])
+
+ def make_device_keys(self, user_id: str, device_id: str) -> JsonDict:
+ # We only generate a master key to simplify the test.
+ master_signing_key = generate_signing_key(device_id)
+ master_verify_key = encode_verify_key_base64(get_verify_key(master_signing_key))
+
+ return {
+ "master_key": sign_json(
+ {
+ "user_id": user_id,
+ "usage": ["master"],
+ "keys": {"ed25519:" + master_verify_key: master_verify_key},
+ },
+ user_id,
+ master_signing_key,
+ ),
+ }
+
+ def test_device_signing_with_uia(self) -> None:
+ """Device signing key upload requires UIA."""
+ password = "wonderland"
+ device_id = "ABCDEFGHI"
+ alice_id = self.register_user("alice", password)
+ alice_token = self.login("alice", password, device_id=device_id)
+
+ content = self.make_device_keys(alice_id, device_id)
+
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/keys/device_signing/upload",
+ content,
+ alice_token,
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+ # Grab the session
+ session = channel.json_body["session"]
+ # Ensure that flows are what is expected.
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+
+ # add UI auth
+ content["auth"] = {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": alice_id},
+ "password": password,
+ "session": session,
+ }
+
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/keys/device_signing/upload",
+ content,
+ alice_token,
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+ @override_config({"ui_auth": {"session_timeout": "15m"}})
+ def test_device_signing_with_uia_session_timeout(self) -> None:
+ """Device signing key upload requires UIA buy passes with grace period."""
+ password = "wonderland"
+ device_id = "ABCDEFGHI"
+ alice_id = self.register_user("alice", password)
+ alice_token = self.login("alice", password, device_id=device_id)
+
+ content = self.make_device_keys(alice_id, device_id)
+
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/keys/device_signing/upload",
+ content,
+ alice_token,
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+ @override_config(
+ {
+ "experimental_features": {"msc3967_enabled": True},
+ "ui_auth": {"session_timeout": "15s"},
+ }
+ )
+ def test_device_signing_with_msc3967(self) -> None:
+ """Device signing key follows MSC3967 behaviour when enabled."""
+ password = "wonderland"
+ device_id = "ABCDEFGHI"
+ alice_id = self.register_user("alice", password)
+ alice_token = self.login("alice", password, device_id=device_id)
+
+ keys1 = self.make_device_keys(alice_id, device_id)
+
+ # Initial request should succeed as no existing keys are present.
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/keys/device_signing/upload",
+ keys1,
+ alice_token,
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+ keys2 = self.make_device_keys(alice_id, device_id)
+
+ # Subsequent request should require UIA as keys already exist even though session_timeout is set.
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/keys/device_signing/upload",
+ keys2,
+ alice_token,
+ )
+ self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+
+ # Grab the session
+ session = channel.json_body["session"]
+ # Ensure that flows are what is expected.
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+
+ # add UI auth
+ keys2["auth"] = {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": alice_id},
+ "password": password,
+ "session": session,
+ }
+
+ # Request should complete
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/keys/device_signing/upload",
+ keys2,
+ alice_token,
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index ff5baa9f0a..62acf4f44e 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -89,7 +89,6 @@ ADDITIONAL_LOGIN_FLOWS = [
class LoginRestServletTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
@@ -737,7 +736,6 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
class CASTestCase(unittest.HomeserverTestCase):
-
servlets = [
login.register_servlets,
]
diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py
index 6aedc1a11c..b8187db982 100644
--- a/tests/rest/client/test_login_token_request.py
+++ b/tests/rest/client/test_login_token_request.py
@@ -26,7 +26,6 @@ endpoint = "/_matrix/client/unstable/org.matrix.msc3882/login/token"
class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
-
servlets = [
login.register_servlets,
admin.register_servlets,
diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py
index 67e16880e6..dcbb125a3b 100644
--- a/tests/rest/client/test_presence.py
+++ b/tests/rest/client/test_presence.py
@@ -35,7 +35,6 @@ class PresenceTestCase(unittest.HomeserverTestCase):
servlets = [presence.register_servlets]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
self.presence_handler = Mock(spec=PresenceHandler)
self.presence_handler.set_state.return_value = make_awaitable(None)
diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py
index 8de5a342ae..27c93ad761 100644
--- a/tests/rest/client/test_profile.py
+++ b/tests/rest/client/test_profile.py
@@ -30,7 +30,6 @@ from tests import unittest
class ProfileTestCase(unittest.HomeserverTestCase):
-
servlets = [
admin.register_servlets_for_client_rest_resource,
login.register_servlets,
@@ -324,7 +323,6 @@ class ProfileTestCase(unittest.HomeserverTestCase):
class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
-
servlets = [
admin.register_servlets_for_client_rest_resource,
login.register_servlets,
@@ -404,7 +402,6 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase):
-
servlets = [
admin.register_servlets_for_client_rest_resource,
login.register_servlets,
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 4c561f9525..b228dba861 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -40,7 +40,6 @@ from tests.unittest import override_config
class RegisterRestServletTestCase(unittest.HomeserverTestCase):
-
servlets = [
login.register_servlets,
register.register_servlets,
@@ -797,7 +796,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
class AccountValidityTestCase(unittest.HomeserverTestCase):
-
servlets = [
register.register_servlets,
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -913,7 +911,6 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
-
servlets = [
register.register_servlets,
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -1132,7 +1129,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
-
servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index c8a6911d5e..fbbbcb23f1 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -30,7 +30,6 @@ from tests import unittest
from tests.server import FakeChannel
from tests.test_utils import make_awaitable
from tests.test_utils.event_injection import inject_event
-from tests.unittest import override_config
class BaseRelationsTestCase(unittest.HomeserverTestCase):
@@ -403,7 +402,7 @@ class RelationsTestCase(BaseRelationsTestCase):
def test_edit(self) -> None:
"""Test that a simple edit works."""
-
+ orig_body = {"body": "Hi!", "msgtype": "m.text"}
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
edit_event_content = {
"msgtype": "m.text",
@@ -424,9 +423,7 @@ class RelationsTestCase(BaseRelationsTestCase):
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(
- channel.json_body["content"], {"body": "Hi!", "msgtype": "m.text"}
- )
+ self.assertEqual(channel.json_body["content"], orig_body)
self._assert_edit_bundle(channel.json_body, edit_event_id, edit_event_content)
# Request the room messages.
@@ -443,7 +440,7 @@ class RelationsTestCase(BaseRelationsTestCase):
)
# Request the room context.
- # /context should return the edited event.
+ # /context should return the event.
channel = self.make_request(
"GET",
f"/rooms/{self.room}/context/{self.parent_id}",
@@ -453,7 +450,7 @@ class RelationsTestCase(BaseRelationsTestCase):
self._assert_edit_bundle(
channel.json_body["event"], edit_event_id, edit_event_content
)
- self.assertEqual(channel.json_body["event"]["content"], new_body)
+ self.assertEqual(channel.json_body["event"]["content"], orig_body)
# Request sync, but limit the timeline so it becomes limited (and includes
# bundled aggregations).
@@ -491,45 +488,11 @@ class RelationsTestCase(BaseRelationsTestCase):
edit_event_content,
)
- @override_config({"experimental_features": {"msc3925_inhibit_edit": True}})
- def test_edit_inhibit_replace(self) -> None:
- """
- If msc3925_inhibit_edit is enabled, then the original event should not be
- replaced.
- """
-
- new_body = {"msgtype": "m.text", "body": "I've been edited!"}
- edit_event_content = {
- "msgtype": "m.text",
- "body": "foo",
- "m.new_content": new_body,
- }
- channel = self._send_relation(
- RelationTypes.REPLACE,
- "m.room.message",
- content=edit_event_content,
- )
- edit_event_id = channel.json_body["event_id"]
-
- # /context should return the *original* event.
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/context/{self.parent_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(
- channel.json_body["event"]["content"], {"body": "Hi!", "msgtype": "m.text"}
- )
- self._assert_edit_bundle(
- channel.json_body["event"], edit_event_id, edit_event_content
- )
-
def test_multi_edit(self) -> None:
"""Test that multiple edits, including attempts by people who
shouldn't be allowed, are correctly handled.
"""
-
+ orig_body = orig_body = {"body": "Hi!", "msgtype": "m.text"}
self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
@@ -570,7 +533,7 @@ class RelationsTestCase(BaseRelationsTestCase):
)
self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(channel.json_body["event"]["content"], new_body)
+ self.assertEqual(channel.json_body["event"]["content"], orig_body)
self._assert_edit_bundle(
channel.json_body["event"], edit_event_id, edit_event_content
)
@@ -642,6 +605,7 @@ class RelationsTestCase(BaseRelationsTestCase):
def test_edit_edit(self) -> None:
"""Test that an edit cannot be edited."""
+ orig_body = {"body": "Hi!", "msgtype": "m.text"}
new_body = {"msgtype": "m.text", "body": "Initial edit"}
edit_event_content = {
"msgtype": "m.text",
@@ -675,14 +639,12 @@ class RelationsTestCase(BaseRelationsTestCase):
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(
- channel.json_body["content"], {"body": "Hi!", "msgtype": "m.text"}
- )
+ self.assertEqual(channel.json_body["content"], orig_body)
# The relations information should not include the edit to the edit.
self._assert_edit_bundle(channel.json_body, edit_event_id, edit_event_content)
- # /context should return the event updated for the *first* edit
+ # /context should return the bundled edit for the *first* edit
# (The edit to the edit should be ignored.)
channel = self.make_request(
"GET",
@@ -690,7 +652,7 @@ class RelationsTestCase(BaseRelationsTestCase):
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(channel.json_body["event"]["content"], new_body)
+ self.assertEqual(channel.json_body["event"]["content"], orig_body)
self._assert_edit_bundle(
channel.json_body["event"], edit_event_id, edit_event_content
)
@@ -1080,48 +1042,6 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
]
assert_bundle(self._find_event_in_chunk(chunk))
- def test_annotation(self) -> None:
- """
- Test that annotations get correctly bundled.
- """
- # Setup by sending a variety of relations.
- self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
- )
- self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
-
- def assert_annotations(bundled_aggregations: JsonDict) -> None:
- self.assertEqual(
- {
- "chunk": [
- {"type": "m.reaction", "key": "a", "count": 2},
- {"type": "m.reaction", "key": "b", "count": 1},
- ]
- },
- bundled_aggregations,
- )
-
- self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7)
-
- def test_annotation_to_annotation(self) -> None:
- """Any relation to an annotation should be ignored."""
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- event_id = channel.json_body["event_id"]
- self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "b", parent_id=event_id
- )
-
- # Fetch the initial annotation event to see if it has bundled aggregations.
- channel = self.make_request(
- "GET",
- f"/_matrix/client/v3/rooms/{self.room}/event/{event_id}",
- access_token=self.user_token,
- )
- self.assertEquals(200, channel.code, channel.json_body)
- # The first annotationt should not have any bundled aggregations.
- self.assertNotIn("m.relations", channel.json_body["unsigned"])
-
def test_reference(self) -> None:
"""
Test that references get correctly bundled.
@@ -1138,7 +1058,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations,
)
- self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7)
+ self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 6)
def test_thread(self) -> None:
"""
@@ -1183,7 +1103,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
# The "user" sent the root event and is making queries for the bundled
# aggregations: they have participated.
- self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 7)
+ self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 6)
# The "user2" sent replies in the thread and is making queries for the
# bundled aggregations: they have participated.
#
@@ -1208,9 +1128,10 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
thread_2 = channel.json_body["event_id"]
- self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_2
+ channel = self._send_relation(
+ RelationTypes.REFERENCE, "org.matrix.test", parent_id=thread_2
)
+ reference_event_id = channel.json_body["event_id"]
def assert_thread(bundled_aggregations: JsonDict) -> None:
self.assertEqual(2, bundled_aggregations.get("count"))
@@ -1235,17 +1156,15 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
self.assert_dict(
{
"m.relations": {
- RelationTypes.ANNOTATION: {
- "chunk": [
- {"type": "m.reaction", "key": "a", "count": 1},
- ]
+ RelationTypes.REFERENCE: {
+ "chunk": [{"event_id": reference_event_id}]
},
}
},
bundled_aggregations["latest_event"].get("unsigned"),
)
- self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 7)
+ self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 6)
def test_nested_thread(self) -> None:
"""
@@ -1330,7 +1249,6 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
thread_summary = relations_dict[RelationTypes.THREAD]
self.assertIn("latest_event", thread_summary)
latest_event_in_thread = thread_summary["latest_event"]
- self.assertEqual(latest_event_in_thread["content"]["body"], "I've been edited!")
# The latest event in the thread should have the edit appear under the
# bundled aggregations.
self.assertDictContainsSubset(
@@ -1363,10 +1281,11 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
thread_id = channel.json_body["event_id"]
- # Annotate the thread.
- self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id
+ # Make a reference to the thread.
+ channel = self._send_relation(
+ RelationTypes.REFERENCE, "org.matrix.test", parent_id=thread_id
)
+ reference_event_id = channel.json_body["event_id"]
channel = self.make_request(
"GET",
@@ -1377,9 +1296,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
self.assertEqual(
channel.json_body["unsigned"].get("m.relations"),
{
- RelationTypes.ANNOTATION: {
- "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
- },
+ RelationTypes.REFERENCE: {"chunk": [{"event_id": reference_event_id}]},
},
)
@@ -1396,9 +1313,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
self.assertEqual(
thread_message["unsigned"].get("m.relations"),
{
- RelationTypes.ANNOTATION: {
- "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
- },
+ RelationTypes.REFERENCE: {"chunk": [{"event_id": reference_event_id}]},
},
)
@@ -1410,7 +1325,8 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
Note that the spec allows for a server to return additional fields beyond
what is specified.
"""
- self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ channel = self._send_relation(RelationTypes.REFERENCE, "org.matrix.test")
+ reference_event_id = channel.json_body["event_id"]
# Note that the sync filter does not include "unsigned" as a field.
filter = urllib.parse.quote_plus(
@@ -1428,7 +1344,12 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
# Ensure there's bundled aggregations on it.
self.assertIn("unsigned", parent_event)
- self.assertIn("m.relations", parent_event["unsigned"])
+ self.assertEqual(
+ parent_event["unsigned"].get("m.relations"),
+ {
+ RelationTypes.REFERENCE: {"chunk": [{"event_id": reference_event_id}]},
+ },
+ )
class RelationIgnoredUserTestCase(BaseRelationsTestCase):
@@ -1475,53 +1396,8 @@ class RelationIgnoredUserTestCase(BaseRelationsTestCase):
return before_aggregations[relation_type], after_aggregations[relation_type]
- def test_annotation(self) -> None:
- """Annotations should ignore"""
- # Send 2 from us, 2 from the to be ignored user.
- allowed_event_ids = []
- ignored_event_ids = []
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
- allowed_event_ids.append(channel.json_body["event_id"])
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="b")
- allowed_event_ids.append(channel.json_body["event_id"])
- channel = self._send_relation(
- RelationTypes.ANNOTATION,
- "m.reaction",
- key="a",
- access_token=self.user2_token,
- )
- ignored_event_ids.append(channel.json_body["event_id"])
- channel = self._send_relation(
- RelationTypes.ANNOTATION,
- "m.reaction",
- key="c",
- access_token=self.user2_token,
- )
- ignored_event_ids.append(channel.json_body["event_id"])
-
- before_aggregations, after_aggregations = self._test_ignored_user(
- RelationTypes.ANNOTATION, allowed_event_ids, ignored_event_ids
- )
-
- self.assertCountEqual(
- before_aggregations["chunk"],
- [
- {"type": "m.reaction", "key": "a", "count": 2},
- {"type": "m.reaction", "key": "b", "count": 1},
- {"type": "m.reaction", "key": "c", "count": 1},
- ],
- )
-
- self.assertCountEqual(
- after_aggregations["chunk"],
- [
- {"type": "m.reaction", "key": "a", "count": 1},
- {"type": "m.reaction", "key": "b", "count": 1},
- ],
- )
-
def test_reference(self) -> None:
- """Annotations should ignore"""
+ """Aggregations should exclude reference relations from ignored users"""
channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
allowed_event_ids = [channel.json_body["event_id"]]
@@ -1544,7 +1420,7 @@ class RelationIgnoredUserTestCase(BaseRelationsTestCase):
)
def test_thread(self) -> None:
- """Annotations should ignore"""
+ """Aggregations should exclude thread releations from ignored users"""
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
allowed_event_ids = [channel.json_body["event_id"]]
@@ -1618,43 +1494,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
for t in threads
]
- def test_redact_relation_annotation(self) -> None:
- """
- Test that annotations of an event are properly handled after the
- annotation is redacted.
-
- The redacted relation should not be included in bundled aggregations or
- the response to relations.
- """
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- to_redact_event_id = channel.json_body["event_id"]
-
- channel = self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
- )
- unredacted_event_id = channel.json_body["event_id"]
-
- # Both relations should exist.
- event_ids = self._get_related_events()
- relations = self._get_bundled_aggregations()
- self.assertCountEqual(event_ids, [to_redact_event_id, unredacted_event_id])
- self.assertEquals(
- relations["m.annotation"],
- {"chunk": [{"type": "m.reaction", "key": "a", "count": 2}]},
- )
-
- # Redact one of the reactions.
- self._redact(to_redact_event_id)
-
- # The unredacted relation should still exist.
- event_ids = self._get_related_events()
- relations = self._get_bundled_aggregations()
- self.assertEquals(event_ids, [unredacted_event_id])
- self.assertEquals(
- relations["m.annotation"],
- {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
- )
-
def test_redact_relation_thread(self) -> None:
"""
Test that thread replies are properly handled after the thread reply redacted.
@@ -1775,14 +1614,14 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
is redacted.
"""
# Add a relation
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
+ channel = self._send_relation(RelationTypes.REFERENCE, "org.matrix.test")
related_event_id = channel.json_body["event_id"]
# The relations should exist.
event_ids = self._get_related_events()
relations = self._get_bundled_aggregations()
self.assertEqual(len(event_ids), 1)
- self.assertIn(RelationTypes.ANNOTATION, relations)
+ self.assertIn(RelationTypes.REFERENCE, relations)
# Redact the original event.
self._redact(self.parent_id)
@@ -1792,8 +1631,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
relations = self._get_bundled_aggregations()
self.assertEquals(event_ids, [related_event_id])
self.assertEquals(
- relations["m.annotation"],
- {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]},
+ relations[RelationTypes.REFERENCE],
+ {"chunk": [{"event_id": related_event_id}]},
)
def test_redact_parent_thread(self) -> None:
diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py
index c0eb5d01a6..8dbd64be55 100644
--- a/tests/rest/client/test_rendezvous.py
+++ b/tests/rest/client/test_rendezvous.py
@@ -25,7 +25,6 @@ endpoint = "/_matrix/client/unstable/org.matrix.msc3886/rendezvous"
class RendezvousServletTestCase(unittest.HomeserverTestCase):
-
servlets = [
rendezvous.register_servlets,
]
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index cfad182b2f..a4900703c4 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -65,7 +65,6 @@ class RoomBase(unittest.HomeserverTestCase):
servlets = [room.register_servlets, room.register_deprecated_servlets]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
self.hs = self.setup_test_homeserver(
"red",
federation_http_client=None,
@@ -92,7 +91,6 @@ class RoomPermissionsTestCase(RoomBase):
rmcreator_id = "@notme:red"
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-
self.helper.auth_user_id = self.rmcreator_id
# create some rooms under the name rmcreator_id
self.uncreated_rmid = "!aa:test"
@@ -715,7 +713,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
- self.assertEqual(33, channel.resource_usage.db_txn_count)
+ self.assertEqual(30, channel.resource_usage.db_txn_count)
def test_post_room_initial_state(self) -> None:
# POST with initial_state config key, expect new room id
@@ -728,7 +726,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
- self.assertEqual(36, channel.resource_usage.db_txn_count)
+ self.assertEqual(32, channel.resource_usage.db_txn_count)
def test_post_room_visibility_key(self) -> None:
# POST with visibility config key, expect new room id
@@ -1127,7 +1125,6 @@ class RoomInviteRatelimitTestCase(RoomBase):
class RoomJoinTestCase(RoomBase):
-
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -2102,7 +2099,6 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
hijack_auth = False
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-
# Register the user who does the searching
self.user_id2 = self.register_user("user", "pass")
self.access_token = self.login("user", "pass")
@@ -2195,7 +2191,6 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -2203,7 +2198,6 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
self.url = b"/_matrix/client/r0/publicRooms"
config = self.default_config()
@@ -2225,7 +2219,6 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -2233,7 +2226,6 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
config = self.default_config()
config["allow_public_rooms_without_auth"] = True
self.hs = self.setup_test_homeserver(config=config)
@@ -2414,7 +2406,6 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -2983,7 +2974,6 @@ class RelationsTestCase(PaginationTestCase):
class ContextTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -3359,7 +3349,6 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
class ThreepidInviteTestCase(unittest.HomeserverTestCase):
-
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -3438,7 +3427,8 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
"""
Test allowing/blocking threepid invites with a spam-check module.
- In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`."""
+ In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`.
+ """
# Mock a few functions to prevent the test from failing due to failing to talk to
# a remote IS. We keep the mock for make_and_store_3pid_invite around so we
# can check its call_count later on during the test.
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index b9047194dd..9c876c7a32 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -41,7 +41,6 @@ from tests.server import TimedOutException
class FilterTestCase(unittest.HomeserverTestCase):
-
user_id = "@apple:test"
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -191,7 +190,6 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
class SyncTypingTests(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -892,7 +890,6 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
class ExcludeRoomTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 3277a116e8..7245830b01 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -137,6 +137,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"""Tests that a forbidden event is forbidden from being sent, but an allowed one
can be sent.
"""
+
# patch the rules module with a Mock which will return False for some event
# types
async def check(
@@ -243,6 +244,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
def test_modify_event(self) -> None:
"""The module can return a modified version of the event"""
+
# first patch the event checker so that it will modify the event
async def check(
ev: EventBase, state: StateMap[EventBase]
@@ -315,6 +317,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
def test_message_edit(self) -> None:
"""Ensure that the module doesn't cause issues with edited messages."""
+
# first patch the event checker so that it will modify the event
async def check(
ev: EventBase, state: StateMap[EventBase]
@@ -465,7 +468,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
async def test_fn(
event: EventBase, state_events: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict]]:
- if event.is_state and event.type == EventTypes.PowerLevels:
+ if event.is_state() and event.type == EventTypes.PowerLevels:
await api.create_and_send_event_into_room(
{
"room_id": event.room_id,
@@ -971,3 +974,124 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
# Check that the mock was called with the right parameters
self.assertEqual(args, (user_id, "email", "foo@example.com"))
+
+ def test_on_add_and_remove_user_third_party_identifier(self) -> None:
+ """Tests that the on_add_user_third_party_identifier and
+ on_remove_user_third_party_identifier module callbacks are called
+ just before associating and removing a 3PID to/from an account.
+ """
+ # Pretend to be a Synapse module and register both callbacks as mocks.
+ third_party_rules = self.hs.get_third_party_event_rules()
+ on_add_user_third_party_identifier_callback_mock = Mock(
+ return_value=make_awaitable(None)
+ )
+ on_remove_user_third_party_identifier_callback_mock = Mock(
+ return_value=make_awaitable(None)
+ )
+ third_party_rules._on_threepid_bind_callbacks.append(
+ on_add_user_third_party_identifier_callback_mock
+ )
+ third_party_rules._on_threepid_bind_callbacks.append(
+ on_remove_user_third_party_identifier_callback_mock
+ )
+
+ # Register an admin user.
+ self.register_user("admin", "password", admin=True)
+ admin_tok = self.login("admin", "password")
+
+ # Also register a normal user we can modify.
+ user_id = self.register_user("user", "password")
+
+ # Add a 3PID to the user.
+ channel = self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/%s" % user_id,
+ {
+ "threepids": [
+ {
+ "medium": "email",
+ "address": "foo@example.com",
+ },
+ ],
+ },
+ access_token=admin_tok,
+ )
+
+ # Check that the mocked add callback was called with the appropriate
+ # 3PID details.
+ self.assertEqual(channel.code, 200, channel.json_body)
+ on_add_user_third_party_identifier_callback_mock.assert_called_once()
+ args = on_add_user_third_party_identifier_callback_mock.call_args[0]
+ self.assertEqual(args, (user_id, "email", "foo@example.com"))
+
+ # Now remove the 3PID from the user
+ channel = self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/%s" % user_id,
+ {
+ "threepids": [],
+ },
+ access_token=admin_tok,
+ )
+
+ # Check that the mocked remove callback was called with the appropriate
+ # 3PID details.
+ self.assertEqual(channel.code, 200, channel.json_body)
+ on_remove_user_third_party_identifier_callback_mock.assert_called_once()
+ args = on_remove_user_third_party_identifier_callback_mock.call_args[0]
+ self.assertEqual(args, (user_id, "email", "foo@example.com"))
+
+ def test_on_remove_user_third_party_identifier_is_called_on_deactivate(
+ self,
+ ) -> None:
+ """Tests that the on_remove_user_third_party_identifier module callback is called
+ when a user is deactivated and their third-party ID associations are deleted.
+ """
+ # Pretend to be a Synapse module and register both callbacks as mocks.
+ third_party_rules = self.hs.get_third_party_event_rules()
+ on_remove_user_third_party_identifier_callback_mock = Mock(
+ return_value=make_awaitable(None)
+ )
+ third_party_rules._on_threepid_bind_callbacks.append(
+ on_remove_user_third_party_identifier_callback_mock
+ )
+
+ # Register an admin user.
+ self.register_user("admin", "password", admin=True)
+ admin_tok = self.login("admin", "password")
+
+ # Also register a normal user we can modify.
+ user_id = self.register_user("user", "password")
+
+ # Add a 3PID to the user.
+ channel = self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/%s" % user_id,
+ {
+ "threepids": [
+ {
+ "medium": "email",
+ "address": "foo@example.com",
+ },
+ ],
+ },
+ access_token=admin_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Now deactivate the user.
+ channel = self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/%s" % user_id,
+ {
+ "deactivated": True,
+ },
+ access_token=admin_tok,
+ )
+
+ # Check that the mocked remove callback was called with the appropriate
+ # 3PID details.
+ self.assertEqual(channel.code, 200, channel.json_body)
+ on_remove_user_third_party_identifier_callback_mock.assert_called_once()
+ args = on_remove_user_third_party_identifier_callback_mock.call_args[0]
+ self.assertEqual(args, (user_id, "email", "foo@example.com"))
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 8d6f2b6ff9..9532e5ddc1 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -36,6 +36,7 @@ from urllib.parse import urlencode
import attr
from typing_extensions import Literal
+from twisted.test.proto_helpers import MemoryReactorClock
from twisted.web.resource import Resource
from twisted.web.server import Site
@@ -67,6 +68,7 @@ class RestHelper:
"""
hs: HomeServer
+ reactor: MemoryReactorClock
site: Site
auth_user_id: Optional[str]
@@ -142,7 +144,7 @@ class RestHelper:
path = path + "?access_token=%s" % tok
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"POST",
path,
@@ -216,7 +218,7 @@ class RestHelper:
data["reason"] = reason
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"POST",
path,
@@ -313,7 +315,7 @@ class RestHelper:
data.update(extra_data or {})
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"PUT",
path,
@@ -394,7 +396,7 @@ class RestHelper:
path = path + "?access_token=%s" % tok
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"PUT",
path,
@@ -433,7 +435,7 @@ class RestHelper:
path = path + f"?access_token={tok}"
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"GET",
path,
@@ -488,7 +490,7 @@ class RestHelper:
if body is not None:
content = json.dumps(body).encode("utf8")
- channel = make_request(self.hs.get_reactor(), self.site, method, path, content)
+ channel = make_request(self.reactor, self.site, method, path, content)
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
@@ -573,8 +575,8 @@ class RestHelper:
image_length = len(image_data)
path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
channel = make_request(
- self.hs.get_reactor(),
- FakeSite(resource, self.hs.get_reactor()),
+ self.reactor,
+ FakeSite(resource, self.reactor),
"POST",
path,
content=image_data,
@@ -603,7 +605,7 @@ class RestHelper:
expect_code: The return code to expect from attempting the whoami request
"""
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"GET",
"account/whoami",
@@ -642,7 +644,7 @@ class RestHelper:
) -> Tuple[JsonDict, FakeAuthorizationGrant]:
"""Log in (as a new user) via OIDC
- Returns the result of the final token login.
+ Returns the result of the final token login and the fake authorization grant.
Requires that "oidc_config" in the homeserver config be set appropriately
(TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
@@ -672,10 +674,28 @@ class RestHelper:
assert m, channel.text_body
login_token = m.group(1)
- # finally, submit the matrix login token to the login API, which gives us our
- # matrix access token and device id.
+ return self.login_via_token(login_token, expected_status), grant
+
+ def login_via_token(
+ self,
+ login_token: str,
+ expected_status: int = 200,
+ ) -> JsonDict:
+ """Submit the matrix login token to the login API, which gives us our
+ matrix access token and device id.Log in (as a new user) via OIDC
+
+ Returns the result of the token login.
+
+ Requires that "oidc_config" in the homeserver config be set appropriately
+ (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
+ "public_base_url".
+
+ Also requires the login servlet and the OIDC callback resource to be mounted at
+ the normal places.
+ """
+
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"POST",
"/login",
@@ -684,7 +704,7 @@ class RestHelper:
assert (
channel.code == expected_status
), f"unexpected status in response: {channel.code}"
- return channel.json_body, grant
+ return channel.json_body
def auth_via_oidc(
self,
@@ -805,7 +825,7 @@ class RestHelper:
with fake_serer.patch_homeserver(hs=self.hs):
# now hit the callback URI with the right params and a made-up code
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"GET",
callback_uri,
@@ -849,7 +869,7 @@ class RestHelper:
# is the easiest way of figuring out what the Host header ought to be set to
# to keep Synapse happy.
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"GET",
uri,
@@ -867,7 +887,7 @@ class RestHelper:
location = get_location(channel)
parts = urllib.parse.urlsplit(location)
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"GET",
urllib.parse.urlunsplit(("", "") + parts[2:]),
@@ -900,9 +920,7 @@ class RestHelper:
+ urllib.parse.urlencode({"session": ui_auth_session_id})
)
# hit the redirect url (which will issue a cookie and state)
- channel = make_request(
- self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint
- )
+ channel = make_request(self.reactor, self.site, "GET", sso_redirect_endpoint)
# that should serve a confirmation page
assert channel.code == HTTPStatus.OK, channel.text_body
channel.extract_cookies(cookies)
diff --git a/tests/rest/media/test_media_retention.py b/tests/rest/media/test_media_retention.py
index 23f227aed6..b59d9dfd4d 100644
--- a/tests/rest/media/test_media_retention.py
+++ b/tests/rest/media/test_media_retention.py
@@ -31,7 +31,6 @@ from tests.utils import MockClock
class MediaRetentionTestCase(unittest.HomeserverTestCase):
-
ONE_DAY_IN_MS = 24 * 60 * 60 * 1000
THIRTY_DAYS_IN_MS = 30 * ONE_DAY_IN_MS
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/test_url_preview.py
index 2c321f8d04..e91dc581c2 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/test_url_preview.py
@@ -26,8 +26,8 @@ from twisted.internet.interfaces import IAddress, IResolutionReceiver
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor
from synapse.config.oembed import OEmbedEndpointConfig
-from synapse.rest.media.v1.media_repository import MediaRepositoryResource
-from synapse.rest.media.v1.preview_url_resource import IMAGE_CACHE_EXPIRY_MS
+from synapse.rest.media.media_repository_resource import MediaRepositoryResource
+from synapse.rest.media.preview_url_resource import IMAGE_CACHE_EXPIRY_MS
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
@@ -58,7 +58,6 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
config = self.default_config()
config["url_preview_enabled"] = True
config["max_spider_size"] = 9999999
@@ -83,7 +82,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
config["media_store_path"] = self.media_store_path
provider_config = {
- "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
+ "module": "synapse.media.storage_provider.FileStorageProviderBackend",
"store_local": True,
"store_synchronous": False,
"store_remote": True,
@@ -118,7 +117,6 @@ class URLPreviewTests(unittest.HomeserverTestCase):
return hs
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-
self.media_repo = hs.get_media_repository_resource()
self.preview_url = self.media_repo.children[b"preview_url"]
@@ -133,7 +131,6 @@ class URLPreviewTests(unittest.HomeserverTestCase):
addressTypes: Optional[Sequence[Type[IAddress]]] = None,
transportSemantics: str = "TCP",
) -> IResolutionReceiver:
-
resolution = HostResolution(hostName)
resolutionReceiver.resolutionBegan(resolution)
if hostName not in self.lookups:
@@ -660,7 +657,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""If the preview image doesn't exist, ensure some data is returned."""
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
- end_content = (
+ result = (
b"""<html><body><img src="http://cdn.matrix.org/foo.jpg"></body></html>"""
)
@@ -681,8 +678,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
b'Content-Type: text/html; charset="utf8"\r\n\r\n'
)
- % (len(end_content),)
- + end_content
+ % (len(result),)
+ + result
)
self.pump()
@@ -691,6 +688,44 @@ class URLPreviewTests(unittest.HomeserverTestCase):
# The image should not be in the result.
self.assertNotIn("og:image", channel.json_body)
+ def test_oembed_failure(self) -> None:
+ """If the autodiscovered oEmbed URL fails, ensure some data is returned."""
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
+
+ result = b"""
+ <title>oEmbed Autodiscovery Fail</title>
+ <link rel="alternate" type="application/json+oembed"
+ href="http://example.com/oembed?url=http%3A%2F%2Fmatrix.org&format=json"
+ title="matrixdotorg" />
+ """
+
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="utf8"\r\n\r\n'
+ )
+ % (len(result),)
+ + result
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+
+ # The image should not be in the result.
+ self.assertEqual(channel.json_body["og:title"], "oEmbed Autodiscovery Fail")
+
def test_data_url(self) -> None:
"""
Requesting to preview a data URL is not supported.
diff --git a/tests/server.py b/tests/server.py
index 237bcad8ba..5de9722766 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -22,20 +22,25 @@ import warnings
from collections import deque
from io import SEEK_END, BytesIO
from typing import (
+ Any,
+ Awaitable,
Callable,
Dict,
Iterable,
List,
MutableMapping,
Optional,
+ Sequence,
Tuple,
Type,
+ TypeVar,
Union,
+ cast,
)
from unittest.mock import Mock
import attr
-from typing_extensions import Deque
+from typing_extensions import Deque, ParamSpec
from zope.interface import implementer
from twisted.internet import address, threads, udp
@@ -44,8 +49,10 @@ from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import (
IAddress,
+ IConnector,
IConsumer,
IHostnameResolver,
+ IProducer,
IProtocol,
IPullProducer,
IPushProducer,
@@ -54,6 +61,8 @@ from twisted.internet.interfaces import (
IResolverSimple,
ITransport,
)
+from twisted.internet.protocol import ClientFactory, DatagramProtocol
+from twisted.python import threadpool
from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
from twisted.web.http_headers import Headers
@@ -61,6 +70,7 @@ from twisted.web.resource import IResource
from twisted.web.server import Request, Site
from synapse.config.database import DatabaseConnectionConfig
+from synapse.config.homeserver import HomeServerConfig
from synapse.events.presence_router import load_legacy_presence_router
from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
@@ -88,6 +98,9 @@ from tests.utils import (
logger = logging.getLogger(__name__)
+R = TypeVar("R")
+P = ParamSpec("P")
+
# the type of thing that can be passed into `make_request` in the headers list
CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
@@ -98,12 +111,14 @@ class TimedOutException(Exception):
"""
-@implementer(IConsumer)
+@implementer(ITransport, IPushProducer, IConsumer)
@attr.s(auto_attribs=True)
class FakeChannel:
"""
A fake Twisted Web Channel (the part that interfaces with the
wire).
+
+ See twisted.web.http.HTTPChannel.
"""
site: Union[Site, "FakeSite"]
@@ -142,7 +157,7 @@ class FakeChannel:
Raises an exception if the request has not yet completed.
"""
- if not self.is_finished:
+ if not self.is_finished():
raise Exception("Request not yet completed")
return self.result["body"].decode("utf8")
@@ -165,27 +180,36 @@ class FakeChannel:
h.addRawHeader(*i)
return h
- def writeHeaders(self, version, code, reason, headers):
+ def writeHeaders(
+ self, version: bytes, code: bytes, reason: bytes, headers: Headers
+ ) -> None:
self.result["version"] = version
self.result["code"] = code
self.result["reason"] = reason
self.result["headers"] = headers
- def write(self, content: bytes) -> None:
- assert isinstance(content, bytes), "Should be bytes! " + repr(content)
+ def write(self, data: bytes) -> None:
+ assert isinstance(data, bytes), "Should be bytes! " + repr(data)
if "body" not in self.result:
self.result["body"] = b""
- self.result["body"] += content
+ self.result["body"] += data
+
+ def writeSequence(self, data: Iterable[bytes]) -> None:
+ for x in data:
+ self.write(x)
+
+ def loseConnection(self) -> None:
+ self.unregisterProducer()
+ self.transport.loseConnection()
# Type ignore: mypy doesn't like the fact that producer isn't an IProducer.
- def registerProducer( # type: ignore[override]
- self,
- producer: Union[IPullProducer, IPushProducer],
- streaming: bool,
- ) -> None:
- self._producer = producer
+ def registerProducer(self, producer: IProducer, streaming: bool) -> None:
+ # TODO This should ensure that the IProducer is an IPushProducer or
+ # IPullProducer, unfortunately twisted.protocols.basic.FileSender does
+ # implement those, but doesn't declare it.
+ self._producer = cast(Union[IPushProducer, IPullProducer], producer)
self.producerStreaming = streaming
def _produce() -> None:
@@ -202,6 +226,16 @@ class FakeChannel:
self._producer = None
+ def stopProducing(self) -> None:
+ if self._producer is not None:
+ self._producer.stopProducing()
+
+ def pauseProducing(self) -> None:
+ raise NotImplementedError()
+
+ def resumeProducing(self) -> None:
+ raise NotImplementedError()
+
def requestDone(self, _self: Request) -> None:
self.result["done"] = True
if isinstance(_self, SynapseRequest):
@@ -281,12 +315,12 @@ class FakeSite:
self.reactor = reactor
self.experimental_cors_msc3886 = experimental_cors_msc3886
- def getResourceFor(self, request):
+ def getResourceFor(self, request: Request) -> IResource:
return self._resource
def make_request(
- reactor,
+ reactor: MemoryReactorClock,
site: Union[Site, FakeSite],
method: Union[bytes, str],
path: Union[bytes, str],
@@ -409,19 +443,21 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
A MemoryReactorClock that supports callFromThread.
"""
- def __init__(self):
+ def __init__(self) -> None:
self.threadpool = ThreadPool(self)
self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {}
- self._udp = []
+ self._udp: List[udp.Port] = []
self.lookups: Dict[str, str] = {}
- self._thread_callbacks: Deque[Callable[[], None]] = deque()
+ self._thread_callbacks: Deque[Callable[..., R]] = deque()
lookups = self.lookups
@implementer(IResolverSimple)
class FakeResolver:
- def getHostByName(self, name, timeout=None):
+ def getHostByName(
+ self, name: str, timeout: Optional[Sequence[int]] = None
+ ) -> "Deferred[str]":
if name not in lookups:
return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
return succeed(lookups[name])
@@ -432,25 +468,44 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
raise NotImplementedError()
- def listenUDP(self, port, protocol, interface="", maxPacketSize=8196):
+ def listenUDP(
+ self,
+ port: int,
+ protocol: DatagramProtocol,
+ interface: str = "",
+ maxPacketSize: int = 8196,
+ ) -> udp.Port:
p = udp.Port(port, protocol, interface, maxPacketSize, self)
p.startListening()
self._udp.append(p)
return p
- def callFromThread(self, callback, *args, **kwargs):
+ def callFromThread(
+ self, callable: Callable[..., Any], *args: object, **kwargs: object
+ ) -> None:
"""
Make the callback fire in the next reactor iteration.
"""
- cb = lambda: callback(*args, **kwargs)
+ cb = lambda: callable(*args, **kwargs)
# it's not safe to call callLater() here, so we append the callback to a
# separate queue.
self._thread_callbacks.append(cb)
- def getThreadPool(self):
- return self.threadpool
+ def callInThread(
+ self, callable: Callable[..., Any], *args: object, **kwargs: object
+ ) -> None:
+ raise NotImplementedError()
+
+ def suggestThreadPoolSize(self, size: int) -> None:
+ raise NotImplementedError()
+
+ def getThreadPool(self) -> "threadpool.ThreadPool":
+ # Cast to match super-class.
+ return cast(threadpool.ThreadPool, self.threadpool)
- def add_tcp_client_callback(self, host: str, port: int, callback: Callable):
+ def add_tcp_client_callback(
+ self, host: str, port: int, callback: Callable[[], None]
+ ) -> None:
"""Add a callback that will be invoked when we receive a connection
attempt to the given IP/port using `connectTCP`.
@@ -459,7 +514,14 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
"""
self._tcp_callbacks[(host, port)] = callback
- def connectTCP(self, host: str, port: int, factory, timeout=30, bindAddress=None):
+ def connectTCP(
+ self,
+ host: str,
+ port: int,
+ factory: ClientFactory,
+ timeout: float = 30,
+ bindAddress: Optional[Tuple[str, int]] = None,
+ ) -> IConnector:
"""Fake L{IReactorTCP.connectTCP}."""
conn = super().connectTCP(
@@ -472,7 +534,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
return conn
- def advance(self, amount):
+ def advance(self, amount: float) -> None:
# first advance our reactor's time, and run any "callLater" callbacks that
# makes ready
super().advance(amount)
@@ -500,25 +562,33 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
class ThreadPool:
"""
Threadless thread pool.
+
+ See twisted.python.threadpool.ThreadPool
"""
- def __init__(self, reactor):
+ def __init__(self, reactor: IReactorTime):
self._reactor = reactor
- def start(self):
+ def start(self) -> None:
pass
- def stop(self):
+ def stop(self) -> None:
pass
- def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
- def _(res):
+ def callInThreadWithCallback(
+ self,
+ onResult: Callable[[bool, Union[Failure, R]], None],
+ function: Callable[P, R],
+ *args: P.args,
+ **kwargs: P.kwargs,
+ ) -> "Deferred[None]":
+ def _(res: Any) -> None:
if isinstance(res, Failure):
onResult(False, res)
else:
onResult(True, res)
- d = Deferred()
+ d: "Deferred[None]" = Deferred()
d.addCallback(lambda x: function(*args, **kwargs))
d.addBoth(_)
self._reactor.callLater(0, d.callback, True)
@@ -535,7 +605,9 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None:
for database in server.get_datastores().databases:
pool = database._db_pool
- def runWithConnection(func, *args, **kwargs):
+ def runWithConnection(
+ func: Callable[..., R], *args: Any, **kwargs: Any
+ ) -> Awaitable[R]:
return threads.deferToThreadPool(
pool._reactor,
pool.threadpool,
@@ -545,20 +617,23 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None:
**kwargs,
)
- def runInteraction(interaction, *args, **kwargs):
+ def runInteraction(
+ desc: str, func: Callable[..., R], *args: Any, **kwargs: Any
+ ) -> Awaitable[R]:
return threads.deferToThreadPool(
pool._reactor,
pool.threadpool,
pool._runInteraction,
- interaction,
+ desc,
+ func,
*args,
**kwargs,
)
- pool.runWithConnection = runWithConnection
- pool.runInteraction = runInteraction
+ pool.runWithConnection = runWithConnection # type: ignore[assignment]
+ pool.runInteraction = runInteraction # type: ignore[assignment]
# Replace the thread pool with a threadless 'thread' pool
- pool.threadpool = ThreadPool(clock._reactor)
+ pool.threadpool = ThreadPool(clock._reactor) # type: ignore[assignment]
pool.running = True
# We've just changed the Databases to run DB transactions on the same
@@ -573,7 +648,7 @@ def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
@implementer(ITransport)
-@attr.s(cmp=False)
+@attr.s(cmp=False, auto_attribs=True)
class FakeTransport:
"""
A twisted.internet.interfaces.ITransport implementation which sends all its data
@@ -588,48 +663,50 @@ class FakeTransport:
If you want bidirectional communication, you'll need two instances.
"""
- other = attr.ib()
+ other: IProtocol
"""The Protocol object which will receive any data written to this transport.
-
- :type: twisted.internet.interfaces.IProtocol
"""
- _reactor = attr.ib()
+ _reactor: IReactorTime
"""Test reactor
-
- :type: twisted.internet.interfaces.IReactorTime
"""
- _protocol = attr.ib(default=None)
+ _protocol: Optional[IProtocol] = None
"""The Protocol which is producing data for this transport. Optional, but if set
will get called back for connectionLost() notifications etc.
"""
- _peer_address: Optional[IAddress] = attr.ib(default=None)
+ _peer_address: IAddress = attr.Factory(
+ lambda: address.IPv4Address("TCP", "127.0.0.1", 5678)
+ )
"""The value to be returned by getPeer"""
- _host_address: Optional[IAddress] = attr.ib(default=None)
+ _host_address: IAddress = attr.Factory(
+ lambda: address.IPv4Address("TCP", "127.0.0.1", 1234)
+ )
"""The value to be returned by getHost"""
disconnecting = False
disconnected = False
connected = True
- buffer = attr.ib(default=b"")
- producer = attr.ib(default=None)
- autoflush = attr.ib(default=True)
+ buffer: bytes = b""
+ producer: Optional[IPushProducer] = None
+ autoflush: bool = True
- def getPeer(self) -> Optional[IAddress]:
+ def getPeer(self) -> IAddress:
return self._peer_address
- def getHost(self) -> Optional[IAddress]:
+ def getHost(self) -> IAddress:
return self._host_address
- def loseConnection(self, reason=None):
+ def loseConnection(self) -> None:
if not self.disconnecting:
- logger.info("FakeTransport: loseConnection(%s)", reason)
+ logger.info("FakeTransport: loseConnection()")
self.disconnecting = True
if self._protocol:
- self._protocol.connectionLost(reason)
+ self._protocol.connectionLost(
+ Failure(RuntimeError("FakeTransport.loseConnection()"))
+ )
# if we still have data to write, delay until that is done
if self.buffer:
@@ -640,38 +717,38 @@ class FakeTransport:
self.connected = False
self.disconnected = True
- def abortConnection(self):
+ def abortConnection(self) -> None:
logger.info("FakeTransport: abortConnection()")
if not self.disconnecting:
self.disconnecting = True
if self._protocol:
- self._protocol.connectionLost(None)
+ self._protocol.connectionLost(None) # type: ignore[arg-type]
self.disconnected = True
- def pauseProducing(self):
+ def pauseProducing(self) -> None:
if not self.producer:
return
self.producer.pauseProducing()
- def resumeProducing(self):
+ def resumeProducing(self) -> None:
if not self.producer:
return
self.producer.resumeProducing()
- def unregisterProducer(self):
+ def unregisterProducer(self) -> None:
if not self.producer:
return
self.producer = None
- def registerProducer(self, producer, streaming):
+ def registerProducer(self, producer: IPushProducer, streaming: bool) -> None:
self.producer = producer
self.producerStreaming = streaming
- def _produce():
+ def _produce() -> None:
if not self.producer:
# we've been unregistered
return
@@ -683,7 +760,7 @@ class FakeTransport:
if not streaming:
self._reactor.callLater(0.0, _produce)
- def write(self, byt):
+ def write(self, byt: bytes) -> None:
if self.disconnecting:
raise Exception("Writing to disconnecting FakeTransport")
@@ -695,11 +772,11 @@ class FakeTransport:
if self.autoflush:
self._reactor.callLater(0.0, self.flush)
- def writeSequence(self, seq):
+ def writeSequence(self, seq: Iterable[bytes]) -> None:
for x in seq:
self.write(x)
- def flush(self, maxbytes=None):
+ def flush(self, maxbytes: Optional[int] = None) -> None:
if not self.buffer:
# nothing to do. Don't write empty buffers: it upsets the
# TLSMemoryBIOProtocol
@@ -750,17 +827,17 @@ def connect_client(
class TestHomeServer(HomeServer):
- DATASTORE_CLASS = DataStore
+ DATASTORE_CLASS = DataStore # type: ignore[assignment]
def setup_test_homeserver(
- cleanup_func,
- name="test",
- config=None,
- reactor=None,
+ cleanup_func: Callable[[Callable[[], None]], None],
+ name: str = "test",
+ config: Optional[HomeServerConfig] = None,
+ reactor: Optional[ISynapseReactor] = None,
homeserver_to_use: Type[HomeServer] = TestHomeServer,
- **kwargs,
-):
+ **kwargs: Any,
+) -> HomeServer:
"""
Setup a homeserver suitable for running tests against. Keyword arguments
are passed to the Homeserver constructor.
@@ -775,13 +852,14 @@ def setup_test_homeserver(
HomeserverTestCase.
"""
if reactor is None:
- from twisted.internet import reactor
+ from twisted.internet import reactor as _reactor
+
+ reactor = cast(ISynapseReactor, _reactor)
if config is None:
config = default_config(name, parse=True)
config.caches.resize_all_caches()
- config.ldap_enabled = False
if "clock" not in kwargs:
kwargs["clock"] = MockClock()
@@ -832,6 +910,8 @@ def setup_test_homeserver(
# Create the database before we actually try and connect to it, based off
# the template database we generate in setupdb()
if isinstance(db_engine, PostgresEngine):
+ import psycopg2.extensions
+
db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
@@ -839,6 +919,7 @@ def setup_test_homeserver(
port=POSTGRES_PORT,
password=POSTGRES_PASSWORD,
)
+ assert isinstance(db_conn, psycopg2.extensions.connection)
db_conn.autocommit = True
cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
@@ -867,14 +948,15 @@ def setup_test_homeserver(
hs.setup_background_tasks()
if isinstance(db_engine, PostgresEngine):
- database = hs.get_datastores().databases[0]
+ database_pool = hs.get_datastores().databases[0]
# We need to do cleanup on PostgreSQL
- def cleanup():
+ def cleanup() -> None:
import psycopg2
+ import psycopg2.extensions
# Close all the db pools
- database._db_pool.close()
+ database_pool._db_pool.close()
dropped = False
@@ -886,6 +968,7 @@ def setup_test_homeserver(
port=POSTGRES_PORT,
password=POSTGRES_PASSWORD,
)
+ assert isinstance(db_conn, psycopg2.extensions.connection)
db_conn.autocommit = True
cur = db_conn.cursor()
@@ -918,23 +1001,23 @@ def setup_test_homeserver(
# Need to let the HS build an auth handler and then mess with it
# because AuthHandler's constructor requires the HS, so we can't make one
# beforehand and pass it in to the HS's constructor (chicken / egg)
- async def hash(p):
+ async def hash(p: str) -> str:
return hashlib.md5(p.encode("utf8")).hexdigest()
- hs.get_auth_handler().hash = hash
+ hs.get_auth_handler().hash = hash # type: ignore[assignment]
- async def validate_hash(p, h):
+ async def validate_hash(p: str, h: str) -> bool:
return hashlib.md5(p.encode("utf8")).hexdigest() == h
- hs.get_auth_handler().validate_hash = validate_hash
+ hs.get_auth_handler().validate_hash = validate_hash # type: ignore[assignment]
# Make the threadpool and database transactions synchronous for testing.
_make_test_homeserver_synchronous(hs)
# Load any configured modules into the homeserver
module_api = hs.get_module_api()
- for module, config in hs.config.modules.loaded_modules:
- module(config=config, api=module_api)
+ for module, module_config in hs.config.modules.loaded_modules:
+ module(config=module_config, api=module_api)
load_legacy_spam_checkers(hs)
load_legacy_third_party_event_rules(hs)
diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py
index 6540ed53f1..3fdf5a6d52 100644
--- a/tests/server_notices/test_consent.py
+++ b/tests/server_notices/test_consent.py
@@ -25,7 +25,6 @@ from tests import unittest
class ConsentNoticesTests(unittest.HomeserverTestCase):
-
servlets = [
sync.register_servlets,
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -34,7 +33,6 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
tmpdir = self.mktemp()
os.mkdir(tmpdir)
self.consent_notice_message = "consent %(consent_uri)s"
diff --git a/tests/storage/databases/main/test_deviceinbox.py b/tests/storage/databases/main/test_deviceinbox.py
index 373707b275..b6d5c474b0 100644
--- a/tests/storage/databases/main/test_deviceinbox.py
+++ b/tests/storage/databases/main/test_deviceinbox.py
@@ -23,7 +23,6 @@ from tests.unittest import HomeserverTestCase
class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase):
-
servlets = [
admin.register_servlets,
devices.register_servlets,
diff --git a/tests/storage/databases/main/test_receipts.py b/tests/storage/databases/main/test_receipts.py
index ac77aec003..71db47405e 100644
--- a/tests/storage/databases/main/test_receipts.py
+++ b/tests/storage/databases/main/test_receipts.py
@@ -26,7 +26,6 @@ from tests.unittest import HomeserverTestCase
class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
-
servlets = [
admin.register_servlets,
room.register_servlets,
@@ -62,6 +61,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
keys and expected receipt key-values after duplicate receipts have been
removed.
"""
+
# First, undo the background update.
def drop_receipts_unique_index(txn: LoggingTransaction) -> None:
txn.execute(f"DROP INDEX IF EXISTS {index_name}")
diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py
index 3108ca3444..dbd8f3a85e 100644
--- a/tests/storage/databases/main/test_room.py
+++ b/tests/storage/databases/main/test_room.py
@@ -27,7 +27,6 @@ from tests.unittest import HomeserverTestCase
class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
-
servlets = [
admin.register_servlets,
room.register_servlets,
diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py
index 1bfd11ceae..b12691a9d3 100644
--- a/tests/storage/test_account_data.py
+++ b/tests/storage/test_account_data.py
@@ -140,3 +140,25 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
# No one ignores the user now.
self.assert_ignored(self.user, set())
self.assert_ignorers("@other:test", set())
+
+ def test_ignoring_users_with_latest_stream_ids(self) -> None:
+ """Test that ignoring users updates the latest stream ID for the ignored
+ user list account data."""
+
+ def get_latest_ignore_streampos(user_id: str) -> Optional[int]:
+ return self.get_success(
+ self.store.get_latest_stream_id_for_global_account_data_by_type_for_user(
+ user_id, AccountDataTypes.IGNORED_USER_LIST
+ )
+ )
+
+ self.assertIsNone(get_latest_ignore_streampos("@user:test"))
+
+ self._update_ignore_list("@other:test", "@another:remote")
+
+ self.assertEqual(get_latest_ignore_streampos("@user:test"), 2)
+
+ # Add one user, remove one user, and leave one user.
+ self._update_ignore_list("@foo:test", "@another:remote")
+
+ self.assertEqual(get_latest_ignore_streampos("@user:test"), 3)
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index d570684c99..7de109966d 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -43,8 +43,9 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
# Create a test user and room
self.user = UserID("alice", "test")
self.requester = create_requester(self.user)
- info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
- self.room_id = info["room_id"]
+ self.room_id, _, _ = self.get_success(
+ self.room_creator.create_room(self.requester, {})
+ )
def run_background_update(self) -> None:
"""Re run the background update to clean up the extremities."""
@@ -275,10 +276,9 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.user = UserID.from_string(self.register_user("user1", "password"))
self.token1 = self.login("user1", "password")
self.requester = create_requester(self.user)
- info, _ = self.get_success(
+ self.room_id, _, _ = self.get_success(
self.room_creator.create_room(self.requester, {"visibility": "public"})
)
- self.room_id = info["room_id"]
self.event_creator = homeserver.get_event_creation_handler()
homeserver.config.consent.user_consent_version = self.CONSENT_VERSION
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 7f7f4ef892..cd0079871c 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -656,7 +656,6 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
class ClientIpAuthTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index 8d7310d8e5..2a9aa9e21c 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -417,7 +417,6 @@ class EventChainStoreTestCase(HomeserverTestCase):
def fetch_chains(
self, events: List[EventBase]
) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]:
-
# Fetch the map from event ID -> (chain ID, sequence number)
rows = self.get_success(
self.store.db_pool.simple_select_many_batch(
@@ -492,7 +491,6 @@ class LinkMapTestCase(unittest.TestCase):
class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
-
servlets = [
admin.register_servlets,
room.register_servlets,
@@ -524,7 +522,8 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
latest_event_ids = self.get_success(
self.store.get_prev_events_for_room(room_id)
)
- event, context, _ = self.get_success(
+
+ event, unpersisted_context, _ = self.get_success(
event_handler.create_event(
self.requester,
{
@@ -537,6 +536,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
prev_event_ids=latest_event_ids,
)
)
+ context = self.get_success(unpersisted_context.persist(event))
self.get_success(
event_handler.handle_new_client_event(
self.requester, events_and_context=[(event, context)]
@@ -546,7 +546,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
assert state_ids1 is not None
state1 = set(state_ids1.values())
- event, context, _ = self.get_success(
+ event, unpersisted_context, _ = self.get_success(
event_handler.create_event(
self.requester,
{
@@ -559,6 +559,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
prev_event_ids=latest_event_ids,
)
)
+ context = self.get_success(unpersisted_context.persist(event))
self.get_success(
event_handler.handle_new_client_event(
self.requester, events_and_context=[(event, context)]
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 8fc7936ab0..3e1984c15c 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -672,7 +672,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
complete_event_dict_map: Dict[str, JsonDict] = {}
stream_ordering = 0
- for (event_id, prev_event_ids) in event_graph.items():
+ for event_id, prev_event_ids in event_graph.items():
depth = depth_map[event_id]
complete_event_dict_map[event_id] = {
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index a91411168c..6897addbd3 100644
--- a/tests/storage/test_event_metrics.py
+++ b/tests/storage/test_event_metrics.py
@@ -33,8 +33,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
events = [(3, 2), (6, 2), (4, 6)]
for event_count, extrems in events:
- info, _ = self.get_success(room_creator.create_room(requester, {}))
- room_id = info["room_id"]
+ room_id, _, _ = self.get_success(room_creator.create_room(requester, {}))
last_event = None
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 76c06a9d1e..aa19c3bd30 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -774,7 +774,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
self.assertEqual(r, 3)
# add a bunch of dummy events to the events table
- for (stream_ordering, ts) in (
+ for stream_ordering, ts in (
(3, 110),
(4, 120),
(5, 120),
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index d8f42c5d05..857e2caf2e 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -23,7 +23,6 @@ from tests.unittest import HomeserverTestCase
class PurgeTests(HomeserverTestCase):
-
user_id = "@red:server"
servlets = [room.register_servlets]
diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py
index 12c17f1073..1b52eef23f 100644
--- a/tests/storage/test_receipts.py
+++ b/tests/storage/test_receipts.py
@@ -50,12 +50,14 @@ class ReceiptTestCase(HomeserverTestCase):
self.otherRequester = create_requester(self.otherUser)
# Create a test room
- info, _ = self.get_success(self.room_creator.create_room(self.ourRequester, {}))
- self.room_id1 = info["room_id"]
+ self.room_id1, _, _ = self.get_success(
+ self.room_creator.create_room(self.ourRequester, {})
+ )
# Create a second test room
- info, _ = self.get_success(self.room_creator.create_room(self.ourRequester, {}))
- self.room_id2 = info["room_id"]
+ self.room_id2, _, _ = self.get_success(
+ self.room_creator.create_room(self.ourRequester, {})
+ )
# Join the second user to the first room
memberEvent, memberEventContext = self.get_success(
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 8794401823..f4c4661aaf 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -27,7 +27,6 @@ from tests.test_utils import event_injection
class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
-
servlets = [
login.register_servlets,
register_servlets_for_client_rest_resource,
@@ -35,7 +34,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: TestHomeServer) -> None: # type: ignore[override]
-
# We can't test the RoomMemberStore on its own without the other event
# storage logic
self.store = hs.get_datastores().main
@@ -48,7 +46,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
self.u_charlie = UserID.from_string("@charlie:elsewhere")
def test_one_member(self) -> None:
-
# Alice creates the room, and is automatically joined
self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 333a67e84d..8b30832a35 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -242,7 +242,7 @@ class StateStoreTestCase(HomeserverTestCase):
# test _get_state_for_group_using_cache correctly filters out members
# with types=[]
- (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+ state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -259,7 +259,7 @@ class StateStoreTestCase(HomeserverTestCase):
state_dict,
)
- (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+ state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -272,7 +272,7 @@ class StateStoreTestCase(HomeserverTestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with wildcard types
- (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+ state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -289,7 +289,7 @@ class StateStoreTestCase(HomeserverTestCase):
state_dict,
)
- (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+ state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -309,7 +309,7 @@ class StateStoreTestCase(HomeserverTestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+ state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -327,7 +327,7 @@ class StateStoreTestCase(HomeserverTestCase):
state_dict,
)
- (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+ state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -341,7 +341,7 @@ class StateStoreTestCase(HomeserverTestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+ state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -392,7 +392,7 @@ class StateStoreTestCase(HomeserverTestCase):
# test _get_state_for_group_using_cache correctly filters out members
# with types=[]
room_id = self.room.to_string()
- (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+ state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -404,7 +404,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.assertDictEqual({}, state_dict)
room_id = self.room.to_string()
- (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+ state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -417,7 +417,7 @@ class StateStoreTestCase(HomeserverTestCase):
# test _get_state_for_group_using_cache correctly filters in members
# wildcard types
- (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+ state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -428,7 +428,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({}, state_dict)
- (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+ state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -447,7 +447,7 @@ class StateStoreTestCase(HomeserverTestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+ state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -459,7 +459,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({}, state_dict)
- (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+ state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -473,7 +473,7 @@ class StateStoreTestCase(HomeserverTestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+ state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -485,7 +485,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({}, state_dict)
- (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+ state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -496,3 +496,129 @@ class StateStoreTestCase(HomeserverTestCase):
self.assertEqual(is_all, True)
self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
+
+ def test_batched_state_group_storing(self) -> None:
+ creation_event = self.inject_state_event(
+ self.room, self.u_alice, EventTypes.Create, "", {}
+ )
+ state_to_event = self.get_success(
+ self.storage.state.get_state_groups(
+ self.room.to_string(), [creation_event.event_id]
+ )
+ )
+ current_state_group = list(state_to_event.keys())[0]
+
+ # create some unpersisted events and event contexts to store against room
+ events_and_context = []
+ builder = self.event_builder_factory.for_room_version(
+ RoomVersions.V1,
+ {
+ "type": EventTypes.Name,
+ "sender": self.u_alice.to_string(),
+ "state_key": "",
+ "room_id": self.room.to_string(),
+ "content": {"name": "first rename of room"},
+ },
+ )
+
+ event1, unpersisted_context1 = self.get_success(
+ self.event_creation_handler.create_new_client_event(builder)
+ )
+ events_and_context.append((event1, unpersisted_context1))
+
+ builder2 = self.event_builder_factory.for_room_version(
+ RoomVersions.V1,
+ {
+ "type": EventTypes.JoinRules,
+ "sender": self.u_alice.to_string(),
+ "state_key": "",
+ "room_id": self.room.to_string(),
+ "content": {"join_rule": "private"},
+ },
+ )
+
+ event2, unpersisted_context2 = self.get_success(
+ self.event_creation_handler.create_new_client_event(builder2)
+ )
+ events_and_context.append((event2, unpersisted_context2))
+
+ builder3 = self.event_builder_factory.for_room_version(
+ RoomVersions.V1,
+ {
+ "type": EventTypes.Message,
+ "sender": self.u_alice.to_string(),
+ "room_id": self.room.to_string(),
+ "content": {"body": "hello from event 3", "msgtype": "m.text"},
+ },
+ )
+
+ event3, unpersisted_context3 = self.get_success(
+ self.event_creation_handler.create_new_client_event(builder3)
+ )
+ events_and_context.append((event3, unpersisted_context3))
+
+ builder4 = self.event_builder_factory.for_room_version(
+ RoomVersions.V1,
+ {
+ "type": EventTypes.JoinRules,
+ "sender": self.u_alice.to_string(),
+ "state_key": "",
+ "room_id": self.room.to_string(),
+ "content": {"join_rule": "public"},
+ },
+ )
+
+ event4, unpersisted_context4 = self.get_success(
+ self.event_creation_handler.create_new_client_event(builder4)
+ )
+ events_and_context.append((event4, unpersisted_context4))
+
+ processed_events_and_context = self.get_success(
+ self.hs.get_datastores().state.store_state_deltas_for_batched(
+ events_and_context, self.room.to_string(), current_state_group
+ )
+ )
+
+ # check that only state events are in state_groups, and all state events are in state_groups
+ res = self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="state_groups",
+ keyvalues=None,
+ retcols=("event_id",),
+ )
+ )
+
+ events = []
+ for result in res:
+ self.assertNotIn(event3.event_id, result)
+ events.append(result.get("event_id"))
+
+ for event, _ in processed_events_and_context:
+ if event.is_state():
+ self.assertIn(event.event_id, events)
+
+ # check that each unique state has state group in state_groups_state and that the
+ # type/state key is correct, and check that each state event's state group
+ # has an entry and prev event in state_group_edges
+ for event, context in processed_events_and_context:
+ if event.is_state():
+ state = self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="state_groups_state",
+ keyvalues={"state_group": context.state_group_after_event},
+ retcols=("type", "state_key"),
+ )
+ )
+ self.assertEqual(event.type, state[0].get("type"))
+ self.assertEqual(event.state_key, state[0].get("state_key"))
+
+ groups = self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="state_group_edges",
+ keyvalues={"state_group": str(context.state_group_after_event)},
+ retcols=("*",),
+ )
+ )
+ self.assertEqual(
+ context.state_group_before_event, groups[0].get("prev_state_group")
+ )
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index f1ca523d23..8c72aa1722 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -25,6 +25,11 @@ from synapse.rest.client import login, register, room
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.background_updates import _BackgroundUpdateHandler
+from synapse.storage.databases.main import user_directory
+from synapse.storage.databases.main.user_directory import (
+ _parse_words_with_icu,
+ _parse_words_with_regex,
+)
from synapse.storage.roommember import ProfileInfo
from synapse.util import Clock
@@ -42,7 +47,7 @@ ALICE = "@alice:a"
BOB = "@bob:b"
BOBBY = "@bobby:a"
# The localpart isn't 'Bela' on purpose so we can test looking up display names.
-BELA = "@somenickname:a"
+BELA = "@somenickname:example.org"
class GetUserDirectoryTables:
@@ -423,6 +428,8 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
class UserDirectoryStoreTestCase(HomeserverTestCase):
+ use_icu = False
+
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
@@ -434,6 +441,12 @@ class UserDirectoryStoreTestCase(HomeserverTestCase):
self.get_success(self.store.update_profile_in_user_dir(BELA, "Bela", None))
self.get_success(self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)))
+ self._restore_use_icu = user_directory.USE_ICU
+ user_directory.USE_ICU = self.use_icu
+
+ def tearDown(self) -> None:
+ user_directory.USE_ICU = self._restore_use_icu
+
def test_search_user_dir(self) -> None:
# normally when alice searches the directory she should just find
# bob because bobby doesn't share a room with her.
@@ -478,6 +491,159 @@ class UserDirectoryStoreTestCase(HomeserverTestCase):
{"user_id": BELA, "display_name": "Bela", "avatar_url": None},
)
+ @override_config({"user_directory": {"search_all_users": True}})
+ def test_search_user_dir_start_of_user_id(self) -> None:
+ """Tests that a user can look up another user by searching for the start
+ of their user ID.
+ """
+ r = self.get_success(self.store.search_user_dir(ALICE, "somenickname:exa", 10))
+ self.assertFalse(r["limited"])
+ self.assertEqual(1, len(r["results"]))
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": BELA, "display_name": "Bela", "avatar_url": None},
+ )
+
+ @override_config({"user_directory": {"search_all_users": True}})
+ def test_search_user_dir_ascii_case_insensitivity(self) -> None:
+ """Tests that a user can look up another user by searching for their name in a
+ different case.
+ """
+ CHARLIE = "@someuser:example.org"
+ self.get_success(
+ self.store.update_profile_in_user_dir(CHARLIE, "Charlie", None)
+ )
+
+ r = self.get_success(self.store.search_user_dir(ALICE, "cHARLIE", 10))
+ self.assertFalse(r["limited"])
+ self.assertEqual(1, len(r["results"]))
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": CHARLIE, "display_name": "Charlie", "avatar_url": None},
+ )
+
+ @override_config({"user_directory": {"search_all_users": True}})
+ def test_search_user_dir_unicode_case_insensitivity(self) -> None:
+ """Tests that a user can look up another user by searching for their name in a
+ different case.
+ """
+ IVAN = "@someuser:example.org"
+ self.get_success(self.store.update_profile_in_user_dir(IVAN, "Иван", None))
+
+ r = self.get_success(self.store.search_user_dir(ALICE, "иВАН", 10))
+ self.assertFalse(r["limited"])
+ self.assertEqual(1, len(r["results"]))
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": IVAN, "display_name": "Иван", "avatar_url": None},
+ )
+
+ @override_config({"user_directory": {"search_all_users": True}})
+ def test_search_user_dir_dotted_dotless_i_case_insensitivity(self) -> None:
+ """Tests that a user can look up another user by searching for their name in a
+ different case, when their name contains dotted or dotless "i"s.
+
+ Some languages have dotted and dotless versions of "i", which are considered to
+ be different letters: i <-> İ, ı <-> I. To make things difficult, they reuse the
+ ASCII "i" and "I" code points, despite having different lowercase / uppercase
+ forms.
+ """
+ USER = "@someuser:example.org"
+
+ expected_matches = [
+ # (search_term, display_name)
+ # A search for "i" should match "İ".
+ ("iiiii", "İİİİİ"),
+ # A search for "I" should match "ı".
+ ("IIIII", "ııııı"),
+ # A search for "ı" should match "I".
+ ("ııııı", "IIIII"),
+ # A search for "İ" should match "i".
+ ("İİİİİ", "iiiii"),
+ ]
+
+ for search_term, display_name in expected_matches:
+ self.get_success(
+ self.store.update_profile_in_user_dir(USER, display_name, None)
+ )
+
+ r = self.get_success(self.store.search_user_dir(ALICE, search_term, 10))
+ self.assertFalse(r["limited"])
+ self.assertEqual(
+ 1,
+ len(r["results"]),
+ f"searching for {search_term!r} did not match {display_name!r}",
+ )
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": USER, "display_name": display_name, "avatar_url": None},
+ )
+
+ # We don't test for negative matches, to allow implementations that consider all
+ # the i variants to be the same.
+
+ test_search_user_dir_dotted_dotless_i_case_insensitivity.skip = "not supported" # type: ignore
+
+ @override_config({"user_directory": {"search_all_users": True}})
+ def test_search_user_dir_unicode_normalization(self) -> None:
+ """Tests that a user can look up another user by searching for their name with
+ either composed or decomposed accents.
+ """
+ AMELIE = "@someuser:example.org"
+
+ expected_matches = [
+ # (search_term, display_name)
+ ("Ame\u0301lie", "Amélie"),
+ ("Amélie", "Ame\u0301lie"),
+ ]
+
+ for search_term, display_name in expected_matches:
+ self.get_success(
+ self.store.update_profile_in_user_dir(AMELIE, display_name, None)
+ )
+
+ r = self.get_success(self.store.search_user_dir(ALICE, search_term, 10))
+ self.assertFalse(r["limited"])
+ self.assertEqual(
+ 1,
+ len(r["results"]),
+ f"searching for {search_term!r} did not match {display_name!r}",
+ )
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": AMELIE, "display_name": display_name, "avatar_url": None},
+ )
+
+ @override_config({"user_directory": {"search_all_users": True}})
+ def test_search_user_dir_accent_insensitivity(self) -> None:
+ """Tests that a user can look up another user by searching for their name
+ without any accents.
+ """
+ AMELIE = "@someuser:example.org"
+ self.get_success(self.store.update_profile_in_user_dir(AMELIE, "Amélie", None))
+
+ r = self.get_success(self.store.search_user_dir(ALICE, "amelie", 10))
+ self.assertFalse(r["limited"])
+ self.assertEqual(1, len(r["results"]))
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": AMELIE, "display_name": "Amélie", "avatar_url": None},
+ )
+
+ # It may be desirable for "é"s in search terms to not match plain "e"s and we
+ # really don't want "é"s in search terms to match "e"s with different accents.
+ # But we don't test for this to allow implementations that consider all
+ # "e"-lookalikes to be the same.
+
+ test_search_user_dir_accent_insensitivity.skip = "not supported yet" # type: ignore
+
+
+class UserDirectoryStoreTestCaseWithIcu(UserDirectoryStoreTestCase):
+ use_icu = True
+
+ if not icu:
+ skip = "Requires PyICU"
+
class UserDirectoryICUTestCase(HomeserverTestCase):
if not icu:
@@ -513,3 +679,33 @@ class UserDirectoryICUTestCase(HomeserverTestCase):
r["results"][0],
{"user_id": ALICE, "display_name": display_name, "avatar_url": None},
)
+
+ def test_icu_word_boundary_punctuation(self) -> None:
+ """
+ Tests the behaviour of punctuation with the ICU tokeniser.
+
+ Seems to depend on underlying version of ICU.
+ """
+
+ # Note: either tokenisation is fine, because Postgres actually splits
+ # words itself afterwards.
+ self.assertIn(
+ _parse_words_with_icu("lazy'fox jumped:over the.dog"),
+ (
+ # ICU 66 on Ubuntu 20.04
+ ["lazy'fox", "jumped", "over", "the", "dog"],
+ # ICU 70 on Ubuntu 22.04
+ ["lazy'fox", "jumped:over", "the.dog"],
+ # pyicu 2.10.2 on Alpine edge / macOS
+ ["lazy'fox", "jumped", "over", "the.dog"],
+ ),
+ )
+
+ def test_regex_word_boundary_punctuation(self) -> None:
+ """
+ Tests the behaviour of punctuation with the non-ICU tokeniser
+ """
+ self.assertEqual(
+ _parse_words_with_regex("lazy'fox jumped:over the.dog"),
+ ["lazy", "fox", "jumped", "over", "the", "dog"],
+ )
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 82dfd88b99..46d2f99eac 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -47,7 +47,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
room_creator.create_room(
our_user, room_creator._presets_dict["public_chat"], ratelimit=False
)
- )[0]["room_id"]
+ )[0]
self.store = self.hs.get_datastores().main
diff --git a/tests/test_mau.py b/tests/test_mau.py
index 4e7665a22b..ff21098a59 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -32,7 +32,6 @@ from tests.utils import default_config
class TestMauLimit(unittest.HomeserverTestCase):
-
servlets = [register.register_servlets, sync.register_servlets]
def default_config(self) -> JsonDict:
diff --git a/tests/unittest.py b/tests/unittest.py
index 2d73911747..4b31f84494 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -45,7 +45,7 @@ from typing_extensions import Concatenate, ParamSpec, Protocol
from twisted.internet.defer import Deferred, ensureDeferred
from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool
-from twisted.test.proto_helpers import MemoryReactor
+from twisted.test.proto_helpers import MemoryReactor, MemoryReactorClock
from twisted.trial import unittest
from twisted.web.resource import Resource
from twisted.web.server import Request
@@ -82,7 +82,7 @@ from tests.server import (
)
from tests.test_utils import event_injection, setup_awaitable_errors
from tests.test_utils.logging_setup import setup_logging
-from tests.utils import default_config, setupdb
+from tests.utils import checked_cast, default_config, setupdb
setupdb()
setup_logging()
@@ -296,7 +296,12 @@ class HomeserverTestCase(TestCase):
from tests.rest.client.utils import RestHelper
- self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None))
+ self.helper = RestHelper(
+ self.hs,
+ checked_cast(MemoryReactorClock, self.hs.get_reactor()),
+ self.site,
+ getattr(self, "user_id", None),
+ )
if hasattr(self, "user_id"):
if self.hijack_auth:
@@ -718,7 +723,7 @@ class HomeserverTestCase(TestCase):
event_creator = self.hs.get_event_creation_handler()
requester = create_requester(user)
- event, context, _ = self.get_success(
+ event, unpersisted_context, _ = self.get_success(
event_creator.create_event(
requester,
{
@@ -730,7 +735,7 @@ class HomeserverTestCase(TestCase):
prev_event_ids=prev_event_ids,
)
)
-
+ context = self.get_success(unpersisted_context.persist(event))
if soft_failed:
event.internal_metadata.soft_failed = True
diff --git a/tests/utils.py b/tests/utils.py
index a22f1d1241..3badfb7d41 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -15,7 +15,7 @@
import atexit
import os
-from typing import Any, Callable, Dict, List, Tuple, Union, overload
+from typing import Any, Callable, Dict, List, Tuple, Type, TypeVar, Union, overload
import attr
from typing_extensions import Literal, ParamSpec
@@ -343,3 +343,27 @@ async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None:
context = await unpersisted_context.persist(event)
await persistence_store.persist_event(event, context)
+
+
+T = TypeVar("T")
+
+
+def checked_cast(type: Type[T], x: object) -> T:
+ """A version of typing.cast that is checked at runtime.
+
+ We have our own function for this for two reasons:
+
+ 1. typing.cast itself is deliberately a no-op at runtime, see
+ https://docs.python.org/3/library/typing.html#typing.cast
+ 2. To help workaround a mypy-zope bug https://github.com/Shoobx/mypy-zope/issues/91
+ where mypy would erroneously consider `isinstance(x, type)` to be false in all
+ circumstances.
+
+ For this to make sense, `T` needs to be something that `isinstance` can check; see
+ https://docs.python.org/3/library/functions.html?highlight=isinstance#isinstance
+ https://docs.python.org/3/glossary.html#term-abstract-base-class
+ https://docs.python.org/3/library/typing.html#typing.runtime_checkable
+ for more details.
+ """
+ assert isinstance(x, type)
+ return x
|