diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index f44c91a373..b7fc33dc94 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -15,6 +15,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from unittest.mock import patch
+
import jsonschema
from synapse.api.constants import EventContentFields
@@ -51,9 +53,8 @@ class FilteringTestCase(unittest.HomeserverTestCase):
{"presence": {"senders": ["@bar;pik.test.com"]}},
]
for filter in invalid_filters:
- with self.assertRaises(SynapseError) as check_filter_error:
+ with self.assertRaises(SynapseError):
self.filtering.check_valid_filter(filter)
- self.assertIsInstance(check_filter_error.exception, SynapseError)
def test_valid_filters(self):
valid_filters = [
@@ -119,12 +120,12 @@ class FilteringTestCase(unittest.HomeserverTestCase):
definition = {"types": ["m.room.message", "org.matrix.foo.bar"]}
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
- self.assertTrue(Filter(definition).check(event))
+ self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_types_works_with_wildcards(self):
definition = {"types": ["m.*", "org.matrix.foo.bar"]}
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
- self.assertTrue(Filter(definition).check(event))
+ self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_types_works_with_unknowns(self):
definition = {"types": ["m.room.message", "org.matrix.foo.bar"]}
@@ -133,24 +134,24 @@ class FilteringTestCase(unittest.HomeserverTestCase):
type="now.for.something.completely.different",
room_id="!foo:bar",
)
- self.assertFalse(Filter(definition).check(event))
+ self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_types_works_with_literals(self):
definition = {"not_types": ["m.room.message", "org.matrix.foo.bar"]}
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
- self.assertFalse(Filter(definition).check(event))
+ self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_types_works_with_wildcards(self):
definition = {"not_types": ["m.room.message", "org.matrix.*"]}
event = MockEvent(
sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar"
)
- self.assertFalse(Filter(definition).check(event))
+ self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_types_works_with_unknowns(self):
definition = {"not_types": ["m.*", "org.*"]}
event = MockEvent(sender="@foo:bar", type="com.nom.nom.nom", room_id="!foo:bar")
- self.assertTrue(Filter(definition).check(event))
+ self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_not_types_takes_priority_over_types(self):
definition = {
@@ -158,35 +159,35 @@ class FilteringTestCase(unittest.HomeserverTestCase):
"types": ["m.room.message", "m.room.topic"],
}
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
- self.assertFalse(Filter(definition).check(event))
+ self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_senders_works_with_literals(self):
definition = {"senders": ["@flibble:wibble"]}
event = MockEvent(
sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar"
)
- self.assertTrue(Filter(definition).check(event))
+ self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_senders_works_with_unknowns(self):
definition = {"senders": ["@flibble:wibble"]}
event = MockEvent(
sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar"
)
- self.assertFalse(Filter(definition).check(event))
+ self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_senders_works_with_literals(self):
definition = {"not_senders": ["@flibble:wibble"]}
event = MockEvent(
sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar"
)
- self.assertFalse(Filter(definition).check(event))
+ self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_senders_works_with_unknowns(self):
definition = {"not_senders": ["@flibble:wibble"]}
event = MockEvent(
sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar"
)
- self.assertTrue(Filter(definition).check(event))
+ self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_not_senders_takes_priority_over_senders(self):
definition = {
@@ -196,14 +197,14 @@ class FilteringTestCase(unittest.HomeserverTestCase):
event = MockEvent(
sender="@misspiggy:muppets", type="m.room.topic", room_id="!foo:bar"
)
- self.assertFalse(Filter(definition).check(event))
+ self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_rooms_works_with_literals(self):
definition = {"rooms": ["!secretbase:unknown"]}
event = MockEvent(
sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown"
)
- self.assertTrue(Filter(definition).check(event))
+ self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_rooms_works_with_unknowns(self):
definition = {"rooms": ["!secretbase:unknown"]}
@@ -212,7 +213,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
type="m.room.message",
room_id="!anothersecretbase:unknown",
)
- self.assertFalse(Filter(definition).check(event))
+ self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_rooms_works_with_literals(self):
definition = {"not_rooms": ["!anothersecretbase:unknown"]}
@@ -221,7 +222,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
type="m.room.message",
room_id="!anothersecretbase:unknown",
)
- self.assertFalse(Filter(definition).check(event))
+ self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_rooms_works_with_unknowns(self):
definition = {"not_rooms": ["!secretbase:unknown"]}
@@ -230,7 +231,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
type="m.room.message",
room_id="!anothersecretbase:unknown",
)
- self.assertTrue(Filter(definition).check(event))
+ self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_not_rooms_takes_priority_over_rooms(self):
definition = {
@@ -240,7 +241,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
event = MockEvent(
sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown"
)
- self.assertFalse(Filter(definition).check(event))
+ self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_combined_event(self):
definition = {
@@ -256,7 +257,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
type="m.room.message", # yup
room_id="!stage:unknown", # yup
)
- self.assertTrue(Filter(definition).check(event))
+ self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_combined_event_bad_sender(self):
definition = {
@@ -272,7 +273,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
type="m.room.message", # yup
room_id="!stage:unknown", # yup
)
- self.assertFalse(Filter(definition).check(event))
+ self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_combined_event_bad_room(self):
definition = {
@@ -288,7 +289,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
type="m.room.message", # yup
room_id="!piggyshouse:muppets", # nope
)
- self.assertFalse(Filter(definition).check(event))
+ self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_combined_event_bad_type(self):
definition = {
@@ -304,7 +305,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
type="muppets.misspiggy.kisses", # nope
room_id="!stage:unknown", # yup
)
- self.assertFalse(Filter(definition).check(event))
+ self.assertFalse(Filter(self.hs, definition)._check(event))
def test_filter_labels(self):
definition = {"org.matrix.labels": ["#fun"]}
@@ -315,7 +316,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
content={EventContentFields.LABELS: ["#fun"]},
)
- self.assertTrue(Filter(definition).check(event))
+ self.assertTrue(Filter(self.hs, definition)._check(event))
event = MockEvent(
sender="@foo:bar",
@@ -324,7 +325,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
content={EventContentFields.LABELS: ["#notfun"]},
)
- self.assertFalse(Filter(definition).check(event))
+ self.assertFalse(Filter(self.hs, definition)._check(event))
def test_filter_not_labels(self):
definition = {"org.matrix.not_labels": ["#fun"]}
@@ -335,7 +336,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
content={EventContentFields.LABELS: ["#fun"]},
)
- self.assertFalse(Filter(definition).check(event))
+ self.assertFalse(Filter(self.hs, definition)._check(event))
event = MockEvent(
sender="@foo:bar",
@@ -344,7 +345,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
content={EventContentFields.LABELS: ["#notfun"]},
)
- self.assertTrue(Filter(definition).check(event))
+ self.assertTrue(Filter(self.hs, definition)._check(event))
def test_filter_presence_match(self):
user_filter_json = {"presence": {"types": ["m.*"]}}
@@ -362,7 +363,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
)
- results = user_filter.filter_presence(events=events)
+ results = self.get_success(user_filter.filter_presence(events=events))
self.assertEquals(events, results)
def test_filter_presence_no_match(self):
@@ -386,7 +387,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
)
- results = user_filter.filter_presence(events=events)
+ results = self.get_success(user_filter.filter_presence(events=events))
self.assertEquals([], results)
def test_filter_room_state_match(self):
@@ -405,7 +406,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
)
- results = user_filter.filter_room_state(events=events)
+ results = self.get_success(user_filter.filter_room_state(events=events))
self.assertEquals(events, results)
def test_filter_room_state_no_match(self):
@@ -426,7 +427,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
)
- results = user_filter.filter_room_state(events)
+ results = self.get_success(user_filter.filter_room_state(events))
self.assertEquals([], results)
def test_filter_rooms(self):
@@ -441,10 +442,52 @@ class FilteringTestCase(unittest.HomeserverTestCase):
"!not_included:example.com", # Disallowed because not in rooms.
]
- filtered_room_ids = list(Filter(definition).filter_rooms(room_ids))
+ filtered_room_ids = list(Filter(self.hs, definition).filter_rooms(room_ids))
self.assertEquals(filtered_room_ids, ["!allowed:example.com"])
+ @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
+ def test_filter_relations(self):
+ events = [
+ # An event without a relation.
+ MockEvent(
+ event_id="$no_relation",
+ sender="@foo:bar",
+ type="org.matrix.custom.event",
+ room_id="!foo:bar",
+ ),
+ # An event with a relation.
+ MockEvent(
+ event_id="$with_relation",
+ sender="@foo:bar",
+ type="org.matrix.custom.event",
+ room_id="!foo:bar",
+ ),
+ # Non-EventBase objects get passed through.
+ {},
+ ]
+
+ # For the following tests we patch the datastore method (intead of injecting
+ # events). This is a bit cheeky, but tests the logic of _check_event_relations.
+
+ # Filter for a particular sender.
+ definition = {
+ "io.element.relation_senders": ["@foo:bar"],
+ }
+
+ async def events_have_relations(*args, **kwargs):
+ return ["$with_relation"]
+
+ with patch.object(
+ self.datastore, "events_have_relations", new=events_have_relations
+ ):
+ filtered_events = list(
+ self.get_success(
+ Filter(self.hs, definition)._check_event_relations(events)
+ )
+ )
+ self.assertEquals(filtered_events, events[1:])
+
def test_add_filter(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 1f6a924452..d6f14e2dba 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -272,7 +272,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
make_awaitable(([event], None))
)
- self.handler.notify_interested_services_ephemeral("receipt_key", 580)
+ self.handler.notify_interested_services_ephemeral(
+ "receipt_key", 580, ["@fakerecipient:example.com"]
+ )
self.mock_scheduler.submit_ephemeral_events_for_as.assert_called_once_with(
interested_service, [event]
)
@@ -300,7 +302,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
make_awaitable(([event], None))
)
- self.handler.notify_interested_services_ephemeral("receipt_key", 579)
+ self.handler.notify_interested_services_ephemeral(
+ "receipt_key", 580, ["@fakerecipient:example.com"]
+ )
self.mock_scheduler.submit_ephemeral_events_for_as.assert_not_called()
def _mkservice(self, is_interested, protocols=None):
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 7dd4a5a367..08e9730d4d 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -31,7 +31,10 @@ from tests.unittest import override_config
# (possibly experimental) login flows we expect to appear in the list after the normal
# ones
-ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}]
+ADDITIONAL_LOGIN_FLOWS = [
+ {"type": "m.login.application_service"},
+ {"type": "uk.half-shot.msc2778.login.application_service"},
+]
# a mock instance which the dummy auth providers delegate to, so we can see what's going
# on
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index db691c4c1c..cd6f2c77ae 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -193,7 +193,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_not_blocked(self):
- self.store.count_monthly_users = Mock(
+ # Type ignore: mypy doesn't like us assigning to methods.
+ self.store.count_monthly_users = Mock( # type: ignore[assignment]
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
)
# Ensure does not throw exception
@@ -201,7 +202,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_blocked(self):
- self.store.get_monthly_active_count = Mock(
+ # Type ignore: mypy doesn't like us assigning to methods.
+ self.store.get_monthly_active_count = Mock( # type: ignore[assignment]
return_value=make_awaitable(self.lots_of_users)
)
self.get_failure(
@@ -209,7 +211,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
ResourceLimitError,
)
- self.store.get_monthly_active_count = Mock(
+ # Type ignore: mypy doesn't like us assigning to methods.
+ self.store.get_monthly_active_count = Mock( # type: ignore[assignment]
return_value=make_awaitable(self.hs.config.server.max_mau_value)
)
self.get_failure(
diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py
index d3d0bf1ac5..7b95844b55 100644
--- a/tests/handlers/test_room_summary.py
+++ b/tests/handlers/test_room_summary.py
@@ -14,6 +14,8 @@
from typing import Any, Iterable, List, Optional, Tuple
from unittest import mock
+from twisted.internet.defer import ensureDeferred
+
from synapse.api.constants import (
EventContentFields,
EventTypes,
@@ -316,6 +318,59 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
AuthError,
)
+ def test_room_hierarchy_cache(self) -> None:
+ """In-flight room hierarchy requests are deduplicated."""
+ # Run two `get_room_hierarchy` calls up until they block.
+ deferred1 = ensureDeferred(
+ self.handler.get_room_hierarchy(self.user, self.space)
+ )
+ deferred2 = ensureDeferred(
+ self.handler.get_room_hierarchy(self.user, self.space)
+ )
+
+ # Complete the two calls.
+ result1 = self.get_success(deferred1)
+ result2 = self.get_success(deferred2)
+
+ # Both `get_room_hierarchy` calls should return the same result.
+ expected = [(self.space, [self.room]), (self.room, ())]
+ self._assert_hierarchy(result1, expected)
+ self._assert_hierarchy(result2, expected)
+ self.assertIs(result1, result2)
+
+ # A subsequent `get_room_hierarchy` call should not reuse the result.
+ result3 = self.get_success(
+ self.handler.get_room_hierarchy(self.user, self.space)
+ )
+ self._assert_hierarchy(result3, expected)
+ self.assertIsNot(result1, result3)
+
+ def test_room_hierarchy_cache_sharing(self) -> None:
+ """Room hierarchy responses for different users are not shared."""
+ user2 = self.register_user("user2", "pass")
+
+ # Make the room within the space invite-only.
+ self.helper.send_state(
+ self.room,
+ event_type=EventTypes.JoinRules,
+ body={"join_rule": JoinRules.INVITE},
+ tok=self.token,
+ )
+
+ # Run two `get_room_hierarchy` calls for different users up until they block.
+ deferred1 = ensureDeferred(
+ self.handler.get_room_hierarchy(self.user, self.space)
+ )
+ deferred2 = ensureDeferred(self.handler.get_room_hierarchy(user2, self.space))
+
+ # Complete the two calls.
+ result1 = self.get_success(deferred1)
+ result2 = self.get_success(deferred2)
+
+ # The `get_room_hierarchy` calls should return different results.
+ self._assert_hierarchy(result1, [(self.space, [self.room]), (self.room, ())])
+ self._assert_hierarchy(result2, [(self.space, [self.room])])
+
def _create_room_with_join_rule(
self, join_rule: str, room_version: Optional[str] = None, **extra_content
) -> str:
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 339c039914..638186f173 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -13,10 +13,11 @@
# limitations under the License.
from typing import Optional
+from unittest.mock import Mock
from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import Codes, ResourceLimitError
-from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
+from synapse.api.filtering import Filtering
from synapse.api.room_versions import RoomVersions
from synapse.handlers.sync import SyncConfig
from synapse.rest import admin
@@ -197,7 +198,7 @@ def generate_sync_config(
_request_key += 1
return SyncConfig(
user=UserID.from_string(user_id),
- filter_collection=DEFAULT_FILTER_COLLECTION,
+ filter_collection=Filtering(Mock()).DEFAULT_FILTER_COLLECTION,
is_guest=False,
request_key=("request_key", _request_key),
device_id=device_id,
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index eac4664b41..cb02eddf07 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -12,13 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Callable, Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
from twisted.internet.protocol import Protocol
from twisted.web.resource import Resource
from synapse.app.generic_worker import GenericWorkerServer
-from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.client import ReplicationDataHandler
@@ -220,8 +219,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
unlike `BaseStreamTestCase`.
"""
- servlets: List[Callable[[HomeServer, JsonResource], None]] = []
-
def setUp(self):
super().setUp()
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 192073c520..af849bd471 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -474,3 +474,51 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
% server_and_media_id_2
),
)
+
+
+class PurgeHistoryTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+
+ self.room_id = self.helper.create_room_as(
+ self.other_user, tok=self.other_user_tok
+ )
+ self.url = f"/_synapse/admin/v1/purge_history/{self.room_id}"
+ self.url_status = "/_synapse/admin/v1/purge_history_status/"
+
+ def test_purge_history(self):
+ """
+ Simple test of purge history API.
+ Test only that is is possible to call, get status 200 and purge_id.
+ """
+
+ channel = self.make_request(
+ "POST",
+ self.url,
+ content={"delete_local_events": True, "purge_up_to_ts": 0},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertIn("purge_id", channel.json_body)
+ purge_id = channel.json_body["purge_id"]
+
+ # get status
+ channel = self.make_request(
+ "GET",
+ self.url_status + purge_id,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("complete", channel.json_body["status"])
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 46116644ce..b48fc12e5f 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -14,12 +14,16 @@
import json
import urllib.parse
+from http import HTTPStatus
from typing import List, Optional
from unittest.mock import Mock
+from parameterized import parameterized
+
import synapse.rest.admin
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import Codes
+from synapse.handlers.pagination import PaginationHandler
from synapse.rest.client import directory, events, login, room
from tests import unittest
@@ -68,11 +72,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
self.url,
- json.dumps({}),
+ {},
access_token=self.other_user_tok,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_room_does_not_exist(self):
@@ -84,11 +88,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
url,
- json.dumps({}),
+ {},
access_token=self.admin_user_tok,
)
- self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_room_is_not_valid(self):
@@ -100,11 +104,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
url,
- json.dumps({}),
+ {},
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"invalidroom is not a legal room ID",
channel.json_body["error"],
@@ -119,11 +123,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
self.url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("new_room_id", channel.json_body)
self.assertIn("kicked_users", channel.json_body)
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -138,11 +142,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
self.url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"User must be our own: @not:exist.bla",
channel.json_body["error"],
@@ -157,11 +161,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
self.url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_purge_is_not_bool(self):
@@ -173,11 +177,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
self.url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_purge_room_and_block(self):
@@ -199,11 +203,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
self.url.encode("ascii"),
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(None, channel.json_body["new_room_id"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -232,11 +236,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
self.url.encode("ascii"),
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(None, channel.json_body["new_room_id"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -266,11 +270,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
self.url.encode("ascii"),
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(None, channel.json_body["new_room_id"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -281,6 +285,31 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
self._is_blocked(self.room_id, expect=True)
self._has_no_members(self.room_id)
+ @parameterized.expand([(True,), (False,)])
+ def test_block_unknown_room(self, purge: bool) -> None:
+ """
+ We can block an unknown room. In this case, the `purge` argument
+ should be ignored.
+ """
+ room_id = "!unknown:test"
+
+ # The room isn't already in the blocked rooms table
+ self._is_blocked(room_id, expect=False)
+
+ # Request the room be blocked.
+ channel = self.make_request(
+ "DELETE",
+ f"/_synapse/admin/v1/rooms/{room_id}",
+ {"block": True, "purge": purge},
+ access_token=self.admin_user_tok,
+ )
+
+ # The room is now blocked.
+ self.assertEqual(
+ HTTPStatus.OK, int(channel.result["code"]), msg=channel.result["body"]
+ )
+ self._is_blocked(room_id)
+
def test_shutdown_room_consent(self):
"""Test that we can shutdown rooms with local users who have not
yet accepted the privacy policy. This used to fail when we tried to
@@ -316,7 +345,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("new_room_id", channel.json_body)
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -345,7 +374,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
json.dumps({"history_visibility": "world_readable"}),
access_token=self.other_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Test that room is not purged
with self.assertRaises(AssertionError):
@@ -362,7 +391,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("new_room_id", channel.json_body)
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -418,18 +447,617 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok
)
+ self.assertEqual(expect_code, channel.code, msg=channel.json_body)
+
+ url = "events?timeout=0&room_id=" + room_id
+ channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ )
+ self.assertEqual(expect_code, channel.code, msg=channel.json_body)
+
+
+class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ events.register_servlets,
+ room.register_servlets,
+ room.register_deprecated_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.event_creation_handler = hs.get_event_creation_handler()
+ hs.config.consent.user_consent_version = "1"
+
+ consent_uri_builder = Mock()
+ consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
+ self.event_creation_handler._consent_uri_builder = consent_uri_builder
+
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+
+ # Mark the admin user as having consented
+ self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
+
+ self.room_id = self.helper.create_room_as(
+ self.other_user, tok=self.other_user_tok
+ )
+ self.url = f"/_synapse/admin/v2/rooms/{self.room_id}"
+ self.url_status_by_room_id = (
+ f"/_synapse/admin/v2/rooms/{self.room_id}/delete_status"
+ )
+ self.url_status_by_delete_id = "/_synapse/admin/v2/rooms/delete_status/"
+
+ @parameterized.expand(
+ [
+ ("DELETE", "/_synapse/admin/v2/rooms/%s"),
+ ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"),
+ ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"),
+ ]
+ )
+ def test_requester_is_no_admin(self, method: str, url: str):
+ """
+ If the user is not a server admin, an error 403 is returned.
+ """
+
+ channel = self.make_request(
+ method,
+ url % self.room_id,
+ content={},
+ access_token=self.other_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ @parameterized.expand(
+ [
+ ("DELETE", "/_synapse/admin/v2/rooms/%s"),
+ ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"),
+ ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"),
+ ]
+ )
+ def test_room_does_not_exist(self, method: str, url: str):
+ """
+ Check that unknown rooms/server return error 404.
+ """
+
+ channel = self.make_request(
+ method,
+ url % "!unknown:test",
+ content={},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ @parameterized.expand(
+ [
+ ("DELETE", "/_synapse/admin/v2/rooms/%s"),
+ ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"),
+ ]
+ )
+ def test_room_is_not_valid(self, method: str, url: str):
+ """
+ Check that invalid room names, return an error 400.
+ """
+
+ channel = self.make_request(
+ method,
+ url % "invalidroom",
+ content={},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ "invalidroom is not a legal room ID",
+ channel.json_body["error"],
+ )
+
+ def test_new_room_user_does_not_exist(self):
+ """
+ Tests that the user ID must be from local server but it does not have to exist.
+ """
+
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ content={"new_room_user_id": "@unknown:test"},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertIn("delete_id", channel.json_body)
+ delete_id = channel.json_body["delete_id"]
+
+ self._test_result(delete_id, self.other_user, expect_new_room=True)
+
+ def test_new_room_user_is_not_local(self):
+ """
+ Check that only local users can create new room to move members.
+ """
+
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ content={"new_room_user_id": "@not:exist.bla"},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ "User must be our own: @not:exist.bla",
+ channel.json_body["error"],
+ )
+
+ def test_block_is_not_bool(self):
+ """
+ If parameter `block` is not boolean, return an error
+ """
+
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ content={"block": "NotBool"},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+ def test_purge_is_not_bool(self):
+ """
+ If parameter `purge` is not boolean, return an error
+ """
+
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ content={"purge": "NotBool"},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+ def test_delete_expired_status(self):
+ """Test that the task status is removed after expiration."""
+
+ # first task, do not purge, that we can create a second task
+ channel = self.make_request(
+ "DELETE",
+ self.url.encode("ascii"),
+ content={"purge": False},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertIn("delete_id", channel.json_body)
+ delete_id1 = channel.json_body["delete_id"]
+
+ # go ahead
+ self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2)
+
+ # second task
+ channel = self.make_request(
+ "DELETE",
+ self.url.encode("ascii"),
+ content={"purge": True},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertIn("delete_id", channel.json_body)
+ delete_id2 = channel.json_body["delete_id"]
+
+ # get status
+ channel = self.make_request(
+ "GET",
+ self.url_status_by_room_id,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(2, len(channel.json_body["results"]))
+ self.assertEqual("complete", channel.json_body["results"][0]["status"])
+ self.assertEqual("complete", channel.json_body["results"][1]["status"])
+ self.assertEqual(delete_id1, channel.json_body["results"][0]["delete_id"])
+ self.assertEqual(delete_id2, channel.json_body["results"][1]["delete_id"])
+
+ # get status after more than clearing time for first task
+ # second task is not cleared
+ self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2)
+
+ channel = self.make_request(
+ "GET",
+ self.url_status_by_room_id,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(1, len(channel.json_body["results"]))
+ self.assertEqual("complete", channel.json_body["results"][0]["status"])
+ self.assertEqual(delete_id2, channel.json_body["results"][0]["delete_id"])
+
+ # get status after more than clearing time for all tasks
+ self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2)
+
+ channel = self.make_request(
+ "GET",
+ self.url_status_by_room_id,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_delete_same_room_twice(self):
+ """Test that the call for delete a room at second time gives an exception."""
+
+ body = {"new_room_user_id": self.admin_user}
+
+ # first call to delete room
+ # and do not wait for finish the task
+ first_channel = self.make_request(
+ "DELETE",
+ self.url.encode("ascii"),
+ content=body,
+ access_token=self.admin_user_tok,
+ await_result=False,
+ )
+
+ # second call to delete room
+ second_channel = self.make_request(
+ "DELETE",
+ self.url.encode("ascii"),
+ content=body,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, second_channel.code, msg=second_channel.json_body
+ )
+ self.assertEqual(Codes.UNKNOWN, second_channel.json_body["errcode"])
+ self.assertEqual(
+ f"History purge already in progress for {self.room_id}",
+ second_channel.json_body["error"],
+ )
+
+ # get result of first call
+ first_channel.await_result()
+ self.assertEqual(HTTPStatus.OK, first_channel.code, msg=first_channel.json_body)
+ self.assertIn("delete_id", first_channel.json_body)
+
+ # check status after finish the task
+ self._test_result(
+ first_channel.json_body["delete_id"],
+ self.other_user,
+ expect_new_room=True,
+ )
+
+ def test_purge_room_and_block(self):
+ """Test to purge a room and block it.
+ Members will not be moved to a new room and will not receive a message.
+ """
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Test that room is not blocked
+ self._is_blocked(self.room_id, expect=False)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ channel = self.make_request(
+ "DELETE",
+ self.url.encode("ascii"),
+ content={"block": True, "purge": True},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertIn("delete_id", channel.json_body)
+ delete_id = channel.json_body["delete_id"]
+
+ self._test_result(delete_id, self.other_user)
+
+ self._is_purged(self.room_id)
+ self._is_blocked(self.room_id, expect=True)
+ self._has_no_members(self.room_id)
+
+ def test_purge_room_and_not_block(self):
+ """Test to purge a room and do not block it.
+ Members will not be moved to a new room and will not receive a message.
+ """
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Test that room is not blocked
+ self._is_blocked(self.room_id, expect=False)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ channel = self.make_request(
+ "DELETE",
+ self.url.encode("ascii"),
+ content={"block": False, "purge": True},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertIn("delete_id", channel.json_body)
+ delete_id = channel.json_body["delete_id"]
+
+ self._test_result(delete_id, self.other_user)
+
+ self._is_purged(self.room_id)
+ self._is_blocked(self.room_id, expect=False)
+ self._has_no_members(self.room_id)
+
+ def test_block_room_and_not_purge(self):
+ """Test to block a room without purging it.
+ Members will not be moved to a new room and will not receive a message.
+ The room will not be purged.
+ """
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Test that room is not blocked
+ self._is_blocked(self.room_id, expect=False)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ channel = self.make_request(
+ "DELETE",
+ self.url.encode("ascii"),
+ content={"block": True, "purge": False},
+ access_token=self.admin_user_tok,
)
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertIn("delete_id", channel.json_body)
+ delete_id = channel.json_body["delete_id"]
+
+ self._test_result(delete_id, self.other_user)
+
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+ self._is_blocked(self.room_id, expect=True)
+ self._has_no_members(self.room_id)
+
+ def test_shutdown_room_consent(self):
+ """Test that we can shutdown rooms with local users who have not
+ yet accepted the privacy policy. This used to fail when we tried to
+ force part the user from the old room.
+ Members will be moved to a new room and will receive a message.
+ """
+ self.event_creation_handler._block_events_without_consent_error = None
+
+ # Assert one user in room
+ users_in_room = self.get_success(self.store.get_users_in_room(self.room_id))
+ self.assertEqual([self.other_user], users_in_room)
+
+ # Enable require consent to send events
+ self.event_creation_handler._block_events_without_consent_error = "Error"
+
+ # Assert that the user is getting consent error
+ self.helper.send(
+ self.room_id, body="foo", tok=self.other_user_tok, expect_code=403
+ )
+
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ # Test that the admin can still send shutdown
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ content={"new_room_user_id": self.admin_user},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertIn("delete_id", channel.json_body)
+ delete_id = channel.json_body["delete_id"]
+
+ self._test_result(delete_id, self.other_user, expect_new_room=True)
+
+ channel = self.make_request(
+ "GET",
+ self.url_status_by_room_id,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(1, len(channel.json_body["results"]))
+
+ # Test that member has moved to new room
+ self._is_member(
+ room_id=channel.json_body["results"][0]["shutdown_room"]["new_room_id"],
+ user_id=self.other_user,
+ )
+
+ self._is_purged(self.room_id)
+ self._has_no_members(self.room_id)
+
+ def test_shutdown_room_block_peek(self):
+ """Test that a world_readable room can no longer be peeked into after
+ it has been shut down.
+ Members will be moved to a new room and will receive a message.
+ """
+ self.event_creation_handler._block_events_without_consent_error = None
+
+ # Enable world readable
+ url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,)
+ channel = self.make_request(
+ "PUT",
+ url.encode("ascii"),
+ content={"history_visibility": "world_readable"},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ # Test that the admin can still send shutdown
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ content={"new_room_user_id": self.admin_user},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertIn("delete_id", channel.json_body)
+ delete_id = channel.json_body["delete_id"]
+
+ self._test_result(delete_id, self.other_user, expect_new_room=True)
+
+ channel = self.make_request(
+ "GET",
+ self.url_status_by_room_id,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(1, len(channel.json_body["results"]))
+
+ # Test that member has moved to new room
+ self._is_member(
+ room_id=channel.json_body["results"][0]["shutdown_room"]["new_room_id"],
+ user_id=self.other_user,
+ )
+
+ self._is_purged(self.room_id)
+ self._has_no_members(self.room_id)
+
+ # Assert we can no longer peek into the room
+ self._assert_peek(self.room_id, expect_code=403)
+
+ def _is_blocked(self, room_id: str, expect: bool = True) -> None:
+ """Assert that the room is blocked or not"""
+ d = self.store.is_room_blocked(room_id)
+ if expect:
+ self.assertTrue(self.get_success(d))
+ else:
+ self.assertIsNone(self.get_success(d))
+
+ def _has_no_members(self, room_id: str) -> None:
+ """Assert there is now no longer anyone in the room"""
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertEqual([], users_in_room)
+
+ def _is_member(self, room_id: str, user_id: str) -> None:
+ """Test that user is member of the room"""
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertIn(user_id, users_in_room)
+
+ def _is_purged(self, room_id: str) -> None:
+ """Test that the following tables have been purged of all rows related to the room."""
+ for table in PURGE_TABLES:
+ count = self.get_success(
+ self.store.db_pool.simple_select_one_onecol(
+ table=table,
+ keyvalues={"room_id": room_id},
+ retcol="COUNT(*)",
+ desc="test_purge_room",
+ )
+ )
+
+ self.assertEqual(count, 0, msg=f"Rows not purged in {table}")
+
+ def _assert_peek(self, room_id: str, expect_code: int) -> None:
+ """Assert that the admin user can (or cannot) peek into the room."""
+
+ url = f"rooms/{room_id}/initialSync"
+ channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ )
+ self.assertEqual(expect_code, channel.code, msg=channel.json_body)
+
url = "events?timeout=0&room_id=" + room_id
channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok
)
+ self.assertEqual(expect_code, channel.code, msg=channel.json_body)
+
+ def _test_result(
+ self,
+ delete_id: str,
+ kicked_user: str,
+ expect_new_room: bool = False,
+ ) -> None:
+ """
+ Test that the result is the expected.
+ Uses both APIs (status by room_id and delete_id)
+
+ Args:
+ delete_id: id of this purge
+ kicked_user: a user_id which is kicked from the room
+ expect_new_room: if we expect that a new room was created
+ """
+
+ # get information by room_id
+ channel_room_id = self.make_request(
+ "GET",
+ self.url_status_by_room_id,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(
+ HTTPStatus.OK, channel_room_id.code, msg=channel_room_id.json_body
+ )
+ self.assertEqual(1, len(channel_room_id.json_body["results"]))
self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ delete_id, channel_room_id.json_body["results"][0]["delete_id"]
)
+ # get information by delete_id
+ channel_delete_id = self.make_request(
+ "GET",
+ self.url_status_by_delete_id + delete_id,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(
+ HTTPStatus.OK,
+ channel_delete_id.code,
+ msg=channel_delete_id.json_body,
+ )
+
+ # test values that are the same in both responses
+ for content in [
+ channel_room_id.json_body["results"][0],
+ channel_delete_id.json_body,
+ ]:
+ self.assertEqual("complete", content["status"])
+ self.assertEqual(kicked_user, content["shutdown_room"]["kicked_users"][0])
+ self.assertIn("failed_to_kick_users", content["shutdown_room"])
+ self.assertIn("local_aliases", content["shutdown_room"])
+ self.assertNotIn("error", content)
+
+ if expect_new_room:
+ self.assertIsNotNone(content["shutdown_room"]["new_room_id"])
+ else:
+ self.assertIsNone(content["shutdown_room"]["new_room_id"])
+
class RoomTestCase(unittest.HomeserverTestCase):
"""Test /room admin API."""
@@ -466,7 +1094,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
)
# Check request completed successfully
- self.assertEqual(200, int(channel.code), msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Check that response json body contains a "rooms" key
self.assertTrue(
@@ -550,9 +1178,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(
- 200, int(channel.result["code"]), msg=channel.result["body"]
- )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertTrue("rooms" in channel.json_body)
for r in channel.json_body["rooms"]:
@@ -592,7 +1218,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
def test_correct_room_attributes(self):
"""Test the correct attributes for a room are returned"""
@@ -615,7 +1241,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
{"room_id": room_id},
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Set this new alias as the canonical alias for this room
self.helper.send_state(
@@ -647,7 +1273,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Check that rooms were returned
self.assertTrue("rooms" in channel.json_body)
@@ -1107,7 +1733,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
{"room_id": room_id},
access_token=admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Set this new alias as the canonical alias for this room
self.helper.send_state(
@@ -1157,11 +1783,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
self.url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.second_tok,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self):
@@ -1173,11 +1799,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
self.url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
def test_local_user_does_not_exist(self):
@@ -1189,11 +1815,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
self.url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_remote_user(self):
@@ -1205,11 +1831,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
self.url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"This endpoint can only be used with local users",
channel.json_body["error"],
@@ -1225,11 +1851,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual("No known servers", channel.json_body["error"])
def test_room_is_not_valid(self):
@@ -1242,11 +1868,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"invalidroom was not legal room ID or room alias",
channel.json_body["error"],
@@ -1261,11 +1887,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
self.url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.public_room_id, channel.json_body["room_id"])
# Validate if user is a member of the room
@@ -1275,7 +1901,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.second_tok,
)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEquals(200, channel.code, msg=channel.json_body)
self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
def test_join_private_room_if_not_member(self):
@@ -1292,11 +1918,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_join_private_room_if_member(self):
@@ -1324,7 +1950,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.admin_user_tok,
)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEquals(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
# Join user to room.
@@ -1335,10 +1961,10 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["room_id"])
# Validate if user is a member of the room
@@ -1348,7 +1974,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.second_tok,
)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEquals(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
def test_join_private_room_if_owner(self):
@@ -1365,11 +1991,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["room_id"])
# Validate if user is a member of the room
@@ -1379,7 +2005,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.second_tok,
)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEquals(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
def test_context_as_non_admin(self):
@@ -1413,9 +2039,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
% (room_id, events[midway]["event_id"]),
access_token=tok,
)
- self.assertEquals(
- 403, int(channel.result["code"]), msg=channel.result["body"]
- )
+ self.assertEquals(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_context_as_admin(self):
@@ -1445,7 +2069,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
% (room_id, events[midway]["event_id"]),
access_token=self.admin_user_tok,
)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEquals(200, channel.code, msg=channel.json_body)
self.assertEquals(
channel.json_body["event"]["event_id"], events[midway]["event_id"]
)
@@ -1504,7 +2128,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Now we test that we can join the room and ban a user.
self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok)
@@ -1531,7 +2155,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Now we test that we can join the room (we should have received an
# invite) and can ban a user.
@@ -1557,7 +2181,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Now we test that we can join the room and ban a user.
self.helper.join(room_id, self.second_user_id, tok=self.second_tok)
@@ -1595,7 +2219,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
#
# (Note we assert the error message to ensure that it's not denied for
# some other reason)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
channel.json_body["error"],
"No local admin user in room with power to update power levels.",
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 25e8d6cf27..c9fe0f06c2 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -3592,31 +3592,34 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
self.other_user
)
- def test_no_auth(self):
+ @parameterized.expand(["POST", "DELETE"])
+ def test_no_auth(self, method: str):
"""
Try to get information of an user without authentication.
"""
- channel = self.make_request("POST", self.url)
+ channel = self.make_request(method, self.url)
self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_requester_is_not_admin(self):
+ @parameterized.expand(["POST", "DELETE"])
+ def test_requester_is_not_admin(self, method: str):
"""
If the user is not a server admin, an error is returned.
"""
other_user_token = self.login("user", "pass")
- channel = self.make_request("POST", self.url, access_token=other_user_token)
+ channel = self.make_request(method, self.url, access_token=other_user_token)
self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_user_is_not_local(self):
+ @parameterized.expand(["POST", "DELETE"])
+ def test_user_is_not_local(self, method: str):
"""
Tests that shadow-banning for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
- channel = self.make_request("POST", url, access_token=self.admin_user_tok)
+ channel = self.make_request(method, url, access_token=self.admin_user_tok)
self.assertEqual(400, channel.code, msg=channel.json_body)
def test_success(self):
@@ -3636,6 +3639,17 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
result = self.get_success(self.store.get_user_by_access_token(other_user_token))
self.assertTrue(result.shadow_banned)
+ # Un-shadow-ban the user.
+ channel = self.make_request(
+ "DELETE", self.url, access_token=self.admin_user_tok
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual({}, channel.json_body)
+
+ # Ensure the user is no longer shadow-banned (and the cache was cleared).
+ result = self.get_success(self.store.get_user_by_access_token(other_user_token))
+ self.assertFalse(result.shadow_banned)
+
class RateLimitTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py
index d2181ea907..aca03afd0e 100644
--- a/tests/rest/client/test_directory.py
+++ b/tests/rest/client/test_directory.py
@@ -11,12 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import json
+from http import HTTPStatus
+
+from twisted.test.proto_helpers import MemoryReactor
from synapse.rest import admin
from synapse.rest.client import directory, login, room
+from synapse.server import HomeServer
from synapse.types import RoomAlias
+from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
@@ -32,7 +36,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["require_membership_for_aliases"] = True
@@ -40,7 +44,11 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
return self.hs
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
+ """Create two local users and access tokens for them.
+ One of them creates a room."""
self.room_owner = self.register_user("room_owner", "test")
self.room_owner_tok = self.login("room_owner", "test")
@@ -51,39 +59,39 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.user = self.register_user("user", "test")
self.user_tok = self.login("user", "test")
- def test_state_event_not_in_room(self):
+ def test_state_event_not_in_room(self) -> None:
self.ensure_user_left_room()
- self.set_alias_via_state_event(403)
+ self.set_alias_via_state_event(HTTPStatus.FORBIDDEN)
- def test_directory_endpoint_not_in_room(self):
+ def test_directory_endpoint_not_in_room(self) -> None:
self.ensure_user_left_room()
- self.set_alias_via_directory(403)
+ self.set_alias_via_directory(HTTPStatus.FORBIDDEN)
- def test_state_event_in_room_too_long(self):
+ def test_state_event_in_room_too_long(self) -> None:
self.ensure_user_joined_room()
- self.set_alias_via_state_event(400, alias_length=256)
+ self.set_alias_via_state_event(HTTPStatus.BAD_REQUEST, alias_length=256)
- def test_directory_in_room_too_long(self):
+ def test_directory_in_room_too_long(self) -> None:
self.ensure_user_joined_room()
- self.set_alias_via_directory(400, alias_length=256)
+ self.set_alias_via_directory(HTTPStatus.BAD_REQUEST, alias_length=256)
@override_config({"default_room_version": 5})
- def test_state_event_user_in_v5_room(self):
+ def test_state_event_user_in_v5_room(self) -> None:
"""Test that a regular user can add alias events before room v6"""
self.ensure_user_joined_room()
- self.set_alias_via_state_event(200)
+ self.set_alias_via_state_event(HTTPStatus.OK)
@override_config({"default_room_version": 6})
- def test_state_event_v6_room(self):
+ def test_state_event_v6_room(self) -> None:
"""Test that a regular user can *not* add alias events from room v6"""
self.ensure_user_joined_room()
- self.set_alias_via_state_event(403)
+ self.set_alias_via_state_event(HTTPStatus.FORBIDDEN)
- def test_directory_in_room(self):
+ def test_directory_in_room(self) -> None:
self.ensure_user_joined_room()
- self.set_alias_via_directory(200)
+ self.set_alias_via_directory(HTTPStatus.OK)
- def test_room_creation_too_long(self):
+ def test_room_creation_too_long(self) -> None:
url = "/_matrix/client/r0/createRoom"
# We use deliberately a localpart under the length threshold so
@@ -93,9 +101,9 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST", url, request_data, access_token=self.user_tok
)
- self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
- def test_room_creation(self):
+ def test_room_creation(self) -> None:
url = "/_matrix/client/r0/createRoom"
# Check with an alias of allowed length. There should already be
@@ -106,9 +114,46 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST", url, request_data, access_token=self.user_tok
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+ def test_deleting_alias_via_directory(self) -> None:
+ # Add an alias for the room. We must be joined to do so.
+ self.ensure_user_joined_room()
+ alias = self.set_alias_via_directory(HTTPStatus.OK)
+
+ # Then try to remove the alias
+ channel = self.make_request(
+ "DELETE",
+ f"/_matrix/client/r0/directory/room/{alias}",
+ access_token=self.user_tok,
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+ def test_deleting_nonexistant_alias(self) -> None:
+ # Check that no alias exists
+ alias = "#potato:test"
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/r0/directory/room/{alias}",
+ access_token=self.user_tok,
+ )
+ self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
+ self.assertIn("error", channel.json_body, channel.json_body)
+ self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND", channel.json_body)
+
+ # Then try to remove the alias
+ channel = self.make_request(
+ "DELETE",
+ f"/_matrix/client/r0/directory/room/{alias}",
+ access_token=self.user_tok,
+ )
+ self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
+ self.assertIn("error", channel.json_body, channel.json_body)
+ self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND", channel.json_body)
- def set_alias_via_state_event(self, expected_code, alias_length=5):
+ def set_alias_via_state_event(
+ self, expected_code: HTTPStatus, alias_length: int = 5
+ ) -> None:
url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % (
self.room_id,
self.hs.hostname,
@@ -122,8 +167,11 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, expected_code, channel.result)
- def set_alias_via_directory(self, expected_code, alias_length=5):
- url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length)
+ def set_alias_via_directory(
+ self, expected_code: HTTPStatus, alias_length: int = 5
+ ) -> str:
+ alias = self.random_alias(alias_length)
+ url = "/_matrix/client/r0/directory/room/%s" % alias
data = {"room_id": self.room_id}
request_data = json.dumps(data)
@@ -131,17 +179,18 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
"PUT", url, request_data, access_token=self.user_tok
)
self.assertEqual(channel.code, expected_code, channel.result)
+ return alias
- def random_alias(self, length):
+ def random_alias(self, length: int) -> str:
return RoomAlias(random_string(length), self.hs.hostname).to_string()
- def ensure_user_left_room(self):
+ def ensure_user_left_room(self) -> None:
self.ensure_membership("leave")
- def ensure_user_joined_room(self):
+ def ensure_user_joined_room(self) -> None:
self.ensure_membership("join")
- def ensure_membership(self, membership):
+ def ensure_membership(self, membership: str) -> None:
try:
if membership == "leave":
self.helper.leave(room=self.room_id, user=self.user, tok=self.user_tok)
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index a63f04bd41..0b90e3f803 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -79,7 +79,10 @@ EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fΓΆ&=o"')]
# (possibly experimental) login flows we expect to appear in the list after the normal
# ones
-ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}]
+ADDITIONAL_LOGIN_FLOWS = [
+ {"type": "m.login.application_service"},
+ {"type": "uk.half-shot.msc2778.login.application_service"},
+]
class LoginRestServletTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 376853fd65..10a4a4dc5e 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -25,7 +25,12 @@ from urllib import parse as urlparse
from twisted.internet import defer
import synapse.rest.admin
-from synapse.api.constants import EventContentFields, EventTypes, Membership
+from synapse.api.constants import (
+ EventContentFields,
+ EventTypes,
+ Membership,
+ RelationTypes,
+)
from synapse.api.errors import Codes, HttpResponseException
from synapse.handlers.pagination import PurgeStatus
from synapse.rest import admin
@@ -2157,6 +2162,153 @@ class LabelsTestCase(unittest.HomeserverTestCase):
return event_id
+class RelationsTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def default_config(self):
+ config = super().default_config()
+ config["experimental_features"] = {"msc3440_enabled": True}
+ return config
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("test", "test")
+ self.tok = self.login("test", "test")
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+ self.second_user_id = self.register_user("second", "test")
+ self.second_tok = self.login("second", "test")
+ self.helper.join(
+ room=self.room_id, user=self.second_user_id, tok=self.second_tok
+ )
+
+ self.third_user_id = self.register_user("third", "test")
+ self.third_tok = self.login("third", "test")
+ self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok)
+
+ # An initial event with a relation from second user.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "Message 1"},
+ tok=self.tok,
+ )
+ self.event_id_1 = res["event_id"]
+ self.helper.send_event(
+ room_id=self.room_id,
+ type="m.reaction",
+ content={
+ "m.relates_to": {
+ "rel_type": RelationTypes.ANNOTATION,
+ "event_id": self.event_id_1,
+ "key": "π",
+ }
+ },
+ tok=self.second_tok,
+ )
+
+ # Another event with a relation from third user.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "Message 2"},
+ tok=self.tok,
+ )
+ self.event_id_2 = res["event_id"]
+ self.helper.send_event(
+ room_id=self.room_id,
+ type="m.reaction",
+ content={
+ "m.relates_to": {
+ "rel_type": RelationTypes.REFERENCE,
+ "event_id": self.event_id_2,
+ }
+ },
+ tok=self.third_tok,
+ )
+
+ # An event with no relations.
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "No relations"},
+ tok=self.tok,
+ )
+
+ def _filter_messages(self, filter: JsonDict) -> List[JsonDict]:
+ """Make a request to /messages with a filter, returns the chunk of events."""
+ channel = self.make_request(
+ "GET",
+ "/rooms/%s/messages?filter=%s&dir=b" % (self.room_id, json.dumps(filter)),
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ return channel.json_body["chunk"]
+
+ def test_filter_relation_senders(self):
+ # Messages which second user reacted to.
+ filter = {"io.element.relation_senders": [self.second_user_id]}
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 1, chunk)
+ self.assertEqual(chunk[0]["event_id"], self.event_id_1)
+
+ # Messages which third user reacted to.
+ filter = {"io.element.relation_senders": [self.third_user_id]}
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 1, chunk)
+ self.assertEqual(chunk[0]["event_id"], self.event_id_2)
+
+ # Messages which either user reacted to.
+ filter = {
+ "io.element.relation_senders": [self.second_user_id, self.third_user_id]
+ }
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 2, chunk)
+ self.assertCountEqual(
+ [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
+ )
+
+ def test_filter_relation_type(self):
+ # Messages which have annotations.
+ filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]}
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 1, chunk)
+ self.assertEqual(chunk[0]["event_id"], self.event_id_1)
+
+ # Messages which have references.
+ filter = {"io.element.relation_types": [RelationTypes.REFERENCE]}
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 1, chunk)
+ self.assertEqual(chunk[0]["event_id"], self.event_id_2)
+
+ # Messages which have either annotations or references.
+ filter = {
+ "io.element.relation_types": [
+ RelationTypes.ANNOTATION,
+ RelationTypes.REFERENCE,
+ ]
+ }
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 2, chunk)
+ self.assertCountEqual(
+ [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
+ )
+
+ def test_filter_relation_senders_and_type(self):
+ # Messages which second user reacted to.
+ filter = {
+ "io.element.relation_senders": [self.second_user_id],
+ "io.element.relation_types": [RelationTypes.ANNOTATION],
+ }
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 1, chunk)
+ self.assertEqual(chunk[0]["event_id"], self.event_id_1)
+
+
class ContextTestCase(unittest.HomeserverTestCase):
servlets = [
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index ec0979850b..1af5e5cee5 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -19,10 +19,21 @@ import json
import re
import time
import urllib.parse
-from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union
+from typing import (
+ Any,
+ AnyStr,
+ Dict,
+ Iterable,
+ Mapping,
+ MutableMapping,
+ Optional,
+ Tuple,
+ overload,
+)
from unittest.mock import patch
import attr
+from typing_extensions import Literal
from twisted.web.resource import Resource
from twisted.web.server import Site
@@ -45,6 +56,32 @@ class RestHelper:
site = attr.ib(type=Site)
auth_user_id = attr.ib()
+ @overload
+ def create_room_as(
+ self,
+ room_creator: Optional[str] = ...,
+ is_public: Optional[bool] = ...,
+ room_version: Optional[str] = ...,
+ tok: Optional[str] = ...,
+ expect_code: Literal[200] = ...,
+ extra_content: Optional[Dict] = ...,
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ...,
+ ) -> str:
+ ...
+
+ @overload
+ def create_room_as(
+ self,
+ room_creator: Optional[str] = ...,
+ is_public: Optional[bool] = ...,
+ room_version: Optional[str] = ...,
+ tok: Optional[str] = ...,
+ expect_code: int = ...,
+ extra_content: Optional[Dict] = ...,
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ...,
+ ) -> Optional[str]:
+ ...
+
def create_room_as(
self,
room_creator: Optional[str] = None,
@@ -53,10 +90,8 @@ class RestHelper:
tok: Optional[str] = None,
expect_code: int = 200,
extra_content: Optional[Dict] = None,
- custom_headers: Optional[
- Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
- ] = None,
- ) -> str:
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
+ ) -> Optional[str]:
"""
Create a room.
@@ -99,6 +134,8 @@ class RestHelper:
if expect_code == 200:
return channel.json_body["room_id"]
+ else:
+ return None
def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
self.change_membership(
@@ -168,7 +205,7 @@ class RestHelper:
extra_data: Optional[dict] = None,
tok: Optional[str] = None,
expect_code: int = 200,
- expect_errcode: str = None,
+ expect_errcode: Optional[str] = None,
) -> None:
"""
Send a membership state event into a room.
@@ -227,9 +264,7 @@ class RestHelper:
txn_id=None,
tok=None,
expect_code=200,
- custom_headers: Optional[
- Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
- ] = None,
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
):
if body is None:
body = "body_text_here"
@@ -254,9 +289,7 @@ class RestHelper:
txn_id=None,
tok=None,
expect_code=200,
- custom_headers: Optional[
- Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
- ] = None,
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
):
if txn_id is None:
txn_id = "m%s" % (str(time.time()))
@@ -418,7 +451,7 @@ class RestHelper:
path,
content=image_data,
access_token=tok,
- custom_headers=[(b"Content-Length", str(image_length))],
+ custom_headers=[("Content-Length", str(image_length))],
)
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
@@ -503,7 +536,7 @@ class RestHelper:
went.
"""
- cookies = {}
+ cookies: Dict[str, str] = {}
# if we're doing a ui auth, hit the ui auth redirect endpoint
if ui_auth_session_id:
@@ -625,7 +658,13 @@ class RestHelper:
# hit the redirect url again with the right Host header, which should now issue
# a cookie and redirect to the SSO provider.
- location = channel.headers.getRawHeaders("Location")[0]
+ def get_location(channel: FakeChannel) -> str:
+ location_values = channel.headers.getRawHeaders("Location")
+ # Keep mypy happy by asserting that location_values is nonempty
+ assert location_values
+ return location_values[0]
+
+ location = get_location(channel)
parts = urllib.parse.urlsplit(location)
channel = make_request(
self.hs.get_reactor(),
@@ -639,7 +678,7 @@ class RestHelper:
assert channel.code == 302
channel.extract_cookies(cookies)
- return channel.headers.getRawHeaders("Location")[0]
+ return get_location(channel)
def initiate_sso_ui_auth(
self, ui_auth_session_id: str, cookies: MutableMapping[str, str]
diff --git a/tests/server.py b/tests/server.py
index 103351b487..40cf5b12c3 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -16,7 +16,17 @@ import json
import logging
from collections import deque
from io import SEEK_END, BytesIO
-from typing import Callable, Dict, Iterable, MutableMapping, Optional, Tuple, Union
+from typing import (
+ AnyStr,
+ Callable,
+ Dict,
+ Iterable,
+ MutableMapping,
+ Optional,
+ Tuple,
+ Type,
+ Union,
+)
import attr
from typing_extensions import Deque
@@ -217,14 +227,12 @@ def make_request(
path: Union[bytes, str],
content: Union[bytes, str, JsonDict] = b"",
access_token: Optional[str] = None,
- request: Request = SynapseRequest,
+ request: Type[Request] = SynapseRequest,
shorthand: bool = True,
federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False,
await_result: bool = True,
- custom_headers: Optional[
- Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
- ] = None,
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index a1ba99ff14..d37736edf8 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -11,19 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+from synapse.server import HomeServer
from synapse.types import UserID
+from synapse.util import Clock
from tests import unittest
class ProfileStoreTestCase(unittest.HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastore()
self.u_frank = UserID.from_string("@frank:test")
- def test_displayname(self):
+ def test_displayname(self) -> None:
self.get_success(self.store.create_profile(self.u_frank.localpart))
self.get_success(
@@ -48,7 +51,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.get_profile_displayname(self.u_frank.localpart))
)
- def test_avatar_url(self):
+ def test_avatar_url(self) -> None:
self.get_success(self.store.create_profile(self.u_frank.localpart))
self.get_success(
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 2873e22ccf..fccab733c0 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -161,6 +161,54 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(users.keys(), {self.u_alice, self.u_bob})
+ def test__null_byte_in_display_name_properly_handled(self):
+ room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
+
+ res = self.get_success(
+ self.store.db_pool.simple_select_list(
+ "room_memberships",
+ {"user_id": "@alice:test"},
+ ["display_name", "event_id"],
+ )
+ )
+ # Check that we only got one result back
+ self.assertEqual(len(res), 1)
+
+ # Check that alice's display name is "alice"
+ self.assertEqual(res[0]["display_name"], "alice")
+
+ # Grab the event_id to use later
+ event_id = res[0]["event_id"]
+
+ # Create a profile with the offending null byte in the display name
+ new_profile = {"displayname": "ali\u0000ce"}
+
+ # Ensure that the change goes smoothly and does not fail due to the null byte
+ self.helper.change_membership(
+ room,
+ self.u_alice,
+ self.u_alice,
+ "join",
+ extra_data=new_profile,
+ tok=self.t_alice,
+ )
+
+ res2 = self.get_success(
+ self.store.db_pool.simple_select_list(
+ "room_memberships",
+ {"user_id": "@alice:test"},
+ ["display_name", "event_id"],
+ )
+ )
+ # Check that we only have two results
+ self.assertEqual(len(res2), 2)
+
+ # Filter out the previous event using the event_id we grabbed above
+ row = [row for row in res2 if row["event_id"] != event_id]
+
+ # Check that alice's display name is now None
+ self.assertEqual(row[0]["display_name"], None)
+
class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
new file mode 100644
index 0000000000..ce782c7e1d
--- /dev/null
+++ b/tests/storage/test_stream.py
@@ -0,0 +1,207 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.filtering import Filter
+from synapse.events import EventBase
+from synapse.rest import admin
+from synapse.rest.client import login, room
+from synapse.types import JsonDict
+
+from tests.unittest import HomeserverTestCase
+
+
+class PaginationTestCase(HomeserverTestCase):
+ """
+ Test the pre-filtering done in the pagination code.
+
+ This is similar to some of the tests in tests.rest.client.test_rooms but here
+ we ensure that the filtering done in the database is applied successfully.
+ """
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def default_config(self):
+ config = super().default_config()
+ config["experimental_features"] = {"msc3440_enabled": True}
+ return config
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("test", "test")
+ self.tok = self.login("test", "test")
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+ self.second_user_id = self.register_user("second", "test")
+ self.second_tok = self.login("second", "test")
+ self.helper.join(
+ room=self.room_id, user=self.second_user_id, tok=self.second_tok
+ )
+
+ self.third_user_id = self.register_user("third", "test")
+ self.third_tok = self.login("third", "test")
+ self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok)
+
+ # An initial event with a relation from second user.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "Message 1"},
+ tok=self.tok,
+ )
+ self.event_id_1 = res["event_id"]
+ self.helper.send_event(
+ room_id=self.room_id,
+ type="m.reaction",
+ content={
+ "m.relates_to": {
+ "rel_type": RelationTypes.ANNOTATION,
+ "event_id": self.event_id_1,
+ "key": "π",
+ }
+ },
+ tok=self.second_tok,
+ )
+
+ # Another event with a relation from third user.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "Message 2"},
+ tok=self.tok,
+ )
+ self.event_id_2 = res["event_id"]
+ self.helper.send_event(
+ room_id=self.room_id,
+ type="m.reaction",
+ content={
+ "m.relates_to": {
+ "rel_type": RelationTypes.REFERENCE,
+ "event_id": self.event_id_2,
+ }
+ },
+ tok=self.third_tok,
+ )
+
+ # An event with no relations.
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "No relations"},
+ tok=self.tok,
+ )
+
+ def _filter_messages(self, filter: JsonDict) -> List[EventBase]:
+ """Make a request to /messages with a filter, returns the chunk of events."""
+
+ from_token = self.get_success(
+ self.hs.get_event_sources().get_current_token_for_pagination()
+ )
+
+ events, next_key = self.get_success(
+ self.hs.get_datastore().paginate_room_events(
+ room_id=self.room_id,
+ from_key=from_token.room_key,
+ to_key=None,
+ direction="b",
+ limit=10,
+ event_filter=Filter(self.hs, filter),
+ )
+ )
+
+ return events
+
+ def test_filter_relation_senders(self):
+ # Messages which second user reacted to.
+ filter = {"io.element.relation_senders": [self.second_user_id]}
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 1, chunk)
+ self.assertEqual(chunk[0].event_id, self.event_id_1)
+
+ # Messages which third user reacted to.
+ filter = {"io.element.relation_senders": [self.third_user_id]}
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 1, chunk)
+ self.assertEqual(chunk[0].event_id, self.event_id_2)
+
+ # Messages which either user reacted to.
+ filter = {
+ "io.element.relation_senders": [self.second_user_id, self.third_user_id]
+ }
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 2, chunk)
+ self.assertCountEqual(
+ [c.event_id for c in chunk], [self.event_id_1, self.event_id_2]
+ )
+
+ def test_filter_relation_type(self):
+ # Messages which have annotations.
+ filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]}
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 1, chunk)
+ self.assertEqual(chunk[0].event_id, self.event_id_1)
+
+ # Messages which have references.
+ filter = {"io.element.relation_types": [RelationTypes.REFERENCE]}
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 1, chunk)
+ self.assertEqual(chunk[0].event_id, self.event_id_2)
+
+ # Messages which have either annotations or references.
+ filter = {
+ "io.element.relation_types": [
+ RelationTypes.ANNOTATION,
+ RelationTypes.REFERENCE,
+ ]
+ }
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 2, chunk)
+ self.assertCountEqual(
+ [c.event_id for c in chunk], [self.event_id_1, self.event_id_2]
+ )
+
+ def test_filter_relation_senders_and_type(self):
+ # Messages which second user reacted to.
+ filter = {
+ "io.element.relation_senders": [self.second_user_id],
+ "io.element.relation_types": [RelationTypes.ANNOTATION],
+ }
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 1, chunk)
+ self.assertEqual(chunk[0].event_id, self.event_id_1)
+
+ def test_duplicate_relation(self):
+ """An event should only be returned once if there are multiple relations to it."""
+ self.helper.send_event(
+ room_id=self.room_id,
+ type="m.reaction",
+ content={
+ "m.relates_to": {
+ "rel_type": RelationTypes.ANNOTATION,
+ "event_id": self.event_id_1,
+ "key": "A",
+ }
+ },
+ tok=self.second_tok,
+ )
+
+ filter = {"io.element.relation_senders": [self.second_user_id]}
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 1, chunk)
+ self.assertEqual(chunk[0].event_id, self.event_id_1)
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 24fc77d7a7..3eef1c4c05 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -81,8 +81,6 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
origin,
event,
context,
- state=None,
- backfilled=False,
):
return context
diff --git a/tests/unittest.py b/tests/unittest.py
index a9b60b7eeb..c9a08a3420 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -20,7 +20,20 @@ import inspect
import logging
import secrets
import time
-from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
+from typing import (
+ Any,
+ AnyStr,
+ Callable,
+ ClassVar,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+)
from unittest.mock import Mock, patch
from canonicaljson import json
@@ -31,6 +44,7 @@ from twisted.python.threadpool import ThreadPool
from twisted.test.proto_helpers import MemoryReactor
from twisted.trial import unittest
from twisted.web.resource import Resource
+from twisted.web.server import Request
from synapse import events
from synapse.api.constants import EventTypes, Membership
@@ -45,6 +59,7 @@ from synapse.logging.context import (
current_context,
set_current_context,
)
+from synapse.rest import RegisterServletsFunc
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
@@ -81,16 +96,13 @@ def around(target):
return _around
-T = TypeVar("T")
-
-
class TestCase(unittest.TestCase):
"""A subclass of twisted.trial's TestCase which looks for 'loglevel'
attributes on both itself and its individual test methods, to override the
root logger's logging level while that test (case|method) runs."""
- def __init__(self, methodName, *args, **kwargs):
- super().__init__(methodName, *args, **kwargs)
+ def __init__(self, methodName: str):
+ super().__init__(methodName)
method = getattr(self, methodName)
@@ -204,18 +216,18 @@ class HomeserverTestCase(TestCase):
config dict.
Attributes:
- servlets (list[function]): List of servlet registration function.
+ servlets: List of servlet registration function.
user_id (str): The user ID to assume if auth is hijacked.
- hijack_auth (bool): Whether to hijack auth to return the user specified
+ hijack_auth: Whether to hijack auth to return the user specified
in user_id.
"""
- servlets = []
- hijack_auth = True
- needs_threadpool = False
+ hijack_auth: ClassVar[bool] = True
+ needs_threadpool: ClassVar[bool] = False
+ servlets: ClassVar[List[RegisterServletsFunc]] = []
- def __init__(self, methodName, *args, **kwargs):
- super().__init__(methodName, *args, **kwargs)
+ def __init__(self, methodName: str):
+ super().__init__(methodName)
# see if we have any additional config for this test
method = getattr(self, methodName)
@@ -287,9 +299,10 @@ class HomeserverTestCase(TestCase):
None,
)
- self.hs.get_auth().get_user_by_req = get_user_by_req
- self.hs.get_auth().get_user_by_access_token = get_user_by_access_token
- self.hs.get_auth().get_access_token_from_request = Mock(
+ # Type ignore: mypy doesn't like us assigning to methods.
+ self.hs.get_auth().get_user_by_req = get_user_by_req # type: ignore[assignment]
+ self.hs.get_auth().get_user_by_access_token = get_user_by_access_token # type: ignore[assignment]
+ self.hs.get_auth().get_access_token_from_request = Mock( # type: ignore[assignment]
return_value="1234"
)
@@ -403,14 +416,12 @@ class HomeserverTestCase(TestCase):
path: Union[bytes, str],
content: Union[bytes, str, JsonDict] = b"",
access_token: Optional[str] = None,
- request: Type[T] = SynapseRequest,
+ request: Type[Request] = SynapseRequest,
shorthand: bool = True,
- federation_auth_origin: str = None,
+ federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False,
await_result: bool = True,
- custom_headers: Optional[
- Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
- ] = None,
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""
@@ -425,7 +436,7 @@ class HomeserverTestCase(TestCase):
a dict.
shorthand: Whether to try and be helpful and prefix the given URL
with the usual REST API path, if it doesn't contain it.
- federation_auth_origin (bytes|None): if set to not-None, we will add a fake
+ federation_auth_origin: if set to not-None, we will add a fake
Authorization header pretenting to be the given server name.
content_is_form: Whether the content is URL encoded form data. Adds the
'Content-Type': 'application/x-www-form-urlencoded' header.
@@ -584,7 +595,7 @@ class HomeserverTestCase(TestCase):
nonce_str += b"\x00notadmin"
want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str)
- want_mac = want_mac.hexdigest()
+ want_mac_digest = want_mac.hexdigest()
body = json.dumps(
{
@@ -593,7 +604,7 @@ class HomeserverTestCase(TestCase):
"displayname": displayname,
"password": password,
"admin": admin,
- "mac": want_mac,
+ "mac": want_mac_digest,
"inhibit_login": True,
}
)
@@ -639,9 +650,7 @@ class HomeserverTestCase(TestCase):
username,
password,
device_id=None,
- custom_headers: Optional[
- Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
- ] = None,
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
):
"""
Log in a user, and get an access token. Requires the Login API be
|