diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index dfcfaf79b6..e0f363555b 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -284,10 +284,13 @@ class AuthTestCase(unittest.HomeserverTestCase):
TokenLookupResult(
user_id="@baldrick:matrix.org",
device_id="device",
+ token_id=5,
token_owner="@admin:matrix.org",
+ token_used=True,
)
)
self.store.insert_client_ip = simple_async_mock(None)
+ self.store.mark_access_token_as_used = simple_async_mock(None)
request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
@@ -301,10 +304,13 @@ class AuthTestCase(unittest.HomeserverTestCase):
TokenLookupResult(
user_id="@baldrick:matrix.org",
device_id="device",
+ token_id=5,
token_owner="@admin:matrix.org",
+ token_used=True,
)
)
self.store.insert_client_ip = simple_async_mock(None)
+ self.store.mark_access_token_as_used = simple_async_mock(None)
request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
@@ -347,7 +353,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
serialized = macaroon.serialize()
user_info = self.get_success(self.auth.get_user_by_access_token(serialized))
- self.assertEqual(user_id, user_info.user_id)
+ self.assertEqual(user_id, user_info.user.to_string())
self.assertTrue(user_info.is_guest)
self.store.get_user_by_id.assert_called_with(user_id)
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index 264e101082..c7dae58eb5 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -61,7 +61,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
}
# Listen with the config
- self.hs._listen_http(parse_listener_def(config))
+ self.hs._listen_http(parse_listener_def(0, config))
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
@@ -109,7 +109,7 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
}
# Listen with the config
- self.hs._listener_http(self.hs.config, parse_listener_def(config))
+ self.hs._listener_http(self.hs.config, parse_listener_def(0, config))
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index ffc3012a86..685a9a6d52 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -141,10 +141,6 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
hs = self.setup_test_homeserver(
federation_transport_client=fed_transport_client,
)
- # Load the modules into the homeserver
- module_api = hs.get_module_api()
- for module, config in hs.config.modules.loaded_modules:
- module(config=config, api=module_api)
load_legacy_presence_router(hs)
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 01a1db6115..a5aa500ef8 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -173,17 +173,24 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
return c
def prepare(self, reactor, clock, hs):
- # stub out `get_rooms_for_user` and `get_users_in_room` so that the
+ test_room_id = "!room:host1"
+
+ # stub out `get_rooms_for_user` and `get_current_hosts_in_room` so that the
# server thinks the user shares a room with `@user2:host2`
def get_rooms_for_user(user_id):
- return defer.succeed({"!room:host1"})
+ return defer.succeed({test_room_id})
hs.get_datastores().main.get_rooms_for_user = get_rooms_for_user
- def get_users_in_room(room_id):
- return defer.succeed({"@user2:host2"})
+ async def get_current_hosts_in_room(room_id):
+ if room_id == test_room_id:
+ return ["host2"]
+
+ # TODO: We should fail the test when we encounter an unxpected room ID.
+ # We can't just use `self.fail(...)` here because the app code is greedy
+ # with `Exception` and will catch it before the test can see it.
- hs.get_datastores().main.get_users_in_room = get_users_in_room
+ hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room
# whenever send_transaction is called, record the edu data
self.edus = []
diff --git a/tests/federation/transport/server/test__base.py b/tests/federation/transport/server/test__base.py
index d33e86db4c..e88e5d8bb3 100644
--- a/tests/federation/transport/server/test__base.py
+++ b/tests/federation/transport/server/test__base.py
@@ -18,9 +18,10 @@ from typing import Dict, List, Tuple
from synapse.api.errors import Codes
from synapse.federation.transport.server import BaseFederationServlet
from synapse.federation.transport.server._base import Authenticator, _parse_auth_header
-from synapse.http.server import JsonResource, cancellable
+from synapse.http.server import JsonResource
from synapse.server import HomeServer
from synapse.types import JsonDict
+from synapse.util.cancellation import cancellable
from synapse.util.ratelimitutils import FederationRateLimiter
from tests import unittest
diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py
index 82baa8f154..7b9b711521 100644
--- a/tests/handlers/test_deactivate_account.py
+++ b/tests/handlers/test_deactivate_account.py
@@ -322,3 +322,18 @@ class DeactivateAccountTestCase(HomeserverTestCase):
)
),
)
+
+ def test_deactivate_account_needs_auth(self) -> None:
+ """
+ Tests that making a request to /deactivate with an empty body
+ succeeds in starting the user-interactive auth flow.
+ """
+ req = self.make_request(
+ "POST",
+ "account/deactivate",
+ {},
+ access_token=self.token,
+ )
+
+ self.assertEqual(req.code, 401, req)
+ self.assertEqual(req.json_body["flows"], [{"stages": ["m.login.password"]}])
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 4c62449c89..75934b1707 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -21,7 +21,6 @@ from unittest.mock import Mock
import synapse
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
-from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.module_api import ModuleApi
from synapse.rest.client import account, devices, login, logout, register
from synapse.types import JsonDict, UserID
@@ -167,16 +166,6 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock()
super().setUp()
- def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver()
- # Load the modules into the homeserver
- module_api = hs.get_module_api()
- for module, config in hs.config.modules.loaded_modules:
- module(config=config, api=module_api)
- load_legacy_password_auth_providers(hs)
-
- return hs
-
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_password_only_auth_progiver_login_legacy(self):
self.password_only_auth_provider_login_test_body()
diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py
index 5f70a2db79..b55238650c 100644
--- a/tests/handlers/test_receipts.py
+++ b/tests/handlers/test_receipts.py
@@ -15,8 +15,6 @@
from copy import deepcopy
from typing import List
-from parameterized import parameterized
-
from synapse.api.constants import EduTypes, ReceiptTypes
from synapse.types import JsonDict
@@ -27,16 +25,13 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.event_source = hs.get_event_sources().sources.receipt
- @parameterized.expand(
- [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
- )
- def test_filters_out_private_receipt(self, receipt_type: str) -> None:
+ def test_filters_out_private_receipt(self) -> None:
self._test_filters_private(
[
{
"content": {
"$1435641916114394fHBLK:matrix.org": {
- receipt_type: {
+ ReceiptTypes.READ_PRIVATE: {
"@rikj:jki.re": {
"ts": 1436451550453,
}
@@ -50,18 +45,13 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
[],
)
- @parameterized.expand(
- [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
- )
- def test_filters_out_private_receipt_and_ignores_rest(
- self, receipt_type: str
- ) -> None:
+ def test_filters_out_private_receipt_and_ignores_rest(self) -> None:
self._test_filters_private(
[
{
"content": {
"$1dgdgrd5641916114394fHBLK:matrix.org": {
- receipt_type: {
+ ReceiptTypes.READ_PRIVATE: {
"@rikj:jki.re": {
"ts": 1436451550453,
},
@@ -94,18 +84,15 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
],
)
- @parameterized.expand(
- [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
- )
def test_filters_out_event_with_only_private_receipts_and_ignores_the_rest(
- self, receipt_type: str
+ self,
) -> None:
self._test_filters_private(
[
{
"content": {
"$14356419edgd14394fHBLK:matrix.org": {
- receipt_type: {
+ ReceiptTypes.READ_PRIVATE: {
"@rikj:jki.re": {
"ts": 1436451550453,
},
@@ -175,18 +162,15 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
],
)
- @parameterized.expand(
- [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
- )
def test_filters_out_receipt_event_with_only_private_receipt_and_ignores_rest(
- self, receipt_type: str
+ self,
) -> None:
self._test_filters_private(
[
{
"content": {
"$14356419edgd14394fHBLK:matrix.org": {
- receipt_type: {
+ ReceiptTypes.READ_PRIVATE: {
"@rikj:jki.re": {
"ts": 1436451550453,
},
@@ -262,16 +246,13 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
],
)
- @parameterized.expand(
- [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
- )
- def test_leaves_our_private_and_their_public(self, receipt_type: str) -> None:
+ def test_leaves_our_private_and_their_public(self) -> None:
self._test_filters_private(
[
{
"content": {
"$1dgdgrd5641916114394fHBLK:matrix.org": {
- receipt_type: {
+ ReceiptTypes.READ_PRIVATE: {
"@me:server.org": {
"ts": 1436451550453,
},
@@ -296,7 +277,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
{
"content": {
"$1dgdgrd5641916114394fHBLK:matrix.org": {
- receipt_type: {
+ ReceiptTypes.READ_PRIVATE: {
"@me:server.org": {
"ts": 1436451550453,
},
@@ -319,16 +300,13 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
],
)
- @parameterized.expand(
- [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
- )
- def test_we_do_not_mutate(self, receipt_type: str) -> None:
+ def test_we_do_not_mutate(self) -> None:
"""Ensure the input values are not modified."""
events = [
{
"content": {
"$1435641916114394fHBLK:matrix.org": {
- receipt_type: {
+ ReceiptTypes.READ_PRIVATE: {
"@rikj:jki.re": {
"ts": 1436451550453,
}
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 23f35d5bf5..86b3d51975 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -22,7 +22,6 @@ from synapse.api.errors import (
ResourceLimitError,
SynapseError,
)
-from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, RoomID, UserID, create_requester
@@ -144,12 +143,6 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
config=hs_config, federation_client=self.mock_federation_client
)
- load_legacy_spam_checkers(hs)
-
- module_api = hs.get_module_api()
- for module, config in hs.config.modules.loaded_modules:
- module(config=config, api=module_api)
-
return hs
def prepare(self, reactor, clock, hs):
diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py
index b4e1405aee..6bbfd5dc84 100644
--- a/tests/handlers/test_room_member.py
+++ b/tests/handlers/test_room_member.py
@@ -6,7 +6,7 @@ import synapse.rest.admin
import synapse.rest.client.login
import synapse.rest.client.room
from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import LimitExceededError
+from synapse.api.errors import LimitExceededError, SynapseError
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events import FrozenEventV3
from synapse.federation.federation_client import SendJoinResult
@@ -14,10 +14,14 @@ from synapse.server import HomeServer
from synapse.types import UserID, create_requester
from synapse.util import Clock
-from tests.replication._base import RedisMultiWorkerStreamTestCase
+from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request
from tests.test_utils import make_awaitable
-from tests.unittest import FederatingHomeserverTestCase, override_config
+from tests.unittest import (
+ FederatingHomeserverTestCase,
+ HomeserverTestCase,
+ override_config,
+)
class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
@@ -216,7 +220,7 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
# - trying to remote-join again.
-class TestReplicatedJoinsLimitedByPerRoomRateLimiter(RedisMultiWorkerStreamTestCase):
+class TestReplicatedJoinsLimitedByPerRoomRateLimiter(BaseMultiWorkerStreamTestCase):
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.client.login.register_servlets,
@@ -287,3 +291,88 @@ class TestReplicatedJoinsLimitedByPerRoomRateLimiter(RedisMultiWorkerStreamTestC
),
LimitExceededError,
)
+
+
+class RoomMemberMasterHandlerTestCase(HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ synapse.rest.client.login.register_servlets,
+ synapse.rest.client.room.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.handler = hs.get_room_member_handler()
+ self.store = hs.get_datastores().main
+
+ # Create two users.
+ self.alice = self.register_user("alice", "pass")
+ self.alice_ID = UserID.from_string(self.alice)
+ self.alice_token = self.login("alice", "pass")
+ self.bob = self.register_user("bob", "pass")
+ self.bob_ID = UserID.from_string(self.bob)
+ self.bob_token = self.login("bob", "pass")
+
+ # Create a room on this homeserver.
+ self.room_id = self.helper.create_room_as(self.alice, tok=self.alice_token)
+
+ def test_leave_and_forget(self) -> None:
+ """Tests that forget a room is successfully. The test is performed with two users,
+ as forgetting by the last user respectively after all users had left the
+ is a special edge case."""
+ self.helper.join(self.room_id, user=self.bob, tok=self.bob_token)
+
+ # alice is not the last room member that leaves and forgets the room
+ self.helper.leave(self.room_id, user=self.alice, tok=self.alice_token)
+ self.get_success(self.handler.forget(self.alice_ID, self.room_id))
+ self.assertTrue(
+ self.get_success(self.store.did_forget(self.alice, self.room_id))
+ )
+
+ # the server has not forgotten the room
+ self.assertFalse(
+ self.get_success(self.store.is_locally_forgotten_room(self.room_id))
+ )
+
+ def test_leave_and_forget_last_user(self) -> None:
+ """Tests that forget a room is successfully when the last user has left the room."""
+
+ # alice is the last room member that leaves and forgets the room
+ self.helper.leave(self.room_id, user=self.alice, tok=self.alice_token)
+ self.get_success(self.handler.forget(self.alice_ID, self.room_id))
+ self.assertTrue(
+ self.get_success(self.store.did_forget(self.alice, self.room_id))
+ )
+
+ # the server has forgotten the room
+ self.assertTrue(
+ self.get_success(self.store.is_locally_forgotten_room(self.room_id))
+ )
+
+ def test_forget_when_not_left(self) -> None:
+ """Tests that a user cannot not forgets a room that has not left."""
+ self.get_failure(self.handler.forget(self.alice_ID, self.room_id), SynapseError)
+
+ def test_rejoin_forgotten_by_user(self) -> None:
+ """Test that a user that has forgotten a room can do a re-join.
+ The room was not forgotten from the local server.
+ One local user is still member of the room."""
+ self.helper.join(self.room_id, user=self.bob, tok=self.bob_token)
+
+ self.helper.leave(self.room_id, user=self.alice, tok=self.alice_token)
+ self.get_success(self.handler.forget(self.alice_ID, self.room_id))
+ self.assertTrue(
+ self.get_success(self.store.did_forget(self.alice, self.room_id))
+ )
+
+ # the server has not forgotten the room
+ self.assertFalse(
+ self.get_success(self.store.is_locally_forgotten_room(self.room_id))
+ )
+
+ self.helper.join(self.room_id, user=self.alice, tok=self.alice_token)
+ # TODO: A join to a room does not invalidate the forgotten cache
+ # see https://github.com/matrix-org/synapse/issues/13262
+ self.store.did_forget.invalidate_all()
+ self.assertFalse(
+ self.get_success(self.store.did_forget(self.alice, self.room_id))
+ )
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 7af1333126..8adba29d7f 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -25,7 +25,7 @@ from synapse.api.constants import EduTypes
from synapse.api.errors import AuthError
from synapse.federation.transport.server import TransportLayerServer
from synapse.server import HomeServer
-from synapse.types import JsonDict, UserID, create_requester
+from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util import Clock
from tests import unittest
@@ -117,8 +117,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.room_members = []
- async def check_user_in_room(room_id: str, user_id: str) -> None:
- if user_id not in [u.to_string() for u in self.room_members]:
+ async def check_user_in_room(room_id: str, requester: Requester) -> None:
+ if requester.user.to_string() not in [
+ u.to_string() for u in self.room_members
+ ]:
raise AuthError(401, "User is not in the room")
return None
diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py
index 5726e60cee..5071f83574 100644
--- a/tests/http/server/_base.py
+++ b/tests/http/server/_base.py
@@ -140,6 +140,8 @@ def make_request_with_cancellation_test(
method: str,
path: str,
content: Union[bytes, str, JsonDict] = b"",
+ *,
+ token: Optional[str] = None,
) -> FakeChannel:
"""Performs a request repeatedly, disconnecting at successive `await`s, until
one completes.
@@ -211,7 +213,13 @@ def make_request_with_cancellation_test(
with deferred_patch.patch():
# Start the request.
channel = make_request(
- reactor, site, method, path, content, await_result=False
+ reactor,
+ site,
+ method,
+ path,
+ content,
+ await_result=False,
+ access_token=token,
)
request = channel.request
diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py
index bb966c80c6..3cbca0f5a3 100644
--- a/tests/http/test_servlet.py
+++ b/tests/http/test_servlet.py
@@ -18,7 +18,6 @@ from typing import Tuple
from unittest.mock import Mock
from synapse.api.errors import Codes, SynapseError
-from synapse.http.server import cancellable
from synapse.http.servlet import (
RestServlet,
parse_json_object_from_request,
@@ -28,6 +27,7 @@ from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.server import HomeServer
from synapse.types import JsonDict
+from synapse.util.cancellation import cancellable
from tests import unittest
from tests.http.server._base import test_disconnect
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 106159fa65..02cef6f876 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -30,7 +30,6 @@ from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.test_utils import simple_async_mock
from tests.test_utils.event_injection import inject_member_event
from tests.unittest import HomeserverTestCase, override_config
-from tests.utils import USE_POSTGRES_FOR_TESTS
class ModuleApiTestCase(HomeserverTestCase):
@@ -738,11 +737,6 @@ class ModuleApiTestCase(HomeserverTestCase):
class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
"""For testing ModuleApi functionality in a multi-worker setup"""
- # Testing stream ID replication from the main to worker processes requires postgres
- # (due to needing `MultiWriterIdGenerator`).
- if not USE_POSTGRES_FOR_TESTS:
- skip = "Requires Postgres"
-
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -752,7 +746,6 @@ class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
def default_config(self):
conf = super().default_config()
- conf["redis"] = {"enabled": "true"}
conf["stream_writers"] = {"presence": ["presence_writer"]}
conf["instance_map"] = {
"presence_writer": {"host": "testserv", "port": 1001},
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 970d5e533b..ce53f808db 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -24,11 +24,11 @@ from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.handler import ReplicationCommandHandler
-from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
-from synapse.replication.tcp.resource import (
- ReplicationStreamProtocolFactory,
+from synapse.replication.tcp.protocol import (
+ ClientReplicationStreamProtocol,
ServerReplicationStreamProtocol,
)
+from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.server import HomeServer
from tests import unittest
@@ -220,15 +220,34 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests running multiple workers.
+ Enables Redis, providing a fake Redis server.
+
Automatically handle HTTP replication requests from workers to master,
unlike `BaseStreamTestCase`.
"""
+ if not hiredis:
+ skip = "Requires hiredis"
+
+ if not USE_POSTGRES_FOR_TESTS:
+ # Redis replication only takes place on Postgres
+ skip = "Requires Postgres"
+
+ def default_config(self) -> Dict[str, Any]:
+ """
+ Overrides the default config to enable Redis.
+ Even if the test only uses make_worker_hs, the main process needs Redis
+ enabled otherwise it won't create a Fake Redis server to listen on the
+ Redis port and accept fake TCP connections.
+ """
+ base = super().default_config()
+ base["redis"] = {"enabled": True}
+ return base
+
def setUp(self):
super().setUp()
# build a replication server
- self.server_factory = ReplicationStreamProtocolFactory(self.hs)
self.streamer = self.hs.get_replication_streamer()
# Fake in memory Redis server that servers can connect to.
@@ -247,15 +266,14 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# handling inbound HTTP requests to that instance.
self._hs_to_site = {self.hs: self.site}
- if self.hs.config.redis.redis_enabled:
- # Handle attempts to connect to fake redis server.
- self.reactor.add_tcp_client_callback(
- "localhost",
- 6379,
- self.connect_any_redis_attempts,
- )
+ # Handle attempts to connect to fake redis server.
+ self.reactor.add_tcp_client_callback(
+ "localhost",
+ 6379,
+ self.connect_any_redis_attempts,
+ )
- self.hs.get_replication_command_handler().start_replication(self.hs)
+ self.hs.get_replication_command_handler().start_replication(self.hs)
# When we see a connection attempt to the master replication listener we
# automatically set up the connection. This is so that tests don't
@@ -339,27 +357,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
store = worker_hs.get_datastores().main
store.db_pool._db_pool = self.database_pool._db_pool
- # Set up TCP replication between master and the new worker if we don't
- # have Redis support enabled.
- if not worker_hs.config.redis.redis_enabled:
- repl_handler = ReplicationCommandHandler(worker_hs)
- client = ClientReplicationStreamProtocol(
- worker_hs,
- "client",
- "test",
- self.clock,
- repl_handler,
- )
- server = self.server_factory.buildProtocol(
- IPv4Address("TCP", "127.0.0.1", 0)
- )
-
- client_transport = FakeTransport(server, self.reactor)
- client.makeConnection(client_transport)
-
- server_transport = FakeTransport(client, self.reactor)
- server.makeConnection(server_transport)
-
# Set up a resource for the worker
resource = ReplicationRestResource(worker_hs)
@@ -378,8 +375,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
reactor=self.reactor,
)
- if worker_hs.config.redis.redis_enabled:
- worker_hs.get_replication_command_handler().start_replication(worker_hs)
+ worker_hs.get_replication_command_handler().start_replication(worker_hs)
return worker_hs
@@ -582,27 +578,3 @@ class FakeRedisPubSubProtocol(Protocol):
def connectionLost(self, reason):
self._server.remove_subscriber(self)
-
-
-class RedisMultiWorkerStreamTestCase(BaseMultiWorkerStreamTestCase):
- """
- A test case that enables Redis, providing a fake Redis server.
- """
-
- if not hiredis:
- skip = "Requires hiredis"
-
- if not USE_POSTGRES_FOR_TESTS:
- # Redis replication only takes place on Postgres
- skip = "Requires Postgres"
-
- def default_config(self) -> Dict[str, Any]:
- """
- Overrides the default config to enable Redis.
- Even if the test only uses make_worker_hs, the main process needs Redis
- enabled otherwise it won't create a Fake Redis server to listen on the
- Redis port and accept fake TCP connections.
- """
- base = super().default_config()
- base["redis"] = {"enabled": True}
- return base
diff --git a/tests/replication/http/test__base.py b/tests/replication/http/test__base.py
index 822a957c3a..936ab4504a 100644
--- a/tests/replication/http/test__base.py
+++ b/tests/replication/http/test__base.py
@@ -18,11 +18,12 @@ from typing import Tuple
from twisted.web.server import Request
from synapse.api.errors import Codes
-from synapse.http.server import JsonResource, cancellable
+from synapse.http.server import JsonResource
from synapse.replication.http import REPLICATION_PREFIX
from synapse.replication.http._base import ReplicationEndpoint
from synapse.server import HomeServer
from synapse.types import JsonDict
+from synapse.util.cancellation import cancellable
from tests import unittest
from tests.http.server._base import test_disconnect
diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py
index e6a19eafd5..1e299d2d67 100644
--- a/tests/replication/tcp/test_handler.py
+++ b/tests/replication/tcp/test_handler.py
@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from tests.replication._base import RedisMultiWorkerStreamTestCase
+from tests.replication._base import BaseMultiWorkerStreamTestCase
-class ChannelsTestCase(RedisMultiWorkerStreamTestCase):
+class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
def test_subscribed_to_enough_redis_channels(self) -> None:
# The default main process is subscribed to the USER_IP channel.
self.assertCountEqual(
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index a7ca68069e..541d390286 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -20,7 +20,6 @@ from synapse.storage.util.id_generators import MultiWriterIdGenerator
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request
-from tests.utils import USE_POSTGRES_FOR_TESTS
logger = logging.getLogger(__name__)
@@ -28,11 +27,6 @@ logger = logging.getLogger(__name__)
class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
"""Checks event persisting sharding works"""
- # Event persister sharding requires postgres (due to needing
- # `MultiWriterIdGenerator`).
- if not USE_POSTGRES_FOR_TESTS:
- skip = "Requires Postgres"
-
servlets = [
admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -50,7 +44,6 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
def default_config(self):
conf = super().default_config()
- conf["redis"] = {"enabled": "true"}
conf["stream_writers"] = {"events": ["worker1", "worker2"]}
conf["instance_map"] = {
"worker1": {"host": "testserv", "port": 1001},
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index fd6da557c1..d156be82b0 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import json
+import time
import urllib.parse
from typing import List, Optional
from unittest.mock import Mock
@@ -22,10 +24,11 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.constants import EventTypes, Membership, RoomTypes
from synapse.api.errors import Codes
-from synapse.handlers.pagination import PaginationHandler
+from synapse.handlers.pagination import PaginationHandler, PurgeStatus
from synapse.rest.client import directory, events, login, room
from synapse.server import HomeServer
from synapse.util import Clock
+from synapse.util.stringutils import random_string
from tests import unittest
@@ -1080,7 +1083,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
room_ids = []
for _ in range(total_rooms):
room_id = self.helper.create_room_as(
- self.admin_user, tok=self.admin_user_tok
+ self.admin_user,
+ tok=self.admin_user_tok,
+ is_public=True,
)
room_ids.append(room_id)
@@ -1119,8 +1124,8 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertIn("version", r)
self.assertIn("creator", r)
self.assertIn("encryption", r)
- self.assertIn("federatable", r)
- self.assertIn("public", r)
+ self.assertIs(r["federatable"], True)
+ self.assertIs(r["public"], True)
self.assertIn("join_rules", r)
self.assertIn("guest_access", r)
self.assertIn("history_visibility", r)
@@ -1587,8 +1592,12 @@ class RoomTestCase(unittest.HomeserverTestCase):
def test_single_room(self) -> None:
"""Test that a single room can be requested correctly"""
# Create two test rooms
- room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
- room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_1 = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok, is_public=True
+ )
+ room_id_2 = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok, is_public=False
+ )
room_name_1 = "something"
room_name_2 = "else"
@@ -1634,7 +1643,10 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertIn("state_events", channel.json_body)
self.assertIn("room_type", channel.json_body)
self.assertIn("forgotten", channel.json_body)
+
self.assertEqual(room_id_1, channel.json_body["room_id"])
+ self.assertIs(True, channel.json_body["federatable"])
+ self.assertIs(True, channel.json_body["public"])
def test_single_room_devices(self) -> None:
"""Test that `joined_local_devices` can be requested correctly"""
@@ -1784,6 +1796,159 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+class RoomMessagesTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.user = self.register_user("foo", "pass")
+ self.user_tok = self.login("foo", "pass")
+ self.room_id = self.helper.create_room_as(self.user, tok=self.user_tok)
+
+ def test_timestamp_to_event(self) -> None:
+ """Test that providing the current timestamp can get the last event."""
+ self.helper.send(self.room_id, body="message 1", tok=self.user_tok)
+ second_event_id = self.helper.send(
+ self.room_id, body="message 2", tok=self.user_tok
+ )["event_id"]
+ ts = str(round(time.time() * 1000))
+
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/rooms/%s/timestamp_to_event?dir=b&ts=%s"
+ % (self.room_id, ts),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code)
+ self.assertIn("event_id", channel.json_body)
+ self.assertEqual(second_event_id, channel.json_body["event_id"])
+
+ def test_topo_token_is_accepted(self) -> None:
+ """Test Topo Token is accepted."""
+ token = "t1-0_0_0_0_0_0_0_0_0"
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code)
+ self.assertIn("start", channel.json_body)
+ self.assertEqual(token, channel.json_body["start"])
+ self.assertIn("chunk", channel.json_body)
+ self.assertIn("end", channel.json_body)
+
+ def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None:
+ """Test that stream token is accepted for forward pagination."""
+ token = "s0_0_0_0_0_0_0_0_0"
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code)
+ self.assertIn("start", channel.json_body)
+ self.assertEqual(token, channel.json_body["start"])
+ self.assertIn("chunk", channel.json_body)
+ self.assertIn("end", channel.json_body)
+
+ def test_room_messages_purge(self) -> None:
+ """Test room messages can be retrieved by an admin that isn't in the room."""
+ store = self.hs.get_datastores().main
+ pagination_handler = self.hs.get_pagination_handler()
+
+ # Send a first message in the room, which will be removed by the purge.
+ first_event_id = self.helper.send(
+ self.room_id, body="message 1", tok=self.user_tok
+ )["event_id"]
+ first_token = self.get_success(
+ store.get_topological_token_for_event(first_event_id)
+ )
+ first_token_str = self.get_success(first_token.to_string(store))
+
+ # Send a second message in the room, which won't be removed, and which we'll
+ # use as the marker to purge events before.
+ second_event_id = self.helper.send(
+ self.room_id, body="message 2", tok=self.user_tok
+ )["event_id"]
+ second_token = self.get_success(
+ store.get_topological_token_for_event(second_event_id)
+ )
+ second_token_str = self.get_success(second_token.to_string(store))
+
+ # Send a third event in the room to ensure we don't fall under any edge case
+ # due to our marker being the latest forward extremity in the room.
+ self.helper.send(self.room_id, body="message 3", tok=self.user_tok)
+
+ # Check that we get the first and second message when querying /messages.
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/rooms/%s/messages?from=%s&dir=b&filter=%s"
+ % (
+ self.room_id,
+ second_token_str,
+ json.dumps({"types": [EventTypes.Message]}),
+ ),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ chunk = channel.json_body["chunk"]
+ self.assertEqual(len(chunk), 2, [event["content"] for event in chunk])
+
+ # Purge every event before the second event.
+ purge_id = random_string(16)
+ pagination_handler._purges_by_id[purge_id] = PurgeStatus()
+ self.get_success(
+ pagination_handler._purge_history(
+ purge_id=purge_id,
+ room_id=self.room_id,
+ token=second_token_str,
+ delete_local_events=True,
+ )
+ )
+
+ # Check that we only get the second message through /message now that the first
+ # has been purged.
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/rooms/%s/messages?from=%s&dir=b&filter=%s"
+ % (
+ self.room_id,
+ second_token_str,
+ json.dumps({"types": [EventTypes.Message]}),
+ ),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ chunk = channel.json_body["chunk"]
+ self.assertEqual(len(chunk), 1, [event["content"] for event in chunk])
+
+ # Check that we get no event, but also no error, when querying /messages with
+ # the token that was pointing at the first event, because we don't have it
+ # anymore.
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/rooms/%s/messages?from=%s&dir=b&filter=%s"
+ % (
+ self.room_id,
+ first_token_str,
+ json.dumps({"types": [EventTypes.Message]}),
+ ),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ chunk = channel.json_body["chunk"]
+ self.assertEqual(len(chunk), 0, [event["content"] for event in chunk])
+
+
class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
servlets = [
diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
index 81e125e27d..a2f347f666 100644
--- a/tests/rest/admin/test_server_notice.py
+++ b/tests/rest/admin/test_server_notice.py
@@ -159,6 +159,62 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
self.assertEqual("'msgtype' not in content", channel.json_body["error"])
+ @override_config(
+ {
+ "server_notices": {
+ "system_mxid_localpart": "notices",
+ "system_mxid_avatar_url": "somthingwrong",
+ },
+ "max_avatar_size": "10M",
+ }
+ )
+ def test_invalid_avatar_url(self) -> None:
+ """If avatar url in homeserver.yaml is invalid and
+ "check avatar size and mime type" is set, an error is returned.
+ TODO: Should be checked when reading the configuration."""
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg"},
+ },
+ )
+
+ self.assertEqual(500, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+ @override_config(
+ {
+ "server_notices": {
+ "system_mxid_localpart": "notices",
+ "system_mxid_display_name": "test display name",
+ "system_mxid_avatar_url": None,
+ },
+ "max_avatar_size": "10M",
+ }
+ )
+ def test_displayname_is_set_avatar_is_none(self) -> None:
+ """
+ Tests that sending a server notices is successfully,
+ if a display_name is set, avatar_url is `None` and
+ "check avatar size and mime type" is set.
+ """
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg"},
+ },
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # user has one invite
+ self._check_invite_and_join_status(self.other_user, 1, 0)
+
def test_server_notice_disabled(self) -> None:
"""Tests that server returns error if server notice is disabled"""
channel = self.make_request(
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 411e4ec005..ec5ccf6fca 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1,4 +1,4 @@
-# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2018-2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -904,6 +904,96 @@ class UsersListTestCase(unittest.HomeserverTestCase):
)
+class UserDevicesTestCase(unittest.HomeserverTestCase):
+ """
+ Tests user device management-related Admin APIs.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
+ # Set up an Admin user to query the Admin API with.
+ self.admin_user_id = self.register_user("admin", "pass", admin=True)
+ self.admin_user_token = self.login("admin", "pass")
+
+ # Set up a test user to query the devices of.
+ self.other_user_device_id = "TESTDEVICEID"
+ self.other_user_device_display_name = "My Test Device"
+ self.other_user_client_ip = "1.2.3.4"
+ self.other_user_user_agent = "EquestriaTechnology/123.0"
+
+ self.other_user_id = self.register_user("user", "pass", displayname="User1")
+ self.other_user_token = self.login(
+ "user",
+ "pass",
+ device_id=self.other_user_device_id,
+ additional_request_fields={
+ "initial_device_display_name": self.other_user_device_display_name,
+ },
+ )
+
+ # Have the "other user" make a request so that the "last_seen_*" fields are
+ # populated in the tests below.
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v3/sync",
+ access_token=self.other_user_token,
+ client_ip=self.other_user_client_ip,
+ custom_headers=[
+ ("User-Agent", self.other_user_user_agent),
+ ],
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ def test_list_user_devices(self) -> None:
+ """
+ Tests that a user's devices and attributes are listed correctly via the Admin API.
+ """
+ # Request all devices of "other user"
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v2/users/{self.other_user_id}/devices",
+ access_token=self.admin_user_token,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Double-check we got the single device expected
+ user_devices = channel.json_body["devices"]
+ self.assertEqual(len(user_devices), 1)
+ self.assertEqual(channel.json_body["total"], 1)
+
+ # Check that all the attributes of the device reported are as expected.
+ self._validate_attributes_of_device_response(user_devices[0])
+
+ # Request just a single device for "other user" by its ID
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v2/users/{self.other_user_id}/devices/"
+ f"{self.other_user_device_id}",
+ access_token=self.admin_user_token,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Check that all the attributes of the device reported are as expected.
+ self._validate_attributes_of_device_response(channel.json_body)
+
+ def _validate_attributes_of_device_response(self, response: JsonDict) -> None:
+ # Check that all device expected attributes are present
+ self.assertEqual(response["user_id"], self.other_user_id)
+ self.assertEqual(response["device_id"], self.other_user_device_id)
+ self.assertEqual(response["display_name"], self.other_user_device_display_name)
+ self.assertEqual(response["last_seen_ip"], self.other_user_client_ip)
+ self.assertEqual(response["last_seen_user_agent"], self.other_user_user_agent)
+ self.assertIsInstance(response["last_seen_ts"], int)
+ self.assertGreater(response["last_seen_ts"], 0)
+
+
class DeactivateAccountTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -2490,6 +2580,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertIn("appservice_id", content)
self.assertIn("consent_server_notice_sent", content)
self.assertIn("consent_version", content)
+ self.assertIn("consent_ts", content)
self.assertIn("external_ids", content)
# This key was removed intentionally. Ensure it is not accidentally re-included.
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index dc17c9d113..b0c8215744 100644
--- a/tests/rest/client/test_identity.py
+++ b/tests/rest/client/test_identity.py
@@ -25,7 +25,6 @@ from tests import unittest
class IdentityTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -33,7 +32,6 @@ class IdentityTestCase(unittest.HomeserverTestCase):
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
config = self.default_config()
config["enable_3pid_lookup"] = False
self.hs = self.setup_test_homeserver(config=config)
@@ -54,6 +52,7 @@ class IdentityTestCase(unittest.HomeserverTestCase):
"id_server": "testis",
"medium": "email",
"address": "test@example.com",
+ "id_access_token": tok,
}
request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
channel = self.make_request(
diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py
index bbc8e74243..741fecea77 100644
--- a/tests/rest/client/test_keys.py
+++ b/tests/rest/client/test_keys.py
@@ -19,6 +19,7 @@ from synapse.rest import admin
from synapse.rest.client import keys, login
from tests import unittest
+from tests.http.server._base import make_request_with_cancellation_test
class KeyQueryTestCase(unittest.HomeserverTestCase):
@@ -89,3 +90,31 @@ class KeyQueryTestCase(unittest.HomeserverTestCase):
Codes.BAD_JSON,
channel.result,
)
+
+ def test_key_query_cancellation(self) -> None:
+ """
+ Tests that /keys/query is cancellable and does not swallow the
+ CancelledError.
+ """
+ self.register_user("alice", "wonderland")
+ alice_token = self.login("alice", "wonderland")
+
+ bob = self.register_user("bob", "uncle")
+
+ channel = make_request_with_cancellation_test(
+ "test_key_query_cancellation",
+ self.reactor,
+ self.site,
+ "POST",
+ "/_matrix/client/r0/keys/query",
+ {
+ "device_keys": {
+ # Empty list means we request keys for all bob's devices
+ bob: [],
+ },
+ },
+ token=alice_token,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertIn(bob, channel.json_body["device_keys"])
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index ab4277dd31..b781875d52 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -586,9 +586,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"require_at_registration": True,
},
"account_threepid_delegates": {
- "email": "https://id_server",
"msisdn": "https://id_server",
},
+ "email": {"notif_from": "Synapse <synapse@example.com>"},
}
)
def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None:
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index d589f07314..651f4f415d 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -999,7 +999,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations,
)
- self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 6)
+ self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7)
def test_annotation_to_annotation(self) -> None:
"""Any relation to an annotation should be ignored."""
@@ -1035,7 +1035,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations,
)
- self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 6)
+ self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7)
def test_thread(self) -> None:
"""
@@ -1080,21 +1080,21 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
# The "user" sent the root event and is making queries for the bundled
# aggregations: they have participated.
- self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8)
+ self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 9)
# The "user2" sent replies in the thread and is making queries for the
# bundled aggregations: they have participated.
#
# Note that this re-uses some cached values, so the total number of
# queries is much smaller.
self._test_bundled_aggregations(
- RelationTypes.THREAD, _gen_assert(True), 2, access_token=self.user2_token
+ RelationTypes.THREAD, _gen_assert(True), 3, access_token=self.user2_token
)
# A user with no interactions with the thread: they have not participated.
user3_id, user3_token = self._create_user("charlie")
self.helper.join(self.room, user=user3_id, tok=user3_token)
self._test_bundled_aggregations(
- RelationTypes.THREAD, _gen_assert(False), 2, access_token=user3_token
+ RelationTypes.THREAD, _gen_assert(False), 3, access_token=user3_token
)
def test_thread_with_bundled_aggregations_for_latest(self) -> None:
@@ -1142,7 +1142,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations["latest_event"].get("unsigned"),
)
- self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8)
+ self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9)
def test_nested_thread(self) -> None:
"""
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index ac9c113354..9c8c1889d3 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -20,7 +20,7 @@ from synapse.api.constants import EventTypes
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
-from synapse.types import JsonDict
+from synapse.types import JsonDict, create_requester
from synapse.util import Clock
from synapse.visibility import filter_events_for_client
@@ -188,7 +188,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
message_handler = self.hs.get_message_handler()
create_event = self.get_success(
message_handler.get_room_data(
- self.user_id, room_id, EventTypes.Create, state_key=""
+ create_requester(self.user_id), room_id, EventTypes.Create, state_key=""
)
)
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index aa2f578441..c7eb88d33f 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -3461,3 +3461,21 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# Also check that it stopped before calling _make_and_store_3pid_invite.
make_invite_mock.assert_called_once()
+
+ def test_400_missing_param_without_id_access_token(self) -> None:
+ """
+ Test that a 3pid invite request returns 400 M_MISSING_PARAM
+ if we do not include id_access_token.
+ """
+ channel = self.make_request(
+ method="POST",
+ path="/rooms/" + self.room_id + "/invite",
+ content={
+ "id_server": "example.com",
+ "medium": "email",
+ "address": "teresa@example.com",
+ },
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 400)
+ self.assertEqual(channel.json_body["errcode"], "M_MISSING_PARAM")
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index d9bd8c4a28..c807a37bc2 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -26,7 +26,7 @@ from synapse.rest.client import (
room_upgrade_rest_servlet,
)
from synapse.server import HomeServer
-from synapse.types import UserID
+from synapse.types import UserID, create_requester
from synapse.util import Clock
from tests import unittest
@@ -97,7 +97,12 @@ class RoomTestCase(_ShadowBannedBase):
channel = self.make_request(
"POST",
"/rooms/%s/invite" % (room_id,),
- {"id_server": "test", "medium": "email", "address": "test@test.test"},
+ {
+ "id_server": "test",
+ "medium": "email",
+ "address": "test@test.test",
+ "id_access_token": "anytoken",
+ },
access_token=self.banned_access_token,
)
self.assertEqual(200, channel.code, channel.result)
@@ -275,7 +280,7 @@ class ProfileTestCase(_ShadowBannedBase):
message_handler = self.hs.get_message_handler()
event = self.get_success(
message_handler.get_room_data(
- self.banned_user_id,
+ create_requester(self.banned_user_id),
room_id,
"m.room.member",
self.banned_user_id,
@@ -310,7 +315,7 @@ class ProfileTestCase(_ShadowBannedBase):
message_handler = self.hs.get_message_handler()
event = self.get_success(
message_handler.get_room_data(
- self.banned_user_id,
+ create_requester(self.banned_user_id),
room_id,
"m.room.member",
self.banned_user_id,
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index de0dec8539..0af643ecd9 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -391,7 +391,6 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
- config["experimental_features"] = {"msc2285_enabled": True}
return self.setup_test_homeserver(config=config)
@@ -413,17 +412,14 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Join the second user
self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
- @parameterized.expand(
- [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
- )
- def test_private_read_receipts(self, receipt_type: str) -> None:
+ def test_private_read_receipts(self) -> None:
# Send a message as the first user
res = self.helper.send(self.room_id, body="hello", tok=self.tok)
# Send a private read receipt to tell the server the first user's message was read
channel = self.make_request(
"POST",
- f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}",
{},
access_token=self.tok2,
)
@@ -432,10 +428,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Test that the first user can't see the other user's private read receipt
self.assertIsNone(self._get_read_receipt())
- @parameterized.expand(
- [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
- )
- def test_public_receipt_can_override_private(self, receipt_type: str) -> None:
+ def test_public_receipt_can_override_private(self) -> None:
"""
Sending a public read receipt to the same event which has a private read
receipt should cause that receipt to become public.
@@ -446,7 +439,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Send a private read receipt
channel = self.make_request(
"POST",
- f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}",
{},
access_token=self.tok2,
)
@@ -465,10 +458,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Test that we did override the private read receipt
self.assertNotEqual(self._get_read_receipt(), None)
- @parameterized.expand(
- [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
- )
- def test_private_receipt_cannot_override_public(self, receipt_type: str) -> None:
+ def test_private_receipt_cannot_override_public(self) -> None:
"""
Sending a private read receipt to the same event which has a public read
receipt should cause no change.
@@ -489,7 +479,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Send a private read receipt
channel = self.make_request(
"POST",
- f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}",
{},
access_token=self.tok2,
)
@@ -554,7 +544,6 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
config = super().default_config()
config["experimental_features"] = {
"msc2654_enabled": True,
- "msc2285_enabled": True,
}
return config
@@ -601,10 +590,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
tok=self.tok,
)
- @parameterized.expand(
- [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
- )
- def test_unread_counts(self, receipt_type: str) -> None:
+ def test_unread_counts(self) -> None:
"""Tests that /sync returns the right value for the unread count (MSC2654)."""
# Check that our own messages don't increase the unread count.
@@ -638,7 +624,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
# Send a read receipt to tell the server we've read the latest event.
channel = self.make_request(
"POST",
- f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}",
{},
access_token=self.tok,
)
@@ -726,7 +712,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
- f"/rooms/{self.room_id}/receipt/{receipt_type}/{res2['event_id']}",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res2['event_id']}",
{},
access_token=self.tok,
)
@@ -738,7 +724,6 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
[
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
- ReceiptTypes.UNSTABLE_READ_PRIVATE,
]
)
def test_read_receipts_only_go_down(self, receipt_type: str) -> None:
@@ -752,7 +737,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
# Read last event
channel = self.make_request(
"POST",
- f"/rooms/{self.room_id}/receipt/{receipt_type}/{res2['event_id']}",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res2['event_id']}",
{},
access_token=self.tok,
)
@@ -763,7 +748,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
# read receipt go up to an older event
channel = self.make_request(
"POST",
- f"/rooms/{self.room_id}/receipt/{receipt_type}/{res1['event_id']}",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res1['event_id']}",
{},
access_token=self.tok,
)
diff --git a/tests/server.py b/tests/server.py
index 9689e6a0cd..c447d5e4c4 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -61,6 +61,10 @@ from twisted.web.resource import IResource
from twisted.web.server import Request, Site
from synapse.config.database import DatabaseConnectionConfig
+from synapse.events.presence_router import load_legacy_presence_router
+from synapse.events.spamcheck import load_legacy_spam_checkers
+from synapse.events.third_party_rules import load_legacy_third_party_event_rules
+from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.http.site import SynapseRequest
from synapse.logging.context import ContextResourceUsage
from synapse.server import HomeServer
@@ -913,4 +917,14 @@ def setup_test_homeserver(
# Make the threadpool and database transactions synchronous for testing.
_make_test_homeserver_synchronous(hs)
+ # Load any configured modules into the homeserver
+ module_api = hs.get_module_api()
+ for module, config in hs.config.modules.loaded_modules:
+ module(config=config, api=module_api)
+
+ load_legacy_spam_checkers(hs)
+ load_legacy_third_party_event_rules(hs)
+ load_legacy_presence_router(hs)
+ load_legacy_password_auth_providers(hs)
+
return hs
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index e07ae78fc4..bf403045e9 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -11,16 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from unittest.mock import Mock
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType
from synapse.api.errors import ResourceLimitError
from synapse.rest import admin
from synapse.rest.client import login, room, sync
+from synapse.server import HomeServer
from synapse.server_notices.resource_limits_server_notices import (
ResourceLimitsServerNotices,
)
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
@@ -52,7 +55,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
return config
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.server_notices_sender = self.hs.get_server_notices_sender()
# relying on [1] is far from ideal, but the only case where
@@ -251,7 +254,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
c["admin_contact"] = "mailto:user@test.com"
return c
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = self.hs.get_datastores().main
self.server_notices_sender = self.hs.get_server_notices_sender()
self.server_notices_manager = self.hs.get_server_notices_manager()
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 46d829b062..67401272ac 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -254,7 +254,7 @@ class DatabaseOutageTestCase(unittest.HomeserverTestCase):
"room_id": self.room_id,
"json": json.dumps(event_json),
"internal_metadata": "{}",
- "format_version": EventFormatVersions.V3,
+ "format_version": EventFormatVersions.ROOM_V4_PLUS,
},
)
)
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index cce8e75c74..40e58f8199 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -54,7 +54,6 @@ class SQLBaseStoreTestCase(unittest.TestCase):
sqlite_config = {"name": "sqlite3"}
engine = create_engine(sqlite_config)
fake_engine = Mock(wraps=engine)
- fake_engine.can_native_upsert = False
fake_engine.in_transaction.return_value = False
db = DatabasePool(Mock(), Mock(config=sqlite_config), fake_engine)
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index d92a9ac5b7..a6679e1312 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -513,7 +513,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def prev_event_format(prev_event_id: str) -> Union[Tuple[str, dict], str]:
"""Account for differences in prev_events format across room versions"""
- if room_version.event_format == EventFormatVersions.V1:
+ if room_version.event_format == EventFormatVersions.ROOM_V1_V2:
return prev_event_id, {}
return prev_event_id
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 62fd4aeb2f..fc43d7edd1 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -67,9 +67,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
last_event_id: str
- def _assert_counts(
- noitf_count: int, unread_count: int, highlight_count: int
- ) -> None:
+ def _assert_counts(noitf_count: int, highlight_count: int) -> None:
counts = self.get_success(
self.store.db_pool.runInteraction(
"get-unread-counts",
@@ -82,7 +80,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
counts,
NotifCounts(
notify_count=noitf_count,
- unread_count=unread_count,
+ unread_count=0,
highlight_count=highlight_count,
),
)
@@ -112,27 +110,27 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
)
)
- _assert_counts(0, 0, 0)
+ _assert_counts(0, 0)
_create_event()
- _assert_counts(1, 1, 0)
+ _assert_counts(1, 0)
_rotate()
- _assert_counts(1, 1, 0)
+ _assert_counts(1, 0)
event_id = _create_event()
- _assert_counts(2, 2, 0)
+ _assert_counts(2, 0)
_rotate()
- _assert_counts(2, 2, 0)
+ _assert_counts(2, 0)
_create_event()
_mark_read(event_id)
- _assert_counts(1, 1, 0)
+ _assert_counts(1, 0)
_mark_read(last_event_id)
- _assert_counts(0, 0, 0)
+ _assert_counts(0, 0)
_create_event()
_rotate()
- _assert_counts(1, 1, 0)
+ _assert_counts(1, 0)
# Delete old event push actions, this should not affect the (summarised) count.
#
@@ -151,35 +149,35 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
)
)
self.assertEqual(result, [])
- _assert_counts(1, 1, 0)
+ _assert_counts(1, 0)
_mark_read(last_event_id)
- _assert_counts(0, 0, 0)
+ _assert_counts(0, 0)
event_id = _create_event(True)
- _assert_counts(1, 1, 1)
+ _assert_counts(1, 1)
_rotate()
- _assert_counts(1, 1, 1)
+ _assert_counts(1, 1)
# Check that adding another notification and rotating after highlight
# works.
_create_event()
_rotate()
- _assert_counts(2, 2, 1)
+ _assert_counts(2, 1)
# Check that sending read receipts at different points results in the
# right counts.
_mark_read(event_id)
- _assert_counts(1, 1, 0)
+ _assert_counts(1, 0)
_mark_read(last_event_id)
- _assert_counts(0, 0, 0)
+ _assert_counts(0, 0)
_create_event(True)
- _assert_counts(1, 1, 1)
+ _assert_counts(1, 1)
_mark_read(last_event_id)
- _assert_counts(0, 0, 0)
+ _assert_counts(0, 0)
_rotate()
- _assert_counts(0, 0, 0)
+ _assert_counts(0, 0)
def test_find_first_stream_ordering_after_ts(self) -> None:
def add_event(so: int, ts: int) -> None:
diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py
index 191c957fb5..c89bfff241 100644
--- a/tests/storage/test_receipts.py
+++ b/tests/storage/test_receipts.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from parameterized import parameterized
from synapse.api.constants import ReceiptTypes
from synapse.types import UserID, create_requester
@@ -92,7 +91,6 @@ class ReceiptTestCase(HomeserverTestCase):
[
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
- ReceiptTypes.UNSTABLE_READ_PRIVATE,
],
)
)
@@ -104,7 +102,6 @@ class ReceiptTestCase(HomeserverTestCase):
[
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
- ReceiptTypes.UNSTABLE_READ_PRIVATE,
],
)
)
@@ -117,16 +114,12 @@ class ReceiptTestCase(HomeserverTestCase):
[
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
- ReceiptTypes.UNSTABLE_READ_PRIVATE,
],
)
)
self.assertEqual(res, None)
- @parameterized.expand(
- [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
- )
- def test_get_receipts_for_user(self, receipt_type: str) -> None:
+ def test_get_receipts_for_user(self) -> None:
# Send some events into the first room
event1_1_id = self.create_and_send_event(
self.room_id1, UserID.from_string(OTHER_USER_ID)
@@ -144,14 +137,14 @@ class ReceiptTestCase(HomeserverTestCase):
# Send private read receipt for the second event
self.get_success(
self.store.insert_receipt(
- self.room_id1, receipt_type, OUR_USER_ID, [event1_2_id], {}
+ self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {}
)
)
# Test we get the latest event when we want both private and public receipts
res = self.get_success(
self.store.get_receipts_for_user(
- OUR_USER_ID, [ReceiptTypes.READ, receipt_type]
+ OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
)
)
self.assertEqual(res, {self.room_id1: event1_2_id})
@@ -164,7 +157,7 @@ class ReceiptTestCase(HomeserverTestCase):
# Test we get the latest event when we want only the public receipt
res = self.get_success(
- self.store.get_receipts_for_user(OUR_USER_ID, [receipt_type])
+ self.store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ_PRIVATE])
)
self.assertEqual(res, {self.room_id1: event1_2_id})
@@ -187,20 +180,17 @@ class ReceiptTestCase(HomeserverTestCase):
# Test new room is reflected in what the method returns
self.get_success(
self.store.insert_receipt(
- self.room_id2, receipt_type, OUR_USER_ID, [event2_1_id], {}
+ self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {}
)
)
res = self.get_success(
self.store.get_receipts_for_user(
- OUR_USER_ID, [ReceiptTypes.READ, receipt_type]
+ OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
)
)
self.assertEqual(res, {self.room_id1: event1_2_id, self.room_id2: event2_1_id})
- @parameterized.expand(
- [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
- )
- def test_get_last_receipt_event_id_for_user(self, receipt_type: str) -> None:
+ def test_get_last_receipt_event_id_for_user(self) -> None:
# Send some events into the first room
event1_1_id = self.create_and_send_event(
self.room_id1, UserID.from_string(OTHER_USER_ID)
@@ -218,7 +208,7 @@ class ReceiptTestCase(HomeserverTestCase):
# Send private read receipt for the second event
self.get_success(
self.store.insert_receipt(
- self.room_id1, receipt_type, OUR_USER_ID, [event1_2_id], {}
+ self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {}
)
)
@@ -227,7 +217,7 @@ class ReceiptTestCase(HomeserverTestCase):
self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID,
self.room_id1,
- [ReceiptTypes.READ, receipt_type],
+ [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
)
)
self.assertEqual(res, event1_2_id)
@@ -243,7 +233,7 @@ class ReceiptTestCase(HomeserverTestCase):
# Test we get the latest event when we want only the private receipt
res = self.get_success(
self.store.get_last_receipt_event_id_for_user(
- OUR_USER_ID, self.room_id1, [receipt_type]
+ OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE]
)
)
self.assertEqual(res, event1_2_id)
@@ -269,14 +259,14 @@ class ReceiptTestCase(HomeserverTestCase):
# Test new room is reflected in what the method returns
self.get_success(
self.store.insert_receipt(
- self.room_id2, receipt_type, OUR_USER_ID, [event2_1_id], {}
+ self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {}
)
)
res = self.get_success(
self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID,
self.room_id2,
- [ReceiptTypes.READ, receipt_type],
+ [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
)
)
self.assertEqual(res, event2_1_id)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index a49ac1525e..853a93afab 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -11,15 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import UserTypes
from synapse.api.errors import ThreepidValidationError
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
class RegistrationStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.user_id = "@my-user:test"
@@ -27,7 +30,7 @@ class RegistrationStoreTestCase(HomeserverTestCase):
self.pwhash = "{xx1}123456789"
self.device_id = "akgjhdjklgshg"
- def test_register(self):
+ def test_register(self) -> None:
self.get_success(self.store.register_user(self.user_id, self.pwhash))
self.assertEqual(
@@ -38,6 +41,7 @@ class RegistrationStoreTestCase(HomeserverTestCase):
"admin": 0,
"is_guest": 0,
"consent_version": None,
+ "consent_ts": None,
"consent_server_notice_sent": None,
"appservice_id": None,
"creation_ts": 0,
@@ -48,7 +52,20 @@ class RegistrationStoreTestCase(HomeserverTestCase):
(self.get_success(self.store.get_user_by_id(self.user_id))),
)
- def test_add_tokens(self):
+ def test_consent(self) -> None:
+ self.get_success(self.store.register_user(self.user_id, self.pwhash))
+ before_consent = self.clock.time_msec()
+ self.reactor.advance(5)
+ self.get_success(self.store.user_set_consent_version(self.user_id, "1"))
+ self.reactor.advance(5)
+
+ user = self.get_success(self.store.get_user_by_id(self.user_id))
+ assert user
+ self.assertEqual(user["consent_version"], "1")
+ self.assertGreater(user["consent_ts"], before_consent)
+ self.assertLess(user["consent_ts"], self.clock.time_msec())
+
+ def test_add_tokens(self) -> None:
self.get_success(self.store.register_user(self.user_id, self.pwhash))
self.get_success(
self.store.add_access_token_to_user(
@@ -58,11 +75,12 @@ class RegistrationStoreTestCase(HomeserverTestCase):
result = self.get_success(self.store.get_user_by_access_token(self.tokens[1]))
+ assert result
self.assertEqual(result.user_id, self.user_id)
self.assertEqual(result.device_id, self.device_id)
self.assertIsNotNone(result.token_id)
- def test_user_delete_access_tokens(self):
+ def test_user_delete_access_tokens(self) -> None:
# add some tokens
self.get_success(self.store.register_user(self.user_id, self.pwhash))
self.get_success(
@@ -87,6 +105,7 @@ class RegistrationStoreTestCase(HomeserverTestCase):
# check the one not associated with the device was not deleted
user = self.get_success(self.store.get_user_by_access_token(self.tokens[0]))
+ assert user
self.assertEqual(self.user_id, user.user_id)
# now delete the rest
@@ -95,11 +114,11 @@ class RegistrationStoreTestCase(HomeserverTestCase):
user = self.get_success(self.store.get_user_by_access_token(self.tokens[0]))
self.assertIsNone(user, "access token was not deleted without device_id")
- def test_is_support_user(self):
+ def test_is_support_user(self) -> None:
TEST_USER = "@test:test"
SUPPORT_USER = "@support:test"
- res = self.get_success(self.store.is_support_user(None))
+ res = self.get_success(self.store.is_support_user(None)) # type: ignore[arg-type]
self.assertFalse(res)
self.get_success(
self.store.register_user(user_id=TEST_USER, password_hash=None)
@@ -115,7 +134,7 @@ class RegistrationStoreTestCase(HomeserverTestCase):
res = self.get_success(self.store.is_support_user(SUPPORT_USER))
self.assertTrue(res)
- def test_3pid_inhibit_invalid_validation_session_error(self):
+ def test_3pid_inhibit_invalid_validation_session_error(self) -> None:
"""Tests that enabling the configuration option to inhibit 3PID errors on
/requestToken also inhibits validation errors caused by an unknown session ID.
"""
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index ceec690285..8794401823 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -158,7 +158,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
# Check that alice's display name is now None
self.assertEqual(row[0]["display_name"], None)
- def test_room_is_locally_forgotten(self):
+ def test_room_is_locally_forgotten(self) -> None:
"""Test that when the last local user has forgotten a room it is known as forgotten."""
# join two local and one remote user
self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
@@ -199,7 +199,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.is_locally_forgotten_room(self.room))
)
- def test_join_locally_forgotten_room(self):
+ def test_join_locally_forgotten_room(self) -> None:
"""Tests if a user joins a forgotten room the room is not forgotten anymore."""
self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
self.assertFalse(
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index e42d7b9ba0..f4d9fba0a1 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -821,7 +821,7 @@ def _alias_event(room_version: RoomVersion, sender: str, **kwargs) -> EventBase:
def _build_auth_dict_for_room_version(
room_version: RoomVersion, auth_events: Iterable[EventBase]
) -> List:
- if room_version.event_format == EventFormatVersions.V1:
+ if room_version.event_format == EventFormatVersions.ROOM_V1_V2:
return [(e.event_id, "not_used") for e in auth_events]
else:
return [e.event_id for e in auth_events]
@@ -871,7 +871,7 @@ event_count = 0
def _maybe_get_event_id_dict_for_room_version(room_version: RoomVersion) -> dict:
"""If this room version needs it, generate an event id"""
- if room_version.event_format != EventFormatVersions.V1:
+ if room_version.event_format != EventFormatVersions.ROOM_V1_V2:
return {}
global event_count
diff --git a/tests/test_metrics.py b/tests/test_metrics.py
index b4574b2ffe..1a70eddc9b 100644
--- a/tests/test_metrics.py
+++ b/tests/test_metrics.py
@@ -12,7 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+try:
+ from importlib import metadata
+except ImportError:
+ import importlib_metadata as metadata # type: ignore[no-redef]
+from unittest.mock import patch
+
+from pkg_resources import parse_version
+
+from synapse.app._base import _set_prometheus_client_use_created_metrics
from synapse.metrics import REGISTRY, InFlightGauge, generate_latest
from synapse.util.caches.deferred_cache import DeferredCache
@@ -162,3 +171,30 @@ class CacheMetricsTests(unittest.HomeserverTestCase):
self.assertEqual(items["synapse_util_caches_cache_size"], "1.0")
self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0")
+
+
+class PrometheusMetricsHackTestCase(unittest.HomeserverTestCase):
+ if parse_version(metadata.version("prometheus_client")) < parse_version("0.14.0"):
+ skip = "prometheus-client too old"
+
+ def test_created_metrics_disabled(self) -> None:
+ """
+ Tests that a brittle hack, to disable `_created` metrics, works.
+ This involves poking at the internals of prometheus-client.
+ It's not the end of the world if this doesn't work.
+
+ This test gives us a way to notice if prometheus-client changes
+ their internals.
+ """
+ import prometheus_client.metrics
+
+ PRIVATE_FLAG_NAME = "_use_created"
+
+ # By default, the pesky `_created` metrics are enabled.
+ # Check this assumption is still valid.
+ self.assertTrue(getattr(prometheus_client.metrics, PRIVATE_FLAG_NAME))
+
+ with patch("prometheus_client.metrics") as mock:
+ setattr(mock, PRIVATE_FLAG_NAME, True)
+ _set_prometheus_client_use_created_metrics(False)
+ self.assertFalse(getattr(mock, PRIVATE_FLAG_NAME, False))
diff --git a/tests/test_phone_home.py b/tests/test_phone_home.py
index b01cae6e5d..cc1a98f1c4 100644
--- a/tests/test_phone_home.py
+++ b/tests/test_phone_home.py
@@ -15,8 +15,14 @@
import resource
from unittest import mock
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.app.phone_stats_home import phone_stats_home
+from synapse.rest import admin
+from synapse.rest.client import login, sync
+from synapse.server import HomeServer
from synapse.types import JsonDict
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -47,5 +53,43 @@ class PhoneHomeStatsTestCase(HomeserverTestCase):
stats: JsonDict = {}
self.reactor.advance(1)
# `old_resource` has type `Mock` instead of `struct_rusage`
- self.get_success(phone_stats_home(self.hs, stats, past_stats)) # type: ignore[arg-type]
+ self.get_success(
+ phone_stats_home(self.hs, stats, past_stats) # type: ignore[arg-type]
+ )
self.assertApproximates(stats["cpu_average"], 100, tolerance=2.5)
+
+
+class CommonMetricsTestCase(HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.metrics_manager = hs.get_common_usage_metrics_manager()
+ self.get_success(self.metrics_manager.setup())
+
+ def test_dau(self) -> None:
+ """Tests that the daily active users count is correctly updated."""
+ self._assert_metric_value("daily_active_users", 0)
+
+ self.register_user("user", "password")
+ tok = self.login("user", "password")
+ self.make_request("GET", "/sync", access_token=tok)
+
+ self.pump(1)
+
+ self._assert_metric_value("daily_active_users", 1)
+
+ def _assert_metric_value(self, metric_name: str, expected: int) -> None:
+ """Compare the given value to the current value of the common usage metric with
+ the given name.
+
+ Args:
+ metric_name: The metric to look up.
+ expected: Expected value for this metric.
+ """
+ metrics = self.get_success(self.metrics_manager.get_metrics())
+ value = getattr(metrics, metric_name)
+ self.assertEqual(value, expected)
diff --git a/tests/test_rust.py b/tests/test_rust.py
new file mode 100644
index 0000000000..55d8b6b28c
--- /dev/null
+++ b/tests/test_rust.py
@@ -0,0 +1,11 @@
+from synapse.synapse_rust import sum_as_string
+
+from tests import unittest
+
+
+class RustTestCase(unittest.TestCase):
+ """Basic tests to ensure that we can call into Rust code."""
+
+ def test_basic(self):
+ result = sum_as_string(1, 2)
+ self.assertEqual("3", result)
diff --git a/tests/test_server.py b/tests/test_server.py
index d2b2d8344a..7c66448245 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -26,12 +26,12 @@ from synapse.http.server import (
DirectServeJsonResource,
JsonResource,
OptionsResource,
- cancellable,
)
from synapse.http.site import SynapseRequest, SynapseSite
from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict
from synapse.util import Clock
+from synapse.util.cancellation import cancellable
from tests import unittest
from tests.http.server._base import test_disconnect
@@ -228,7 +228,7 @@ class OptionsResourceTests(unittest.TestCase):
site = SynapseSite(
"test",
"site_tag",
- parse_listener_def({"type": "http", "port": 0}),
+ parse_listener_def(0, {"type": "http", "port": 0}),
self.resource,
"1.0",
max_request_body_size=4096,
diff --git a/tests/test_types.py b/tests/test_types.py
index d8d82a517e..1111169384 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -13,11 +13,35 @@
# limitations under the License.
from synapse.api.errors import SynapseError
-from synapse.types import RoomAlias, UserID, map_username_to_mxid_localpart
+from synapse.types import (
+ RoomAlias,
+ UserID,
+ get_domain_from_id,
+ get_localpart_from_id,
+ map_username_to_mxid_localpart,
+)
from tests import unittest
+class IsMineIDTests(unittest.HomeserverTestCase):
+ def test_is_mine_id(self) -> None:
+ self.assertTrue(self.hs.is_mine_id("@user:test"))
+ self.assertTrue(self.hs.is_mine_id("#room:test"))
+ self.assertTrue(self.hs.is_mine_id("invalid:test"))
+
+ self.assertFalse(self.hs.is_mine_id("@user:test\0"))
+ self.assertFalse(self.hs.is_mine_id("@user"))
+
+ def test_two_colons(self) -> None:
+ """Test handling of IDs containing more than one colon."""
+ # The domain starts after the first colon.
+ # These functions must interpret things consistently.
+ self.assertFalse(self.hs.is_mine_id("@user:test:test"))
+ self.assertEqual("user", get_localpart_from_id("@user:test:test"))
+ self.assertEqual("test:test", get_domain_from_id("@user:test:test"))
+
+
class UserIDTestCase(unittest.HomeserverTestCase):
def test_parse(self):
user = UserID.from_string("@1234abcd:test")
diff --git a/tests/unittest.py b/tests/unittest.py
index bec4a3d023..975b0a23a7 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -677,14 +677,29 @@ class HomeserverTestCase(TestCase):
username: str,
password: str,
device_id: Optional[str] = None,
+ additional_request_fields: Optional[Dict[str, str]] = None,
custom_headers: Optional[Iterable[CustomHeaderType]] = None,
) -> str:
"""
Log in a user, and get an access token. Requires the Login API be registered.
+
+ Args:
+ username: The localpart to assign to the new user.
+ password: The password to assign to the new user.
+ device_id: An optional device ID to assign to the new device created during
+ login.
+ additional_request_fields: A dictionary containing any additional /login
+ request fields and their values.
+ custom_headers: Custom HTTP headers and values to add to the /login request.
+
+ Returns:
+ The newly registered user's Matrix ID.
"""
body = {"type": "m.login.password", "user": username, "password": password}
if device_id:
body["device_id"] = device_id
+ if additional_request_fields:
+ body.update(additional_request_fields)
channel = self.make_request(
"POST",
diff --git a/tests/utils.py b/tests/utils.py
index d2c6d1e852..65db437697 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -135,7 +135,6 @@ def default_config(
"enable_registration_captcha": False,
"macaroon_secret_key": "not even a little secret",
"password_providers": [],
- "worker_replication_url": "",
"worker_app": None,
"block_non_admin_invites": False,
"federation_domain_whitelist": None,
|