diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index 27946febff..e00b5c171c 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -89,6 +89,17 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
self.store = hs.get_datastores().main
self.submit_token_resource = PasswordResetSubmitTokenResource(hs)
+ def attempt_wrong_password_login(self, username: str, password: str) -> None:
+ """Attempts to login as the user with the given password, asserting
+ that the attempt *fails*.
+ """
+ body = {"type": "m.login.password", "user": username, "password": password}
+
+ channel = self.make_request(
+ "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
+ )
+ self.assertEqual(channel.code, 403, channel.result)
+
def test_basic_password_reset(self) -> None:
"""Test basic password reset flow"""
old_password = "monkey"
diff --git a/tests/rest/client/test_account_data.py b/tests/rest/client/test_account_data.py
new file mode 100644
index 0000000000..d5b0640e7a
--- /dev/null
+++ b/tests/rest/client/test_account_data.py
@@ -0,0 +1,75 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest.mock import Mock
+
+from synapse.rest import admin
+from synapse.rest.client import account_data, login, room
+
+from tests import unittest
+from tests.test_utils import make_awaitable
+
+
+class AccountDataTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ account_data.register_servlets,
+ ]
+
+ def test_on_account_data_updated_callback(self) -> None:
+ """Tests that the on_account_data_updated module callback is called correctly when
+ a user's account data changes.
+ """
+ mocked_callback = Mock(return_value=make_awaitable(None))
+ self.hs.get_account_data_handler()._on_account_data_updated_callbacks.append(
+ mocked_callback
+ )
+
+ user_id = self.register_user("user", "password")
+ tok = self.login("user", "password")
+ account_data_type = "org.matrix.foo"
+ account_data_content = {"bar": "baz"}
+
+ # Change the user's global account data.
+ channel = self.make_request(
+ "PUT",
+ f"/user/{user_id}/account_data/{account_data_type}",
+ account_data_content,
+ access_token=tok,
+ )
+
+ # Test that the callback is called with the user ID, the new account data, and
+ # None as the room ID.
+ self.assertEqual(channel.code, 200, channel.result)
+ mocked_callback.assert_called_once_with(
+ user_id, None, account_data_type, account_data_content
+ )
+
+ # Change the user's room-specific account data.
+ room_id = self.helper.create_room_as(user_id, tok=tok)
+ channel = self.make_request(
+ "PUT",
+ f"/user/{user_id}/rooms/{room_id}/account_data/{account_data_type}",
+ account_data_content,
+ access_token=tok,
+ )
+
+ # Test that the callback is called with the user ID, the room ID and the new
+ # account data.
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(mocked_callback.call_count, 2)
+ mocked_callback.assert_called_with(
+ user_id, room_id, account_data_type, account_data_content
+ )
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index fe97a0b3dd..419eef166a 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import itertools
import urllib.parse
from typing import Any, Callable, Dict, List, Optional, Tuple
from unittest.mock import patch
@@ -145,16 +144,6 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
return channel.json_body["unsigned"].get("m.relations", {})
- def _get_aggregations(self) -> List[JsonDict]:
- """Request /aggregations on the parent ID and includes the returned chunk."""
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- return channel.json_body["chunk"]
-
def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
"""
Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
@@ -264,43 +253,6 @@ class RelationsTestCase(BaseRelationsTestCase):
expected_response_code=400,
)
- def test_aggregation(self) -> None:
- """Test that annotations get correctly aggregated."""
-
- self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
- )
- self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- self.assertEqual(
- channel.json_body,
- {
- "chunk": [
- {"type": "m.reaction", "key": "a", "count": 2},
- {"type": "m.reaction", "key": "b", "count": 1},
- ]
- },
- )
-
- def test_aggregation_must_be_annotation(self) -> None:
- """Test that aggregations must be annotations."""
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/aggregations"
- f"/{self.parent_id}/{RelationTypes.REPLACE}?limit=1",
- access_token=self.user_token,
- )
- self.assertEqual(400, channel.code, channel.json_body)
-
def test_ignore_invalid_room(self) -> None:
"""Test that we ignore invalid relations over federation."""
# Create another room and send a message in it.
@@ -394,15 +346,6 @@ class RelationsTestCase(BaseRelationsTestCase):
self.assertEqual(200, channel.code, channel.json_body)
self.assertEqual(channel.json_body["chunk"], [])
- # And when fetching aggregations.
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{room2}/aggregations/{parent_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(channel.json_body["chunk"], [])
-
# And for bundled aggregations.
channel = self.make_request(
"GET",
@@ -717,15 +660,6 @@ class RelationsTestCase(BaseRelationsTestCase):
self.assertEqual(200, channel.code, channel.json_body)
self.assertNotIn("m.relations", channel.json_body["unsigned"])
- # But unknown relations can be directly queried.
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(channel.json_body["chunk"], [])
-
def test_background_update(self) -> None:
"""Test the event_arbitrary_relations background update."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
@@ -941,131 +875,6 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
)
- def test_aggregation_pagination_groups(self) -> None:
- """Test that we can paginate annotation groups correctly."""
-
- # We need to create ten separate users to send each reaction.
- access_tokens = [self.user_token, self.user2_token]
- idx = 0
- while len(access_tokens) < 10:
- user_id, token = self._create_user("test" + str(idx))
- idx += 1
-
- self.helper.join(self.room, user=user_id, tok=token)
- access_tokens.append(token)
-
- idx = 0
- sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1}
- for key in itertools.chain.from_iterable(
- itertools.repeat(key, num) for key, num in sent_groups.items()
- ):
- self._send_relation(
- RelationTypes.ANNOTATION,
- "m.reaction",
- key=key,
- access_token=access_tokens[idx],
- )
-
- idx += 1
- idx %= len(access_tokens)
-
- prev_token: Optional[str] = None
- found_groups: Dict[str, int] = {}
- for _ in range(20):
- from_token = ""
- if prev_token:
- from_token = "&from=" + prev_token
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-
- for groups in channel.json_body["chunk"]:
- # We only expect reactions
- self.assertEqual(groups["type"], "m.reaction", channel.json_body)
-
- # We should only see each key once
- self.assertNotIn(groups["key"], found_groups, channel.json_body)
-
- found_groups[groups["key"]] = groups["count"]
-
- next_batch = channel.json_body.get("next_batch")
-
- self.assertNotEqual(prev_token, next_batch)
- prev_token = next_batch
-
- if not prev_token:
- break
-
- self.assertEqual(sent_groups, found_groups)
-
- def test_aggregation_pagination_within_group(self) -> None:
- """Test that we can paginate within an annotation group."""
-
- # We need to create ten separate users to send each reaction.
- access_tokens = [self.user_token, self.user2_token]
- idx = 0
- while len(access_tokens) < 10:
- user_id, token = self._create_user("test" + str(idx))
- idx += 1
-
- self.helper.join(self.room, user=user_id, tok=token)
- access_tokens.append(token)
-
- idx = 0
- expected_event_ids = []
- for _ in range(10):
- channel = self._send_relation(
- RelationTypes.ANNOTATION,
- "m.reaction",
- key="👍",
- access_token=access_tokens[idx],
- )
- expected_event_ids.append(channel.json_body["event_id"])
-
- idx += 1
-
- # Also send a different type of reaction so that we test we don't see it
- self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
-
- prev_token = ""
- found_event_ids: List[str] = []
- encoded_key = urllib.parse.quote_plus("👍".encode())
- for _ in range(20):
- from_token = ""
- if prev_token:
- from_token = "&from=" + prev_token
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}"
- f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
- f"/m.reaction/{encoded_key}?limit=1{from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-
- found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
-
- next_batch = channel.json_body.get("next_batch")
-
- self.assertNotEqual(prev_token, next_batch)
- prev_token = next_batch
-
- if not prev_token:
- break
-
- # We paginated backwards, so reverse
- found_event_ids.reverse()
- self.assertEqual(found_event_ids, expected_event_ids)
-
class BundledAggregationsTestCase(BaseRelationsTestCase):
"""
@@ -1453,10 +1262,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
{"chunk": [{"type": "m.reaction", "key": "a", "count": 2}]},
)
- # Both relations appear in the aggregation.
- chunk = self._get_aggregations()
- self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 2}])
-
# Redact one of the reactions.
self._redact(to_redact_event_id)
@@ -1469,10 +1274,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
{"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
)
- # The unredacted aggregation should still exist.
- chunk = self._get_aggregations()
- self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 1}])
-
def test_redact_relation_thread(self) -> None:
"""
Test that thread replies are properly handled after the thread reply redacted.
@@ -1578,10 +1379,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
self.assertEqual(len(event_ids), 1)
self.assertIn(RelationTypes.ANNOTATION, relations)
- # The aggregation should exist.
- chunk = self._get_aggregations()
- self.assertEqual(chunk, [{"type": "m.reaction", "key": "👍", "count": 1}])
-
# Redact the original event.
self._redact(self.parent_id)
@@ -1594,10 +1391,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
{"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]},
)
- # There's nothing to aggregate.
- chunk = self._get_aggregations()
- self.assertEqual(chunk, [{"count": 1, "key": "👍", "type": "m.reaction"}])
-
@unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_redact_parent_thread(self) -> None:
"""
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 3a9617d6da..6ff79b9e2e 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -982,7 +982,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
super().prepare(reactor, clock, hs)
# profile changes expect that the user is actually registered
user = UserID.from_string(self.user_id)
- self.get_success(self.register_user(user.localpart, "supersecretpassword"))
+ self.register_user(user.localpart, "supersecretpassword")
@unittest.override_config(
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 4351013952..773c16a54c 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -341,7 +341,6 @@ class SyncKnockTestCase(
hs, self.room_id, self.user_id
)
- @override_config({"experimental_features": {"msc2403_enabled": True}})
def test_knock_room_state(self) -> None:
"""Tests that /sync returns state from a room after knocking on it."""
# Knock on a room
@@ -497,6 +496,11 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
receipts.register_servlets,
]
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["experimental_features"] = {"msc2654_enabled": True}
+ return config
+
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.url = "/sync?since=%s"
self.next_batch = "s0"
@@ -772,3 +776,65 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
self.assertIn(
self.user_id, device_list_changes, incremental_sync_channel.json_body
)
+
+
+class ExcludeRoomTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
+ self.user_id = self.register_user("user", "password")
+ self.tok = self.login("user", "password")
+
+ self.excluded_room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+ self.included_room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+ # We need to manually append the room ID, because we can't know the ID before
+ # creating the room, and we can't set the config after starting the homeserver.
+ self.hs.get_sync_handler().rooms_to_exclude.append(self.excluded_room_id)
+
+ def test_join_leave(self) -> None:
+ """Tests that rooms are correctly excluded from the 'join' and 'leave' sections of
+ sync responses.
+ """
+ channel = self.make_request("GET", "/sync", access_token=self.tok)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["join"])
+ self.assertIn(self.included_room_id, channel.json_body["rooms"]["join"])
+
+ self.helper.leave(self.excluded_room_id, self.user_id, tok=self.tok)
+ self.helper.leave(self.included_room_id, self.user_id, tok=self.tok)
+
+ channel = self.make_request(
+ "GET",
+ "/sync?since=" + channel.json_body["next_batch"],
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["leave"])
+ self.assertIn(self.included_room_id, channel.json_body["rooms"]["leave"])
+
+ def test_invite(self) -> None:
+ """Tests that rooms are correctly excluded from the 'invite' section of sync
+ responses.
+ """
+ invitee = self.register_user("invitee", "password")
+ invitee_tok = self.login("invitee", "password")
+
+ self.helper.invite(self.excluded_room_id, self.user_id, invitee, tok=self.tok)
+ self.helper.invite(self.included_room_id, self.user_id, invitee, tok=self.tok)
+
+ channel = self.make_request("GET", "/sync", access_token=invitee_tok)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["invite"])
+ self.assertIn(self.included_room_id, channel.json_body["rooms"]["invite"])
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index e7de67e3a3..5eb0f243f7 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -896,3 +896,44 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
# Check that the mock was called with the right room ID
self.assertEqual(args[1], self.room_id)
+
+ def test_on_threepid_bind(self) -> None:
+ """Tests that the on_threepid_bind module callback is called correctly after
+ associating a 3PID to an account.
+ """
+ # Register a mocked callback.
+ threepid_bind_mock = Mock(return_value=make_awaitable(None))
+ third_party_rules = self.hs.get_third_party_event_rules()
+ third_party_rules._on_threepid_bind_callbacks.append(threepid_bind_mock)
+
+ # Register an admin user.
+ self.register_user("admin", "password", admin=True)
+ admin_tok = self.login("admin", "password")
+
+ # Also register a normal user we can modify.
+ user_id = self.register_user("user", "password")
+
+ # Add a 3PID to the user.
+ channel = self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/%s" % user_id,
+ {
+ "threepids": [
+ {
+ "medium": "email",
+ "address": "foo@example.com",
+ },
+ ],
+ },
+ access_token=admin_tok,
+ )
+
+ # Check that the shutdown was blocked
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check that the mock was called once.
+ threepid_bind_mock.assert_called_once()
+ args = threepid_bind_mock.call_args[0]
+
+ # Check that the mock was called with the right parameters
+ self.assertEqual(args, (user_id, "email", "foo@example.com"))
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 28663826fc..a0788b1bb0 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -88,7 +88,7 @@ class RestHelper:
def create_room_as(
self,
room_creator: Optional[str] = None,
- is_public: Optional[bool] = None,
+ is_public: Optional[bool] = True,
room_version: Optional[str] = None,
tok: Optional[str] = None,
expect_code: int = HTTPStatus.OK,
@@ -101,9 +101,12 @@ class RestHelper:
Args:
room_creator: The user ID to create the room with.
is_public: If True, the `visibility` parameter will be set to
- "public". If False, it will be set to "private". If left
- unspecified, the server will set it to an appropriate default
- (which should be "private" as per the CS spec).
+ "public". If False, it will be set to "private".
+ If None, doesn't specify the `visibility` parameter in which
+ case the server is supposed to make the room private according to
+ the CS API.
+ Defaults to public, since that is commonly needed in tests
+ for convenience where room privacy is not a problem.
room_version: The room version to create the room as. Defaults to Synapse's
default room version.
tok: The access token to use in the request.
|