diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index dfcfaf79b6..e0f363555b 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -284,10 +284,13 @@ class AuthTestCase(unittest.HomeserverTestCase):
TokenLookupResult(
user_id="@baldrick:matrix.org",
device_id="device",
+ token_id=5,
token_owner="@admin:matrix.org",
+ token_used=True,
)
)
self.store.insert_client_ip = simple_async_mock(None)
+ self.store.mark_access_token_as_used = simple_async_mock(None)
request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
@@ -301,10 +304,13 @@ class AuthTestCase(unittest.HomeserverTestCase):
TokenLookupResult(
user_id="@baldrick:matrix.org",
device_id="device",
+ token_id=5,
token_owner="@admin:matrix.org",
+ token_used=True,
)
)
self.store.insert_client_ip = simple_async_mock(None)
+ self.store.mark_access_token_as_used = simple_async_mock(None)
request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
@@ -347,7 +353,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
serialized = macaroon.serialize()
user_info = self.get_success(self.auth.get_user_by_access_token(serialized))
- self.assertEqual(user_id, user_info.user_id)
+ self.assertEqual(user_id, user_info.user.to_string())
self.assertTrue(user_info.is_guest)
self.store.get_user_by_id.assert_called_with(user_id)
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index a269c477fb..d5524d296e 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -35,6 +35,8 @@ def MockEvent(**kwargs):
kwargs["event_id"] = "fake_event_id"
if "type" not in kwargs:
kwargs["type"] = "fake_type"
+ if "content" not in kwargs:
+ kwargs["content"] = {}
return make_event_from_dict(kwargs)
@@ -44,19 +46,36 @@ class FilteringTestCase(unittest.HomeserverTestCase):
self.datastore = hs.get_datastores().main
def test_errors_on_invalid_filters(self):
+ # See USER_FILTER_SCHEMA for the filter schema.
invalid_filters = [
- {"boom": {}},
+ # `account_data` must be a dictionary
{"account_data": "Hello World"},
+ # `event_fields` entries must not contain backslashes
{"event_fields": [r"\\foo"]},
- {"room": {"timeline": {"limit": 0}, "state": {"not_bars": ["*"]}}},
+ # `event_format` must be "client" or "federation"
{"event_format": "other"},
+ # `not_rooms` must contain valid room IDs
{"room": {"not_rooms": ["#foo:pik-test"]}},
+ # `senders` must contain valid user IDs
{"presence": {"senders": ["@bar;pik.test.com"]}},
]
for filter in invalid_filters:
with self.assertRaises(SynapseError):
self.filtering.check_valid_filter(filter)
+ def test_ignores_unknown_filter_fields(self):
+ # For forward compatibility, we must ignore unknown filter fields.
+ # See USER_FILTER_SCHEMA for the filter schema.
+ filters = [
+ {"org.matrix.msc9999.future_option": True},
+ {"presence": {"org.matrix.msc9999.future_option": True}},
+ {"room": {"org.matrix.msc9999.future_option": True}},
+ {"room": {"timeline": {"org.matrix.msc9999.future_option": True}}},
+ ]
+ for filter in filters:
+ self.filtering.check_valid_filter(filter)
+ # Must not raise.
+
def test_valid_filters(self):
valid_filters = [
{
@@ -357,6 +376,66 @@ class FilteringTestCase(unittest.HomeserverTestCase):
self.assertTrue(Filter(self.hs, definition)._check(event))
+ @unittest.override_config({"experimental_features": {"msc3874_enabled": True}})
+ def test_filter_rel_type(self):
+ definition = {"org.matrix.msc3874.rel_types": ["m.thread"]}
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!secretbase:unknown",
+ content={},
+ )
+
+ self.assertFalse(Filter(self.hs, definition)._check(event))
+
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!secretbase:unknown",
+ content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.reference"}},
+ )
+
+ self.assertFalse(Filter(self.hs, definition)._check(event))
+
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!secretbase:unknown",
+ content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.thread"}},
+ )
+
+ self.assertTrue(Filter(self.hs, definition)._check(event))
+
+ @unittest.override_config({"experimental_features": {"msc3874_enabled": True}})
+ def test_filter_not_rel_type(self):
+ definition = {"org.matrix.msc3874.not_rel_types": ["m.thread"]}
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!secretbase:unknown",
+ content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.thread"}},
+ )
+
+ self.assertFalse(Filter(self.hs, definition)._check(event))
+
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!secretbase:unknown",
+ content={},
+ )
+
+ self.assertTrue(Filter(self.hs, definition)._check(event))
+
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!secretbase:unknown",
+ content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.reference"}},
+ )
+
+ self.assertTrue(Filter(self.hs, definition)._check(event))
+
def test_filter_presence_match(self):
user_filter_json = {"presence": {"types": ["m.*"]}}
filter_id = self.get_success(
@@ -456,7 +535,6 @@ class FilteringTestCase(unittest.HomeserverTestCase):
self.assertEqual(filtered_room_ids, ["!allowed:example.com"])
- @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_filter_relations(self):
events = [
# An event without a relation.
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index 264e101082..8d03da7f96 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -61,7 +61,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
}
# Listen with the config
- self.hs._listen_http(parse_listener_def(config))
+ self.hs._listen_http(parse_listener_def(0, config))
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
@@ -79,7 +79,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
self.assertEqual(channel.code, 401)
-@patch("synapse.app.homeserver.KeyApiV2Resource", new=Mock())
+@patch("synapse.app.homeserver.KeyResource", new=Mock())
class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
@@ -109,7 +109,7 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
}
# Listen with the config
- self.hs._listener_http(self.hs.config, parse_listener_def(config))
+ self.hs._listener_http(self.hs.config, parse_listener_def(0, config))
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py
index 532b676365..89ee79396f 100644
--- a/tests/appservice/test_api.py
+++ b/tests/appservice/test_api.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, List, Mapping
+from typing import Any, List, Mapping, Sequence, Union
from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
@@ -69,10 +69,16 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
self.request_url = None
- async def get_json(url: str, args: Mapping[Any, Any]) -> List[JsonDict]:
- if not args.get(b"access_token"):
+ async def get_json(
+ url: str,
+ args: Mapping[Any, Any],
+ headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]],
+ ) -> List[JsonDict]:
+ # Ensure the access token is passed as both a header and query arg.
+ if not headers.get("Authorization") or not args.get(b"access_token"):
raise RuntimeError("Access token not provided")
+ self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"])
self.assertEqual(args.get(b"access_token"), TOKEN)
self.request_url = url
if url == URL_USER:
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index 3018d3fc6f..d4dccfc2f0 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -43,7 +43,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.store = Mock()
self.store.get_aliases_for_room = simple_async_mock([])
- self.store.get_users_in_room = simple_async_mock([])
+ self.store.get_local_users_in_room = simple_async_mock([])
@defer.inlineCallbacks
def test_regex_user_id_prefix_match(self):
@@ -129,7 +129,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.store.get_aliases_for_room = simple_async_mock(
["#irc_foobar:matrix.org", "#athing:matrix.org"]
)
- self.store.get_users_in_room = simple_async_mock([])
+ self.store.get_local_users_in_room = simple_async_mock([])
self.assertTrue(
(
yield defer.ensureDeferred(
@@ -184,7 +184,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.store.get_aliases_for_room = simple_async_mock(
["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
)
- self.store.get_users_in_room = simple_async_mock([])
+ self.store.get_local_users_in_room = simple_async_mock([])
self.assertFalse(
(
yield defer.ensureDeferred(
@@ -203,7 +203,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@irc_foobar:matrix.org"
self.store.get_aliases_for_room = simple_async_mock(["#irc_barfoo:matrix.org"])
- self.store.get_users_in_room = simple_async_mock([])
+ self.store.get_local_users_in_room = simple_async_mock([])
self.assertTrue(
(
yield defer.ensureDeferred(
@@ -236,7 +236,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
def test_member_list_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
# Note that @irc_fo:here is the AS user.
- self.store.get_users_in_room = simple_async_mock(
+ self.store.get_local_users_in_room = simple_async_mock(
["@alice:here", "@irc_fo:here", "@bob:here"]
)
self.store.get_aliases_for_room = simple_async_mock([])
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 820a1a54e2..63628aa6b0 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -469,6 +469,18 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
self.assertEqual(keys, {})
+ def test_keyid_containing_forward_slash(self) -> None:
+ """We should url-encode any url unsafe chars in key ids.
+
+ Detects https://github.com/matrix-org/synapse/issues/14488.
+ """
+ fetcher = ServerKeyFetcher(self.hs)
+ self.get_success(fetcher.get_keys("example.com", ["key/potato"], 0))
+
+ self.http_client.get_json.assert_called_once()
+ args, kwargs = self.http_client.get_json.call_args
+ self.assertEqual(kwargs["path"], "/_matrix/key/v2/server/key%2Fpotato")
+
class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index ffc3012a86..685a9a6d52 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -141,10 +141,6 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
hs = self.setup_test_homeserver(
federation_transport_client=fed_transport_client,
)
- # Load the modules into the homeserver
- module_api = hs.get_module_api()
- for module, config in hs.config.modules.loaded_modules:
- module(config=config, api=module_api)
load_legacy_presence_router(hs)
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index c6dd99316a..9f1115dd23 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
from unittest.mock import Mock
from synapse.api.errors import Codes, SynapseError
@@ -51,7 +50,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
)
- self.assertEqual(HTTPStatus.OK, channel.code)
+ self.assertEqual(200, channel.code)
complexity = channel.json_body["v1"]
self.assertTrue(complexity > 0, complexity)
@@ -63,7 +62,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
)
- self.assertEqual(HTTPStatus.OK, channel.code)
+ self.assertEqual(200, channel.code)
complexity = channel.json_body["v1"]
self.assertEqual(complexity, 1.23)
diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py
index 50e376f695..e67f405826 100644
--- a/tests/federation/test_federation_client.py
+++ b/tests/federation/test_federation_client.py
@@ -12,25 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
from unittest import mock
import twisted.web.client
from twisted.internet import defer
-from twisted.internet.protocol import Protocol
-from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
+from synapse.rest import admin
+from synapse.rest.client import login, room
from synapse.server import HomeServer
-from synapse.types import JsonDict
from synapse.util import Clock
+from tests.test_utils import FakeResponse, event_injection
from tests.unittest import FederatingHomeserverTestCase
class FederationClientTest(FederatingHomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
super().prepare(reactor, clock, homeserver)
@@ -89,8 +94,8 @@ class FederationClientTest(FederatingHomeserverTestCase):
# mock up the response, and have the agent return it
self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed(
- _mock_response(
- {
+ FakeResponse.json(
+ payload={
"pdus": [
create_event_dict,
member_event_dict,
@@ -137,14 +142,14 @@ class FederationClientTest(FederatingHomeserverTestCase):
def test_get_pdu_returns_nothing_when_event_does_not_exist(self):
"""No event should be returned when the event does not exist"""
- remote_pdu = self.get_success(
+ pulled_pdu_info = self.get_success(
self.hs.get_federation_client().get_pdu(
["yet.another.server"],
"event_should_not_exist",
RoomVersions.V9,
)
)
- self.assertEqual(remote_pdu, None)
+ self.assertEqual(pulled_pdu_info, None)
def test_get_pdu(self):
"""Test to make sure an event is returned by `get_pdu()`"""
@@ -164,13 +169,15 @@ class FederationClientTest(FederatingHomeserverTestCase):
remote_pdu.internal_metadata.outlier = True
# Get the event again. This time it should read it from cache.
- remote_pdu2 = self.get_success(
+ pulled_pdu_info2 = self.get_success(
self.hs.get_federation_client().get_pdu(
["yet.another.server"],
remote_pdu.event_id,
RoomVersions.V9,
)
)
+ self.assertIsNotNone(pulled_pdu_info2)
+ remote_pdu2 = pulled_pdu_info2.pdu
# Sanity check that we are working against the same event
self.assertEqual(remote_pdu.event_id, remote_pdu2.event_id)
@@ -199,8 +206,8 @@ class FederationClientTest(FederatingHomeserverTestCase):
# mock up the response, and have the agent return it
self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed(
- _mock_response(
- {
+ FakeResponse.json(
+ payload={
"origin": "yet.another.server",
"origin_server_ts": 900,
"pdus": [
@@ -210,13 +217,15 @@ class FederationClientTest(FederatingHomeserverTestCase):
)
)
- remote_pdu = self.get_success(
+ pulled_pdu_info = self.get_success(
self.hs.get_federation_client().get_pdu(
["yet.another.server"],
"event_id",
RoomVersions.V9,
)
)
+ self.assertIsNotNone(pulled_pdu_info)
+ remote_pdu = pulled_pdu_info.pdu
# check the right call got made to the agent
self._mock_agent.request.assert_called_once_with(
@@ -231,20 +240,68 @@ class FederationClientTest(FederatingHomeserverTestCase):
return remote_pdu
+ def test_backfill_invalid_signature_records_failed_pull_attempts(
+ self,
+ ) -> None:
+ """
+ Test to make sure that events from /backfill with invalid signatures get
+ recorded as failed pull attempts.
+ """
+ OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
+ main_store = self.hs.get_datastores().main
+
+ # Create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+
+ # We purposely don't run `add_hashes_and_signatures_from_other_server`
+ # over this because we want the signature check to fail.
+ pulled_event, _ = self.get_success(
+ event_injection.create_event(
+ self.hs,
+ room_id=room_id,
+ sender=OTHER_USER,
+ type="test_event_type",
+ content={"body": "garply"},
+ )
+ )
-def _mock_response(resp: JsonDict):
- body = json.dumps(resp).encode("utf-8")
+ # We expect an outbound request to /backfill, so stub that out
+ self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed(
+ FakeResponse.json(
+ payload={
+ "origin": "yet.another.server",
+ "origin_server_ts": 900,
+ # Mimic the other server returning our new `pulled_event`
+ "pdus": [pulled_event.get_pdu_json()],
+ }
+ )
+ )
- def deliver_body(p: Protocol):
- p.dataReceived(body)
- p.connectionLost(Failure(twisted.web.client.ResponseDone()))
+ self.get_success(
+ self.hs.get_federation_client().backfill(
+ # We use "yet.another.server" instead of
+ # `self.OTHER_SERVER_NAME` because we want to see the behavior
+ # from `_check_sigs_and_hash_and_fetch_one` where it tries to
+ # fetch the PDU again from the origin server if the signature
+ # fails. Just want to make sure that the failure is counted from
+ # both code paths.
+ dest="yet.another.server",
+ room_id=room_id,
+ limit=1,
+ extremities=[pulled_event.event_id],
+ ),
+ )
- response = mock.Mock(
- code=200,
- phrase=b"OK",
- headers=twisted.web.client.Headers({"content-Type": ["application/json"]}),
- length=len(body),
- deliverBody=deliver_body,
- )
- mock.seal(response)
- return response
+ # Make sure our failed pull attempt was recorded
+ backfill_num_attempts = self.get_success(
+ main_store.db_pool.simple_select_one_onecol(
+ table="event_failed_pull_attempts",
+ keyvalues={"event_id": pulled_event.event_id},
+ retcol="num_attempts",
+ )
+ )
+ # This is 2 because it failed once from `self.OTHER_SERVER_NAME` and the
+ # other from "yet.another.server"
+ self.assertEqual(backfill_num_attempts, 2)
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 01a1db6115..f1e357764f 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -49,7 +49,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
sender = self.hs.get_federation_sender()
receipt = ReadReceipt(
- "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
+ "room_id",
+ "m.read",
+ "user_id",
+ ["event_id"],
+ thread_id=None,
+ data={"ts": 1234},
)
self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
@@ -89,7 +94,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
sender = self.hs.get_federation_sender()
receipt = ReadReceipt(
- "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
+ "room_id",
+ "m.read",
+ "user_id",
+ ["event_id"],
+ thread_id=None,
+ data={"ts": 1234},
)
self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
@@ -121,7 +131,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
# send the second RR
receipt = ReadReceipt(
- "room_id", "m.read", "user_id", ["other_id"], {"ts": 1234}
+ "room_id",
+ "m.read",
+ "user_id",
+ ["other_id"],
+ thread_id=None,
+ data={"ts": 1234},
)
self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
self.pump()
@@ -173,17 +188,24 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
return c
def prepare(self, reactor, clock, hs):
- # stub out `get_rooms_for_user` and `get_users_in_room` so that the
+ test_room_id = "!room:host1"
+
+ # stub out `get_rooms_for_user` and `get_current_hosts_in_room` so that the
# server thinks the user shares a room with `@user2:host2`
def get_rooms_for_user(user_id):
- return defer.succeed({"!room:host1"})
+ return defer.succeed({test_room_id})
hs.get_datastores().main.get_rooms_for_user = get_rooms_for_user
- def get_users_in_room(room_id):
- return defer.succeed({"@user2:host2"})
+ async def get_current_hosts_in_room(room_id):
+ if room_id == test_room_id:
+ return ["host2"]
+
+ # TODO: We should fail the test when we encounter an unxpected room ID.
+ # We can't just use `self.fail(...)` here because the app code is greedy
+ # with `Exception` and will catch it before the test can see it.
- hs.get_datastores().main.get_users_in_room = get_users_in_room
+ hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room
# whenever send_transaction is called, record the edu data
self.edus = []
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index 3a6ef221ae..177e5b5afc 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -212,7 +212,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
@override_config({"experimental_features": {"msc3706_enabled": True}})
- def test_send_join_partial_state(self):
+ def test_send_join_partial_state(self) -> None:
"""When MSC3706 support is enabled, /send_join should return partial state"""
joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME
join_result = self._make_join(joining_user)
@@ -240,6 +240,9 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
("m.room.power_levels", ""),
("m.room.join_rules", ""),
("m.room.history_visibility", ""),
+ # Users included here because they're heroes.
+ ("m.room.member", "@kermit:test"),
+ ("m.room.member", "@fozzie:test"),
],
)
@@ -249,9 +252,9 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
]
self.assertCountEqual(
returned_auth_chain_events,
- [
- ("m.room.member", "@kermit:test"),
- ],
+ # TODO: change the test so that we get at least one event in the auth chain
+ # here.
+ [],
)
# the room should show that the new user is a member
diff --git a/tests/federation/transport/server/test__base.py b/tests/federation/transport/server/test__base.py
index d33e86db4c..e88e5d8bb3 100644
--- a/tests/federation/transport/server/test__base.py
+++ b/tests/federation/transport/server/test__base.py
@@ -18,9 +18,10 @@ from typing import Dict, List, Tuple
from synapse.api.errors import Codes
from synapse.federation.transport.server import BaseFederationServlet
from synapse.federation.transport.server._base import Authenticator, _parse_auth_header
-from synapse.http.server import JsonResource, cancellable
+from synapse.http.server import JsonResource
from synapse.server import HomeServer
from synapse.types import JsonDict
+from synapse.util.cancellation import cancellable
from synapse.util.ratelimitutils import FederationRateLimiter
from tests import unittest
diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py
index c2320ce133..b84c74fc0e 100644
--- a/tests/federation/transport/test_client.py
+++ b/tests/federation/transport/test_client.py
@@ -13,9 +13,13 @@
# limitations under the License.
import json
+from unittest.mock import Mock
+
+import ijson.common
from synapse.api.room_versions import RoomVersions
from synapse.federation.transport.client import SendJoinParser
+from synapse.util import ExceptionBundle
from tests.unittest import TestCase
@@ -94,3 +98,46 @@ class SendJoinParserTestCase(TestCase):
# Retrieve and check the parsed SendJoinResponse
parsed_response = parser.finish()
self.assertEqual(parsed_response.servers_in_room, ["hs1", "hs2"])
+
+ def test_errors_closing_coroutines(self) -> None:
+ """Check we close all coroutines, even if closing the first raises an Exception.
+
+ We also check that an Exception of some kind is raised, but we don't make any
+ assertions about its attributes or type.
+ """
+ parser = SendJoinParser(RoomVersions.V1, False)
+ response = {"org.matrix.msc3706.servers_in_room": ["hs1", "hs2"]}
+ serialisation = json.dumps(response).encode()
+
+ # Mock the coroutines managed by this parser.
+ # The first one will error when we try to close it.
+ coro_1 = Mock()
+ coro_1.close = Mock(side_effect=RuntimeError("Couldn't close coro 1"))
+
+ coro_2 = Mock()
+
+ coro_3 = Mock()
+ coro_3.close = Mock(side_effect=RuntimeError("Couldn't close coro 3"))
+
+ original_coros = parser._coros
+ parser._coros = [coro_1, coro_2, coro_3]
+
+ # Close the original coroutines. If we don't, when we garbage collect them
+ # they will throw, failing the test. (Oddly, this only started in CPython 3.11).
+ for coro in original_coros:
+ try:
+ coro.close()
+ except ijson.common.IncompleteJSONError:
+ pass
+
+ # Send half of the data to the parser
+ parser.write(serialisation[: len(serialisation) // 2])
+
+ # Close the parser. There should be _some_ kind of exception.
+ with self.assertRaises(ExceptionBundle):
+ parser.finish()
+
+ # In any case, we should have tried to close both coros.
+ coro_1.close.assert_called()
+ coro_2.close.assert_called()
+ coro_3.close.assert_called()
diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py
index 0d048207b7..d21c11b716 100644
--- a/tests/federation/transport/test_knocking.py
+++ b/tests/federation/transport/test_knocking.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
-from http import HTTPStatus
from typing import Dict, List
from synapse.api.constants import EventTypes, JoinRules, Membership
@@ -256,7 +255,7 @@ class FederationKnockingTestCase(
RoomVersions.V7.identifier,
),
)
- self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+ self.assertEqual(200, channel.code, channel.result)
# Note: We don't expect the knock membership event to be sent over federation as
# part of the stripped room state, as the knocking homeserver already has that
@@ -294,7 +293,7 @@ class FederationKnockingTestCase(
% (room_id, signed_knock_event.event_id),
signed_knock_event_json,
)
- self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+ self.assertEqual(200, channel.code, channel.result)
# Check that we got the stripped room state in return
room_state_events = channel.json_body["knock_state_events"]
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index b17af2725b..144e49d0fd 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -22,7 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
import synapse.storage
-from synapse.api.constants import EduTypes
+from synapse.api.constants import EduTypes, EventTypes
from synapse.appservice import (
ApplicationService,
TransactionOneTimeKeyCounts,
@@ -36,7 +36,7 @@ from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
-from tests.test_utils import make_awaitable, simple_async_mock
+from tests.test_utils import event_injection, make_awaitable, simple_async_mock
from tests.unittest import override_config
from tests.utils import MockClock
@@ -76,9 +76,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
event = Mock(
sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
)
- self.mock_store.get_all_new_events_stream.side_effect = [
- make_awaitable((0, [], {})),
- make_awaitable((1, [event], {event.event_id: 0})),
+ self.mock_store.get_all_new_event_ids_stream.side_effect = [
+ make_awaitable((0, {})),
+ make_awaitable((1, {event.event_id: 0})),
+ ]
+ self.mock_store.get_events_as_list.side_effect = [
+ make_awaitable([]),
+ make_awaitable([event]),
]
self.handler.notify_interested_services(RoomStreamToken(None, 1))
@@ -95,10 +99,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
self.mock_as_api.query_user.return_value = make_awaitable(True)
- self.mock_store.get_all_new_events_stream.side_effect = [
- make_awaitable((0, [event], {event.event_id: 0})),
+ self.mock_store.get_all_new_event_ids_stream.side_effect = [
+ make_awaitable((0, {event.event_id: 0})),
]
-
+ self.mock_store.get_events_as_list.side_effect = [make_awaitable([event])]
self.handler.notify_interested_services(RoomStreamToken(None, 0))
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
@@ -112,7 +116,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
self.mock_as_api.query_user.return_value = make_awaitable(True)
- self.mock_store.get_all_new_events_stream.side_effect = [
+ self.mock_store.get_all_new_event_ids_stream.side_effect = [
make_awaitable((0, [event], {event.event_id: 0})),
]
@@ -386,15 +390,16 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
receipts.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+ self.hs = hs
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that
# we can track any outgoing ephemeral events
self.send_mock = simple_async_mock()
- hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock
+ hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment]
# Mock out application services, and allow defining our own in tests
self._services: List[ApplicationService] = []
- self.hs.get_datastores().main.get_app_services = Mock(
+ self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment]
return_value=self._services
)
@@ -412,6 +417,157 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
"exclusive_as_user", "password", self.exclusive_as_user_device_id
)
+ def _notify_interested_services(self):
+ # This is normally set in `notify_interested_services` but we need to call the
+ # internal async version so the reactor gets pushed to completion.
+ self.hs.get_application_service_handler().current_max += 1
+ self.get_success(
+ self.hs.get_application_service_handler()._notify_interested_services(
+ RoomStreamToken(
+ None, self.hs.get_application_service_handler().current_max
+ )
+ )
+ )
+
+ @parameterized.expand(
+ [
+ ("@local_as_user:test", True),
+ # Defining remote users in an application service user namespace regex is a
+ # footgun since the appservice might assume that it'll receive all events
+ # sent by that remote user, but it will only receive events in rooms that
+ # are shared with a local user. So we just remove this footgun possibility
+ # entirely and we won't notify the application service based on remote
+ # users.
+ ("@remote_as_user:remote", False),
+ ]
+ )
+ def test_match_interesting_room_members(
+ self, interesting_user: str, should_notify: bool
+ ):
+ """
+ Test to make sure that a interesting user (local or remote) in the room is
+ notified as expected when someone else in the room sends a message.
+ """
+ # Register an application service that's interested in the `interesting_user`
+ interested_appservice = self._register_application_service(
+ namespaces={
+ ApplicationService.NS_USERS: [
+ {
+ "regex": interesting_user,
+ "exclusive": False,
+ },
+ ],
+ },
+ )
+
+ # Create a room
+ alice = self.register_user("alice", "pass")
+ alice_access_token = self.login("alice", "pass")
+ room_id = self.helper.create_room_as(room_creator=alice, tok=alice_access_token)
+
+ # Join the interesting user to the room
+ self.get_success(
+ event_injection.inject_member_event(
+ self.hs, room_id, interesting_user, "join"
+ )
+ )
+ # Kick the appservice into checking this membership event to get the event out
+ # of the way
+ self._notify_interested_services()
+ # We don't care about the interesting user join event (this test is making sure
+ # the next thing works)
+ self.send_mock.reset_mock()
+
+ # Send a message from an uninteresting user
+ self.helper.send_event(
+ room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "message from uninteresting user",
+ },
+ tok=alice_access_token,
+ )
+ # Kick the appservice into checking this new event
+ self._notify_interested_services()
+
+ if should_notify:
+ self.send_mock.assert_called_once()
+ (
+ service,
+ events,
+ _ephemeral,
+ _to_device_messages,
+ _otks,
+ _fbks,
+ _device_list_summary,
+ ) = self.send_mock.call_args[0]
+
+ # Even though the message came from an uninteresting user, it should still
+ # notify us because the interesting user is joined to the room where the
+ # message was sent.
+ self.assertEqual(service, interested_appservice)
+ self.assertEqual(events[0]["type"], "m.room.message")
+ self.assertEqual(events[0]["sender"], alice)
+ else:
+ self.send_mock.assert_not_called()
+
+ def test_application_services_receive_events_sent_by_interesting_local_user(self):
+ """
+ Test to make sure that a messages sent from a local user can be interesting and
+ picked up by the appservice.
+ """
+ # Register an application service that's interested in all local users
+ interested_appservice = self._register_application_service(
+ namespaces={
+ ApplicationService.NS_USERS: [
+ {
+ "regex": ".*",
+ "exclusive": False,
+ },
+ ],
+ },
+ )
+
+ # Create a room
+ alice = self.register_user("alice", "pass")
+ alice_access_token = self.login("alice", "pass")
+ room_id = self.helper.create_room_as(room_creator=alice, tok=alice_access_token)
+
+ # We don't care about interesting events before this (this test is making sure
+ # the next thing works)
+ self.send_mock.reset_mock()
+
+ # Send a message from the interesting local user
+ self.helper.send_event(
+ room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "message from interesting local user",
+ },
+ tok=alice_access_token,
+ )
+ # Kick the appservice into checking this new event
+ self._notify_interested_services()
+
+ self.send_mock.assert_called_once()
+ (
+ service,
+ events,
+ _ephemeral,
+ _to_device_messages,
+ _otks,
+ _fbks,
+ _device_list_summary,
+ ) = self.send_mock.call_args[0]
+
+ # Events sent from an interesting local user should also be picked up as
+ # interesting to the appservice.
+ self.assertEqual(service, interested_appservice)
+ self.assertEqual(events[0]["type"], "m.room.message")
+ self.assertEqual(events[0]["sender"], alice)
+
def test_sending_read_receipt_batches_to_application_services(self):
"""Tests that a large batch of read receipts are sent correctly to
interested application services.
@@ -447,6 +603,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
receipt_type="m.read",
user_id=self.local_user,
event_ids=[f"$eventid_{i}"],
+ thread_id=None,
data={},
)
)
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 7106799d44..036dbbc45b 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
from unittest.mock import Mock
import pymacaroons
@@ -19,6 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import AuthError, ResourceLimitError
from synapse.rest import admin
+from synapse.rest.client import login
from synapse.server import HomeServer
from synapse.util import Clock
@@ -29,6 +31,7 @@ from tests.test_utils import make_awaitable
class AuthTestCase(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,
+ login.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -46,6 +49,23 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.user1 = self.register_user("a_user", "pass")
+ def token_login(self, token: str) -> Optional[str]:
+ body = {
+ "type": "m.login.token",
+ "token": token,
+ }
+
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/login",
+ body,
+ )
+
+ if channel.code == 200:
+ return channel.json_body["user_id"]
+
+ return None
+
def test_macaroon_caveats(self) -> None:
token = self.macaroon_generator.generate_guest_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token)
@@ -73,49 +93,62 @@ class AuthTestCase(unittest.HomeserverTestCase):
v.satisfy_general(verify_guest)
v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
- def test_short_term_login_token_gives_user_id(self) -> None:
- token = self.macaroon_generator.generate_short_term_login_token(
- self.user1, "", duration_in_ms=5000
+ def test_login_token_gives_user_id(self) -> None:
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(
+ self.user1,
+ duration_ms=(5 * 1000),
+ )
)
- res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
+
+ res = self.get_success(self.auth_handler.consume_login_token(token))
self.assertEqual(self.user1, res.user_id)
- self.assertEqual("", res.auth_provider_id)
+ self.assertEqual(None, res.auth_provider_id)
- # when we advance the clock, the token should be rejected
- self.reactor.advance(6)
- self.get_failure(
- self.auth_handler.validate_short_term_login_token(token),
- AuthError,
+ def test_login_token_reuse_fails(self) -> None:
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(
+ self.user1,
+ duration_ms=(5 * 1000),
+ )
)
- def test_short_term_login_token_gives_auth_provider(self) -> None:
- token = self.macaroon_generator.generate_short_term_login_token(
- self.user1, auth_provider_id="my_idp"
- )
- res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
- self.assertEqual(self.user1, res.user_id)
- self.assertEqual("my_idp", res.auth_provider_id)
+ self.get_success(self.auth_handler.consume_login_token(token))
- def test_short_term_login_token_cannot_replace_user_id(self) -> None:
- token = self.macaroon_generator.generate_short_term_login_token(
- self.user1, "", duration_in_ms=5000
+ self.get_failure(
+ self.auth_handler.consume_login_token(token),
+ AuthError,
)
- macaroon = pymacaroons.Macaroon.deserialize(token)
- res = self.get_success(
- self.auth_handler.validate_short_term_login_token(macaroon.serialize())
+ def test_login_token_expires(self) -> None:
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(
+ self.user1,
+ duration_ms=(5 * 1000),
+ )
)
- self.assertEqual(self.user1, res.user_id)
-
- # add another "user_id" caveat, which might allow us to override the
- # user_id.
- macaroon.add_first_party_caveat("user_id = b_user")
+ # when we advance the clock, the token should be rejected
+ self.reactor.advance(6)
self.get_failure(
- self.auth_handler.validate_short_term_login_token(macaroon.serialize()),
+ self.auth_handler.consume_login_token(token),
AuthError,
)
+ def test_login_token_gives_auth_provider(self) -> None:
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(
+ self.user1,
+ auth_provider_id="my_idp",
+ auth_provider_session_id="11-22-33-44",
+ duration_ms=(5 * 1000),
+ )
+ )
+ res = self.get_success(self.auth_handler.consume_login_token(token))
+ self.assertEqual(self.user1, res.user_id)
+ self.assertEqual("my_idp", res.auth_provider_id)
+ self.assertEqual("11-22-33-44", res.auth_provider_session_id)
+
def test_mau_limits_disabled(self) -> None:
self.auth_blocking._limit_usage_by_mau = False
# Ensure does not throw exception
@@ -125,12 +158,12 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
)
- self.get_success(
- self.auth_handler.validate_short_term_login_token(
- self._get_macaroon().serialize()
- )
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(self.user1)
)
+ self.assertIsNotNone(self.token_login(token))
+
def test_mau_limits_exceeded_large(self) -> None:
self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastores().main.get_monthly_active_count = Mock(
@@ -147,12 +180,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.hs.get_datastores().main.get_monthly_active_count = Mock(
return_value=make_awaitable(self.large_number_of_users)
)
- self.get_failure(
- self.auth_handler.validate_short_term_login_token(
- self._get_macaroon().serialize()
- ),
- ResourceLimitError,
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(self.user1)
)
+ self.assertIsNone(self.token_login(token))
def test_mau_limits_parity(self) -> None:
# Ensure we're not at the unix epoch.
@@ -171,12 +202,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
),
ResourceLimitError,
)
- self.get_failure(
- self.auth_handler.validate_short_term_login_token(
- self._get_macaroon().serialize()
- ),
- ResourceLimitError,
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(self.user1)
)
+ self.assertIsNone(self.token_login(token))
# If in monthly active cohort
self.hs.get_datastores().main.user_last_seen_monthly_active = Mock(
@@ -187,11 +216,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.user1, device_id=None, valid_until_ms=None
)
)
- self.get_success(
- self.auth_handler.validate_short_term_login_token(
- self._get_macaroon().serialize()
- )
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(self.user1)
)
+ self.assertIsNotNone(self.token_login(token))
def test_mau_limits_not_exceeded(self) -> None:
self.auth_blocking._limit_usage_by_mau = True
@@ -209,14 +237,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.hs.get_datastores().main.get_monthly_active_count = Mock(
return_value=make_awaitable(self.small_number_of_users)
)
- self.get_success(
- self.auth_handler.validate_short_term_login_token(
- self._get_macaroon().serialize()
- )
- )
-
- def _get_macaroon(self) -> pymacaroons.Macaroon:
- token = self.macaroon_generator.generate_short_term_login_token(
- self.user1, "", duration_in_ms=5000
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(self.user1)
)
- return pymacaroons.Macaroon.deserialize(token)
+ self.assertIsNotNone(self.token_login(token))
diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py
index 7586e472b5..bce65fab7d 100644
--- a/tests/handlers/test_deactivate_account.py
+++ b/tests/handlers/test_deactivate_account.py
@@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
-from typing import Any, Dict
from twisted.test.proto_helpers import MemoryReactor
@@ -21,6 +19,7 @@ from synapse.push.rulekinds import PRIORITY_CLASS_MAP
from synapse.rest import admin
from synapse.rest.client import account, login
from synapse.server import HomeServer
+from synapse.synapse_rust.push import PushRule
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -58,7 +57,7 @@ class DeactivateAccountTestCase(HomeserverTestCase):
access_token=self.token,
)
- self.assertEqual(req.code, HTTPStatus.OK, req)
+ self.assertEqual(req.code, 200, req)
def test_global_account_data_deleted_upon_deactivation(self) -> None:
"""
@@ -131,12 +130,12 @@ class DeactivateAccountTestCase(HomeserverTestCase):
),
)
- def _is_custom_rule(self, push_rule: Dict[str, Any]) -> bool:
+ def _is_custom_rule(self, push_rule: PushRule) -> bool:
"""
Default rules start with a dot: such as .m.rule and .im.vector.
This function returns true iff a rule is custom (not default).
"""
- return "/." not in push_rule["rule_id"]
+ return "/." not in push_rule.rule_id
def test_push_rules_deleted_upon_account_deactivation(self) -> None:
"""
@@ -158,32 +157,30 @@ class DeactivateAccountTestCase(HomeserverTestCase):
)
# Test the rule exists
- push_rules = self.get_success(self._store.get_push_rules_for_user(self.user))
+ filtered_push_rules = self.get_success(
+ self._store.get_push_rules_for_user(self.user)
+ )
# Filter out default rules; we don't care
- push_rules = list(filter(self._is_custom_rule, push_rules))
+ push_rules = [
+ r for r, _ in filtered_push_rules.rules() if self._is_custom_rule(r)
+ ]
# Check our rule made it
- self.assertEqual(
- push_rules,
- [
- {
- "user_name": "@user:test",
- "rule_id": "personal.override.rule1",
- "priority_class": 5,
- "priority": 0,
- "conditions": [],
- "actions": [],
- "default": False,
- }
- ],
- push_rules,
- )
+ self.assertEqual(len(push_rules), 1)
+ self.assertEqual(push_rules[0].rule_id, "personal.override.rule1")
+ self.assertEqual(push_rules[0].priority_class, 5)
+ self.assertEqual(push_rules[0].conditions, [])
+ self.assertEqual(push_rules[0].actions, [])
# Request the deactivation of our account
self._deactivate_my_account()
- push_rules = self.get_success(self._store.get_push_rules_for_user(self.user))
+ filtered_push_rules = self.get_success(
+ self._store.get_push_rules_for_user(self.user)
+ )
# Filter out default rules; we don't care
- push_rules = list(filter(self._is_custom_rule, push_rules))
+ push_rules = [
+ r for r, _ in filtered_push_rules.rules() if self._is_custom_rule(r)
+ ]
# Check our rule no longer exists
self.assertEqual(push_rules, [], push_rules)
@@ -322,3 +319,18 @@ class DeactivateAccountTestCase(HomeserverTestCase):
)
),
)
+
+ def test_deactivate_account_needs_auth(self) -> None:
+ """
+ Tests that making a request to /deactivate with an empty body
+ succeeds in starting the user-interactive auth flow.
+ """
+ req = self.make_request(
+ "POST",
+ "account/deactivate",
+ {},
+ access_token=self.token,
+ )
+
+ self.assertEqual(req.code, 401, req)
+ self.assertEqual(req.json_body["flows"], [{"stages": ["m.login.password"]}])
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index b8b465d35b..ce7525e29c 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -19,7 +19,7 @@ from typing import Optional
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import NotFoundError, SynapseError
-from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN
+from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
from synapse.server import HomeServer
from synapse.util import Clock
@@ -32,7 +32,9 @@ user2 = "@theresa:bbb"
class DeviceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None)
- self.handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.handler = handler
self.store = hs.get_datastores().main
return hs
@@ -61,6 +63,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
self.assertEqual(res, "fco")
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
+ assert dev is not None
self.assertEqual(dev["display_name"], "display name")
def test_device_is_preserved_if_exists(self) -> None:
@@ -83,6 +86,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
self.assertEqual(res2, "fco")
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
+ assert dev is not None
self.assertEqual(dev["display_name"], "display name")
def test_device_id_is_made_up_if_unspecified(self) -> None:
@@ -95,6 +99,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
)
dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id))
+ assert dev is not None
self.assertEqual(dev["display_name"], "display")
def test_get_devices_by_user(self) -> None:
@@ -264,7 +269,9 @@ class DeviceTestCase(unittest.HomeserverTestCase):
class DehydrationTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None)
- self.handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.handler = handler
self.registration = hs.get_registration_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
@@ -284,9 +291,9 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
)
)
- retrieved_device_id, device_data = self.get_success(
- self.handler.get_dehydrated_device(user_id=user_id)
- )
+ result = self.get_success(self.handler.get_dehydrated_device(user_id=user_id))
+ assert result is not None
+ retrieved_device_id, device_data = result
self.assertEqual(retrieved_device_id, stored_dehydrated_device_id)
self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 1e6ad4b663..95698bc275 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -891,6 +891,12 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
new_callable=mock.MagicMock,
return_value=make_awaitable(["some_room_id"]),
)
+ mock_get_users = mock.patch.object(
+ self.store,
+ "get_users_server_still_shares_room_with",
+ new_callable=mock.MagicMock,
+ return_value=make_awaitable({remote_user_id}),
+ )
mock_request = mock.patch.object(
self.hs.get_federation_client(),
"query_user_devices",
@@ -898,7 +904,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
return_value=make_awaitable(response_body),
)
- with mock_get_rooms, mock_request as mocked_federation_request:
+ with mock_get_rooms, mock_get_users, mock_request as mocked_federation_request:
# Make the first query and sanity check it succeeds.
response_1 = self.get_success(
e2e_handler.query_devices(
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 745750b1d7..d00c69c229 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -19,7 +19,13 @@ from unittest.mock import Mock, patch
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes
-from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ LimitExceededError,
+ NotFoundError,
+ SynapseError,
+)
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, make_event_from_dict
from synapse.federation.federation_base import event_from_pdu_json
@@ -28,6 +34,7 @@ from synapse.logging.context import LoggingContext, run_in_background
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
+from synapse.storage.databases.main.events_worker import EventCacheEntry
from synapse.util import Clock
from synapse.util.stringutils import random_string
@@ -322,6 +329,102 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
)
self.get_success(d)
+ def test_backfill_ignores_known_events(self) -> None:
+ """
+ Tests that events that we already know about are ignored when backfilling.
+ """
+ # Set up users
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+
+ other_server = "otherserver"
+ other_user = "@otheruser:" + other_server
+
+ # Create a room to backfill events into
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+ room_version = self.get_success(self.store.get_room_version(room_id))
+
+ # Build an event to backfill
+ event = event_from_pdu_json(
+ {
+ "type": EventTypes.Message,
+ "content": {"body": "hello world", "msgtype": "m.text"},
+ "room_id": room_id,
+ "sender": other_user,
+ "depth": 32,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ room_version,
+ )
+
+ # Ensure the event is not already in the DB
+ self.get_failure(
+ self.store.get_event(event.event_id),
+ NotFoundError,
+ )
+
+ # Backfill the event and check that it has entered the DB.
+
+ # We mock out the FederationClient.backfill method, to pretend that a remote
+ # server has returned our fake event.
+ federation_client_backfill_mock = Mock(return_value=make_awaitable([event]))
+ self.hs.get_federation_client().backfill = federation_client_backfill_mock
+
+ # We also mock the persist method with a side effect of itself. This allows us
+ # to track when it has been called while preserving its function.
+ persist_events_and_notify_mock = Mock(
+ side_effect=self.hs.get_federation_event_handler().persist_events_and_notify
+ )
+ self.hs.get_federation_event_handler().persist_events_and_notify = (
+ persist_events_and_notify_mock
+ )
+
+ # Small side-tangent. We populate the event cache with the event, even though
+ # it is not yet in the DB. This is an invalid scenario that can currently occur
+ # due to not properly invalidating the event cache.
+ # See https://github.com/matrix-org/synapse/issues/13476.
+ #
+ # As a result, backfill should not rely on the event cache to check whether
+ # we already have an event in the DB.
+ # TODO: Remove this bit when the event cache is properly invalidated.
+ cache_entry = EventCacheEntry(
+ event=event,
+ redacted_event=None,
+ )
+ self.store._get_event_cache.set_local((event.event_id,), cache_entry)
+
+ # We now call FederationEventHandler.backfill (a separate method) to trigger
+ # a backfill request. It should receive the fake event.
+ self.get_success(
+ self.hs.get_federation_event_handler().backfill(
+ other_user,
+ room_id,
+ limit=10,
+ extremities=[],
+ )
+ )
+
+ # Check that our fake event was persisted.
+ persist_events_and_notify_mock.assert_called_once()
+ persist_events_and_notify_mock.reset_mock()
+
+ # Now we repeat the backfill, having the homeserver receive the fake event
+ # again.
+ self.get_success(
+ self.hs.get_federation_event_handler().backfill(
+ other_user,
+ room_id,
+ limit=10,
+ extremities=[],
+ ),
+ )
+
+ # This time, we expect no event persistence to have occurred, as we already
+ # have this event.
+ persist_events_and_notify_mock.assert_not_called()
+
@unittest.override_config(
{"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
)
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index 51c8dd6498..e448cb1901 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -11,14 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
from unittest import mock
+from synapse.api.errors import AuthError, StoreError
+from synapse.api.room_versions import RoomVersion
+from synapse.event_auth import (
+ check_state_dependent_auth_rules,
+ check_state_independent_auth_rules,
+)
from synapse.events import make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.federation.transport.client import StateRequestResponse
from synapse.logging.context import LoggingContext
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort
+from synapse.types import JsonDict
from tests import unittest
from tests.test_utils import event_injection, make_awaitable
@@ -34,7 +43,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
def make_homeserver(self, reactor, clock):
# mock out the federation transport client
self.mock_federation_transport_client = mock.Mock(
- spec=["get_room_state_ids", "get_room_state", "get_event"]
+ spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"]
)
return super().setup_test_homeserver(
federation_transport_client=self.mock_federation_transport_client
@@ -227,3 +236,812 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
if prev_exists_as_outlier:
self.mock_federation_transport_client.get_event.assert_not_called()
+
+ def test_process_pulled_event_records_failed_backfill_attempts(
+ self,
+ ) -> None:
+ """
+ Test to make sure that failed backfill attempts for an event are
+ recorded in the `event_failed_pull_attempts` table.
+
+ In this test, we pretend we are processing a "pulled" event via
+ backfill. The pulled event has a fake `prev_event` which our server has
+ obviously never seen before so it attempts to request the state at that
+ `prev_event` which expectedly fails because it's a fake event. Because
+ the server can't fetch the state at the missing `prev_event`, the
+ "pulled" event fails the history check and is fails to process.
+
+ We check that we correctly record the number of failed pull attempts
+ of the pulled event and as a sanity check, that the "pulled" event isn't
+ persisted.
+ """
+ OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
+ main_store = self.hs.get_datastores().main
+
+ # Create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+ room_version = self.get_success(main_store.get_room_version(room_id))
+
+ # We expect an outbound request to /state_ids, so stub that out
+ self.mock_federation_transport_client.get_room_state_ids.return_value = make_awaitable(
+ {
+ # Mimic the other server not knowing about the state at all.
+ # We want to cause Synapse to throw an error (`Unable to get
+ # missing prev_event $fake_prev_event`) and fail to backfill
+ # the pulled event.
+ "pdu_ids": [],
+ "auth_chain_ids": [],
+ }
+ )
+ # We also expect an outbound request to /state
+ self.mock_federation_transport_client.get_room_state.return_value = make_awaitable(
+ StateRequestResponse(
+ # Mimic the other server not knowing about the state at all.
+ # We want to cause Synapse to throw an error (`Unable to get
+ # missing prev_event $fake_prev_event`) and fail to backfill
+ # the pulled event.
+ auth_events=[],
+ state=[],
+ )
+ )
+
+ pulled_event = make_event_from_dict(
+ self.add_hashes_and_signatures_from_other_server(
+ {
+ "type": "test_regular_type",
+ "room_id": room_id,
+ "sender": OTHER_USER,
+ "prev_events": [
+ # The fake prev event will make the pulled event fail
+ # the history check (`Unable to get missing prev_event
+ # $fake_prev_event`)
+ "$fake_prev_event"
+ ],
+ "auth_events": [],
+ "origin_server_ts": 1,
+ "depth": 12,
+ "content": {"body": "pulled"},
+ }
+ ),
+ room_version,
+ )
+
+ # The function under test: try to process the pulled event
+ with LoggingContext("test"):
+ self.get_success(
+ self.hs.get_federation_event_handler()._process_pulled_event(
+ self.OTHER_SERVER_NAME, pulled_event, backfilled=True
+ )
+ )
+
+ # Make sure our failed pull attempt was recorded
+ backfill_num_attempts = self.get_success(
+ main_store.db_pool.simple_select_one_onecol(
+ table="event_failed_pull_attempts",
+ keyvalues={"event_id": pulled_event.event_id},
+ retcol="num_attempts",
+ )
+ )
+ self.assertEqual(backfill_num_attempts, 1)
+
+ # The function under test: try to process the pulled event again
+ with LoggingContext("test"):
+ self.get_success(
+ self.hs.get_federation_event_handler()._process_pulled_event(
+ self.OTHER_SERVER_NAME, pulled_event, backfilled=True
+ )
+ )
+
+ # Make sure our second failed pull attempt was recorded (`num_attempts` was incremented)
+ backfill_num_attempts = self.get_success(
+ main_store.db_pool.simple_select_one_onecol(
+ table="event_failed_pull_attempts",
+ keyvalues={"event_id": pulled_event.event_id},
+ retcol="num_attempts",
+ )
+ )
+ self.assertEqual(backfill_num_attempts, 2)
+
+ # And as a sanity check, make sure the event was not persisted through all of this.
+ persisted = self.get_success(
+ main_store.get_event(pulled_event.event_id, allow_none=True)
+ )
+ self.assertIsNone(
+ persisted,
+ "pulled event that fails the history check should not be persisted at all",
+ )
+
+ def test_process_pulled_event_clears_backfill_attempts_after_being_successfully_persisted(
+ self,
+ ) -> None:
+ """
+ Test to make sure that failed pull attempts
+ (`event_failed_pull_attempts` table) for an event are cleared after the
+ event is successfully persisted.
+
+ In this test, we pretend we are processing a "pulled" event via
+ backfill. The pulled event succesfully processes and the backward
+ extremeties are updated along with clearing out any failed pull attempts
+ for those old extremities.
+
+ We check that we correctly cleared failed pull attempts of the
+ pulled event.
+ """
+ OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
+ main_store = self.hs.get_datastores().main
+
+ # Create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+ room_version = self.get_success(main_store.get_room_version(room_id))
+
+ # allow the remote user to send state events
+ self.helper.send_state(
+ room_id,
+ "m.room.power_levels",
+ {"events_default": 0, "state_default": 0},
+ tok=tok,
+ )
+
+ # add the remote user to the room
+ member_event = self.get_success(
+ event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join")
+ )
+
+ initial_state_map = self.get_success(
+ main_store.get_partial_current_state_ids(room_id)
+ )
+
+ auth_event_ids = [
+ initial_state_map[("m.room.create", "")],
+ initial_state_map[("m.room.power_levels", "")],
+ member_event.event_id,
+ ]
+
+ pulled_event = make_event_from_dict(
+ self.add_hashes_and_signatures_from_other_server(
+ {
+ "type": "test_regular_type",
+ "room_id": room_id,
+ "sender": OTHER_USER,
+ "prev_events": [member_event.event_id],
+ "auth_events": auth_event_ids,
+ "origin_server_ts": 1,
+ "depth": 12,
+ "content": {"body": "pulled"},
+ }
+ ),
+ room_version,
+ )
+
+ # Fake the "pulled" event failing to backfill once so we can test
+ # if it's cleared out later on.
+ self.get_success(
+ main_store.record_event_failed_pull_attempt(
+ pulled_event.room_id, pulled_event.event_id, "fake cause"
+ )
+ )
+ # Make sure we have a failed pull attempt recorded for the pulled event
+ backfill_num_attempts = self.get_success(
+ main_store.db_pool.simple_select_one_onecol(
+ table="event_failed_pull_attempts",
+ keyvalues={"event_id": pulled_event.event_id},
+ retcol="num_attempts",
+ )
+ )
+ self.assertEqual(backfill_num_attempts, 1)
+
+ # The function under test: try to process the pulled event
+ with LoggingContext("test"):
+ self.get_success(
+ self.hs.get_federation_event_handler()._process_pulled_event(
+ self.OTHER_SERVER_NAME, pulled_event, backfilled=True
+ )
+ )
+
+ # Make sure the failed pull attempts for the pulled event are cleared
+ backfill_num_attempts = self.get_success(
+ main_store.db_pool.simple_select_one_onecol(
+ table="event_failed_pull_attempts",
+ keyvalues={"event_id": pulled_event.event_id},
+ retcol="num_attempts",
+ allow_none=True,
+ )
+ )
+ self.assertIsNone(backfill_num_attempts)
+
+ # And as a sanity check, make sure the "pulled" event was persisted.
+ persisted = self.get_success(
+ main_store.get_event(pulled_event.event_id, allow_none=True)
+ )
+ self.assertIsNotNone(persisted, "pulled event was not persisted at all")
+
+ def test_backfill_signature_failure_does_not_fetch_same_prev_event_later(
+ self,
+ ) -> None:
+ """
+ Test to make sure we backoff and don't try to fetch a missing prev_event when we
+ already know it has a invalid signature from checking the signatures of all of
+ the events in the backfill response.
+ """
+ OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
+ main_store = self.hs.get_datastores().main
+
+ # Create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+ room_version = self.get_success(main_store.get_room_version(room_id))
+
+ # Allow the remote user to send state events
+ self.helper.send_state(
+ room_id,
+ "m.room.power_levels",
+ {"events_default": 0, "state_default": 0},
+ tok=tok,
+ )
+
+ # Add the remote user to the room
+ member_event = self.get_success(
+ event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join")
+ )
+
+ initial_state_map = self.get_success(
+ main_store.get_partial_current_state_ids(room_id)
+ )
+
+ auth_event_ids = [
+ initial_state_map[("m.room.create", "")],
+ initial_state_map[("m.room.power_levels", "")],
+ member_event.event_id,
+ ]
+
+ # We purposely don't run `add_hashes_and_signatures_from_other_server`
+ # over this because we want the signature check to fail.
+ pulled_event_without_signatures = make_event_from_dict(
+ {
+ "type": "test_regular_type",
+ "room_id": room_id,
+ "sender": OTHER_USER,
+ "prev_events": [member_event.event_id],
+ "auth_events": auth_event_ids,
+ "origin_server_ts": 1,
+ "depth": 12,
+ "content": {"body": "pulled_event_without_signatures"},
+ },
+ room_version,
+ )
+
+ # Create a regular event that should pass except for the
+ # `pulled_event_without_signatures` in the `prev_event`.
+ pulled_event = make_event_from_dict(
+ self.add_hashes_and_signatures_from_other_server(
+ {
+ "type": "test_regular_type",
+ "room_id": room_id,
+ "sender": OTHER_USER,
+ "prev_events": [
+ member_event.event_id,
+ pulled_event_without_signatures.event_id,
+ ],
+ "auth_events": auth_event_ids,
+ "origin_server_ts": 1,
+ "depth": 12,
+ "content": {"body": "pulled_event"},
+ }
+ ),
+ room_version,
+ )
+
+ # We expect an outbound request to /backfill, so stub that out
+ self.mock_federation_transport_client.backfill.return_value = make_awaitable(
+ {
+ "origin": self.OTHER_SERVER_NAME,
+ "origin_server_ts": 123,
+ "pdus": [
+ # This is one of the important aspects of this test: we include
+ # `pulled_event_without_signatures` so it fails the signature check
+ # when we filter down the backfill response down to events which
+ # have valid signatures in
+ # `_check_sigs_and_hash_for_pulled_events_and_fetch`
+ pulled_event_without_signatures.get_pdu_json(),
+ # Then later when we process this valid signature event, when we
+ # fetch the missing `prev_event`s, we want to make sure that we
+ # backoff and don't try and fetch `pulled_event_without_signatures`
+ # again since we know it just had an invalid signature.
+ pulled_event.get_pdu_json(),
+ ],
+ }
+ )
+
+ # Keep track of the count and make sure we don't make any of these requests
+ event_endpoint_requested_count = 0
+ room_state_ids_endpoint_requested_count = 0
+ room_state_endpoint_requested_count = 0
+
+ async def get_event(
+ destination: str, event_id: str, timeout: Optional[int] = None
+ ) -> None:
+ nonlocal event_endpoint_requested_count
+ event_endpoint_requested_count += 1
+
+ async def get_room_state_ids(
+ destination: str, room_id: str, event_id: str
+ ) -> None:
+ nonlocal room_state_ids_endpoint_requested_count
+ room_state_ids_endpoint_requested_count += 1
+
+ async def get_room_state(
+ room_version: RoomVersion, destination: str, room_id: str, event_id: str
+ ) -> None:
+ nonlocal room_state_endpoint_requested_count
+ room_state_endpoint_requested_count += 1
+
+ # We don't expect an outbound request to `/event`, `/state_ids`, or `/state` in
+ # the happy path but if the logic is sneaking around what we expect, stub that
+ # out so we can detect that failure
+ self.mock_federation_transport_client.get_event.side_effect = get_event
+ self.mock_federation_transport_client.get_room_state_ids.side_effect = (
+ get_room_state_ids
+ )
+ self.mock_federation_transport_client.get_room_state.side_effect = (
+ get_room_state
+ )
+
+ # The function under test: try to backfill and process the pulled event
+ with LoggingContext("test"):
+ self.get_success(
+ self.hs.get_federation_event_handler().backfill(
+ self.OTHER_SERVER_NAME,
+ room_id,
+ limit=1,
+ extremities=["$some_extremity"],
+ )
+ )
+
+ if event_endpoint_requested_count > 0:
+ self.fail(
+ "We don't expect an outbound request to /event in the happy path but if "
+ "the logic is sneaking around what we expect, make sure to fail the test. "
+ "We don't expect it because the signature failure should cause us to backoff "
+ "and not asking about pulled_event_without_signatures="
+ f"{pulled_event_without_signatures.event_id} again"
+ )
+
+ if room_state_ids_endpoint_requested_count > 0:
+ self.fail(
+ "We don't expect an outbound request to /state_ids in the happy path but if "
+ "the logic is sneaking around what we expect, make sure to fail the test. "
+ "We don't expect it because the signature failure should cause us to backoff "
+ "and not asking about pulled_event_without_signatures="
+ f"{pulled_event_without_signatures.event_id} again"
+ )
+
+ if room_state_endpoint_requested_count > 0:
+ self.fail(
+ "We don't expect an outbound request to /state in the happy path but if "
+ "the logic is sneaking around what we expect, make sure to fail the test. "
+ "We don't expect it because the signature failure should cause us to backoff "
+ "and not asking about pulled_event_without_signatures="
+ f"{pulled_event_without_signatures.event_id} again"
+ )
+
+ # Make sure we only recorded a single failure which corresponds to the signature
+ # failure initially in `_check_sigs_and_hash_for_pulled_events_and_fetch` before
+ # we process all of the pulled events.
+ backfill_num_attempts_for_event_without_signatures = self.get_success(
+ main_store.db_pool.simple_select_one_onecol(
+ table="event_failed_pull_attempts",
+ keyvalues={"event_id": pulled_event_without_signatures.event_id},
+ retcol="num_attempts",
+ )
+ )
+ self.assertEqual(backfill_num_attempts_for_event_without_signatures, 1)
+
+ # And make sure we didn't record a failure for the event that has the missing
+ # prev_event because we don't want to cause a cascade of failures. Not being
+ # able to fetch the `prev_events` just means we won't be able to de-outlier the
+ # pulled event. But we can still use an `outlier` in the state/auth chain for
+ # another event. So we shouldn't stop a downstream event from trying to pull it.
+ self.get_failure(
+ main_store.db_pool.simple_select_one_onecol(
+ table="event_failed_pull_attempts",
+ keyvalues={"event_id": pulled_event.event_id},
+ retcol="num_attempts",
+ ),
+ # StoreError: 404: No row found
+ StoreError,
+ )
+
+ def test_process_pulled_event_with_rejected_missing_state(self) -> None:
+ """Ensure that we correctly handle pulled events with missing state containing a
+ rejected state event
+
+ In this test, we pretend we are processing a "pulled" event (eg, via backfill
+ or get_missing_events). The pulled event has a prev_event we haven't previously
+ seen, so the server requests the state at that prev_event. We expect the server
+ to make a /state request.
+
+ We simulate a remote server whose /state includes a rejected kick event for a
+ local user. Notably, the kick event is rejected only because it cites a rejected
+ auth event and would otherwise be accepted based on the room state. During state
+ resolution, we re-run auth and can potentially introduce such rejected events
+ into the state if we are not careful.
+
+ We check that the pulled event is correctly persisted, and that the state
+ afterwards does not include the rejected kick.
+ """
+ # The DAG we are testing looks like:
+ #
+ # ...
+ # |
+ # v
+ # remote admin user joins
+ # | |
+ # +-------+ +-------+
+ # | |
+ # | rejected power levels
+ # | from remote server
+ # | |
+ # | v
+ # | rejected kick of local user
+ # v from remote server
+ # new power levels |
+ # | v
+ # | missing event
+ # | from remote server
+ # | |
+ # +-------+ +-------+
+ # | |
+ # v v
+ # pulled event
+ # from remote server
+ #
+ # (arrows are in the opposite direction to prev_events.)
+
+ OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
+ main_store = self.hs.get_datastores().main
+
+ # Create the room.
+ kermit_user_id = self.register_user("kermit", "test")
+ kermit_tok = self.login("kermit", "test")
+ room_id = self.helper.create_room_as(
+ room_creator=kermit_user_id, tok=kermit_tok
+ )
+ room_version = self.get_success(main_store.get_room_version(room_id))
+
+ # Add another local user to the room. This user is going to be kicked in a
+ # rejected event.
+ bert_user_id = self.register_user("bert", "test")
+ bert_tok = self.login("bert", "test")
+ self.helper.join(room_id, user=bert_user_id, tok=bert_tok)
+
+ # Allow the remote user to kick bert.
+ # The remote user is going to send a rejected power levels event later on and we
+ # need state resolution to order it before another power levels event kermit is
+ # going to send later on. Hence we give both users the same power level, so that
+ # ties are broken by `origin_server_ts`.
+ self.helper.send_state(
+ room_id,
+ "m.room.power_levels",
+ {"users": {kermit_user_id: 100, OTHER_USER: 100}},
+ tok=kermit_tok,
+ )
+
+ # Add the remote user to the room.
+ other_member_event = self.get_success(
+ event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join")
+ )
+
+ initial_state_map = self.get_success(
+ main_store.get_partial_current_state_ids(room_id)
+ )
+ create_event = self.get_success(
+ main_store.get_event(initial_state_map[("m.room.create", "")])
+ )
+ bert_member_event = self.get_success(
+ main_store.get_event(initial_state_map[("m.room.member", bert_user_id)])
+ )
+ power_levels_event = self.get_success(
+ main_store.get_event(initial_state_map[("m.room.power_levels", "")])
+ )
+
+ # We now need a rejected state event that will fail
+ # `check_state_independent_auth_rules` but pass
+ # `check_state_dependent_auth_rules`.
+
+ # First, we create a power levels event that we pretend the remote server has
+ # accepted, but the local homeserver will reject.
+ next_depth = 100
+ next_timestamp = other_member_event.origin_server_ts + 100
+ rejected_power_levels_event = make_event_from_dict(
+ self.add_hashes_and_signatures_from_other_server(
+ {
+ "type": "m.room.power_levels",
+ "state_key": "",
+ "room_id": room_id,
+ "sender": OTHER_USER,
+ "prev_events": [other_member_event.event_id],
+ "auth_events": [
+ initial_state_map[("m.room.create", "")],
+ initial_state_map[("m.room.power_levels", "")],
+ # The event will be rejected because of the duplicated auth
+ # event.
+ other_member_event.event_id,
+ other_member_event.event_id,
+ ],
+ "origin_server_ts": next_timestamp,
+ "depth": next_depth,
+ "content": power_levels_event.content,
+ }
+ ),
+ room_version,
+ )
+ next_depth += 1
+ next_timestamp += 100
+
+ with LoggingContext("send_rejected_power_levels_event"):
+ self.get_success(
+ self.hs.get_federation_event_handler()._process_pulled_event(
+ self.OTHER_SERVER_NAME,
+ rejected_power_levels_event,
+ backfilled=False,
+ )
+ )
+ self.assertEqual(
+ self.get_success(
+ main_store.get_rejection_reason(
+ rejected_power_levels_event.event_id
+ )
+ ),
+ "auth_error",
+ )
+
+ # Then we create a kick event for a local user that cites the rejected power
+ # levels event in its auth events. The kick event will be rejected solely
+ # because of the rejected auth event and would otherwise be accepted.
+ rejected_kick_event = make_event_from_dict(
+ self.add_hashes_and_signatures_from_other_server(
+ {
+ "type": "m.room.member",
+ "state_key": bert_user_id,
+ "room_id": room_id,
+ "sender": OTHER_USER,
+ "prev_events": [rejected_power_levels_event.event_id],
+ "auth_events": [
+ initial_state_map[("m.room.create", "")],
+ rejected_power_levels_event.event_id,
+ initial_state_map[("m.room.member", bert_user_id)],
+ initial_state_map[("m.room.member", OTHER_USER)],
+ ],
+ "origin_server_ts": next_timestamp,
+ "depth": next_depth,
+ "content": {"membership": "leave"},
+ }
+ ),
+ room_version,
+ )
+ next_depth += 1
+ next_timestamp += 100
+
+ # The kick event must fail the state-independent auth rules, but pass the
+ # state-dependent auth rules, so that it has a chance of making it through state
+ # resolution.
+ self.get_failure(
+ check_state_independent_auth_rules(main_store, rejected_kick_event),
+ AuthError,
+ )
+ check_state_dependent_auth_rules(
+ rejected_kick_event,
+ [create_event, power_levels_event, other_member_event, bert_member_event],
+ )
+
+ # The kick event must also win over the original member event during state
+ # resolution.
+ self.assertEqual(
+ self.get_success(
+ _mainline_sort(
+ self.clock,
+ room_id,
+ event_ids=[
+ bert_member_event.event_id,
+ rejected_kick_event.event_id,
+ ],
+ resolved_power_event_id=power_levels_event.event_id,
+ event_map={
+ bert_member_event.event_id: bert_member_event,
+ rejected_kick_event.event_id: rejected_kick_event,
+ },
+ state_res_store=main_store,
+ )
+ ),
+ [bert_member_event.event_id, rejected_kick_event.event_id],
+ "The rejected kick event will not be applied after bert's join event "
+ "during state resolution. The test setup is incorrect.",
+ )
+
+ with LoggingContext("send_rejected_kick_event"):
+ self.get_success(
+ self.hs.get_federation_event_handler()._process_pulled_event(
+ self.OTHER_SERVER_NAME, rejected_kick_event, backfilled=False
+ )
+ )
+ self.assertEqual(
+ self.get_success(
+ main_store.get_rejection_reason(rejected_kick_event.event_id)
+ ),
+ "auth_error",
+ )
+
+ # We need another power levels event which will win over the rejected one during
+ # state resolution, otherwise we hit other issues where we end up with rejected
+ # a power levels event during state resolution.
+ self.reactor.advance(100) # ensure the `origin_server_ts` is larger
+ new_power_levels_event = self.get_success(
+ main_store.get_event(
+ self.helper.send_state(
+ room_id,
+ "m.room.power_levels",
+ {"users": {kermit_user_id: 100, OTHER_USER: 100, bert_user_id: 1}},
+ tok=kermit_tok,
+ )["event_id"]
+ )
+ )
+ self.assertEqual(
+ self.get_success(
+ _reverse_topological_power_sort(
+ self.clock,
+ room_id,
+ event_ids=[
+ new_power_levels_event.event_id,
+ rejected_power_levels_event.event_id,
+ ],
+ event_map={},
+ state_res_store=main_store,
+ full_conflicted_set=set(),
+ )
+ ),
+ [rejected_power_levels_event.event_id, new_power_levels_event.event_id],
+ "The power levels events will not have the desired ordering during state "
+ "resolution. The test setup is incorrect.",
+ )
+
+ # Create a missing event, so that the local homeserver has to do a `/state` or
+ # `/state_ids` request to pull state from the remote homeserver.
+ missing_event = make_event_from_dict(
+ self.add_hashes_and_signatures_from_other_server(
+ {
+ "type": "m.room.message",
+ "room_id": room_id,
+ "sender": OTHER_USER,
+ "prev_events": [rejected_kick_event.event_id],
+ "auth_events": [
+ initial_state_map[("m.room.create", "")],
+ initial_state_map[("m.room.power_levels", "")],
+ initial_state_map[("m.room.member", OTHER_USER)],
+ ],
+ "origin_server_ts": next_timestamp,
+ "depth": next_depth,
+ "content": {"msgtype": "m.text", "body": "foo"},
+ }
+ ),
+ room_version,
+ )
+ next_depth += 1
+ next_timestamp += 100
+
+ # The pulled event has two prev events, one of which is missing. We will make a
+ # `/state` or `/state_ids` request to the remote homeserver to ask it for the
+ # state before the missing prev event.
+ pulled_event = make_event_from_dict(
+ self.add_hashes_and_signatures_from_other_server(
+ {
+ "type": "m.room.message",
+ "room_id": room_id,
+ "sender": OTHER_USER,
+ "prev_events": [
+ new_power_levels_event.event_id,
+ missing_event.event_id,
+ ],
+ "auth_events": [
+ initial_state_map[("m.room.create", "")],
+ new_power_levels_event.event_id,
+ initial_state_map[("m.room.member", OTHER_USER)],
+ ],
+ "origin_server_ts": next_timestamp,
+ "depth": next_depth,
+ "content": {"msgtype": "m.text", "body": "bar"},
+ }
+ ),
+ room_version,
+ )
+ next_depth += 1
+ next_timestamp += 100
+
+ # Prepare the response for the `/state` or `/state_ids` request.
+ # The remote server believes bert has been kicked, while the local server does
+ # not.
+ state_before_missing_event = self.get_success(
+ main_store.get_events_as_list(initial_state_map.values())
+ )
+ state_before_missing_event = [
+ event
+ for event in state_before_missing_event
+ if event.event_id != bert_member_event.event_id
+ ]
+ state_before_missing_event.append(rejected_kick_event)
+
+ # We have to bump the clock a bit, to keep the retry logic in
+ # `FederationClient.get_pdu` happy
+ self.reactor.advance(60000)
+ with LoggingContext("send_pulled_event"):
+
+ async def get_event(
+ destination: str, event_id: str, timeout: Optional[int] = None
+ ) -> JsonDict:
+ self.assertEqual(destination, self.OTHER_SERVER_NAME)
+ self.assertEqual(event_id, missing_event.event_id)
+ return {"pdus": [missing_event.get_pdu_json()]}
+
+ async def get_room_state_ids(
+ destination: str, room_id: str, event_id: str
+ ) -> JsonDict:
+ self.assertEqual(destination, self.OTHER_SERVER_NAME)
+ self.assertEqual(event_id, missing_event.event_id)
+ return {
+ "pdu_ids": [event.event_id for event in state_before_missing_event],
+ "auth_chain_ids": [],
+ }
+
+ async def get_room_state(
+ room_version: RoomVersion, destination: str, room_id: str, event_id: str
+ ) -> StateRequestResponse:
+ self.assertEqual(destination, self.OTHER_SERVER_NAME)
+ self.assertEqual(event_id, missing_event.event_id)
+ return StateRequestResponse(
+ state=state_before_missing_event,
+ auth_events=[],
+ )
+
+ self.mock_federation_transport_client.get_event.side_effect = get_event
+ self.mock_federation_transport_client.get_room_state_ids.side_effect = (
+ get_room_state_ids
+ )
+ self.mock_federation_transport_client.get_room_state.side_effect = (
+ get_room_state
+ )
+
+ self.get_success(
+ self.hs.get_federation_event_handler()._process_pulled_event(
+ self.OTHER_SERVER_NAME, pulled_event, backfilled=False
+ )
+ )
+ self.assertIsNone(
+ self.get_success(
+ main_store.get_rejection_reason(pulled_event.event_id)
+ ),
+ "Pulled event was unexpectedly rejected, likely due to a problem with "
+ "the test setup.",
+ )
+ self.assertEqual(
+ {pulled_event.event_id},
+ self.get_success(
+ main_store.have_events_in_timeline([pulled_event.event_id])
+ ),
+ "Pulled event was not persisted, likely due to a problem with the test "
+ "setup.",
+ )
+
+ # We must not accept rejected events into the room state, so we expect bert
+ # to not be kicked, even if the remote server believes so.
+ new_state_map = self.get_success(
+ main_store.get_partial_current_state_ids(room_id)
+ )
+ self.assertEqual(
+ new_state_map[("m.room.member", bert_user_id)],
+ bert_member_event.event_id,
+ "Rejected kick event unexpectedly became part of room state.",
+ )
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index 44da96c792..99384837d0 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -105,7 +105,10 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
event1, context = self._create_duplicate_event(txn_id)
ret_event1 = self.get_success(
- self.handler.handle_new_client_event(self.requester, event1, context)
+ self.handler.handle_new_client_event(
+ self.requester,
+ events_and_context=[(event1, context)],
+ )
)
stream_id1 = ret_event1.internal_metadata.stream_ordering
@@ -118,7 +121,10 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertNotEqual(event1.event_id, event2.event_id)
ret_event2 = self.get_success(
- self.handler.handle_new_client_event(self.requester, event2, context)
+ self.handler.handle_new_client_event(
+ self.requester,
+ events_and_context=[(event2, context)],
+ )
)
stream_id2 = ret_event2.internal_metadata.stream_ordering
@@ -314,4 +320,4 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST", path, content={}, access_token=self.access_token
)
- self.assertEqual(int(channel.result["code"]), 403)
+ self.assertEqual(channel.code, 403)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index e6cd3af7b7..5955410524 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
import os
-from typing import Any, Dict
+from typing import Any, Dict, Tuple
from unittest.mock import ANY, Mock, patch
from urllib.parse import parse_qs, urlparse
@@ -22,12 +21,15 @@ import pymacaroons
from twisted.test.proto_helpers import MemoryReactor
from synapse.handlers.sso import MappingException
+from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
-from synapse.types import JsonDict, UserID
+from synapse.types import UserID
from synapse.util import Clock
-from synapse.util.macaroons import OidcSessionData, get_value_from_macaroon
+from synapse.util.macaroons import get_value_from_macaroon
+from synapse.util.stringutils import random_string
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
+from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer
from tests.unittest import HomeserverTestCase, override_config
try:
@@ -46,12 +48,6 @@ BASE_URL = "https://synapse/"
CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback"
SCOPES = ["openid"]
-AUTHORIZATION_ENDPOINT = ISSUER + "authorize"
-TOKEN_ENDPOINT = ISSUER + "token"
-USERINFO_ENDPOINT = ISSUER + "userinfo"
-WELL_KNOWN = ISSUER + ".well-known/openid-configuration"
-JWKS_URI = ISSUER + ".well-known/jwks.json"
-
# config for common cases
DEFAULT_CONFIG = {
"enabled": True,
@@ -66,9 +62,9 @@ DEFAULT_CONFIG = {
EXPLICIT_ENDPOINT_CONFIG = {
**DEFAULT_CONFIG,
"discover": False,
- "authorization_endpoint": AUTHORIZATION_ENDPOINT,
- "token_endpoint": TOKEN_ENDPOINT,
- "jwks_uri": JWKS_URI,
+ "authorization_endpoint": ISSUER + "authorize",
+ "token_endpoint": ISSUER + "token",
+ "jwks_uri": ISSUER + "jwks",
}
@@ -102,27 +98,6 @@ class TestMappingProviderFailures(TestMappingProvider):
}
-async def get_json(url: str) -> JsonDict:
- # Mock get_json calls to handle jwks & oidc discovery endpoints
- if url == WELL_KNOWN:
- # Minimal discovery document, as defined in OpenID.Discovery
- # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
- return {
- "issuer": ISSUER,
- "authorization_endpoint": AUTHORIZATION_ENDPOINT,
- "token_endpoint": TOKEN_ENDPOINT,
- "jwks_uri": JWKS_URI,
- "userinfo_endpoint": USERINFO_ENDPOINT,
- "response_types_supported": ["code"],
- "subject_types_supported": ["public"],
- "id_token_signing_alg_values_supported": ["RS256"],
- }
- elif url == JWKS_URI:
- return {"keys": []}
-
- return {}
-
-
def _key_file_path() -> str:
"""path to a file containing the private half of a test key"""
@@ -159,11 +134,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
return config
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- self.http_client = Mock(spec=["get_json"])
- self.http_client.get_json.side_effect = get_json
- self.http_client.user_agent = b"Synapse Test"
+ self.fake_server = FakeOidcServer(clock=clock, issuer=ISSUER)
- hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
+ hs = self.setup_test_homeserver()
+ self.hs_patcher = self.fake_server.patch_homeserver(hs=hs)
+ self.hs_patcher.start()
self.handler = hs.get_oidc_handler()
self.provider = self.handler._providers["oidc"]
@@ -175,18 +150,51 @@ class OidcHandlerTestCase(HomeserverTestCase):
# Reduce the number of attempts when generating MXIDs.
sso_handler._MAP_USERNAME_RETRIES = 3
+ auth_handler = hs.get_auth_handler()
+ # Mock the complete SSO login method.
+ self.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[assignment]
+
return hs
+ def tearDown(self) -> None:
+ self.hs_patcher.stop()
+ return super().tearDown()
+
+ def reset_mocks(self):
+ """Reset all the Mocks."""
+ self.fake_server.reset_mocks()
+ self.render_error.reset_mock()
+ self.complete_sso_login.reset_mock()
+
def metadata_edit(self, values):
"""Modify the result that will be returned by the well-known query"""
- async def patched_get_json(uri):
- res = await get_json(uri)
- if uri == WELL_KNOWN:
- res.update(values)
- return res
+ metadata = self.fake_server.get_metadata()
+ metadata.update(values)
+ return patch.object(self.fake_server, "get_metadata", return_value=metadata)
- return patch.object(self.http_client, "get_json", patched_get_json)
+ def start_authorization(
+ self,
+ userinfo: dict,
+ client_redirect_url: str = "http://client/redirect",
+ scope: str = "openid",
+ with_sid: bool = False,
+ ) -> Tuple[SynapseRequest, FakeAuthorizationGrant]:
+ """Start an authorization request, and get the callback request back."""
+ nonce = random_string(10)
+ state = random_string(10)
+
+ code, grant = self.fake_server.start_authorization(
+ userinfo=userinfo,
+ scope=scope,
+ client_id=self.provider._client_auth.client_id,
+ redirect_uri=self.provider._callback_url,
+ nonce=nonce,
+ with_sid=with_sid,
+ )
+ session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
+ return _build_callback_request(code, state, session), grant
def assertRenderedError(self, error, error_description=None):
self.render_error.assert_called_once()
@@ -210,52 +218,54 @@ class OidcHandlerTestCase(HomeserverTestCase):
"""The handler should discover the endpoints from OIDC discovery document."""
# This would throw if some metadata were invalid
metadata = self.get_success(self.provider.load_metadata())
- self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+ self.fake_server.get_metadata_handler.assert_called_once()
- self.assertEqual(metadata.issuer, ISSUER)
- self.assertEqual(metadata.authorization_endpoint, AUTHORIZATION_ENDPOINT)
- self.assertEqual(metadata.token_endpoint, TOKEN_ENDPOINT)
- self.assertEqual(metadata.jwks_uri, JWKS_URI)
- # FIXME: it seems like authlib does not have that defined in its metadata models
- # self.assertEqual(metadata.userinfo_endpoint, USERINFO_ENDPOINT)
+ self.assertEqual(metadata.issuer, self.fake_server.issuer)
+ self.assertEqual(
+ metadata.authorization_endpoint,
+ self.fake_server.authorization_endpoint,
+ )
+ self.assertEqual(metadata.token_endpoint, self.fake_server.token_endpoint)
+ self.assertEqual(metadata.jwks_uri, self.fake_server.jwks_uri)
+ # It seems like authlib does not have that defined in its metadata models
+ self.assertEqual(
+ metadata.get("userinfo_endpoint"),
+ self.fake_server.userinfo_endpoint,
+ )
# subsequent calls should be cached
- self.http_client.reset_mock()
+ self.reset_mocks()
self.get_success(self.provider.load_metadata())
- self.http_client.get_json.assert_not_called()
+ self.fake_server.get_metadata_handler.assert_not_called()
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
def test_no_discovery(self) -> None:
"""When discovery is disabled, it should not try to load from discovery document."""
self.get_success(self.provider.load_metadata())
- self.http_client.get_json.assert_not_called()
+ self.fake_server.get_metadata_handler.assert_not_called()
- @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
+ @override_config({"oidc_config": DEFAULT_CONFIG})
def test_load_jwks(self) -> None:
"""JWKS loading is done once (then cached) if used."""
jwks = self.get_success(self.provider.load_jwks())
- self.http_client.get_json.assert_called_once_with(JWKS_URI)
- self.assertEqual(jwks, {"keys": []})
+ self.fake_server.get_jwks_handler.assert_called_once()
+ self.assertEqual(jwks, self.fake_server.get_jwks())
# subsequent calls should be cached…
- self.http_client.reset_mock()
+ self.reset_mocks()
self.get_success(self.provider.load_jwks())
- self.http_client.get_json.assert_not_called()
+ self.fake_server.get_jwks_handler.assert_not_called()
# …unless forced
- self.http_client.reset_mock()
+ self.reset_mocks()
self.get_success(self.provider.load_jwks(force=True))
- self.http_client.get_json.assert_called_once_with(JWKS_URI)
+ self.fake_server.get_jwks_handler.assert_called_once()
- # Throw if the JWKS uri is missing
- original = self.provider.load_metadata
-
- async def patched_load_metadata():
- m = (await original()).copy()
- m.update({"jwks_uri": None})
- return m
-
- with patch.object(self.provider, "load_metadata", patched_load_metadata):
+ with self.metadata_edit({"jwks_uri": None}):
+ # If we don't do this, the load_metadata call will throw because of the
+ # missing jwks_uri
+ self.provider._user_profile_method = "userinfo_endpoint"
+ self.get_success(self.provider.load_metadata(force=True))
self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
@override_config({"oidc_config": DEFAULT_CONFIG})
@@ -359,7 +369,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.provider.handle_redirect_request(req, b"http://client/redirect")
)
)
- auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
+ auth_endpoint = urlparse(self.fake_server.authorization_endpoint)
self.assertEqual(url.scheme, auth_endpoint.scheme)
self.assertEqual(url.netloc, auth_endpoint.netloc)
@@ -424,48 +434,34 @@ class OidcHandlerTestCase(HomeserverTestCase):
with self.assertRaises(AttributeError):
_ = mapping_provider.get_extra_attributes
- token = {
- "type": "bearer",
- "id_token": "id_token",
- "access_token": "access_token",
- }
username = "bar"
userinfo = {
"sub": "foo",
"username": username,
}
expected_user_id = "@%s:%s" % (username, self.hs.hostname)
- self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
- self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
- self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
- code = "code"
- state = "state"
- nonce = "nonce"
client_redirect_url = "http://client/redirect"
- ip_address = "10.0.0.1"
- session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
- request = _build_callback_request(code, state, session, ip_address=ip_address)
-
+ request, _ = self.start_authorization(
+ userinfo, client_redirect_url=client_redirect_url
+ )
self.get_success(self.handler.handle_oidc_callback(request))
- auth_handler.complete_sso_login.assert_called_once_with(
+ self.complete_sso_login.assert_called_once_with(
expected_user_id,
- "oidc",
+ self.provider.idp_id,
request,
client_redirect_url,
None,
new_user=True,
auth_provider_session_id=None,
)
- self.provider._exchange_code.assert_called_once_with(code)
- self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
- self.provider._fetch_userinfo.assert_not_called()
+ self.fake_server.post_token_handler.assert_called_once()
+ self.fake_server.get_userinfo_handler.assert_not_called()
self.render_error.assert_not_called()
# Handle mapping errors
+ request, _ = self.start_authorization(userinfo)
with patch.object(
self.provider,
"_remote_id_from_userinfo",
@@ -475,81 +471,63 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRenderedError("mapping_error")
# Handle ID token errors
- self.provider._parse_id_token = simple_async_mock(raises=Exception()) # type: ignore[assignment]
- self.get_success(self.handler.handle_oidc_callback(request))
+ request, _ = self.start_authorization(userinfo)
+ with self.fake_server.id_token_override({"iss": "https://bad.issuer/"}):
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_token")
- auth_handler.complete_sso_login.reset_mock()
- self.provider._exchange_code.reset_mock()
- self.provider._parse_id_token.reset_mock()
- self.provider._fetch_userinfo.reset_mock()
+ self.reset_mocks()
# With userinfo fetching
self.provider._user_profile_method = "userinfo_endpoint"
- token = {
- "type": "bearer",
- "access_token": "access_token",
- }
- self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
+ # Without the "openid" scope, the FakeProvider does not generate an id_token
+ request, _ = self.start_authorization(userinfo, scope="")
self.get_success(self.handler.handle_oidc_callback(request))
- auth_handler.complete_sso_login.assert_called_once_with(
+ self.complete_sso_login.assert_called_once_with(
expected_user_id,
- "oidc",
+ self.provider.idp_id,
request,
- client_redirect_url,
+ ANY,
None,
new_user=False,
auth_provider_session_id=None,
)
- self.provider._exchange_code.assert_called_once_with(code)
- self.provider._parse_id_token.assert_not_called()
- self.provider._fetch_userinfo.assert_called_once_with(token)
+ self.fake_server.post_token_handler.assert_called_once()
+ self.fake_server.get_userinfo_handler.assert_called_once()
self.render_error.assert_not_called()
+ self.reset_mocks()
+
# With an ID token, userinfo fetching and sid in the ID token
self.provider._user_profile_method = "userinfo_endpoint"
- token = {
- "type": "bearer",
- "access_token": "access_token",
- "id_token": "id_token",
- }
- id_token = {
- "sid": "abcdefgh",
- }
- self.provider._parse_id_token = simple_async_mock(return_value=id_token) # type: ignore[assignment]
- self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
- auth_handler.complete_sso_login.reset_mock()
- self.provider._fetch_userinfo.reset_mock()
+ request, grant = self.start_authorization(userinfo, with_sid=True)
+ self.assertIsNotNone(grant.sid)
self.get_success(self.handler.handle_oidc_callback(request))
- auth_handler.complete_sso_login.assert_called_once_with(
+ self.complete_sso_login.assert_called_once_with(
expected_user_id,
- "oidc",
+ self.provider.idp_id,
request,
- client_redirect_url,
+ ANY,
None,
new_user=False,
- auth_provider_session_id=id_token["sid"],
+ auth_provider_session_id=grant.sid,
)
- self.provider._exchange_code.assert_called_once_with(code)
- self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
- self.provider._fetch_userinfo.assert_called_once_with(token)
+ self.fake_server.post_token_handler.assert_called_once()
+ self.fake_server.get_userinfo_handler.assert_called_once()
self.render_error.assert_not_called()
# Handle userinfo fetching error
- self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) # type: ignore[assignment]
- self.get_success(self.handler.handle_oidc_callback(request))
+ request, _ = self.start_authorization(userinfo)
+ with self.fake_server.buggy_endpoint(userinfo=True):
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("fetch_error")
- # Handle code exchange failure
- from synapse.handlers.oidc import OidcError
-
- self.provider._exchange_code = simple_async_mock( # type: ignore[assignment]
- raises=OidcError("invalid_request")
- )
- self.get_success(self.handler.handle_oidc_callback(request))
- self.assertRenderedError("invalid_request")
+ request, _ = self.start_authorization(userinfo)
+ with self.fake_server.buggy_endpoint(token=True):
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("server_error")
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback_session(self) -> None:
@@ -599,18 +577,22 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
def test_exchange_code(self) -> None:
"""Code exchange behaves correctly and handles various error scenarios."""
- token = {"type": "bearer"}
- token_json = json.dumps(token).encode("utf-8")
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
+ token = {
+ "type": "Bearer",
+ "access_token": "aabbcc",
+ }
+
+ self.fake_server.post_token_handler.side_effect = None
+ self.fake_server.post_token_handler.return_value = FakeResponse.json(
+ payload=token
)
code = "code"
ret = self.get_success(self.provider._exchange_code(code))
- kwargs = self.http_client.request.call_args[1]
+ kwargs = self.fake_server.request.call_args[1]
self.assertEqual(ret, token)
self.assertEqual(kwargs["method"], "POST")
- self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+ self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)
args = parse_qs(kwargs["data"].decode("utf-8"))
self.assertEqual(args["grant_type"], ["authorization_code"])
@@ -620,12 +602,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
# Test error handling
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(
- code=400,
- phrase=b"Bad Request",
- body=b'{"error": "foo", "error_description": "bar"}',
- )
+ self.fake_server.post_token_handler.return_value = FakeResponse.json(
+ code=400, payload={"error": "foo", "error_description": "bar"}
)
from synapse.handlers.oidc import OidcError
@@ -634,46 +612,30 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(exc.value.error_description, "bar")
# Internal server error with no JSON body
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(
- code=500,
- phrase=b"Internal Server Error",
- body=b"Not JSON",
- )
+ self.fake_server.post_token_handler.return_value = FakeResponse(
+ code=500, body=b"Not JSON"
)
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "server_error")
# Internal server error with JSON body
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(
- code=500,
- phrase=b"Internal Server Error",
- body=b'{"error": "internal_server_error"}',
- )
+ self.fake_server.post_token_handler.return_value = FakeResponse.json(
+ code=500, payload={"error": "internal_server_error"}
)
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "internal_server_error")
# 4xx error without "error" field
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(
- code=400,
- phrase=b"Bad request",
- body=b"{}",
- )
+ self.fake_server.post_token_handler.return_value = FakeResponse.json(
+ code=400, payload={}
)
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "server_error")
# 2xx error with "error" field
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(
- code=200,
- phrase=b"OK",
- body=b'{"error": "some_error"}',
- )
+ self.fake_server.post_token_handler.return_value = FakeResponse.json(
+ code=200, payload={"error": "some_error"}
)
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "some_error")
@@ -697,11 +659,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
"""Test that code exchange works with a JWK client secret."""
from authlib.jose import jwt
- token = {"type": "bearer"}
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(
- code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
- )
+ token = {
+ "type": "Bearer",
+ "access_token": "aabbcc",
+ }
+
+ self.fake_server.post_token_handler.side_effect = None
+ self.fake_server.post_token_handler.return_value = FakeResponse.json(
+ payload=token
)
code = "code"
@@ -714,9 +679,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(ret, token)
# the request should have hit the token endpoint
- kwargs = self.http_client.request.call_args[1]
+ kwargs = self.fake_server.request.call_args[1]
self.assertEqual(kwargs["method"], "POST")
- self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+ self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)
# the client secret provided to the should be a jwt which can be checked with
# the public key
@@ -750,11 +715,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
def test_exchange_code_no_auth(self) -> None:
"""Test that code exchange works with no client secret."""
- token = {"type": "bearer"}
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(
- code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
- )
+ token = {
+ "type": "Bearer",
+ "access_token": "aabbcc",
+ }
+
+ self.fake_server.post_token_handler.side_effect = None
+ self.fake_server.post_token_handler.return_value = FakeResponse.json(
+ payload=token
)
code = "code"
ret = self.get_success(self.provider._exchange_code(code))
@@ -762,9 +730,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(ret, token)
# the request should have hit the token endpoint
- kwargs = self.http_client.request.call_args[1]
+ kwargs = self.fake_server.request.call_args[1]
self.assertEqual(kwargs["method"], "POST")
- self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+ self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)
# check the POSTed data
args = parse_qs(kwargs["data"].decode("utf-8"))
@@ -787,37 +755,19 @@ class OidcHandlerTestCase(HomeserverTestCase):
"""
Login while using a mapping provider that implements get_extra_attributes.
"""
- token = {
- "type": "bearer",
- "id_token": "id_token",
- "access_token": "access_token",
- }
userinfo = {
"sub": "foo",
"username": "foo",
"phone": "1234567",
}
- self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
- self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
-
- state = "state"
- client_redirect_url = "http://client/redirect"
- session = self._generate_oidc_session_token(
- state=state,
- nonce="nonce",
- client_redirect_url=client_redirect_url,
- )
- request = _build_callback_request("code", state, session)
-
+ request, _ = self.start_authorization(userinfo)
self.get_success(self.handler.handle_oidc_callback(request))
- auth_handler.complete_sso_login.assert_called_once_with(
+ self.complete_sso_login.assert_called_once_with(
"@foo:test",
- "oidc",
+ self.provider.idp_id,
request,
- client_redirect_url,
+ ANY,
{"phone": "1234567"},
new_user=True,
auth_provider_session_id=None,
@@ -826,41 +776,40 @@ class OidcHandlerTestCase(HomeserverTestCase):
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_map_userinfo_to_user(self) -> None:
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
-
userinfo: dict = {
"sub": "test_user",
"username": "test_user",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_called_once_with(
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_called_once_with(
"@test_user:test",
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=True,
auth_provider_session_id=None,
)
- auth_handler.complete_sso_login.reset_mock()
+ self.reset_mocks()
# Some providers return an integer ID.
userinfo = {
"sub": 1234,
"username": "test_user_2",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_called_once_with(
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_called_once_with(
"@test_user_2:test",
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=True,
auth_provider_session_id=None,
)
- auth_handler.complete_sso_login.reset_mock()
+ self.reset_mocks()
# Test if the mxid is already taken
store = self.hs.get_datastores().main
@@ -869,8 +818,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user3.to_string(), password_hash=None)
)
userinfo = {"sub": "test3", "username": "test_user_3"}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
self.assertRenderedError(
"mapping_error",
"Mapping provider does not support de-duplicating Matrix IDs",
@@ -885,38 +835,37 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user.to_string(), password_hash=None)
)
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
-
# Map a user via SSO.
userinfo = {
"sub": "test",
"username": "test_user",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_called_once_with(
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_called_once_with(
user.to_string(),
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=False,
auth_provider_session_id=None,
)
- auth_handler.complete_sso_login.reset_mock()
+ self.reset_mocks()
# Subsequent calls should map to the same mxid.
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_called_once_with(
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_called_once_with(
user.to_string(),
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=False,
auth_provider_session_id=None,
)
- auth_handler.complete_sso_login.reset_mock()
+ self.reset_mocks()
# Note that a second SSO user can be mapped to the same Matrix ID. (This
# requires a unique sub, but something that maps to the same matrix ID,
@@ -927,17 +876,18 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test1",
"username": "test_user",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_called_once_with(
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_called_once_with(
user.to_string(),
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=False,
auth_provider_session_id=None,
)
- auth_handler.complete_sso_login.reset_mock()
+ self.reset_mocks()
# Register some non-exact matching cases.
user2 = UserID.from_string("@TEST_user_2:test")
@@ -954,8 +904,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test2",
"username": "TEST_USER_2",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
args = self.assertRenderedError("mapping_error")
self.assertTrue(
args[2].startswith(
@@ -969,11 +920,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user2.to_string(), password_hash=None)
)
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_called_once_with(
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_called_once_with(
"@TEST_USER_2:test",
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=False,
@@ -983,9 +935,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_map_userinfo_to_invalid_localpart(self) -> None:
"""If the mapping provider generates an invalid localpart it should be rejected."""
- self.get_success(
- _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
- )
+ userinfo = {"sub": "test2", "username": "föö"}
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("mapping_error", "localpart is invalid: föö")
@override_config(
@@ -1000,9 +952,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
def test_map_userinfo_to_user_retries(self) -> None:
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
-
store = self.hs.get_datastores().main
self.get_success(
store.register_user(user_id="@test_user:test", password_hash=None)
@@ -1011,19 +960,20 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test",
"username": "test_user",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
# test_user is already taken, so test_user1 gets registered instead.
- auth_handler.complete_sso_login.assert_called_once_with(
+ self.complete_sso_login.assert_called_once_with(
"@test_user1:test",
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=True,
auth_provider_session_id=None,
)
- auth_handler.complete_sso_login.reset_mock()
+ self.reset_mocks()
# Register all of the potential mxids for a particular OIDC username.
self.get_success(
@@ -1039,8 +989,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "tester",
"username": "tester",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
self.assertRenderedError(
"mapping_error", "Unable to generate a Matrix ID from the SSO response"
)
@@ -1052,7 +1003,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "tester",
"username": "",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("mapping_error", "localpart is invalid: ")
@override_config(
@@ -1071,7 +1023,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "tester",
"username": None,
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("mapping_error", "localpart is invalid: ")
@override_config(
@@ -1084,16 +1037,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
def test_attribute_requirements(self) -> None:
"""The required attributes must be met from the OIDC userinfo response."""
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
-
# userinfo lacking "test": "foobar" attribute should fail.
userinfo = {
"sub": "tester",
"username": "tester",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
# userinfo with "test": "foobar" attribute should succeed.
userinfo = {
@@ -1101,13 +1052,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "tester",
"test": "foobar",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
# check that the auth handler got called as expected
- auth_handler.complete_sso_login.assert_called_once_with(
+ self.complete_sso_login.assert_called_once_with(
"@tester:test",
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=True,
@@ -1124,21 +1076,20 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
def test_attribute_requirements_contains(self) -> None:
"""Test that auth succeeds if userinfo attribute CONTAINS required value"""
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
# userinfo with "test": ["foobar", "foo", "bar"] attribute should succeed.
userinfo = {
"sub": "tester",
"username": "tester",
"test": ["foobar", "foo", "bar"],
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
# check that the auth handler got called as expected
- auth_handler.complete_sso_login.assert_called_once_with(
+ self.complete_sso_login.assert_called_once_with(
"@tester:test",
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=True,
@@ -1158,16 +1109,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
Test that auth fails if attributes exist but don't match,
or are non-string values.
"""
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
# userinfo with "test": "not_foobar" attribute should fail
userinfo: dict = {
"sub": "tester",
"username": "tester",
"test": "not_foobar",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
# userinfo with "test": ["foo", "bar"] attribute should fail
userinfo = {
@@ -1175,8 +1125,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "tester",
"test": ["foo", "bar"],
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
# userinfo with "test": False attribute should fail
# this is largely just to ensure we don't crash here
@@ -1185,8 +1136,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "tester",
"test": False,
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
# userinfo with "test": None attribute should fail
# a value of None breaks the OIDC spec, but it's important to not crash here
@@ -1195,8 +1147,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "tester",
"test": None,
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
# userinfo with "test": 1 attribute should fail
# this is largely just to ensure we don't crash here
@@ -1205,8 +1158,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "tester",
"test": 1,
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
# userinfo with "test": 3.14 attribute should fail
# this is largely just to ensure we don't crash here
@@ -1215,8 +1169,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "tester",
"test": 3.14,
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
def _generate_oidc_session_token(
self,
@@ -1230,7 +1185,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
return self.handler._macaroon_generator.generate_oidc_session_token(
state=state,
session_data=OidcSessionData(
- idp_id="oidc",
+ idp_id=self.provider.idp_id,
nonce=nonce,
client_redirect_url=client_redirect_url,
ui_auth_session_id=ui_auth_session_id,
@@ -1238,41 +1193,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
-async def _make_callback_with_userinfo(
- hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect"
-) -> None:
- """Mock up an OIDC callback with the given userinfo dict
-
- We'll pull out the OIDC handler from the homeserver, stub out a couple of methods,
- and poke in the userinfo dict as if it were the response to an OIDC userinfo call.
-
- Args:
- hs: the HomeServer impl to send the callback to.
- userinfo: the OIDC userinfo dict
- client_redirect_url: the URL to redirect to on success.
- """
-
- handler = hs.get_oidc_handler()
- provider = handler._providers["oidc"]
- provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) # type: ignore[assignment]
- provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
- provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
-
- state = "state"
- session = handler._macaroon_generator.generate_oidc_session_token(
- state=state,
- session_data=OidcSessionData(
- idp_id="oidc",
- nonce="nonce",
- client_redirect_url=client_redirect_url,
- ui_auth_session_id="",
- ),
- )
- request = _build_callback_request("code", state, session)
-
- await handler.handle_oidc_callback(request)
-
-
def _build_callback_request(
code: str,
state: str,
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 4c62449c89..75934b1707 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -21,7 +21,6 @@ from unittest.mock import Mock
import synapse
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
-from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.module_api import ModuleApi
from synapse.rest.client import account, devices, login, logout, register
from synapse.types import JsonDict, UserID
@@ -167,16 +166,6 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock()
super().setUp()
- def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver()
- # Load the modules into the homeserver
- module_api = hs.get_module_api()
- for module, config in hs.config.modules.loaded_modules:
- module(config=config, api=module_api)
- load_legacy_password_auth_providers(hs)
-
- return hs
-
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_password_only_auth_progiver_login_legacy(self):
self.password_only_auth_provider_login_test_body()
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index c96dc6caf2..c5981ff965 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -15,6 +15,7 @@
from typing import Optional
from unittest.mock import Mock, call
+from parameterized import parameterized
from signedjson.key import generate_signing_key
from synapse.api.constants import EventTypes, Membership, PresenceState
@@ -37,6 +38,7 @@ from synapse.rest.client import room
from synapse.types import UserID, get_domain_from_id
from tests import unittest
+from tests.replication._base import BaseMultiWorkerStreamTestCase
class PresenceUpdateTestCase(unittest.HomeserverTestCase):
@@ -505,7 +507,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertEqual(state, new_state)
-class PresenceHandlerTestCase(unittest.HomeserverTestCase):
+class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
def prepare(self, reactor, clock, hs):
self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
@@ -716,20 +718,47 @@ class PresenceHandlerTestCase(unittest.HomeserverTestCase):
# our status message should be the same as it was before
self.assertEqual(state.status_msg, status_msg)
- def test_set_presence_from_syncing_keeps_busy(self):
- """Test that presence set by syncing doesn't affect busy status"""
- # while this isn't the default
- self.presence_handler._busy_presence_enabled = True
+ @parameterized.expand([(False,), (True,)])
+ @unittest.override_config(
+ {
+ "experimental_features": {
+ "msc3026_enabled": True,
+ },
+ }
+ )
+ def test_set_presence_from_syncing_keeps_busy(self, test_with_workers: bool):
+ """Test that presence set by syncing doesn't affect busy status
+ Args:
+ test_with_workers: If True, check the presence state of the user by calling
+ /sync against a worker, rather than the main process.
+ """
user_id = "@test:server"
status_msg = "I'm busy!"
+ # By default, we call /sync against the main process.
+ worker_to_sync_against = self.hs
+ if test_with_workers:
+ # Create a worker and use it to handle /sync traffic instead.
+ # This is used to test that presence changes get replicated from workers
+ # to the main process correctly.
+ worker_to_sync_against = self.make_worker_hs(
+ "synapse.app.generic_worker", {"worker_name": "presence_writer"}
+ )
+
+ # Set presence to BUSY
self._set_presencestate_with_status_msg(user_id, PresenceState.BUSY, status_msg)
+ # Perform a sync with a presence state other than busy. This should NOT change
+ # our presence status; we only change from busy if we explicitly set it via
+ # /presence/*.
self.get_success(
- self.presence_handler.user_syncing(user_id, True, PresenceState.ONLINE)
+ worker_to_sync_against.get_presence_handler().user_syncing(
+ user_id, True, PresenceState.ONLINE
+ )
)
+ # Check against the main process that the user's presence did not change.
state = self.get_success(
self.presence_handler.get_state(UserID.from_string(user_id))
)
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index f88c725a42..675aa023ac 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -14,6 +14,8 @@
from typing import Any, Awaitable, Callable, Dict
from unittest.mock import Mock
+from parameterized import parameterized
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.types
@@ -327,6 +329,53 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertFalse(res)
+ @unittest.override_config(
+ {"server_name": "test:8888", "allowed_avatar_mimetypes": ["image/png"]}
+ )
+ def test_avatar_constraint_on_local_server_with_port(self):
+ """Test that avatar metadata is correctly fetched when the media is on a local
+ server and the server has an explicit port.
+
+ (This was previously a bug)
+ """
+ local_server_name = self.hs.config.server.server_name
+ media_id = "local"
+ local_mxc = f"mxc://{local_server_name}/{media_id}"
+
+ # mock up the existence of the avatar file
+ self._setup_local_files({media_id: {"mimetype": "image/png"}})
+
+ # and now check that check_avatar_size_and_mime_type is happy
+ self.assertTrue(
+ self.get_success(self.handler.check_avatar_size_and_mime_type(local_mxc))
+ )
+
+ @parameterized.expand([("remote",), ("remote:1234",)])
+ @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
+ def test_check_avatar_on_remote_server(self, remote_server_name: str) -> None:
+ """Test that avatar metadata is correctly fetched from a remote server"""
+ media_id = "remote"
+ remote_mxc = f"mxc://{remote_server_name}/{media_id}"
+
+ # if the media is remote, check_avatar_size_and_mime_type just checks the
+ # media cache, so we don't need to instantiate a real remote server. It is
+ # sufficient to poke an entry into the db.
+ self.get_success(
+ self.hs.get_datastores().main.store_cached_remote_media(
+ media_id=media_id,
+ media_type="image/png",
+ media_length=50,
+ origin=remote_server_name,
+ time_now_ms=self.clock.time_msec(),
+ upload_name=None,
+ filesystem_id="xyz",
+ )
+ )
+
+ self.assertTrue(
+ self.get_success(self.handler.check_avatar_size_and_mime_type(remote_mxc))
+ )
+
def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]):
"""Stores metadata about files in the database.
diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py
index a95868b5c0..b55238650c 100644
--- a/tests/handlers/test_receipts.py
+++ b/tests/handlers/test_receipts.py
@@ -25,7 +25,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.event_source = hs.get_event_sources().sources.receipt
- def test_filters_out_private_receipt(self):
+ def test_filters_out_private_receipt(self) -> None:
self._test_filters_private(
[
{
@@ -45,7 +45,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
[],
)
- def test_filters_out_private_receipt_and_ignores_rest(self):
+ def test_filters_out_private_receipt_and_ignores_rest(self) -> None:
self._test_filters_private(
[
{
@@ -84,7 +84,9 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
],
)
- def test_filters_out_event_with_only_private_receipts_and_ignores_the_rest(self):
+ def test_filters_out_event_with_only_private_receipts_and_ignores_the_rest(
+ self,
+ ) -> None:
self._test_filters_private(
[
{
@@ -125,7 +127,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
],
)
- def test_handles_empty_event(self):
+ def test_handles_empty_event(self) -> None:
self._test_filters_private(
[
{
@@ -160,7 +162,9 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
],
)
- def test_filters_out_receipt_event_with_only_private_receipt_and_ignores_rest(self):
+ def test_filters_out_receipt_event_with_only_private_receipt_and_ignores_rest(
+ self,
+ ) -> None:
self._test_filters_private(
[
{
@@ -207,7 +211,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
],
)
- def test_handles_string_data(self):
+ def test_handles_string_data(self) -> None:
"""
Tests that an invalid shape for read-receipts is handled.
Context: https://github.com/matrix-org/synapse/issues/10603
@@ -242,7 +246,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
],
)
- def test_leaves_our_private_and_their_public(self):
+ def test_leaves_our_private_and_their_public(self) -> None:
self._test_filters_private(
[
{
@@ -296,7 +300,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
],
)
- def test_we_do_not_mutate(self):
+ def test_we_do_not_mutate(self) -> None:
"""Ensure the input values are not modified."""
events = [
{
@@ -320,7 +324,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
def _test_filters_private(
self, events: List[JsonDict], expected_output: List[JsonDict]
- ):
+ ) -> None:
"""Tests that the _filter_out_private returns the expected output"""
filtered_events = self.event_source.filter_out_private_receipts(
events, "@me:server.org"
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 23f35d5bf5..765df75d91 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -22,7 +22,6 @@ from synapse.api.errors import (
ResourceLimitError,
SynapseError,
)
-from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, RoomID, UserID, create_requester
@@ -144,12 +143,6 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
config=hs_config, federation_client=self.mock_federation_client
)
- load_legacy_spam_checkers(hs)
-
- module_api = hs.get_module_api()
- for module, config in hs.config.modules.loaded_modules:
- module(config=config, api=module_api)
-
return hs
def prepare(self, reactor, clock, hs):
@@ -504,7 +497,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- event_creation_handler.handle_new_client_event(requester, event, context)
+ event_creation_handler.handle_new_client_event(
+ requester, events_and_context=[(event, context)]
+ )
)
# Register a second user, which won't be be in the room (or even have an invite)
diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py
index 254e7e4b80..6bbfd5dc84 100644
--- a/tests/handlers/test_room_member.py
+++ b/tests/handlers/test_room_member.py
@@ -1,4 +1,3 @@
-from http import HTTPStatus
from unittest.mock import Mock, patch
from twisted.test.proto_helpers import MemoryReactor
@@ -7,7 +6,7 @@ import synapse.rest.admin
import synapse.rest.client.login
import synapse.rest.client.room
from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import LimitExceededError
+from synapse.api.errors import LimitExceededError, SynapseError
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events import FrozenEventV3
from synapse.federation.federation_client import SendJoinResult
@@ -15,10 +14,14 @@ from synapse.server import HomeServer
from synapse.types import UserID, create_requester
from synapse.util import Clock
-from tests.replication._base import RedisMultiWorkerStreamTestCase
+from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request
from tests.test_utils import make_awaitable
-from tests.unittest import FederatingHomeserverTestCase, override_config
+from tests.unittest import (
+ FederatingHomeserverTestCase,
+ HomeserverTestCase,
+ override_config,
+)
class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
@@ -217,7 +220,7 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
# - trying to remote-join again.
-class TestReplicatedJoinsLimitedByPerRoomRateLimiter(RedisMultiWorkerStreamTestCase):
+class TestReplicatedJoinsLimitedByPerRoomRateLimiter(BaseMultiWorkerStreamTestCase):
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.client.login.register_servlets,
@@ -260,7 +263,7 @@ class TestReplicatedJoinsLimitedByPerRoomRateLimiter(RedisMultiWorkerStreamTestC
f"/_matrix/client/v3/rooms/{self.room_id}/join",
access_token=self.bob_token,
)
- self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
+ self.assertEqual(channel.code, 200, channel.json_body)
# wait for join to arrive over replication
self.replicate()
@@ -288,3 +291,88 @@ class TestReplicatedJoinsLimitedByPerRoomRateLimiter(RedisMultiWorkerStreamTestC
),
LimitExceededError,
)
+
+
+class RoomMemberMasterHandlerTestCase(HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ synapse.rest.client.login.register_servlets,
+ synapse.rest.client.room.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.handler = hs.get_room_member_handler()
+ self.store = hs.get_datastores().main
+
+ # Create two users.
+ self.alice = self.register_user("alice", "pass")
+ self.alice_ID = UserID.from_string(self.alice)
+ self.alice_token = self.login("alice", "pass")
+ self.bob = self.register_user("bob", "pass")
+ self.bob_ID = UserID.from_string(self.bob)
+ self.bob_token = self.login("bob", "pass")
+
+ # Create a room on this homeserver.
+ self.room_id = self.helper.create_room_as(self.alice, tok=self.alice_token)
+
+ def test_leave_and_forget(self) -> None:
+ """Tests that forget a room is successfully. The test is performed with two users,
+ as forgetting by the last user respectively after all users had left the
+ is a special edge case."""
+ self.helper.join(self.room_id, user=self.bob, tok=self.bob_token)
+
+ # alice is not the last room member that leaves and forgets the room
+ self.helper.leave(self.room_id, user=self.alice, tok=self.alice_token)
+ self.get_success(self.handler.forget(self.alice_ID, self.room_id))
+ self.assertTrue(
+ self.get_success(self.store.did_forget(self.alice, self.room_id))
+ )
+
+ # the server has not forgotten the room
+ self.assertFalse(
+ self.get_success(self.store.is_locally_forgotten_room(self.room_id))
+ )
+
+ def test_leave_and_forget_last_user(self) -> None:
+ """Tests that forget a room is successfully when the last user has left the room."""
+
+ # alice is the last room member that leaves and forgets the room
+ self.helper.leave(self.room_id, user=self.alice, tok=self.alice_token)
+ self.get_success(self.handler.forget(self.alice_ID, self.room_id))
+ self.assertTrue(
+ self.get_success(self.store.did_forget(self.alice, self.room_id))
+ )
+
+ # the server has forgotten the room
+ self.assertTrue(
+ self.get_success(self.store.is_locally_forgotten_room(self.room_id))
+ )
+
+ def test_forget_when_not_left(self) -> None:
+ """Tests that a user cannot not forgets a room that has not left."""
+ self.get_failure(self.handler.forget(self.alice_ID, self.room_id), SynapseError)
+
+ def test_rejoin_forgotten_by_user(self) -> None:
+ """Test that a user that has forgotten a room can do a re-join.
+ The room was not forgotten from the local server.
+ One local user is still member of the room."""
+ self.helper.join(self.room_id, user=self.bob, tok=self.bob_token)
+
+ self.helper.leave(self.room_id, user=self.alice, tok=self.alice_token)
+ self.get_success(self.handler.forget(self.alice_ID, self.room_id))
+ self.assertTrue(
+ self.get_success(self.store.did_forget(self.alice, self.room_id))
+ )
+
+ # the server has not forgotten the room
+ self.assertFalse(
+ self.get_success(self.store.is_locally_forgotten_room(self.room_id))
+ )
+
+ self.helper.join(self.room_id, user=self.alice, tok=self.alice_token)
+ # TODO: A join to a room does not invalidate the forgotten cache
+ # see https://github.com/matrix-org/synapse/issues/13262
+ self.store.did_forget.invalidate_all()
+ self.assertFalse(
+ self.get_success(self.store.did_forget(self.alice, self.room_id))
+ )
diff --git a/tests/handlers/test_sso.py b/tests/handlers/test_sso.py
new file mode 100644
index 0000000000..137deab138
--- /dev/null
+++ b/tests/handlers/test_sso.py
@@ -0,0 +1,145 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from http import HTTPStatus
+from typing import BinaryIO, Callable, Dict, List, Optional, Tuple
+from unittest.mock import Mock
+
+from twisted.test.proto_helpers import MemoryReactor
+from twisted.web.http_headers import Headers
+
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.client import RawHeaders
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests import unittest
+from tests.test_utils import SMALL_PNG, FakeResponse
+
+
+class TestSSOHandler(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.http_client = Mock(spec=["get_file"])
+ self.http_client.get_file.side_effect = mock_get_file
+ self.http_client.user_agent = b"Synapse Test"
+ hs = self.setup_test_homeserver(
+ proxied_blacklisted_http_client=self.http_client
+ )
+ return hs
+
+ async def test_set_avatar(self) -> None:
+ """Tests successfully setting the avatar of a newly created user"""
+ handler = self.hs.get_sso_handler()
+
+ # Create a new user to set avatar for
+ reg_handler = self.hs.get_registration_handler()
+ user_id = self.get_success(reg_handler.register_user(approved=True))
+
+ self.assertTrue(
+ self.get_success(handler.set_avatar(user_id, "http://my.server/me.png"))
+ )
+
+ # Ensure avatar is set on this newly created user,
+ # so no need to compare for the exact image
+ profile_handler = self.hs.get_profile_handler()
+ profile = self.get_success(profile_handler.get_profile(user_id))
+ self.assertIsNot(profile["avatar_url"], None)
+
+ @unittest.override_config({"max_avatar_size": 1})
+ async def test_set_avatar_too_big_image(self) -> None:
+ """Tests that saving an avatar fails when it is too big"""
+ handler = self.hs.get_sso_handler()
+
+ # any random user works since image check is supposed to fail
+ user_id = "@sso-user:test"
+
+ self.assertFalse(
+ self.get_success(handler.set_avatar(user_id, "http://my.server/me.png"))
+ )
+
+ @unittest.override_config({"allowed_avatar_mimetypes": ["image/jpeg"]})
+ async def test_set_avatar_incorrect_mime_type(self) -> None:
+ """Tests that saving an avatar fails when its mime type is not allowed"""
+ handler = self.hs.get_sso_handler()
+
+ # any random user works since image check is supposed to fail
+ user_id = "@sso-user:test"
+
+ self.assertFalse(
+ self.get_success(handler.set_avatar(user_id, "http://my.server/me.png"))
+ )
+
+ async def test_skip_saving_avatar_when_not_changed(self) -> None:
+ """Tests whether saving of avatar correctly skips if the avatar hasn't
+ changed"""
+ handler = self.hs.get_sso_handler()
+
+ # Create a new user to set avatar for
+ reg_handler = self.hs.get_registration_handler()
+ user_id = self.get_success(reg_handler.register_user(approved=True))
+
+ # set avatar for the first time, should be a success
+ self.assertTrue(
+ self.get_success(handler.set_avatar(user_id, "http://my.server/me.png"))
+ )
+
+ # get avatar picture for comparison after another attempt
+ profile_handler = self.hs.get_profile_handler()
+ profile = self.get_success(profile_handler.get_profile(user_id))
+ url_to_match = profile["avatar_url"]
+
+ # set same avatar for the second time, should be a success
+ self.assertTrue(
+ self.get_success(handler.set_avatar(user_id, "http://my.server/me.png"))
+ )
+
+ # compare avatar picture's url from previous step
+ profile = self.get_success(profile_handler.get_profile(user_id))
+ self.assertEqual(profile["avatar_url"], url_to_match)
+
+
+async def mock_get_file(
+ url: str,
+ output_stream: BinaryIO,
+ max_size: Optional[int] = None,
+ headers: Optional[RawHeaders] = None,
+ is_allowed_content_type: Optional[Callable[[str], bool]] = None,
+) -> Tuple[int, Dict[bytes, List[bytes]], str, int]:
+
+ fake_response = FakeResponse(code=404)
+ if url == "http://my.server/me.png":
+ fake_response = FakeResponse(
+ code=200,
+ headers=Headers(
+ {"Content-Type": ["image/png"], "Content-Length": [str(len(SMALL_PNG))]}
+ ),
+ body=SMALL_PNG,
+ )
+
+ if max_size is not None and max_size < len(SMALL_PNG):
+ raise SynapseError(
+ HTTPStatus.BAD_GATEWAY,
+ "Requested file is too large > %r bytes" % (max_size,),
+ Codes.TOO_LARGE,
+ )
+
+ if is_allowed_content_type and not is_allowed_content_type("image/png"):
+ raise SynapseError(
+ HTTPStatus.BAD_GATEWAY,
+ (
+ "Requested file's content type not allowed for this operation: %s"
+ % "image/png"
+ ),
+ )
+
+ output_stream.write(fake_response.body)
+
+ return len(SMALL_PNG), {b"Content-Type": [b"image/png"]}, "", 200
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index e3f38fbcc5..ab5c101eb7 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -159,6 +159,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Blow away caches (supported room versions can only change due to a restart).
self.store.get_rooms_for_user_with_stream_ordering.invalidate_all()
+ self.store.get_rooms_for_user.invalidate_all()
self.get_success(self.store._get_event_cache.clear())
self.store._event_ref.clear()
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 7af1333126..9c821b3042 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -25,7 +25,7 @@ from synapse.api.constants import EduTypes
from synapse.api.errors import AuthError
from synapse.federation.transport.server import TransportLayerServer
from synapse.server import HomeServer
-from synapse.types import JsonDict, UserID, create_requester
+from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util import Clock
from tests import unittest
@@ -117,8 +117,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.room_members = []
- async def check_user_in_room(room_id: str, user_id: str) -> None:
- if user_id not in [u.to_string() for u in self.room_members]:
+ async def check_user_in_room(room_id: str, requester: Requester) -> None:
+ if requester.user.to_string() not in [
+ u.to_string() for u in self.room_members
+ ]:
raise AuthError(401, "User is not in the room")
return None
@@ -127,7 +129,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
async def check_host_in_room(room_id: str, server_name: str) -> bool:
return room_id == ROOM_ID
- hs.get_event_auth_handler().check_host_in_room = check_host_in_room
+ hs.get_event_auth_handler().is_host_in_room = check_host_in_room
async def get_current_hosts_in_room(room_id: str):
return {member.domain for member in self.room_members}
@@ -136,6 +138,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
get_current_hosts_in_room
)
+ hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = (
+ get_current_hosts_in_room
+ )
+
async def get_users_in_room(room_id: str):
return {str(u) for u in self.room_members}
diff --git a/tests/http/__init__.py b/tests/http/__init__.py
index e74f7f5b48..093537adef 100644
--- a/tests/http/__init__.py
+++ b/tests/http/__init__.py
@@ -13,6 +13,7 @@
# limitations under the License.
import os.path
import subprocess
+from typing import List
from zope.interface import implementer
@@ -70,14 +71,14 @@ subjectAltName = %(sanentries)s
"""
-def create_test_cert_file(sanlist):
+def create_test_cert_file(sanlist: List[bytes]) -> str:
"""build an x509 certificate file
Args:
- sanlist: list[bytes]: a list of subjectAltName values for the cert
+ sanlist: a list of subjectAltName values for the cert
Returns:
- str: the path to the file
+ The path to the file
"""
global cert_file_count
csr_filename = "server.csr"
diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py
index 994d8880b0..5071f83574 100644
--- a/tests/http/server/_base.py
+++ b/tests/http/server/_base.py
@@ -15,7 +15,6 @@
import inspect
import itertools
import logging
-from http import HTTPStatus
from typing import (
Any,
Callable,
@@ -78,7 +77,7 @@ def test_disconnect(
if expect_cancellation:
expected_code = HTTP_STATUS_REQUEST_CANCELLED
else:
- expected_code = HTTPStatus.OK
+ expected_code = 200
request = channel.request
if channel.is_finished():
@@ -141,6 +140,8 @@ def make_request_with_cancellation_test(
method: str,
path: str,
content: Union[bytes, str, JsonDict] = b"",
+ *,
+ token: Optional[str] = None,
) -> FakeChannel:
"""Performs a request repeatedly, disconnecting at successive `await`s, until
one completes.
@@ -212,7 +213,13 @@ def make_request_with_cancellation_test(
with deferred_patch.patch():
# Start the request.
channel = make_request(
- reactor, site, method, path, content, await_result=False
+ reactor,
+ site,
+ method,
+ path,
+ content,
+ await_result=False,
+ access_token=token,
)
request = channel.request
diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py
index c8cc21cadd..a801f002a0 100644
--- a/tests/http/test_endpoint.py
+++ b/tests/http/test_endpoint.py
@@ -25,6 +25,8 @@ class ServerNameTestCase(unittest.TestCase):
"[0abc:1def::1234]": ("[0abc:1def::1234]", None),
"1.2.3.4:1": ("1.2.3.4", 1),
"[0abc:1def::1234]:8080": ("[0abc:1def::1234]", 8080),
+ ":80": ("", 80),
+ "": ("", None),
}
for i, o in test_data.items():
@@ -42,6 +44,7 @@ class ServerNameTestCase(unittest.TestCase):
"newline.com\n",
".empty-label.com",
"1234:5678:80", # too many colons
+ ":80",
]
for i in test_data:
try:
diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py
index bb966c80c6..46166292fe 100644
--- a/tests/http/test_servlet.py
+++ b/tests/http/test_servlet.py
@@ -18,7 +18,6 @@ from typing import Tuple
from unittest.mock import Mock
from synapse.api.errors import Codes, SynapseError
-from synapse.http.server import cancellable
from synapse.http.servlet import (
RestServlet,
parse_json_object_from_request,
@@ -28,6 +27,7 @@ from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.server import HomeServer
from synapse.types import JsonDict
+from synapse.util.cancellation import cancellable
from tests import unittest
from tests.http.server._base import test_disconnect
@@ -35,11 +35,13 @@ from tests.http.server._base import test_disconnect
def make_request(content):
"""Make an object that acts enough like a request."""
- request = Mock(spec=["content"])
+ request = Mock(spec=["method", "uri", "content"])
if isinstance(content, dict):
content = json.dumps(content).encode("utf8")
+ request.method = bytes("STUB_METHOD", "ascii")
+ request.uri = bytes("/test_stub_uri", "ascii")
request.content = BytesIO(content)
return request
diff --git a/tests/logging/test_opentracing.py b/tests/logging/test_opentracing.py
index 3b14c76d7e..0917e478a5 100644
--- a/tests/logging/test_opentracing.py
+++ b/tests/logging/test_opentracing.py
@@ -25,6 +25,8 @@ from synapse.logging.context import (
from synapse.logging.opentracing import (
start_active_span,
start_active_span_follows_from,
+ tag_args,
+ trace_with_opname,
)
from synapse.util import Clock
@@ -38,8 +40,12 @@ try:
except ImportError:
jaeger_client = None # type: ignore
+import logging
+
from tests.unittest import TestCase
+logger = logging.getLogger(__name__)
+
class LogContextScopeManagerTestCase(TestCase):
"""
@@ -194,3 +200,80 @@ class LogContextScopeManagerTestCase(TestCase):
self._reporter.get_spans(),
[scopes[1].span, scopes[2].span, scopes[0].span],
)
+
+ def test_trace_decorator_sync(self) -> None:
+ """
+ Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
+ with sync functions
+ """
+ with LoggingContext("root context"):
+
+ @trace_with_opname("fixture_sync_func", tracer=self._tracer)
+ @tag_args
+ def fixture_sync_func() -> str:
+ return "foo"
+
+ result = fixture_sync_func()
+ self.assertEqual(result, "foo")
+
+ # the span should have been reported
+ self.assertEqual(
+ [span.operation_name for span in self._reporter.get_spans()],
+ ["fixture_sync_func"],
+ )
+
+ def test_trace_decorator_deferred(self) -> None:
+ """
+ Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
+ with functions that return deferreds
+ """
+ reactor = MemoryReactorClock()
+
+ with LoggingContext("root context"):
+
+ @trace_with_opname("fixture_deferred_func", tracer=self._tracer)
+ @tag_args
+ def fixture_deferred_func() -> "defer.Deferred[str]":
+ d1: defer.Deferred[str] = defer.Deferred()
+ d1.callback("foo")
+ return d1
+
+ result_d1 = fixture_deferred_func()
+
+ # let the tasks complete
+ reactor.pump((2,) * 8)
+
+ self.assertEqual(self.successResultOf(result_d1), "foo")
+
+ # the span should have been reported
+ self.assertEqual(
+ [span.operation_name for span in self._reporter.get_spans()],
+ ["fixture_deferred_func"],
+ )
+
+ def test_trace_decorator_async(self) -> None:
+ """
+ Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
+ with async functions
+ """
+ reactor = MemoryReactorClock()
+
+ with LoggingContext("root context"):
+
+ @trace_with_opname("fixture_async_func", tracer=self._tracer)
+ @tag_args
+ async def fixture_async_func() -> str:
+ return "foo"
+
+ d1 = defer.ensureDeferred(fixture_async_func())
+
+ # let the tasks complete
+ reactor.pump((2,) * 8)
+
+ self.assertEqual(self.successResultOf(d1), "foo")
+
+ # the span should have been reported
+ self.assertEqual(
+ [span.operation_name for span in self._reporter.get_spans()],
+ ["fixture_async_func"],
+ )
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index 96f399b7ab..0b0d8737c1 100644
--- a/tests/logging/test_terse_json.py
+++ b/tests/logging/test_terse_json.py
@@ -153,6 +153,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
site.site_tag = "test-site"
site.server_version_string = "Server v1"
site.reactor = Mock()
+ site.experimental_cors_msc3886 = False
request = SynapseRequest(FakeChannel(site, None), site)
# Call requestReceived to finish instantiating the object.
request.content = BytesIO()
diff --git a/tests/metrics/__init__.py b/tests/metrics/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/metrics/__init__.py
diff --git a/tests/metrics/test_background_process_metrics.py b/tests/metrics/test_background_process_metrics.py
new file mode 100644
index 0000000000..f0f6cb2912
--- /dev/null
+++ b/tests/metrics/test_background_process_metrics.py
@@ -0,0 +1,19 @@
+from unittest import TestCase as StdlibTestCase
+from unittest.mock import Mock
+
+from synapse.logging.context import ContextResourceUsage, LoggingContext
+from synapse.metrics.background_process_metrics import _BackgroundProcess
+
+
+class TestBackgroundProcessMetrics(StdlibTestCase):
+ def test_update_metrics_with_negative_time_diff(self) -> None:
+ """We should ignore negative reported utime and stime differences"""
+ usage = ContextResourceUsage()
+ usage.ru_stime = usage.ru_utime = -1.0
+
+ mock_logging_context = Mock(spec=LoggingContext)
+ mock_logging_context.get_resource_usage.return_value = usage
+
+ process = _BackgroundProcess("test process", mock_logging_context)
+ # Should not raise
+ process.update_metrics()
diff --git a/tests/test_metrics.py b/tests/metrics/test_metrics.py
index b4574b2ffe..bddc4228bc 100644
--- a/tests/test_metrics.py
+++ b/tests/metrics/test_metrics.py
@@ -12,7 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing_extensions import Protocol
+try:
+ from importlib import metadata
+except ImportError:
+ import importlib_metadata as metadata # type: ignore[no-redef]
+
+from unittest.mock import patch
+
+from pkg_resources import parse_version
+
+from synapse.app._base import _set_prometheus_client_use_created_metrics
from synapse.metrics import REGISTRY, InFlightGauge, generate_latest
from synapse.util.caches.deferred_cache import DeferredCache
@@ -43,7 +54,11 @@ def get_sample_labels_value(sample):
class TestMauLimit(unittest.TestCase):
def test_basic(self):
- gauge = InFlightGauge(
+ class MetricEntry(Protocol):
+ foo: int
+ bar: int
+
+ gauge: InFlightGauge[MetricEntry] = InFlightGauge(
"test1", "", labels=["test_label"], sub_metrics=["foo", "bar"]
)
@@ -137,7 +152,7 @@ class CacheMetricsTests(unittest.HomeserverTestCase):
Caches produce metrics reflecting their state when scraped.
"""
CACHE_NAME = "cache_metrics_test_fgjkbdfg"
- cache = DeferredCache(CACHE_NAME, max_entries=777)
+ cache: DeferredCache[str, str] = DeferredCache(CACHE_NAME, max_entries=777)
items = {
x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii")
@@ -162,3 +177,30 @@ class CacheMetricsTests(unittest.HomeserverTestCase):
self.assertEqual(items["synapse_util_caches_cache_size"], "1.0")
self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0")
+
+
+class PrometheusMetricsHackTestCase(unittest.HomeserverTestCase):
+ if parse_version(metadata.version("prometheus_client")) < parse_version("0.14.0"):
+ skip = "prometheus-client too old"
+
+ def test_created_metrics_disabled(self) -> None:
+ """
+ Tests that a brittle hack, to disable `_created` metrics, works.
+ This involves poking at the internals of prometheus-client.
+ It's not the end of the world if this doesn't work.
+
+ This test gives us a way to notice if prometheus-client changes
+ their internals.
+ """
+ import prometheus_client.metrics
+
+ PRIVATE_FLAG_NAME = "_use_created"
+
+ # By default, the pesky `_created` metrics are enabled.
+ # Check this assumption is still valid.
+ self.assertTrue(getattr(prometheus_client.metrics, PRIVATE_FLAG_NAME))
+
+ with patch("prometheus_client.metrics") as mock:
+ setattr(mock, PRIVATE_FLAG_NAME, True)
+ _set_prometheus_client_use_created_metrics(False)
+ self.assertFalse(getattr(mock, PRIVATE_FLAG_NAME, False))
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 8e05590230..058ca57e55 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -16,6 +16,7 @@ from unittest.mock import Mock
from twisted.internet import defer
from synapse.api.constants import EduTypes, EventTypes
+from synapse.api.errors import NotFoundError
from synapse.events import EventBase
from synapse.federation.units import Transaction
from synapse.handlers.presence import UserPresenceState
@@ -29,7 +30,6 @@ from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.test_utils import simple_async_mock
from tests.test_utils.event_injection import inject_member_event
from tests.unittest import HomeserverTestCase, override_config
-from tests.utils import USE_POSTGRES_FOR_TESTS
class ModuleApiTestCase(HomeserverTestCase):
@@ -532,6 +532,34 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertEqual(res["displayname"], "simone")
self.assertIsNone(res["avatar_url"])
+ def test_update_room_membership_remote_join(self):
+ """Test that the module API can join a remote room."""
+ # Necessary to fake a remote join.
+ fake_stream_id = 1
+ mocked_remote_join = simple_async_mock(
+ return_value=("fake-event-id", fake_stream_id)
+ )
+ self.hs.get_room_member_handler()._remote_join = mocked_remote_join
+ fake_remote_host = f"{self.module_api.server_name}-remote"
+
+ # Given that the join is to be faked, we expect the relevant join event not to
+ # be persisted and the module API method to raise that.
+ self.get_failure(
+ defer.ensureDeferred(
+ self.module_api.update_room_membership(
+ sender=f"@user:{self.module_api.server_name}",
+ target=f"@user:{self.module_api.server_name}",
+ room_id=f"!nonexistent:{fake_remote_host}",
+ new_membership="join",
+ remote_room_hosts=[fake_remote_host],
+ )
+ ),
+ NotFoundError,
+ )
+
+ # Check that a remote join was attempted.
+ self.assertEqual(mocked_remote_join.call_count, 1)
+
def test_get_room_state(self):
"""Tests that a module can retrieve the state of a room through the module API."""
user_id = self.register_user("peter", "hackme")
@@ -654,15 +682,61 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertEqual(room_id, reference_room_id)
+ def test_create_room(self) -> None:
+ """Test that modules can create a room."""
+ # First test user validation (i.e. user is local).
+ self.get_failure(
+ self.module_api.create_room(
+ user_id=f"@user:{self.module_api.server_name}abc",
+ config={},
+ ratelimit=False,
+ ),
+ RuntimeError,
+ )
+
+ # Now do the happy path.
+ user_id = self.register_user("user", "password")
+ access_token = self.login(user_id, "password")
+
+ room_id, room_alias = self.get_success(
+ self.module_api.create_room(
+ user_id=user_id, config={"room_alias_name": "foo-bar"}, ratelimit=False
+ )
+ )
+
+ # Check room creator.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v3/rooms/{room_id}/state/m.room.create",
+ access_token=access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.json_body["creator"], user_id)
+
+ # Check room alias.
+ self.assertEquals(room_alias, f"#foo-bar:{self.module_api.server_name}")
+
+ # Let's try a room with no alias.
+ room_id, room_alias = self.get_success(
+ self.module_api.create_room(user_id=user_id, config={}, ratelimit=False)
+ )
+
+ # Check room creator.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v3/rooms/{room_id}/state/m.room.create",
+ access_token=access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.json_body["creator"], user_id)
+
+ # Check room alias.
+ self.assertIsNone(room_alias)
+
class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
"""For testing ModuleApi functionality in a multi-worker setup"""
- # Testing stream ID replication from the main to worker processes requires postgres
- # (due to needing `MultiWriterIdGenerator`).
- if not USE_POSTGRES_FOR_TESTS:
- skip = "Requires Postgres"
-
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -672,7 +746,6 @@ class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
def default_config(self):
conf = super().default_config()
- conf["redis"] = {"enabled": "true"}
conf["stream_writers"] = {"presence": ["presence_writer"]}
conf["instance_map"] = {
"presence_writer": {"host": "testserv", "port": 1001},
@@ -705,8 +778,11 @@ def _test_sending_local_online_presence_to_local_user(
worker process. The test users will still sync with the main process. The purpose of testing
with a worker is to check whether a Synapse module running on a worker can inform other workers/
the main process that they should include additional presence when a user next syncs.
+ If this argument is True, `test_case` MUST be an instance of BaseMultiWorkerStreamTestCase.
"""
if test_with_workers:
+ assert isinstance(test_case, BaseMultiWorkerStreamTestCase)
+
# Create a worker process to make module_api calls against
worker_hs = test_case.make_worker_hs(
"synapse.app.generic_worker", {"worker_name": "presence_writer"}
diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py
new file mode 100644
index 0000000000..594e7937a8
--- /dev/null
+++ b/tests/push/test_bulk_push_rule_evaluator.py
@@ -0,0 +1,74 @@
+from unittest.mock import patch
+
+from synapse.api.room_versions import RoomVersions
+from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
+from synapse.rest import admin
+from synapse.rest.client import login, register, room
+from synapse.types import create_requester
+
+from tests import unittest
+
+
+class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ register.register_servlets,
+ ]
+
+ def test_action_for_event_by_user_handles_noninteger_power_levels(self) -> None:
+ """We should convert floats and strings to integers before passing to Rust.
+
+ Reproduces #14060.
+
+ A lack of validation: the gift that keeps on giving.
+ """
+ # Create a new user and room.
+ alice = self.register_user("alice", "pass")
+ token = self.login(alice, "pass")
+
+ room_id = self.helper.create_room_as(
+ alice, room_version=RoomVersions.V9.identifier, tok=token
+ )
+
+ # Alter the power levels in that room to include stringy and floaty levels.
+ # We need to suppress the validation logic or else it will reject these dodgy
+ # values. (Presumably this validation was not always present.)
+ event_creation_handler = self.hs.get_event_creation_handler()
+ requester = create_requester(alice)
+ with patch("synapse.events.validator.validate_canonicaljson"), patch(
+ "synapse.events.validator.jsonschema.validate"
+ ):
+ self.helper.send_state(
+ room_id,
+ "m.room.power_levels",
+ {
+ "users": {alice: "100"}, # stringy
+ "notifications": {"room": 100.0}, # float
+ },
+ token,
+ state_key="",
+ )
+
+ # Create a new message event, and try to evaluate it under the dodgy
+ # power level event.
+ event, context = self.get_success(
+ event_creation_handler.create_event(
+ requester,
+ {
+ "type": "m.room.message",
+ "room_id": room_id,
+ "content": {
+ "msgtype": "m.text",
+ "body": "helo",
+ },
+ "sender": alice,
+ },
+ )
+ )
+
+ bulk_evaluator = BulkPushRuleEvaluator(self.hs)
+ # should not raise
+ self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)]))
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 7a3b0d6755..fd14568f55 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -114,7 +114,7 @@ class EmailPusherTests(HomeserverTestCase):
)
self.pusher = self.get_success(
- self.hs.get_pusherpool().add_pusher(
+ self.hs.get_pusherpool().add_or_update_pusher(
user_id=self.user_id,
access_token=self.token_id,
kind="email",
@@ -136,7 +136,7 @@ class EmailPusherTests(HomeserverTestCase):
"""
with self.assertRaises(SynapseError) as cm:
self.get_success_or_raise(
- self.hs.get_pusherpool().add_pusher(
+ self.hs.get_pusherpool().add_or_update_pusher(
user_id=self.user_id,
access_token=self.token_id,
kind="email",
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index d9c68cdd2d..b383b8401f 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -19,9 +19,10 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.logging.context import make_deferred_yieldable
-from synapse.push import PusherConfigException
-from synapse.rest.client import login, push_rule, receipts, room
+from synapse.push import PusherConfig, PusherConfigException
+from synapse.rest.client import login, push_rule, pusher, receipts, room
from synapse.server import HomeServer
+from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import JsonDict
from synapse.util import Clock
@@ -35,6 +36,7 @@ class HTTPPusherTests(HomeserverTestCase):
login.register_servlets,
receipts.register_servlets,
push_rule.register_servlets,
+ pusher.register_servlets,
]
user_id = True
hijack_auth = False
@@ -74,7 +76,7 @@ class HTTPPusherTests(HomeserverTestCase):
def test_data(data: Optional[JsonDict]) -> None:
self.get_failure(
- self.hs.get_pusherpool().add_pusher(
+ self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
@@ -119,7 +121,7 @@ class HTTPPusherTests(HomeserverTestCase):
token_id = user_tuple.token_id
self.get_success(
- self.hs.get_pusherpool().add_pusher(
+ self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
@@ -235,7 +237,7 @@ class HTTPPusherTests(HomeserverTestCase):
token_id = user_tuple.token_id
self.get_success(
- self.hs.get_pusherpool().add_pusher(
+ self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
@@ -355,7 +357,7 @@ class HTTPPusherTests(HomeserverTestCase):
token_id = user_tuple.token_id
self.get_success(
- self.hs.get_pusherpool().add_pusher(
+ self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
@@ -441,7 +443,7 @@ class HTTPPusherTests(HomeserverTestCase):
token_id = user_tuple.token_id
self.get_success(
- self.hs.get_pusherpool().add_pusher(
+ self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
@@ -518,7 +520,7 @@ class HTTPPusherTests(HomeserverTestCase):
token_id = user_tuple.token_id
self.get_success(
- self.hs.get_pusherpool().add_pusher(
+ self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
@@ -624,7 +626,7 @@ class HTTPPusherTests(HomeserverTestCase):
token_id = user_tuple.token_id
self.get_success(
- self.hs.get_pusherpool().add_pusher(
+ self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
@@ -728,18 +730,38 @@ class HTTPPusherTests(HomeserverTestCase):
)
self.assertEqual(channel.code, 200, channel.json_body)
- def _make_user_with_pusher(self, username: str) -> Tuple[str, str]:
+ def _make_user_with_pusher(
+ self, username: str, enabled: bool = True
+ ) -> Tuple[str, str]:
+ """Registers a user and creates a pusher for them.
+
+ Args:
+ username: the localpart of the new user's Matrix ID.
+ enabled: whether to create the pusher in an enabled or disabled state.
+ """
user_id = self.register_user(username, "pass")
access_token = self.login(username, "pass")
# Register the pusher
+ self._set_pusher(user_id, access_token, enabled)
+
+ return user_id, access_token
+
+ def _set_pusher(self, user_id: str, access_token: str, enabled: bool) -> None:
+ """Creates or updates the pusher for the given user.
+
+ Args:
+ user_id: the user's Matrix ID.
+ access_token: the access token associated with the pusher.
+ enabled: whether to enable or disable the pusher.
+ """
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
token_id = user_tuple.token_id
self.get_success(
- self.hs.get_pusherpool().add_pusher(
+ self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
@@ -749,11 +771,11 @@ class HTTPPusherTests(HomeserverTestCase):
pushkey="a@example.com",
lang=None,
data={"url": "http://example.com/_matrix/push/v1/notify"},
+ enabled=enabled,
+ device_id=user_tuple.device_id,
)
)
- return user_id, access_token
-
def test_dont_notify_rule_overrides_message(self) -> None:
"""
The override push rule will suppress notification
@@ -791,3 +813,148 @@ class HTTPPusherTests(HomeserverTestCase):
# The user sends a message back (sends a notification)
self.helper.send(room, body="Hello", tok=access_token)
self.assertEqual(len(self.push_attempts), 1)
+
+ @override_config({"experimental_features": {"msc3881_enabled": True}})
+ def test_disable(self) -> None:
+ """Tests that disabling a pusher means it's not pushed to anymore."""
+ user_id, access_token = self._make_user_with_pusher("user")
+ other_user_id, other_access_token = self._make_user_with_pusher("otheruser")
+
+ room = self.helper.create_room_as(user_id, tok=access_token)
+ self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+
+ # Send a message and check that it generated a push.
+ self.helper.send(room, body="Hi!", tok=other_access_token)
+ self.assertEqual(len(self.push_attempts), 1)
+
+ # Disable the pusher.
+ self._set_pusher(user_id, access_token, enabled=False)
+
+ # Send another message and check that it did not generate a push.
+ self.helper.send(room, body="Hi!", tok=other_access_token)
+ self.assertEqual(len(self.push_attempts), 1)
+
+ # Get the pushers for the user and check that it is marked as disabled.
+ channel = self.make_request("GET", "/pushers", access_token=access_token)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(len(channel.json_body["pushers"]), 1)
+
+ enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"]
+ self.assertFalse(enabled)
+ self.assertTrue(isinstance(enabled, bool))
+
+ @override_config({"experimental_features": {"msc3881_enabled": True}})
+ def test_enable(self) -> None:
+ """Tests that enabling a disabled pusher means it gets pushed to."""
+ # Create the user with the pusher already disabled.
+ user_id, access_token = self._make_user_with_pusher("user", enabled=False)
+ other_user_id, other_access_token = self._make_user_with_pusher("otheruser")
+
+ room = self.helper.create_room_as(user_id, tok=access_token)
+ self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+
+ # Send a message and check that it did not generate a push.
+ self.helper.send(room, body="Hi!", tok=other_access_token)
+ self.assertEqual(len(self.push_attempts), 0)
+
+ # Enable the pusher.
+ self._set_pusher(user_id, access_token, enabled=True)
+
+ # Send another message and check that it did generate a push.
+ self.helper.send(room, body="Hi!", tok=other_access_token)
+ self.assertEqual(len(self.push_attempts), 1)
+
+ # Get the pushers for the user and check that it is marked as enabled.
+ channel = self.make_request("GET", "/pushers", access_token=access_token)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(len(channel.json_body["pushers"]), 1)
+
+ enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"]
+ self.assertTrue(enabled)
+ self.assertTrue(isinstance(enabled, bool))
+
+ @override_config({"experimental_features": {"msc3881_enabled": True}})
+ def test_null_enabled(self) -> None:
+ """Tests that a pusher that has an 'enabled' column set to NULL (eg pushers
+ created before the column was introduced) is considered enabled.
+ """
+ # We intentionally set 'enabled' to None so that it's stored as NULL in the
+ # database.
+ user_id, access_token = self._make_user_with_pusher("user", enabled=None) # type: ignore[arg-type]
+
+ channel = self.make_request("GET", "/pushers", access_token=access_token)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(len(channel.json_body["pushers"]), 1)
+ self.assertTrue(channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"])
+
+ def test_update_different_device_access_token_device_id(self) -> None:
+ """Tests that if we create a pusher from one device, the update it from another
+ device, the access token and device ID associated with the pusher stays the
+ same.
+ """
+ # Create a user with a pusher.
+ user_id, access_token = self._make_user_with_pusher("user")
+
+ # Get the token ID for the current access token, since that's what we store in
+ # the pushers table. Also get the device ID from it.
+ user_tuple = self.get_success(
+ self.hs.get_datastores().main.get_user_by_access_token(access_token)
+ )
+ token_id = user_tuple.token_id
+ device_id = user_tuple.device_id
+
+ # Generate a new access token, and update the pusher with it.
+ new_token = self.login("user", "pass")
+ self._set_pusher(user_id, new_token, enabled=False)
+
+ # Get the current list of pushers for the user.
+ ret = self.get_success(
+ self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+ )
+ pushers: List[PusherConfig] = list(ret)
+
+ # Check that we still have one pusher, and that the access token and device ID
+ # associated with it didn't change.
+ self.assertEqual(len(pushers), 1)
+ self.assertEqual(pushers[0].access_token, token_id)
+ self.assertEqual(pushers[0].device_id, device_id)
+
+ @override_config({"experimental_features": {"msc3881_enabled": True}})
+ def test_device_id(self) -> None:
+ """Tests that a pusher created with a given device ID shows that device ID in
+ GET /pushers requests.
+ """
+ self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # We create the pusher with an HTTP request rather than with
+ # _make_user_with_pusher so that we can test the device ID is correctly set when
+ # creating a pusher via an API call.
+ self.make_request(
+ method="POST",
+ path="/pushers/set",
+ content={
+ "kind": "http",
+ "app_id": "m.http",
+ "app_display_name": "HTTP Push Notifications",
+ "device_display_name": "pushy push",
+ "pushkey": "a@example.com",
+ "lang": "en",
+ "data": {"url": "http://example.com/_matrix/push/v1/notify"},
+ },
+ access_token=access_token,
+ )
+
+ # Look up the user info for the access token so we can compare the device ID.
+ lookup_result: TokenLookupResult = self.get_success(
+ self.hs.get_datastores().main.get_user_by_access_token(access_token)
+ )
+
+ # Get the user's devices and check it has the correct device ID.
+ channel = self.make_request("GET", "/pushers", access_token=access_token)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(len(channel.json_body["pushers"]), 1)
+ self.assertEqual(
+ channel.json_body["pushers"][0]["org.matrix.msc3881.device_id"],
+ lookup_result.device_id,
+ )
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 718f489577..fe7c145840 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -12,23 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, Optional, Set, Tuple, Union
+from typing import Dict, Optional, Union
import frozendict
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.api.room_versions import RoomVersions
from synapse.appservice import ApplicationService
from synapse.events import FrozenEvent
-from synapse.push import push_rule_evaluator
-from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent
+from synapse.push.bulk_push_rule_evaluator import _flatten_dict
+from synapse.push.httppusher import tweaks_for_actions
+from synapse.rest import admin
from synapse.rest.client import login, register, room
from synapse.server import HomeServer
from synapse.storage.databases.main.appservice import _make_exclusive_regex
-from synapse.types import JsonDict
+from synapse.synapse_rust.push import PushRuleEvaluator
+from synapse.types import JsonDict, UserID
from synapse.util import Clock
from tests import unittest
@@ -37,11 +39,8 @@ from tests.test_utils.event_injection import create_event, inject_member_event
class PushRuleEvaluatorTestCase(unittest.TestCase):
def _get_evaluator(
- self,
- content: JsonDict,
- relations: Optional[Dict[str, Set[Tuple[str, str]]]] = None,
- relations_match_enabled: bool = False,
- ) -> PushRuleEvaluatorForEvent:
+ self, content: JsonDict, related_events=None
+ ) -> PushRuleEvaluator:
event = FrozenEvent(
{
"event_id": "$event_id",
@@ -56,13 +55,13 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
room_member_count = 0
sender_power_level = 0
power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
- return PushRuleEvaluatorForEvent(
- event,
+ return PushRuleEvaluator(
+ _flatten_dict(event),
room_member_count,
sender_power_level,
- power_levels,
- relations or set(),
- relations_match_enabled,
+ power_levels.get("notifications", {}),
+ {} if related_events is None else related_events,
+ True,
)
def test_display_name(self) -> None:
@@ -293,77 +292,218 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
]
self.assertEqual(
- push_rule_evaluator.tweaks_for_actions(actions),
+ tweaks_for_actions(actions),
{"sound": "default", "highlight": True},
)
- def test_relation_match(self) -> None:
- """Test the relation_match push rule kind."""
-
- # Check if the experimental feature is disabled.
+ def test_related_event_match(self):
evaluator = self._get_evaluator(
- {}, {"m.annotation": {("@user:test", "m.reaction")}}
+ {
+ "m.relates_to": {
+ "event_id": "$parent_event_id",
+ "key": "😀",
+ "rel_type": "m.annotation",
+ "m.in_reply_to": {
+ "event_id": "$parent_event_id",
+ },
+ }
+ },
+ {
+ "m.in_reply_to": {
+ "event_id": "$parent_event_id",
+ "type": "m.room.message",
+ "sender": "@other_user:test",
+ "room_id": "!room:test",
+ "content.msgtype": "m.text",
+ "content.body": "Original message",
+ },
+ "m.annotation": {
+ "event_id": "$parent_event_id",
+ "type": "m.room.message",
+ "sender": "@other_user:test",
+ "room_id": "!room:test",
+ "content.msgtype": "m.text",
+ "content.body": "Original message",
+ },
+ },
+ )
+ self.assertTrue(
+ evaluator.matches(
+ {
+ "kind": "im.nheko.msc3664.related_event_match",
+ "key": "sender",
+ "rel_type": "m.in_reply_to",
+ "pattern": "@other_user:test",
+ },
+ "@user:test",
+ "display_name",
+ )
+ )
+ self.assertFalse(
+ evaluator.matches(
+ {
+ "kind": "im.nheko.msc3664.related_event_match",
+ "key": "sender",
+ "rel_type": "m.in_reply_to",
+ "pattern": "@user:test",
+ },
+ "@other_user:test",
+ "display_name",
+ )
+ )
+ self.assertTrue(
+ evaluator.matches(
+ {
+ "kind": "im.nheko.msc3664.related_event_match",
+ "key": "sender",
+ "rel_type": "m.annotation",
+ "pattern": "@other_user:test",
+ },
+ "@other_user:test",
+ "display_name",
+ )
+ )
+ self.assertFalse(
+ evaluator.matches(
+ {
+ "kind": "im.nheko.msc3664.related_event_match",
+ "key": "sender",
+ "rel_type": "m.in_reply_to",
+ },
+ "@user:test",
+ "display_name",
+ )
+ )
+ self.assertTrue(
+ evaluator.matches(
+ {
+ "kind": "im.nheko.msc3664.related_event_match",
+ "rel_type": "m.in_reply_to",
+ },
+ "@user:test",
+ "display_name",
+ )
+ )
+ self.assertFalse(
+ evaluator.matches(
+ {
+ "kind": "im.nheko.msc3664.related_event_match",
+ "rel_type": "m.replace",
+ },
+ "@other_user:test",
+ "display_name",
+ )
)
- condition = {"kind": "relation_match"}
- # Oddly, an unknown condition always matches.
- self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
- # A push rule evaluator with the experimental rule enabled.
+ def test_related_event_match_with_fallback(self):
evaluator = self._get_evaluator(
- {}, {"m.annotation": {("@user:test", "m.reaction")}}, True
+ {
+ "m.relates_to": {
+ "event_id": "$parent_event_id",
+ "key": "😀",
+ "rel_type": "m.thread",
+ "is_falling_back": True,
+ "m.in_reply_to": {
+ "event_id": "$parent_event_id",
+ },
+ }
+ },
+ {
+ "m.in_reply_to": {
+ "event_id": "$parent_event_id",
+ "type": "m.room.message",
+ "sender": "@other_user:test",
+ "room_id": "!room:test",
+ "content.msgtype": "m.text",
+ "content.body": "Original message",
+ "im.vector.is_falling_back": "",
+ },
+ "m.thread": {
+ "event_id": "$parent_event_id",
+ "type": "m.room.message",
+ "sender": "@other_user:test",
+ "room_id": "!room:test",
+ "content.msgtype": "m.text",
+ "content.body": "Original message",
+ },
+ },
+ )
+ self.assertTrue(
+ evaluator.matches(
+ {
+ "kind": "im.nheko.msc3664.related_event_match",
+ "key": "sender",
+ "rel_type": "m.in_reply_to",
+ "pattern": "@other_user:test",
+ "include_fallbacks": True,
+ },
+ "@user:test",
+ "display_name",
+ )
+ )
+ self.assertFalse(
+ evaluator.matches(
+ {
+ "kind": "im.nheko.msc3664.related_event_match",
+ "key": "sender",
+ "rel_type": "m.in_reply_to",
+ "pattern": "@other_user:test",
+ "include_fallbacks": False,
+ },
+ "@user:test",
+ "display_name",
+ )
+ )
+ self.assertFalse(
+ evaluator.matches(
+ {
+ "kind": "im.nheko.msc3664.related_event_match",
+ "key": "sender",
+ "rel_type": "m.in_reply_to",
+ "pattern": "@other_user:test",
+ },
+ "@user:test",
+ "display_name",
+ )
)
- # Check just relation type.
- condition = {
- "kind": "org.matrix.msc3772.relation_match",
- "rel_type": "m.annotation",
- }
- self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
-
- # Check relation type and sender.
- condition = {
- "kind": "org.matrix.msc3772.relation_match",
- "rel_type": "m.annotation",
- "sender": "@user:test",
- }
- self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
- condition = {
- "kind": "org.matrix.msc3772.relation_match",
- "rel_type": "m.annotation",
- "sender": "@other:test",
- }
- self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
-
- # Check relation type and event type.
- condition = {
- "kind": "org.matrix.msc3772.relation_match",
- "rel_type": "m.annotation",
- "type": "m.reaction",
- }
- self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
-
- # Check just sender, this fails since rel_type is required.
- condition = {
- "kind": "org.matrix.msc3772.relation_match",
- "sender": "@user:test",
- }
- self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
-
- # Check sender glob.
- condition = {
- "kind": "org.matrix.msc3772.relation_match",
- "rel_type": "m.annotation",
- "sender": "@*:test",
- }
- self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
-
- # Check event type glob.
- condition = {
- "kind": "org.matrix.msc3772.relation_match",
- "rel_type": "m.annotation",
- "event_type": "*.reaction",
- }
- self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
+ def test_related_event_match_no_related_event(self):
+ evaluator = self._get_evaluator(
+ {"msgtype": "m.text", "body": "Message without related event"}
+ )
+ self.assertFalse(
+ evaluator.matches(
+ {
+ "kind": "im.nheko.msc3664.related_event_match",
+ "key": "sender",
+ "rel_type": "m.in_reply_to",
+ "pattern": "@other_user:test",
+ },
+ "@user:test",
+ "display_name",
+ )
+ )
+ self.assertFalse(
+ evaluator.matches(
+ {
+ "kind": "im.nheko.msc3664.related_event_match",
+ "key": "sender",
+ "rel_type": "m.in_reply_to",
+ },
+ "@user:test",
+ "display_name",
+ )
+ )
+ self.assertFalse(
+ evaluator.matches(
+ {
+ "kind": "im.nheko.msc3664.related_event_match",
+ "rel_type": "m.in_reply_to",
+ },
+ "@user:test",
+ "display_name",
+ )
+ )
class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
@@ -439,3 +579,80 @@ class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
)
self.assertEqual(len(users_with_push_actions), 0)
+
+
+class BulkPushRuleEvaluatorTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
+ self.main_store = homeserver.get_datastores().main
+
+ self.user_id1 = self.register_user("user1", "password")
+ self.tok1 = self.login(self.user_id1, "password")
+ self.user_id2 = self.register_user("user2", "password")
+ self.tok2 = self.login(self.user_id2, "password")
+
+ self.room_id = self.helper.create_room_as(tok=self.tok1)
+
+ # We want to test history visibility works correctly.
+ self.helper.send_state(
+ self.room_id,
+ EventTypes.RoomHistoryVisibility,
+ {"history_visibility": HistoryVisibility.JOINED},
+ tok=self.tok1,
+ )
+
+ def get_notif_count(self, user_id: str) -> int:
+ return self.get_success(
+ self.main_store.db_pool.simple_select_one_onecol(
+ table="event_push_actions",
+ keyvalues={"user_id": user_id},
+ retcol="COALESCE(SUM(notif), 0)",
+ desc="get_staging_notif_count",
+ )
+ )
+
+ def test_plain_message(self) -> None:
+ """Test that sending a normal message in a room will trigger a
+ notification
+ """
+
+ # Have user2 join the room and cle
+ self.helper.join(self.room_id, self.user_id2, tok=self.tok2)
+
+ # They start off with no notifications, but get them when messages are
+ # sent.
+ self.assertEqual(self.get_notif_count(self.user_id2), 0)
+
+ user1 = UserID.from_string(self.user_id1)
+ self.create_and_send_event(self.room_id, user1)
+
+ self.assertEqual(self.get_notif_count(self.user_id2), 1)
+
+ def test_delayed_message(self) -> None:
+ """Test that a delayed message that was from before a user joined
+ doesn't cause a notification for the joined user.
+ """
+ user1 = UserID.from_string(self.user_id1)
+
+ # Send a message before user2 joins
+ event_id1 = self.create_and_send_event(self.room_id, user1)
+
+ # Have user2 join the room
+ self.helper.join(self.room_id, self.user_id2, tok=self.tok2)
+
+ # They start off with no notifications
+ self.assertEqual(self.get_notif_count(self.user_id2), 0)
+
+ # Send another message that references the event before the join to
+ # simulate a "delayed" event
+ self.create_and_send_event(self.room_id, user1, prev_event_ids=[event_id1])
+
+ # user2 should not be notified about it, because they can't see it.
+ self.assertEqual(self.get_notif_count(self.user_id2), 0)
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 970d5e533b..3029a16dda 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -24,11 +24,11 @@ from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.handler import ReplicationCommandHandler
-from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
-from synapse.replication.tcp.resource import (
- ReplicationStreamProtocolFactory,
+from synapse.replication.tcp.protocol import (
+ ClientReplicationStreamProtocol,
ServerReplicationStreamProtocol,
)
+from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.server import HomeServer
from tests import unittest
@@ -220,15 +220,34 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests running multiple workers.
+ Enables Redis, providing a fake Redis server.
+
Automatically handle HTTP replication requests from workers to master,
unlike `BaseStreamTestCase`.
"""
+ if not hiredis:
+ skip = "Requires hiredis"
+
+ if not USE_POSTGRES_FOR_TESTS:
+ # Redis replication only takes place on Postgres
+ skip = "Requires Postgres"
+
+ def default_config(self) -> Dict[str, Any]:
+ """
+ Overrides the default config to enable Redis.
+ Even if the test only uses make_worker_hs, the main process needs Redis
+ enabled otherwise it won't create a Fake Redis server to listen on the
+ Redis port and accept fake TCP connections.
+ """
+ base = super().default_config()
+ base["redis"] = {"enabled": True}
+ return base
+
def setUp(self):
super().setUp()
# build a replication server
- self.server_factory = ReplicationStreamProtocolFactory(self.hs)
self.streamer = self.hs.get_replication_streamer()
# Fake in memory Redis server that servers can connect to.
@@ -247,15 +266,14 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# handling inbound HTTP requests to that instance.
self._hs_to_site = {self.hs: self.site}
- if self.hs.config.redis.redis_enabled:
- # Handle attempts to connect to fake redis server.
- self.reactor.add_tcp_client_callback(
- "localhost",
- 6379,
- self.connect_any_redis_attempts,
- )
+ # Handle attempts to connect to fake redis server.
+ self.reactor.add_tcp_client_callback(
+ "localhost",
+ 6379,
+ self.connect_any_redis_attempts,
+ )
- self.hs.get_replication_command_handler().start_replication(self.hs)
+ self.hs.get_replication_command_handler().start_replication(self.hs)
# When we see a connection attempt to the master replication listener we
# automatically set up the connection. This is so that tests don't
@@ -339,27 +357,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
store = worker_hs.get_datastores().main
store.db_pool._db_pool = self.database_pool._db_pool
- # Set up TCP replication between master and the new worker if we don't
- # have Redis support enabled.
- if not worker_hs.config.redis.redis_enabled:
- repl_handler = ReplicationCommandHandler(worker_hs)
- client = ClientReplicationStreamProtocol(
- worker_hs,
- "client",
- "test",
- self.clock,
- repl_handler,
- )
- server = self.server_factory.buildProtocol(
- IPv4Address("TCP", "127.0.0.1", 0)
- )
-
- client_transport = FakeTransport(server, self.reactor)
- client.makeConnection(client_transport)
-
- server_transport = FakeTransport(client, self.reactor)
- server.makeConnection(server_transport)
-
# Set up a resource for the worker
resource = ReplicationRestResource(worker_hs)
@@ -374,12 +371,11 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
config=worker_hs.config.server.listeners[0],
resource=resource,
server_version_string="1",
- max_request_body_size=4096,
+ max_request_body_size=8192,
reactor=self.reactor,
)
- if worker_hs.config.redis.redis_enabled:
- worker_hs.get_replication_command_handler().start_replication(worker_hs)
+ worker_hs.get_replication_command_handler().start_replication(worker_hs)
return worker_hs
@@ -546,8 +542,13 @@ class FakeRedisPubSubProtocol(Protocol):
self.send("OK")
elif command == b"GET":
self.send(None)
+
+ # Connection keep-alives.
+ elif command == b"PING":
+ self.send("PONG")
+
else:
- raise Exception("Unknown command")
+ raise Exception(f"Unknown command: {command}")
def send(self, msg):
"""Send a message back to the client."""
@@ -582,27 +583,3 @@ class FakeRedisPubSubProtocol(Protocol):
def connectionLost(self, reason):
self._server.remove_subscriber(self)
-
-
-class RedisMultiWorkerStreamTestCase(BaseMultiWorkerStreamTestCase):
- """
- A test case that enables Redis, providing a fake Redis server.
- """
-
- if not hiredis:
- skip = "Requires hiredis"
-
- if not USE_POSTGRES_FOR_TESTS:
- # Redis replication only takes place on Postgres
- skip = "Requires Postgres"
-
- def default_config(self) -> Dict[str, Any]:
- """
- Overrides the default config to enable Redis.
- Even if the test only uses make_worker_hs, the main process needs Redis
- enabled otherwise it won't create a Fake Redis server to listen on the
- Redis port and accept fake TCP connections.
- """
- base = super().default_config()
- base["redis"] = {"enabled": True}
- return base
diff --git a/tests/replication/http/test__base.py b/tests/replication/http/test__base.py
index 822a957c3a..936ab4504a 100644
--- a/tests/replication/http/test__base.py
+++ b/tests/replication/http/test__base.py
@@ -18,11 +18,12 @@ from typing import Tuple
from twisted.web.server import Request
from synapse.api.errors import Codes
-from synapse.http.server import JsonResource, cancellable
+from synapse.http.server import JsonResource
from synapse.replication.http import REPLICATION_PREFIX
from synapse.replication.http._base import ReplicationEndpoint
from synapse.server import HomeServer
from synapse.types import JsonDict
+from synapse.util.cancellation import cancellable
from tests import unittest
from tests.http.server._base import test_disconnect
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 531a0db2d0..dce71f7334 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -21,8 +21,11 @@ from synapse.api.constants import ReceiptTypes
from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
from synapse.handlers.room import RoomEventSource
-from synapse.replication.slave.storage.events import SlavedEventStore
-from synapse.storage.databases.main.event_push_actions import NotifCounts
+from synapse.storage.databases.main.event_push_actions import (
+ NotifCounts,
+ RoomNotifCounts,
+)
+from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser
from synapse.types import PersistedEventPosition
@@ -55,9 +58,9 @@ def patch__eq__(cls):
return unpatch
-class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
+class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
- STORE_TYPE = SlavedEventStore
+ STORE_TYPE = EventsWorkerStore
def setUp(self):
# Patch up the equality operator for events so that we can check
@@ -140,6 +143,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.persist(type="m.room.create", key="", creator=USER_ID)
self.check("get_invited_rooms_for_local_user", [USER_ID_2], [])
event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
+ assert event.internal_metadata.stream_ordering is not None
self.replicate()
@@ -171,14 +175,16 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
if send_receipt:
self.get_success(
self.master_store.insert_receipt(
- ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], {}
+ ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], None, {}
)
)
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2],
- NotifCounts(highlight_count=0, unread_count=0, notify_count=0),
+ RoomNotifCounts(
+ NotifCounts(highlight_count=0, unread_count=0, notify_count=0), {}
+ ),
)
self.persist(
@@ -191,7 +197,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2],
- NotifCounts(highlight_count=0, unread_count=0, notify_count=1),
+ RoomNotifCounts(
+ NotifCounts(highlight_count=0, unread_count=0, notify_count=1), {}
+ ),
)
self.persist(
@@ -206,7 +214,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2],
- NotifCounts(highlight_count=1, unread_count=0, notify_count=2),
+ RoomNotifCounts(
+ NotifCounts(highlight_count=1, unread_count=0, notify_count=2), {}
+ ),
)
def test_get_rooms_for_user_with_stream_ordering(self):
@@ -221,6 +231,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
j2 = self.persist(
type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
)
+ assert j2.internal_metadata.stream_ordering is not None
self.replicate()
expected_pos = PersistedEventPosition(
@@ -278,6 +289,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
)
)
self.replicate()
+ assert j2.internal_metadata.stream_ordering is not None
event_source = RoomEventSource(self.hs)
event_source.store = self.slaved_store
@@ -327,10 +339,10 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
event_id = 0
- def persist(self, backfill=False, **kwargs):
+ def persist(self, backfill=False, **kwargs) -> FrozenEvent:
"""
Returns:
- synapse.events.FrozenEvent: The event that was persisted.
+ The event that was persisted.
"""
event, context = self.build_event(**kwargs)
@@ -404,6 +416,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
event.event_id,
{user_id: actions for user_id, actions in push_actions},
False,
+ "main",
)
)
return event, context
diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py
index eb00117845..ede6d0c118 100644
--- a/tests/replication/tcp/streams/test_receipts.py
+++ b/tests/replication/tcp/streams/test_receipts.py
@@ -33,7 +33,12 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
# tell the master to send a new receipt
self.get_success(
self.hs.get_datastores().main.insert_receipt(
- "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1}
+ "!room:blue",
+ "m.read",
+ USER_ID,
+ ["$event:blue"],
+ thread_id=None,
+ data={"a": 1},
)
)
self.replicate()
@@ -48,6 +53,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
self.assertEqual("m.read", row.receipt_type)
self.assertEqual(USER_ID, row.user_id)
self.assertEqual("$event:blue", row.event_id)
+ self.assertIsNone(row.thread_id)
self.assertEqual({"a": 1}, row.data)
# Now let's disconnect and insert some data.
@@ -57,7 +63,12 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
self.get_success(
self.hs.get_datastores().main.insert_receipt(
- "!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2}
+ "!room2:blue",
+ "m.read",
+ USER_ID,
+ ["$event2:foo"],
+ thread_id=None,
+ data={"a": 2},
)
)
self.replicate()
diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py
index e6a19eafd5..1e299d2d67 100644
--- a/tests/replication/tcp/test_handler.py
+++ b/tests/replication/tcp/test_handler.py
@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from tests.replication._base import RedisMultiWorkerStreamTestCase
+from tests.replication._base import BaseMultiWorkerStreamTestCase
-class ChannelsTestCase(RedisMultiWorkerStreamTestCase):
+class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
def test_subscribed_to_enough_redis_channels(self) -> None:
# The default main process is subscribed to the USER_IP channel.
self.assertCountEqual(
diff --git a/tests/replication/test_module_cache_invalidation.py b/tests/replication/test_module_cache_invalidation.py
new file mode 100644
index 0000000000..b93cae67d3
--- /dev/null
+++ b/tests/replication/test_module_cache_invalidation.py
@@ -0,0 +1,79 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+import synapse
+from synapse.module_api import cached
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+
+logger = logging.getLogger(__name__)
+
+FIRST_VALUE = "one"
+SECOND_VALUE = "two"
+
+KEY = "mykey"
+
+
+class TestCache:
+ current_value = FIRST_VALUE
+
+ @cached()
+ async def cached_function(self, user_id: str) -> str:
+ return self.current_value
+
+
+class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ ]
+
+ def test_module_cache_full_invalidation(self):
+ main_cache = TestCache()
+ self.hs.get_module_api().register_cached_function(main_cache.cached_function)
+
+ worker_hs = self.make_worker_hs("synapse.app.generic_worker")
+
+ worker_cache = TestCache()
+ worker_hs.get_module_api().register_cached_function(
+ worker_cache.cached_function
+ )
+
+ self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
+ self.assertEqual(
+ FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))
+ )
+
+ main_cache.current_value = SECOND_VALUE
+ worker_cache.current_value = SECOND_VALUE
+ # No invalidation yet, should return the cached value on both the main process and the worker
+ self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
+ self.assertEqual(
+ FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))
+ )
+
+ # Full invalidation on the main process, should be replicated on the worker that
+ # should returned the updated value too
+ self.get_success(
+ self.hs.get_module_api().invalidate_cache(
+ main_cache.cached_function, (KEY,)
+ )
+ )
+
+ self.assertEqual(
+ SECOND_VALUE, self.get_success(main_cache.cached_function(KEY))
+ )
+ self.assertEqual(
+ SECOND_VALUE, self.get_success(worker_cache.cached_function(KEY))
+ )
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 13aa5eb51a..96cdf2c45b 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -15,8 +15,9 @@ import logging
import os
from typing import Optional, Tuple
+from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
from twisted.internet.protocol import Factory
-from twisted.protocols.tls import TLSMemoryBIOFactory
+from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.web.http import HTTPChannel
from twisted.web.server import Request
@@ -102,7 +103,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
)
# fish the test server back out of the server-side TLS protocol.
- http_server = server_tls_protocol.wrappedProtocol
+ http_server: HTTPChannel = server_tls_protocol.wrappedProtocol # type: ignore[assignment]
# give the reactor a pump to get the TLS juices flowing.
self.reactor.pump((0.1,))
@@ -238,16 +239,15 @@ def get_connection_factory():
return test_server_connection_factory
-def _build_test_server(connection_creator):
+def _build_test_server(
+ connection_creator: IOpenSSLServerConnectionCreator,
+) -> TLSMemoryBIOProtocol:
"""Construct a test server
This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol
Args:
- connection_creator (IOpenSSLServerConnectionCreator): thing to build
- SSL connections
- sanlist (list[bytes]): list of the SAN entries for the cert returned
- by the server
+ connection_creator: thing to build SSL connections
Returns:
TLSMemoryBIOProtocol
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 8f4f6688ce..59fea93e49 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
token_id = user_dict.token_id
self.get_success(
- self.hs.get_pusherpool().add_pusher(
+ self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index a7ca68069e..541d390286 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -20,7 +20,6 @@ from synapse.storage.util.id_generators import MultiWriterIdGenerator
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request
-from tests.utils import USE_POSTGRES_FOR_TESTS
logger = logging.getLogger(__name__)
@@ -28,11 +27,6 @@ logger = logging.getLogger(__name__)
class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
"""Checks event persisting sharding works"""
- # Event persister sharding requires postgres (due to needing
- # `MultiWriterIdGenerator`).
- if not USE_POSTGRES_FOR_TESTS:
- skip = "Requires Postgres"
-
servlets = [
admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -50,7 +44,6 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
def default_config(self):
conf = super().default_config()
- conf["redis"] = {"enabled": "true"}
conf["stream_writers"] = {"events": ["worker1", "worker2"]}
conf["instance_map"] = {
"worker1": {"host": "testserv", "port": 1001},
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 82ac5991e6..a8f6436836 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -13,7 +13,6 @@
# limitations under the License.
import urllib.parse
-from http import HTTPStatus
from parameterized import parameterized
@@ -42,7 +41,7 @@ class VersionTestCase(unittest.HomeserverTestCase):
def test_version_string(self) -> None:
channel = self.make_request("GET", self.url, shorthand=False)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(
{"server_version", "python_version"}, set(channel.json_body.keys())
)
@@ -79,10 +78,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Should be quarantined
self.assertEqual(
- HTTPStatus.NOT_FOUND,
+ 404,
channel.code,
msg=(
- "Expected to receive a HTTPStatus.NOT_FOUND on accessing quarantined media: %s"
+ "Expected to receive a 404 on accessing quarantined media: %s"
% server_and_media_id
),
)
@@ -107,7 +106,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Expect a forbidden error
self.assertEqual(
- HTTPStatus.FORBIDDEN,
+ 403,
channel.code,
msg="Expected forbidden on quarantining media as a non-admin",
)
@@ -139,7 +138,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
)
# Should be successful
- self.assertEqual(HTTPStatus.OK, channel.code)
+ self.assertEqual(200, channel.code)
# Quarantine the media
url = "/_synapse/admin/v1/media/quarantine/%s/%s" % (
@@ -152,7 +151,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
access_token=admin_user_tok,
)
self.pump(1.0)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Attempt to access the media
self._ensure_quarantined(admin_user_tok, server_name_and_media_id)
@@ -209,7 +208,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
access_token=admin_user_tok,
)
self.pump(1.0)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(
channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items"
)
@@ -251,7 +250,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
access_token=admin_user_tok,
)
self.pump(1.0)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(
channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items"
)
@@ -285,7 +284,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/media/protect/%s" % (urllib.parse.quote(media_id_2),)
channel = self.make_request("POST", url, access_token=admin_user_tok)
self.pump(1.0)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Quarantine all media by this user
url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
@@ -297,7 +296,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
access_token=admin_user_tok,
)
self.pump(1.0)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(
channel.json_body, {"num_quarantined": 1}, "Expected 1 quarantined item"
)
@@ -318,10 +317,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Shouldn't be quarantined
self.assertEqual(
- HTTPStatus.OK,
+ 200,
channel.code,
msg=(
- "Expected to receive a HTTPStatus.OK on accessing not-quarantined media: %s"
+ "Expected to receive a 200 on accessing not-quarantined media: %s"
% server_and_media_id_2
),
)
@@ -350,7 +349,7 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase):
def test_purge_history(self) -> None:
"""
Simple test of purge history API.
- Test only that is is possible to call, get status HTTPStatus.OK and purge_id.
+ Test only that is is possible to call, get status 200 and purge_id.
"""
channel = self.make_request(
@@ -360,7 +359,7 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("purge_id", channel.json_body)
purge_id = channel.json_body["purge_id"]
@@ -371,5 +370,5 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("complete", channel.json_body["status"])
diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py
index 6cf56b1e35..d507a3af8d 100644
--- a/tests/rest/admin/test_background_updates.py
+++ b/tests/rest/admin/test_background_updates.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
from typing import Collection
from parameterized import parameterized
@@ -51,7 +50,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
)
def test_requester_is_no_admin(self, method: str, url: str) -> None:
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
self.register_user("user", "pass", admin=False)
@@ -64,7 +63,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
access_token=other_user_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self) -> None:
@@ -81,7 +80,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
# job_name invalid
@@ -92,7 +91,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
def _register_bg_update(self) -> None:
@@ -125,7 +124,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Background updates should be enabled, but none should be running.
self.assertDictEqual(
@@ -147,7 +146,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Background updates should be enabled, and one should be running.
self.assertDictEqual(
@@ -181,7 +180,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/enabled",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertDictEqual(channel.json_body, {"enabled": True})
# Disable the BG updates
@@ -191,7 +190,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
content={"enabled": False},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertDictEqual(channel.json_body, {"enabled": False})
# Advance a bit and get the current status, note this will finish the in
@@ -204,7 +203,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertDictEqual(
channel.json_body,
{
@@ -231,7 +230,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# There should be no change from the previous /status response.
self.assertDictEqual(
@@ -259,7 +258,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
content={"enabled": True},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertDictEqual(channel.json_body, {"enabled": True})
@@ -270,7 +269,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Background updates should be enabled and making progress.
self.assertDictEqual(
@@ -325,7 +324,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# test that each background update is waiting now
for update in updates:
@@ -365,4 +364,4 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index f7080bda87..03f2112b07 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import urllib.parse
-from http import HTTPStatus
from parameterized import parameterized
@@ -20,6 +19,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.errors import Codes
+from synapse.handlers.device import DeviceHandler
from synapse.rest.client import login
from synapse.server import HomeServer
from synapse.util import Clock
@@ -35,7 +35,9 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.handler = handler
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -58,7 +60,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(method, self.url, b"{}")
self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
+ 401,
channel.code,
msg=channel.json_body,
)
@@ -76,7 +78,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
- HTTPStatus.FORBIDDEN,
+ 403,
channel.code,
msg=channel.json_body,
)
@@ -85,7 +87,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
@parameterized.expand(["GET", "PUT", "DELETE"])
def test_user_does_not_exist(self, method: str) -> None:
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
url = (
"/_synapse/admin/v2/users/@unknown_person:test/devices/%s"
@@ -98,13 +100,13 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@parameterized.expand(["GET", "PUT", "DELETE"])
def test_user_is_not_local(self, method: str) -> None:
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
url = (
"/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices/%s"
@@ -117,12 +119,12 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
def test_unknown_device(self) -> None:
"""
- Tests that a lookup for a device that does not exist returns either HTTPStatus.NOT_FOUND or HTTPStatus.OK.
+ Tests that a lookup for a device that does not exist returns either 404 or 200.
"""
url = "/_synapse/admin/v2/users/%s/devices/unknown_device" % urllib.parse.quote(
self.other_user
@@ -134,7 +136,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
channel = self.make_request(
@@ -143,7 +145,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
channel = self.make_request(
"DELETE",
@@ -151,8 +153,8 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- # Delete unknown device returns status HTTPStatus.OK
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ # Delete unknown device returns status 200
+ self.assertEqual(200, channel.code, msg=channel.json_body)
def test_update_device_too_long_display_name(self) -> None:
"""
@@ -179,7 +181,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
content=update,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"])
# Ensure the display name was not updated.
@@ -189,12 +191,12 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("new display", channel.json_body["display_name"])
def test_update_no_display_name(self) -> None:
"""
- Tests that a update for a device without JSON returns a HTTPStatus.OK
+ Tests that a update for a device without JSON returns a 200
"""
# Set iniital display name.
update = {"display_name": "new display"}
@@ -210,7 +212,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Ensure the display name was not updated.
channel = self.make_request(
@@ -219,7 +221,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("new display", channel.json_body["display_name"])
def test_update_display_name(self) -> None:
@@ -234,7 +236,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
content={"display_name": "new displayname"},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Check new display_name
channel = self.make_request(
@@ -243,7 +245,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("new displayname", channel.json_body["display_name"])
def test_get_device(self) -> None:
@@ -256,7 +258,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
# Check that all fields are available
self.assertIn("user_id", channel.json_body)
@@ -281,7 +283,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Ensure that the number of devices is decreased
res = self.get_success(self.handler.get_devices_by_user(self.other_user))
@@ -312,7 +314,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request("GET", self.url, b"{}")
self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
+ 401,
channel.code,
msg=channel.json_body,
)
@@ -331,7 +333,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
- HTTPStatus.FORBIDDEN,
+ 403,
channel.code,
msg=channel.json_body,
)
@@ -339,7 +341,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
def test_user_does_not_exist(self) -> None:
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v2/users/@unknown_person:test/devices"
channel = self.make_request(
@@ -348,12 +350,12 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_user_is_not_local(self) -> None:
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices"
@@ -363,7 +365,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
def test_user_has_no_devices(self) -> None:
@@ -379,7 +381,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["devices"]))
@@ -399,7 +401,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_devices, channel.json_body["total"])
self.assertEqual(number_devices, len(channel.json_body["devices"]))
self.assertEqual(self.other_user, channel.json_body["devices"][0]["user_id"])
@@ -438,7 +440,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, b"{}")
self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
+ 401,
channel.code,
msg=channel.json_body,
)
@@ -457,7 +459,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
- HTTPStatus.FORBIDDEN,
+ 403,
channel.code,
msg=channel.json_body,
)
@@ -465,7 +467,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
def test_user_does_not_exist(self) -> None:
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices"
channel = self.make_request(
@@ -474,12 +476,12 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_user_is_not_local(self) -> None:
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices"
@@ -489,12 +491,12 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
def test_unknown_devices(self) -> None:
"""
- Tests that a remove of a device that does not exist returns HTTPStatus.OK.
+ Tests that a remove of a device that does not exist returns 200.
"""
channel = self.make_request(
"POST",
@@ -503,8 +505,8 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
content={"devices": ["unknown_device1", "unknown_device2"]},
)
- # Delete unknown devices returns status HTTPStatus.OK
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ # Delete unknown devices returns status 200
+ self.assertEqual(200, channel.code, msg=channel.json_body)
def test_delete_devices(self) -> None:
"""
@@ -533,7 +535,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
content={"devices": device_ids},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
res = self.get_success(self.handler.get_devices_by_user(self.other_user))
self.assertEqual(0, len(res))
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index 4f89f8b534..8a4e5c3f77 100644
--- a/tests/rest/admin/test_event_reports.py
+++ b/tests/rest/admin/test_event_reports.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
from typing import List
from twisted.test.proto_helpers import MemoryReactor
@@ -81,16 +80,12 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
channel = self.make_request(
@@ -99,11 +94,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_default_success(self) -> None:
@@ -117,7 +108,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 20)
self.assertNotIn("next_token", channel.json_body)
@@ -134,7 +125,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 5)
self.assertEqual(channel.json_body["next_token"], 5)
@@ -151,7 +142,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 15)
self.assertNotIn("next_token", channel.json_body)
@@ -168,7 +159,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(channel.json_body["next_token"], 15)
self.assertEqual(len(channel.json_body["event_reports"]), 10)
@@ -185,7 +176,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 10)
self.assertEqual(len(channel.json_body["event_reports"]), 10)
self.assertNotIn("next_token", channel.json_body)
@@ -205,7 +196,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 10)
self.assertEqual(len(channel.json_body["event_reports"]), 10)
self.assertNotIn("next_token", channel.json_body)
@@ -225,7 +216,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 5)
self.assertEqual(len(channel.json_body["event_reports"]), 5)
self.assertNotIn("next_token", channel.json_body)
@@ -247,7 +238,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 20)
report = 1
@@ -265,7 +256,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 20)
report = 1
@@ -278,7 +269,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
def test_invalid_search_order(self) -> None:
"""
- Testing that a invalid search order returns a HTTPStatus.BAD_REQUEST
+ Testing that a invalid search order returns a 400
"""
channel = self.make_request(
@@ -287,17 +278,13 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual("Unknown direction: bar", channel.json_body["error"])
def test_limit_is_negative(self) -> None:
"""
- Testing that a negative limit parameter returns a HTTPStatus.BAD_REQUEST
+ Testing that a negative limit parameter returns a 400
"""
channel = self.make_request(
@@ -306,16 +293,12 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_from_is_negative(self) -> None:
"""
- Testing that a negative from parameter returns a HTTPStatus.BAD_REQUEST
+ Testing that a negative from parameter returns a 400
"""
channel = self.make_request(
@@ -324,11 +307,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_next_token(self) -> None:
@@ -344,7 +323,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 20)
self.assertNotIn("next_token", channel.json_body)
@@ -357,7 +336,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 20)
self.assertNotIn("next_token", channel.json_body)
@@ -370,7 +349,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 19)
self.assertEqual(channel.json_body["next_token"], 19)
@@ -384,7 +363,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 1)
self.assertNotIn("next_token", channel.json_body)
@@ -400,7 +379,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
{"score": -100, "reason": "this makes me sad"},
access_token=user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
def _create_event_and_report_without_parameters(
self, room_id: str, user_tok: str
@@ -415,7 +394,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
{},
access_token=user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
def _check_fields(self, content: List[JsonDict]) -> None:
"""Checks that all attributes are present in an event report"""
@@ -431,6 +410,33 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.assertIn("score", c)
self.assertIn("reason", c)
+ def test_count_correct_despite_table_deletions(self) -> None:
+ """
+ Tests that the count matches the number of rows, even if rows in joined tables
+ are missing.
+ """
+
+ # Delete rows from room_stats_state for one of our rooms.
+ self.get_success(
+ self.hs.get_datastores().main.db_pool.simple_delete(
+ "room_stats_state", {"room_id": self.room_id1}, desc="_"
+ )
+ )
+
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ # The 'total' field is 10 because only 10 reports will actually
+ # be retrievable since we deleted the rows in the room_stats_state
+ # table.
+ self.assertEqual(channel.json_body["total"], 10)
+ # This is consistent with the number of rows actually returned.
+ self.assertEqual(len(channel.json_body["event_reports"]), 10)
+
class EventReportDetailTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -466,16 +472,12 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
channel = self.make_request(
@@ -484,11 +486,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_default_success(self) -> None:
@@ -502,12 +500,12 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self._check_fields(channel.json_body)
def test_invalid_report_id(self) -> None:
"""
- Testing that an invalid `report_id` returns a HTTPStatus.BAD_REQUEST.
+ Testing that an invalid `report_id` returns a 400.
"""
# `report_id` is negative
@@ -517,11 +515,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"The report_id parameter must be a string representing a positive integer.",
@@ -535,11 +529,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"The report_id parameter must be a string representing a positive integer.",
@@ -553,11 +543,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"The report_id parameter must be a string representing a positive integer.",
@@ -566,7 +552,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
def test_report_id_not_found(self) -> None:
"""
- Testing that a not existing `report_id` returns a HTTPStatus.NOT_FOUND.
+ Testing that a not existing `report_id` returns a 404.
"""
channel = self.make_request(
@@ -575,11 +561,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.NOT_FOUND,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
self.assertEqual("Event report not found", channel.json_body["error"])
@@ -594,7 +576,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
{"score": -100, "reason": "this makes me sad"},
access_token=user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
def _check_fields(self, content: JsonDict) -> None:
"""Checks that all attributes are present in a event report"""
diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py
index 929bbdc37d..4c7864c629 100644
--- a/tests/rest/admin/test_federation.py
+++ b/tests/rest/admin/test_federation.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
from typing import List, Optional
from parameterized import parameterized
@@ -64,7 +63,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=other_user_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self) -> None:
@@ -77,7 +76,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
@@ -87,7 +86,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# unkown order_by
@@ -97,7 +96,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid search order
@@ -107,7 +106,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid destination
@@ -117,7 +116,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
# invalid destination
@@ -127,7 +126,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_limit(self) -> None:
@@ -142,7 +141,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_destinations)
self.assertEqual(len(channel.json_body["destinations"]), 5)
self.assertEqual(channel.json_body["next_token"], "5")
@@ -160,7 +159,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_destinations)
self.assertEqual(len(channel.json_body["destinations"]), 15)
self.assertNotIn("next_token", channel.json_body)
@@ -178,7 +177,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_destinations)
self.assertEqual(channel.json_body["next_token"], "15")
self.assertEqual(len(channel.json_body["destinations"]), 10)
@@ -198,7 +197,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_destinations)
self.assertEqual(len(channel.json_body["destinations"]), number_destinations)
self.assertNotIn("next_token", channel.json_body)
@@ -211,7 +210,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_destinations)
self.assertEqual(len(channel.json_body["destinations"]), number_destinations)
self.assertNotIn("next_token", channel.json_body)
@@ -224,7 +223,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_destinations)
self.assertEqual(len(channel.json_body["destinations"]), 19)
self.assertEqual(channel.json_body["next_token"], "19")
@@ -238,7 +237,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_destinations)
self.assertEqual(len(channel.json_body["destinations"]), 1)
self.assertNotIn("next_token", channel.json_body)
@@ -255,7 +254,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_destinations, len(channel.json_body["destinations"]))
self.assertEqual(number_destinations, channel.json_body["total"])
@@ -290,7 +289,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], len(expected_destination_list))
returned_order = [
@@ -376,7 +375,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Check that destinations were returned
self.assertTrue("destinations" in channel.json_body)
@@ -418,7 +417,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("sub0.example.com", channel.json_body["destination"])
# Check that all fields are available
@@ -435,7 +434,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("sub0.example.com", channel.json_body["destination"])
self.assertEqual(0, channel.json_body["retry_last_ts"])
self.assertEqual(0, channel.json_body["retry_interval"])
@@ -452,7 +451,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
retry_timings = self.get_success(
self.store.get_destination_retry_timings("sub0.example.com")
@@ -469,7 +468,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"The retry timing does not need to be reset for this destination.",
channel.json_body["error"],
@@ -561,7 +560,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=other_user_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self) -> None:
@@ -574,7 +573,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
@@ -584,7 +583,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid search order
@@ -594,7 +593,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid destination
@@ -604,7 +603,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_limit(self) -> None:
@@ -619,7 +618,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(len(channel.json_body["rooms"]), 3)
self.assertEqual(channel.json_body["next_token"], "3")
@@ -637,7 +636,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(len(channel.json_body["rooms"]), 5)
self.assertNotIn("next_token", channel.json_body)
@@ -655,7 +654,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(channel.json_body["next_token"], "8")
self.assertEqual(len(channel.json_body["rooms"]), 5)
@@ -673,7 +672,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel_asc.code, msg=channel_asc.json_body)
+ self.assertEqual(200, channel_asc.code, msg=channel_asc.json_body)
self.assertEqual(channel_asc.json_body["total"], number_rooms)
self.assertEqual(number_rooms, len(channel_asc.json_body["rooms"]))
self._check_fields(channel_asc.json_body["rooms"])
@@ -685,7 +684,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel_desc.code, msg=channel_desc.json_body)
+ self.assertEqual(200, channel_desc.code, msg=channel_desc.json_body)
self.assertEqual(channel_desc.json_body["total"], number_rooms)
self.assertEqual(number_rooms, len(channel_desc.json_body["rooms"]))
self._check_fields(channel_desc.json_body["rooms"])
@@ -711,7 +710,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(len(channel.json_body["rooms"]), number_rooms)
self.assertNotIn("next_token", channel.json_body)
@@ -724,7 +723,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(len(channel.json_body["rooms"]), number_rooms)
self.assertNotIn("next_token", channel.json_body)
@@ -737,7 +736,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(len(channel.json_body["rooms"]), 4)
self.assertEqual(channel.json_body["next_token"], "4")
@@ -751,7 +750,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(len(channel.json_body["rooms"]), 1)
self.assertNotIn("next_token", channel.json_body)
@@ -767,7 +766,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(number_rooms, len(channel.json_body["rooms"]))
self._check_fields(channel.json_body["rooms"])
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index e909e444ac..aadb31ca83 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
-from http import HTTPStatus
from parameterized import parameterized
@@ -60,7 +59,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
channel = self.make_request("DELETE", url, b"{}")
self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
+ 401,
channel.code,
msg=channel.json_body,
)
@@ -81,16 +80,12 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_media_does_not_exist(self) -> None:
"""
- Tests that a lookup for a media that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a media that does not exist returns a 404
"""
url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
@@ -100,12 +95,12 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_media_is_not_local(self) -> None:
"""
- Tests that a lookup for a media that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a media that is not a local returns a 400
"""
url = "/_synapse/admin/v1/media/%s/%s" % ("unknown_domain", "12345")
@@ -115,7 +110,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only delete local media", channel.json_body["error"])
def test_delete_media(self) -> None:
@@ -131,7 +126,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
upload_resource,
SMALL_PNG,
tok=self.admin_user_tok,
- expect_code=HTTPStatus.OK,
+ expect_code=200,
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
@@ -151,11 +146,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
# Should be successful
self.assertEqual(
- HTTPStatus.OK,
+ 200,
channel.code,
msg=(
- "Expected to receive a HTTPStatus.OK on accessing media: %s"
- % server_and_media_id
+ "Expected to receive a 200 on accessing media: %s" % server_and_media_id
),
)
@@ -172,7 +166,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
media_id,
@@ -189,10 +183,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
self.assertEqual(
- HTTPStatus.NOT_FOUND,
+ 404,
channel.code,
msg=(
- "Expected to receive a HTTPStatus.NOT_FOUND on accessing deleted media: %s"
+ "Expected to receive a 404 on accessing deleted media: %s"
% server_and_media_id
),
)
@@ -231,11 +225,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
@@ -251,16 +241,12 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_media_is_not_local(self) -> None:
"""
- Tests that a lookup for media that is not local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for media that is not local returns a 400
"""
url = "/_synapse/admin/v1/media/%s/delete" % "unknown_domain"
@@ -270,7 +256,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only delete local media", channel.json_body["error"])
def test_missing_parameter(self) -> None:
@@ -283,11 +269,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Missing integer query parameter 'before_ts'", channel.json_body["error"]
@@ -303,11 +285,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Query parameter before_ts must be a positive integer.",
@@ -320,11 +298,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Query parameter before_ts you provided is from the year 1970. "
@@ -338,11 +312,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Query parameter size_gt must be a string representing a positive integer.",
@@ -355,11 +325,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Boolean query parameter 'keep_profiles' must be one of ['true', 'false']",
@@ -388,7 +354,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
media_id,
@@ -413,7 +379,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self._access_media(server_and_media_id)
@@ -425,7 +391,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
server_and_media_id.split("/")[1],
@@ -449,7 +415,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&size_gt=67",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self._access_media(server_and_media_id)
@@ -460,7 +426,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&size_gt=66",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
server_and_media_id.split("/")[1],
@@ -485,7 +451,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
content={"avatar_url": "mxc://%s" % (server_and_media_id,)},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
now_ms = self.clock.time_msec()
channel = self.make_request(
@@ -493,7 +459,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self._access_media(server_and_media_id)
@@ -504,7 +470,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
server_and_media_id.split("/")[1],
@@ -530,7 +496,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
content={"url": "mxc://%s" % (server_and_media_id,)},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
now_ms = self.clock.time_msec()
channel = self.make_request(
@@ -538,7 +504,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self._access_media(server_and_media_id)
@@ -549,7 +515,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
server_and_media_id.split("/")[1],
@@ -569,7 +535,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
upload_resource,
SMALL_PNG,
tok=self.admin_user_tok,
- expect_code=HTTPStatus.OK,
+ expect_code=200,
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
@@ -602,10 +568,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
if expect_success:
self.assertEqual(
- HTTPStatus.OK,
+ 200,
channel.code,
msg=(
- "Expected to receive a HTTPStatus.OK on accessing media: %s"
+ "Expected to receive a 200 on accessing media: %s"
% server_and_media_id
),
)
@@ -613,10 +579,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.assertTrue(os.path.exists(local_path))
else:
self.assertEqual(
- HTTPStatus.NOT_FOUND,
+ 404,
channel.code,
msg=(
- "Expected to receive a HTTPStatus.NOT_FOUND on accessing deleted media: %s"
+ "Expected to receive a 404 on accessing deleted media: %s"
% (server_and_media_id)
),
)
@@ -648,7 +614,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
upload_resource,
SMALL_PNG,
tok=self.admin_user_tok,
- expect_code=HTTPStatus.OK,
+ expect_code=200,
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
@@ -668,11 +634,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
b"{}",
)
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["quarantine", "unquarantine"])
@@ -689,11 +651,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_quarantine_media(self) -> None:
@@ -712,7 +670,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id))
@@ -726,7 +684,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id))
@@ -753,7 +711,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body)
# verify that is not in quarantine
@@ -785,7 +743,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
upload_resource,
SMALL_PNG,
tok=self.admin_user_tok,
- expect_code=HTTPStatus.OK,
+ expect_code=200,
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
@@ -801,11 +759,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url % (action, self.media_id), b"{}")
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["protect", "unprotect"])
@@ -822,11 +776,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_protect_media(self) -> None:
@@ -845,7 +795,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id))
@@ -859,7 +809,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id))
@@ -895,7 +845,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, b"{}")
self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
+ 401,
channel.code,
msg=channel.json_body,
)
@@ -914,11 +864,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self) -> None:
@@ -931,11 +877,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Query parameter before_ts must be a positive integer.",
@@ -948,11 +890,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Query parameter before_ts you provided is from the year 1970. "
diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py
index 8354250ec2..8f8abc21c7 100644
--- a/tests/rest/admin/test_registration_tokens.py
+++ b/tests/rest/admin/test_registration_tokens.py
@@ -13,7 +13,6 @@
# limitations under the License.
import random
import string
-from http import HTTPStatus
from typing import Optional
from twisted.test.proto_helpers import MemoryReactor
@@ -74,11 +73,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
def test_create_no_auth(self) -> None:
"""Try to create a token without authentication."""
channel = self.make_request("POST", self.url + "/new", {})
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_create_requester_not_admin(self) -> None:
@@ -89,11 +84,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{},
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_create_using_defaults(self) -> None:
@@ -105,7 +96,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(len(channel.json_body["token"]), 16)
self.assertIsNone(channel.json_body["uses_allowed"])
self.assertIsNone(channel.json_body["expiry_time"])
@@ -129,7 +120,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["token"], token)
self.assertEqual(channel.json_body["uses_allowed"], 1)
self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"])
@@ -150,7 +141,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(len(channel.json_body["token"]), 16)
self.assertIsNone(channel.json_body["uses_allowed"])
self.assertIsNone(channel.json_body["expiry_time"])
@@ -168,11 +159,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_token_invalid_chars(self) -> None:
@@ -188,11 +175,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_token_already_exists(self) -> None:
@@ -207,7 +190,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
data,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel1.code, msg=channel1.json_body)
+ self.assertEqual(200, channel1.code, msg=channel1.json_body)
channel2 = self.make_request(
"POST",
@@ -215,7 +198,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
data,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel2.code, msg=channel2.json_body)
+ self.assertEqual(400, channel2.code, msg=channel2.json_body)
self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_unable_to_generate_token(self) -> None:
@@ -251,7 +234,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": 0},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["uses_allowed"], 0)
# Should fail with negative integer
@@ -262,7 +245,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
self.assertEqual(
- HTTPStatus.BAD_REQUEST,
+ 400,
channel.code,
msg=channel.json_body,
)
@@ -275,11 +258,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": 1.5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_expiry_time(self) -> None:
@@ -291,11 +270,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": self.clock.time_msec() - 10000},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with float
@@ -305,11 +280,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": self.clock.time_msec() + 1000000.5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_length(self) -> None:
@@ -321,7 +292,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": 64},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(len(channel.json_body["token"]), 64)
# Should fail with 0
@@ -331,11 +302,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": 0},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with a negative integer
@@ -345,11 +312,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": -5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with a float
@@ -359,11 +322,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": 8.5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with 65
@@ -373,11 +332,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": 65},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# UPDATING
@@ -389,11 +344,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
self.url + "/1234", # Token doesn't exist but that doesn't matter
{},
)
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_update_requester_not_admin(self) -> None:
@@ -404,11 +355,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{},
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_update_non_existent(self) -> None:
@@ -420,11 +367,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.NOT_FOUND,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
def test_update_uses_allowed(self) -> None:
@@ -439,7 +382,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": 1},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["uses_allowed"], 1)
self.assertIsNone(channel.json_body["expiry_time"])
@@ -450,7 +393,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": 0},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["uses_allowed"], 0)
self.assertIsNone(channel.json_body["expiry_time"])
@@ -461,7 +404,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": None},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIsNone(channel.json_body["uses_allowed"])
self.assertIsNone(channel.json_body["expiry_time"])
@@ -472,11 +415,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": 1.5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with a negative integer
@@ -486,11 +425,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": -5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_update_expiry_time(self) -> None:
@@ -506,7 +441,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": new_expiry_time},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
self.assertIsNone(channel.json_body["uses_allowed"])
@@ -517,7 +452,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": None},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIsNone(channel.json_body["expiry_time"])
self.assertIsNone(channel.json_body["uses_allowed"])
@@ -529,11 +464,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": past_time},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail a float
@@ -543,11 +474,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": new_expiry_time + 0.5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_update_both(self) -> None:
@@ -568,7 +495,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["uses_allowed"], 1)
self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
@@ -589,11 +516,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# DELETING
@@ -605,11 +528,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
self.url + "/1234", # Token doesn't exist but that doesn't matter
{},
)
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_delete_requester_not_admin(self) -> None:
@@ -620,11 +539,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{},
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_delete_non_existent(self) -> None:
@@ -636,11 +551,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.NOT_FOUND,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
def test_delete(self) -> None:
@@ -655,7 +566,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# GETTING ONE
@@ -666,11 +577,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
self.url + "/1234", # Token doesn't exist but that doesn't matter
{},
)
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_get_requester_not_admin(self) -> None:
@@ -682,7 +589,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
self.assertEqual(
- HTTPStatus.FORBIDDEN,
+ 403,
channel.code,
msg=channel.json_body,
)
@@ -697,11 +604,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.NOT_FOUND,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
def test_get(self) -> None:
@@ -716,7 +619,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["token"], token)
self.assertIsNone(channel.json_body["uses_allowed"])
self.assertIsNone(channel.json_body["expiry_time"])
@@ -728,11 +631,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
def test_list_no_auth(self) -> None:
"""Try to list tokens without authentication."""
channel = self.make_request("GET", self.url, {})
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_list_requester_not_admin(self) -> None:
@@ -743,11 +642,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{},
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_list_all(self) -> None:
@@ -762,7 +657,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(len(channel.json_body["registration_tokens"]), 1)
token_info = channel.json_body["registration_tokens"][0]
self.assertEqual(token_info["token"], token)
@@ -780,11 +675,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
def _test_list_query_parameter(self, valid: str) -> None:
"""Helper used to test both valid=true and valid=false."""
@@ -816,7 +707,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(len(channel.json_body["registration_tokens"]), 2)
token_info_1 = channel.json_body["registration_tokens"][0]
token_info_2 = channel.json_body["registration_tokens"][1]
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 989cbdb5e2..e0f5d54aba 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import json
+import time
import urllib.parse
-from http import HTTPStatus
from typing import List, Optional
from unittest.mock import Mock
@@ -23,10 +24,11 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.constants import EventTypes, Membership, RoomTypes
from synapse.api.errors import Codes
-from synapse.handlers.pagination import PaginationHandler
+from synapse.handlers.pagination import PaginationHandler, PurgeStatus
from synapse.rest.client import directory, events, login, room
from synapse.server import HomeServer
from synapse.util import Clock
+from synapse.util.stringutils import random_string
from tests import unittest
@@ -68,7 +70,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
def test_requester_is_no_admin(self) -> None:
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
channel = self.make_request(
@@ -78,7 +80,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_room_does_not_exist(self) -> None:
@@ -94,11 +96,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
def test_room_is_not_valid(self) -> None:
"""
- Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.
+ Check that invalid room names, return an error 400.
"""
url = "/_synapse/admin/v1/rooms/%s" % "invalidroom"
@@ -109,7 +111,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"invalidroom is not a legal room ID",
channel.json_body["error"],
@@ -127,7 +129,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("new_room_id", channel.json_body)
self.assertIn("kicked_users", channel.json_body)
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -145,7 +147,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"User must be our own: @not:exist.bla",
channel.json_body["error"],
@@ -163,7 +165,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_purge_is_not_bool(self) -> None:
@@ -178,7 +180,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_purge_room_and_block(self) -> None:
@@ -202,7 +204,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(None, channel.json_body["new_room_id"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -233,7 +235,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(None, channel.json_body["new_room_id"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -265,7 +267,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(None, channel.json_body["new_room_id"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -296,7 +298,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
)
# The room is now blocked.
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self._is_blocked(room_id)
def test_shutdown_room_consent(self) -> None:
@@ -319,7 +321,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
self.room_id,
body="foo",
tok=self.other_user_tok,
- expect_code=HTTPStatus.FORBIDDEN,
+ expect_code=403,
)
# Test that room is not purged
@@ -337,7 +339,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("new_room_id", channel.json_body)
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -366,7 +368,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
{"history_visibility": "world_readable"},
access_token=self.other_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Test that room is not purged
with self.assertRaises(AssertionError):
@@ -383,7 +385,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("new_room_id", channel.json_body)
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -398,7 +400,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
self._has_no_members(self.room_id)
# Assert we can no longer peek into the room
- self._assert_peek(self.room_id, expect_code=HTTPStatus.FORBIDDEN)
+ self._assert_peek(self.room_id, expect_code=403)
def _is_blocked(self, room_id: str, expect: bool = True) -> None:
"""Assert that the room is blocked or not"""
@@ -494,7 +496,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
)
def test_requester_is_no_admin(self, method: str, url: str) -> None:
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
channel = self.make_request(
@@ -504,7 +506,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_room_does_not_exist(self) -> None:
@@ -522,7 +524,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body)
delete_id = channel.json_body["delete_id"]
@@ -533,7 +535,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, len(channel.json_body["results"]))
self.assertEqual("complete", channel.json_body["results"][0]["status"])
self.assertEqual(delete_id, channel.json_body["results"][0]["delete_id"])
@@ -546,7 +548,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
)
def test_room_is_not_valid(self, method: str, url: str) -> None:
"""
- Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.
+ Check that invalid room names, return an error 400.
"""
channel = self.make_request(
@@ -556,7 +558,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"invalidroom is not a legal room ID",
channel.json_body["error"],
@@ -574,7 +576,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body)
delete_id = channel.json_body["delete_id"]
@@ -592,7 +594,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"User must be our own: @not:exist.bla",
channel.json_body["error"],
@@ -610,7 +612,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_purge_is_not_bool(self) -> None:
@@ -625,7 +627,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_delete_expired_status(self) -> None:
@@ -639,7 +641,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body)
delete_id1 = channel.json_body["delete_id"]
@@ -654,7 +656,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body)
delete_id2 = channel.json_body["delete_id"]
@@ -665,7 +667,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(2, len(channel.json_body["results"]))
self.assertEqual("complete", channel.json_body["results"][0]["status"])
self.assertEqual("complete", channel.json_body["results"][1]["status"])
@@ -682,7 +684,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, len(channel.json_body["results"]))
self.assertEqual("complete", channel.json_body["results"][0]["status"])
self.assertEqual(delete_id2, channel.json_body["results"][0]["delete_id"])
@@ -696,7 +698,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_delete_same_room_twice(self) -> None:
@@ -722,9 +724,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST, second_channel.code, msg=second_channel.json_body
- )
+ self.assertEqual(400, second_channel.code, msg=second_channel.json_body)
self.assertEqual(Codes.UNKNOWN, second_channel.json_body["errcode"])
self.assertEqual(
f"History purge already in progress for {self.room_id}",
@@ -733,7 +733,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
# get result of first call
first_channel.await_result()
- self.assertEqual(HTTPStatus.OK, first_channel.code, msg=first_channel.json_body)
+ self.assertEqual(200, first_channel.code, msg=first_channel.json_body)
self.assertIn("delete_id", first_channel.json_body)
# check status after finish the task
@@ -764,7 +764,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body)
delete_id = channel.json_body["delete_id"]
@@ -795,7 +795,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body)
delete_id = channel.json_body["delete_id"]
@@ -827,7 +827,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body)
delete_id = channel.json_body["delete_id"]
@@ -858,7 +858,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self.room_id,
body="foo",
tok=self.other_user_tok,
- expect_code=HTTPStatus.FORBIDDEN,
+ expect_code=403,
)
# Test that room is not purged
@@ -876,7 +876,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body)
delete_id = channel.json_body["delete_id"]
@@ -887,7 +887,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self.url_status_by_room_id,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, len(channel.json_body["results"]))
# Test that member has moved to new room
@@ -914,7 +914,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
content={"history_visibility": "world_readable"},
access_token=self.other_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Test that room is not purged
with self.assertRaises(AssertionError):
@@ -931,7 +931,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body)
delete_id = channel.json_body["delete_id"]
@@ -942,7 +942,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self.url_status_by_room_id,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, len(channel.json_body["results"]))
# Test that member has moved to new room
@@ -955,7 +955,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self._has_no_members(self.room_id)
# Assert we can no longer peek into the room
- self._assert_peek(self.room_id, expect_code=HTTPStatus.FORBIDDEN)
+ self._assert_peek(self.room_id, expect_code=403)
def _is_blocked(self, room_id: str, expect: bool = True) -> None:
"""Assert that the room is blocked or not"""
@@ -1026,9 +1026,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self.url_status_by_room_id,
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.OK, channel_room_id.code, msg=channel_room_id.json_body
- )
+ self.assertEqual(200, channel_room_id.code, msg=channel_room_id.json_body)
self.assertEqual(1, len(channel_room_id.json_body["results"]))
self.assertEqual(
delete_id, channel_room_id.json_body["results"][0]["delete_id"]
@@ -1041,7 +1039,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
self.assertEqual(
- HTTPStatus.OK,
+ 200,
channel_delete_id.code,
msg=channel_delete_id.json_body,
)
@@ -1085,7 +1083,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
room_ids = []
for _ in range(total_rooms):
room_id = self.helper.create_room_as(
- self.admin_user, tok=self.admin_user_tok
+ self.admin_user,
+ tok=self.admin_user_tok,
+ is_public=True,
)
room_ids.append(room_id)
@@ -1100,7 +1100,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
)
# Check request completed successfully
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Check that response json body contains a "rooms" key
self.assertTrue(
@@ -1124,8 +1124,8 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertIn("version", r)
self.assertIn("creator", r)
self.assertIn("encryption", r)
- self.assertIn("federatable", r)
- self.assertIn("public", r)
+ self.assertIs(r["federatable"], True)
+ self.assertIs(r["public"], True)
self.assertIn("join_rules", r)
self.assertIn("guest_access", r)
self.assertIn("history_visibility", r)
@@ -1186,7 +1186,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertTrue("rooms" in channel.json_body)
for r in channel.json_body["rooms"]:
@@ -1226,7 +1226,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
def test_correct_room_attributes(self) -> None:
"""Test the correct attributes for a room are returned"""
@@ -1253,7 +1253,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
{"room_id": room_id},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Set this new alias as the canonical alias for this room
self.helper.send_state(
@@ -1285,7 +1285,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Check that rooms were returned
self.assertTrue("rooms" in channel.json_body)
@@ -1341,7 +1341,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Check that rooms were returned
self.assertTrue("rooms" in channel.json_body)
@@ -1487,7 +1487,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
def _search_test(
expected_room_id: Optional[str],
search_term: str,
- expected_http_code: int = HTTPStatus.OK,
+ expected_http_code: int = 200,
) -> None:
"""Search for a room and check that the returned room's id is a match
@@ -1505,7 +1505,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
- if expected_http_code != HTTPStatus.OK:
+ if expected_http_code != 200:
return
# Check that rooms were returned
@@ -1548,7 +1548,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
_search_test(None, "foo")
_search_test(None, "bar")
- _search_test(None, "", expected_http_code=HTTPStatus.BAD_REQUEST)
+ _search_test(None, "", expected_http_code=400)
# Test that the whole room id returns the room
_search_test(room_id_1, room_id_1)
@@ -1585,15 +1585,19 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(room_id, channel.json_body["rooms"][0].get("room_id"))
self.assertEqual("ж", channel.json_body["rooms"][0].get("name"))
def test_single_room(self) -> None:
"""Test that a single room can be requested correctly"""
# Create two test rooms
- room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
- room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_1 = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok, is_public=True
+ )
+ room_id_2 = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok, is_public=False
+ )
room_name_1 = "something"
room_name_2 = "else"
@@ -1618,7 +1622,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("room_id", channel.json_body)
self.assertIn("name", channel.json_body)
@@ -1638,7 +1642,11 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertIn("history_visibility", channel.json_body)
self.assertIn("state_events", channel.json_body)
self.assertIn("room_type", channel.json_body)
+ self.assertIn("forgotten", channel.json_body)
+
self.assertEqual(room_id_1, channel.json_body["room_id"])
+ self.assertIs(True, channel.json_body["federatable"])
+ self.assertIs(True, channel.json_body["public"])
def test_single_room_devices(self) -> None:
"""Test that `joined_local_devices` can be requested correctly"""
@@ -1650,7 +1658,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["joined_local_devices"])
# Have another user join the room
@@ -1664,7 +1672,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(2, channel.json_body["joined_local_devices"])
# leave room
@@ -1676,7 +1684,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["joined_local_devices"])
def test_room_members(self) -> None:
@@ -1707,7 +1715,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertCountEqual(
["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"]
@@ -1720,7 +1728,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertCountEqual(
["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"]
@@ -1738,7 +1746,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("state", channel.json_body)
# testing that the state events match is painful and not done here. We assume that
# the create_room already does the right thing, so no need to verify that we got
@@ -1755,7 +1763,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
{"room_id": room_id},
access_token=admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Set this new alias as the canonical alias for this room
self.helper.send_state(
@@ -1784,10 +1792,203 @@ class RoomTestCase(unittest.HomeserverTestCase):
# delete the rooms and get joined roomed membership
url = f"/_matrix/client/r0/rooms/{room_id}/joined_members"
channel = self.make_request("GET", url.encode("ascii"), access_token=user_tok)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+class RoomMessagesTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.user = self.register_user("foo", "pass")
+ self.user_tok = self.login("foo", "pass")
+ self.room_id = self.helper.create_room_as(self.user, tok=self.user_tok)
+
+ def test_timestamp_to_event(self) -> None:
+ """Test that providing the current timestamp can get the last event."""
+ self.helper.send(self.room_id, body="message 1", tok=self.user_tok)
+ second_event_id = self.helper.send(
+ self.room_id, body="message 2", tok=self.user_tok
+ )["event_id"]
+ ts = str(round(time.time() * 1000))
+
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/rooms/%s/timestamp_to_event?dir=b&ts=%s"
+ % (self.room_id, ts),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code)
+ self.assertIn("event_id", channel.json_body)
+ self.assertEqual(second_event_id, channel.json_body["event_id"])
+
+ def test_topo_token_is_accepted(self) -> None:
+ """Test Topo Token is accepted."""
+ token = "t1-0_0_0_0_0_0_0_0_0"
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code)
+ self.assertIn("start", channel.json_body)
+ self.assertEqual(token, channel.json_body["start"])
+ self.assertIn("chunk", channel.json_body)
+ self.assertIn("end", channel.json_body)
+
+ def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None:
+ """Test that stream token is accepted for forward pagination."""
+ token = "s0_0_0_0_0_0_0_0_0"
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code)
+ self.assertIn("start", channel.json_body)
+ self.assertEqual(token, channel.json_body["start"])
+ self.assertIn("chunk", channel.json_body)
+ self.assertIn("end", channel.json_body)
+
+ def test_room_messages_backward(self) -> None:
+ """Test room messages can be retrieved by an admin that isn't in the room."""
+ latest_event_id = self.helper.send(
+ self.room_id, body="message 1", tok=self.user_tok
+ )["event_id"]
+
+ # Check that we get the first and second message when querying /messages.
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/rooms/%s/messages?dir=b" % (self.room_id,),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ chunk = channel.json_body["chunk"]
+ self.assertEqual(len(chunk), 6, [event["content"] for event in chunk])
+
+ # in backwards, this is the first event
+ self.assertEqual(chunk[0]["event_id"], latest_event_id)
+
+ def test_room_messages_forward(self) -> None:
+ """Test room messages can be retrieved by an admin that isn't in the room."""
+ latest_event_id = self.helper.send(
+ self.room_id, body="message 1", tok=self.user_tok
+ )["event_id"]
+
+ # Check that we get the first and second message when querying /messages.
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/rooms/%s/messages?dir=f" % (self.room_id,),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ chunk = channel.json_body["chunk"]
+ self.assertEqual(len(chunk), 6, [event["content"] for event in chunk])
+
+ # in forward, this is the last event
+ self.assertEqual(chunk[5]["event_id"], latest_event_id)
+
+ def test_room_messages_purge(self) -> None:
+ """Test room messages can be retrieved by an admin that isn't in the room."""
+ store = self.hs.get_datastores().main
+ pagination_handler = self.hs.get_pagination_handler()
+
+ # Send a first message in the room, which will be removed by the purge.
+ first_event_id = self.helper.send(
+ self.room_id, body="message 1", tok=self.user_tok
+ )["event_id"]
+ first_token = self.get_success(
+ store.get_topological_token_for_event(first_event_id)
+ )
+ first_token_str = self.get_success(first_token.to_string(store))
+
+ # Send a second message in the room, which won't be removed, and which we'll
+ # use as the marker to purge events before.
+ second_event_id = self.helper.send(
+ self.room_id, body="message 2", tok=self.user_tok
+ )["event_id"]
+ second_token = self.get_success(
+ store.get_topological_token_for_event(second_event_id)
+ )
+ second_token_str = self.get_success(second_token.to_string(store))
+
+ # Send a third event in the room to ensure we don't fall under any edge case
+ # due to our marker being the latest forward extremity in the room.
+ self.helper.send(self.room_id, body="message 3", tok=self.user_tok)
+
+ # Check that we get the first and second message when querying /messages.
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/rooms/%s/messages?from=%s&dir=b&filter=%s"
+ % (
+ self.room_id,
+ second_token_str,
+ json.dumps({"types": [EventTypes.Message]}),
+ ),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ chunk = channel.json_body["chunk"]
+ self.assertEqual(len(chunk), 2, [event["content"] for event in chunk])
+
+ # Purge every event before the second event.
+ purge_id = random_string(16)
+ pagination_handler._purges_by_id[purge_id] = PurgeStatus()
+ self.get_success(
+ pagination_handler._purge_history(
+ purge_id=purge_id,
+ room_id=self.room_id,
+ token=second_token_str,
+ delete_local_events=True,
+ )
+ )
+
+ # Check that we only get the second message through /message now that the first
+ # has been purged.
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/rooms/%s/messages?from=%s&dir=b&filter=%s"
+ % (
+ self.room_id,
+ second_token_str,
+ json.dumps({"types": [EventTypes.Message]}),
+ ),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ chunk = channel.json_body["chunk"]
+ self.assertEqual(len(chunk), 1, [event["content"] for event in chunk])
+
+ # Check that we get no event, but also no error, when querying /messages with
+ # the token that was pointing at the first event, because we don't have it
+ # anymore.
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/rooms/%s/messages?from=%s&dir=b&filter=%s"
+ % (
+ self.room_id,
+ first_token_str,
+ json.dumps({"types": [EventTypes.Message]}),
+ ),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ chunk = channel.json_body["chunk"]
+ self.assertEqual(len(chunk), 0, [event["content"] for event in chunk])
+
+
class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -1813,7 +2014,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
def test_requester_is_no_admin(self) -> None:
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
channel = self.make_request(
@@ -1823,7 +2024,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.second_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self) -> None:
@@ -1838,12 +2039,12 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
def test_local_user_does_not_exist(self) -> None:
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
channel = self.make_request(
@@ -1853,7 +2054,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_remote_user(self) -> None:
@@ -1868,7 +2069,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"This endpoint can only be used with local users",
channel.json_body["error"],
@@ -1876,7 +2077,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
def test_room_does_not_exist(self) -> None:
"""
- Check that unknown rooms/server return error HTTPStatus.NOT_FOUND.
+ Check that unknown rooms/server return error 404.
"""
url = "/_synapse/admin/v1/join/!unknown:test"
@@ -1887,7 +2088,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(
"Can't join remote room because no servers that are in the room have been provided.",
channel.json_body["error"],
@@ -1895,7 +2096,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
def test_room_is_not_valid(self) -> None:
"""
- Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.
+ Check that invalid room names, return an error 400.
"""
url = "/_synapse/admin/v1/join/invalidroom"
@@ -1906,7 +2107,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"invalidroom was not legal room ID or room alias",
channel.json_body["error"],
@@ -1924,7 +2125,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.public_room_id, channel.json_body["room_id"])
# Validate if user is a member of the room
@@ -1934,7 +2135,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.second_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
def test_join_private_room_if_not_member(self) -> None:
@@ -1954,7 +2155,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_join_private_room_if_member(self) -> None:
@@ -1982,7 +2183,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
# Join user to room.
@@ -1995,7 +2196,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
content={"user_id": self.second_user_id},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["room_id"])
# Validate if user is a member of the room
@@ -2005,7 +2206,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.second_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
def test_join_private_room_if_owner(self) -> None:
@@ -2025,7 +2226,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["room_id"])
# Validate if user is a member of the room
@@ -2035,7 +2236,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.second_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
def test_context_as_non_admin(self) -> None:
@@ -2069,7 +2270,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
% (room_id, events[midway]["event_id"]),
access_token=tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_context_as_admin(self) -> None:
@@ -2099,7 +2300,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
% (room_id, events[midway]["event_id"]),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(
channel.json_body["event"]["event_id"], events[midway]["event_id"]
)
@@ -2158,7 +2359,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Now we test that we can join the room and ban a user.
self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok)
@@ -2185,7 +2386,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Now we test that we can join the room (we should have received an
# invite) and can ban a user.
@@ -2211,7 +2412,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Now we test that we can join the room and ban a user.
self.helper.join(room_id, self.second_user_id, tok=self.second_tok)
@@ -2245,11 +2446,11 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- # We expect this to fail with a HTTPStatus.BAD_REQUEST as there are no room admins.
+ # We expect this to fail with a 400 as there are no room admins.
#
# (Note we assert the error message to ensure that it's not denied for
# some other reason)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
channel.json_body["error"],
"No local admin user in room with power to update power levels.",
@@ -2279,7 +2480,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
@parameterized.expand([("PUT",), ("GET",)])
def test_requester_is_no_admin(self, method: str) -> None:
- """If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned."""
+ """If the user is not a server admin, an error 403 is returned."""
channel = self.make_request(
method,
@@ -2288,12 +2489,12 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand([("PUT",), ("GET",)])
def test_room_is_not_valid(self, method: str) -> None:
- """Check that invalid room names, return an error HTTPStatus.BAD_REQUEST."""
+ """Check that invalid room names, return an error 400."""
channel = self.make_request(
method,
@@ -2302,7 +2503,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"invalidroom is not a legal room ID",
channel.json_body["error"],
@@ -2319,7 +2520,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
# `block` is not set
@@ -2330,7 +2531,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
# no content is send
@@ -2340,7 +2541,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"])
def test_block_room(self) -> None:
@@ -2354,7 +2555,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
content={"block": True},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertTrue(channel.json_body["block"])
self._is_blocked(room_id, expect=True)
@@ -2378,7 +2579,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
content={"block": True},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertTrue(channel.json_body["block"])
self._is_blocked(self.room_id, expect=True)
@@ -2394,7 +2595,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
content={"block": False},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body["block"])
self._is_blocked(room_id, expect=False)
@@ -2418,7 +2619,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
content={"block": False},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body["block"])
self._is_blocked(self.room_id, expect=False)
@@ -2433,7 +2634,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
self.url % room_id,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertTrue(channel.json_body["block"])
self.assertEqual(self.other_user, channel.json_body["user_id"])
@@ -2457,7 +2658,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
self.url % room_id,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body["block"])
self.assertNotIn("user_id", channel.json_body)
diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
index dbcba2663c..a2f347f666 100644
--- a/tests/rest/admin/test_server_notice.py
+++ b/tests/rest/admin/test_server_notice.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
from typing import List
from twisted.test.proto_helpers import MemoryReactor
@@ -57,7 +56,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url)
self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
+ 401,
channel.code,
msg=channel.json_body,
)
@@ -72,7 +71,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
- HTTPStatus.FORBIDDEN,
+ 403,
channel.code,
msg=channel.json_body,
)
@@ -80,7 +79,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
@override_config({"server_notices": {"system_mxid_localpart": "notices"}})
def test_user_does_not_exist(self) -> None:
- """Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND"""
+ """Tests that a lookup for a user that does not exist returns a 404"""
channel = self.make_request(
"POST",
self.url,
@@ -88,13 +87,13 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
content={"user_id": "@unknown_person:test", "content": ""},
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@override_config({"server_notices": {"system_mxid_localpart": "notices"}})
def test_user_is_not_local(self) -> None:
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
channel = self.make_request(
"POST",
@@ -106,7 +105,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"Server notices can only be sent to local users", channel.json_body["error"]
)
@@ -122,7 +121,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"])
# no content
@@ -133,7 +132,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
content={"user_id": self.other_user},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
# no body
@@ -144,7 +143,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
content={"user_id": self.other_user, "content": ""},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
self.assertEqual("'body' not in content", channel.json_body["error"])
@@ -156,10 +155,66 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
content={"user_id": self.other_user, "content": {"body": ""}},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
self.assertEqual("'msgtype' not in content", channel.json_body["error"])
+ @override_config(
+ {
+ "server_notices": {
+ "system_mxid_localpart": "notices",
+ "system_mxid_avatar_url": "somthingwrong",
+ },
+ "max_avatar_size": "10M",
+ }
+ )
+ def test_invalid_avatar_url(self) -> None:
+ """If avatar url in homeserver.yaml is invalid and
+ "check avatar size and mime type" is set, an error is returned.
+ TODO: Should be checked when reading the configuration."""
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg"},
+ },
+ )
+
+ self.assertEqual(500, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+ @override_config(
+ {
+ "server_notices": {
+ "system_mxid_localpart": "notices",
+ "system_mxid_display_name": "test display name",
+ "system_mxid_avatar_url": None,
+ },
+ "max_avatar_size": "10M",
+ }
+ )
+ def test_displayname_is_set_avatar_is_none(self) -> None:
+ """
+ Tests that sending a server notices is successfully,
+ if a display_name is set, avatar_url is `None` and
+ "check avatar size and mime type" is set.
+ """
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg"},
+ },
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # user has one invite
+ self._check_invite_and_join_status(self.other_user, 1, 0)
+
def test_server_notice_disabled(self) -> None:
"""Tests that server returns error if server notice is disabled"""
channel = self.make_request(
@@ -172,7 +227,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
self.assertEqual(
"Server notices are not enabled on this server", channel.json_body["error"]
@@ -197,7 +252,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg one"},
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# user has one invite
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
@@ -226,7 +281,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg two"},
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# user has no new invites or memberships
self._check_invite_and_join_status(self.other_user, 0, 1)
@@ -260,7 +315,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg one"},
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# user has one invite
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
@@ -301,7 +356,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg two"},
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# user has one invite
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
@@ -341,7 +396,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg one"},
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# user has one invite
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
@@ -388,7 +443,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg two"},
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# user has one invite
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
@@ -538,7 +593,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET", "/_matrix/client/r0/sync", access_token=token
)
- self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertEqual(channel.code, 200)
# Get the messages
room = channel.json_body["rooms"]["join"][room_id]
diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py
index 7cb8ec57ba..b60f16b914 100644
--- a/tests/rest/admin/test_statistics.py
+++ b/tests/rest/admin/test_statistics.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
from typing import List, Optional
from twisted.test.proto_helpers import MemoryReactor
@@ -51,16 +50,12 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
channel = self.make_request(
"GET",
@@ -69,11 +64,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self) -> None:
@@ -87,11 +78,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
@@ -101,11 +88,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative limit
@@ -115,11 +98,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from_ts
@@ -129,11 +108,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative until_ts
@@ -143,11 +118,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# until_ts smaller from_ts
@@ -157,11 +128,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# empty search term
@@ -171,11 +138,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid search order
@@ -185,11 +148,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_limit(self) -> None:
@@ -204,7 +163,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 10)
self.assertEqual(len(channel.json_body["users"]), 5)
self.assertEqual(channel.json_body["next_token"], 5)
@@ -222,7 +181,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["users"]), 15)
self.assertNotIn("next_token", channel.json_body)
@@ -240,7 +199,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(channel.json_body["next_token"], 15)
self.assertEqual(len(channel.json_body["users"]), 10)
@@ -262,7 +221,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), number_users)
self.assertNotIn("next_token", channel.json_body)
@@ -275,7 +234,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), number_users)
self.assertNotIn("next_token", channel.json_body)
@@ -288,7 +247,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 19)
self.assertEqual(channel.json_body["next_token"], 19)
@@ -301,7 +260,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 1)
self.assertNotIn("next_token", channel.json_body)
@@ -318,7 +277,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["users"]))
@@ -415,7 +374,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["users"][0]["media_count"], 3)
# filter media starting at `ts1` after creating first media
@@ -425,7 +384,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?from_ts=%s" % (ts1,),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 0)
self._create_media(self.other_user_tok, 3)
@@ -440,7 +399,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?from_ts=%s&until_ts=%s" % (ts1, ts2),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["users"][0]["media_count"], 3)
# filter media until `ts2` and earlier
@@ -449,7 +408,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?until_ts=%s" % (ts2,),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["users"][0]["media_count"], 6)
def test_search_term(self) -> None:
@@ -461,7 +420,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
# filter user 1 and 10-19 by `user_id`
@@ -470,7 +429,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?search_term=foo_user_1",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 11)
# filter on this user in `displayname`
@@ -479,7 +438,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?search_term=bar_user_10",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["users"][0]["displayname"], "bar_user_10")
self.assertEqual(channel.json_body["total"], 1)
@@ -489,7 +448,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?search_term=foobar",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 0)
def _create_users_with_media(self, number_users: int, media_per_user: int) -> None:
@@ -515,7 +474,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
for _ in range(number_media):
# Upload some media into the room
self.helper.upload_media(
- upload_resource, SMALL_PNG, tok=user_token, expect_code=HTTPStatus.OK
+ upload_resource, SMALL_PNG, tok=user_token, expect_code=200
)
def _check_fields(self, content: List[JsonDict]) -> None:
@@ -549,7 +508,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], len(expected_user_list))
returned_order = [row["user_id"] for row in channel.json_body["users"]]
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 12db68d564..e8c9457794 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1,4 +1,4 @@
-# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2018-2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,7 +17,6 @@ import hmac
import os
import urllib.parse
from binascii import unhexlify
-from http import HTTPStatus
from typing import List, Optional
from unittest.mock import Mock, patch
@@ -26,13 +25,13 @@ from parameterized import parameterized, parameterized_class
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
-from synapse.api.constants import UserTypes
+from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
from synapse.api.room_versions import RoomVersions
-from synapse.rest.client import devices, login, logout, profile, room, sync
+from synapse.rest.client import devices, login, logout, profile, register, room, sync
from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.server import HomeServer
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
from tests import unittest
@@ -42,14 +41,12 @@ from tests.unittest import override_config
class UserRegisterTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
profile.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
self.url = "/_synapse/admin/v1/register"
self.registration_handler = Mock()
@@ -79,7 +76,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"Shared secret registration is not enabled", channel.json_body["error"]
)
@@ -111,7 +108,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
body = {"nonce": nonce}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("username must be specified", channel.json_body["error"])
# 61 seconds
@@ -119,7 +116,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("unrecognised nonce", channel.json_body["error"])
def test_register_incorrect_nonce(self) -> None:
@@ -142,7 +139,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("HMAC incorrect", channel.json_body["error"])
def test_register_correct_nonce(self) -> None:
@@ -169,7 +166,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"])
def test_nonce_reuse(self) -> None:
@@ -192,13 +189,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"])
# Now, try and reuse it
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("unrecognised nonce", channel.json_body["error"])
def test_missing_parts(self) -> None:
@@ -219,7 +216,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Must be an empty body present
channel = self.make_request("POST", self.url, {})
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("nonce must be specified", channel.json_body["error"])
#
@@ -229,28 +226,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Must be present
channel = self.make_request("POST", self.url, {"nonce": nonce()})
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("username must be specified", channel.json_body["error"])
# Must be a string
body = {"nonce": nonce(), "username": 1234}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid username", channel.json_body["error"])
# Must not have null bytes
body = {"nonce": nonce(), "username": "abcd\u0000"}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid username", channel.json_body["error"])
# Must not have null bytes
body = {"nonce": nonce(), "username": "a" * 1000}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid username", channel.json_body["error"])
#
@@ -261,28 +258,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
body = {"nonce": nonce(), "username": "a"}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("password must be specified", channel.json_body["error"])
# Must be a string
body = {"nonce": nonce(), "username": "a", "password": 1234}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid password", channel.json_body["error"])
# Must not have null bytes
body = {"nonce": nonce(), "username": "a", "password": "abcd\u0000"}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid password", channel.json_body["error"])
# Super long
body = {"nonce": nonce(), "username": "a", "password": "A" * 1000}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid password", channel.json_body["error"])
#
@@ -298,7 +295,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid user type", channel.json_body["error"])
def test_displayname(self) -> None:
@@ -323,11 +320,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob1:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob1:test/displayname")
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("bob1", channel.json_body["displayname"])
# displayname is None
@@ -347,11 +344,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob2:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob2:test/displayname")
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("bob2", channel.json_body["displayname"])
# displayname is empty
@@ -371,11 +368,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob3:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob3:test/displayname")
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
# set displayname
channel = self.make_request("GET", self.url)
@@ -394,11 +391,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob4:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob4:test/displayname")
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("Bob's Name", channel.json_body["displayname"])
@override_config(
@@ -442,12 +439,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"])
class UsersListTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
@@ -466,7 +462,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
@@ -478,7 +474,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
channel = self.make_request("GET", self.url, access_token=other_user_token)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_all_users(self) -> None:
@@ -494,7 +490,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(3, len(channel.json_body["users"]))
self.assertEqual(3, channel.json_body["total"])
@@ -508,7 +504,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
expected_user_id: Optional[str],
search_term: str,
search_field: Optional[str] = "name",
- expected_http_code: Optional[int] = HTTPStatus.OK,
+ expected_http_code: Optional[int] = 200,
) -> None:
"""Search for a user and check that the returned user's id is a match
@@ -530,7 +526,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
- if expected_http_code != HTTPStatus.OK:
+ if expected_http_code != 200:
return
# Check that users were returned
@@ -579,6 +575,16 @@ class UsersListTestCase(unittest.HomeserverTestCase):
_search_test(None, "foo", "user_id")
_search_test(None, "bar", "user_id")
+ @override_config(
+ {
+ "experimental_features": {
+ "msc3866": {
+ "enabled": True,
+ "require_approval_for_new_accounts": True,
+ }
+ }
+ }
+ )
def test_invalid_parameter(self) -> None:
"""
If parameters are invalid, an error is returned.
@@ -591,7 +597,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
@@ -601,7 +607,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid guests
@@ -611,7 +617,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid deactivated
@@ -621,7 +627,17 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ # invalid approved
+ channel = self.make_request(
+ "GET",
+ self.url + "?approved=not_bool",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# unkown order_by
@@ -631,7 +647,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid search order
@@ -641,7 +657,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_limit(self) -> None:
@@ -659,7 +675,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 5)
self.assertEqual(channel.json_body["next_token"], "5")
@@ -680,7 +696,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 15)
self.assertNotIn("next_token", channel.json_body)
@@ -701,7 +717,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(channel.json_body["next_token"], "15")
self.assertEqual(len(channel.json_body["users"]), 10)
@@ -724,7 +740,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), number_users)
self.assertNotIn("next_token", channel.json_body)
@@ -737,7 +753,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), number_users)
self.assertNotIn("next_token", channel.json_body)
@@ -750,7 +766,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 19)
self.assertEqual(channel.json_body["next_token"], "19")
@@ -764,7 +780,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 1)
self.assertNotIn("next_token", channel.json_body)
@@ -842,6 +858,99 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self._order_test([self.admin_user, user1, user2], "creation_ts", "f")
self._order_test([user2, user1, self.admin_user], "creation_ts", "b")
+ @override_config(
+ {
+ "experimental_features": {
+ "msc3866": {
+ "enabled": True,
+ "require_approval_for_new_accounts": True,
+ }
+ }
+ }
+ )
+ def test_filter_out_approved(self) -> None:
+ """Tests that the endpoint can filter out approved users."""
+ # Create our users.
+ self._create_users(2)
+
+ # Get the list of users.
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, channel.result)
+
+ # Exclude the admin, because we don't want to accidentally un-approve the admin.
+ non_admin_user_ids = [
+ user["name"]
+ for user in channel.json_body["users"]
+ if user["name"] != self.admin_user
+ ]
+
+ self.assertEqual(2, len(non_admin_user_ids), non_admin_user_ids)
+
+ # Select a user and un-approve them. We do this rather than the other way around
+ # because, since these users are created by an admin, we consider them already
+ # approved.
+ not_approved_user = non_admin_user_ids[0]
+
+ channel = self.make_request(
+ "PUT",
+ f"/_synapse/admin/v2/users/{not_approved_user}",
+ {"approved": False},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, channel.result)
+
+ # Now get the list of users again, this time filtering out approved users.
+ channel = self.make_request(
+ "GET",
+ self.url + "?approved=false",
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, channel.result)
+
+ non_admin_user_ids = [
+ user["name"]
+ for user in channel.json_body["users"]
+ if user["name"] != self.admin_user
+ ]
+
+ # We should only have our unapproved user now.
+ self.assertEqual(1, len(non_admin_user_ids), non_admin_user_ids)
+ self.assertEqual(not_approved_user, non_admin_user_ids[0])
+
+ def test_erasure_status(self) -> None:
+ # Create a new user.
+ user_id = self.register_user("eraseme", "eraseme")
+
+ # They should appear in the list users API, marked as not erased.
+ channel = self.make_request(
+ "GET",
+ self.url + "?deactivated=true",
+ access_token=self.admin_user_tok,
+ )
+ users = {user["name"]: user for user in channel.json_body["users"]}
+ self.assertIs(users[user_id]["erased"], False)
+
+ # Deactivate that user, requesting erasure.
+ deactivate_account_handler = self.hs.get_deactivate_account_handler()
+ self.get_success(
+ deactivate_account_handler.deactivate_account(
+ user_id, erase_data=True, requester=create_requester(user_id)
+ )
+ )
+
+ # Repeat the list users query. They should now be marked as erased.
+ channel = self.make_request(
+ "GET",
+ self.url + "?deactivated=true",
+ access_token=self.admin_user_tok,
+ )
+ users = {user["name"]: user for user in channel.json_body["users"]}
+ self.assertIs(users[user_id]["erased"], True)
+
def _order_test(
self,
expected_user_list: List[str],
@@ -867,7 +976,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], len(expected_user_list))
returned_order = [row["name"] for row in channel.json_body["users"]]
@@ -905,11 +1014,100 @@ class UsersListTestCase(unittest.HomeserverTestCase):
)
-class DeactivateAccountTestCase(unittest.HomeserverTestCase):
+class UserDevicesTestCase(unittest.HomeserverTestCase):
+ """
+ Tests user device management-related Admin APIs.
+ """
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
+ # Set up an Admin user to query the Admin API with.
+ self.admin_user_id = self.register_user("admin", "pass", admin=True)
+ self.admin_user_token = self.login("admin", "pass")
+
+ # Set up a test user to query the devices of.
+ self.other_user_device_id = "TESTDEVICEID"
+ self.other_user_device_display_name = "My Test Device"
+ self.other_user_client_ip = "1.2.3.4"
+ self.other_user_user_agent = "EquestriaTechnology/123.0"
+
+ self.other_user_id = self.register_user("user", "pass", displayname="User1")
+ self.other_user_token = self.login(
+ "user",
+ "pass",
+ device_id=self.other_user_device_id,
+ additional_request_fields={
+ "initial_device_display_name": self.other_user_device_display_name,
+ },
+ )
+
+ # Have the "other user" make a request so that the "last_seen_*" fields are
+ # populated in the tests below.
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v3/sync",
+ access_token=self.other_user_token,
+ client_ip=self.other_user_client_ip,
+ custom_headers=[
+ ("User-Agent", self.other_user_user_agent),
+ ],
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ def test_list_user_devices(self) -> None:
+ """
+ Tests that a user's devices and attributes are listed correctly via the Admin API.
+ """
+ # Request all devices of "other user"
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v2/users/{self.other_user_id}/devices",
+ access_token=self.admin_user_token,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Double-check we got the single device expected
+ user_devices = channel.json_body["devices"]
+ self.assertEqual(len(user_devices), 1)
+ self.assertEqual(channel.json_body["total"], 1)
+
+ # Check that all the attributes of the device reported are as expected.
+ self._validate_attributes_of_device_response(user_devices[0])
+
+ # Request just a single device for "other user" by its ID
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v2/users/{self.other_user_id}/devices/"
+ f"{self.other_user_device_id}",
+ access_token=self.admin_user_token,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Check that all the attributes of the device reported are as expected.
+ self._validate_attributes_of_device_response(channel.json_body)
+
+ def _validate_attributes_of_device_response(self, response: JsonDict) -> None:
+ # Check that all device expected attributes are present
+ self.assertEqual(response["user_id"], self.other_user_id)
+ self.assertEqual(response["device_id"], self.other_user_device_id)
+ self.assertEqual(response["display_name"], self.other_user_device_display_name)
+ self.assertEqual(response["last_seen_ip"], self.other_user_client_ip)
+ self.assertEqual(response["last_seen_user_agent"], self.other_user_user_agent)
+ self.assertIsInstance(response["last_seen_ts"], int)
+ self.assertGreater(response["last_seen_ts"], 0)
+
+
+class DeactivateAccountTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -941,7 +1139,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_not_admin(self) -> None:
@@ -952,7 +1150,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", url, access_token=self.other_user_token)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("You are not a server admin", channel.json_body["error"])
channel = self.make_request(
@@ -962,12 +1160,12 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
content=b"{}",
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("You are not a server admin", channel.json_body["error"])
def test_user_does_not_exist(self) -> None:
"""
- Tests that deactivation for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that deactivation for a user that does not exist returns a 404
"""
channel = self.make_request(
@@ -976,7 +1174,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_erase_is_not_bool(self) -> None:
@@ -991,18 +1189,18 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_user_is_not_local(self) -> None:
"""
- Tests that deactivation for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that deactivation for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v1/deactivate/@unknown_person:unknown_domain"
channel = self.make_request("POST", url, access_token=self.admin_user_tok)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only deactivate local users", channel.json_body["error"])
def test_deactivate_user_erase_true(self) -> None:
@@ -1017,12 +1215,13 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User1", channel.json_body["displayname"])
+ self.assertFalse(channel.json_body["erased"])
# Deactivate and erase user
channel = self.make_request(
@@ -1032,7 +1231,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
content={"erase": True},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Get user
channel = self.make_request(
@@ -1041,12 +1240,13 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertIsNone(channel.json_body["avatar_url"])
self.assertIsNone(channel.json_body["displayname"])
+ self.assertTrue(channel.json_body["erased"])
self._is_erased("@user:test", True)
@@ -1066,7 +1266,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"erase": True},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self._is_erased("@user:test", True)
def test_deactivate_user_erase_false(self) -> None:
@@ -1081,7 +1281,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
@@ -1096,7 +1296,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
content={"erase": False},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Get user
channel = self.make_request(
@@ -1105,7 +1305,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"]))
@@ -1135,7 +1335,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
@@ -1150,7 +1350,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
content={"erase": True},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Get user
channel = self.make_request(
@@ -1159,7 +1359,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"]))
@@ -1178,11 +1378,11 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
class UserRestTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
sync.register_servlets,
+ register.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -1220,7 +1420,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("You are not a server admin", channel.json_body["error"])
channel = self.make_request(
@@ -1230,12 +1430,12 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=b"{}",
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("You are not a server admin", channel.json_body["error"])
def test_user_does_not_exist(self) -> None:
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
channel = self.make_request(
@@ -1244,7 +1444,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"])
def test_invalid_parameter(self) -> None:
@@ -1259,7 +1459,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"admin": "not_bool"},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
# deactivated not bool
@@ -1269,7 +1469,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": "not_bool"},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# password not str
@@ -1279,7 +1479,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"password": True},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# password not length
@@ -1289,7 +1489,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"password": "x" * 513},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# user_type not valid
@@ -1299,7 +1499,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"user_type": "new type"},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# external_ids not valid
@@ -1311,7 +1511,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"external_ids": {"auth_provider": "prov", "wrong_external_id": "id"}
},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
channel = self.make_request(
@@ -1320,7 +1520,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"external_ids": {"external_id": "id"}},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
# threepids not valid
@@ -1330,7 +1530,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"threepids": {"medium": "email", "wrong_address": "id"}},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
channel = self.make_request(
@@ -1339,7 +1539,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"threepids": {"address": "value"}},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
def test_get_user(self) -> None:
@@ -1352,7 +1552,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("User", channel.json_body["displayname"])
self._check_fields(channel.json_body)
@@ -1379,7 +1579,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body,
)
- self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1395,7 +1595,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1434,7 +1634,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body,
)
- self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1458,7 +1658,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1486,7 +1686,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# before limit of monthly active users is reached
channel = self.make_request("GET", "/sync", access_token=self.admin_user_tok)
- if channel.code != HTTPStatus.OK:
+ if channel.code != 200:
raise HttpResponseException(
channel.code, channel.result["reason"], channel.result["body"]
)
@@ -1512,7 +1712,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"password": "abc123", "admin": False},
)
- self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertFalse(channel.json_body["admin"])
@@ -1550,7 +1750,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
# Admin user is not blocked by mau anymore
- self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertFalse(channel.json_body["admin"])
@@ -1585,7 +1785,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body,
)
- self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
@@ -1626,7 +1826,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body,
)
- self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
@@ -1666,7 +1866,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body,
)
- self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("msisdn", channel.json_body["threepids"][0]["medium"])
self.assertEqual("1234567890", channel.json_body["threepids"][0]["address"])
@@ -1684,7 +1884,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"password": "hahaha"},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self._check_fields(channel.json_body)
def test_set_displayname(self) -> None:
@@ -1700,7 +1900,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"displayname": "foobar"},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("foobar", channel.json_body["displayname"])
@@ -1711,7 +1911,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("foobar", channel.json_body["displayname"])
@@ -1733,7 +1933,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["threepids"]))
# result does not always have the same sort order, therefore it becomes sorted
@@ -1759,7 +1959,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1775,7 +1975,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1791,7 +1991,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"threepids": []},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(0, len(channel.json_body["threepids"]))
self._check_fields(channel.json_body)
@@ -1818,7 +2018,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(first_user, channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1837,7 +2037,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1859,7 +2059,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
# other user has this two threepids
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["threepids"]))
# result does not always have the same sort order, therefore it becomes sorted
@@ -1878,7 +2078,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
url_first_user,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(first_user, channel.json_body["name"])
self.assertEqual(0, len(channel.json_body["threepids"]))
self._check_fields(channel.json_body)
@@ -1907,7 +2107,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["external_ids"]))
# result does not always have the same sort order, therefore it becomes sorted
@@ -1939,7 +2139,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["external_ids"]))
self.assertEqual(
@@ -1958,7 +2158,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["external_ids"]))
self.assertEqual(
@@ -1977,7 +2177,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"external_ids": []},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(0, len(channel.json_body["external_ids"]))
@@ -2006,7 +2206,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(first_user, channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertEqual(
@@ -2032,7 +2232,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertEqual(
@@ -2064,7 +2264,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
# must fail
- self.assertEqual(HTTPStatus.CONFLICT, channel.code, msg=channel.json_body)
+ self.assertEqual(409, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
self.assertEqual("External id is already in use.", channel.json_body["error"])
@@ -2075,7 +2275,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertEqual(
@@ -2093,7 +2293,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(first_user, channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertEqual(
@@ -2124,7 +2324,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
@@ -2139,7 +2339,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"deactivated": True},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"]))
@@ -2158,7 +2358,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"]))
@@ -2188,7 +2388,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"deactivated": True},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
@@ -2204,7 +2404,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"displayname": "Foobar"},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
self.assertEqual("Foobar", channel.json_body["displayname"])
@@ -2228,7 +2428,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
# Reactivate the user.
channel = self.make_request(
@@ -2237,7 +2437,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False, "password": "foo"},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
self._is_erased("@user:test", False)
@@ -2261,7 +2461,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False, "password": "foo"},
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Reactivate the user without a password.
@@ -2271,7 +2471,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
self._is_erased("@user:test", False)
@@ -2295,7 +2495,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False, "password": "foo"},
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Reactivate the user without a password.
@@ -2305,7 +2505,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
self._is_erased("@user:test", False)
@@ -2326,7 +2526,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"admin": True},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["admin"])
@@ -2337,7 +2537,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["admin"])
@@ -2354,7 +2554,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"user_type": UserTypes.SUPPORT},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"])
@@ -2365,7 +2565,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"])
@@ -2377,7 +2577,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"user_type": None},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertIsNone(channel.json_body["user_type"])
@@ -2388,7 +2588,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertIsNone(channel.json_body["user_type"])
@@ -2407,7 +2607,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"password": "abc123"},
)
- self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("bob", channel.json_body["displayname"])
@@ -2418,7 +2618,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("bob", channel.json_body["displayname"])
self.assertEqual(0, channel.json_body["deactivated"])
@@ -2431,7 +2631,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"password": "abc123", "deactivated": "false"},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
# Check user is not deactivated
channel = self.make_request(
@@ -2440,13 +2640,111 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("bob", channel.json_body["displayname"])
# Ensure they're still alive
self.assertEqual(0, channel.json_body["deactivated"])
+ @override_config(
+ {
+ "experimental_features": {
+ "msc3866": {
+ "enabled": True,
+ "require_approval_for_new_accounts": True,
+ }
+ }
+ }
+ )
+ def test_approve_account(self) -> None:
+ """Tests that approving an account correctly sets the approved flag for the user."""
+ url = self.url_prefix % "@bob:test"
+
+ # Create the user using the client-server API since otherwise the user will be
+ # marked as approved automatically.
+ channel = self.make_request(
+ "POST",
+ "register",
+ {
+ "username": "bob",
+ "password": "test",
+ "auth": {"type": LoginType.DUMMY},
+ },
+ )
+ self.assertEqual(403, channel.code, channel.result)
+ self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
+ self.assertEqual(
+ ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
+ )
+
+ # Get user
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertIs(False, channel.json_body["approved"])
+
+ # Approve user
+ channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
+ content={"approved": True},
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertIs(True, channel.json_body["approved"])
+
+ # Check that the user is now approved
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertIs(True, channel.json_body["approved"])
+
+ @override_config(
+ {
+ "experimental_features": {
+ "msc3866": {
+ "enabled": True,
+ "require_approval_for_new_accounts": True,
+ }
+ }
+ }
+ )
+ def test_register_approved(self) -> None:
+ url = self.url_prefix % "@bob:test"
+
+ # Create user
+ channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
+ content={"password": "abc123", "approved": True},
+ )
+
+ self.assertEqual(201, channel.code, msg=channel.json_body)
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual(1, channel.json_body["approved"])
+
+ # Get user
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual(1, channel.json_body["approved"])
+
def _is_erased(self, user_id: str, expect: bool) -> None:
"""Assert that the user is erased or not"""
d = self.store.is_user_erased(user_id)
@@ -2465,7 +2763,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": True},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertTrue(channel.json_body["deactivated"])
self._is_erased(user_id, False)
d = self.store.mark_user_erased(user_id)
@@ -2486,11 +2784,13 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertIn("avatar_url", content)
self.assertIn("admin", content)
self.assertIn("deactivated", content)
+ self.assertIn("erased", content)
self.assertIn("shadow_banned", content)
self.assertIn("creation_ts", content)
self.assertIn("appservice_id", content)
self.assertIn("consent_server_notice_sent", content)
self.assertIn("consent_version", content)
+ self.assertIn("consent_ts", content)
self.assertIn("external_ids", content)
# This key was removed intentionally. Ensure it is not accidentally re-included.
@@ -2498,7 +2798,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
class UserMembershipRestTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
@@ -2520,7 +2819,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
@@ -2535,7 +2834,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_does_not_exist(self) -> None:
@@ -2549,7 +2848,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["joined_rooms"]))
@@ -2565,7 +2864,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["joined_rooms"]))
@@ -2581,7 +2880,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["joined_rooms"]))
@@ -2602,7 +2901,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_rooms, channel.json_body["total"])
self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
@@ -2649,13 +2948,12 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual([local_and_remote_room_id], channel.json_body["joined_rooms"])
class PushersRestTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
@@ -2678,7 +2976,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
@@ -2693,12 +2991,12 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_does_not_exist(self) -> None:
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v1/users/@unknown_person:test/pushers"
channel = self.make_request(
@@ -2707,12 +3005,12 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_user_is_not_local(self) -> None:
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers"
@@ -2722,7 +3020,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only look up local users", channel.json_body["error"])
def test_get_pushers(self) -> None:
@@ -2737,7 +3035,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
# Register the pusher
@@ -2749,7 +3047,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
token_id = user_tuple.token_id
self.get_success(
- self.hs.get_pusherpool().add_pusher(
+ self.hs.get_pusherpool().add_or_update_pusher(
user_id=self.other_user,
access_token=token_id,
kind="http",
@@ -2769,7 +3067,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
for p in channel.json_body["pushers"]:
@@ -2784,7 +3082,6 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
class UserMediaRestTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
@@ -2808,7 +3105,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
"""Try to list media of an user without authentication."""
channel = self.make_request(method, self.url, {})
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "DELETE"])
@@ -2822,12 +3119,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "DELETE"])
def test_user_does_not_exist(self, method: str) -> None:
- """Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND"""
+ """Tests that a lookup for a user that does not exist returns a 404"""
url = "/_synapse/admin/v1/users/@unknown_person:test/media"
channel = self.make_request(
method,
@@ -2835,12 +3132,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@parameterized.expand(["GET", "DELETE"])
def test_user_is_not_local(self, method: str) -> None:
- """Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST"""
+ """Tests that a lookup for a user that is not a local returns a 400"""
url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media"
channel = self.make_request(
@@ -2849,7 +3146,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only look up local users", channel.json_body["error"])
def test_limit_GET(self) -> None:
@@ -2865,7 +3162,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 5)
self.assertEqual(channel.json_body["next_token"], 5)
@@ -2884,7 +3181,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 5)
self.assertEqual(len(channel.json_body["deleted_media"]), 5)
@@ -2901,7 +3198,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 15)
self.assertNotIn("next_token", channel.json_body)
@@ -2920,7 +3217,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 15)
self.assertEqual(len(channel.json_body["deleted_media"]), 15)
@@ -2937,7 +3234,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(channel.json_body["next_token"], 15)
self.assertEqual(len(channel.json_body["media"]), 10)
@@ -2956,7 +3253,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 10)
self.assertEqual(len(channel.json_body["deleted_media"]), 10)
@@ -2970,7 +3267,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid search order
@@ -2980,7 +3277,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative limit
@@ -2990,7 +3287,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
@@ -3000,7 +3297,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_next_token(self) -> None:
@@ -3023,7 +3320,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), number_media)
self.assertNotIn("next_token", channel.json_body)
@@ -3036,7 +3333,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), number_media)
self.assertNotIn("next_token", channel.json_body)
@@ -3049,7 +3346,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 19)
self.assertEqual(channel.json_body["next_token"], 19)
@@ -3063,7 +3360,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 1)
self.assertNotIn("next_token", channel.json_body)
@@ -3080,7 +3377,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["media"]))
@@ -3095,7 +3392,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["deleted_media"]))
@@ -3112,7 +3409,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_media, channel.json_body["total"])
self.assertEqual(number_media, len(channel.json_body["media"]))
self.assertNotIn("next_token", channel.json_body)
@@ -3138,7 +3435,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_media, channel.json_body["total"])
self.assertEqual(number_media, len(channel.json_body["deleted_media"]))
self.assertCountEqual(channel.json_body["deleted_media"], media_ids)
@@ -3283,7 +3580,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
# Upload some media into the room
response = self.helper.upload_media(
- upload_resource, image_data, user_token, filename, expect_code=HTTPStatus.OK
+ upload_resource, image_data, user_token, filename, expect_code=200
)
# Extract media ID from the response
@@ -3301,10 +3598,10 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
- HTTPStatus.OK,
+ 200,
channel.code,
msg=(
- f"Expected to receive a HTTPStatus.OK on accessing media: {server_and_media_id}"
+ f"Expected to receive a 200 on accessing media: {server_and_media_id}"
),
)
@@ -3350,7 +3647,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], len(expected_media_list))
returned_order = [row["media_id"] for row in channel.json_body["media"]]
@@ -3386,14 +3683,14 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST", self.url, b"{}", access_token=self.admin_user_tok
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
return channel.json_body["access_token"]
def test_no_auth(self) -> None:
"""Try to login as a user without authentication."""
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_not_admin(self) -> None:
@@ -3402,7 +3699,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
"POST", self.url, b"{}", access_token=self.other_user_tok
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
def test_send_event(self) -> None:
"""Test that sending event as a user works."""
@@ -3427,7 +3724,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# We should only see the one device (from the login in `prepare`)
self.assertEqual(len(channel.json_body["devices"]), 1)
@@ -3439,21 +3736,21 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# Test that we can successfully make a request
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Logout with the puppet token
channel = self.make_request("POST", "logout", b"{}", access_token=puppet_token)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# The puppet token should no longer work
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
# .. but the real user's tokens should still work
channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
def test_user_logout_all(self) -> None:
"""Tests that the target user calling `/logout/all` does *not* expire
@@ -3464,23 +3761,23 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# Test that we can successfully make a request
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Logout all with the real user token
channel = self.make_request(
"POST", "logout/all", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# The puppet token should still work
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# .. but the real user's tokens shouldn't
channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
def test_admin_logout_all(self) -> None:
"""Tests that the admin user calling `/logout/all` does expire the
@@ -3491,23 +3788,23 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# Test that we can successfully make a request
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Logout all with the admin user token
channel = self.make_request(
"POST", "logout/all", b"{}", access_token=self.admin_user_tok
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# The puppet token should no longer work
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
# .. but the real user's tokens should still work
channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
@unittest.override_config(
{
@@ -3538,7 +3835,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
room_id,
"com.example.test",
tok=self.other_user_tok,
- expect_code=HTTPStatus.FORBIDDEN,
+ expect_code=403,
)
# Login in as the user
@@ -3559,7 +3856,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
room_id,
user=self.other_user,
tok=self.other_user_tok,
- expect_code=HTTPStatus.FORBIDDEN,
+ expect_code=403,
)
# Logging in as the other user and joining a room should work, even
@@ -3576,7 +3873,6 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
],
)
class WhoisRestTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
@@ -3594,7 +3890,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
Try to get information of an user without authentication.
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_not_admin(self) -> None:
@@ -3609,12 +3905,12 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
self.url,
access_token=other_user2_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_is_not_local(self) -> None:
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
url = self.url_prefix % "@unknown_person:unknown_domain" # type: ignore[attr-defined]
@@ -3623,7 +3919,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only whois a local user", channel.json_body["error"])
def test_get_whois_admin(self) -> None:
@@ -3635,7 +3931,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
self.assertIn("devices", channel.json_body)
@@ -3650,13 +3946,12 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
self.url,
access_token=other_user_token,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
self.assertIn("devices", channel.json_body)
class ShadowBanRestTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
@@ -3680,7 +3975,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
Try to get information of an user without authentication.
"""
channel = self.make_request(method, self.url)
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["POST", "DELETE"])
@@ -3691,18 +3986,18 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
other_user_token = self.login("user", "pass")
channel = self.make_request(method, self.url, access_token=other_user_token)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand(["POST", "DELETE"])
def test_user_is_not_local(self, method: str) -> None:
"""
- Tests that shadow-banning for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that shadow-banning for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
channel = self.make_request(method, url, access_token=self.admin_user_tok)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
def test_success(self) -> None:
"""
@@ -3715,7 +4010,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
self.assertFalse(result.shadow_banned)
channel = self.make_request("POST", self.url, access_token=self.admin_user_tok)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual({}, channel.json_body)
# Ensure the user is shadow-banned (and the cache was cleared).
@@ -3727,7 +4022,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE", self.url, access_token=self.admin_user_tok
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual({}, channel.json_body)
# Ensure the user is no longer shadow-banned (and the cache was cleared).
@@ -3737,7 +4032,6 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
class RateLimitTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
@@ -3762,7 +4056,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request(method, self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "POST", "DELETE"])
@@ -3778,13 +4072,13 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "POST", "DELETE"])
def test_user_does_not_exist(self, method: str) -> None:
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v1/users/@unknown_person:test/override_ratelimit"
@@ -3794,7 +4088,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@parameterized.expand(
@@ -3806,7 +4100,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
)
def test_user_is_not_local(self, method: str, error_msg: str) -> None:
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
url = (
"/_synapse/admin/v1/users/@unknown_person:unknown_domain/override_ratelimit"
@@ -3818,7 +4112,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(error_msg, channel.json_body["error"])
def test_invalid_parameter(self) -> None:
@@ -3833,7 +4127,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"messages_per_second": "string"},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# messages_per_second is negative
@@ -3844,7 +4138,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"messages_per_second": -1},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# burst_count is a string
@@ -3855,7 +4149,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"burst_count": "string"},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# burst_count is negative
@@ -3866,7 +4160,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"burst_count": -1},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_return_zero_when_null(self) -> None:
@@ -3891,7 +4185,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["messages_per_second"])
self.assertEqual(0, channel.json_body["burst_count"])
@@ -3905,7 +4199,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertNotIn("messages_per_second", channel.json_body)
self.assertNotIn("burst_count", channel.json_body)
@@ -3916,7 +4210,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"messages_per_second": 10, "burst_count": 11},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(10, channel.json_body["messages_per_second"])
self.assertEqual(11, channel.json_body["burst_count"])
@@ -3927,7 +4221,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"messages_per_second": 20, "burst_count": 21},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(20, channel.json_body["messages_per_second"])
self.assertEqual(21, channel.json_body["burst_count"])
@@ -3937,7 +4231,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(20, channel.json_body["messages_per_second"])
self.assertEqual(21, channel.json_body["burst_count"])
@@ -3947,7 +4241,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertNotIn("messages_per_second", channel.json_body)
self.assertNotIn("burst_count", channel.json_body)
@@ -3957,13 +4251,12 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertNotIn("messages_per_second", channel.json_body)
self.assertNotIn("burst_count", channel.json_body)
class AccountDataTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
@@ -3982,7 +4275,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
"""Try to get information of a user without authentication."""
channel = self.make_request("GET", self.url, {})
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
@@ -3995,7 +4288,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_does_not_exist(self) -> None:
@@ -4008,7 +4301,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_user_is_not_local(self) -> None:
@@ -4021,7 +4314,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only look up local users", channel.json_body["error"])
def test_success(self) -> None:
@@ -4042,7 +4335,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(
{"a": 1}, channel.json_body["account_data"]["global"]["m.global"]
)
@@ -4050,3 +4343,183 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
{"b": 2},
channel.json_body["account_data"]["rooms"]["test_room"]["m.per_room"],
)
+
+
+class UsersByExternalIdTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.get_success(
+ self.store.record_user_external_id(
+ "the-auth-provider", "the-external-id", self.other_user
+ )
+ )
+ self.get_success(
+ self.store.record_user_external_id(
+ "another-auth-provider", "a:complex@external/id", self.other_user
+ )
+ )
+
+ def test_no_auth(self) -> None:
+ """Try to lookup a user without authentication."""
+ url = (
+ "/_synapse/admin/v1/auth_providers/the-auth-provider/users/the-external-id"
+ )
+
+ channel = self.make_request(
+ "GET",
+ url,
+ )
+
+ self.assertEqual(401, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_binding_does_not_exist(self) -> None:
+ """Tests that a lookup for an external ID that does not exist returns a 404"""
+ url = "/_synapse/admin/v1/auth_providers/the-auth-provider/users/unknown-id"
+
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_success(self) -> None:
+ """Tests a successful external ID lookup"""
+ url = (
+ "/_synapse/admin/v1/auth_providers/the-auth-provider/users/the-external-id"
+ )
+
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ {"user_id": self.other_user},
+ channel.json_body,
+ )
+
+ def test_success_urlencoded(self) -> None:
+ """Tests a successful external ID lookup with an url-encoded ID"""
+ url = "/_synapse/admin/v1/auth_providers/another-auth-provider/users/a%3Acomplex%40external%2Fid"
+
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ {"user_id": self.other_user},
+ channel.json_body,
+ )
+
+
+class UsersByThreePidTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.get_success(
+ self.store.user_add_threepid(
+ self.other_user, "email", "user@email.com", 1, 1
+ )
+ )
+ self.get_success(
+ self.store.user_add_threepid(self.other_user, "msidn", "+1-12345678", 1, 1)
+ )
+
+ def test_no_auth(self) -> None:
+ """Try to look up a user without authentication."""
+ url = "/_synapse/admin/v1/threepid/email/users/user%40email.com"
+
+ channel = self.make_request(
+ "GET",
+ url,
+ )
+
+ self.assertEqual(401, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_medium_does_not_exist(self) -> None:
+ """Tests that both a lookup for a medium that does not exist and a user that
+ doesn't exist with that third party ID returns a 404"""
+ # test for unknown medium
+ url = "/_synapse/admin/v1/threepid/publickey/users/unknown-key"
+
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ # test for unknown user with a known medium
+ url = "/_synapse/admin/v1/threepid/email/users/unknown"
+
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_success(self) -> None:
+ """Tests a successful medium + address lookup"""
+ # test for email medium with encoded value of user@email.com
+ url = "/_synapse/admin/v1/threepid/email/users/user%40email.com"
+
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ {"user_id": self.other_user},
+ channel.json_body,
+ )
+
+ # test for msidn medium with encoded value of +1-12345678
+ url = "/_synapse/admin/v1/threepid/msidn/users/%2B1-12345678"
+
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ {"user_id": self.other_user},
+ channel.json_body,
+ )
diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py
index b21f6d4689..30f12f1bff 100644
--- a/tests/rest/admin/test_username_available.py
+++ b/tests/rest/admin/test_username_available.py
@@ -11,9 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from http import HTTPStatus
-
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -40,7 +37,7 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
if username == "allowed":
return True
raise SynapseError(
- HTTPStatus.BAD_REQUEST,
+ 400,
"User ID already taken.",
errcode=Codes.USER_IN_USE,
)
@@ -50,27 +47,23 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
def test_username_available(self) -> None:
"""
- The endpoint should return a HTTPStatus.OK response if the username does not exist
+ The endpoint should return a 200 response if the username does not exist
"""
url = "%s?username=%s" % (self.url, "allowed")
channel = self.make_request("GET", url, access_token=self.admin_user_tok)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertTrue(channel.json_body["available"])
def test_username_unavailable(self) -> None:
"""
- The endpoint should return a HTTPStatus.OK response if the username does not exist
+ The endpoint should return a 200 response if the username does not exist
"""
url = "%s?username=%s" % (self.url, "disallowed")
channel = self.make_request("GET", url, access_token=self.admin_user_tok)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], "M_USER_IN_USE")
self.assertEqual(channel.json_body["error"], "User ID already taken.")
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index 7ae926dc9c..c1a7fb2f8a 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -488,7 +488,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST", "account/deactivate", request_data, access_token=tok
)
- self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.code, 200, channel.json_body)
class WhoamiTestCase(unittest.HomeserverTestCase):
@@ -641,21 +641,21 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
def test_add_email_no_at(self) -> None:
self._request_token_invalid_email(
"address-without-at.bar",
- expected_errcode=Codes.UNKNOWN,
+ expected_errcode=Codes.BAD_JSON,
expected_error="Unable to parse email address",
)
def test_add_email_two_at(self) -> None:
self._request_token_invalid_email(
"foo@foo@test.bar",
- expected_errcode=Codes.UNKNOWN,
+ expected_errcode=Codes.BAD_JSON,
expected_error="Unable to parse email address",
)
def test_add_email_bad_format(self) -> None:
self._request_token_invalid_email(
"user@bad.example.net@good.example.com",
- expected_errcode=Codes.UNKNOWN,
+ expected_errcode=Codes.BAD_JSON,
expected_error="Unable to parse email address",
)
@@ -1001,7 +1001,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
)
self.assertEqual(expected_errcode, channel.json_body["errcode"])
- self.assertEqual(expected_error, channel.json_body["error"])
+ self.assertIn(expected_error, channel.json_body["error"])
def _validate_token(self, link: str) -> None:
# Remove the host
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 05355c7fb6..208ec44829 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import re
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -20,7 +21,8 @@ from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
import synapse.rest.admin
-from synapse.api.constants import LoginType
+from synapse.api.constants import ApprovalNoticeMedium, LoginType
+from synapse.api.errors import Codes, SynapseError
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.rest.client import account, auth, devices, login, logout, register
from synapse.rest.synapse.client import build_synapse_client_resource_tree
@@ -31,8 +33,8 @@ from synapse.util import Clock
from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC
-from tests.rest.client.utils import TEST_OIDC_CONFIG
-from tests.server import FakeChannel
+from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER
+from tests.server import FakeChannel, make_request
from tests.unittest import override_config, skip_unless
@@ -464,9 +466,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
* checking that the original operation succeeds
"""
+ fake_oidc_server = self.helper.fake_oidc_server()
+
# log the user in
remote_user_id = UserID.from_string(self.user).localpart
- login_resp = self.helper.login_via_oidc(remote_user_id)
+ login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, remote_user_id)
self.assertEqual(login_resp["user_id"], self.user)
# initiate a UI Auth process by attempting to delete the device
@@ -480,8 +484,8 @@ class UIAuthTests(unittest.HomeserverTestCase):
# run the UIA-via-SSO flow
session_id = channel.json_body["session"]
- channel = self.helper.auth_via_oidc(
- {"sub": remote_user_id}, ui_auth_session_id=session_id
+ channel, _ = self.helper.auth_via_oidc(
+ fake_oidc_server, {"sub": remote_user_id}, ui_auth_session_id=session_id
)
# that should serve a confirmation page
@@ -498,7 +502,8 @@ class UIAuthTests(unittest.HomeserverTestCase):
@skip_unless(HAS_OIDC, "requires OIDC")
@override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_does_not_offer_password_for_sso_user(self) -> None:
- login_resp = self.helper.login_via_oidc("username")
+ fake_oidc_server = self.helper.fake_oidc_server()
+ login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, "username")
user_tok = login_resp["access_token"]
device_id = login_resp["device_id"]
@@ -521,7 +526,10 @@ class UIAuthTests(unittest.HomeserverTestCase):
@override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_offers_both_flows_for_upgraded_user(self) -> None:
"""A user that had a password and then logged in with SSO should get both flows"""
- login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
+ fake_oidc_server = self.helper.fake_oidc_server()
+ login_resp, _ = self.helper.login_via_oidc(
+ fake_oidc_server, UserID.from_string(self.user).localpart
+ )
self.assertEqual(login_resp["user_id"], self.user)
channel = self.delete_device(
@@ -538,8 +546,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
@override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_ui_auth_fails_for_incorrect_sso_user(self) -> None:
"""If the user tries to authenticate with the wrong SSO user, they get an error"""
+
+ fake_oidc_server = self.helper.fake_oidc_server()
+
# log the user in
- login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
+ login_resp, _ = self.helper.login_via_oidc(
+ fake_oidc_server, UserID.from_string(self.user).localpart
+ )
self.assertEqual(login_resp["user_id"], self.user)
# start a UI Auth flow by attempting to delete a device
@@ -552,8 +565,8 @@ class UIAuthTests(unittest.HomeserverTestCase):
session_id = channel.json_body["session"]
# do the OIDC auth, but auth as the wrong user
- channel = self.helper.auth_via_oidc(
- {"sub": "wrong_user"}, ui_auth_session_id=session_id
+ channel, _ = self.helper.auth_via_oidc(
+ fake_oidc_server, {"sub": "wrong_user"}, ui_auth_session_id=session_id
)
# that should return a failure message
@@ -567,6 +580,39 @@ class UIAuthTests(unittest.HomeserverTestCase):
body={"auth": {"session": session_id}},
)
+ @skip_unless(HAS_OIDC, "requires OIDC")
+ @override_config(
+ {
+ "oidc_config": TEST_OIDC_CONFIG,
+ "experimental_features": {
+ "msc3866": {
+ "enabled": True,
+ "require_approval_for_new_accounts": True,
+ }
+ },
+ }
+ )
+ def test_sso_not_approved(self) -> None:
+ """Tests that if we register a user via SSO while requiring approval for new
+ accounts, we still raise the correct error before logging the user in.
+ """
+ fake_oidc_server = self.helper.fake_oidc_server()
+ login_resp, _ = self.helper.login_via_oidc(
+ fake_oidc_server, "username", expected_status=403
+ )
+
+ self.assertEqual(login_resp["errcode"], Codes.USER_AWAITING_APPROVAL)
+ self.assertEqual(
+ ApprovalNoticeMedium.NONE, login_resp["approval_notice_medium"]
+ )
+
+ # Check that we didn't register a device for the user during the login attempt.
+ devices = self.get_success(
+ self.hs.get_datastores().main.get_devices_by_user("@username:test")
+ )
+
+ self.assertEqual(len(devices), 0)
+
class RefreshAuthTests(unittest.HomeserverTestCase):
servlets = [
@@ -589,23 +635,10 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
"""
return self.make_request(
"POST",
- "/_matrix/client/v1/refresh",
+ "/_matrix/client/v3/refresh",
{"refresh_token": refresh_token},
)
- def is_access_token_valid(self, access_token: str) -> bool:
- """
- Checks whether an access token is valid, returning whether it is or not.
- """
- code = self.make_request(
- "GET", "/_matrix/client/v3/account/whoami", access_token=access_token
- ).code
-
- # Either 200 or 401 is what we get back; anything else is a bug.
- assert code in {HTTPStatus.OK, HTTPStatus.UNAUTHORIZED}
-
- return code == HTTPStatus.OK
-
def test_login_issue_refresh_token(self) -> None:
"""
A login response should include a refresh_token only if asked.
@@ -691,7 +724,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
refresh_response = self.make_request(
"POST",
- "/_matrix/client/v1/refresh",
+ "/_matrix/client/v3/refresh",
{"refresh_token": login_response.json_body["refresh_token"]},
)
self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result)
@@ -732,7 +765,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
refresh_response = self.make_request(
"POST",
- "/_matrix/client/v1/refresh",
+ "/_matrix/client/v3/refresh",
{"refresh_token": login_response.json_body["refresh_token"]},
)
self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result)
@@ -802,29 +835,37 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
self.reactor.advance(59.0)
# Both tokens should still be valid.
- self.assertTrue(self.is_access_token_valid(refreshable_access_token))
- self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token))
+ self.helper.whoami(refreshable_access_token, expect_code=HTTPStatus.OK)
+ self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK)
# Advance to 61 s (just past 1 minute, the time of expiry)
self.reactor.advance(2.0)
# Only the non-refreshable token is still valid.
- self.assertFalse(self.is_access_token_valid(refreshable_access_token))
- self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token))
+ self.helper.whoami(
+ refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
+ )
+ self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK)
# Advance to 599 s (just shy of 10 minutes, the time of expiry)
self.reactor.advance(599.0 - 61.0)
# It's still the case that only the non-refreshable token is still valid.
- self.assertFalse(self.is_access_token_valid(refreshable_access_token))
- self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token))
+ self.helper.whoami(
+ refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
+ )
+ self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK)
# Advance to 601 s (just past 10 minutes, the time of expiry)
self.reactor.advance(2.0)
# Now neither token is valid.
- self.assertFalse(self.is_access_token_valid(refreshable_access_token))
- self.assertFalse(self.is_access_token_valid(nonrefreshable_access_token))
+ self.helper.whoami(
+ refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
+ )
+ self.helper.whoami(
+ nonrefreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
+ )
@override_config(
{"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"}
@@ -961,7 +1002,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
# This first refresh should work properly
first_refresh_response = self.make_request(
"POST",
- "/_matrix/client/v1/refresh",
+ "/_matrix/client/v3/refresh",
{"refresh_token": login_response.json_body["refresh_token"]},
)
self.assertEqual(
@@ -971,7 +1012,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
# This one as well, since the token in the first one was never used
second_refresh_response = self.make_request(
"POST",
- "/_matrix/client/v1/refresh",
+ "/_matrix/client/v3/refresh",
{"refresh_token": login_response.json_body["refresh_token"]},
)
self.assertEqual(
@@ -981,7 +1022,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
# This one should not, since the token from the first refresh is not valid anymore
third_refresh_response = self.make_request(
"POST",
- "/_matrix/client/v1/refresh",
+ "/_matrix/client/v3/refresh",
{"refresh_token": first_refresh_response.json_body["refresh_token"]},
)
self.assertEqual(
@@ -1015,7 +1056,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
# Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail
fourth_refresh_response = self.make_request(
"POST",
- "/_matrix/client/v1/refresh",
+ "/_matrix/client/v3/refresh",
{"refresh_token": login_response.json_body["refresh_token"]},
)
self.assertEqual(
@@ -1027,7 +1068,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
# But refreshing from the last valid refresh token still works
fifth_refresh_response = self.make_request(
"POST",
- "/_matrix/client/v1/refresh",
+ "/_matrix/client/v3/refresh",
{"refresh_token": second_refresh_response.json_body["refresh_token"]},
)
self.assertEqual(
@@ -1120,3 +1161,349 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
# and no refresh token
self.assertEqual(_table_length("access_tokens"), 0)
self.assertEqual(_table_length("refresh_tokens"), 0)
+
+
+def oidc_config(
+ id: str, with_localpart_template: bool, **kwargs: Any
+) -> Dict[str, Any]:
+ """Sample OIDC provider config used in backchannel logout tests.
+
+ Args:
+ id: IDP ID for this provider
+ with_localpart_template: Set to `true` to have a default localpart_template in
+ the `user_mapping_provider` config and skip the user mapping session
+ **kwargs: rest of the config
+
+ Returns:
+ A dict suitable for the `oidc_config` or the `oidc_providers[]` parts of
+ the HS config
+ """
+ config: Dict[str, Any] = {
+ "idp_id": id,
+ "idp_name": id,
+ "issuer": TEST_OIDC_ISSUER,
+ "client_id": "test-client-id",
+ "client_secret": "test-client-secret",
+ "scopes": ["openid"],
+ }
+
+ if with_localpart_template:
+ config["user_mapping_provider"] = {
+ "config": {"localpart_template": "{{ user.sub }}"}
+ }
+ else:
+ config["user_mapping_provider"] = {"config": {}}
+
+ config.update(kwargs)
+
+ return config
+
+
+@skip_unless(HAS_OIDC, "Requires OIDC")
+class OidcBackchannelLogoutTests(unittest.HomeserverTestCase):
+ servlets = [
+ account.register_servlets,
+ login.register_servlets,
+ ]
+
+ def default_config(self) -> Dict[str, Any]:
+ config = super().default_config()
+
+ # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns
+ # False, so synapse will see the requested uri as http://..., so using http in
+ # the public_baseurl stops Synapse trying to redirect to https.
+ config["public_baseurl"] = "http://synapse.test"
+
+ return config
+
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ resource_dict = super().create_resource_dict()
+ resource_dict.update(build_synapse_client_resource_tree(self.hs))
+ return resource_dict
+
+ def submit_logout_token(self, logout_token: str) -> FakeChannel:
+ return self.make_request(
+ "POST",
+ "/_synapse/client/oidc/backchannel_logout",
+ content=f"logout_token={logout_token}",
+ content_is_form=True,
+ )
+
+ @override_config(
+ {
+ "oidc_providers": [
+ oidc_config(
+ id="oidc",
+ with_localpart_template=True,
+ backchannel_logout_enabled=True,
+ )
+ ]
+ }
+ )
+ def test_simple_logout(self) -> None:
+ """
+ Receiving a logout token should logout the user
+ """
+ fake_oidc_server = self.helper.fake_oidc_server()
+ user = "john"
+
+ login_resp, first_grant = self.helper.login_via_oidc(
+ fake_oidc_server, user, with_sid=True
+ )
+ first_access_token: str = login_resp["access_token"]
+ self.helper.whoami(first_access_token, expect_code=HTTPStatus.OK)
+
+ login_resp, second_grant = self.helper.login_via_oidc(
+ fake_oidc_server, user, with_sid=True
+ )
+ second_access_token: str = login_resp["access_token"]
+ self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
+
+ self.assertNotEqual(first_grant.sid, second_grant.sid)
+ self.assertEqual(first_grant.userinfo["sub"], second_grant.userinfo["sub"])
+
+ # Logging out of the first session
+ logout_token = fake_oidc_server.generate_logout_token(first_grant)
+ channel = self.submit_logout_token(logout_token)
+ self.assertEqual(channel.code, 200)
+
+ self.helper.whoami(first_access_token, expect_code=HTTPStatus.UNAUTHORIZED)
+ self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
+
+ # Logging out of the second session
+ logout_token = fake_oidc_server.generate_logout_token(second_grant)
+ channel = self.submit_logout_token(logout_token)
+ self.assertEqual(channel.code, 200)
+
+ @override_config(
+ {
+ "oidc_providers": [
+ oidc_config(
+ id="oidc",
+ with_localpart_template=True,
+ backchannel_logout_enabled=True,
+ )
+ ]
+ }
+ )
+ def test_logout_during_login(self) -> None:
+ """
+ It should revoke login tokens when receiving a logout token
+ """
+ fake_oidc_server = self.helper.fake_oidc_server()
+ user = "john"
+
+ # Get an authentication, and logout before submitting the logout token
+ client_redirect_url = "https://x"
+ userinfo = {"sub": user}
+ channel, grant = self.helper.auth_via_oidc(
+ fake_oidc_server,
+ userinfo,
+ client_redirect_url,
+ with_sid=True,
+ )
+
+ # expect a confirmation page
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+ # fish the matrix login token out of the body of the confirmation page
+ m = re.search(
+ 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
+ channel.text_body,
+ )
+ assert m, channel.text_body
+ login_token = m.group(1)
+
+ # Submit a logout
+ logout_token = fake_oidc_server.generate_logout_token(grant)
+ channel = self.submit_logout_token(logout_token)
+ self.assertEqual(channel.code, 200)
+
+ # Now try to exchange the login token
+ channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "POST",
+ "/login",
+ content={"type": "m.login.token", "token": login_token},
+ )
+ # It should have failed
+ self.assertEqual(channel.code, 403)
+
+ @override_config(
+ {
+ "oidc_providers": [
+ oidc_config(
+ id="oidc",
+ with_localpart_template=False,
+ backchannel_logout_enabled=True,
+ )
+ ]
+ }
+ )
+ def test_logout_during_mapping(self) -> None:
+ """
+ It should stop ongoing user mapping session when receiving a logout token
+ """
+ fake_oidc_server = self.helper.fake_oidc_server()
+ user = "john"
+
+ # Get an authentication, and logout before submitting the logout token
+ client_redirect_url = "https://x"
+ userinfo = {"sub": user}
+ channel, grant = self.helper.auth_via_oidc(
+ fake_oidc_server,
+ userinfo,
+ client_redirect_url,
+ with_sid=True,
+ )
+
+ # Expect a user mapping page
+ self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
+
+ # We should have a user_mapping_session cookie
+ cookie_headers = channel.headers.getRawHeaders("Set-Cookie")
+ assert cookie_headers
+ cookies: Dict[str, str] = {}
+ for h in cookie_headers:
+ key, value = h.split(";")[0].split("=", maxsplit=1)
+ cookies[key] = value
+
+ user_mapping_session_id = cookies["username_mapping_session"]
+
+ # Getting that session should not raise
+ session = self.hs.get_sso_handler().get_mapping_session(user_mapping_session_id)
+ self.assertIsNotNone(session)
+
+ # Submit a logout
+ logout_token = fake_oidc_server.generate_logout_token(grant)
+ channel = self.submit_logout_token(logout_token)
+ self.assertEqual(channel.code, 200)
+
+ # Now it should raise
+ with self.assertRaises(SynapseError):
+ self.hs.get_sso_handler().get_mapping_session(user_mapping_session_id)
+
+ @override_config(
+ {
+ "oidc_providers": [
+ oidc_config(
+ id="oidc",
+ with_localpart_template=True,
+ backchannel_logout_enabled=False,
+ )
+ ]
+ }
+ )
+ def test_disabled(self) -> None:
+ """
+ Receiving a logout token should do nothing if it is disabled in the config
+ """
+ fake_oidc_server = self.helper.fake_oidc_server()
+ user = "john"
+
+ login_resp, grant = self.helper.login_via_oidc(
+ fake_oidc_server, user, with_sid=True
+ )
+ access_token: str = login_resp["access_token"]
+ self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
+
+ # Logging out shouldn't work
+ logout_token = fake_oidc_server.generate_logout_token(grant)
+ channel = self.submit_logout_token(logout_token)
+ self.assertEqual(channel.code, 400)
+
+ # And the token should still be valid
+ self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
+
+ @override_config(
+ {
+ "oidc_providers": [
+ oidc_config(
+ id="oidc",
+ with_localpart_template=True,
+ backchannel_logout_enabled=True,
+ )
+ ]
+ }
+ )
+ def test_no_sid(self) -> None:
+ """
+ Receiving a logout token without `sid` during the login should do nothing
+ """
+ fake_oidc_server = self.helper.fake_oidc_server()
+ user = "john"
+
+ login_resp, grant = self.helper.login_via_oidc(
+ fake_oidc_server, user, with_sid=False
+ )
+ access_token: str = login_resp["access_token"]
+ self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
+
+ # Logging out shouldn't work
+ logout_token = fake_oidc_server.generate_logout_token(grant)
+ channel = self.submit_logout_token(logout_token)
+ self.assertEqual(channel.code, 400)
+
+ # And the token should still be valid
+ self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
+
+ @override_config(
+ {
+ "oidc_providers": [
+ oidc_config(
+ "first",
+ issuer="https://first-issuer.com/",
+ with_localpart_template=True,
+ backchannel_logout_enabled=True,
+ ),
+ oidc_config(
+ "second",
+ issuer="https://second-issuer.com/",
+ with_localpart_template=True,
+ backchannel_logout_enabled=True,
+ ),
+ ]
+ }
+ )
+ def test_multiple_providers(self) -> None:
+ """
+ It should be able to distinguish login tokens from two different IdPs
+ """
+ first_server = self.helper.fake_oidc_server(issuer="https://first-issuer.com/")
+ second_server = self.helper.fake_oidc_server(
+ issuer="https://second-issuer.com/"
+ )
+ user = "john"
+
+ login_resp, first_grant = self.helper.login_via_oidc(
+ first_server, user, with_sid=True, idp_id="oidc-first"
+ )
+ first_access_token: str = login_resp["access_token"]
+ self.helper.whoami(first_access_token, expect_code=HTTPStatus.OK)
+
+ login_resp, second_grant = self.helper.login_via_oidc(
+ second_server, user, with_sid=True, idp_id="oidc-second"
+ )
+ second_access_token: str = login_resp["access_token"]
+ self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
+
+ # `sid` in the fake providers are generated by a counter, so the first grant of
+ # each provider should give the same SID
+ self.assertEqual(first_grant.sid, second_grant.sid)
+ self.assertEqual(first_grant.userinfo["sub"], second_grant.userinfo["sub"])
+
+ # Logging out of the first session
+ logout_token = first_server.generate_logout_token(first_grant)
+ channel = self.submit_logout_token(logout_token)
+ self.assertEqual(channel.code, 200)
+
+ self.helper.whoami(first_access_token, expect_code=HTTPStatus.UNAUTHORIZED)
+ self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
+
+ # Logging out of the second session
+ logout_token = second_server.generate_logout_token(second_grant)
+ channel = self.submit_logout_token(logout_token)
+ self.assertEqual(channel.code, 200)
+
+ self.helper.whoami(second_access_token, expect_code=HTTPStatus.UNAUTHORIZED)
diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py
index aa98222434..d80eea17d3 100644
--- a/tests/rest/client/test_devices.py
+++ b/tests/rest/client/test_devices.py
@@ -200,3 +200,37 @@ class DevicesTestCase(unittest.HomeserverTestCase):
self.reactor.advance(43200)
self.get_success(self.handler.get_device(user_id, "abc"))
self.get_failure(self.handler.get_device(user_id, "def"), NotFoundError)
+
+
+class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ register.register_servlets,
+ devices.register_servlets,
+ ]
+
+ def test_PUT(self) -> None:
+ """Sanity-check that we can PUT a dehydrated device.
+
+ Detects https://github.com/matrix-org/synapse/issues/14334.
+ """
+ alice = self.register_user("alice", "correcthorse")
+ token = self.login(alice, "correcthorse")
+
+ # Have alice update their device list
+ channel = self.make_request(
+ "PUT",
+ "_matrix/client/unstable/org.matrix.msc2697.v2/dehydrated_device",
+ {
+ "device_data": {
+ "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm",
+ "account": "dehydrated_device",
+ }
+ },
+ access_token=token,
+ shorthand=False,
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
+ device_id = channel.json_body.get("device_id")
+ self.assertIsInstance(device_id, str)
diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py
index 823e8ab8c4..afc8d641be 100644
--- a/tests/rest/client/test_filter.py
+++ b/tests/rest/client/test_filter.py
@@ -43,7 +43,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.EXAMPLE_FILTER_JSON,
)
- self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {"filter_id": "0"})
filter = self.get_success(
self.store.get_user_filter(user_localpart="apple", filter_id=0)
@@ -58,7 +58,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.EXAMPLE_FILTER_JSON,
)
- self.assertEqual(channel.result["code"], b"403")
+ self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
def test_add_filter_non_local_user(self) -> None:
@@ -71,7 +71,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
)
self.hs.is_mine = _is_mine
- self.assertEqual(channel.result["code"], b"403")
+ self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
def test_get_filter(self) -> None:
@@ -85,7 +85,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
"GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id)
)
- self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, self.EXAMPLE_FILTER)
def test_get_filter_non_existant(self) -> None:
@@ -93,7 +93,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id)
)
- self.assertEqual(channel.result["code"], b"404")
+ self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
# Currently invalid params do not have an appropriate errcode
@@ -103,7 +103,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id)
)
- self.assertEqual(channel.result["code"], b"400")
+ self.assertEqual(channel.code, 400)
# No ID also returns an invalid_id error
def test_get_filter_no_id(self) -> None:
@@ -111,4 +111,4 @@ class FilterTestCase(unittest.HomeserverTestCase):
"GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id)
)
- self.assertEqual(channel.result["code"], b"400")
+ self.assertEqual(channel.code, 400)
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index dc17c9d113..b0c8215744 100644
--- a/tests/rest/client/test_identity.py
+++ b/tests/rest/client/test_identity.py
@@ -25,7 +25,6 @@ from tests import unittest
class IdentityTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -33,7 +32,6 @@ class IdentityTestCase(unittest.HomeserverTestCase):
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
config = self.default_config()
config["enable_3pid_lookup"] = False
self.hs = self.setup_test_homeserver(config=config)
@@ -54,6 +52,7 @@ class IdentityTestCase(unittest.HomeserverTestCase):
"id_server": "testis",
"medium": "email",
"address": "test@example.com",
+ "id_access_token": tok,
}
request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
channel = self.make_request(
diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py
index bbc8e74243..741fecea77 100644
--- a/tests/rest/client/test_keys.py
+++ b/tests/rest/client/test_keys.py
@@ -19,6 +19,7 @@ from synapse.rest import admin
from synapse.rest.client import keys, login
from tests import unittest
+from tests.http.server._base import make_request_with_cancellation_test
class KeyQueryTestCase(unittest.HomeserverTestCase):
@@ -89,3 +90,31 @@ class KeyQueryTestCase(unittest.HomeserverTestCase):
Codes.BAD_JSON,
channel.result,
)
+
+ def test_key_query_cancellation(self) -> None:
+ """
+ Tests that /keys/query is cancellable and does not swallow the
+ CancelledError.
+ """
+ self.register_user("alice", "wonderland")
+ alice_token = self.login("alice", "wonderland")
+
+ bob = self.register_user("bob", "uncle")
+
+ channel = make_request_with_cancellation_test(
+ "test_key_query_cancellation",
+ self.reactor,
+ self.site,
+ "POST",
+ "/_matrix/client/r0/keys/query",
+ {
+ "device_keys": {
+ # Empty list means we request keys for all bob's devices
+ bob: [],
+ },
+ },
+ token=alice_token,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertIn(bob, channel.json_body["device_keys"])
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index a2958f6959..ff5baa9f0a 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -13,7 +13,6 @@
# limitations under the License.
import time
import urllib.parse
-from http import HTTPStatus
from typing import Any, Dict, List, Optional
from unittest.mock import Mock
from urllib.parse import urlencode
@@ -24,6 +23,8 @@ from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
import synapse.rest.admin
+from synapse.api.constants import ApprovalNoticeMedium, LoginType
+from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
from synapse.rest.client import devices, login, logout, register
from synapse.rest.client.account import WhoamiRestServlet
@@ -35,7 +36,7 @@ from synapse.util import Clock
from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC
from tests.handlers.test_saml import has_saml2
-from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
+from tests.rest.client.utils import TEST_OIDC_CONFIG
from tests.server import FakeChannel
from tests.test_utils.html_parsers import TestHtmlParser
from tests.unittest import HomeserverTestCase, override_config, skip_unless
@@ -95,6 +96,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
logout.register_servlets,
devices.register_servlets,
lambda hs, http_server: WhoamiRestServlet(hs).register(http_server),
+ register.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
@@ -134,10 +136,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", LOGIN_URL, params)
if i == 5:
- self.assertEqual(channel.result["code"], b"429", channel.result)
+ self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
# than 1min.
@@ -152,7 +154,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request(b"POST", LOGIN_URL, params)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
@override_config(
{
@@ -179,10 +181,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", LOGIN_URL, params)
if i == 5:
- self.assertEqual(channel.result["code"], b"429", channel.result)
+ self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
# than 1min.
@@ -197,7 +199,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request(b"POST", LOGIN_URL, params)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
@override_config(
{
@@ -224,10 +226,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", LOGIN_URL, params)
if i == 5:
- self.assertEqual(channel.result["code"], b"429", channel.result)
+ self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
# than 1min.
@@ -242,7 +244,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request(b"POST", LOGIN_URL, params)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
@override_config({"session_lifetime": "24h"})
def test_soft_logout(self) -> None:
@@ -250,7 +252,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# we shouldn't be able to make requests without an access token
channel = self.make_request(b"GET", TEST_URL)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN")
# log in as normal
@@ -261,20 +263,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request(b"POST", LOGIN_URL, params)
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
access_token = channel.json_body["access_token"]
device_id = channel.json_body["device_id"]
# we should now be able to make requests with the access token
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
# time passes
self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
- self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+ self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True)
@@ -288,7 +290,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# more requests with the expired token should still return a soft-logout
self.reactor.advance(3600)
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
- self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+ self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True)
@@ -296,7 +298,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self._delete_device(access_token_2, "kermit", "monkey", device_id)
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
- self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+ self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], False)
@@ -307,7 +309,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
b"DELETE", "devices/" + device_id, access_token=access_token
)
- self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+ self.assertEqual(channel.code, 401, channel.result)
# check it's a UI-Auth fail
self.assertEqual(
set(channel.json_body.keys()),
@@ -330,7 +332,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
access_token=access_token,
content={"auth": auth},
)
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
@override_config({"session_lifetime": "24h"})
def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None:
@@ -341,20 +343,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# we should now be able to make requests with the access token
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
# time passes
self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
- self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+ self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True)
# Now try to hard logout this session
channel = self.make_request(b"POST", "/logout", access_token=access_token)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
@override_config({"session_lifetime": "24h"})
def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(
@@ -367,20 +369,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# we should now be able to make requests with the access token
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
# time passes
self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
- self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+ self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True)
# Now try to hard log out all of the user's sessions
channel = self.make_request(b"POST", "/logout/all", access_token=access_token)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
def test_login_with_overly_long_device_id_fails(self) -> None:
self.register_user("mickey", "cheese")
@@ -407,6 +409,44 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400)
self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM")
+ @override_config(
+ {
+ "experimental_features": {
+ "msc3866": {
+ "enabled": True,
+ "require_approval_for_new_accounts": True,
+ }
+ }
+ }
+ )
+ def test_require_approval(self) -> None:
+ channel = self.make_request(
+ "POST",
+ "register",
+ {
+ "username": "kermit",
+ "password": "monkey",
+ "auth": {"type": LoginType.DUMMY},
+ },
+ )
+ self.assertEqual(403, channel.code, channel.result)
+ self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
+ self.assertEqual(
+ ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
+ )
+
+ params = {
+ "type": LoginType.PASSWORD,
+ "identifier": {"type": "m.id.user", "user": "kermit"},
+ "password": "monkey",
+ }
+ channel = self.make_request("POST", LOGIN_URL, params)
+ self.assertEqual(403, channel.code, channel.result)
+ self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
+ self.assertEqual(
+ ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
+ )
+
@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
class MultiSSOTestCase(unittest.HomeserverTestCase):
@@ -466,7 +506,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
def test_get_login_flows(self) -> None:
"""GET /login should return password and SSO flows"""
channel = self.make_request("GET", "/_matrix/client/r0/login")
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
expected_flow_types = [
"m.login.cas",
@@ -494,14 +534,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
"""/login/sso/redirect should redirect to an identity picker"""
# first hit the redirect url, which should redirect to our idp picker
channel = self._make_sso_redirect_request(None)
- self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
+ self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
uri = location_headers[0]
# hitting that picker should give us some HTML
channel = self.make_request("GET", uri)
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
# parse the form to check it has fields assumed elsewhere in this class
html = channel.result["body"].decode("utf-8")
@@ -530,7 +570,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ "&idp=cas",
shorthand=False,
)
- self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
+ self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
cas_uri = location_headers[0]
@@ -555,7 +595,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ "&idp=saml",
)
- self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
+ self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
saml_uri = location_headers[0]
@@ -572,21 +612,24 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
def test_login_via_oidc(self) -> None:
"""If OIDC is chosen, should redirect to the OIDC auth endpoint"""
- # pick the default OIDC provider
- channel = self.make_request(
- "GET",
- "/_synapse/client/pick_idp?redirectUrl="
- + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
- + "&idp=oidc",
- )
- self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
+ fake_oidc_server = self.helper.fake_oidc_server()
+
+ with fake_oidc_server.patch_homeserver(hs=self.hs):
+ # pick the default OIDC provider
+ channel = self.make_request(
+ "GET",
+ "/_synapse/client/pick_idp?redirectUrl="
+ + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ + "&idp=oidc",
+ )
+ self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
oidc_uri = location_headers[0]
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
# it should redirect us to the auth page of the OIDC server
- self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+ self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint)
# ... and should have set a cookie including the redirect url
cookie_headers = channel.headers.getRawHeaders("Set-Cookie")
@@ -603,10 +646,12 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
TEST_CLIENT_REDIRECT_URL,
)
- channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})
+ channel, _ = self.helper.complete_oidc_auth(
+ fake_oidc_server, oidc_uri, cookies, {"sub": "user1"}
+ )
# that should serve a confirmation page
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
content_type_headers = channel.headers.getRawHeaders("Content-Type")
assert content_type_headers
self.assertTrue(content_type_headers[-1].startswith("text/html"))
@@ -634,7 +679,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
"/login",
content={"type": "m.login.token", "token": login_token},
)
- self.assertEqual(chan.code, HTTPStatus.OK, chan.result)
+ self.assertEqual(chan.code, 200, chan.result)
self.assertEqual(chan.json_body["user_id"], "@user1:test")
def test_multi_sso_redirect_to_unknown(self) -> None:
@@ -643,25 +688,28 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
"GET",
"/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
)
- self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+ self.assertEqual(channel.code, 400, channel.result)
def test_client_idp_redirect_to_unknown(self) -> None:
"""If the client tries to pick an unknown IdP, return a 404"""
channel = self._make_sso_redirect_request("xxx")
- self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
+ self.assertEqual(channel.code, 404, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
def test_client_idp_redirect_to_oidc(self) -> None:
"""If the client pick a known IdP, redirect to it"""
- channel = self._make_sso_redirect_request("oidc")
- self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
+ fake_oidc_server = self.helper.fake_oidc_server()
+
+ with fake_oidc_server.patch_homeserver(hs=self.hs):
+ channel = self._make_sso_redirect_request("oidc")
+ self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
oidc_uri = location_headers[0]
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
# it should redirect us to the auth page of the OIDC server
- self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+ self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint)
def _make_sso_redirect_request(self, idp_prov: Optional[str] = None) -> FakeChannel:
"""Send a request to /_matrix/client/r0/login/sso/redirect
@@ -765,7 +813,7 @@ class CASTestCase(unittest.HomeserverTestCase):
channel = self.make_request("GET", cas_ticket_url)
# Test that the response is HTML.
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
content_type_header_value = ""
for header in channel.result.get("headers", []):
if header[0] == b"Content-Type":
@@ -878,17 +926,17 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_jwt_valid_registered(self) -> None:
self.register_user("kermit", "monkey")
channel = self.jwt_login({"sub": "kermit"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
def test_login_jwt_valid_unregistered(self) -> None:
channel = self.jwt_login({"sub": "frog"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@frog:test")
def test_login_jwt_invalid_signature(self) -> None:
channel = self.jwt_login({"sub": "frog"}, "notsecret")
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
@@ -897,7 +945,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_jwt_expired(self) -> None:
channel = self.jwt_login({"sub": "frog", "exp": 864000})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
@@ -907,7 +955,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_jwt_not_before(self) -> None:
now = int(time.time())
channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
@@ -916,7 +964,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_no_sub(self) -> None:
channel = self.jwt_login({"username": "root"})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Invalid JWT")
@@ -925,12 +973,12 @@ class JWTTestCase(unittest.HomeserverTestCase):
"""Test validating the issuer claim."""
# A valid issuer.
channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
# An invalid issuer.
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
@@ -939,7 +987,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
# Not providing an issuer.
channel = self.jwt_login({"sub": "kermit"})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
@@ -949,7 +997,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_iss_no_config(self) -> None:
"""Test providing an issuer claim without requiring it in the configuration."""
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
@override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}})
@@ -957,12 +1005,12 @@ class JWTTestCase(unittest.HomeserverTestCase):
"""Test validating the audience claim."""
# A valid audience.
channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
# An invalid audience.
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
@@ -971,7 +1019,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
# Not providing an audience.
channel = self.jwt_login({"sub": "kermit"})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
@@ -981,7 +1029,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_aud_no_config(self) -> None:
"""Test providing an audience without requiring it in the configuration."""
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
@@ -991,20 +1039,20 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_default_sub(self) -> None:
"""Test reading user ID from the default subject claim."""
channel = self.jwt_login({"sub": "kermit"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
@override_config({"jwt_config": {**base_config, "subject_claim": "username"}})
def test_login_custom_sub(self) -> None:
"""Test reading user ID from a custom subject claim."""
channel = self.jwt_login({"username": "frog"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@frog:test")
def test_login_no_token(self) -> None:
params = {"type": "org.matrix.login.jwt"}
channel = self.make_request(b"POST", LOGIN_URL, params)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
@@ -1086,12 +1134,12 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
def test_login_jwt_valid(self) -> None:
channel = self.jwt_login({"sub": "kermit"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
def test_login_jwt_invalid_signature(self) -> None:
channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
@@ -1152,7 +1200,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
b"POST", LOGIN_URL, params, access_token=self.service.token
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
def test_login_appservice_user_bot(self) -> None:
"""Test that the appservice bot can use /login"""
@@ -1166,7 +1214,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
b"POST", LOGIN_URL, params, access_token=self.service.token
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
def test_login_appservice_wrong_user(self) -> None:
"""Test that non-as users cannot login with the as token"""
@@ -1180,7 +1228,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
b"POST", LOGIN_URL, params, access_token=self.service.token
)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
def test_login_appservice_wrong_as(self) -> None:
"""Test that as users cannot login with wrong as token"""
@@ -1194,7 +1242,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
b"POST", LOGIN_URL, params, access_token=self.another_service.token
)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
def test_login_appservice_no_token(self) -> None:
"""Test that users must provide a token when using the appservice
@@ -1208,7 +1256,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request(b"POST", LOGIN_URL, params)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
@skip_unless(HAS_OIDC, "requires OIDC")
@@ -1240,13 +1288,17 @@ class UsernamePickerTestCase(HomeserverTestCase):
def test_username_picker(self) -> None:
"""Test the happy path of a username picker flow."""
+ fake_oidc_server = self.helper.fake_oidc_server()
+
# do the start of the login flow
- channel = self.helper.auth_via_oidc(
- {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL
+ channel, _ = self.helper.auth_via_oidc(
+ fake_oidc_server,
+ {"sub": "tester", "displayname": "Jonny"},
+ TEST_CLIENT_REDIRECT_URL,
)
# that should redirect to the username picker
- self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
+ self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
picker_url = location_headers[0]
@@ -1290,7 +1342,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
("Content-Length", str(len(content))),
],
)
- self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result)
+ self.assertEqual(chan.code, 302, chan.result)
location_headers = chan.headers.getRawHeaders("Location")
assert location_headers
@@ -1300,7 +1352,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
path=location_headers[0],
custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
)
- self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result)
+ self.assertEqual(chan.code, 302, chan.result)
location_headers = chan.headers.getRawHeaders("Location")
assert location_headers
@@ -1325,5 +1377,5 @@ class UsernamePickerTestCase(HomeserverTestCase):
"/login",
content={"type": "m.login.token", "token": login_token},
)
- self.assertEqual(chan.code, HTTPStatus.OK, chan.result)
+ self.assertEqual(chan.code, 200, chan.result)
self.assertEqual(chan.json_body["user_id"], "@bobby:test")
diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py
new file mode 100644
index 0000000000..c2e1e08811
--- /dev/null
+++ b/tests/rest/client/test_login_token_request.py
@@ -0,0 +1,134 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.rest import admin
+from synapse.rest.client import login, login_token_request
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests import unittest
+from tests.unittest import override_config
+
+endpoint = "/_matrix/client/unstable/org.matrix.msc3882/login/token"
+
+
+class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ login.register_servlets,
+ admin.register_servlets,
+ login_token_request.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.hs = self.setup_test_homeserver()
+ self.hs.config.registration.enable_registration = True
+ self.hs.config.registration.registrations_require_3pid = []
+ self.hs.config.registration.auto_join_rooms = []
+ self.hs.config.captcha.enable_registration_captcha = False
+
+ return self.hs
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.user = "user123"
+ self.password = "password"
+
+ def test_disabled(self) -> None:
+ channel = self.make_request("POST", endpoint, {}, access_token=None)
+ self.assertEqual(channel.code, 400)
+
+ self.register_user(self.user, self.password)
+ token = self.login(self.user, self.password)
+
+ channel = self.make_request("POST", endpoint, {}, access_token=token)
+ self.assertEqual(channel.code, 400)
+
+ @override_config({"experimental_features": {"msc3882_enabled": True}})
+ def test_require_auth(self) -> None:
+ channel = self.make_request("POST", endpoint, {}, access_token=None)
+ self.assertEqual(channel.code, 401)
+
+ @override_config({"experimental_features": {"msc3882_enabled": True}})
+ def test_uia_on(self) -> None:
+ user_id = self.register_user(self.user, self.password)
+ token = self.login(self.user, self.password)
+
+ channel = self.make_request("POST", endpoint, {}, access_token=token)
+ self.assertEqual(channel.code, 401)
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+
+ session = channel.json_body["session"]
+
+ uia = {
+ "auth": {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": self.user},
+ "password": self.password,
+ "session": session,
+ },
+ }
+
+ channel = self.make_request("POST", endpoint, uia, access_token=token)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["expires_in"], 300)
+
+ login_token = channel.json_body["login_token"]
+
+ channel = self.make_request(
+ "POST",
+ "/login",
+ content={"type": "m.login.token", "token": login_token},
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.json_body["user_id"], user_id)
+
+ @override_config(
+ {"experimental_features": {"msc3882_enabled": True, "msc3882_ui_auth": False}}
+ )
+ def test_uia_off(self) -> None:
+ user_id = self.register_user(self.user, self.password)
+ token = self.login(self.user, self.password)
+
+ channel = self.make_request("POST", endpoint, {}, access_token=token)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["expires_in"], 300)
+
+ login_token = channel.json_body["login_token"]
+
+ channel = self.make_request(
+ "POST",
+ "/login",
+ content={"type": "m.login.token", "token": login_token},
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.json_body["user_id"], user_id)
+
+ @override_config(
+ {
+ "experimental_features": {
+ "msc3882_enabled": True,
+ "msc3882_ui_auth": False,
+ "msc3882_token_timeout": "15s",
+ }
+ }
+ )
+ def test_expires_in(self) -> None:
+ self.register_user(self.user, self.password)
+ token = self.login(self.user, self.password)
+
+ channel = self.make_request("POST", endpoint, {}, access_token=token)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["expires_in"], 15)
diff --git a/tests/rest/client/test_models.py b/tests/rest/client/test_models.py
new file mode 100644
index 0000000000..0b8fcb0c47
--- /dev/null
+++ b/tests/rest/client/test_models.py
@@ -0,0 +1,76 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import unittest as stdlib_unittest
+
+from pydantic import BaseModel, ValidationError
+from typing_extensions import Literal
+
+from synapse.rest.client.models import EmailRequestTokenBody
+
+
+class ThreepidMediumEnumTestCase(stdlib_unittest.TestCase):
+ class Model(BaseModel):
+ medium: Literal["email", "msisdn"]
+
+ def test_accepts_valid_medium_string(self) -> None:
+ """Sanity check that Pydantic behaves sensibly with an enum-of-str
+
+ This is arguably more of a test of a class that inherits from str and Enum
+ simultaneously.
+ """
+ model = self.Model.parse_obj({"medium": "email"})
+ self.assertEqual(model.medium, "email")
+
+ def test_rejects_invalid_medium_value(self) -> None:
+ with self.assertRaises(ValidationError):
+ self.Model.parse_obj({"medium": "interpretive_dance"})
+
+ def test_rejects_invalid_medium_type(self) -> None:
+ with self.assertRaises(ValidationError):
+ self.Model.parse_obj({"medium": 123})
+
+
+class EmailRequestTokenBodyTestCase(stdlib_unittest.TestCase):
+ base_request = {
+ "client_secret": "hunter2",
+ "email": "alice@wonderland.com",
+ "send_attempt": 1,
+ }
+
+ def test_token_required_if_id_server_provided(self) -> None:
+ with self.assertRaises(ValidationError):
+ EmailRequestTokenBody.parse_obj(
+ {
+ **self.base_request,
+ "id_server": "identity.wonderland.com",
+ }
+ )
+ with self.assertRaises(ValidationError):
+ EmailRequestTokenBody.parse_obj(
+ {
+ **self.base_request,
+ "id_server": "identity.wonderland.com",
+ "id_access_token": None,
+ }
+ )
+
+ def test_token_typechecked_when_id_server_provided(self) -> None:
+ with self.assertRaises(ValidationError):
+ EmailRequestTokenBody.parse_obj(
+ {
+ **self.base_request,
+ "id_server": "identity.wonderland.com",
+ "id_access_token": 1337,
+ }
+ )
diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py
index 7401b5e0c0..5dfe44defb 100644
--- a/tests/rest/client/test_redactions.py
+++ b/tests/rest/client/test_redactions.py
@@ -11,17 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List
+from typing import List, Optional
from twisted.test.proto_helpers import MemoryReactor
+from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin
from synapse.rest.client import login, room, sync
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, override_config
class RedactionsTestCase(HomeserverTestCase):
@@ -67,7 +68,12 @@ class RedactionsTestCase(HomeserverTestCase):
)
def _redact_event(
- self, access_token: str, room_id: str, event_id: str, expect_code: int = 200
+ self,
+ access_token: str,
+ room_id: str,
+ event_id: str,
+ expect_code: int = 200,
+ with_relations: Optional[List[str]] = None,
) -> JsonDict:
"""Helper function to send a redaction event.
@@ -75,13 +81,19 @@ class RedactionsTestCase(HomeserverTestCase):
"""
path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id)
- channel = self.make_request("POST", path, content={}, access_token=access_token)
- self.assertEqual(int(channel.result["code"]), expect_code)
+ request_content = {}
+ if with_relations:
+ request_content["org.matrix.msc3912.with_relations"] = with_relations
+
+ channel = self.make_request(
+ "POST", path, request_content, access_token=access_token
+ )
+ self.assertEqual(channel.code, expect_code)
return channel.json_body
def _sync_room_timeline(self, access_token: str, room_id: str) -> List[JsonDict]:
channel = self.make_request("GET", "sync", access_token=self.mod_access_token)
- self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.code, 200)
room_sync = channel.json_body["rooms"]["join"][room_id]
return room_sync["timeline"]["events"]
@@ -201,3 +213,256 @@ class RedactionsTestCase(HomeserverTestCase):
# These should all succeed, even though this would be denied by
# the standard message ratelimiter
self._redact_event(self.mod_access_token, self.room_id, msg_id)
+
+ @override_config({"experimental_features": {"msc3912_enabled": True}})
+ def test_redact_relations(self) -> None:
+ """Tests that we can redact the relations of an event at the same time as the
+ event itself.
+ """
+ # Send a root event.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "hello"},
+ tok=self.mod_access_token,
+ )
+ root_event_id = res["event_id"]
+
+ # Send an edit to this root event.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "body": " * hello world",
+ "m.new_content": {
+ "body": "hello world",
+ "msgtype": "m.text",
+ },
+ "m.relates_to": {
+ "event_id": root_event_id,
+ "rel_type": RelationTypes.REPLACE,
+ },
+ "msgtype": "m.text",
+ },
+ tok=self.mod_access_token,
+ )
+ edit_event_id = res["event_id"]
+
+ # Also send a threaded message whose root is the same as the edit's.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "message 1",
+ "m.relates_to": {
+ "event_id": root_event_id,
+ "rel_type": RelationTypes.THREAD,
+ },
+ },
+ tok=self.mod_access_token,
+ )
+ threaded_event_id = res["event_id"]
+
+ # Also send a reaction, again with the same root.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Reaction,
+ content={
+ "m.relates_to": {
+ "rel_type": RelationTypes.ANNOTATION,
+ "event_id": root_event_id,
+ "key": "👍",
+ }
+ },
+ tok=self.mod_access_token,
+ )
+ reaction_event_id = res["event_id"]
+
+ # Redact the root event, specifying that we also want to delete events that
+ # relate to it with m.replace.
+ self._redact_event(
+ self.mod_access_token,
+ self.room_id,
+ root_event_id,
+ with_relations=[
+ RelationTypes.REPLACE,
+ RelationTypes.THREAD,
+ ],
+ )
+
+ # Check that the root event got redacted.
+ event_dict = self.helper.get_event(
+ self.room_id, root_event_id, self.mod_access_token
+ )
+ self.assertIn("redacted_because", event_dict, event_dict)
+
+ # Check that the edit got redacted.
+ event_dict = self.helper.get_event(
+ self.room_id, edit_event_id, self.mod_access_token
+ )
+ self.assertIn("redacted_because", event_dict, event_dict)
+
+ # Check that the threaded message got redacted.
+ event_dict = self.helper.get_event(
+ self.room_id, threaded_event_id, self.mod_access_token
+ )
+ self.assertIn("redacted_because", event_dict, event_dict)
+
+ # Check that the reaction did not get redacted.
+ event_dict = self.helper.get_event(
+ self.room_id, reaction_event_id, self.mod_access_token
+ )
+ self.assertNotIn("redacted_because", event_dict, event_dict)
+
+ @override_config({"experimental_features": {"msc3912_enabled": True}})
+ def test_redact_relations_no_perms(self) -> None:
+ """Tests that, when redacting a message along with its relations, if not all
+ the related messages can be redacted because of insufficient permissions, the
+ server still redacts all the ones that can be.
+ """
+ # Send a root event.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "root",
+ },
+ tok=self.other_access_token,
+ )
+ root_event_id = res["event_id"]
+
+ # Send a first threaded message, this one from the moderator. We do this for the
+ # first message with the m.thread relation (and not the last one) to ensure
+ # that, when the server fails to redact it, it doesn't stop there, and it
+ # instead goes on to redact the other one.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "message 1",
+ "m.relates_to": {
+ "event_id": root_event_id,
+ "rel_type": RelationTypes.THREAD,
+ },
+ },
+ tok=self.mod_access_token,
+ )
+ first_threaded_event_id = res["event_id"]
+
+ # Send a second threaded message, this time from the user who'll perform the
+ # redaction.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "message 2",
+ "m.relates_to": {
+ "event_id": root_event_id,
+ "rel_type": RelationTypes.THREAD,
+ },
+ },
+ tok=self.other_access_token,
+ )
+ second_threaded_event_id = res["event_id"]
+
+ # Redact the thread's root, and request that all threaded messages are also
+ # redacted. Send that request from the non-mod user, so that the first threaded
+ # event cannot be redacted.
+ self._redact_event(
+ self.other_access_token,
+ self.room_id,
+ root_event_id,
+ with_relations=[RelationTypes.THREAD],
+ )
+
+ # Check that the thread root got redacted.
+ event_dict = self.helper.get_event(
+ self.room_id, root_event_id, self.other_access_token
+ )
+ self.assertIn("redacted_because", event_dict, event_dict)
+
+ # Check that the last message in the thread got redacted, despite failing to
+ # redact the one before it.
+ event_dict = self.helper.get_event(
+ self.room_id, second_threaded_event_id, self.other_access_token
+ )
+ self.assertIn("redacted_because", event_dict, event_dict)
+
+ # Check that the message that was sent into the tread by the mod user is not
+ # redacted.
+ event_dict = self.helper.get_event(
+ self.room_id, first_threaded_event_id, self.other_access_token
+ )
+ self.assertIn("body", event_dict["content"], event_dict)
+ self.assertEqual("message 1", event_dict["content"]["body"])
+
+ @override_config({"experimental_features": {"msc3912_enabled": True}})
+ def test_redact_relations_txn_id_reuse(self) -> None:
+ """Tests that redacting a message using a transaction ID, then reusing the same
+ transaction ID but providing an additional list of relations to redact, is
+ effectively a no-op.
+ """
+ # Send a root event.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "root",
+ },
+ tok=self.mod_access_token,
+ )
+ root_event_id = res["event_id"]
+
+ # Send a first threaded message.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "I'm in a thread!",
+ "m.relates_to": {
+ "event_id": root_event_id,
+ "rel_type": RelationTypes.THREAD,
+ },
+ },
+ tok=self.mod_access_token,
+ )
+ threaded_event_id = res["event_id"]
+
+ # Send a first redaction request which redacts only the root event.
+ channel = self.make_request(
+ method="PUT",
+ path=f"/rooms/{self.room_id}/redact/{root_event_id}/foo",
+ content={},
+ access_token=self.mod_access_token,
+ )
+ self.assertEqual(channel.code, 200)
+
+ # Send a second redaction request which redacts the root event as well as
+ # threaded messages.
+ channel = self.make_request(
+ method="PUT",
+ path=f"/rooms/{self.room_id}/redact/{root_event_id}/foo",
+ content={"org.matrix.msc3912.with_relations": [RelationTypes.THREAD]},
+ access_token=self.mod_access_token,
+ )
+ self.assertEqual(channel.code, 200)
+
+ # Check that the root event got redacted.
+ event_dict = self.helper.get_event(
+ self.room_id, root_event_id, self.mod_access_token
+ )
+ self.assertIn("redacted_because", event_dict)
+
+ # Check that the threaded message didn't get redacted (since that wasn't part of
+ # the original redaction).
+ event_dict = self.helper.get_event(
+ self.room_id, threaded_event_id, self.mod_access_token
+ )
+ self.assertIn("body", event_dict["content"], event_dict)
+ self.assertEqual("I'm in a thread!", event_dict["content"]["body"])
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index f8e64ce6ac..11cf3939d8 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -22,7 +22,11 @@ import pkg_resources
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
-from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
+from synapse.api.constants import (
+ APP_SERVICE_REGISTRATION_TYPE,
+ ApprovalNoticeMedium,
+ LoginType,
+)
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
from synapse.rest.client import account, account_validity, login, logout, register, sync
@@ -70,7 +74,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
det_data = {"user_id": user_id, "home_server": self.hs.hostname}
self.assertDictContainsSubset(det_data, channel.json_body)
@@ -91,7 +95,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
- self.assertEqual(channel.result["code"], b"400", channel.result)
+ self.assertEqual(channel.code, 400, msg=channel.result)
def test_POST_appservice_registration_invalid(self) -> None:
self.appservice = None # no application service exists
@@ -100,20 +104,20 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
def test_POST_bad_password(self) -> None:
request_data = {"username": "kermit", "password": 666}
channel = self.make_request(b"POST", self.url, request_data)
- self.assertEqual(channel.result["code"], b"400", channel.result)
+ self.assertEqual(channel.code, 400, msg=channel.result)
self.assertEqual(channel.json_body["error"], "Invalid password")
def test_POST_bad_username(self) -> None:
request_data = {"username": 777, "password": "monkey"}
channel = self.make_request(b"POST", self.url, request_data)
- self.assertEqual(channel.result["code"], b"400", channel.result)
+ self.assertEqual(channel.code, 400, msg=channel.result)
self.assertEqual(channel.json_body["error"], "Invalid username")
def test_POST_user_valid(self) -> None:
@@ -132,7 +136,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"home_server": self.hs.hostname,
"device_id": device_id,
}
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertDictContainsSubset(det_data, channel.json_body)
@override_config({"enable_registration": False})
@@ -142,7 +146,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", self.url, request_data)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["error"], "Registration has been disabled")
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@@ -153,7 +157,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"}
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertDictContainsSubset(det_data, channel.json_body)
def test_POST_disabled_guest_registration(self) -> None:
@@ -161,7 +165,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["error"], "Guest access is disabled")
@override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
@@ -171,16 +175,16 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", url, b"{}")
if i == 5:
- self.assertEqual(channel.result["code"], b"429", channel.result)
+ self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
@override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
def test_POST_ratelimiting(self) -> None:
@@ -194,16 +198,16 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", self.url, request_data)
if i == 5:
- self.assertEqual(channel.result["code"], b"429", channel.result)
+ self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
@override_config({"registration_requires_token": True})
def test_POST_registration_requires_token(self) -> None:
@@ -231,7 +235,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
# Request without auth to get flows and session
channel = self.make_request(b"POST", self.url, params)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
flows = channel.json_body["flows"]
# Synapse adds a dummy stage to differentiate flows where otherwise one
# flow would be a subset of another flow.
@@ -248,7 +252,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"session": session,
}
channel = self.make_request(b"POST", self.url, params)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
completed = channel.json_body["completed"]
self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
@@ -263,7 +267,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"home_server": self.hs.hostname,
"device_id": device_id,
}
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertDictContainsSubset(det_data, channel.json_body)
# Check the `completed` counter has been incremented and pending is 0
@@ -293,21 +297,21 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"session": session,
}
channel = self.make_request(b"POST", self.url, params)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM)
self.assertEqual(channel.json_body["completed"], [])
# Test with non-string (invalid)
params["auth"]["token"] = 1234
channel = self.make_request(b"POST", self.url, params)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
self.assertEqual(channel.json_body["completed"], [])
# Test with unknown token (invalid)
params["auth"]["token"] = "1234"
channel = self.make_request(b"POST", self.url, params)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], [])
@@ -361,7 +365,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"session": session2,
}
channel = self.make_request(b"POST", self.url, params2)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], [])
@@ -381,7 +385,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
# Check auth still fails when using token with session2
channel = self.make_request(b"POST", self.url, params2)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], [])
@@ -415,7 +419,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"session": session,
}
channel = self.make_request(b"POST", self.url, params)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], [])
@@ -570,7 +574,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
def test_advertised_flows(self) -> None:
channel = self.make_request(b"POST", self.url, b"{}")
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
flows = channel.json_body["flows"]
# with the stock config, we only expect the dummy flow
@@ -586,14 +590,14 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"require_at_registration": True,
},
"account_threepid_delegates": {
- "email": "https://id_server",
"msisdn": "https://id_server",
},
+ "email": {"notif_from": "Synapse <synapse@example.com>"},
}
)
def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None:
channel = self.make_request(b"POST", self.url, b"{}")
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
flows = channel.json_body["flows"]
self.assertCountEqual(
@@ -625,7 +629,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
)
def test_advertised_flows_no_msisdn_email_required(self) -> None:
channel = self.make_request(b"POST", self.url, b"{}")
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
flows = channel.json_body["flows"]
# with the stock config, we expect all four combinations of 3pid
@@ -765,6 +769,32 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.USER_IN_USE)
+ @override_config(
+ {
+ "experimental_features": {
+ "msc3866": {
+ "enabled": True,
+ "require_approval_for_new_accounts": True,
+ }
+ }
+ }
+ )
+ def test_require_approval(self) -> None:
+ channel = self.make_request(
+ "POST",
+ "register",
+ {
+ "username": "kermit",
+ "password": "monkey",
+ "auth": {"type": LoginType.DUMMY},
+ },
+ )
+ self.assertEqual(403, channel.code, channel.result)
+ self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
+ self.assertEqual(
+ ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
+ )
+
class AccountValidityTestCase(unittest.HomeserverTestCase):
@@ -797,13 +827,13 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
# endpoint.
channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
)
@@ -823,12 +853,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/account_validity/validity"
request_data = {"user_id": user_id}
channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# The specific endpoint doesn't matter, all we need is an authenticated
# endpoint.
channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
def test_manual_expire(self) -> None:
user_id = self.register_user("kermit", "monkey")
@@ -844,12 +874,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
"enable_renewal_emails": False,
}
channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# The specific endpoint doesn't matter, all we need is an authenticated
# endpoint.
channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
)
@@ -868,18 +898,18 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
"enable_renewal_emails": False,
}
channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# Try to log the user out
channel = self.make_request(b"POST", "/logout", access_token=tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# Log the user in again (allowed for expired accounts)
tok = self.login("kermit", "monkey")
# Try to log out all of the user's sessions
channel = self.make_request(b"POST", "/logout/all", access_token=tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
@@ -954,7 +984,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id))
url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
channel = self.make_request(b"GET", url)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# Check that we're getting HTML back.
content_type = channel.headers.getRawHeaders(b"Content-Type")
@@ -972,7 +1002,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# Move 1 day forward. Try to renew with the same token again.
url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
channel = self.make_request(b"GET", url)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# Check that we're getting HTML back.
content_type = channel.headers.getRawHeaders(b"Content-Type")
@@ -992,14 +1022,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# succeed.
self.reactor.advance(datetime.timedelta(days=3).total_seconds())
channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
def test_renewal_invalid_token(self) -> None:
# Hit the renewal endpoint with an invalid token and check that it behaves as
# expected, i.e. that it responds with 404 Not Found and the correct HTML.
url = "/_matrix/client/unstable/account_validity/renew?token=123"
channel = self.make_request(b"GET", url)
- self.assertEqual(channel.result["code"], b"404", channel.result)
+ self.assertEqual(channel.code, 404, msg=channel.result)
# Check that we're getting HTML back.
content_type = channel.headers.getRawHeaders(b"Content-Type")
@@ -1023,7 +1053,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
"/_matrix/client/unstable/account_validity/send_mail",
access_token=tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(len(self.email_attempts), 1)
@@ -1096,7 +1126,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
"/_matrix/client/unstable/account_validity/send_mail",
access_token=tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(len(self.email_attempts), 1)
@@ -1176,7 +1206,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
b"GET",
f"{self.url}?token={token}",
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["valid"], True)
def test_GET_token_invalid(self) -> None:
@@ -1185,7 +1215,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
b"GET",
f"{self.url}?token={token}",
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["valid"], False)
@override_config(
@@ -1201,10 +1231,10 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
)
if i == 5:
- self.assertEqual(channel.result["code"], b"429", channel.result)
+ self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
@@ -1212,4 +1242,4 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
b"GET",
f"{self.url}?token={token}",
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index d589f07314..b86f341ff5 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -654,6 +654,14 @@ class RelationsTestCase(BaseRelationsTestCase):
)
# We also expect to get the original event (the id of which is self.parent_id)
+ # when requesting the unstable endpoint.
+ self.assertNotIn("original_event", channel.json_body)
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
self.assertEqual(
channel.json_body["original_event"]["event_id"], self.parent_id
)
@@ -728,7 +736,6 @@ class RelationsTestCase(BaseRelationsTestCase):
class RelationPaginationTestCase(BaseRelationsTestCase):
- @unittest.override_config({"experimental_features": {"msc3715_enabled": True}})
def test_basic_paginate_relations(self) -> None:
"""Tests that calling pagination API correctly the latest relations."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
@@ -756,11 +763,6 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
channel.json_body["chunk"][0],
)
- # We also expect to get the original event (the id of which is self.parent_id)
- self.assertEqual(
- channel.json_body["original_event"]["event_id"], self.parent_id
- )
-
# Make sure next_batch has something in it that looks like it could be a
# valid token.
self.assertIsInstance(
@@ -771,7 +773,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
channel = self.make_request(
"GET",
f"/_matrix/client/v1/rooms/{self.room}/relations"
- f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f",
+ f"/{self.parent_id}?limit=1&dir=f",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
@@ -809,7 +811,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
channel = self.make_request(
"GET",
- f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
+ f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?limit=3{from_token}",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
@@ -827,6 +829,32 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
found_event_ids.reverse()
self.assertEqual(found_event_ids, expected_event_ids)
+ # Test forward pagination.
+ prev_token = ""
+ found_event_ids = []
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + prev_token
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?dir=f&limit=3{from_token}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEqual(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ self.assertEqual(found_event_ids, expected_event_ids)
+
def test_pagination_from_sync_and_messages(self) -> None:
"""Pagination tokens from /sync and /messages can be used to paginate /relations."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
@@ -999,7 +1027,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations,
)
- self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 6)
+ self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7)
def test_annotation_to_annotation(self) -> None:
"""Any relation to an annotation should be ignored."""
@@ -1035,7 +1063,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations,
)
- self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 6)
+ self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7)
def test_thread(self) -> None:
"""
@@ -1080,21 +1108,21 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
# The "user" sent the root event and is making queries for the bundled
# aggregations: they have participated.
- self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8)
+ self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 7)
# The "user2" sent replies in the thread and is making queries for the
# bundled aggregations: they have participated.
#
# Note that this re-uses some cached values, so the total number of
# queries is much smaller.
self._test_bundled_aggregations(
- RelationTypes.THREAD, _gen_assert(True), 2, access_token=self.user2_token
+ RelationTypes.THREAD, _gen_assert(True), 3, access_token=self.user2_token
)
# A user with no interactions with the thread: they have not participated.
user3_id, user3_token = self._create_user("charlie")
self.helper.join(self.room, user=user3_id, tok=user3_token)
self._test_bundled_aggregations(
- RelationTypes.THREAD, _gen_assert(False), 2, access_token=user3_token
+ RelationTypes.THREAD, _gen_assert(False), 3, access_token=user3_token
)
def test_thread_with_bundled_aggregations_for_latest(self) -> None:
@@ -1142,7 +1170,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations["latest_event"].get("unsigned"),
)
- self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8)
+ self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 7)
def test_nested_thread(self) -> None:
"""
@@ -1495,6 +1523,26 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
)
self.assertEqual(200, channel.code, channel.json_body)
+ def _get_threads(self) -> List[Tuple[str, str]]:
+ """Request the threads in the room and returns a list of thread ID and latest event ID."""
+ # Request the threads in the room.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/rooms/{self.room}/threads",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ threads = channel.json_body["chunk"]
+ return [
+ (
+ t["event_id"],
+ t["unsigned"]["m.relations"][RelationTypes.THREAD]["latest_event"][
+ "event_id"
+ ],
+ )
+ for t in threads
+ ]
+
def test_redact_relation_annotation(self) -> None:
"""
Test that annotations of an event are properly handled after the
@@ -1539,58 +1587,82 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
The redacted event should not be included in bundled aggregations or
the response to relations.
"""
- channel = self._send_relation(
- RelationTypes.THREAD,
- EventTypes.Message,
- content={"body": "reply 1", "msgtype": "m.text"},
- )
- unredacted_event_id = channel.json_body["event_id"]
+ # Create a thread with a few events in it.
+ thread_replies = []
+ for i in range(3):
+ channel = self._send_relation(
+ RelationTypes.THREAD,
+ EventTypes.Message,
+ content={"body": f"reply {i}", "msgtype": "m.text"},
+ )
+ thread_replies.append(channel.json_body["event_id"])
- # Note that the *last* event in the thread is redacted, as that gets
- # included in the bundled aggregation.
- channel = self._send_relation(
- RelationTypes.THREAD,
- EventTypes.Message,
- content={"body": "reply 2", "msgtype": "m.text"},
+ ##################################################
+ # Check the test data is configured as expected. #
+ ##################################################
+ self.assertEquals(self._get_related_events(), list(reversed(thread_replies)))
+ relations = self._get_bundled_aggregations()
+ self.assertDictContainsSubset(
+ {"count": 3, "current_user_participated": True},
+ relations[RelationTypes.THREAD],
+ )
+ # The latest event is the last sent event.
+ self.assertEqual(
+ relations[RelationTypes.THREAD]["latest_event"]["event_id"],
+ thread_replies[-1],
)
- to_redact_event_id = channel.json_body["event_id"]
- # Both relations exist.
- event_ids = self._get_related_events()
+ # There should be one thread, the latest event is the event that will be redacted.
+ self.assertEqual(self._get_threads(), [(self.parent_id, thread_replies[-1])])
+
+ ##########################
+ # Redact the last event. #
+ ##########################
+ self._redact(thread_replies.pop())
+
+ # The thread should still exist, but the latest event should be updated.
+ self.assertEquals(self._get_related_events(), list(reversed(thread_replies)))
relations = self._get_bundled_aggregations()
- self.assertEquals(event_ids, [to_redact_event_id, unredacted_event_id])
self.assertDictContainsSubset(
- {
- "count": 2,
- "current_user_participated": True,
- },
+ {"count": 2, "current_user_participated": True},
relations[RelationTypes.THREAD],
)
- # And the latest event returned is the event that will be redacted.
+ # And the latest event is the last unredacted event.
self.assertEqual(
relations[RelationTypes.THREAD]["latest_event"]["event_id"],
- to_redact_event_id,
+ thread_replies[-1],
)
+ self.assertEqual(self._get_threads(), [(self.parent_id, thread_replies[-1])])
- # Redact one of the reactions.
- self._redact(to_redact_event_id)
+ ###########################################
+ # Redact the *first* event in the thread. #
+ ###########################################
+ self._redact(thread_replies.pop(0))
- # The unredacted relation should still exist.
- event_ids = self._get_related_events()
+ # Nothing should have changed (except the thread count).
+ self.assertEquals(self._get_related_events(), thread_replies)
relations = self._get_bundled_aggregations()
- self.assertEquals(event_ids, [unredacted_event_id])
self.assertDictContainsSubset(
- {
- "count": 1,
- "current_user_participated": True,
- },
+ {"count": 1, "current_user_participated": True},
relations[RelationTypes.THREAD],
)
- # And the latest event is now the unredacted event.
+ # And the latest event is the last unredacted event.
self.assertEqual(
relations[RelationTypes.THREAD]["latest_event"]["event_id"],
- unredacted_event_id,
+ thread_replies[-1],
)
+ self.assertEqual(self._get_threads(), [(self.parent_id, thread_replies[-1])])
+
+ ####################################
+ # Redact the last remaining event. #
+ ####################################
+ self._redact(thread_replies.pop(0))
+ self.assertEquals(thread_replies, [])
+
+ # The event should no longer be considered a thread.
+ self.assertEquals(self._get_related_events(), [])
+ self.assertEquals(self._get_bundled_aggregations(), {})
+ self.assertEqual(self._get_threads(), [])
def test_redact_parent_edit(self) -> None:
"""Test that edits of an event are redacted when the original event
@@ -1649,7 +1721,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
{"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]},
)
- @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_redact_parent_thread(self) -> None:
"""
Test that thread replies are still available when the root event is redacted.
@@ -1679,3 +1750,165 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
relations[RelationTypes.THREAD]["latest_event"]["event_id"],
related_event_id,
)
+
+
+class ThreadsTestCase(BaseRelationsTestCase):
+ def _get_threads(self, body: JsonDict) -> List[Tuple[str, str]]:
+ return [
+ (
+ ev["event_id"],
+ ev["unsigned"]["m.relations"]["m.thread"]["latest_event"]["event_id"],
+ )
+ for ev in body["chunk"]
+ ]
+
+ def test_threads(self) -> None:
+ """Create threads and ensure the ordering is due to their latest event."""
+ # Create 2 threads.
+ thread_1 = self.parent_id
+ res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token)
+ thread_2 = res["event_id"]
+
+ channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ reply_1 = channel.json_body["event_id"]
+ channel = self._send_relation(
+ RelationTypes.THREAD, "m.room.test", parent_id=thread_2
+ )
+ reply_2 = channel.json_body["event_id"]
+
+ # Request the threads in the room.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/rooms/{self.room}/threads",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ threads = self._get_threads(channel.json_body)
+ self.assertEqual(threads, [(thread_2, reply_2), (thread_1, reply_1)])
+
+ # Update the first thread, the ordering should swap.
+ channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ reply_3 = channel.json_body["event_id"]
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/rooms/{self.room}/threads",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ # Tuple of (thread ID, latest event ID) for each thread.
+ threads = self._get_threads(channel.json_body)
+ self.assertEqual(threads, [(thread_1, reply_3), (thread_2, reply_2)])
+
+ def test_pagination(self) -> None:
+ """Create threads and paginate through them."""
+ # Create 2 threads.
+ thread_1 = self.parent_id
+ res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token)
+ thread_2 = res["event_id"]
+
+ self._send_relation(RelationTypes.THREAD, "m.room.test")
+ self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
+
+ # Request the threads in the room.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
+ self.assertEqual(thread_roots, [thread_2])
+
+ # Make sure next_batch has something in it that looks like it could be a
+ # valid token.
+ next_batch = channel.json_body.get("next_batch")
+ self.assertIsInstance(next_batch, str, channel.json_body)
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1&from={next_batch}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
+ self.assertEqual(thread_roots, [thread_1], channel.json_body)
+
+ self.assertNotIn("next_batch", channel.json_body, channel.json_body)
+
+ def test_include(self) -> None:
+ """Filtering threads to all or participated in should work."""
+ # Thread 1 has the user as the root event.
+ thread_1 = self.parent_id
+ self._send_relation(
+ RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
+ )
+
+ # Thread 2 has the user replying.
+ res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token)
+ thread_2 = res["event_id"]
+ self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
+
+ # Thread 3 has the user not participating in.
+ res = self.helper.send(self.room, body="Another thread!", tok=self.user2_token)
+ thread_3 = res["event_id"]
+ self._send_relation(
+ RelationTypes.THREAD,
+ "m.room.test",
+ access_token=self.user2_token,
+ parent_id=thread_3,
+ )
+
+ # All threads in the room.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/rooms/{self.room}/threads",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
+ self.assertEqual(
+ thread_roots, [thread_3, thread_2, thread_1], channel.json_body
+ )
+
+ # Only participated threads.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/rooms/{self.room}/threads?include=participated",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
+ self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body)
+
+ def test_ignored_user(self) -> None:
+ """Events from ignored users should be ignored."""
+ # Thread 1 has a reply from an ignored user.
+ thread_1 = self.parent_id
+ self._send_relation(
+ RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
+ )
+
+ # Thread 2 is created by an ignored user.
+ res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token)
+ thread_2 = res["event_id"]
+ self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
+
+ # Ignore user2.
+ self.get_success(
+ self.store.add_account_data_for_user(
+ self.user_id,
+ AccountDataTypes.IGNORED_USER_LIST,
+ {"ignored_users": {self.user2_id: {}}},
+ )
+ )
+
+ # Only thread 1 is returned.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/rooms/{self.room}/threads",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
+ self.assertEqual(thread_roots, [thread_1], channel.json_body)
diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py
new file mode 100644
index 0000000000..ad00a476e1
--- /dev/null
+++ b/tests/rest/client/test_rendezvous.py
@@ -0,0 +1,45 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.rest.client import rendezvous
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests import unittest
+from tests.unittest import override_config
+
+endpoint = "/_matrix/client/unstable/org.matrix.msc3886/rendezvous"
+
+
+class RendezvousServletTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ rendezvous.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.hs = self.setup_test_homeserver()
+ return self.hs
+
+ def test_disabled(self) -> None:
+ channel = self.make_request("POST", endpoint, {}, access_token=None)
+ self.assertEqual(channel.code, 400)
+
+ @override_config({"experimental_features": {"msc3886_endpoint": "/asd"}})
+ def test_redirect(self) -> None:
+ channel = self.make_request("POST", endpoint, {}, access_token=None)
+ self.assertEqual(channel.code, 307)
+ self.assertEqual(channel.headers.getRawHeaders("Location"), ["/asd"])
diff --git a/tests/rest/client/test_report_event.py b/tests/rest/client/test_report_event.py
index ad0d0209f7..7cb1017a4a 100644
--- a/tests/rest/client/test_report_event.py
+++ b/tests/rest/client/test_report_event.py
@@ -77,6 +77,4 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST", self.report_path, data, access_token=self.other_user_tok
)
- self.assertEqual(
- response_status, int(channel.result["code"]), msg=channel.result["body"]
- )
+ self.assertEqual(response_status, channel.code, msg=channel.result["body"])
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index ac9c113354..9c8c1889d3 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -20,7 +20,7 @@ from synapse.api.constants import EventTypes
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
-from synapse.types import JsonDict
+from synapse.types import JsonDict, create_requester
from synapse.util import Clock
from synapse.visibility import filter_events_for_client
@@ -188,7 +188,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
message_handler = self.hs.get_message_handler()
create_event = self.get_success(
message_handler.get_room_data(
- self.user_id, room_id, EventTypes.Create, state_key=""
+ create_requester(self.user_id), room_id, EventTypes.Create, state_key=""
)
)
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index aa2f578441..e919e089cb 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -20,7 +20,7 @@
import json
from http import HTTPStatus
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
-from unittest.mock import Mock, call
+from unittest.mock import Mock, call, patch
from urllib import parse as urlparse
from parameterized import param, parameterized
@@ -35,13 +35,15 @@ from synapse.api.constants import (
EventTypes,
Membership,
PublicRoomsFilterFields,
- RelationTypes,
RoomTypes,
)
from synapse.api.errors import Codes, HttpResponseException
+from synapse.appservice import ApplicationService
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
from synapse.handlers.pagination import PurgeStatus
from synapse.rest import admin
-from synapse.rest.client import account, directory, login, profile, room, sync
+from synapse.rest.client import account, directory, login, profile, register, room, sync
from synapse.server import HomeServer
from synapse.types import JsonDict, RoomAlias, UserID, create_requester
from synapse.util import Clock
@@ -49,7 +51,10 @@ from synapse.util.stringutils import random_string
from tests import unittest
from tests.http.server._base import make_request_with_cancellation_test
+from tests.storage.test_stream import PaginationTestCase
from tests.test_utils import make_awaitable
+from tests.test_utils.event_injection import create_event
+from tests.unittest import override_config
PATH_PREFIX = b"/_matrix/client/api/v1"
@@ -710,7 +715,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
- self.assertEqual(44, channel.resource_usage.db_txn_count)
+ self.assertEqual(33, channel.resource_usage.db_txn_count)
def test_post_room_initial_state(self) -> None:
# POST with initial_state config key, expect new room id
@@ -723,7 +728,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
- self.assertEqual(50, channel.resource_usage.db_txn_count)
+ self.assertEqual(36, channel.resource_usage.db_txn_count)
def test_post_room_visibility_key(self) -> None:
# POST with visibility config key, expect new room id
@@ -867,6 +872,41 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
self.assertEqual(join_mock.call_count, 0)
+ def _create_basic_room(self) -> Tuple[int, object]:
+ """
+ Tries to create a basic room and returns the response code.
+ """
+ channel = self.make_request(
+ "POST",
+ "/createRoom",
+ {},
+ )
+ return channel.code, channel.json_body
+
+ @override_config(
+ {
+ "rc_message": {"per_second": 0.2, "burst_count": 10},
+ }
+ )
+ def test_room_creation_ratelimiting(self) -> None:
+ """
+ Regression test for #14312, where ratelimiting was made too strict.
+ Clients should be able to create 10 rooms in a row
+ without hitting rate limits, using default rate limit config.
+ (We override rate limiting config back to its default value.)
+
+ To ensure we don't make ratelimiting too generous accidentally,
+ also check that we can't create an 11th room.
+ """
+
+ for _ in range(10):
+ code, json_body = self._create_basic_room()
+ self.assertEqual(code, HTTPStatus.OK, json_body)
+
+ # The 6th room hits the rate limit.
+ code, json_body = self._create_basic_room()
+ self.assertEqual(code, HTTPStatus.TOO_MANY_REQUESTS, json_body)
+
class RoomTopicTestCase(RoomBase):
"""Tests /rooms/$room_id/topic REST events."""
@@ -1252,6 +1292,120 @@ class RoomJoinTestCase(RoomBase):
)
+class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ room.register_servlets,
+ synapse.rest.admin.register_servlets,
+ register.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.appservice_user, _ = self.register_appservice_user(
+ "as_user_potato", self.appservice.token
+ )
+
+ # Create a room as the appservice user.
+ args = {
+ "access_token": self.appservice.token,
+ "user_id": self.appservice_user,
+ }
+ channel = self.make_request(
+ "POST",
+ f"/_matrix/client/r0/createRoom?{urlparse.urlencode(args)}",
+ content={"visibility": "public"},
+ )
+
+ assert channel.code == 200
+ self.room = channel.json_body["room_id"]
+
+ self.main_store = self.hs.get_datastores().main
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ config = self.default_config()
+
+ self.appservice = ApplicationService(
+ token="i_am_an_app_service",
+ id="1234",
+ namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
+ # Note: this user does not have to match the regex above
+ sender="@as_main:test",
+ )
+
+ mock_load_appservices = Mock(return_value=[self.appservice])
+ with patch(
+ "synapse.storage.databases.main.appservice.load_appservices",
+ mock_load_appservices,
+ ):
+ hs = self.setup_test_homeserver(config=config)
+ return hs
+
+ def test_send_event_ts(self) -> None:
+ """Test sending a non-state event with a custom timestamp."""
+ ts = 1
+
+ url_params = {
+ "user_id": self.appservice_user,
+ "ts": ts,
+ }
+ channel = self.make_request(
+ "PUT",
+ path=f"/_matrix/client/r0/rooms/{self.room}/send/m.room.message/1234?"
+ + urlparse.urlencode(url_params),
+ content={"body": "test", "msgtype": "m.text"},
+ access_token=self.appservice.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ event_id = channel.json_body["event_id"]
+
+ # Ensure the event was persisted with the correct timestamp.
+ res = self.get_success(self.main_store.get_event(event_id))
+ self.assertEquals(ts, res.origin_server_ts)
+
+ def test_send_state_event_ts(self) -> None:
+ """Test sending a state event with a custom timestamp."""
+ ts = 1
+
+ url_params = {
+ "user_id": self.appservice_user,
+ "ts": ts,
+ }
+ channel = self.make_request(
+ "PUT",
+ path=f"/_matrix/client/r0/rooms/{self.room}/state/m.room.name?"
+ + urlparse.urlencode(url_params),
+ content={"name": "test"},
+ access_token=self.appservice.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ event_id = channel.json_body["event_id"]
+
+ # Ensure the event was persisted with the correct timestamp.
+ res = self.get_success(self.main_store.get_event(event_id))
+ self.assertEquals(ts, res.origin_server_ts)
+
+ def test_send_membership_event_ts(self) -> None:
+ """Test sending a membership event with a custom timestamp."""
+ ts = 1
+
+ url_params = {
+ "user_id": self.appservice_user,
+ "ts": ts,
+ }
+ channel = self.make_request(
+ "PUT",
+ path=f"/_matrix/client/r0/rooms/{self.room}/state/m.room.member/{self.appservice_user}?"
+ + urlparse.urlencode(url_params),
+ content={"membership": "join", "display_name": "test"},
+ access_token=self.appservice.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ event_id = channel.json_body["event_id"]
+
+ # Ensure the event was persisted with the correct timestamp.
+ res = self.get_success(self.main_store.get_event(event_id))
+ self.assertEquals(ts, res.origin_server_ts)
+
+
class RoomJoinRatelimitTestCase(RoomBase):
user_id = "@sid1:red"
@@ -1272,10 +1426,22 @@ class RoomJoinRatelimitTestCase(RoomBase):
)
def test_join_local_ratelimit(self) -> None:
"""Tests that local joins are actually rate-limited."""
- for _ in range(3):
- self.helper.create_room_as(self.user_id)
+ # Create 4 rooms
+ room_ids = [
+ self.helper.create_room_as(self.user_id, is_public=True) for _ in range(4)
+ ]
- self.helper.create_room_as(self.user_id, expect_code=429)
+ joiner_user_id = self.register_user("joiner", "secret")
+ # Now make a new user try to join some of them.
+
+ # The user can join 3 rooms
+ for room_id in room_ids[0:3]:
+ self.helper.join(room_id, joiner_user_id)
+
+ # But the user cannot join a 4th room
+ self.helper.join(
+ room_ids[3], joiner_user_id, expect_code=HTTPStatus.TOO_MANY_REQUESTS
+ )
@unittest.override_config(
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
@@ -2098,14 +2264,17 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
)
def make_public_rooms_request(
- self, room_types: Union[List[Union[str, None]], None]
+ self,
+ room_types: Optional[List[Union[str, None]]],
+ instance_id: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
- channel = self.make_request(
- "POST",
- self.url,
- {"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}},
- self.token,
- )
+ body: JsonDict = {"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}}
+ if instance_id:
+ body["third_party_instance_id"] = "test|test"
+
+ channel = self.make_request("POST", self.url, body, self.token)
+ self.assertEqual(channel.code, 200)
+
chunk = channel.json_body["chunk"]
count = channel.json_body["total_room_count_estimate"]
@@ -2115,31 +2284,49 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
def test_returns_both_rooms_and_spaces_if_no_filter(self) -> None:
chunk, count = self.make_public_rooms_request(None)
-
self.assertEqual(count, 2)
+ # Also check if there's no filter property at all in the body.
+ channel = self.make_request("POST", self.url, {}, self.token)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(len(channel.json_body["chunk"]), 2)
+ self.assertEqual(channel.json_body["total_room_count_estimate"], 2)
+
+ chunk, count = self.make_public_rooms_request(None, "test|test")
+ self.assertEqual(count, 0)
+
def test_returns_only_rooms_based_on_filter(self) -> None:
chunk, count = self.make_public_rooms_request([None])
self.assertEqual(count, 1)
self.assertEqual(chunk[0].get("room_type", None), None)
+ chunk, count = self.make_public_rooms_request([None], "test|test")
+ self.assertEqual(count, 0)
+
def test_returns_only_space_based_on_filter(self) -> None:
chunk, count = self.make_public_rooms_request(["m.space"])
self.assertEqual(count, 1)
self.assertEqual(chunk[0].get("room_type", None), "m.space")
+ chunk, count = self.make_public_rooms_request(["m.space"], "test|test")
+ self.assertEqual(count, 0)
+
def test_returns_both_rooms_and_space_based_on_filter(self) -> None:
chunk, count = self.make_public_rooms_request(["m.space", None])
-
self.assertEqual(count, 2)
+ chunk, count = self.make_public_rooms_request(["m.space", None], "test|test")
+ self.assertEqual(count, 0)
+
def test_returns_both_rooms_and_spaces_if_array_is_empty(self) -> None:
chunk, count = self.make_public_rooms_request([])
-
self.assertEqual(count, 2)
+ chunk, count = self.make_public_rooms_request([], "test|test")
+ self.assertEqual(count, 0)
+
class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
"""Test that we correctly fallback to local filtering if a remote server
@@ -2779,149 +2966,20 @@ class LabelsTestCase(unittest.HomeserverTestCase):
return event_id
-class RelationsTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets_for_client_rest_resource,
- room.register_servlets,
- login.register_servlets,
- ]
-
- def default_config(self) -> Dict[str, Any]:
- config = super().default_config()
- config["experimental_features"] = {"msc3440_enabled": True}
- return config
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.user_id = self.register_user("test", "test")
- self.tok = self.login("test", "test")
- self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
-
- self.second_user_id = self.register_user("second", "test")
- self.second_tok = self.login("second", "test")
- self.helper.join(
- room=self.room_id, user=self.second_user_id, tok=self.second_tok
- )
-
- self.third_user_id = self.register_user("third", "test")
- self.third_tok = self.login("third", "test")
- self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok)
-
- # An initial event with a relation from second user.
- res = self.helper.send_event(
- room_id=self.room_id,
- type=EventTypes.Message,
- content={"msgtype": "m.text", "body": "Message 1"},
- tok=self.tok,
- )
- self.event_id_1 = res["event_id"]
- self.helper.send_event(
- room_id=self.room_id,
- type="m.reaction",
- content={
- "m.relates_to": {
- "rel_type": RelationTypes.ANNOTATION,
- "event_id": self.event_id_1,
- "key": "👍",
- }
- },
- tok=self.second_tok,
- )
-
- # Another event with a relation from third user.
- res = self.helper.send_event(
- room_id=self.room_id,
- type=EventTypes.Message,
- content={"msgtype": "m.text", "body": "Message 2"},
- tok=self.tok,
- )
- self.event_id_2 = res["event_id"]
- self.helper.send_event(
- room_id=self.room_id,
- type="m.reaction",
- content={
- "m.relates_to": {
- "rel_type": RelationTypes.REFERENCE,
- "event_id": self.event_id_2,
- }
- },
- tok=self.third_tok,
- )
-
- # An event with no relations.
- self.helper.send_event(
- room_id=self.room_id,
- type=EventTypes.Message,
- content={"msgtype": "m.text", "body": "No relations"},
- tok=self.tok,
- )
-
- def _filter_messages(self, filter: JsonDict) -> List[JsonDict]:
+class RelationsTestCase(PaginationTestCase):
+ def _filter_messages(self, filter: JsonDict) -> List[str]:
"""Make a request to /messages with a filter, returns the chunk of events."""
+ from_token = self.get_success(
+ self.from_token.to_string(self.hs.get_datastores().main)
+ )
channel = self.make_request(
"GET",
- "/rooms/%s/messages?filter=%s&dir=b" % (self.room_id, json.dumps(filter)),
+ f"/rooms/{self.room_id}/messages?filter={json.dumps(filter)}&dir=f&from={from_token}",
access_token=self.tok,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
- return channel.json_body["chunk"]
-
- def test_filter_relation_senders(self) -> None:
- # Messages which second user reacted to.
- filter = {"related_by_senders": [self.second_user_id]}
- chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 1, chunk)
- self.assertEqual(chunk[0]["event_id"], self.event_id_1)
-
- # Messages which third user reacted to.
- filter = {"related_by_senders": [self.third_user_id]}
- chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 1, chunk)
- self.assertEqual(chunk[0]["event_id"], self.event_id_2)
-
- # Messages which either user reacted to.
- filter = {"related_by_senders": [self.second_user_id, self.third_user_id]}
- chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 2, chunk)
- self.assertCountEqual(
- [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
- )
-
- def test_filter_relation_type(self) -> None:
- # Messages which have annotations.
- filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]}
- chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 1, chunk)
- self.assertEqual(chunk[0]["event_id"], self.event_id_1)
-
- # Messages which have references.
- filter = {"related_by_rel_types": [RelationTypes.REFERENCE]}
- chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 1, chunk)
- self.assertEqual(chunk[0]["event_id"], self.event_id_2)
-
- # Messages which have either annotations or references.
- filter = {
- "related_by_rel_types": [
- RelationTypes.ANNOTATION,
- RelationTypes.REFERENCE,
- ]
- }
- chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 2, chunk)
- self.assertCountEqual(
- [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
- )
-
- def test_filter_relation_senders_and_type(self) -> None:
- # Messages which second user reacted to.
- filter = {
- "related_by_senders": [self.second_user_id],
- "related_by_rel_types": [RelationTypes.ANNOTATION],
- }
- chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 1, chunk)
- self.assertEqual(chunk[0]["event_id"], self.event_id_1)
+ return [ev["event_id"] for ev in channel.json_body["chunk"]]
class ContextTestCase(unittest.HomeserverTestCase):
@@ -3461,3 +3519,83 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# Also check that it stopped before calling _make_and_store_3pid_invite.
make_invite_mock.assert_called_once()
+
+ def test_400_missing_param_without_id_access_token(self) -> None:
+ """
+ Test that a 3pid invite request returns 400 M_MISSING_PARAM
+ if we do not include id_access_token.
+ """
+ channel = self.make_request(
+ method="POST",
+ path="/rooms/" + self.room_id + "/invite",
+ content={
+ "id_server": "example.com",
+ "medium": "email",
+ "address": "teresa@example.com",
+ },
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 400)
+ self.assertEqual(channel.json_body["errcode"], "M_MISSING_PARAM")
+
+
+class TimestampLookupTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["experimental_features"] = {"msc3030_enabled": True}
+ return config
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self._storage_controllers = self.hs.get_storage_controllers()
+
+ self.room_owner = self.register_user("room_owner", "test")
+ self.room_owner_tok = self.login("room_owner", "test")
+
+ def _inject_outlier(self, room_id: str) -> EventBase:
+ event, _context = self.get_success(
+ create_event(
+ self.hs,
+ room_id=room_id,
+ type="m.test",
+ sender="@test_remote_user:remote",
+ )
+ )
+
+ event.internal_metadata.outlier = True
+ self.get_success(
+ self._storage_controllers.persistence.persist_event(
+ event, EventContext.for_outlier(self._storage_controllers)
+ )
+ )
+ return event
+
+ def test_no_outliers(self) -> None:
+ """
+ Test to make sure `/timestamp_to_event` does not return `outlier` events.
+ We're unable to determine whether an `outlier` is next to a gap so we
+ don't know whether it's actually the closest event. Instead, let's just
+ ignore `outliers` with this endpoint.
+
+ This test is really seeing that we choose the non-`outlier` event behind the
+ `outlier`. Since the gap checking logic considers the latest message in the room
+ as *not* next to a gap, asking over federation does not come into play here.
+ """
+ room_id = self.helper.create_room_as(self.room_owner, tok=self.room_owner_tok)
+
+ outlier_event = self._inject_outlier(room_id)
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/org.matrix.msc3030/rooms/{room_id}/timestamp_to_event?dir=b&ts={outlier_event.origin_server_ts}",
+ access_token=self.room_owner_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+
+ # Make sure the outlier event is not returned
+ self.assertNotEqual(channel.json_body["event_id"], outlier_event.event_id)
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index d9bd8c4a28..c807a37bc2 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -26,7 +26,7 @@ from synapse.rest.client import (
room_upgrade_rest_servlet,
)
from synapse.server import HomeServer
-from synapse.types import UserID
+from synapse.types import UserID, create_requester
from synapse.util import Clock
from tests import unittest
@@ -97,7 +97,12 @@ class RoomTestCase(_ShadowBannedBase):
channel = self.make_request(
"POST",
"/rooms/%s/invite" % (room_id,),
- {"id_server": "test", "medium": "email", "address": "test@test.test"},
+ {
+ "id_server": "test",
+ "medium": "email",
+ "address": "test@test.test",
+ "id_access_token": "anytoken",
+ },
access_token=self.banned_access_token,
)
self.assertEqual(200, channel.code, channel.result)
@@ -275,7 +280,7 @@ class ProfileTestCase(_ShadowBannedBase):
message_handler = self.hs.get_message_handler()
event = self.get_success(
message_handler.get_room_data(
- self.banned_user_id,
+ create_requester(self.banned_user_id),
room_id,
"m.room.member",
self.banned_user_id,
@@ -310,7 +315,7 @@ class ProfileTestCase(_ShadowBannedBase):
message_handler = self.hs.get_message_handler()
event = self.get_success(
message_handler.get_room_data(
- self.banned_user_id,
+ create_requester(self.banned_user_id),
room_id,
"m.room.member",
self.banned_user_id,
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index b085c50356..0af643ecd9 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -38,7 +38,6 @@ from tests.federation.transport.test_knocking import (
KnockingStrippedStateEventHelperMixin,
)
from tests.server import TimedOutException
-from tests.unittest import override_config
class FilterTestCase(unittest.HomeserverTestCase):
@@ -390,6 +389,11 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
sync.register_servlets,
]
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ config = self.default_config()
+
+ return self.setup_test_homeserver(config=config)
+
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.url = "/sync?since=%s"
self.next_batch = "s0"
@@ -408,7 +412,6 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Join the second user
self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
- @override_config({"experimental_features": {"msc2285_enabled": True}})
def test_private_read_receipts(self) -> None:
# Send a message as the first user
res = self.helper.send(self.room_id, body="hello", tok=self.tok)
@@ -416,7 +419,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Send a private read receipt to tell the server the first user's message was read
channel = self.make_request(
"POST",
- f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res['event_id']}",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}",
{},
access_token=self.tok2,
)
@@ -425,7 +428,6 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Test that the first user can't see the other user's private read receipt
self.assertIsNone(self._get_read_receipt())
- @override_config({"experimental_features": {"msc2285_enabled": True}})
def test_public_receipt_can_override_private(self) -> None:
"""
Sending a public read receipt to the same event which has a private read
@@ -456,7 +458,6 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Test that we did override the private read receipt
self.assertNotEqual(self._get_read_receipt(), None)
- @override_config({"experimental_features": {"msc2285_enabled": True}})
def test_private_receipt_cannot_override_public(self) -> None:
"""
Sending a private read receipt to the same event which has a public read
@@ -543,7 +544,6 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
config = super().default_config()
config["experimental_features"] = {
"msc2654_enabled": True,
- "msc2285_enabled": True,
}
return config
@@ -624,7 +624,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
# Send a read receipt to tell the server we've read the latest event.
channel = self.make_request(
"POST",
- f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res['event_id']}",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}",
{},
access_token=self.tok,
)
@@ -700,7 +700,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
self._check_unread_count(5)
res2 = self.helper.send(self.room_id, "hello", tok=self.tok2)
- # Make sure both m.read and org.matrix.msc2285.read.private advance
+ # Make sure both m.read and m.read.private advance
channel = self.make_request(
"POST",
f"/rooms/{self.room_id}/receipt/m.read/{res1['event_id']}",
@@ -712,16 +712,21 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
- f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res2['event_id']}",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res2['event_id']}",
{},
access_token=self.tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
self._check_unread_count(0)
- # We test for both receipt types that influence notification counts
- @parameterized.expand([ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE])
- def test_read_receipts_only_go_down(self, receipt_type: ReceiptTypes) -> None:
+ # We test for all three receipt types that influence notification counts
+ @parameterized.expand(
+ [
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ]
+ )
+ def test_read_receipts_only_go_down(self, receipt_type: str) -> None:
# Join the new user
self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
@@ -732,18 +737,18 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
# Read last event
channel = self.make_request(
"POST",
- f"/rooms/{self.room_id}/receipt/{receipt_type}/{res2['event_id']}",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res2['event_id']}",
{},
access_token=self.tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
self._check_unread_count(0)
- # Make sure neither m.read nor org.matrix.msc2285.read.private make the
+ # Make sure neither m.read nor m.read.private make the
# read receipt go up to an older event
channel = self.make_request(
"POST",
- f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res1['event_id']}",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res1['event_id']}",
{},
access_token=self.tok,
)
@@ -948,3 +953,24 @@ class ExcludeRoomTestCase(unittest.HomeserverTestCase):
self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["invite"])
self.assertIn(self.included_room_id, channel.json_body["rooms"]["invite"])
+
+ def test_incremental_sync(self) -> None:
+ """Tests that activity in the room is properly filtered out of incremental
+ syncs.
+ """
+ channel = self.make_request("GET", "/sync", access_token=self.tok)
+ self.assertEqual(channel.code, 200, channel.result)
+ next_batch = channel.json_body["next_batch"]
+
+ self.helper.send(self.excluded_room_id, tok=self.tok)
+ self.helper.send(self.included_room_id, tok=self.tok)
+
+ channel = self.make_request(
+ "GET",
+ f"/sync?since={next_batch}",
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["join"])
+ self.assertIn(self.included_room_id, channel.json_body["rooms"]["join"])
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 1083391b41..2e1b4753dc 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -156,7 +156,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
{},
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
callback.assert_called_once()
@@ -174,7 +174,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
{},
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, channel.result)
def test_third_party_rules_workaround_synapse_errors_pass_through(self) -> None:
"""
@@ -212,7 +212,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
access_token=self.tok,
)
# Check the error code
- self.assertEqual(channel.result["code"], b"429", channel.result)
+ self.assertEqual(channel.code, 429, channel.result)
# Check the JSON body has had the `nasty` key injected
self.assertEqual(
channel.json_body,
@@ -261,7 +261,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
{"x": "x"},
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
event_id = channel.json_body["event_id"]
# ... and check that it got modified
@@ -270,7 +270,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
ev = channel.json_body
self.assertEqual(ev["content"]["x"], "y")
@@ -299,7 +299,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
},
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
orig_event_id = channel.json_body["event_id"]
channel = self.make_request(
@@ -316,7 +316,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
},
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
edited_event_id = channel.json_body["event_id"]
# ... and check that they both got modified
@@ -325,7 +325,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, orig_event_id),
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
ev = channel.json_body
self.assertEqual(ev["content"]["body"], "ORIGINAL BODY")
@@ -334,7 +334,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, edited_event_id),
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
ev = channel.json_body
self.assertEqual(ev["content"]["body"], "EDITED BODY")
@@ -380,7 +380,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
},
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
event_id = channel.json_body["event_id"]
@@ -389,7 +389,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
self.assertIn("foo", channel.json_body["content"].keys())
self.assertEqual(channel.json_body["content"]["foo"], "bar")
diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py
index 61b66d7685..fdc433a8b5 100644
--- a/tests/rest/client/test_typing.py
+++ b/tests/rest/client/test_typing.py
@@ -59,7 +59,8 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
self.event_source.get_new_events(
user=UserID.from_string(self.user_id),
from_key=0,
- limit=None,
+ # Limit is unused.
+ limit=0,
room_ids=[self.room_id],
is_guest=False,
)
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 105d418698..8d6f2b6ff9 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -31,7 +31,6 @@ from typing import (
Tuple,
overload,
)
-from unittest.mock import patch
from urllib.parse import urlencode
import attr
@@ -46,8 +45,19 @@ from synapse.server import HomeServer
from synapse.types import JsonDict
from tests.server import FakeChannel, FakeSite, make_request
-from tests.test_utils import FakeResponse
from tests.test_utils.html_parsers import TestHtmlParser
+from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer
+
+# an 'oidc_config' suitable for login_via_oidc.
+TEST_OIDC_ISSUER = "https://issuer.test/"
+TEST_OIDC_CONFIG = {
+ "enabled": True,
+ "issuer": TEST_OIDC_ISSUER,
+ "client_id": "test-client-id",
+ "client_secret": "test-client-secret",
+ "scopes": ["openid"],
+ "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
+}
@attr.s(auto_attribs=True)
@@ -140,7 +150,7 @@ class RestHelper:
custom_headers=custom_headers,
)
- assert channel.result["code"] == b"%d" % expect_code, channel.result
+ assert channel.code == expect_code, channel.result
self.auth_user_id = temp_id
if expect_code == HTTPStatus.OK:
@@ -213,11 +223,9 @@ class RestHelper:
data,
)
- assert (
- int(channel.result["code"]) == expect_code
- ), "Expected: %d, got: %d, resp: %r" % (
+ assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
- int(channel.result["code"]),
+ channel.code,
channel.result["body"],
)
@@ -312,11 +320,9 @@ class RestHelper:
data,
)
- assert (
- int(channel.result["code"]) == expect_code
- ), "Expected: %d, got: %d, resp: %r" % (
+ assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
- int(channel.result["code"]),
+ channel.code,
channel.result["body"],
)
@@ -396,11 +402,46 @@ class RestHelper:
custom_headers=custom_headers,
)
- assert (
- int(channel.result["code"]) == expect_code
- ), "Expected: %d, got: %d, resp: %r" % (
+ assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
- int(channel.result["code"]),
+ channel.code,
+ channel.result["body"],
+ )
+
+ return channel.json_body
+
+ def get_event(
+ self,
+ room_id: str,
+ event_id: str,
+ tok: Optional[str] = None,
+ expect_code: int = HTTPStatus.OK,
+ ) -> JsonDict:
+ """Request a specific event from the server.
+
+ Args:
+ room_id: the room in which the event was sent.
+ event_id: the event's ID.
+ tok: the token to request the event with.
+ expect_code: the expected HTTP status for the response.
+
+ Returns:
+ The event as a dict.
+ """
+ path = f"/_matrix/client/v3/rooms/{room_id}/event/{event_id}"
+ if tok:
+ path = path + f"?access_token={tok}"
+
+ channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "GET",
+ path,
+ )
+
+ assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
+ expect_code,
+ channel.code,
channel.result["body"],
)
@@ -449,11 +490,9 @@ class RestHelper:
channel = make_request(self.hs.get_reactor(), self.site, method, path, content)
- assert (
- int(channel.result["code"]) == expect_code
- ), "Expected: %d, got: %d, resp: %r" % (
+ assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
- int(channel.result["code"]),
+ channel.code,
channel.result["body"],
)
@@ -545,13 +584,62 @@ class RestHelper:
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
- int(channel.result["code"]),
+ channel.code,
channel.result["body"],
)
return channel.json_body
- def login_via_oidc(self, remote_user_id: str) -> JsonDict:
+ def whoami(
+ self,
+ access_token: str,
+ expect_code: Literal[HTTPStatus.OK, HTTPStatus.UNAUTHORIZED] = HTTPStatus.OK,
+ ) -> JsonDict:
+ """Perform a 'whoami' request, which can be a quick way to check for access
+ token validity
+
+ Args:
+ access_token: The user token to use during the request
+ expect_code: The return code to expect from attempting the whoami request
+ """
+ channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "GET",
+ "account/whoami",
+ access_token=access_token,
+ )
+
+ assert channel.code == expect_code, "Exepcted: %d, got %d, resp: %r" % (
+ expect_code,
+ channel.code,
+ channel.result["body"],
+ )
+
+ return channel.json_body
+
+ def fake_oidc_server(self, issuer: str = TEST_OIDC_ISSUER) -> FakeOidcServer:
+ """Create a ``FakeOidcServer``.
+
+ This can be used in conjuction with ``login_via_oidc``::
+
+ fake_oidc_server = self.helper.fake_oidc_server()
+ login_data, _ = self.helper.login_via_oidc(fake_oidc_server, "user")
+ """
+
+ return FakeOidcServer(
+ clock=self.hs.get_clock(),
+ issuer=issuer,
+ )
+
+ def login_via_oidc(
+ self,
+ fake_server: FakeOidcServer,
+ remote_user_id: str,
+ with_sid: bool = False,
+ idp_id: Optional[str] = None,
+ expected_status: int = 200,
+ ) -> Tuple[JsonDict, FakeAuthorizationGrant]:
"""Log in (as a new user) via OIDC
Returns the result of the final token login.
@@ -564,7 +652,14 @@ class RestHelper:
the normal places.
"""
client_redirect_url = "https://x"
- channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url)
+ userinfo = {"sub": remote_user_id}
+ channel, grant = self.auth_via_oidc(
+ fake_server,
+ userinfo,
+ client_redirect_url,
+ with_sid=with_sid,
+ idp_id=idp_id,
+ )
# expect a confirmation page
assert channel.code == HTTPStatus.OK, channel.result
@@ -586,15 +681,20 @@ class RestHelper:
"/login",
content={"type": "m.login.token", "token": login_token},
)
- assert channel.code == HTTPStatus.OK
- return channel.json_body
+ assert (
+ channel.code == expected_status
+ ), f"unexpected status in response: {channel.code}"
+ return channel.json_body, grant
def auth_via_oidc(
self,
+ fake_server: FakeOidcServer,
user_info_dict: JsonDict,
client_redirect_url: Optional[str] = None,
ui_auth_session_id: Optional[str] = None,
- ) -> FakeChannel:
+ with_sid: bool = False,
+ idp_id: Optional[str] = None,
+ ) -> Tuple[FakeChannel, FakeAuthorizationGrant]:
"""Perform an OIDC authentication flow via a mock OIDC provider.
This can be used for either login or user-interactive auth.
@@ -618,6 +718,8 @@ class RestHelper:
the login redirect endpoint
ui_auth_session_id: if set, we will perform a UI Auth flow. The session id
of the UI auth.
+ with_sid: if True, generates a random `sid` (OIDC session ID)
+ idp_id: if set, explicitely chooses one specific IDP
Returns:
A FakeChannel containing the result of calling the OIDC callback endpoint.
@@ -627,14 +729,17 @@ class RestHelper:
cookies: Dict[str, str] = {}
- # if we're doing a ui auth, hit the ui auth redirect endpoint
- if ui_auth_session_id:
- # can't set the client redirect url for UI Auth
- assert client_redirect_url is None
- oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies)
- else:
- # otherwise, hit the login redirect endpoint
- oauth_uri = self.initiate_sso_login(client_redirect_url, cookies)
+ with fake_server.patch_homeserver(hs=self.hs):
+ # if we're doing a ui auth, hit the ui auth redirect endpoint
+ if ui_auth_session_id:
+ # can't set the client redirect url for UI Auth
+ assert client_redirect_url is None
+ oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies)
+ else:
+ # otherwise, hit the login redirect endpoint
+ oauth_uri = self.initiate_sso_login(
+ client_redirect_url, cookies, idp_id=idp_id
+ )
# we now have a URI for the OIDC IdP, but we skip that and go straight
# back to synapse's OIDC callback resource. However, we do need the "state"
@@ -642,17 +747,21 @@ class RestHelper:
# that synapse passes to the client.
oauth_uri_path, _ = oauth_uri.split("?", 1)
- assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, (
+ assert oauth_uri_path == fake_server.authorization_endpoint, (
"unexpected SSO URI " + oauth_uri_path
)
- return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict)
+ return self.complete_oidc_auth(
+ fake_server, oauth_uri, cookies, user_info_dict, with_sid=with_sid
+ )
def complete_oidc_auth(
self,
+ fake_serer: FakeOidcServer,
oauth_uri: str,
cookies: Mapping[str, str],
user_info_dict: JsonDict,
- ) -> FakeChannel:
+ with_sid: bool = False,
+ ) -> Tuple[FakeChannel, FakeAuthorizationGrant]:
"""Mock out an OIDC authentication flow
Assumes that an OIDC auth has been initiated by one of initiate_sso_login or
@@ -663,50 +772,37 @@ class RestHelper:
Requires the OIDC callback resource to be mounted at the normal place.
Args:
+ fake_server: the fake OIDC server with which the auth should be done
oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie,
from initiate_sso_login or initiate_sso_ui_auth).
cookies: the cookies set by synapse's redirect endpoint, which will be
sent back to the callback endpoint.
user_info_dict: the remote userinfo that the OIDC provider should present.
Typically this should be '{"sub": "<remote user id>"}'.
+ with_sid: if True, generates a random `sid` (OIDC session ID)
Returns:
A FakeChannel containing the result of calling the OIDC callback endpoint.
"""
_, oauth_uri_qs = oauth_uri.split("?", 1)
params = urllib.parse.parse_qs(oauth_uri_qs)
+
+ code, grant = fake_serer.start_authorization(
+ scope=params["scope"][0],
+ userinfo=user_info_dict,
+ client_id=params["client_id"][0],
+ redirect_uri=params["redirect_uri"][0],
+ nonce=params["nonce"][0],
+ with_sid=with_sid,
+ )
+ state = params["state"][0]
+
callback_uri = "%s?%s" % (
urllib.parse.urlparse(params["redirect_uri"][0]).path,
- urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}),
- )
-
- # before we hit the callback uri, stub out some methods in the http client so
- # that we don't have to handle full HTTPS requests.
- # (expected url, json response) pairs, in the order we expect them.
- expected_requests = [
- # first we get a hit to the token endpoint, which we tell to return
- # a dummy OIDC access token
- (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}),
- # and then one to the user_info endpoint, which returns our remote user id.
- (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict),
- ]
-
- async def mock_req(
- method: str,
- uri: str,
- data: Optional[dict] = None,
- headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
- ):
- (expected_uri, resp_obj) = expected_requests.pop(0)
- assert uri == expected_uri
- resp = FakeResponse(
- code=HTTPStatus.OK,
- phrase=b"OK",
- body=json.dumps(resp_obj).encode("utf-8"),
- )
- return resp
+ urllib.parse.urlencode({"state": state, "code": code}),
+ )
- with patch.object(self.hs.get_proxied_http_client(), "request", mock_req):
+ with fake_serer.patch_homeserver(hs=self.hs):
# now hit the callback URI with the right params and a made-up code
channel = make_request(
self.hs.get_reactor(),
@@ -717,10 +813,13 @@ class RestHelper:
("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items()
],
)
- return channel
+ return channel, grant
def initiate_sso_login(
- self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str]
+ self,
+ client_redirect_url: Optional[str],
+ cookies: MutableMapping[str, str],
+ idp_id: Optional[str] = None,
) -> str:
"""Make a request to the login-via-sso redirect endpoint, and return the target
@@ -731,6 +830,7 @@ class RestHelper:
client_redirect_url: the client redirect URL to pass to the login redirect
endpoint
cookies: any cookies returned will be added to this dict
+ idp_id: if set, explicitely chooses one specific IDP
Returns:
the URI that the client gets redirected to (ie, the SSO server)
@@ -739,6 +839,12 @@ class RestHelper:
if client_redirect_url:
params["redirectUrl"] = client_redirect_url
+ uri = "/_matrix/client/r0/login/sso/redirect"
+ if idp_id is not None:
+ uri = f"{uri}/{idp_id}"
+
+ uri = f"{uri}?{urllib.parse.urlencode(params)}"
+
# hit the redirect url (which should redirect back to the redirect url. This
# is the easiest way of figuring out what the Host header ought to be set to
# to keep Synapse happy.
@@ -746,7 +852,7 @@ class RestHelper:
self.hs.get_reactor(),
self.site,
"GET",
- "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params),
+ uri,
)
assert channel.code == 302
@@ -808,21 +914,3 @@ class RestHelper:
assert len(p.links) == 1, "not exactly one link in confirmation page"
oauth_uri = p.links[0]
return oauth_uri
-
-
-# an 'oidc_config' suitable for login_via_oidc.
-TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth"
-TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token"
-TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo"
-TEST_OIDC_CONFIG = {
- "enabled": True,
- "discover": False,
- "issuer": "https://issuer.test",
- "client_id": "test-client-id",
- "client_secret": "test-client-secret",
- "scopes": ["profile"],
- "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT,
- "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT,
- "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT,
- "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
-}
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index ac0ac06b7e..7f1fba1086 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -26,7 +26,7 @@ from twisted.web.resource import NoResource, Resource
from synapse.crypto.keyring import PerspectivesKeyFetcher
from synapse.http.site import SynapseRequest
-from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.rest.key.v2 import KeyResource
from synapse.server import HomeServer
from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict
@@ -46,7 +46,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
def create_test_resource(self) -> Resource:
return create_resource_tree(
- {"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource()
+ {"/_matrix/key/v2": KeyResource(self.hs)}, root_resource=NoResource()
)
def expect_outgoing_key_request(
diff --git a/tests/rest/media/test_media_retention.py b/tests/rest/media/test_media_retention.py
index 14af07c5af..23f227aed6 100644
--- a/tests/rest/media/test_media_retention.py
+++ b/tests/rest/media/test_media_retention.py
@@ -13,7 +13,9 @@
# limitations under the License.
import io
-from typing import Iterable, Optional, Tuple
+from typing import Iterable, Optional
+
+from matrix_common.types.mxc_uri import MXCUri
from twisted.test.proto_helpers import MemoryReactor
@@ -63,9 +65,9 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
last_accessed_ms: Optional[int],
is_quarantined: Optional[bool] = False,
is_protected: Optional[bool] = False,
- ) -> str:
+ ) -> MXCUri:
# "Upload" some media to the local media store
- mxc_uri = self.get_success(
+ mxc_uri: MXCUri = self.get_success(
media_repository.create_content(
media_type="text/plain",
upload_name=None,
@@ -75,13 +77,11 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
)
)
- media_id = mxc_uri.split("/")[-1]
-
# Set the last recently accessed time for this media
if last_accessed_ms is not None:
self.get_success(
self.store.update_cached_last_access_time(
- local_media=(media_id,),
+ local_media=(mxc_uri.media_id,),
remote_media=(),
time_ms=last_accessed_ms,
)
@@ -92,7 +92,7 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
self.get_success(
self.store.quarantine_media_by_id(
server_name=self.hs.config.server.server_name,
- media_id=media_id,
+ media_id=mxc_uri.media_id,
quarantined_by="@theadmin:test",
)
)
@@ -101,18 +101,18 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
# Mark this media as protected from quarantine
self.get_success(
self.store.mark_local_media_as_safe(
- media_id=media_id,
+ media_id=mxc_uri.media_id,
safe=True,
)
)
- return media_id
+ return mxc_uri
def _cache_remote_media_and_set_attributes(
media_id: str,
last_accessed_ms: Optional[int],
is_quarantined: Optional[bool] = False,
- ) -> str:
+ ) -> MXCUri:
# Pretend to cache some remote media
self.get_success(
self.store.store_cached_remote_media(
@@ -146,7 +146,7 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
)
)
- return media_id
+ return MXCUri(self.remote_server_name, media_id)
# Start with the local media store
self.local_recently_accessed_media = _create_media_and_set_attributes(
@@ -214,28 +214,16 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
# Remote media should be unaffected.
self._assert_if_mxc_uris_purged(
purged=[
- (
- self.hs.config.server.server_name,
- self.local_not_recently_accessed_media,
- ),
- (self.hs.config.server.server_name, self.local_never_accessed_media),
+ self.local_not_recently_accessed_media,
+ self.local_never_accessed_media,
],
not_purged=[
- (self.hs.config.server.server_name, self.local_recently_accessed_media),
- (
- self.hs.config.server.server_name,
- self.local_not_recently_accessed_quarantined_media,
- ),
- (
- self.hs.config.server.server_name,
- self.local_not_recently_accessed_protected_media,
- ),
- (self.remote_server_name, self.remote_recently_accessed_media),
- (self.remote_server_name, self.remote_not_recently_accessed_media),
- (
- self.remote_server_name,
- self.remote_not_recently_accessed_quarantined_media,
- ),
+ self.local_recently_accessed_media,
+ self.local_not_recently_accessed_quarantined_media,
+ self.local_not_recently_accessed_protected_media,
+ self.remote_recently_accessed_media,
+ self.remote_not_recently_accessed_media,
+ self.remote_not_recently_accessed_quarantined_media,
],
)
@@ -261,49 +249,35 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
# Remote media accessed <30 days ago should still exist.
self._assert_if_mxc_uris_purged(
purged=[
- (self.remote_server_name, self.remote_not_recently_accessed_media),
+ self.remote_not_recently_accessed_media,
],
not_purged=[
- (self.remote_server_name, self.remote_recently_accessed_media),
- (self.hs.config.server.server_name, self.local_recently_accessed_media),
- (
- self.hs.config.server.server_name,
- self.local_not_recently_accessed_media,
- ),
- (
- self.hs.config.server.server_name,
- self.local_not_recently_accessed_quarantined_media,
- ),
- (
- self.hs.config.server.server_name,
- self.local_not_recently_accessed_protected_media,
- ),
- (
- self.remote_server_name,
- self.remote_not_recently_accessed_quarantined_media,
- ),
- (self.hs.config.server.server_name, self.local_never_accessed_media),
+ self.remote_recently_accessed_media,
+ self.local_recently_accessed_media,
+ self.local_not_recently_accessed_media,
+ self.local_not_recently_accessed_quarantined_media,
+ self.local_not_recently_accessed_protected_media,
+ self.remote_not_recently_accessed_quarantined_media,
+ self.local_never_accessed_media,
],
)
def _assert_if_mxc_uris_purged(
- self, purged: Iterable[Tuple[str, str]], not_purged: Iterable[Tuple[str, str]]
+ self, purged: Iterable[MXCUri], not_purged: Iterable[MXCUri]
) -> None:
- def _assert_mxc_uri_purge_state(
- server_name: str, media_id: str, expect_purged: bool
- ) -> None:
+ def _assert_mxc_uri_purge_state(mxc_uri: MXCUri, expect_purged: bool) -> None:
"""Given an MXC URI, assert whether it has been purged or not."""
- if server_name == self.hs.config.server.server_name:
+ if mxc_uri.server_name == self.hs.config.server.server_name:
found_media_dict = self.get_success(
- self.store.get_local_media(media_id)
+ self.store.get_local_media(mxc_uri.media_id)
)
else:
found_media_dict = self.get_success(
- self.store.get_cached_remote_media(server_name, media_id)
+ self.store.get_cached_remote_media(
+ mxc_uri.server_name, mxc_uri.media_id
+ )
)
- mxc_uri = f"mxc://{server_name}/{media_id}"
-
if expect_purged:
self.assertIsNone(
found_media_dict, msg=f"{mxc_uri} unexpectedly not purged"
@@ -315,7 +289,7 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
)
# Assert that the given MXC URIs have either been correctly purged or not.
- for server_name, media_id in purged:
- _assert_mxc_uri_purge_state(server_name, media_id, expect_purged=True)
- for server_name, media_id in not_purged:
- _assert_mxc_uri_purge_state(server_name, media_id, expect_purged=False)
+ for mxc_uri in purged:
+ _assert_mxc_uri_purge_state(mxc_uri, expect_purged=True)
+ for mxc_uri in not_purged:
+ _assert_mxc_uri_purge_state(mxc_uri, expect_purged=False)
diff --git a/tests/rest/media/v1/test_oembed.py b/tests/rest/media/v1/test_oembed.py
index f38d7225f8..319ae8b1cc 100644
--- a/tests/rest/media/v1/test_oembed.py
+++ b/tests/rest/media/v1/test_oembed.py
@@ -14,6 +14,8 @@
import json
+from parameterized import parameterized
+
from twisted.test.proto_helpers import MemoryReactor
from synapse.rest.media.v1.oembed import OEmbedProvider, OEmbedResult
@@ -23,8 +25,16 @@ from synapse.util import Clock
from tests.unittest import HomeserverTestCase
+try:
+ import lxml
+except ImportError:
+ lxml = None
+
class OEmbedTests(HomeserverTestCase):
+ if not lxml:
+ skip = "url preview feature requires lxml"
+
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.oembed = OEmbedProvider(hs)
@@ -36,7 +46,7 @@ class OEmbedTests(HomeserverTestCase):
def test_version(self) -> None:
"""Accept versions that are similar to 1.0 as a string or int (or missing)."""
for version in ("1.0", 1.0, 1):
- result = self.parse_response({"version": version, "type": "link"})
+ result = self.parse_response({"version": version})
# An empty Open Graph response is an error, ensure the URL is included.
self.assertIn("og:url", result.open_graph_result)
@@ -49,3 +59,94 @@ class OEmbedTests(HomeserverTestCase):
result = self.parse_response({"version": version, "type": "link"})
# An empty Open Graph response is an error, ensure the URL is included.
self.assertEqual({}, result.open_graph_result)
+
+ def test_cache_age(self) -> None:
+ """Ensure a cache-age is parsed properly."""
+ # Correct-ish cache ages are allowed.
+ for cache_age in ("1", 1.0, 1):
+ result = self.parse_response({"cache_age": cache_age})
+ self.assertEqual(result.cache_age, 1000)
+
+ # Invalid cache ages are ignored.
+ for cache_age in ("invalid", {}):
+ result = self.parse_response({"cache_age": cache_age})
+ self.assertIsNone(result.cache_age)
+
+ # Cache age is optional.
+ result = self.parse_response({})
+ self.assertIsNone(result.cache_age)
+
+ @parameterized.expand(
+ [
+ ("title", "title"),
+ ("provider_name", "site_name"),
+ ("thumbnail_url", "image"),
+ ],
+ name_func=lambda func, num, p: f"{func.__name__}_{p.args[0]}",
+ )
+ def test_property(self, oembed_property: str, open_graph_property: str) -> None:
+ """Test properties which must be strings."""
+ result = self.parse_response({oembed_property: "test"})
+ self.assertIn(f"og:{open_graph_property}", result.open_graph_result)
+ self.assertEqual(result.open_graph_result[f"og:{open_graph_property}"], "test")
+
+ result = self.parse_response({oembed_property: 1})
+ self.assertNotIn(f"og:{open_graph_property}", result.open_graph_result)
+
+ def test_author_name(self) -> None:
+ """Test the author_name property."""
+ result = self.parse_response({"author_name": "test"})
+ self.assertEqual(result.author_name, "test")
+
+ result = self.parse_response({"author_name": 1})
+ self.assertIsNone(result.author_name)
+
+ def test_rich(self) -> None:
+ """Test a type of rich."""
+ result = self.parse_response({"html": "test<img src='foo'>", "type": "rich"})
+ self.assertIn("og:description", result.open_graph_result)
+ self.assertIn("og:image", result.open_graph_result)
+ self.assertEqual(result.open_graph_result["og:description"], "test")
+ self.assertEqual(result.open_graph_result["og:image"], "foo")
+
+ result = self.parse_response({"type": "rich"})
+ self.assertNotIn("og:description", result.open_graph_result)
+
+ result = self.parse_response({"html": 1, "type": "rich"})
+ self.assertNotIn("og:description", result.open_graph_result)
+
+ def test_photo(self) -> None:
+ """Test a type of photo."""
+ result = self.parse_response({"url": "test", "type": "photo"})
+ self.assertIn("og:image", result.open_graph_result)
+ self.assertEqual(result.open_graph_result["og:image"], "test")
+
+ result = self.parse_response({"type": "photo"})
+ self.assertNotIn("og:image", result.open_graph_result)
+
+ result = self.parse_response({"url": 1, "type": "photo"})
+ self.assertNotIn("og:image", result.open_graph_result)
+
+ def test_video(self) -> None:
+ """Test a type of video."""
+ result = self.parse_response({"html": "test", "type": "video"})
+ self.assertIn("og:type", result.open_graph_result)
+ self.assertEqual(result.open_graph_result["og:type"], "video.other")
+ self.assertIn("og:description", result.open_graph_result)
+ self.assertEqual(result.open_graph_result["og:description"], "test")
+
+ result = self.parse_response({"type": "video"})
+ self.assertIn("og:type", result.open_graph_result)
+ self.assertEqual(result.open_graph_result["og:type"], "video.other")
+ self.assertNotIn("og:description", result.open_graph_result)
+
+ result = self.parse_response({"url": 1, "type": "video"})
+ self.assertIn("og:type", result.open_graph_result)
+ self.assertEqual(result.open_graph_result["og:type"], "video.other")
+ self.assertNotIn("og:description", result.open_graph_result)
+
+ def test_link(self) -> None:
+ """Test type of link."""
+ result = self.parse_response({"type": "link"})
+ self.assertIn("og:type", result.open_graph_result)
+ self.assertEqual(result.open_graph_result["og:type"], "website")
diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py
index da325955f8..c0a2501742 100644
--- a/tests/rest/test_health.py
+++ b/tests/rest/test_health.py
@@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
-
from synapse.rest.health import HealthResource
from tests import unittest
@@ -26,5 +24,5 @@ class HealthCheckTests(unittest.HomeserverTestCase):
def test_health(self) -> None:
channel = self.make_request("GET", "/health", shorthand=False)
- self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertEqual(channel.code, 200)
self.assertEqual(channel.result["body"], b"OK")
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index d8faafec75..2091b08d89 100644
--- a/tests/rest/test_well_known.py
+++ b/tests/rest/test_well_known.py
@@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
-
from twisted.web.resource import Resource
from synapse.rest.well_known import well_known_resource
@@ -38,7 +36,7 @@ class WellKnownTests(unittest.HomeserverTestCase):
"GET", "/.well-known/matrix/client", shorthand=False
)
- self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertEqual(channel.code, 200)
self.assertEqual(
channel.json_body,
{
@@ -57,7 +55,7 @@ class WellKnownTests(unittest.HomeserverTestCase):
"GET", "/.well-known/matrix/client", shorthand=False
)
- self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
+ self.assertEqual(channel.code, 404)
@unittest.override_config(
{
@@ -71,7 +69,7 @@ class WellKnownTests(unittest.HomeserverTestCase):
"GET", "/.well-known/matrix/client", shorthand=False
)
- self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertEqual(channel.code, 200)
self.assertEqual(
channel.json_body,
{
@@ -87,7 +85,7 @@ class WellKnownTests(unittest.HomeserverTestCase):
"GET", "/.well-known/matrix/server", shorthand=False
)
- self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertEqual(channel.code, 200)
self.assertEqual(
channel.json_body,
{"m.server": "test:443"},
@@ -97,4 +95,4 @@ class WellKnownTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET", "/.well-known/matrix/server", shorthand=False
)
- self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
+ self.assertEqual(channel.code, 404)
diff --git a/tests/server.py b/tests/server.py
index 9689e6a0cd..b1730fcc8d 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -61,6 +61,10 @@ from twisted.web.resource import IResource
from twisted.web.server import Request, Site
from synapse.config.database import DatabaseConnectionConfig
+from synapse.events.presence_router import load_legacy_presence_router
+from synapse.events.spamcheck import load_legacy_spam_checkers
+from synapse.events.third_party_rules import load_legacy_third_party_event_rules
+from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.http.site import SynapseRequest
from synapse.logging.context import ContextResourceUsage
from synapse.server import HomeServer
@@ -262,7 +266,12 @@ class FakeSite:
site_tag = "test"
access_logger = logging.getLogger("synapse.access.http.fake")
- def __init__(self, resource: IResource, reactor: IReactorTime):
+ def __init__(
+ self,
+ resource: IResource,
+ reactor: IReactorTime,
+ experimental_cors_msc3886: bool = False,
+ ):
"""
Args:
@@ -270,6 +279,7 @@ class FakeSite:
"""
self._resource = resource
self.reactor = reactor
+ self.experimental_cors_msc3886 = experimental_cors_msc3886
def getResourceFor(self, request):
return self._resource
@@ -352,6 +362,12 @@ def make_request(
# Twisted expects to be at the end of the content when parsing the request.
req.content.seek(0, SEEK_END)
+ # Old version of Twisted (<20.3.0) have issues with parsing x-www-form-urlencoded
+ # bodies if the Content-Length header is missing
+ req.requestHeaders.addRawHeader(
+ b"Content-Length", str(len(content)).encode("ascii")
+ )
+
if access_token:
req.requestHeaders.addRawHeader(
b"Authorization", b"Bearer " + access_token.encode("ascii")
@@ -913,4 +929,14 @@ def setup_test_homeserver(
# Make the threadpool and database transactions synchronous for testing.
_make_test_homeserver_synchronous(hs)
+ # Load any configured modules into the homeserver
+ module_api = hs.get_module_api()
+ for module, config in hs.config.modules.loaded_modules:
+ module(config=config, api=module_api)
+
+ load_legacy_spam_checkers(hs)
+ load_legacy_third_party_event_rules(hs)
+ load_legacy_presence_router(hs)
+ load_legacy_password_auth_providers(hs)
+
return hs
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index e07ae78fc4..7cbc40736c 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -11,16 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+from typing import Tuple
from unittest.mock import Mock
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType
from synapse.api.errors import ResourceLimitError
from synapse.rest import admin
from synapse.rest.client import login, room, sync
+from synapse.server import HomeServer
from synapse.server_notices.resource_limits_server_notices import (
ResourceLimitsServerNotices,
)
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
@@ -52,7 +56,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
return config
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.server_notices_sender = self.hs.get_server_notices_sender()
# relying on [1] is far from ideal, but the only case where
@@ -251,7 +255,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
c["admin_contact"] = "mailto:user@test.com"
return c
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = self.hs.get_datastores().main
self.server_notices_sender = self.hs.get_server_notices_sender()
self.server_notices_manager = self.hs.get_server_notices_manager()
@@ -347,14 +351,15 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
self.assertTrue(notice_in_room, "No server notice in room")
- def _trigger_notice_and_join(self):
+ def _trigger_notice_and_join(self) -> Tuple[str, str, str]:
"""Creates enough active users to hit the MAU limit and trigger a system notice
about it, then joins the system notices room with one of the users created.
Returns:
- user_id (str): The ID of the user that joined the room.
- tok (str): The access token of the user that joined the room.
- room_id (str): The ID of the room that's been joined.
+ A tuple of:
+ user_id: The ID of the user that joined the room.
+ tok: The access token of the user that joined the room.
+ room_id: The ID of the room that's been joined.
"""
user_id = None
tok = None
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 46d829b062..5773172ab8 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -35,66 +35,45 @@ from synapse.util import Clock
from synapse.util.async_helpers import yieldable_gather_results
from tests import unittest
+from tests.test_utils.event_injection import create_event, inject_event
class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
def prepare(self, reactor, clock, hs):
+ self.hs = hs
self.store: EventsWorkerStore = hs.get_datastores().main
- # insert some test data
- for rid in ("room1", "room2"):
- self.get_success(
- self.store.db_pool.simple_insert(
- "rooms",
- {"room_id": rid, "room_version": 4},
- )
- )
+ self.user = self.register_user("user", "pass")
+ self.token = self.login(self.user, "pass")
+ self.room_id = self.helper.create_room_as(self.user, tok=self.token)
self.event_ids: List[str] = []
- for idx, rid in enumerate(
- (
- "room1",
- "room1",
- "room1",
- "room2",
- )
- ):
- event_json = {"type": f"test {idx}", "room_id": rid}
- event = make_event_from_dict(event_json, room_version=RoomVersions.V4)
- event_id = event.event_id
-
- self.get_success(
- self.store.db_pool.simple_insert(
- "events",
- {
- "event_id": event_id,
- "room_id": rid,
- "topological_ordering": idx,
- "stream_ordering": idx,
- "type": event.type,
- "processed": True,
- "outlier": False,
- },
- )
- )
- self.get_success(
- self.store.db_pool.simple_insert(
- "event_json",
- {
- "event_id": event_id,
- "room_id": rid,
- "json": json.dumps(event_json),
- "internal_metadata": "{}",
- "format_version": 3,
- },
+ for i in range(3):
+ event = self.get_success(
+ inject_event(
+ hs,
+ room_version=RoomVersions.V7.identifier,
+ room_id=self.room_id,
+ sender=self.user,
+ type="test_event_type",
+ content={"body": f"foobarbaz{i}"},
)
)
- self.event_ids.append(event_id)
+
+ self.event_ids.append(event.event_id)
def test_simple(self):
with LoggingContext(name="test") as ctx:
res = self.get_success(
- self.store.have_seen_events("room1", [self.event_ids[0], "event19"])
+ self.store.have_seen_events(
+ self.room_id, [self.event_ids[0], "eventdoesnotexist"]
+ )
)
self.assertEqual(res, {self.event_ids[0]})
@@ -104,22 +83,87 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
# a second lookup of the same events should cause no queries
with LoggingContext(name="test") as ctx:
res = self.get_success(
- self.store.have_seen_events("room1", [self.event_ids[0], "event19"])
+ self.store.have_seen_events(
+ self.room_id, [self.event_ids[0], "eventdoesnotexist"]
+ )
)
self.assertEqual(res, {self.event_ids[0]})
self.assertEqual(ctx.get_resource_usage().db_txn_count, 0)
- def test_query_via_event_cache(self):
- # fetch an event into the event cache
- self.get_success(self.store.get_event(self.event_ids[0]))
+ def test_persisting_event_invalidates_cache(self):
+ """
+ Test to make sure that the `have_seen_event` cache
+ is invalidated after we persist an event and returns
+ the updated value.
+ """
+ event, event_context = self.get_success(
+ create_event(
+ self.hs,
+ room_id=self.room_id,
+ sender=self.user,
+ type="test_event_type",
+ content={"body": "garply"},
+ )
+ )
- # looking it up should now cause no db hits
with LoggingContext(name="test") as ctx:
+ # First, check `have_seen_event` for an event we have not seen yet
+ # to prime the cache with a `false` value.
res = self.get_success(
- self.store.have_seen_events("room1", [self.event_ids[0]])
+ self.store.have_seen_events(event.room_id, [event.event_id])
)
- self.assertEqual(res, {self.event_ids[0]})
- self.assertEqual(ctx.get_resource_usage().db_txn_count, 0)
+ self.assertEqual(res, set())
+
+ # That should result in a single db query to lookup
+ self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
+
+ # Persist the event which should invalidate or prefill the
+ # `have_seen_event` cache so we don't return stale values.
+ persistence = self.hs.get_storage_controllers().persistence
+ self.get_success(
+ persistence.persist_event(
+ event,
+ event_context,
+ )
+ )
+
+ with LoggingContext(name="test") as ctx:
+ # Check `have_seen_event` again and we should see the updated fact
+ # that we have now seen the event after persisting it.
+ res = self.get_success(
+ self.store.have_seen_events(event.room_id, [event.event_id])
+ )
+ self.assertEqual(res, {event.event_id})
+
+ # That should result in a single db query to lookup
+ self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
+
+ def test_invalidate_cache_by_room_id(self):
+ """
+ Test to make sure that all events associated with the given `(room_id,)`
+ are invalidated in the `have_seen_event` cache.
+ """
+ with LoggingContext(name="test") as ctx:
+ # Prime the cache with some values
+ res = self.get_success(
+ self.store.have_seen_events(self.room_id, self.event_ids)
+ )
+ self.assertEqual(res, set(self.event_ids))
+
+ # That should result in a single db query to lookup
+ self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
+
+ # Clear the cache with any events associated with the `room_id`
+ self.store.have_seen_event.invalidate((self.room_id,))
+
+ with LoggingContext(name="test") as ctx:
+ res = self.get_success(
+ self.store.have_seen_events(self.room_id, self.event_ids)
+ )
+ self.assertEqual(res, set(self.event_ids))
+
+ # Since we cleared the cache, it should result in another db query to lookup
+ self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
class EventCacheTestCase(unittest.HomeserverTestCase):
@@ -254,7 +298,7 @@ class DatabaseOutageTestCase(unittest.HomeserverTestCase):
"room_id": self.room_id,
"json": json.dumps(event_json),
"internal_metadata": "{}",
- "format_version": EventFormatVersions.V3,
+ "format_version": EventFormatVersions.ROOM_V4_PLUS,
},
)
)
diff --git a/tests/storage/databases/main/test_receipts.py b/tests/storage/databases/main/test_receipts.py
new file mode 100644
index 0000000000..c4f12d81d7
--- /dev/null
+++ b/tests/storage/databases/main/test_receipts.py
@@ -0,0 +1,209 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Optional, Sequence, Tuple
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.rest import admin
+from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.storage.database import LoggingTransaction
+from synapse.util import Clock
+
+from tests.unittest import HomeserverTestCase
+
+
+class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+ self.store = hs.get_datastores().main
+ self.user_id = self.register_user("foo", "pass")
+ self.token = self.login("foo", "pass")
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+ self.other_room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ def _test_background_receipts_unique_index(
+ self,
+ update_name: str,
+ index_name: str,
+ table: str,
+ receipts: Dict[Tuple[str, str, str], Sequence[Dict[str, Any]]],
+ expected_unique_receipts: Dict[Tuple[str, str, str], Optional[Dict[str, Any]]],
+ ):
+ """Test that the background update to uniqueify non-thread receipts in
+ the given receipts table works properly.
+
+ Args:
+ update_name: The name of the background update to test.
+ index_name: The name of the index that the background update creates.
+ table: The table of receipts that the background update fixes.
+ receipts: The test data containing duplicate receipts.
+ A list of receipt rows to insert, grouped by
+ `(room_id, receipt_type, user_id)`.
+ expected_unique_receipts: A dictionary of `(room_id, receipt_type, user_id)`
+ keys and expected receipt key-values after duplicate receipts have been
+ removed.
+ """
+ # First, undo the background update.
+ def drop_receipts_unique_index(txn: LoggingTransaction) -> None:
+ txn.execute(f"DROP INDEX IF EXISTS {index_name}")
+
+ self.get_success(
+ self.store.db_pool.runInteraction(
+ "drop_receipts_unique_index",
+ drop_receipts_unique_index,
+ )
+ )
+
+ # Populate the receipts table, including duplicates.
+ for (room_id, receipt_type, user_id), rows in receipts.items():
+ for row in rows:
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ table,
+ {
+ "room_id": room_id,
+ "receipt_type": receipt_type,
+ "user_id": user_id,
+ "thread_id": None,
+ "data": "{}",
+ **row,
+ },
+ )
+ )
+
+ # Insert and run the background update.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {
+ "update_name": update_name,
+ "progress_json": "{}",
+ },
+ )
+ )
+
+ self.store.db_pool.updates._all_done = False
+
+ self.wait_for_background_updates()
+
+ # Check that the remaining receipts match expectations.
+ for (
+ room_id,
+ receipt_type,
+ user_id,
+ ), expected_row in expected_unique_receipts.items():
+ # Include the receipt key in the returned columns, for more informative
+ # assertion messages.
+ columns = ["room_id", "receipt_type", "user_id"]
+ if expected_row is not None:
+ columns += expected_row.keys()
+
+ rows = self.get_success(
+ self.store.db_pool.simple_select_list(
+ table=table,
+ keyvalues={
+ "room_id": room_id,
+ "receipt_type": receipt_type,
+ "user_id": user_id,
+ # `simple_select_onecol` does not support NULL filters,
+ # so skip the filter on `thread_id`.
+ },
+ retcols=columns,
+ desc="get_receipt",
+ )
+ )
+
+ if expected_row is not None:
+ self.assertEqual(
+ len(rows),
+ 1,
+ f"Background update did not leave behind latest receipt in {table}",
+ )
+ self.assertEqual(
+ rows[0],
+ {
+ "room_id": room_id,
+ "receipt_type": receipt_type,
+ "user_id": user_id,
+ **expected_row,
+ },
+ )
+ else:
+ self.assertEqual(
+ len(rows),
+ 0,
+ f"Background update did not remove all duplicate receipts from {table}",
+ )
+
+ def test_background_receipts_linearized_unique_index(self):
+ """Test that the background update to uniqueify non-thread receipts in
+ `receipts_linearized` works properly.
+ """
+ self._test_background_receipts_unique_index(
+ "receipts_linearized_unique_index",
+ "receipts_linearized_unique_index",
+ "receipts_linearized",
+ receipts={
+ (self.room_id, "m.read", self.user_id): [
+ {"stream_id": 5, "event_id": "$some_event"},
+ {"stream_id": 6, "event_id": "$some_event"},
+ ],
+ (self.other_room_id, "m.read", self.user_id): [
+ {"stream_id": 7, "event_id": "$some_event"}
+ ],
+ },
+ expected_unique_receipts={
+ (self.room_id, "m.read", self.user_id): {"stream_id": 6},
+ (self.other_room_id, "m.read", self.user_id): {"stream_id": 7},
+ },
+ )
+
+ def test_background_receipts_graph_unique_index(self):
+ """Test that the background update to uniqueify non-thread receipts in
+ `receipts_graph` works properly.
+ """
+ self._test_background_receipts_unique_index(
+ "receipts_graph_unique_index",
+ "receipts_graph_unique_index",
+ "receipts_graph",
+ receipts={
+ (self.room_id, "m.read", self.user_id): [
+ {
+ "event_ids": '["$some_event"]',
+ },
+ {
+ "event_ids": '["$some_event"]',
+ },
+ ],
+ (self.other_room_id, "m.read", self.user_id): [
+ {
+ "event_ids": '["$some_event"]',
+ }
+ ],
+ },
+ expected_unique_receipts={
+ (self.room_id, "m.read", self.user_id): None,
+ (self.other_room_id, "m.read", self.user_id): {
+ "event_ids": '["$some_event"]'
+ },
+ },
+ )
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index cce8e75c74..40e58f8199 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -54,7 +54,6 @@ class SQLBaseStoreTestCase(unittest.TestCase):
sqlite_config = {"name": "sqlite3"}
engine = create_engine(sqlite_config)
fake_engine = Mock(wraps=engine)
- fake_engine.can_native_upsert = False
fake_engine.in_transaction.return_value = False
db = DatabasePool(Mock(), Mock(config=sqlite_config), fake_engine)
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index f37505b6cf..8e7db2c4ec 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -28,7 +28,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
"""
for device_id in device_ids:
- stream_id = self.get_success(
+ self.get_success(
self.store.add_device_change_to_streams(
user_id, [device_id], ["!some:room"]
)
@@ -39,7 +39,6 @@ class DeviceStoreTestCase(HomeserverTestCase):
user_id=user_id,
device_id=device_id,
room_id="!some:room",
- stream_id=stream_id,
hosts=[host],
context={},
)
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index a0ce077a99..de9f4af2de 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -531,7 +531,9 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
)
)
self.get_success(
- event_handler.handle_new_client_event(self.requester, event, context)
+ event_handler.handle_new_client_event(
+ self.requester, events_and_context=[(event, context)]
+ )
)
state1 = set(self.get_success(context.get_current_state_ids()).values())
@@ -549,7 +551,9 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
)
)
self.get_success(
- event_handler.handle_new_client_event(self.requester, event, context)
+ event_handler.handle_new_client_event(
+ self.requester, events_and_context=[(event, context)]
+ )
)
state2 = set(self.get_success(context.get_current_state_ids()).values())
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index d92a9ac5b7..853db930d6 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -12,25 +12,46 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Tuple, Union
+import datetime
+from typing import Dict, List, Tuple, Union
import attr
from parameterized import parameterized
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.constants import EventTypes
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
RoomVersion,
)
from synapse.events import _EventInternalMetadata
-from synapse.util import json_encoder
+from synapse.rest import admin
+from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.storage.database import LoggingTransaction
+from synapse.types import JsonDict
+from synapse.util import Clock, json_encoder
import tests.unittest
import tests.utils
+@attr.s(auto_attribs=True, frozen=True, slots=True)
+class _BackfillSetupInfo:
+ room_id: str
+ depth_map: Dict[str, int]
+
+
class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
def test_get_prev_events_for_room(self):
@@ -513,7 +534,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def prev_event_format(prev_event_id: str) -> Union[Tuple[str, dict], str]:
"""Account for differences in prev_events format across room versions"""
- if room_version.event_format == EventFormatVersions.V1:
+ if room_version.event_format == EventFormatVersions.ROOM_V1_V2:
return prev_event_id, {}
return prev_event_id
@@ -571,11 +592,600 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
)
self.assertEqual(count, 1)
- _, event_id = self.get_success(
+ next_staged_event_info = self.get_success(
self.store.get_next_staged_event_id_for_room(room_id)
)
+ assert next_staged_event_info
+ _, event_id = next_staged_event_info
self.assertEqual(event_id, "$fake_event_id_500")
+ def _setup_room_for_backfill_tests(self) -> _BackfillSetupInfo:
+ """
+ Sets up a room with various events and backward extremities to test
+ backfill functions against.
+
+ Returns:
+ _BackfillSetupInfo including the `room_id` to test against and
+ `depth_map` of events in the room
+ """
+ room_id = "!backfill-room-test:some-host"
+
+ # The silly graph we use to test grabbing backward extremities,
+ # where the top is the oldest events.
+ # 1 (oldest)
+ # |
+ # 2 ⹁
+ # | \
+ # | [b1, b2, b3]
+ # | |
+ # | A
+ # | /
+ # 3 {
+ # | \
+ # | [b4, b5, b6]
+ # | |
+ # | B
+ # | /
+ # 4 ´
+ # |
+ # 5 (newest)
+
+ event_graph: Dict[str, List[str]] = {
+ "1": [],
+ "2": ["1"],
+ "3": ["2", "A"],
+ "4": ["3", "B"],
+ "5": ["4"],
+ "A": ["b1", "b2", "b3"],
+ "b1": ["2"],
+ "b2": ["2"],
+ "b3": ["2"],
+ "B": ["b4", "b5", "b6"],
+ "b4": ["3"],
+ "b5": ["3"],
+ "b6": ["3"],
+ }
+
+ depth_map: Dict[str, int] = {
+ "1": 1,
+ "2": 2,
+ "b1": 3,
+ "b2": 3,
+ "b3": 3,
+ "A": 4,
+ "3": 5,
+ "b4": 6,
+ "b5": 6,
+ "b6": 6,
+ "B": 7,
+ "4": 8,
+ "5": 9,
+ }
+
+ # The events we have persisted on our server.
+ # The rest are events in the room but not backfilled tet.
+ our_server_events = {"5", "4", "B", "3", "A"}
+
+ complete_event_dict_map: Dict[str, JsonDict] = {}
+ stream_ordering = 0
+ for (event_id, prev_event_ids) in event_graph.items():
+ depth = depth_map[event_id]
+
+ complete_event_dict_map[event_id] = {
+ "event_id": event_id,
+ "type": "test_regular_type",
+ "room_id": room_id,
+ "sender": "@sender",
+ "prev_event_ids": prev_event_ids,
+ "auth_event_ids": [],
+ "origin_server_ts": stream_ordering,
+ "depth": depth,
+ "stream_ordering": stream_ordering,
+ "content": {"body": "event" + event_id},
+ }
+
+ stream_ordering += 1
+
+ def populate_db(txn: LoggingTransaction):
+ # Insert the room to satisfy the foreign key constraint of
+ # `event_failed_pull_attempts`
+ self.store.db_pool.simple_insert_txn(
+ txn,
+ "rooms",
+ {
+ "room_id": room_id,
+ "creator": "room_creator_user_id",
+ "is_public": True,
+ "room_version": "6",
+ },
+ )
+
+ # Insert our server events
+ for event_id in our_server_events:
+ event_dict = complete_event_dict_map[event_id]
+
+ self.store.db_pool.simple_insert_txn(
+ txn,
+ table="events",
+ values={
+ "event_id": event_dict.get("event_id"),
+ "type": event_dict.get("type"),
+ "room_id": event_dict.get("room_id"),
+ "depth": event_dict.get("depth"),
+ "topological_ordering": event_dict.get("depth"),
+ "stream_ordering": event_dict.get("stream_ordering"),
+ "processed": True,
+ "outlier": False,
+ },
+ )
+
+ # Insert the event edges
+ for event_id in our_server_events:
+ for prev_event_id in event_graph[event_id]:
+ self.store.db_pool.simple_insert_txn(
+ txn,
+ table="event_edges",
+ values={
+ "event_id": event_id,
+ "prev_event_id": prev_event_id,
+ "room_id": room_id,
+ },
+ )
+
+ # Insert the backward extremities
+ prev_events_of_our_events = {
+ prev_event_id
+ for our_server_event in our_server_events
+ for prev_event_id in complete_event_dict_map[our_server_event][
+ "prev_event_ids"
+ ]
+ }
+ backward_extremities = prev_events_of_our_events - our_server_events
+ for backward_extremity in backward_extremities:
+ self.store.db_pool.simple_insert_txn(
+ txn,
+ table="event_backward_extremities",
+ values={
+ "event_id": backward_extremity,
+ "room_id": room_id,
+ },
+ )
+
+ self.get_success(
+ self.store.db_pool.runInteraction(
+ "_setup_room_for_backfill_tests_populate_db",
+ populate_db,
+ )
+ )
+
+ return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map)
+
+ def test_get_backfill_points_in_room(self):
+ """
+ Test to make sure only backfill points that are older and come before
+ the `current_depth` are returned.
+ """
+ setup_info = self._setup_room_for_backfill_tests()
+ room_id = setup_info.room_id
+ depth_map = setup_info.depth_map
+
+ # Try at "B"
+ backfill_points = self.get_success(
+ self.store.get_backfill_points_in_room(room_id, depth_map["B"], limit=100)
+ )
+ backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
+ self.assertEqual(backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2", "b1"])
+
+ # Try at "A"
+ backfill_points = self.get_success(
+ self.store.get_backfill_points_in_room(room_id, depth_map["A"], limit=100)
+ )
+ backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
+ # Event "2" has a depth of 2 but is not included here because we only
+ # know the approximate depth of 5 from our event "3".
+ self.assertListEqual(backfill_event_ids, ["b3", "b2", "b1"])
+
+ def test_get_backfill_points_in_room_excludes_events_we_have_attempted(
+ self,
+ ):
+ """
+ Test to make sure that events we have attempted to backfill (and within
+ backoff timeout duration) do not show up as an event to backfill again.
+ """
+ setup_info = self._setup_room_for_backfill_tests()
+ room_id = setup_info.room_id
+ depth_map = setup_info.depth_map
+
+ # Record some attempts to backfill these events which will make
+ # `get_backfill_points_in_room` exclude them because we
+ # haven't passed the backoff interval.
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(room_id, "b5", "fake cause")
+ )
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(room_id, "b4", "fake cause")
+ )
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(room_id, "b3", "fake cause")
+ )
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(room_id, "b2", "fake cause")
+ )
+
+ # No time has passed since we attempted to backfill ^
+
+ # Try at "B"
+ backfill_points = self.get_success(
+ self.store.get_backfill_points_in_room(room_id, depth_map["B"], limit=100)
+ )
+ backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
+ # Only the backfill points that we didn't record earlier exist here.
+ self.assertEqual(backfill_event_ids, ["b6", "2", "b1"])
+
+ def test_get_backfill_points_in_room_attempted_event_retry_after_backoff_duration(
+ self,
+ ):
+ """
+ Test to make sure after we fake attempt to backfill event "b3" many times,
+ we can see retry and see the "b3" again after the backoff timeout duration
+ has exceeded.
+ """
+ setup_info = self._setup_room_for_backfill_tests()
+ room_id = setup_info.room_id
+ depth_map = setup_info.depth_map
+
+ # Record some attempts to backfill these events which will make
+ # `get_backfill_points_in_room` exclude them because we
+ # haven't passed the backoff interval.
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(room_id, "b3", "fake cause")
+ )
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(room_id, "b1", "fake cause")
+ )
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(room_id, "b1", "fake cause")
+ )
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(room_id, "b1", "fake cause")
+ )
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(room_id, "b1", "fake cause")
+ )
+
+ # Now advance time by 2 hours and we should only be able to see "b3"
+ # because we have waited long enough for the single attempt (2^1 hours)
+ # but we still shouldn't see "b1" because we haven't waited long enough
+ # for this many attempts. We didn't do anything to "b2" so it should be
+ # visible regardless.
+ self.reactor.advance(datetime.timedelta(hours=2).total_seconds())
+
+ # Try at "A" and make sure that "b1" is not in the list because we've
+ # already attempted many times
+ backfill_points = self.get_success(
+ self.store.get_backfill_points_in_room(room_id, depth_map["A"], limit=100)
+ )
+ backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
+ self.assertEqual(backfill_event_ids, ["b3", "b2"])
+
+ # Now advance time by 20 hours (above 2^4 because we made 4 attemps) and
+ # see if we can now backfill it
+ self.reactor.advance(datetime.timedelta(hours=20).total_seconds())
+
+ # Try at "A" again after we advanced enough time and we should see "b3" again
+ backfill_points = self.get_success(
+ self.store.get_backfill_points_in_room(room_id, depth_map["A"], limit=100)
+ )
+ backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
+ self.assertEqual(backfill_event_ids, ["b3", "b2", "b1"])
+
+ def test_get_backfill_points_in_room_works_after_many_failed_pull_attempts_that_could_naively_overflow(
+ self,
+ ) -> None:
+ """
+ A test that reproduces #13929 (Postgres only).
+
+ Test to make sure we can still get backfill points after many failed pull
+ attempts that cause us to backoff to the limit. Even if the backoff formula
+ would tell us to wait for more seconds than can be expressed in a 32 bit
+ signed int.
+ """
+ setup_info = self._setup_room_for_backfill_tests()
+ room_id = setup_info.room_id
+ depth_map = setup_info.depth_map
+
+ # Pretend that we have tried and failed 10 times to backfill event b1.
+ for _ in range(10):
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(room_id, "b1", "fake cause")
+ )
+
+ # If the backoff periods grow without limit:
+ # After the first failed attempt, we would have backed off for 1 << 1 = 2 hours.
+ # After the second failed attempt we would have backed off for 1 << 2 = 4 hours,
+ # so after the 10th failed attempt we should backoff for 1 << 10 == 1024 hours.
+ # Wait 1100 hours just so we have a nice round number.
+ self.reactor.advance(datetime.timedelta(hours=1100).total_seconds())
+
+ # 1024 hours in milliseconds is 1024 * 3600000, which exceeds the largest 32 bit
+ # signed integer. The bug we're reproducing is that this overflow causes an
+ # error in postgres preventing us from fetching a set of backwards extremities
+ # to retry fetching.
+ backfill_points = self.get_success(
+ self.store.get_backfill_points_in_room(room_id, depth_map["A"], limit=100)
+ )
+
+ # We should aim to fetch all backoff points: b1's latest backoff period has
+ # expired, and we haven't tried the rest.
+ backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
+ self.assertEqual(backfill_event_ids, ["b3", "b2", "b1"])
+
+ def _setup_room_for_insertion_backfill_tests(self) -> _BackfillSetupInfo:
+ """
+ Sets up a room with various insertion event backward extremities to test
+ backfill functions against.
+
+ Returns:
+ _BackfillSetupInfo including the `room_id` to test against and
+ `depth_map` of events in the room
+ """
+ room_id = "!backfill-room-test:some-host"
+
+ depth_map: Dict[str, int] = {
+ "1": 1,
+ "2": 2,
+ "insertion_eventA": 3,
+ "3": 4,
+ "insertion_eventB": 5,
+ "4": 6,
+ "5": 7,
+ }
+
+ def populate_db(txn: LoggingTransaction):
+ # Insert the room to satisfy the foreign key constraint of
+ # `event_failed_pull_attempts`
+ self.store.db_pool.simple_insert_txn(
+ txn,
+ "rooms",
+ {
+ "room_id": room_id,
+ "creator": "room_creator_user_id",
+ "is_public": True,
+ "room_version": "6",
+ },
+ )
+
+ # Insert our server events
+ stream_ordering = 0
+ for event_id, depth in depth_map.items():
+ self.store.db_pool.simple_insert_txn(
+ txn,
+ table="events",
+ values={
+ "event_id": event_id,
+ "type": EventTypes.MSC2716_INSERTION
+ if event_id.startswith("insertion_event")
+ else "test_regular_type",
+ "room_id": room_id,
+ "depth": depth,
+ "topological_ordering": depth,
+ "stream_ordering": stream_ordering,
+ "processed": True,
+ "outlier": False,
+ },
+ )
+
+ if event_id.startswith("insertion_event"):
+ self.store.db_pool.simple_insert_txn(
+ txn,
+ table="insertion_event_extremities",
+ values={
+ "event_id": event_id,
+ "room_id": room_id,
+ },
+ )
+
+ stream_ordering += 1
+
+ self.get_success(
+ self.store.db_pool.runInteraction(
+ "_setup_room_for_insertion_backfill_tests_populate_db",
+ populate_db,
+ )
+ )
+
+ return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map)
+
+ def test_get_insertion_event_backward_extremities_in_room(self):
+ """
+ Test to make sure only insertion event backward extremities that are
+ older and come before the `current_depth` are returned.
+ """
+ setup_info = self._setup_room_for_insertion_backfill_tests()
+ room_id = setup_info.room_id
+ depth_map = setup_info.depth_map
+
+ # Try at "insertion_eventB"
+ backfill_points = self.get_success(
+ self.store.get_insertion_event_backward_extremities_in_room(
+ room_id, depth_map["insertion_eventB"], limit=100
+ )
+ )
+ backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
+ self.assertEqual(backfill_event_ids, ["insertion_eventB", "insertion_eventA"])
+
+ # Try at "insertion_eventA"
+ backfill_points = self.get_success(
+ self.store.get_insertion_event_backward_extremities_in_room(
+ room_id, depth_map["insertion_eventA"], limit=100
+ )
+ )
+ backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
+ # Event "2" has a depth of 2 but is not included here because we only
+ # know the approximate depth of 5 from our event "3".
+ self.assertListEqual(backfill_event_ids, ["insertion_eventA"])
+
+ def test_get_insertion_event_backward_extremities_in_room_excludes_events_we_have_attempted(
+ self,
+ ):
+ """
+ Test to make sure that insertion events we have attempted to backfill
+ (and within backoff timeout duration) do not show up as an event to
+ backfill again.
+ """
+ setup_info = self._setup_room_for_insertion_backfill_tests()
+ room_id = setup_info.room_id
+ depth_map = setup_info.depth_map
+
+ # Record some attempts to backfill these events which will make
+ # `get_insertion_event_backward_extremities_in_room` exclude them
+ # because we haven't passed the backoff interval.
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(
+ room_id, "insertion_eventA", "fake cause"
+ )
+ )
+
+ # No time has passed since we attempted to backfill ^
+
+ # Try at "insertion_eventB"
+ backfill_points = self.get_success(
+ self.store.get_insertion_event_backward_extremities_in_room(
+ room_id, depth_map["insertion_eventB"], limit=100
+ )
+ )
+ backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
+ # Only the backfill points that we didn't record earlier exist here.
+ self.assertEqual(backfill_event_ids, ["insertion_eventB"])
+
+ def test_get_insertion_event_backward_extremities_in_room_attempted_event_retry_after_backoff_duration(
+ self,
+ ):
+ """
+ Test to make sure after we fake attempt to backfill event
+ "insertion_eventA" many times, we can see retry and see the
+ "insertion_eventA" again after the backoff timeout duration has
+ exceeded.
+ """
+ setup_info = self._setup_room_for_insertion_backfill_tests()
+ room_id = setup_info.room_id
+ depth_map = setup_info.depth_map
+
+ # Record some attempts to backfill these events which will make
+ # `get_backfill_points_in_room` exclude them because we
+ # haven't passed the backoff interval.
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(
+ room_id, "insertion_eventB", "fake cause"
+ )
+ )
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(
+ room_id, "insertion_eventA", "fake cause"
+ )
+ )
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(
+ room_id, "insertion_eventA", "fake cause"
+ )
+ )
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(
+ room_id, "insertion_eventA", "fake cause"
+ )
+ )
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(
+ room_id, "insertion_eventA", "fake cause"
+ )
+ )
+
+ # Now advance time by 2 hours and we should only be able to see
+ # "insertion_eventB" because we have waited long enough for the single
+ # attempt (2^1 hours) but we still shouldn't see "insertion_eventA"
+ # because we haven't waited long enough for this many attempts.
+ self.reactor.advance(datetime.timedelta(hours=2).total_seconds())
+
+ # Try at "insertion_eventA" and make sure that "insertion_eventA" is not
+ # in the list because we've already attempted many times
+ backfill_points = self.get_success(
+ self.store.get_insertion_event_backward_extremities_in_room(
+ room_id, depth_map["insertion_eventA"], limit=100
+ )
+ )
+ backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
+ self.assertEqual(backfill_event_ids, [])
+
+ # Now advance time by 20 hours (above 2^4 because we made 4 attemps) and
+ # see if we can now backfill it
+ self.reactor.advance(datetime.timedelta(hours=20).total_seconds())
+
+ # Try at "insertion_eventA" again after we advanced enough time and we
+ # should see "insertion_eventA" again
+ backfill_points = self.get_success(
+ self.store.get_insertion_event_backward_extremities_in_room(
+ room_id, depth_map["insertion_eventA"], limit=100
+ )
+ )
+ backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
+ self.assertEqual(backfill_event_ids, ["insertion_eventA"])
+
+ def test_get_event_ids_to_not_pull_from_backoff(
+ self,
+ ):
+ """
+ Test to make sure only event IDs we should backoff from are returned.
+ """
+ # Create the room
+ user_id = self.register_user("alice", "test")
+ tok = self.login("alice", "test")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(
+ room_id, "$failed_event_id", "fake cause"
+ )
+ )
+
+ event_ids_to_backoff = self.get_success(
+ self.store.get_event_ids_to_not_pull_from_backoff(
+ room_id=room_id, event_ids=["$failed_event_id", "$normal_event_id"]
+ )
+ )
+
+ self.assertEqual(event_ids_to_backoff, ["$failed_event_id"])
+
+ def test_get_event_ids_to_not_pull_from_backoff_retry_after_backoff_duration(
+ self,
+ ):
+ """
+ Test to make sure no event IDs are returned after the backoff duration has
+ elapsed.
+ """
+ # Create the room
+ user_id = self.register_user("alice", "test")
+ tok = self.login("alice", "test")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(
+ room_id, "$failed_event_id", "fake cause"
+ )
+ )
+
+ # Now advance time by 2 hours so we wait long enough for the single failed
+ # attempt (2^1 hours).
+ self.reactor.advance(datetime.timedelta(hours=2).total_seconds())
+
+ event_ids_to_backoff = self.get_success(
+ self.store.get_event_ids_to_not_pull_from_backoff(
+ room_id=room_id, event_ids=["$failed_event_id", "$normal_event_id"]
+ )
+ )
+ # Since this function only returns events we should backoff from, time has
+ # elapsed past the backoff range so there is no events to backoff from.
+ self.assertEqual(event_ids_to_backoff, [])
+
@attr.s
class FakeEvent:
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index 088fbb247b..6f1135eef4 100644
--- a/tests/storage/test_event_metrics.py
+++ b/tests/storage/test_event_metrics.py
@@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from prometheus_client import generate_latest
-from synapse.metrics import REGISTRY, generate_latest
+from synapse.metrics import REGISTRY
from synapse.types import UserID, create_requester
from tests.unittest import HomeserverTestCase
@@ -53,8 +54,8 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
items = list(
filter(
- lambda x: b"synapse_forward_extremities_" in x,
- generate_latest(REGISTRY, emit_help=False).split(b"\n"),
+ lambda x: b"synapse_forward_extremities_" in x and b"# HELP" not in x,
+ generate_latest(REGISTRY).split(b"\n"),
)
)
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index ba40124c8a..ee48920f84 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -12,18 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional, Tuple
+
from twisted.test.proto_helpers import MemoryReactor
+from synapse.api.constants import MAIN_TIMELINE, RelationTypes
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.storage.databases.main.event_push_actions import NotifCounts
+from synapse.types import JsonDict
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
-USER_ID = "@user:example.com"
-
class EventPushActionsStoreTestCase(HomeserverTestCase):
servlets = [
@@ -38,21 +40,13 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
assert persist_events_store is not None
self.persist_events_store = persist_events_store
- def test_get_unread_push_actions_for_user_in_range_for_http(self) -> None:
- self.get_success(
- self.store.get_unread_push_actions_for_user_in_range_for_http(
- USER_ID, 0, 1000, 20
- )
- )
+ def _create_users_and_room(self) -> Tuple[str, str, str, str, str]:
+ """
+ Creates two users and a shared room.
- def test_get_unread_push_actions_for_user_in_range_for_email(self) -> None:
- self.get_success(
- self.store.get_unread_push_actions_for_user_in_range_for_email(
- USER_ID, 0, 1000, 20
- )
- )
-
- def test_count_aggregation(self) -> None:
+ Returns:
+ Tuple of (user 1 ID, user 1 token, user 2 ID, user 2 token, room ID).
+ """
# Create a user to receive notifications and send receipts.
user_id = self.register_user("user1235", "pass")
token = self.login("user1235", "pass")
@@ -65,11 +59,104 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
room_id = self.helper.create_room_as(user_id, tok=token)
self.helper.join(room_id, other_id, tok=other_token)
+ return user_id, token, other_id, other_token, room_id
+
+ def test_get_unread_push_actions_for_user_in_range(self) -> None:
+ """Test getting unread push actions for HTTP and email pushers."""
+ user_id, token, _, other_token, room_id = self._create_users_and_room()
+
+ # Create two events, one of which is a highlight.
+ first_event_id = self.helper.send_event(
+ room_id,
+ type="m.room.message",
+ content={"msgtype": "m.text", "body": "msg"},
+ tok=other_token,
+ )["event_id"]
+ second_event_id = self.helper.send_event(
+ room_id,
+ type="m.room.message",
+ content={
+ "msgtype": "m.text",
+ "body": user_id,
+ "m.relates_to": {
+ "rel_type": RelationTypes.THREAD,
+ "event_id": first_event_id,
+ },
+ },
+ tok=other_token,
+ )["event_id"]
+
+ # Fetch unread actions for HTTP pushers.
+ http_actions = self.get_success(
+ self.store.get_unread_push_actions_for_user_in_range_for_http(
+ user_id, 0, 1000, 20
+ )
+ )
+ self.assertEqual(2, len(http_actions))
+
+ # Fetch unread actions for email pushers.
+ email_actions = self.get_success(
+ self.store.get_unread_push_actions_for_user_in_range_for_email(
+ user_id, 0, 1000, 20
+ )
+ )
+ self.assertEqual(2, len(email_actions))
+
+ # Send a receipt, which should clear the first action.
+ self.get_success(
+ self.store.insert_receipt(
+ room_id,
+ "m.read",
+ user_id=user_id,
+ event_ids=[first_event_id],
+ thread_id=None,
+ data={},
+ )
+ )
+ http_actions = self.get_success(
+ self.store.get_unread_push_actions_for_user_in_range_for_http(
+ user_id, 0, 1000, 20
+ )
+ )
+ self.assertEqual(1, len(http_actions))
+ email_actions = self.get_success(
+ self.store.get_unread_push_actions_for_user_in_range_for_email(
+ user_id, 0, 1000, 20
+ )
+ )
+ self.assertEqual(1, len(email_actions))
+
+ # Send a thread receipt to clear the thread action.
+ self.get_success(
+ self.store.insert_receipt(
+ room_id,
+ "m.read",
+ user_id=user_id,
+ event_ids=[second_event_id],
+ thread_id=first_event_id,
+ data={},
+ )
+ )
+ http_actions = self.get_success(
+ self.store.get_unread_push_actions_for_user_in_range_for_http(
+ user_id, 0, 1000, 20
+ )
+ )
+ self.assertEqual([], http_actions)
+ email_actions = self.get_success(
+ self.store.get_unread_push_actions_for_user_in_range_for_email(
+ user_id, 0, 1000, 20
+ )
+ )
+ self.assertEqual([], email_actions)
+
+ def test_count_aggregation(self) -> None:
+ # Create a user to receive notifications and send receipts.
+ user_id, token, _, other_token, room_id = self._create_users_and_room()
+
last_event_id: str
- def _assert_counts(
- noitf_count: int, unread_count: int, highlight_count: int
- ) -> None:
+ def _assert_counts(noitf_count: int, highlight_count: int) -> None:
counts = self.get_success(
self.store.db_pool.runInteraction(
"get-unread-counts",
@@ -79,13 +166,14 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
)
)
self.assertEqual(
- counts,
+ counts.main_timeline,
NotifCounts(
notify_count=noitf_count,
- unread_count=unread_count,
+ unread_count=0,
highlight_count=highlight_count,
),
)
+ self.assertEqual(counts.threads, {})
def _create_event(highlight: bool = False) -> str:
result = self.helper.send_event(
@@ -108,63 +196,518 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
"m.read",
user_id=user_id,
event_ids=[event_id],
+ thread_id=None,
data={},
)
)
- _assert_counts(0, 0, 0)
+ _assert_counts(0, 0)
_create_event()
- _assert_counts(1, 1, 0)
+ _assert_counts(1, 0)
_rotate()
- _assert_counts(1, 1, 0)
+ _assert_counts(1, 0)
event_id = _create_event()
- _assert_counts(2, 2, 0)
+ _assert_counts(2, 0)
_rotate()
- _assert_counts(2, 2, 0)
+ _assert_counts(2, 0)
_create_event()
_mark_read(event_id)
- _assert_counts(1, 1, 0)
+ _assert_counts(1, 0)
_mark_read(last_event_id)
- _assert_counts(0, 0, 0)
+ _assert_counts(0, 0)
_create_event()
+ _assert_counts(1, 0)
_rotate()
- _assert_counts(1, 1, 0)
+ _assert_counts(1, 0)
# Delete old event push actions, this should not affect the (summarised) count.
+ #
+ # All event push actions are kept for 24 hours, so need to move forward
+ # in time.
+ self.pump(60 * 60 * 24)
self.get_success(self.store._remove_old_push_actions_that_have_rotated())
- _assert_counts(1, 1, 0)
+ # Double check that the event push actions have been cleared (i.e. that
+ # any results *must* come from the summary).
+ result = self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="event_push_actions",
+ keyvalues={"1": 1},
+ retcols=("event_id",),
+ desc="",
+ )
+ )
+ self.assertEqual(result, [])
+ _assert_counts(1, 0)
_mark_read(last_event_id)
- _assert_counts(0, 0, 0)
+ _assert_counts(0, 0)
event_id = _create_event(True)
- _assert_counts(1, 1, 1)
+ _assert_counts(1, 1)
_rotate()
- _assert_counts(1, 1, 1)
+ _assert_counts(1, 1)
# Check that adding another notification and rotating after highlight
# works.
_create_event()
_rotate()
- _assert_counts(2, 2, 1)
+ _assert_counts(2, 1)
# Check that sending read receipts at different points results in the
# right counts.
_mark_read(event_id)
- _assert_counts(1, 1, 0)
+ _assert_counts(1, 0)
_mark_read(last_event_id)
- _assert_counts(0, 0, 0)
+ _assert_counts(0, 0)
_create_event(True)
- _assert_counts(1, 1, 1)
+ _assert_counts(1, 1)
_mark_read(last_event_id)
- _assert_counts(0, 0, 0)
+ _assert_counts(0, 0)
_rotate()
- _assert_counts(0, 0, 0)
+ _assert_counts(0, 0)
+
+ def test_count_aggregation_threads(self) -> None:
+ """
+ This is essentially the same test as test_count_aggregation, but adds
+ events to the main timeline and to a thread.
+ """
+
+ user_id, token, _, other_token, room_id = self._create_users_and_room()
+ thread_id: str
+
+ last_event_id: str
+
+ def _assert_counts(
+ noitf_count: int,
+ highlight_count: int,
+ thread_notif_count: int,
+ thread_highlight_count: int,
+ ) -> None:
+ counts = self.get_success(
+ self.store.db_pool.runInteraction(
+ "get-unread-counts",
+ self.store._get_unread_counts_by_receipt_txn,
+ room_id,
+ user_id,
+ )
+ )
+ self.assertEqual(
+ counts.main_timeline,
+ NotifCounts(
+ notify_count=noitf_count,
+ unread_count=0,
+ highlight_count=highlight_count,
+ ),
+ )
+ if thread_notif_count or thread_highlight_count:
+ self.assertEqual(
+ counts.threads,
+ {
+ thread_id: NotifCounts(
+ notify_count=thread_notif_count,
+ unread_count=0,
+ highlight_count=thread_highlight_count,
+ ),
+ },
+ )
+ else:
+ self.assertEqual(counts.threads, {})
+
+ def _create_event(
+ highlight: bool = False, thread_id: Optional[str] = None
+ ) -> str:
+ content: JsonDict = {
+ "msgtype": "m.text",
+ "body": user_id if highlight else "msg",
+ }
+ if thread_id:
+ content["m.relates_to"] = {
+ "rel_type": "m.thread",
+ "event_id": thread_id,
+ }
+
+ result = self.helper.send_event(
+ room_id,
+ type="m.room.message",
+ content=content,
+ tok=other_token,
+ )
+ nonlocal last_event_id
+ last_event_id = result["event_id"]
+ return last_event_id
+
+ def _rotate() -> None:
+ self.get_success(self.store._rotate_notifs())
+
+ def _mark_read(event_id: str, thread_id: str = MAIN_TIMELINE) -> None:
+ self.get_success(
+ self.store.insert_receipt(
+ room_id,
+ "m.read",
+ user_id=user_id,
+ event_ids=[event_id],
+ thread_id=thread_id,
+ data={},
+ )
+ )
+
+ _assert_counts(0, 0, 0, 0)
+ thread_id = _create_event()
+ _assert_counts(1, 0, 0, 0)
+ _rotate()
+ _assert_counts(1, 0, 0, 0)
+
+ _create_event(thread_id=thread_id)
+ _assert_counts(1, 0, 1, 0)
+ _rotate()
+ _assert_counts(1, 0, 1, 0)
+
+ _create_event()
+ _assert_counts(2, 0, 1, 0)
+ _rotate()
+ _assert_counts(2, 0, 1, 0)
+
+ event_id = _create_event(thread_id=thread_id)
+ _assert_counts(2, 0, 2, 0)
+ _rotate()
+ _assert_counts(2, 0, 2, 0)
+
+ _create_event()
+ _create_event(thread_id=thread_id)
+ _mark_read(event_id)
+ _assert_counts(1, 0, 3, 0)
+ _mark_read(event_id, thread_id)
+ _assert_counts(1, 0, 1, 0)
+
+ _mark_read(last_event_id)
+ _mark_read(last_event_id, thread_id)
+ _assert_counts(0, 0, 0, 0)
+
+ _create_event()
+ _create_event(thread_id=thread_id)
+ _assert_counts(1, 0, 1, 0)
+ _rotate()
+ _assert_counts(1, 0, 1, 0)
+
+ # Delete old event push actions, this should not affect the (summarised) count.
+ self.get_success(self.store._remove_old_push_actions_that_have_rotated())
+ _assert_counts(1, 0, 1, 0)
+
+ _mark_read(last_event_id)
+ _mark_read(last_event_id, thread_id)
+ _assert_counts(0, 0, 0, 0)
+
+ _create_event(True)
+ _assert_counts(1, 1, 0, 0)
+ _rotate()
+ _assert_counts(1, 1, 0, 0)
+
+ event_id = _create_event(True, thread_id)
+ _assert_counts(1, 1, 1, 1)
+ _rotate()
+ _assert_counts(1, 1, 1, 1)
+
+ # Check that adding another notification and rotating after highlight
+ # works.
+ _create_event()
+ _rotate()
+ _assert_counts(2, 1, 1, 1)
+
+ _create_event(thread_id=thread_id)
+ _rotate()
+ _assert_counts(2, 1, 2, 1)
+
+ # Check that sending read receipts at different points results in the
+ # right counts.
+ _mark_read(event_id)
+ _assert_counts(1, 0, 2, 1)
+ _mark_read(event_id, thread_id)
+ _assert_counts(1, 0, 1, 0)
+ _mark_read(last_event_id)
+ _assert_counts(0, 0, 1, 0)
+ _mark_read(last_event_id, thread_id)
+ _assert_counts(0, 0, 0, 0)
+
+ _create_event(True)
+ _create_event(True, thread_id)
+ _assert_counts(1, 1, 1, 1)
+ _mark_read(last_event_id)
+ _mark_read(last_event_id, thread_id)
+ _assert_counts(0, 0, 0, 0)
+ _rotate()
+ _assert_counts(0, 0, 0, 0)
+
+ def test_count_aggregation_mixed(self) -> None:
+ """
+ This is essentially the same test as test_count_aggregation_threads, but
+ sends both unthreaded and threaded receipts.
+ """
+
+ user_id, token, _, other_token, room_id = self._create_users_and_room()
+ thread_id: str
+
+ last_event_id: str
+
+ def _assert_counts(
+ noitf_count: int,
+ highlight_count: int,
+ thread_notif_count: int,
+ thread_highlight_count: int,
+ ) -> None:
+ counts = self.get_success(
+ self.store.db_pool.runInteraction(
+ "get-unread-counts",
+ self.store._get_unread_counts_by_receipt_txn,
+ room_id,
+ user_id,
+ )
+ )
+ self.assertEqual(
+ counts.main_timeline,
+ NotifCounts(
+ notify_count=noitf_count,
+ unread_count=0,
+ highlight_count=highlight_count,
+ ),
+ )
+ if thread_notif_count or thread_highlight_count:
+ self.assertEqual(
+ counts.threads,
+ {
+ thread_id: NotifCounts(
+ notify_count=thread_notif_count,
+ unread_count=0,
+ highlight_count=thread_highlight_count,
+ ),
+ },
+ )
+ else:
+ self.assertEqual(counts.threads, {})
+
+ def _create_event(
+ highlight: bool = False, thread_id: Optional[str] = None
+ ) -> str:
+ content: JsonDict = {
+ "msgtype": "m.text",
+ "body": user_id if highlight else "msg",
+ }
+ if thread_id:
+ content["m.relates_to"] = {
+ "rel_type": "m.thread",
+ "event_id": thread_id,
+ }
+
+ result = self.helper.send_event(
+ room_id,
+ type="m.room.message",
+ content=content,
+ tok=other_token,
+ )
+ nonlocal last_event_id
+ last_event_id = result["event_id"]
+ return last_event_id
+
+ def _rotate() -> None:
+ self.get_success(self.store._rotate_notifs())
+
+ def _mark_read(event_id: str, thread_id: Optional[str] = None) -> None:
+ self.get_success(
+ self.store.insert_receipt(
+ room_id,
+ "m.read",
+ user_id=user_id,
+ event_ids=[event_id],
+ thread_id=thread_id,
+ data={},
+ )
+ )
+
+ _assert_counts(0, 0, 0, 0)
+ thread_id = _create_event()
+ _assert_counts(1, 0, 0, 0)
+ _rotate()
+ _assert_counts(1, 0, 0, 0)
+
+ _create_event(thread_id=thread_id)
+ _assert_counts(1, 0, 1, 0)
+ _rotate()
+ _assert_counts(1, 0, 1, 0)
+
+ _create_event()
+ _assert_counts(2, 0, 1, 0)
+ _rotate()
+ _assert_counts(2, 0, 1, 0)
+
+ event_id = _create_event(thread_id=thread_id)
+ _assert_counts(2, 0, 2, 0)
+ _rotate()
+ _assert_counts(2, 0, 2, 0)
+
+ _create_event()
+ _create_event(thread_id=thread_id)
+ _mark_read(event_id)
+ _assert_counts(1, 0, 1, 0)
+
+ _mark_read(last_event_id, MAIN_TIMELINE)
+ _mark_read(last_event_id, thread_id)
+ _assert_counts(0, 0, 0, 0)
+
+ _create_event()
+ _create_event(thread_id=thread_id)
+ _assert_counts(1, 0, 1, 0)
+ _rotate()
+ _assert_counts(1, 0, 1, 0)
+
+ # Delete old event push actions, this should not affect the (summarised) count.
+ self.get_success(self.store._remove_old_push_actions_that_have_rotated())
+ _assert_counts(1, 0, 1, 0)
+
+ _mark_read(last_event_id)
+ _assert_counts(0, 0, 0, 0)
+
+ _create_event(True)
+ _assert_counts(1, 1, 0, 0)
+ _rotate()
+ _assert_counts(1, 1, 0, 0)
+
+ event_id = _create_event(True, thread_id)
+ _assert_counts(1, 1, 1, 1)
+ _rotate()
+ _assert_counts(1, 1, 1, 1)
+
+ # Check that adding another notification and rotating after highlight
+ # works.
+ _create_event()
+ _rotate()
+ _assert_counts(2, 1, 1, 1)
+
+ _create_event(thread_id=thread_id)
+ _rotate()
+ _assert_counts(2, 1, 2, 1)
+
+ # Check that sending read receipts at different points results in the
+ # right counts.
+ _mark_read(event_id)
+ _assert_counts(1, 0, 1, 0)
+ _mark_read(event_id, MAIN_TIMELINE)
+ _assert_counts(1, 0, 1, 0)
+ _mark_read(last_event_id, MAIN_TIMELINE)
+ _assert_counts(0, 0, 1, 0)
+ _mark_read(last_event_id, thread_id)
+ _assert_counts(0, 0, 0, 0)
+
+ _create_event(True)
+ _create_event(True, thread_id)
+ _assert_counts(1, 1, 1, 1)
+ _mark_read(last_event_id)
+ _assert_counts(0, 0, 0, 0)
+ _rotate()
+ _assert_counts(0, 0, 0, 0)
+
+ def test_recursive_thread(self) -> None:
+ """
+ Events related to events in a thread should still be considered part of
+ that thread.
+ """
+
+ # Create a user to receive notifications and send receipts.
+ user_id = self.register_user("user1235", "pass")
+ token = self.login("user1235", "pass")
+
+ # And another users to send events.
+ other_id = self.register_user("other", "pass")
+ other_token = self.login("other", "pass")
+
+ # Create a room and put both users in it.
+ room_id = self.helper.create_room_as(user_id, tok=token)
+ self.helper.join(room_id, other_id, tok=other_token)
+
+ # Update the user's push rules to care about reaction events.
+ self.get_success(
+ self.store.add_push_rule(
+ user_id,
+ "related_events",
+ priority_class=5,
+ conditions=[
+ {"kind": "event_match", "key": "type", "pattern": "m.reaction"}
+ ],
+ actions=["notify"],
+ )
+ )
+
+ def _create_event(type: str, content: JsonDict) -> str:
+ result = self.helper.send_event(
+ room_id, type=type, content=content, tok=other_token
+ )
+ return result["event_id"]
+
+ def _assert_counts(noitf_count: int, thread_notif_count: int) -> None:
+ counts = self.get_success(
+ self.store.db_pool.runInteraction(
+ "get-unread-counts",
+ self.store._get_unread_counts_by_receipt_txn,
+ room_id,
+ user_id,
+ )
+ )
+ self.assertEqual(
+ counts.main_timeline,
+ NotifCounts(
+ notify_count=noitf_count, unread_count=0, highlight_count=0
+ ),
+ )
+ if thread_notif_count:
+ self.assertEqual(
+ counts.threads,
+ {
+ thread_id: NotifCounts(
+ notify_count=thread_notif_count,
+ unread_count=0,
+ highlight_count=0,
+ ),
+ },
+ )
+ else:
+ self.assertEqual(counts.threads, {})
+
+ # Create a root event.
+ thread_id = _create_event(
+ "m.room.message", {"msgtype": "m.text", "body": "msg"}
+ )
+ _assert_counts(1, 0)
+
+ # Reply, creating a thread.
+ reply_id = _create_event(
+ "m.room.message",
+ {
+ "msgtype": "m.text",
+ "body": "msg",
+ "m.relates_to": {
+ "rel_type": "m.thread",
+ "event_id": thread_id,
+ },
+ },
+ )
+ _assert_counts(1, 1)
+
+ # Create an event related to a thread event, this should still appear in
+ # the thread.
+ _create_event(
+ type="m.reaction",
+ content={
+ "m.relates_to": {
+ "rel_type": "m.annotation",
+ "event_id": reply_id,
+ "key": "A",
+ }
+ },
+ )
+ _assert_counts(1, 2)
def test_find_first_stream_ordering_after_ts(self) -> None:
def add_event(so: int, ts: int) -> None:
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 2d8d1f860f..d6a2b8d274 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -16,15 +16,157 @@ from typing import List, Optional
from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.engines import IncorrectDatabaseSetup
-from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
from tests.utils import USE_POSTGRES_FOR_TESTS
+class StreamIdGeneratorTestCase(HomeserverTestCase):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.db_pool: DatabasePool = self.store.db_pool
+
+ self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
+
+ def _setup_db(self, txn: LoggingTransaction) -> None:
+ txn.execute(
+ """
+ CREATE TABLE foobar (
+ stream_id BIGINT NOT NULL,
+ data TEXT
+ );
+ """
+ )
+ txn.execute("INSERT INTO foobar VALUES (123, 'hello world');")
+
+ def _create_id_generator(self) -> StreamIdGenerator:
+ def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator:
+ return StreamIdGenerator(
+ db_conn=conn,
+ table="foobar",
+ column="stream_id",
+ )
+
+ return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
+
+ def test_initial_value(self) -> None:
+ """Check that we read the current token from the DB."""
+ id_gen = self._create_id_generator()
+ self.assertEqual(id_gen.get_current_token(), 123)
+
+ def test_single_gen_next(self) -> None:
+ """Check that we correctly increment the current token from the DB."""
+ id_gen = self._create_id_generator()
+
+ async def test_gen_next() -> None:
+ async with id_gen.get_next() as next_id:
+ # We haven't persisted `next_id` yet; current token is still 123
+ self.assertEqual(id_gen.get_current_token(), 123)
+ # But we did learn what the next value is
+ self.assertEqual(next_id, 124)
+
+ # Once the context manager closes we assume that the `next_id` has been
+ # written to the DB.
+ self.assertEqual(id_gen.get_current_token(), 124)
+
+ self.get_success(test_gen_next())
+
+ def test_multiple_gen_nexts(self) -> None:
+ """Check that we handle overlapping calls to gen_next sensibly."""
+ id_gen = self._create_id_generator()
+
+ async def test_gen_next() -> None:
+ ctx1 = id_gen.get_next()
+ ctx2 = id_gen.get_next()
+ ctx3 = id_gen.get_next()
+
+ # Request three new stream IDs.
+ self.assertEqual(await ctx1.__aenter__(), 124)
+ self.assertEqual(await ctx2.__aenter__(), 125)
+ self.assertEqual(await ctx3.__aenter__(), 126)
+
+ # None are persisted: current token unchanged.
+ self.assertEqual(id_gen.get_current_token(), 123)
+
+ # Persist each in turn.
+ await ctx1.__aexit__(None, None, None)
+ self.assertEqual(id_gen.get_current_token(), 124)
+ await ctx2.__aexit__(None, None, None)
+ self.assertEqual(id_gen.get_current_token(), 125)
+ await ctx3.__aexit__(None, None, None)
+ self.assertEqual(id_gen.get_current_token(), 126)
+
+ self.get_success(test_gen_next())
+
+ def test_multiple_gen_nexts_closed_in_different_order(self) -> None:
+ """Check that we handle overlapping calls to gen_next, even when their IDs
+ created and persisted in different orders."""
+ id_gen = self._create_id_generator()
+
+ async def test_gen_next() -> None:
+ ctx1 = id_gen.get_next()
+ ctx2 = id_gen.get_next()
+ ctx3 = id_gen.get_next()
+
+ # Request three new stream IDs.
+ self.assertEqual(await ctx1.__aenter__(), 124)
+ self.assertEqual(await ctx2.__aenter__(), 125)
+ self.assertEqual(await ctx3.__aenter__(), 126)
+
+ # None are persisted: current token unchanged.
+ self.assertEqual(id_gen.get_current_token(), 123)
+
+ # Persist them in a different order, starting with 126 from ctx3.
+ await ctx3.__aexit__(None, None, None)
+ # We haven't persisted 124 from ctx1 yet---current token is still 123.
+ self.assertEqual(id_gen.get_current_token(), 123)
+
+ # Now persist 124 from ctx1.
+ await ctx1.__aexit__(None, None, None)
+ # Current token is then 124, waiting for 125 to be persisted.
+ self.assertEqual(id_gen.get_current_token(), 124)
+
+ # Finally persist 125 from ctx2.
+ await ctx2.__aexit__(None, None, None)
+ # Current token is then 126 (skipping over 125).
+ self.assertEqual(id_gen.get_current_token(), 126)
+
+ self.get_success(test_gen_next())
+
+ def test_gen_next_while_still_waiting_for_persistence(self) -> None:
+ """Check that we handle overlapping calls to gen_next."""
+ id_gen = self._create_id_generator()
+
+ async def test_gen_next() -> None:
+ ctx1 = id_gen.get_next()
+ ctx2 = id_gen.get_next()
+ ctx3 = id_gen.get_next()
+
+ # Request two new stream IDs.
+ self.assertEqual(await ctx1.__aenter__(), 124)
+ self.assertEqual(await ctx2.__aenter__(), 125)
+
+ # Persist ctx2 first.
+ await ctx2.__aexit__(None, None, None)
+ # Still waiting on ctx1's ID to be persisted.
+ self.assertEqual(id_gen.get_current_token(), 123)
+
+ # Now request a third stream ID. It should be 126 (the smallest ID that
+ # we've not yet handed out.)
+ self.assertEqual(await ctx3.__aenter__(), 126)
+
+ self.get_success(test_gen_next())
+
+
class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
@@ -48,9 +190,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
)
def _create_id_generator(
- self, instance_name="master", writers: Optional[List[str]] = None
+ self, instance_name: str = "master", writers: Optional[List[str]] = None
) -> MultiWriterIdGenerator:
- def _create(conn):
+ def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
return MultiWriterIdGenerator(
conn,
self.db_pool,
@@ -446,7 +588,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self._insert_row_with_id("master", 3)
# Now we add a row *without* updating the stream ID
- def _insert(txn):
+ def _insert(txn: Cursor) -> None:
txn.execute("INSERT INTO foobar VALUES (26, 'master')")
self.get_success(self.db_pool.runInteraction("_insert", _insert))
@@ -481,9 +623,9 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
)
def _create_id_generator(
- self, instance_name="master", writers: Optional[List[str]] = None
+ self, instance_name: str = "master", writers: Optional[List[str]] = None
) -> MultiWriterIdGenerator:
- def _create(conn):
+ def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
return MultiWriterIdGenerator(
conn,
self.db_pool,
@@ -617,9 +759,9 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
)
def _create_id_generator(
- self, instance_name="master", writers: Optional[List[str]] = None
+ self, instance_name: str = "master", writers: Optional[List[str]] = None
) -> MultiWriterIdGenerator:
- def _create(conn):
+ def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
return MultiWriterIdGenerator(
conn,
self.db_pool,
@@ -641,7 +783,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
instance_name: str,
number: int,
update_stream_table: bool = True,
- ):
+ ) -> None:
"""Insert N rows as the given instance, inserting with stream IDs pulled
from the postgres sequence.
"""
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index e8b4a5644b..c55c4db970 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -96,8 +96,12 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
# Test each of the registered users is marked as active
timestamp = self.get_success(self.store.user_last_seen_monthly_active(user1))
+ # Mypy notes that one shouldn't compare Optional[int] to 0 with assertGreater.
+ # Check that timestamp really is an int.
+ assert timestamp is not None
self.assertGreater(timestamp, 0)
timestamp = self.get_success(self.store.user_last_seen_monthly_active(user2))
+ assert timestamp is not None
self.assertGreater(timestamp, 0)
# Test that users with reserved 3pids are not removed from the MAU table
@@ -166,10 +170,11 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.upsert_monthly_active_user(user_id2))
result = self.get_success(self.store.user_last_seen_monthly_active(user_id1))
+ assert result is not None
self.assertGreater(result, 0)
result = self.get_success(self.store.user_last_seen_monthly_active(user_id3))
- self.assertNotEqual(result, 0)
+ self.assertIsNone(result)
@override_config({"max_mau_value": 5})
def test_reap_monthly_active_users(self):
diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py
index b1a8f8bba7..81253d0361 100644
--- a/tests/storage/test_receipts.py
+++ b/tests/storage/test_receipts.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Collection, Optional
+
from synapse.api.constants import ReceiptTypes
from synapse.types import UserID, create_requester
@@ -23,7 +25,7 @@ OUR_USER_ID = "@our:test"
class ReceiptTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, homeserver):
+ def prepare(self, reactor, clock, homeserver) -> None:
super().prepare(reactor, clock, homeserver)
self.store = homeserver.get_datastores().main
@@ -83,10 +85,41 @@ class ReceiptTestCase(HomeserverTestCase):
)
)
- def test_return_empty_with_no_data(self):
+ def get_last_unthreaded_receipt(
+ self, receipt_types: Collection[str], room_id: Optional[str] = None
+ ) -> Optional[str]:
+ """
+ Fetch the event ID for the latest unthreaded receipt in the test room for the test user.
+
+ Args:
+ receipt_types: The receipt types to fetch.
+
+ Returns:
+ The latest receipt, if one exists.
+ """
+ result = self.get_success(
+ self.store.db_pool.runInteraction(
+ "get_last_receipt_event_id_for_user",
+ self.store.get_last_unthreaded_receipt_for_user_txn,
+ OUR_USER_ID,
+ room_id or self.room_id1,
+ receipt_types,
+ )
+ )
+ if not result:
+ return None
+
+ event_id, _ = result
+ return event_id
+
+ def test_return_empty_with_no_data(self) -> None:
res = self.get_success(
self.store.get_receipts_for_user(
- OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
+ OUR_USER_ID,
+ [
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ],
)
)
self.assertEqual(res, {})
@@ -94,21 +127,21 @@ class ReceiptTestCase(HomeserverTestCase):
res = self.get_success(
self.store.get_receipts_for_user_with_orderings(
OUR_USER_ID,
- [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
+ [
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ],
)
)
self.assertEqual(res, {})
- res = self.get_success(
- self.store.get_last_receipt_event_id_for_user(
- OUR_USER_ID,
- self.room_id1,
- [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
- )
+ res = self.get_last_unthreaded_receipt(
+ [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
)
+
self.assertEqual(res, None)
- def test_get_receipts_for_user(self):
+ def test_get_receipts_for_user(self) -> None:
# Send some events into the first room
event1_1_id = self.create_and_send_event(
self.room_id1, UserID.from_string(OTHER_USER_ID)
@@ -120,13 +153,18 @@ class ReceiptTestCase(HomeserverTestCase):
# Send public read receipt for the first event
self.get_success(
self.store.insert_receipt(
- self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {}
+ self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], None, {}
)
)
# Send private read receipt for the second event
self.get_success(
self.store.insert_receipt(
- self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {}
+ self.room_id1,
+ ReceiptTypes.READ_PRIVATE,
+ OUR_USER_ID,
+ [event1_2_id],
+ None,
+ {},
)
)
@@ -153,7 +191,7 @@ class ReceiptTestCase(HomeserverTestCase):
# Test receipt updating
self.get_success(
self.store.insert_receipt(
- self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {}
+ self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {}
)
)
res = self.get_success(
@@ -169,7 +207,12 @@ class ReceiptTestCase(HomeserverTestCase):
# Test new room is reflected in what the method returns
self.get_success(
self.store.insert_receipt(
- self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {}
+ self.room_id2,
+ ReceiptTypes.READ_PRIVATE,
+ OUR_USER_ID,
+ [event2_1_id],
+ None,
+ {},
)
)
res = self.get_success(
@@ -179,7 +222,7 @@ class ReceiptTestCase(HomeserverTestCase):
)
self.assertEqual(res, {self.room_id1: event1_2_id, self.room_id2: event2_1_id})
- def test_get_last_receipt_event_id_for_user(self):
+ def test_get_last_receipt_event_id_for_user(self) -> None:
# Send some events into the first room
event1_1_id = self.create_and_send_event(
self.room_id1, UserID.from_string(OTHER_USER_ID)
@@ -191,53 +234,42 @@ class ReceiptTestCase(HomeserverTestCase):
# Send public read receipt for the first event
self.get_success(
self.store.insert_receipt(
- self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {}
+ self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], None, {}
)
)
# Send private read receipt for the second event
self.get_success(
self.store.insert_receipt(
- self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {}
+ self.room_id1,
+ ReceiptTypes.READ_PRIVATE,
+ OUR_USER_ID,
+ [event1_2_id],
+ None,
+ {},
)
)
# Test we get the latest event when we want both private and public receipts
- res = self.get_success(
- self.store.get_last_receipt_event_id_for_user(
- OUR_USER_ID,
- self.room_id1,
- [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
- )
+ res = self.get_last_unthreaded_receipt(
+ [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
)
self.assertEqual(res, event1_2_id)
# Test we get the older event when we want only public receipt
- res = self.get_success(
- self.store.get_last_receipt_event_id_for_user(
- OUR_USER_ID, self.room_id1, [ReceiptTypes.READ]
- )
- )
+ res = self.get_last_unthreaded_receipt([ReceiptTypes.READ])
self.assertEqual(res, event1_1_id)
# Test we get the latest event when we want only the private receipt
- res = self.get_success(
- self.store.get_last_receipt_event_id_for_user(
- OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE]
- )
- )
+ res = self.get_last_unthreaded_receipt([ReceiptTypes.READ_PRIVATE])
self.assertEqual(res, event1_2_id)
# Test receipt updating
self.get_success(
self.store.insert_receipt(
- self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {}
- )
- )
- res = self.get_success(
- self.store.get_last_receipt_event_id_for_user(
- OUR_USER_ID, self.room_id1, [ReceiptTypes.READ]
+ self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {}
)
)
+ res = self.get_last_unthreaded_receipt([ReceiptTypes.READ])
self.assertEqual(res, event1_2_id)
# Send some events into the second room
@@ -248,14 +280,15 @@ class ReceiptTestCase(HomeserverTestCase):
# Test new room is reflected in what the method returns
self.get_success(
self.store.insert_receipt(
- self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {}
- )
- )
- res = self.get_success(
- self.store.get_last_receipt_event_id_for_user(
- OUR_USER_ID,
self.room_id2,
- [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
+ ReceiptTypes.READ_PRIVATE,
+ OUR_USER_ID,
+ [event2_1_id],
+ None,
+ {},
)
)
+ res = self.get_last_unthreaded_receipt(
+ [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], room_id=self.room_id2
+ )
self.assertEqual(res, event2_1_id)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index a49ac1525e..05ea802008 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -11,15 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import UserTypes
from synapse.api.errors import ThreepidValidationError
+from synapse.server import HomeServer
+from synapse.types import JsonDict, UserID
+from synapse.util import Clock
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, override_config
class RegistrationStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.user_id = "@my-user:test"
@@ -27,7 +31,7 @@ class RegistrationStoreTestCase(HomeserverTestCase):
self.pwhash = "{xx1}123456789"
self.device_id = "akgjhdjklgshg"
- def test_register(self):
+ def test_register(self) -> None:
self.get_success(self.store.register_user(self.user_id, self.pwhash))
self.assertEqual(
@@ -38,17 +42,32 @@ class RegistrationStoreTestCase(HomeserverTestCase):
"admin": 0,
"is_guest": 0,
"consent_version": None,
+ "consent_ts": None,
"consent_server_notice_sent": None,
"appservice_id": None,
"creation_ts": 0,
"user_type": None,
"deactivated": 0,
"shadow_banned": 0,
+ "approved": 1,
},
(self.get_success(self.store.get_user_by_id(self.user_id))),
)
- def test_add_tokens(self):
+ def test_consent(self) -> None:
+ self.get_success(self.store.register_user(self.user_id, self.pwhash))
+ before_consent = self.clock.time_msec()
+ self.reactor.advance(5)
+ self.get_success(self.store.user_set_consent_version(self.user_id, "1"))
+ self.reactor.advance(5)
+
+ user = self.get_success(self.store.get_user_by_id(self.user_id))
+ assert user
+ self.assertEqual(user["consent_version"], "1")
+ self.assertGreater(user["consent_ts"], before_consent)
+ self.assertLess(user["consent_ts"], self.clock.time_msec())
+
+ def test_add_tokens(self) -> None:
self.get_success(self.store.register_user(self.user_id, self.pwhash))
self.get_success(
self.store.add_access_token_to_user(
@@ -58,11 +77,12 @@ class RegistrationStoreTestCase(HomeserverTestCase):
result = self.get_success(self.store.get_user_by_access_token(self.tokens[1]))
+ assert result
self.assertEqual(result.user_id, self.user_id)
self.assertEqual(result.device_id, self.device_id)
self.assertIsNotNone(result.token_id)
- def test_user_delete_access_tokens(self):
+ def test_user_delete_access_tokens(self) -> None:
# add some tokens
self.get_success(self.store.register_user(self.user_id, self.pwhash))
self.get_success(
@@ -87,6 +107,7 @@ class RegistrationStoreTestCase(HomeserverTestCase):
# check the one not associated with the device was not deleted
user = self.get_success(self.store.get_user_by_access_token(self.tokens[0]))
+ assert user
self.assertEqual(self.user_id, user.user_id)
# now delete the rest
@@ -95,11 +116,11 @@ class RegistrationStoreTestCase(HomeserverTestCase):
user = self.get_success(self.store.get_user_by_access_token(self.tokens[0]))
self.assertIsNone(user, "access token was not deleted without device_id")
- def test_is_support_user(self):
+ def test_is_support_user(self) -> None:
TEST_USER = "@test:test"
SUPPORT_USER = "@support:test"
- res = self.get_success(self.store.is_support_user(None))
+ res = self.get_success(self.store.is_support_user(None)) # type: ignore[arg-type]
self.assertFalse(res)
self.get_success(
self.store.register_user(user_id=TEST_USER, password_hash=None)
@@ -115,7 +136,7 @@ class RegistrationStoreTestCase(HomeserverTestCase):
res = self.get_success(self.store.is_support_user(SUPPORT_USER))
self.assertTrue(res)
- def test_3pid_inhibit_invalid_validation_session_error(self):
+ def test_3pid_inhibit_invalid_validation_session_error(self) -> None:
"""Tests that enabling the configuration option to inhibit 3PID errors on
/requestToken also inhibits validation errors caused by an unknown session ID.
"""
@@ -147,3 +168,101 @@ class RegistrationStoreTestCase(HomeserverTestCase):
ThreepidValidationError,
)
self.assertEqual(e.value.msg, "Validation token not found or has expired", e)
+
+
+class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+
+ # If there's already some config for this feature in the default config, it
+ # means we're overriding it with @override_config. In this case we don't want
+ # to do anything more with it.
+ msc3866_config = config.get("experimental_features", {}).get("msc3866")
+ if msc3866_config is not None:
+ return config
+
+ # Require approval for all new accounts.
+ config["experimental_features"] = {
+ "msc3866": {
+ "enabled": True,
+ "require_approval_for_new_accounts": True,
+ }
+ }
+ return config
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.user_id = "@my-user:test"
+ self.pwhash = "{xx1}123456789"
+
+ @override_config(
+ {
+ "experimental_features": {
+ "msc3866": {
+ "enabled": True,
+ "require_approval_for_new_accounts": False,
+ }
+ }
+ }
+ )
+ def test_approval_not_required(self) -> None:
+ """Tests that if we don't require approval for new accounts, newly created
+ accounts are automatically marked as approved.
+ """
+ self.get_success(self.store.register_user(self.user_id, self.pwhash))
+
+ user = self.get_success(self.store.get_user_by_id(self.user_id))
+ assert user is not None
+ self.assertTrue(user["approved"])
+
+ approved = self.get_success(self.store.is_user_approved(self.user_id))
+ self.assertTrue(approved)
+
+ def test_approval_required(self) -> None:
+ """Tests that if we require approval for new accounts, newly created accounts
+ are not automatically marked as approved.
+ """
+ self.get_success(self.store.register_user(self.user_id, self.pwhash))
+
+ user = self.get_success(self.store.get_user_by_id(self.user_id))
+ assert user is not None
+ self.assertFalse(user["approved"])
+
+ approved = self.get_success(self.store.is_user_approved(self.user_id))
+ self.assertFalse(approved)
+
+ def test_override(self) -> None:
+ """Tests that if we require approval for new accounts, but we explicitly say the
+ new user should be considered approved, they're marked as approved.
+ """
+ self.get_success(
+ self.store.register_user(
+ self.user_id,
+ self.pwhash,
+ approved=True,
+ )
+ )
+
+ user = self.get_success(self.store.get_user_by_id(self.user_id))
+ self.assertIsNotNone(user)
+ assert user is not None
+ self.assertEqual(user["approved"], 1)
+
+ approved = self.get_success(self.store.is_user_approved(self.user_id))
+ self.assertTrue(approved)
+
+ def test_approve_user(self) -> None:
+ """Tests that approving the user updates their approval status."""
+ self.get_success(self.store.register_user(self.user_id, self.pwhash))
+
+ approved = self.get_success(self.store.is_user_approved(self.user_id))
+ self.assertFalse(approved)
+
+ self.get_success(
+ self.store.update_user_approval_status(
+ UserID.from_string(self.user_id), True
+ )
+ )
+
+ approved = self.get_success(self.store.is_user_approved(self.user_id))
+ self.assertTrue(approved)
diff --git a/tests/storage/test_relations.py b/tests/storage/test_relations.py
new file mode 100644
index 0000000000..cd1d00208b
--- /dev/null
+++ b/tests/storage/test_relations.py
@@ -0,0 +1,111 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.constants import MAIN_TIMELINE
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests import unittest
+
+
+class RelationsStoreTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ """
+ Creates a DAG:
+
+ A <---[m.thread]-- B <--[m.annotation]-- C
+ ^
+ |--[m.reference]-- D <--[m.annotation]-- E
+
+ F <--[m.annotation]-- G
+
+ """
+ self._main_store = self.hs.get_datastores().main
+
+ self._create_relation("A", "B", "m.thread")
+ self._create_relation("B", "C", "m.annotation")
+ self._create_relation("A", "D", "m.reference")
+ self._create_relation("D", "E", "m.annotation")
+ self._create_relation("F", "G", "m.annotation")
+
+ def _create_relation(self, parent_id: str, event_id: str, rel_type: str) -> None:
+ self.get_success(
+ self._main_store.db_pool.simple_insert(
+ table="event_relations",
+ values={
+ "event_id": event_id,
+ "relates_to_id": parent_id,
+ "relation_type": rel_type,
+ },
+ )
+ )
+
+ def test_get_thread_id(self) -> None:
+ """
+ Ensure that get_thread_id only searches up the tree for threads.
+ """
+ # The thread itself and children of it return the thread.
+ thread_id = self.get_success(self._main_store.get_thread_id("B"))
+ self.assertEqual("A", thread_id)
+
+ thread_id = self.get_success(self._main_store.get_thread_id("C"))
+ self.assertEqual("A", thread_id)
+
+ # But the root and events related to the root do not.
+ thread_id = self.get_success(self._main_store.get_thread_id("A"))
+ self.assertEqual(MAIN_TIMELINE, thread_id)
+
+ thread_id = self.get_success(self._main_store.get_thread_id("D"))
+ self.assertEqual(MAIN_TIMELINE, thread_id)
+
+ thread_id = self.get_success(self._main_store.get_thread_id("E"))
+ self.assertEqual(MAIN_TIMELINE, thread_id)
+
+ # Events which are not related to a thread at all should return the
+ # main timeline.
+ thread_id = self.get_success(self._main_store.get_thread_id("F"))
+ self.assertEqual(MAIN_TIMELINE, thread_id)
+
+ thread_id = self.get_success(self._main_store.get_thread_id("G"))
+ self.assertEqual(MAIN_TIMELINE, thread_id)
+
+ def test_get_thread_id_for_receipts(self) -> None:
+ """
+ Ensure that get_thread_id_for_receipts searches up and down the tree for a thread.
+ """
+ # All of the events are considered related to this thread.
+ thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("A"))
+ self.assertEqual("A", thread_id)
+
+ thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("B"))
+ self.assertEqual("A", thread_id)
+
+ thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("C"))
+ self.assertEqual("A", thread_id)
+
+ thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("D"))
+ self.assertEqual("A", thread_id)
+
+ thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("E"))
+ self.assertEqual("A", thread_id)
+
+ # Events which are not related to a thread at all should return the
+ # main timeline.
+ thread_id = self.get_success(self._main_store.get_thread_id("F"))
+ self.assertEqual(MAIN_TIMELINE, thread_id)
+
+ thread_id = self.get_success(self._main_store.get_thread_id("G"))
+ self.assertEqual(MAIN_TIMELINE, thread_id)
diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py
index e747c6b50e..ef850daa73 100644
--- a/tests/storage/test_room_search.py
+++ b/tests/storage/test_room_search.py
@@ -12,11 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List, Tuple
+from unittest.case import SkipTest
+
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.rest.admin
from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError
from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.storage.databases.main import DataStore
+from synapse.storage.databases.main.search import Phrase, SearchToken, _tokenize_query
from synapse.storage.engines import PostgresEngine
+from synapse.storage.engines.sqlite import Sqlite3Engine
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase, skip_unless
from tests.utils import USE_POSTGRES_FOR_TESTS
@@ -187,3 +197,179 @@ class EventSearchInsertionTest(HomeserverTestCase):
),
)
self.assertCountEqual(values, ["hi", "2"])
+
+
+class MessageSearchTest(HomeserverTestCase):
+ """
+ Check message search.
+
+ A powerful way to check the behaviour is to run the following in Postgres >= 11:
+
+ # SELECT websearch_to_tsquery('english', <your string>);
+
+ The result can be compared to the tokenized version for SQLite and Postgres < 11.
+
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ PHRASE = "the quick brown fox jumps over the lazy dog"
+
+ # Each entry is a search query, followed by a boolean of whether it is in the phrase.
+ COMMON_CASES = [
+ ("nope", False),
+ ("brown", True),
+ ("quick brown", True),
+ ("brown quick", True),
+ ("quick \t brown", True),
+ ("jump", True),
+ ("brown nope", False),
+ ('"brown quick"', False),
+ ('"jumps over"', True),
+ ('"quick fox"', False),
+ ("nope OR doublenope", False),
+ ("furphy OR fox", True),
+ ("fox -nope", True),
+ ("fox -brown", False),
+ ('"fox" quick', True),
+ ('"quick brown', True),
+ ('" quick "', True),
+ ('" nope"', False),
+ ]
+ # TODO Test non-ASCII cases.
+
+ # Case that fail on SQLite.
+ POSTGRES_CASES = [
+ # SQLite treats NOT as a binary operator.
+ ("- fox", False),
+ ("- nope", True),
+ ('"-fox quick', False),
+ # PostgreSQL skips stop words.
+ ('"the quick brown"', True),
+ ('"over lazy"', True),
+ ]
+
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
+ # Register a user and create a room, create some messages
+ self.register_user("alice", "password")
+ self.access_token = self.login("alice", "password")
+ self.room_id = self.helper.create_room_as("alice", tok=self.access_token)
+
+ # Send the phrase as a message and check it was created
+ response = self.helper.send(self.room_id, self.PHRASE, tok=self.access_token)
+ self.assertIn("event_id", response)
+
+ # The behaviour of a missing trailing double quote changed in PostgreSQL 14
+ # from ignoring the initial double quote to treating it as a phrase.
+ main_store = homeserver.get_datastores().main
+ found = False
+ if isinstance(main_store.database_engine, PostgresEngine):
+ assert main_store.database_engine._version is not None
+ found = main_store.database_engine._version < 140000
+ self.COMMON_CASES.append(('"fox quick', found))
+
+ def test_tokenize_query(self) -> None:
+ """Test the custom logic to tokenize a user's query."""
+ cases = (
+ ("brown", ["brown"]),
+ ("quick brown", ["quick", SearchToken.And, "brown"]),
+ ("quick \t brown", ["quick", SearchToken.And, "brown"]),
+ ('"brown quick"', [Phrase(["brown", "quick"])]),
+ ("furphy OR fox", ["furphy", SearchToken.Or, "fox"]),
+ ("fox -brown", ["fox", SearchToken.Not, "brown"]),
+ ("- fox", [SearchToken.Not, "fox"]),
+ ('"fox" quick', [Phrase(["fox"]), SearchToken.And, "quick"]),
+ # No trailing double quote.
+ ('"fox quick', [Phrase(["fox", "quick"])]),
+ ('"-fox quick', [Phrase(["-fox", "quick"])]),
+ ('" quick "', [Phrase(["quick"])]),
+ (
+ 'q"uick brow"n',
+ [
+ "q",
+ SearchToken.And,
+ Phrase(["uick", "brow"]),
+ SearchToken.And,
+ "n",
+ ],
+ ),
+ (
+ '-"quick brown"',
+ [SearchToken.Not, Phrase(["quick", "brown"])],
+ ),
+ )
+
+ for query, expected in cases:
+ tokenized = _tokenize_query(query)
+ self.assertEqual(
+ tokenized, expected, f"{tokenized} != {expected} for {query}"
+ )
+
+ def _check_test_cases(
+ self, store: DataStore, cases: List[Tuple[str, bool]]
+ ) -> None:
+ # Run all the test cases versus search_msgs
+ for query, expect_to_contain in cases:
+ result = self.get_success(
+ store.search_msgs([self.room_id], query, ["content.body"])
+ )
+ self.assertEquals(
+ result["count"],
+ 1 if expect_to_contain else 0,
+ f"expected '{query}' to match '{self.PHRASE}'"
+ if expect_to_contain
+ else f"'{query}' unexpectedly matched '{self.PHRASE}'",
+ )
+ self.assertEquals(
+ len(result["results"]),
+ 1 if expect_to_contain else 0,
+ "results array length should match count",
+ )
+
+ # Run them again versus search_rooms
+ for query, expect_to_contain in cases:
+ result = self.get_success(
+ store.search_rooms([self.room_id], query, ["content.body"], 10)
+ )
+ self.assertEquals(
+ result["count"],
+ 1 if expect_to_contain else 0,
+ f"expected '{query}' to match '{self.PHRASE}'"
+ if expect_to_contain
+ else f"'{query}' unexpectedly matched '{self.PHRASE}'",
+ )
+ self.assertEquals(
+ len(result["results"]),
+ 1 if expect_to_contain else 0,
+ "results array length should match count",
+ )
+
+ def test_postgres_web_search_for_phrase(self):
+ """
+ Test searching for phrases using typical web search syntax, as per postgres' websearch_to_tsquery.
+ This test is skipped unless the postgres instance supports websearch_to_tsquery.
+
+ See https://www.postgresql.org/docs/current/textsearch-controls.html
+ """
+
+ store = self.hs.get_datastores().main
+ if not isinstance(store.database_engine, PostgresEngine):
+ raise SkipTest("Test only applies when postgres is used as the database")
+
+ self._check_test_cases(store, self.COMMON_CASES + self.POSTGRES_CASES)
+
+ def test_sqlite_search(self):
+ """
+ Test sqlite searching for phrases.
+ """
+ store = self.hs.get_datastores().main
+ if not isinstance(store.database_engine, Sqlite3Engine):
+ raise SkipTest("Test only applies when sqlite is used as the database")
+
+ self._check_test_cases(store, self.COMMON_CASES)
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 240b02cb9f..8794401823 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -23,6 +23,7 @@ from synapse.util import Clock
from tests import unittest
from tests.server import TestHomeServer
+from tests.test_utils import event_injection
class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
@@ -157,6 +158,75 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
# Check that alice's display name is now None
self.assertEqual(row[0]["display_name"], None)
+ def test_room_is_locally_forgotten(self) -> None:
+ """Test that when the last local user has forgotten a room it is known as forgotten."""
+ # join two local and one remote user
+ self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
+ self.get_success(
+ event_injection.inject_member_event(self.hs, self.room, self.u_bob, "join")
+ )
+ self.get_success(
+ event_injection.inject_member_event(
+ self.hs, self.room, self.u_charlie.to_string(), "join"
+ )
+ )
+ self.assertFalse(
+ self.get_success(self.store.is_locally_forgotten_room(self.room))
+ )
+
+ # local users leave the room and the room is not forgotten
+ self.get_success(
+ event_injection.inject_member_event(
+ self.hs, self.room, self.u_alice, "leave"
+ )
+ )
+ self.get_success(
+ event_injection.inject_member_event(self.hs, self.room, self.u_bob, "leave")
+ )
+ self.assertFalse(
+ self.get_success(self.store.is_locally_forgotten_room(self.room))
+ )
+
+ # first user forgets the room, room is not forgotten
+ self.get_success(self.store.forget(self.u_alice, self.room))
+ self.assertFalse(
+ self.get_success(self.store.is_locally_forgotten_room(self.room))
+ )
+
+ # second (last local) user forgets the room and the room is forgotten
+ self.get_success(self.store.forget(self.u_bob, self.room))
+ self.assertTrue(
+ self.get_success(self.store.is_locally_forgotten_room(self.room))
+ )
+
+ def test_join_locally_forgotten_room(self) -> None:
+ """Tests if a user joins a forgotten room the room is not forgotten anymore."""
+ self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
+ self.assertFalse(
+ self.get_success(self.store.is_locally_forgotten_room(self.room))
+ )
+
+ # after leaving and forget the room, it is forgotten
+ self.get_success(
+ event_injection.inject_member_event(
+ self.hs, self.room, self.u_alice, "leave"
+ )
+ )
+ self.get_success(self.store.forget(self.u_alice, self.room))
+ self.assertTrue(
+ self.get_success(self.store.is_locally_forgotten_room(self.room))
+ )
+
+ # after rejoin the room is not forgotten anymore
+ self.get_success(
+ event_injection.inject_member_event(
+ self.hs, self.room, self.u_alice, "join"
+ )
+ )
+ self.assertFalse(
+ self.get_success(self.store.is_locally_forgotten_room(self.room))
+ )
+
class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index 78663a53fe..34fa810cf6 100644
--- a/tests/storage/test_stream.py
+++ b/tests/storage/test_stream.py
@@ -16,7 +16,6 @@ from typing import List
from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.filtering import Filter
-from synapse.events import EventBase
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.types import JsonDict
@@ -40,7 +39,7 @@ class PaginationTestCase(HomeserverTestCase):
def default_config(self):
config = super().default_config()
- config["experimental_features"] = {"msc3440_enabled": True}
+ config["experimental_features"] = {"msc3874_enabled": True}
return config
def prepare(self, reactor, clock, homeserver):
@@ -58,6 +57,11 @@ class PaginationTestCase(HomeserverTestCase):
self.third_tok = self.login("third", "test")
self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok)
+ # Store a token which is after all the room creation events.
+ self.from_token = self.get_success(
+ self.hs.get_event_sources().get_current_token_for_pagination(self.room_id)
+ )
+
# An initial event with a relation from second user.
res = self.helper.send_event(
room_id=self.room_id,
@@ -66,7 +70,7 @@ class PaginationTestCase(HomeserverTestCase):
tok=self.tok,
)
self.event_id_1 = res["event_id"]
- self.helper.send_event(
+ res = self.helper.send_event(
room_id=self.room_id,
type="m.reaction",
content={
@@ -78,6 +82,7 @@ class PaginationTestCase(HomeserverTestCase):
},
tok=self.second_tok,
)
+ self.event_id_annotation = res["event_id"]
# Another event with a relation from third user.
res = self.helper.send_event(
@@ -87,7 +92,7 @@ class PaginationTestCase(HomeserverTestCase):
tok=self.tok,
)
self.event_id_2 = res["event_id"]
- self.helper.send_event(
+ res = self.helper.send_event(
room_id=self.room_id,
type="m.reaction",
content={
@@ -98,68 +103,59 @@ class PaginationTestCase(HomeserverTestCase):
},
tok=self.third_tok,
)
+ self.event_id_reference = res["event_id"]
# An event with no relations.
- self.helper.send_event(
+ res = self.helper.send_event(
room_id=self.room_id,
type=EventTypes.Message,
content={"msgtype": "m.text", "body": "No relations"},
tok=self.tok,
)
+ self.event_id_none = res["event_id"]
- def _filter_messages(self, filter: JsonDict) -> List[EventBase]:
+ def _filter_messages(self, filter: JsonDict) -> List[str]:
"""Make a request to /messages with a filter, returns the chunk of events."""
- from_token = self.get_success(
- self.hs.get_event_sources().get_current_token_for_pagination(self.room_id)
- )
-
events, next_key = self.get_success(
self.hs.get_datastores().main.paginate_room_events(
room_id=self.room_id,
- from_key=from_token.room_key,
+ from_key=self.from_token.room_key,
to_key=None,
- direction="b",
+ direction="f",
limit=10,
event_filter=Filter(self.hs, filter),
)
)
- return events
+ return [ev.event_id for ev in events]
def test_filter_relation_senders(self):
# Messages which second user reacted to.
filter = {"related_by_senders": [self.second_user_id]}
chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 1, chunk)
- self.assertEqual(chunk[0].event_id, self.event_id_1)
+ self.assertEqual(chunk, [self.event_id_1])
# Messages which third user reacted to.
filter = {"related_by_senders": [self.third_user_id]}
chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 1, chunk)
- self.assertEqual(chunk[0].event_id, self.event_id_2)
+ self.assertEqual(chunk, [self.event_id_2])
# Messages which either user reacted to.
filter = {"related_by_senders": [self.second_user_id, self.third_user_id]}
chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 2, chunk)
- self.assertCountEqual(
- [c.event_id for c in chunk], [self.event_id_1, self.event_id_2]
- )
+ self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2])
def test_filter_relation_type(self):
# Messages which have annotations.
filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]}
chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 1, chunk)
- self.assertEqual(chunk[0].event_id, self.event_id_1)
+ self.assertEqual(chunk, [self.event_id_1])
# Messages which have references.
filter = {"related_by_rel_types": [RelationTypes.REFERENCE]}
chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 1, chunk)
- self.assertEqual(chunk[0].event_id, self.event_id_2)
+ self.assertEqual(chunk, [self.event_id_2])
# Messages which have either annotations or references.
filter = {
@@ -169,10 +165,7 @@ class PaginationTestCase(HomeserverTestCase):
]
}
chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 2, chunk)
- self.assertCountEqual(
- [c.event_id for c in chunk], [self.event_id_1, self.event_id_2]
- )
+ self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2])
def test_filter_relation_senders_and_type(self):
# Messages which second user reacted to.
@@ -181,8 +174,7 @@ class PaginationTestCase(HomeserverTestCase):
"related_by_rel_types": [RelationTypes.ANNOTATION],
}
chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 1, chunk)
- self.assertEqual(chunk[0].event_id, self.event_id_1)
+ self.assertEqual(chunk, [self.event_id_1])
def test_duplicate_relation(self):
"""An event should only be returned once if there are multiple relations to it."""
@@ -201,5 +193,65 @@ class PaginationTestCase(HomeserverTestCase):
filter = {"related_by_senders": [self.second_user_id]}
chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 1, chunk)
- self.assertEqual(chunk[0].event_id, self.event_id_1)
+ self.assertEqual(chunk, [self.event_id_1])
+
+ def test_filter_rel_types(self) -> None:
+ # Messages which are annotations.
+ filter = {"org.matrix.msc3874.rel_types": [RelationTypes.ANNOTATION]}
+ chunk = self._filter_messages(filter)
+ self.assertEqual(chunk, [self.event_id_annotation])
+
+ # Messages which are references.
+ filter = {"org.matrix.msc3874.rel_types": [RelationTypes.REFERENCE]}
+ chunk = self._filter_messages(filter)
+ self.assertEqual(chunk, [self.event_id_reference])
+
+ # Messages which are either annotations or references.
+ filter = {
+ "org.matrix.msc3874.rel_types": [
+ RelationTypes.ANNOTATION,
+ RelationTypes.REFERENCE,
+ ]
+ }
+ chunk = self._filter_messages(filter)
+ self.assertCountEqual(
+ chunk,
+ [self.event_id_annotation, self.event_id_reference],
+ )
+
+ def test_filter_not_rel_types(self) -> None:
+ # Messages which are not annotations.
+ filter = {"org.matrix.msc3874.not_rel_types": [RelationTypes.ANNOTATION]}
+ chunk = self._filter_messages(filter)
+ self.assertEqual(
+ chunk,
+ [
+ self.event_id_1,
+ self.event_id_2,
+ self.event_id_reference,
+ self.event_id_none,
+ ],
+ )
+
+ # Messages which are not references.
+ filter = {"org.matrix.msc3874.not_rel_types": [RelationTypes.REFERENCE]}
+ chunk = self._filter_messages(filter)
+ self.assertEqual(
+ chunk,
+ [
+ self.event_id_1,
+ self.event_id_annotation,
+ self.event_id_2,
+ self.event_id_none,
+ ],
+ )
+
+ # Messages which are neither annotations or references.
+ filter = {
+ "org.matrix.msc3874.not_rel_types": [
+ RelationTypes.ANNOTATION,
+ RelationTypes.REFERENCE,
+ ]
+ }
+ chunk = self._filter_messages(filter)
+ self.assertEqual(chunk, [self.event_id_1, self.event_id_2, self.event_id_none])
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index e42d7b9ba0..f4d9fba0a1 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -821,7 +821,7 @@ def _alias_event(room_version: RoomVersion, sender: str, **kwargs) -> EventBase:
def _build_auth_dict_for_room_version(
room_version: RoomVersion, auth_events: Iterable[EventBase]
) -> List:
- if room_version.event_format == EventFormatVersions.V1:
+ if room_version.event_format == EventFormatVersions.ROOM_V1_V2:
return [(e.event_id, "not_used") for e in auth_events]
else:
return [e.event_id for e in auth_events]
@@ -871,7 +871,7 @@ event_count = 0
def _maybe_get_event_id_dict_for_room_version(room_version: RoomVersion) -> dict:
"""If this room version needs it, generate an event id"""
- if room_version.event_format != EventFormatVersions.V1:
+ if room_version.event_format != EventFormatVersions.ROOM_V1_V2:
return {}
global event_count
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 779fad1f63..80e5c590d8 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -86,8 +86,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
federation_event_handler._check_event_auth = _check_event_auth
self.client = self.homeserver.get_federation_client()
- self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
- pdus
+ self.client._check_sigs_and_hash_for_pulled_events_and_fetch = (
+ lambda dest, pdus, **k: succeed(pdus)
)
# Send the join, it should return None (which is not an error)
diff --git a/tests/test_phone_home.py b/tests/test_phone_home.py
index b01cae6e5d..cc1a98f1c4 100644
--- a/tests/test_phone_home.py
+++ b/tests/test_phone_home.py
@@ -15,8 +15,14 @@
import resource
from unittest import mock
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.app.phone_stats_home import phone_stats_home
+from synapse.rest import admin
+from synapse.rest.client import login, sync
+from synapse.server import HomeServer
from synapse.types import JsonDict
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -47,5 +53,43 @@ class PhoneHomeStatsTestCase(HomeserverTestCase):
stats: JsonDict = {}
self.reactor.advance(1)
# `old_resource` has type `Mock` instead of `struct_rusage`
- self.get_success(phone_stats_home(self.hs, stats, past_stats)) # type: ignore[arg-type]
+ self.get_success(
+ phone_stats_home(self.hs, stats, past_stats) # type: ignore[arg-type]
+ )
self.assertApproximates(stats["cpu_average"], 100, tolerance=2.5)
+
+
+class CommonMetricsTestCase(HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.metrics_manager = hs.get_common_usage_metrics_manager()
+ self.get_success(self.metrics_manager.setup())
+
+ def test_dau(self) -> None:
+ """Tests that the daily active users count is correctly updated."""
+ self._assert_metric_value("daily_active_users", 0)
+
+ self.register_user("user", "password")
+ tok = self.login("user", "password")
+ self.make_request("GET", "/sync", access_token=tok)
+
+ self.pump(1)
+
+ self._assert_metric_value("daily_active_users", 1)
+
+ def _assert_metric_value(self, metric_name: str, expected: int) -> None:
+ """Compare the given value to the current value of the common usage metric with
+ the given name.
+
+ Args:
+ metric_name: The metric to look up.
+ expected: Expected value for this metric.
+ """
+ metrics = self.get_success(self.metrics_manager.get_metrics())
+ value = getattr(metrics, metric_name)
+ self.assertEqual(value, expected)
diff --git a/tests/test_rust.py b/tests/test_rust.py
new file mode 100644
index 0000000000..55d8b6b28c
--- /dev/null
+++ b/tests/test_rust.py
@@ -0,0 +1,11 @@
+from synapse.synapse_rust import sum_as_string
+
+from tests import unittest
+
+
+class RustTestCase(unittest.TestCase):
+ """Basic tests to ensure that we can call into Rust code."""
+
+ def test_basic(self):
+ result = sum_as_string(1, 2)
+ self.assertEqual("3", result)
diff --git a/tests/test_server.py b/tests/test_server.py
index 2fe4411401..2d9a0257d4 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -26,12 +26,12 @@ from synapse.http.server import (
DirectServeJsonResource,
JsonResource,
OptionsResource,
- cancellable,
)
from synapse.http.site import SynapseRequest, SynapseSite
from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict
from synapse.util import Clock
+from synapse.util.cancellation import cancellable
from tests import unittest
from tests.http.server._base import test_disconnect
@@ -104,7 +104,7 @@ class JsonResourceTests(unittest.TestCase):
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo"
)
- self.assertEqual(channel.result["code"], b"500")
+ self.assertEqual(channel.code, 500)
def test_callback_indirect_exception(self) -> None:
"""
@@ -130,7 +130,7 @@ class JsonResourceTests(unittest.TestCase):
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo"
)
- self.assertEqual(channel.result["code"], b"500")
+ self.assertEqual(channel.code, 500)
def test_callback_synapseerror(self) -> None:
"""
@@ -150,7 +150,7 @@ class JsonResourceTests(unittest.TestCase):
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo"
)
- self.assertEqual(channel.result["code"], b"403")
+ self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["error"], "Forbidden!!one!")
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@@ -174,7 +174,7 @@ class JsonResourceTests(unittest.TestCase):
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foobar"
)
- self.assertEqual(channel.result["code"], b"400")
+ self.assertEqual(channel.code, 400)
self.assertEqual(channel.json_body["error"], "Unrecognized request")
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
@@ -203,7 +203,7 @@ class JsonResourceTests(unittest.TestCase):
self.reactor, FakeSite(res, self.reactor), b"HEAD", b"/_matrix/foo"
)
- self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.code, 200)
self.assertNotIn("body", channel.result)
@@ -222,13 +222,22 @@ class OptionsResourceTests(unittest.TestCase):
self.resource = OptionsResource()
self.resource.putChild(b"res", DummyResource())
- def _make_request(self, method: bytes, path: bytes) -> FakeChannel:
+ def _make_request(
+ self, method: bytes, path: bytes, experimental_cors_msc3886: bool = False
+ ) -> FakeChannel:
"""Create a request from the method/path and return a channel with the response."""
# Create a site and query for the resource.
site = SynapseSite(
"test",
"site_tag",
- parse_listener_def({"type": "http", "port": 0}),
+ parse_listener_def(
+ 0,
+ {
+ "type": "http",
+ "port": 0,
+ "experimental_cors_msc3886": experimental_cors_msc3886,
+ },
+ ),
self.resource,
"1.0",
max_request_body_size=4096,
@@ -239,55 +248,86 @@ class OptionsResourceTests(unittest.TestCase):
channel = make_request(self.reactor, site, method, path, shorthand=False)
return channel
+ def _check_cors_standard_headers(self, channel: FakeChannel) -> None:
+ # Ensure the correct CORS headers have been added
+ # as per https://spec.matrix.org/v1.4/client-server-api/#web-browser-clients
+ self.assertEqual(
+ channel.headers.getRawHeaders(b"Access-Control-Allow-Origin"),
+ [b"*"],
+ "has correct CORS Origin header",
+ )
+ self.assertEqual(
+ channel.headers.getRawHeaders(b"Access-Control-Allow-Methods"),
+ [b"GET, HEAD, POST, PUT, DELETE, OPTIONS"], # HEAD isn't in the spec
+ "has correct CORS Methods header",
+ )
+ self.assertEqual(
+ channel.headers.getRawHeaders(b"Access-Control-Allow-Headers"),
+ [b"X-Requested-With, Content-Type, Authorization, Date"],
+ "has correct CORS Headers header",
+ )
+
+ def _check_cors_msc3886_headers(self, channel: FakeChannel) -> None:
+ # Ensure the correct CORS headers have been added
+ # as per https://github.com/matrix-org/matrix-spec-proposals/blob/hughns/simple-rendezvous-capability/proposals/3886-simple-rendezvous-capability.md#cors
+ self.assertEqual(
+ channel.headers.getRawHeaders(b"Access-Control-Allow-Origin"),
+ [b"*"],
+ "has correct CORS Origin header",
+ )
+ self.assertEqual(
+ channel.headers.getRawHeaders(b"Access-Control-Allow-Methods"),
+ [b"GET, HEAD, POST, PUT, DELETE, OPTIONS"], # HEAD isn't in the spec
+ "has correct CORS Methods header",
+ )
+ self.assertEqual(
+ channel.headers.getRawHeaders(b"Access-Control-Allow-Headers"),
+ [
+ b"X-Requested-With, Content-Type, Authorization, Date, If-Match, If-None-Match"
+ ],
+ "has correct CORS Headers header",
+ )
+ self.assertEqual(
+ channel.headers.getRawHeaders(b"Access-Control-Expose-Headers"),
+ [b"ETag, Location, X-Max-Bytes"],
+ "has correct CORS Expose Headers header",
+ )
+
def test_unknown_options_request(self) -> None:
"""An OPTIONS requests to an unknown URL still returns 204 No Content."""
channel = self._make_request(b"OPTIONS", b"/foo/")
- self.assertEqual(channel.result["code"], b"204")
+ self.assertEqual(channel.code, 204)
self.assertNotIn("body", channel.result)
- # Ensure the correct CORS headers have been added
- self.assertTrue(
- channel.headers.hasHeader(b"Access-Control-Allow-Origin"),
- "has CORS Origin header",
- )
- self.assertTrue(
- channel.headers.hasHeader(b"Access-Control-Allow-Methods"),
- "has CORS Methods header",
- )
- self.assertTrue(
- channel.headers.hasHeader(b"Access-Control-Allow-Headers"),
- "has CORS Headers header",
- )
+ self._check_cors_standard_headers(channel)
def test_known_options_request(self) -> None:
"""An OPTIONS requests to an known URL still returns 204 No Content."""
channel = self._make_request(b"OPTIONS", b"/res/")
- self.assertEqual(channel.result["code"], b"204")
+ self.assertEqual(channel.code, 204)
self.assertNotIn("body", channel.result)
- # Ensure the correct CORS headers have been added
- self.assertTrue(
- channel.headers.hasHeader(b"Access-Control-Allow-Origin"),
- "has CORS Origin header",
- )
- self.assertTrue(
- channel.headers.hasHeader(b"Access-Control-Allow-Methods"),
- "has CORS Methods header",
- )
- self.assertTrue(
- channel.headers.hasHeader(b"Access-Control-Allow-Headers"),
- "has CORS Headers header",
+ self._check_cors_standard_headers(channel)
+
+ def test_known_options_request_msc3886(self) -> None:
+ """An OPTIONS requests to an known URL still returns 204 No Content."""
+ channel = self._make_request(
+ b"OPTIONS", b"/res/", experimental_cors_msc3886=True
)
+ self.assertEqual(channel.code, 204)
+ self.assertNotIn("body", channel.result)
+
+ self._check_cors_msc3886_headers(channel)
def test_unknown_request(self) -> None:
"""A non-OPTIONS request to an unknown URL should 404."""
channel = self._make_request(b"GET", b"/foo/")
- self.assertEqual(channel.result["code"], b"404")
+ self.assertEqual(channel.code, 404)
def test_known_request(self) -> None:
"""A non-OPTIONS request to an known URL should query the proper resource."""
channel = self._make_request(b"GET", b"/res/")
- self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.code, 200)
self.assertEqual(channel.result["body"], b"/res/")
@@ -314,7 +354,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
self.reactor, FakeSite(res, self.reactor), b"GET", b"/path"
)
- self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.code, 200)
body = channel.result["body"]
self.assertEqual(body, b"response")
@@ -334,7 +374,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
self.reactor, FakeSite(res, self.reactor), b"GET", b"/path"
)
- self.assertEqual(channel.result["code"], b"301")
+ self.assertEqual(channel.code, 301)
headers = channel.result["headers"]
location_headers = [v for k, v in headers if k == b"Location"]
self.assertEqual(location_headers, [b"/look/an/eagle"])
@@ -357,7 +397,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
self.reactor, FakeSite(res, self.reactor), b"GET", b"/path"
)
- self.assertEqual(channel.result["code"], b"304")
+ self.assertEqual(channel.code, 304)
headers = channel.result["headers"]
location_headers = [v for k, v in headers if k == b"Location"]
self.assertEqual(location_headers, [b"/no/over/there"])
@@ -378,7 +418,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
self.reactor, FakeSite(res, self.reactor), b"HEAD", b"/path"
)
- self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.code, 200)
self.assertNotIn("body", channel.result)
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index d3c13cf14c..abd7459a8c 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -53,7 +53,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
request_data = {"username": "kermit", "password": "monkey"}
channel = self.make_request(b"POST", self.url, request_data)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, channel.result)
self.assertTrue(channel.json_body is not None)
self.assertIsInstance(channel.json_body["session"], str)
@@ -96,7 +96,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
# We don't bother checking that the response is correct - we'll leave that to
# other tests. We just want to make sure we're on the right path.
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, channel.result)
# Finish the UI auth for terms
request_data = {
@@ -112,7 +112,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
# We're interested in getting a response that looks like a successful
# registration, not so much that the details are exactly what we want.
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
self.assertTrue(channel.json_body is not None)
self.assertIsInstance(channel.json_body["user_id"], str)
diff --git a/tests/test_types.py b/tests/test_types.py
index d8d82a517e..1111169384 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -13,11 +13,35 @@
# limitations under the License.
from synapse.api.errors import SynapseError
-from synapse.types import RoomAlias, UserID, map_username_to_mxid_localpart
+from synapse.types import (
+ RoomAlias,
+ UserID,
+ get_domain_from_id,
+ get_localpart_from_id,
+ map_username_to_mxid_localpart,
+)
from tests import unittest
+class IsMineIDTests(unittest.HomeserverTestCase):
+ def test_is_mine_id(self) -> None:
+ self.assertTrue(self.hs.is_mine_id("@user:test"))
+ self.assertTrue(self.hs.is_mine_id("#room:test"))
+ self.assertTrue(self.hs.is_mine_id("invalid:test"))
+
+ self.assertFalse(self.hs.is_mine_id("@user:test\0"))
+ self.assertFalse(self.hs.is_mine_id("@user"))
+
+ def test_two_colons(self) -> None:
+ """Test handling of IDs containing more than one colon."""
+ # The domain starts after the first colon.
+ # These functions must interpret things consistently.
+ self.assertFalse(self.hs.is_mine_id("@user:test:test"))
+ self.assertEqual("user", get_localpart_from_id("@user:test:test"))
+ self.assertEqual("test:test", get_domain_from_id("@user:test:test"))
+
+
class UserIDTestCase(unittest.HomeserverTestCase):
def test_parse(self):
user = UserID.from_string("@1234abcd:test")
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 0d0d6faf0d..e62ebcc6a5 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -15,17 +15,24 @@
"""
Utilities for running the unit tests
"""
+import json
import sys
import warnings
from asyncio import Future
from binascii import unhexlify
-from typing import Awaitable, Callable, TypeVar
+from typing import Awaitable, Callable, Tuple, TypeVar
from unittest.mock import Mock
import attr
+import zope.interface
from twisted.python.failure import Failure
from twisted.web.client import ResponseDone
+from twisted.web.http import RESPONSES
+from twisted.web.http_headers import Headers
+from twisted.web.iweb import IResponse
+
+from synapse.types import JsonDict
TV = TypeVar("TV")
@@ -97,27 +104,44 @@ def simple_async_mock(return_value=None, raises=None) -> Mock:
return Mock(side_effect=cb)
-@attr.s
-class FakeResponse:
+# Type ignore: it does not fully implement IResponse, but is good enough for tests
+@zope.interface.implementer(IResponse)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class FakeResponse: # type: ignore[misc]
"""A fake twisted.web.IResponse object
there is a similar class at treq.test.test_response, but it lacks a `phrase`
attribute, and didn't support deliverBody until recently.
"""
- # HTTP response code
- code = attr.ib(type=int)
+ version: Tuple[bytes, int, int] = (b"HTTP", 1, 1)
- # HTTP response phrase (eg b'OK' for a 200)
- phrase = attr.ib(type=bytes)
+ # HTTP response code
+ code: int = 200
# body of the response
- body = attr.ib(type=bytes)
+ body: bytes = b""
+
+ headers: Headers = attr.Factory(Headers)
+
+ @property
+ def phrase(self):
+ return RESPONSES.get(self.code, b"Unknown Status")
+
+ @property
+ def length(self):
+ return len(self.body)
def deliverBody(self, protocol):
protocol.dataReceived(self.body)
protocol.connectionLost(Failure(ResponseDone()))
+ @classmethod
+ def json(cls, *, code: int = 200, payload: JsonDict) -> "FakeResponse":
+ headers = Headers({"Content-Type": ["application/json"]})
+ body = json.dumps(payload).encode("utf-8")
+ return cls(code=code, body=body, headers=headers)
+
# A small image used in some tests.
#
diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py
new file mode 100644
index 0000000000..1461d23ee8
--- /dev/null
+++ b/tests/test_utils/oidc.py
@@ -0,0 +1,348 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import json
+from typing import Any, Dict, List, Optional, Tuple
+from unittest.mock import Mock, patch
+from urllib.parse import parse_qs
+
+import attr
+
+from twisted.web.http_headers import Headers
+from twisted.web.iweb import IResponse
+
+from synapse.server import HomeServer
+from synapse.util import Clock
+from synapse.util.stringutils import random_string
+
+from tests.test_utils import FakeResponse
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class FakeAuthorizationGrant:
+ userinfo: dict
+ client_id: str
+ redirect_uri: str
+ scope: str
+ nonce: Optional[str]
+ sid: Optional[str]
+
+
+class FakeOidcServer:
+ """A fake OpenID Connect Provider."""
+
+ # All methods here are mocks, so we can track when they are called, and override
+ # their values
+ request: Mock
+ get_jwks_handler: Mock
+ get_metadata_handler: Mock
+ get_userinfo_handler: Mock
+ post_token_handler: Mock
+
+ sid_counter: int = 0
+
+ def __init__(self, clock: Clock, issuer: str):
+ from authlib.jose import ECKey, KeySet
+
+ self._clock = clock
+ self.issuer = issuer
+
+ self.request = Mock(side_effect=self._request)
+ self.get_jwks_handler = Mock(side_effect=self._get_jwks_handler)
+ self.get_metadata_handler = Mock(side_effect=self._get_metadata_handler)
+ self.get_userinfo_handler = Mock(side_effect=self._get_userinfo_handler)
+ self.post_token_handler = Mock(side_effect=self._post_token_handler)
+
+ # A code -> grant mapping
+ self._authorization_grants: Dict[str, FakeAuthorizationGrant] = {}
+ # An access token -> grant mapping
+ self._sessions: Dict[str, FakeAuthorizationGrant] = {}
+
+ # We generate here an ECDSA key with the P-256 curve (ES256 algorithm) used for
+ # signing JWTs. ECDSA keys are really quick to generate compared to RSA.
+ self._key = ECKey.generate_key(crv="P-256", is_private=True)
+ self._jwks = KeySet([ECKey.import_key(self._key.as_pem(is_private=False))])
+
+ self._id_token_overrides: Dict[str, Any] = {}
+
+ def reset_mocks(self):
+ self.request.reset_mock()
+ self.get_jwks_handler.reset_mock()
+ self.get_metadata_handler.reset_mock()
+ self.get_userinfo_handler.reset_mock()
+ self.post_token_handler.reset_mock()
+
+ def patch_homeserver(self, hs: HomeServer):
+ """Patch the ``HomeServer`` HTTP client to handle requests through the ``FakeOidcServer``.
+
+ This patch should be used whenever the HS is expected to perform request to the
+ OIDC provider, e.g.::
+
+ fake_oidc_server = self.helper.fake_oidc_server()
+ with fake_oidc_server.patch_homeserver(hs):
+ self.make_request("GET", "/_matrix/client/r0/login/sso/redirect")
+ """
+ return patch.object(hs.get_proxied_http_client(), "request", self.request)
+
+ @property
+ def authorization_endpoint(self) -> str:
+ return self.issuer + "authorize"
+
+ @property
+ def token_endpoint(self) -> str:
+ return self.issuer + "token"
+
+ @property
+ def userinfo_endpoint(self) -> str:
+ return self.issuer + "userinfo"
+
+ @property
+ def metadata_endpoint(self) -> str:
+ return self.issuer + ".well-known/openid-configuration"
+
+ @property
+ def jwks_uri(self) -> str:
+ return self.issuer + "jwks"
+
+ def get_metadata(self) -> dict:
+ return {
+ "issuer": self.issuer,
+ "authorization_endpoint": self.authorization_endpoint,
+ "token_endpoint": self.token_endpoint,
+ "jwks_uri": self.jwks_uri,
+ "userinfo_endpoint": self.userinfo_endpoint,
+ "response_types_supported": ["code"],
+ "subject_types_supported": ["public"],
+ "id_token_signing_alg_values_supported": ["ES256"],
+ }
+
+ def get_jwks(self) -> dict:
+ return self._jwks.as_dict()
+
+ def get_userinfo(self, access_token: str) -> Optional[dict]:
+ """Given an access token, get the userinfo of the associated session."""
+ session = self._sessions.get(access_token, None)
+ if session is None:
+ return None
+ return session.userinfo
+
+ def _sign(self, payload: dict) -> str:
+ from authlib.jose import JsonWebSignature
+
+ jws = JsonWebSignature()
+ kid = self.get_jwks()["keys"][0]["kid"]
+ protected = {"alg": "ES256", "kid": kid}
+ json_payload = json.dumps(payload)
+ return jws.serialize_compact(protected, json_payload, self._key).decode("utf-8")
+
+ def generate_id_token(self, grant: FakeAuthorizationGrant) -> str:
+ now = int(self._clock.time())
+ id_token = {
+ **grant.userinfo,
+ "iss": self.issuer,
+ "aud": grant.client_id,
+ "iat": now,
+ "nbf": now,
+ "exp": now + 600,
+ }
+
+ if grant.nonce is not None:
+ id_token["nonce"] = grant.nonce
+
+ if grant.sid is not None:
+ id_token["sid"] = grant.sid
+
+ id_token.update(self._id_token_overrides)
+
+ return self._sign(id_token)
+
+ def generate_logout_token(self, grant: FakeAuthorizationGrant) -> str:
+ now = int(self._clock.time())
+ logout_token = {
+ "iss": self.issuer,
+ "aud": grant.client_id,
+ "iat": now,
+ "jti": random_string(10),
+ "events": {
+ "http://schemas.openid.net/event/backchannel-logout": {},
+ },
+ }
+
+ if grant.sid is not None:
+ logout_token["sid"] = grant.sid
+
+ if "sub" in grant.userinfo:
+ logout_token["sub"] = grant.userinfo["sub"]
+
+ return self._sign(logout_token)
+
+ def id_token_override(self, overrides: dict):
+ """Temporarily patch the ID token generated by the token endpoint."""
+ return patch.object(self, "_id_token_overrides", overrides)
+
+ def start_authorization(
+ self,
+ client_id: str,
+ scope: str,
+ redirect_uri: str,
+ userinfo: dict,
+ nonce: Optional[str] = None,
+ with_sid: bool = False,
+ ) -> Tuple[str, FakeAuthorizationGrant]:
+ """Start an authorization request, and get back the code to use on the authorization endpoint."""
+ code = random_string(10)
+ sid = None
+ if with_sid:
+ sid = str(self.sid_counter)
+ self.sid_counter += 1
+
+ grant = FakeAuthorizationGrant(
+ userinfo=userinfo,
+ scope=scope,
+ redirect_uri=redirect_uri,
+ nonce=nonce,
+ client_id=client_id,
+ sid=sid,
+ )
+ self._authorization_grants[code] = grant
+
+ return code, grant
+
+ def exchange_code(self, code: str) -> Optional[Dict[str, Any]]:
+ grant = self._authorization_grants.pop(code, None)
+ if grant is None:
+ return None
+
+ access_token = random_string(10)
+ self._sessions[access_token] = grant
+
+ token = {
+ "token_type": "Bearer",
+ "access_token": access_token,
+ "expires_in": 3600,
+ "scope": grant.scope,
+ }
+
+ if "openid" in grant.scope:
+ token["id_token"] = self.generate_id_token(grant)
+
+ return dict(token)
+
+ def buggy_endpoint(
+ self,
+ *,
+ jwks: bool = False,
+ metadata: bool = False,
+ token: bool = False,
+ userinfo: bool = False,
+ ):
+ """A context which makes a set of endpoints return a 500 error.
+
+ Args:
+ jwks: If True, makes the JWKS endpoint return a 500 error.
+ metadata: If True, makes the OIDC Discovery endpoint return a 500 error.
+ token: If True, makes the token endpoint return a 500 error.
+ userinfo: If True, makes the userinfo endpoint return a 500 error.
+ """
+ buggy = FakeResponse(code=500, body=b"Internal server error")
+
+ patches = {}
+ if jwks:
+ patches["get_jwks_handler"] = Mock(return_value=buggy)
+ if metadata:
+ patches["get_metadata_handler"] = Mock(return_value=buggy)
+ if token:
+ patches["post_token_handler"] = Mock(return_value=buggy)
+ if userinfo:
+ patches["get_userinfo_handler"] = Mock(return_value=buggy)
+
+ return patch.multiple(self, **patches)
+
+ async def _request(
+ self,
+ method: str,
+ uri: str,
+ data: Optional[bytes] = None,
+ headers: Optional[Headers] = None,
+ ) -> IResponse:
+ """The override of the SimpleHttpClient#request() method"""
+ access_token: Optional[str] = None
+
+ if headers is None:
+ headers = Headers()
+
+ # Try to find the access token in the headers if any
+ auth_headers = headers.getRawHeaders(b"Authorization")
+ if auth_headers:
+ parts = auth_headers[0].split(b" ")
+ if parts[0] == b"Bearer" and len(parts) == 2:
+ access_token = parts[1].decode("ascii")
+
+ if method == "POST":
+ # If the method is POST, assume it has an url-encoded body
+ if data is None or headers.getRawHeaders(b"Content-Type") != [
+ b"application/x-www-form-urlencoded"
+ ]:
+ return FakeResponse.json(code=400, payload={"error": "invalid_request"})
+
+ params = parse_qs(data.decode("utf-8"))
+
+ if uri == self.token_endpoint:
+ # Even though this endpoint should be protected, this does not check
+ # for client authentication. We're not checking it for simplicity,
+ # and because client authentication is tested in other standalone tests.
+ return self.post_token_handler(params)
+
+ elif method == "GET":
+ if uri == self.jwks_uri:
+ return self.get_jwks_handler()
+ elif uri == self.metadata_endpoint:
+ return self.get_metadata_handler()
+ elif uri == self.userinfo_endpoint:
+ return self.get_userinfo_handler(access_token=access_token)
+
+ return FakeResponse(code=404, body=b"404 not found")
+
+ # Request handlers
+ def _get_jwks_handler(self) -> IResponse:
+ """Handles requests to the JWKS URI."""
+ return FakeResponse.json(payload=self.get_jwks())
+
+ def _get_metadata_handler(self) -> IResponse:
+ """Handles requests to the OIDC well-known document."""
+ return FakeResponse.json(payload=self.get_metadata())
+
+ def _get_userinfo_handler(self, access_token: Optional[str]) -> IResponse:
+ """Handles requests to the userinfo endpoint."""
+ if access_token is None:
+ return FakeResponse(code=401)
+ user_info = self.get_userinfo(access_token)
+ if user_info is None:
+ return FakeResponse(code=401)
+
+ return FakeResponse.json(payload=user_info)
+
+ def _post_token_handler(self, params: Dict[str, List[str]]) -> IResponse:
+ """Handles requests to the token endpoint."""
+ code = params.get("code", [])
+
+ if len(code) != 1:
+ return FakeResponse.json(code=400, payload={"error": "invalid_request"})
+
+ grant = self.exchange_code(code=code[0])
+ if grant is None:
+ return FakeResponse.json(code=400, payload={"error": "invalid_grant"})
+
+ return FakeResponse.json(payload=grant)
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index c385b2f8d4..d0b9ad5454 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -61,7 +61,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
filtered = self.get_success(
filter_events_for_server(
- self._storage_controllers, "test_server", events_to_filter
+ self._storage_controllers, "test_server", "hs", events_to_filter
)
)
@@ -83,7 +83,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
self.assertEqual(
self.get_success(
filter_events_for_server(
- self._storage_controllers, "remote_hs", [outlier]
+ self._storage_controllers, "remote_hs", "hs", [outlier]
)
),
[outlier],
@@ -94,7 +94,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
filtered = self.get_success(
filter_events_for_server(
- self._storage_controllers, "remote_hs", [outlier, evt]
+ self._storage_controllers, "remote_hs", "local_hs", [outlier, evt]
)
)
self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}")
@@ -106,7 +106,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# be redacted)
filtered = self.get_success(
filter_events_for_server(
- self._storage_controllers, "other_server", [outlier, evt]
+ self._storage_controllers, "other_server", "local_hs", [outlier, evt]
)
)
self.assertEqual(filtered[0], outlier)
@@ -141,7 +141,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# ... and the filtering happens.
filtered = self.get_success(
filter_events_for_server(
- self._storage_controllers, "test_server", events_to_filter
+ self._storage_controllers, "test_server", "local_hs", events_to_filter
)
)
diff --git a/tests/unittest.py b/tests/unittest.py
index bec4a3d023..a120c2976c 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -300,47 +300,31 @@ class HomeserverTestCase(TestCase):
if hasattr(self, "user_id"):
if self.hijack_auth:
assert self.helper.auth_user_id is not None
+ token = "some_fake_token"
# We need a valid token ID to satisfy foreign key constraints.
token_id = self.get_success(
self.hs.get_datastores().main.add_access_token_to_user(
self.helper.auth_user_id,
- "some_fake_token",
+ token,
None,
None,
)
)
- async def get_user_by_access_token(
- token: Optional[str] = None, allow_guest: bool = False
- ) -> JsonDict:
- assert self.helper.auth_user_id is not None
- return {
- "user": UserID.from_string(self.helper.auth_user_id),
- "token_id": token_id,
- "is_guest": False,
- }
-
- async def get_user_by_req(
- request: SynapseRequest,
- allow_guest: bool = False,
- allow_expired: bool = False,
- ) -> Requester:
+ # This has to be a function and not just a Mock, because
+ # `self.helper.auth_user_id` is temporarily reassigned in some tests
+ async def get_requester(*args, **kwargs) -> Requester:
assert self.helper.auth_user_id is not None
return create_requester(
- UserID.from_string(self.helper.auth_user_id),
- token_id,
- False,
- False,
- None,
+ user_id=UserID.from_string(self.helper.auth_user_id),
+ access_token_id=token_id,
)
# Type ignore: mypy doesn't like us assigning to methods.
- self.hs.get_auth().get_user_by_req = get_user_by_req # type: ignore[assignment]
- self.hs.get_auth().get_user_by_access_token = get_user_by_access_token # type: ignore[assignment]
- self.hs.get_auth().get_access_token_from_request = Mock( # type: ignore[assignment]
- return_value="1234"
- )
+ self.hs.get_auth().get_user_by_req = get_requester # type: ignore[assignment]
+ self.hs.get_auth().get_user_by_access_token = get_requester # type: ignore[assignment]
+ self.hs.get_auth().get_access_token_from_request = Mock(return_value=token) # type: ignore[assignment]
if self.needs_threadpool:
self.reactor.threadpool = ThreadPool() # type: ignore[assignment]
@@ -376,13 +360,13 @@ class HomeserverTestCase(TestCase):
store.db_pool.updates.do_next_background_update(False), by=0.1
)
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock):
"""
Make and return a homeserver.
Args:
reactor: A Twisted Reactor, or something that pretends to be one.
- clock (synapse.util.Clock): The Clock, associated with the reactor.
+ clock: The Clock, associated with the reactor.
Returns:
A homeserver suitable for testing.
@@ -442,9 +426,8 @@ class HomeserverTestCase(TestCase):
Args:
reactor: A Twisted Reactor, or something that pretends to be one.
- clock (synapse.util.Clock): The Clock, associated with the reactor.
- homeserver (synapse.server.HomeServer): The HomeServer to test
- against.
+ clock: The Clock, associated with the reactor.
+ homeserver: The HomeServer to test against.
Function to optionally be overridden in subclasses.
"""
@@ -468,11 +451,10 @@ class HomeserverTestCase(TestCase):
given content.
Args:
- method (bytes/unicode): The HTTP request method ("verb").
- path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
- escaped UTF-8 & spaces and such).
- content (bytes or dict): The body of the request. JSON-encoded, if
- a dict.
+ method: The HTTP request method ("verb").
+ path: The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces
+ and such). content (bytes or dict): The body of the request.
+ JSON-encoded, if a dict.
shorthand: Whether to try and be helpful and prefix the given URL
with the usual REST API path, if it doesn't contain it.
federation_auth_origin: if set to not-None, we will add a fake
@@ -677,14 +659,29 @@ class HomeserverTestCase(TestCase):
username: str,
password: str,
device_id: Optional[str] = None,
+ additional_request_fields: Optional[Dict[str, str]] = None,
custom_headers: Optional[Iterable[CustomHeaderType]] = None,
) -> str:
"""
Log in a user, and get an access token. Requires the Login API be registered.
+
+ Args:
+ username: The localpart to assign to the new user.
+ password: The password to assign to the new user.
+ device_id: An optional device ID to assign to the new device created during
+ login.
+ additional_request_fields: A dictionary containing any additional /login
+ request fields and their values.
+ custom_headers: Custom HTTP headers and values to add to the /login request.
+
+ Returns:
+ The newly registered user's Matrix ID.
"""
body = {"type": "m.login.password", "user": username, "password": password}
if device_id:
body["device_id"] = device_id
+ if additional_request_fields:
+ body.update(additional_request_fields)
channel = self.make_request(
"POST",
@@ -735,7 +732,9 @@ class HomeserverTestCase(TestCase):
event.internal_metadata.soft_failed = True
self.get_success(
- event_creator.handle_new_client_event(requester, event, context)
+ event_creator.handle_new_client_event(
+ requester, events_and_context=[(event, context)]
+ )
)
return event.event_id
diff --git a/tests/util/caches/test_cached_call.py b/tests/util/caches/test_cached_call.py
index 80b97167ba..9266f12590 100644
--- a/tests/util/caches/test_cached_call.py
+++ b/tests/util/caches/test_cached_call.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import NoReturn
from unittest.mock import Mock
from twisted.internet import defer
@@ -23,14 +24,14 @@ from tests.unittest import TestCase
class CachedCallTestCase(TestCase):
- def test_get(self):
+ def test_get(self) -> None:
"""
Happy-path test case: makes a couple of calls and makes sure they behave
correctly
"""
- d = Deferred()
+ d: "Deferred[int]" = Deferred()
- async def f():
+ async def f() -> int:
return await d
slow_call = Mock(side_effect=f)
@@ -43,7 +44,7 @@ class CachedCallTestCase(TestCase):
# now fire off a couple of calls
completed_results = []
- async def r():
+ async def r() -> None:
res = await cached_call.get()
completed_results.append(res)
@@ -69,12 +70,12 @@ class CachedCallTestCase(TestCase):
self.assertEqual(r3, 123)
slow_call.assert_not_called()
- def test_fast_call(self):
+ def test_fast_call(self) -> None:
"""
Test the behaviour when the underlying function completes immediately
"""
- async def f():
+ async def f() -> int:
return 12
fast_call = Mock(side_effect=f)
@@ -92,12 +93,12 @@ class CachedCallTestCase(TestCase):
class RetryOnExceptionCachedCallTestCase(TestCase):
- def test_get(self):
+ def test_get(self) -> None:
# set up the RetryOnExceptionCachedCall around a function which will fail
# (after a while)
- d = Deferred()
+ d: "Deferred[int]" = Deferred()
- async def f1():
+ async def f1() -> NoReturn:
await d
raise ValueError("moo")
@@ -110,7 +111,7 @@ class RetryOnExceptionCachedCallTestCase(TestCase):
# now fire off a couple of calls
completed_results = []
- async def r():
+ async def r() -> None:
try:
await cached_call.get()
except Exception as e1:
@@ -137,7 +138,7 @@ class RetryOnExceptionCachedCallTestCase(TestCase):
# to the getter
d = Deferred()
- async def f2():
+ async def f2() -> int:
return await d
slow_call.reset_mock()
diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py
index 02b99b466a..f74d82b1dc 100644
--- a/tests/util/caches/test_deferred_cache.py
+++ b/tests/util/caches/test_deferred_cache.py
@@ -13,6 +13,7 @@
# limitations under the License.
from functools import partial
+from typing import List, Tuple
from twisted.internet import defer
@@ -22,20 +23,20 @@ from tests.unittest import TestCase
class DeferredCacheTestCase(TestCase):
- def test_empty(self):
- cache = DeferredCache("test")
+ def test_empty(self) -> None:
+ cache: DeferredCache[str, int] = DeferredCache("test")
with self.assertRaises(KeyError):
cache.get("foo")
- def test_hit(self):
- cache = DeferredCache("test")
+ def test_hit(self) -> None:
+ cache: DeferredCache[str, int] = DeferredCache("test")
cache.prefill("foo", 123)
self.assertEqual(self.successResultOf(cache.get("foo")), 123)
- def test_hit_deferred(self):
- cache = DeferredCache("test")
- origin_d = defer.Deferred()
+ def test_hit_deferred(self) -> None:
+ cache: DeferredCache[str, int] = DeferredCache("test")
+ origin_d: "defer.Deferred[int]" = defer.Deferred()
set_d = cache.set("k1", origin_d)
# get should return an incomplete deferred
@@ -43,7 +44,7 @@ class DeferredCacheTestCase(TestCase):
self.assertFalse(get_d.called)
# add a callback that will make sure that the set_d gets called before the get_d
- def check1(r):
+ def check1(r: str) -> str:
self.assertTrue(set_d.called)
return r
@@ -55,16 +56,16 @@ class DeferredCacheTestCase(TestCase):
self.assertEqual(self.successResultOf(set_d), 99)
self.assertEqual(self.successResultOf(get_d), 99)
- def test_callbacks(self):
+ def test_callbacks(self) -> None:
"""Invalidation callbacks are called at the right time"""
- cache = DeferredCache("test")
+ cache: DeferredCache[str, int] = DeferredCache("test")
callbacks = set()
# start with an entry, with a callback
cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
# now replace that entry with a pending result
- origin_d = defer.Deferred()
+ origin_d: "defer.Deferred[int]" = defer.Deferred()
set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
# ... and also make a get request
@@ -89,15 +90,15 @@ class DeferredCacheTestCase(TestCase):
cache.prefill("k1", 30)
self.assertEqual(callbacks, {"set", "get"})
- def test_set_fail(self):
- cache = DeferredCache("test")
+ def test_set_fail(self) -> None:
+ cache: DeferredCache[str, int] = DeferredCache("test")
callbacks = set()
# start with an entry, with a callback
cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
# now replace that entry with a pending result
- origin_d = defer.Deferred()
+ origin_d: defer.Deferred = defer.Deferred()
set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
# ... and also make a get request
@@ -126,9 +127,9 @@ class DeferredCacheTestCase(TestCase):
cache.prefill("k1", 30)
self.assertEqual(callbacks, {"prefill", "get2"})
- def test_get_immediate(self):
- cache = DeferredCache("test")
- d1 = defer.Deferred()
+ def test_get_immediate(self) -> None:
+ cache: DeferredCache[str, int] = DeferredCache("test")
+ d1: "defer.Deferred[int]" = defer.Deferred()
cache.set("key1", d1)
# get_immediate should return default
@@ -142,27 +143,27 @@ class DeferredCacheTestCase(TestCase):
v = cache.get_immediate("key1", 1)
self.assertEqual(v, 2)
- def test_invalidate(self):
- cache = DeferredCache("test")
+ def test_invalidate(self) -> None:
+ cache: DeferredCache[Tuple[str], int] = DeferredCache("test")
cache.prefill(("foo",), 123)
cache.invalidate(("foo",))
with self.assertRaises(KeyError):
cache.get(("foo",))
- def test_invalidate_all(self):
- cache = DeferredCache("testcache")
+ def test_invalidate_all(self) -> None:
+ cache: DeferredCache[str, str] = DeferredCache("testcache")
callback_record = [False, False]
- def record_callback(idx):
+ def record_callback(idx: int) -> None:
callback_record[idx] = True
# add a couple of pending entries
- d1 = defer.Deferred()
+ d1: "defer.Deferred[str]" = defer.Deferred()
cache.set("key1", d1, partial(record_callback, 0))
- d2 = defer.Deferred()
+ d2: "defer.Deferred[str]" = defer.Deferred()
cache.set("key2", d2, partial(record_callback, 1))
# lookup should return pending deferreds
@@ -193,8 +194,8 @@ class DeferredCacheTestCase(TestCase):
with self.assertRaises(KeyError):
cache.get("key1", None)
- def test_eviction(self):
- cache = DeferredCache(
+ def test_eviction(self) -> None:
+ cache: DeferredCache[int, str] = DeferredCache(
"test", max_entries=2, apply_cache_factor_from_config=False
)
@@ -208,8 +209,8 @@ class DeferredCacheTestCase(TestCase):
cache.get(2)
cache.get(3)
- def test_eviction_lru(self):
- cache = DeferredCache(
+ def test_eviction_lru(self) -> None:
+ cache: DeferredCache[int, str] = DeferredCache(
"test", max_entries=2, apply_cache_factor_from_config=False
)
@@ -227,8 +228,8 @@ class DeferredCacheTestCase(TestCase):
cache.get(1)
cache.get(3)
- def test_eviction_iterable(self):
- cache = DeferredCache(
+ def test_eviction_iterable(self) -> None:
+ cache: DeferredCache[int, List[str]] = DeferredCache(
"test",
max_entries=3,
apply_cache_factor_from_config=False,
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 48e616ac74..13f1edd533 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -13,11 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Set
+from typing import Iterable, Set, Tuple, cast
from unittest import mock
from twisted.internet import defer, reactor
from twisted.internet.defer import CancelledError, Deferred
+from twisted.internet.interfaces import IReactorTime
from synapse.api.errors import SynapseError
from synapse.logging.context import (
@@ -28,7 +29,7 @@ from synapse.logging.context import (
make_deferred_yieldable,
)
from synapse.util.caches import descriptors
-from synapse.util.caches.descriptors import cached, cachedList, lru_cache
+from synapse.util.caches.descriptors import cached, cachedList
from tests import unittest
from tests.test_utils import get_awaitable_result
@@ -36,41 +37,9 @@ from tests.test_utils import get_awaitable_result
logger = logging.getLogger(__name__)
-class LruCacheDecoratorTestCase(unittest.TestCase):
- def test_base(self):
- class Cls:
- def __init__(self):
- self.mock = mock.Mock()
-
- @lru_cache()
- def fn(self, arg1, arg2):
- return self.mock(arg1, arg2)
-
- obj = Cls()
- obj.mock.return_value = "fish"
- r = obj.fn(1, 2)
- self.assertEqual(r, "fish")
- obj.mock.assert_called_once_with(1, 2)
- obj.mock.reset_mock()
-
- # a call with different params should call the mock again
- obj.mock.return_value = "chips"
- r = obj.fn(1, 3)
- self.assertEqual(r, "chips")
- obj.mock.assert_called_once_with(1, 3)
- obj.mock.reset_mock()
-
- # the two values should now be cached
- r = obj.fn(1, 2)
- self.assertEqual(r, "fish")
- r = obj.fn(1, 3)
- self.assertEqual(r, "chips")
- obj.mock.assert_not_called()
-
-
def run_on_reactor():
- d = defer.Deferred()
- reactor.callLater(0, d.callback, 0)
+ d: "Deferred[int]" = defer.Deferred()
+ cast(IReactorTime, reactor).callLater(0, d.callback, 0)
return make_deferred_yieldable(d)
@@ -256,7 +225,8 @@ class DescriptorTestCase(unittest.TestCase):
callbacks: Set[str] = set()
# set off an asynchronous request
- obj.result = origin_d = defer.Deferred()
+ origin_d: Deferred = defer.Deferred()
+ obj.result = origin_d
d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1"))
self.assertFalse(d1.called)
@@ -294,7 +264,7 @@ class DescriptorTestCase(unittest.TestCase):
"""Check that logcontexts are set and restored correctly when
using the cache."""
- complete_lookup = defer.Deferred()
+ complete_lookup: Deferred = defer.Deferred()
class Cls:
@descriptors.cached()
@@ -478,10 +448,10 @@ class DescriptorTestCase(unittest.TestCase):
@cached(cache_context=True)
async def func2(self, key, cache_context):
- return self.func3(key, on_invalidate=cache_context.invalidate)
+ return await self.func3(key, on_invalidate=cache_context.invalidate)
- @lru_cache(cache_context=True)
- def func3(self, key, cache_context):
+ @cached(cache_context=True)
+ async def func3(self, key, cache_context):
self.invalidate = cache_context.invalidate
return 42
@@ -804,10 +774,14 @@ class CachedListDescriptorTestCase(unittest.TestCase):
@descriptors.cachedList(cached_method_name="fn", list_name="args1")
async def list_fn(self, args1, arg2):
- assert current_context().name == "c1"
+ context = current_context()
+ assert isinstance(context, LoggingContext)
+ assert context.name == "c1"
# we want this to behave like an asynchronous function
await run_on_reactor()
- assert current_context().name == "c1"
+ context = current_context()
+ assert isinstance(context, LoggingContext)
+ assert context.name == "c1"
return self.mock(args1, arg2)
with LoggingContext("c1") as c1:
@@ -866,7 +840,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
return self.mock(args1)
obj = Cls()
- deferred_result = Deferred()
+ deferred_result: "Deferred[dict]" = Deferred()
obj.mock.return_value = deferred_result
# start off several concurrent lookups of the same key
@@ -1008,3 +982,34 @@ class CachedListDescriptorTestCase(unittest.TestCase):
obj.inner_context_was_finished, "Tried to restart a finished logcontext"
)
self.assertEqual(current_context(), SENTINEL_CONTEXT)
+
+ def test_num_args_mismatch(self):
+ """
+ Make sure someone does not accidentally use @cachedList on a method with
+ a mismatch in the number args to the underlying single cache method.
+ """
+
+ class Cls:
+ @descriptors.cached(tree=True)
+ def fn(self, room_id, event_id):
+ pass
+
+ # This is wrong ❌. `@cachedList` expects to be given the same number
+ # of arguments as the underlying cached function, just with one of
+ # the arguments being an iterable
+ @descriptors.cachedList(cached_method_name="fn", list_name="keys")
+ def list_fn(self, keys: Iterable[Tuple[str, str]]):
+ pass
+
+ # Corrected syntax ✅
+ #
+ # @cachedList(cached_method_name="fn", list_name="event_ids")
+ # async def list_fn(
+ # self, room_id: str, event_ids: Collection[str],
+ # )
+
+ obj = Cls()
+
+ # Make sure this raises an error about the arg mismatch
+ with self.assertRaises(TypeError):
+ obj.list_fn([("foo", "bar")])
diff --git a/tests/util/caches/test_response_cache.py b/tests/util/caches/test_response_cache.py
index 025b73e32f..f09eeecada 100644
--- a/tests/util/caches/test_response_cache.py
+++ b/tests/util/caches/test_response_cache.py
@@ -35,7 +35,7 @@ class ResponseCacheTestCase(TestCase):
(These have cache with a short timeout_ms=, shorter than will be tested through advancing the clock)
"""
- def setUp(self):
+ def setUp(self) -> None:
self.reactor, self.clock = get_clock()
def with_cache(self, name: str, ms: int = 0) -> ResponseCache:
@@ -49,7 +49,7 @@ class ResponseCacheTestCase(TestCase):
await self.clock.sleep(1)
return o
- def test_cache_hit(self):
+ def test_cache_hit(self) -> None:
cache = self.with_cache("keeping_cache", ms=9001)
expected_result = "howdy"
@@ -74,7 +74,7 @@ class ResponseCacheTestCase(TestCase):
"cache should still have the result",
)
- def test_cache_miss(self):
+ def test_cache_miss(self) -> None:
cache = self.with_cache("trashing_cache", ms=0)
expected_result = "howdy"
@@ -90,7 +90,7 @@ class ResponseCacheTestCase(TestCase):
)
self.assertCountEqual([], cache.keys(), "cache should not have the result now")
- def test_cache_expire(self):
+ def test_cache_expire(self) -> None:
cache = self.with_cache("short_cache", ms=1000)
expected_result = "howdy"
@@ -115,7 +115,7 @@ class ResponseCacheTestCase(TestCase):
self.reactor.pump((2,))
self.assertCountEqual([], cache.keys(), "cache should not have the result now")
- def test_cache_wait_hit(self):
+ def test_cache_wait_hit(self) -> None:
cache = self.with_cache("neutral_cache")
expected_result = "howdy"
@@ -131,7 +131,7 @@ class ResponseCacheTestCase(TestCase):
self.assertEqual(expected_result, self.successResultOf(wrap_d))
- def test_cache_wait_expire(self):
+ def test_cache_wait_expire(self) -> None:
cache = self.with_cache("medium_cache", ms=3000)
expected_result = "howdy"
@@ -162,7 +162,7 @@ class ResponseCacheTestCase(TestCase):
self.assertCountEqual([], cache.keys(), "cache should not have the result now")
@parameterized.expand([(True,), (False,)])
- def test_cache_context_nocache(self, should_cache: bool):
+ def test_cache_context_nocache(self, should_cache: bool) -> None:
"""If the callback clears the should_cache bit, the result should not be cached"""
cache = self.with_cache("medium_cache", ms=3000)
@@ -170,7 +170,7 @@ class ResponseCacheTestCase(TestCase):
call_count = 0
- async def non_caching(o: str, cache_context: ResponseCacheContext[int]):
+ async def non_caching(o: str, cache_context: ResponseCacheContext[int]) -> str:
nonlocal call_count
call_count += 1
await self.clock.sleep(1)
diff --git a/tests/util/caches/test_ttlcache.py b/tests/util/caches/test_ttlcache.py
index fe8314057d..679d1eb36b 100644
--- a/tests/util/caches/test_ttlcache.py
+++ b/tests/util/caches/test_ttlcache.py
@@ -20,11 +20,11 @@ from tests import unittest
class CacheTestCase(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.mock_timer = Mock(side_effect=lambda: 100.0)
- self.cache = TTLCache("test_cache", self.mock_timer)
+ self.cache: TTLCache[str, str] = TTLCache("test_cache", self.mock_timer)
- def test_get(self):
+ def test_get(self) -> None:
"""simple set/get tests"""
self.cache.set("one", "1", 10)
self.cache.set("two", "2", 20)
@@ -59,7 +59,7 @@ class CacheTestCase(unittest.TestCase):
self.assertEqual(self.cache._metrics.hits, 4)
self.assertEqual(self.cache._metrics.misses, 5)
- def test_expiry(self):
+ def test_expiry(self) -> None:
self.cache.set("one", "1", 10)
self.cache.set("two", "2", 20)
self.cache.set("three", "3", 30)
diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py
index 5d1aa025d1..6913de24b9 100644
--- a/tests/util/test_check_dependencies.py
+++ b/tests/util/test_check_dependencies.py
@@ -40,7 +40,10 @@ class TestDependencyChecker(TestCase):
def mock_installed_package(
self, distribution: Optional[DummyDistribution]
) -> Generator[None, None, None]:
- """Pretend that looking up any distribution yields the given `distribution`."""
+ """Pretend that looking up any package yields the given `distribution`.
+
+ If `distribution = None`, we pretend that the package is not installed.
+ """
def mock_distribution(name: str):
if distribution is None:
@@ -81,7 +84,7 @@ class TestDependencyChecker(TestCase):
self.assertRaises(DependencyException, check_requirements)
def test_checks_ignore_dev_dependencies(self) -> None:
- """Bot generic and per-extra checks should ignore dev dependencies."""
+ """Both generic and per-extra checks should ignore dev dependencies."""
with patch(
"synapse.util.check_dependencies.metadata.requires",
return_value=["dummypkg >= 1; extra == 'mypy'"],
@@ -142,3 +145,16 @@ class TestDependencyChecker(TestCase):
with self.mock_installed_package(new_release_candidate):
# should not raise
check_requirements()
+
+ def test_setuptools_rust_ignored(self) -> None:
+ """Test a workaround for a `poetry build` problem. Reproduces #13926."""
+ with patch(
+ "synapse.util.check_dependencies.metadata.requires",
+ return_value=["setuptools_rust >= 1.3"],
+ ):
+ with self.mock_installed_package(None):
+ # should not raise, even if setuptools_rust is not installed
+ check_requirements()
+ with self.mock_installed_package(old):
+ # We also ignore old versions of setuptools_rust
+ check_requirements()
diff --git a/tests/util/test_macaroons.py b/tests/util/test_macaroons.py
index 32125f7bb7..40754a4711 100644
--- a/tests/util/test_macaroons.py
+++ b/tests/util/test_macaroons.py
@@ -84,34 +84,6 @@ class MacaroonGeneratorTestCase(TestCase):
)
self.assertEqual(user_id, "@user:tesths")
- def test_short_term_login_token(self):
- """Test the generation and verification of short-term login tokens"""
- token = self.macaroon_generator.generate_short_term_login_token(
- user_id="@user:tesths",
- auth_provider_id="oidc",
- auth_provider_session_id="sid",
- duration_in_ms=2 * 60 * 1000,
- )
-
- info = self.macaroon_generator.verify_short_term_login_token(token)
- self.assertEqual(info.user_id, "@user:tesths")
- self.assertEqual(info.auth_provider_id, "oidc")
- self.assertEqual(info.auth_provider_session_id, "sid")
-
- # Raises with another secret key
- with self.assertRaises(MacaroonVerificationFailedException):
- self.other_macaroon_generator.verify_short_term_login_token(token)
-
- # Wait a minute
- self.reactor.pump([60])
- # Shouldn't raise
- self.macaroon_generator.verify_short_term_login_token(token)
- # Wait another minute
- self.reactor.pump([60])
- # Should raise since it expired
- with self.assertRaises(MacaroonVerificationFailedException):
- self.macaroon_generator.verify_short_term_login_token(token)
-
def test_oidc_session_token(self):
"""Test the generation and verification of OIDC session cookies"""
state = "arandomstate"
diff --git a/tests/utils.py b/tests/utils.py
index d2c6d1e852..045a8b5fa7 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -135,7 +135,6 @@ def default_config(
"enable_registration_captcha": False,
"macaroon_secret_key": "not even a little secret",
"password_providers": [],
- "worker_replication_url": "",
"worker_app": None,
"block_non_admin_invites": False,
"federation_domain_whitelist": None,
@@ -271,9 +270,7 @@ class MockClock:
*args: P.args,
**kwargs: P.kwargs,
) -> None:
- # This type-ignore should be redundant once we use a mypy release with
- # https://github.com/python/mypy/pull/12668.
- self.loopers.append(Looper(function, interval / 1000.0, self.now, args, kwargs)) # type: ignore[arg-type]
+ self.loopers.append(Looper(function, interval / 1000.0, self.now, args, kwargs))
def cancel_call_later(self, timer: Timer, ignore_errs: bool = False) -> None:
if timer.expired:
|