diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index cc52c3dfac..1a3ccb263d 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -321,3 +321,102 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
per_dest_queue._last_successful_stream_ordering,
event_5.internal_metadata.stream_ordering,
)
+
+ @override_config({"send_federation": True})
+ def test_catch_up_on_synapse_startup(self):
+ """
+ Tests the behaviour of get_catch_up_outstanding_destinations and
+ _wake_destinations_needing_catchup.
+ """
+
+ # list of sorted server names (note that there are more servers than the batch
+ # size used in get_catch_up_outstanding_destinations).
+ server_names = ["server%02d" % number for number in range(42)] + ["zzzerver"]
+
+ # ARRANGE:
+ # - a local user (u1)
+ # - a room which u1 is joined to (and remote users @user:serverXX are
+ # joined to)
+
+ # mark the remotes as online
+ self.is_online = True
+
+ self.register_user("u1", "you the one")
+ u1_token = self.login("u1", "you the one")
+ room_id = self.helper.create_room_as("u1", tok=u1_token)
+
+ for server_name in server_names:
+ self.get_success(
+ event_injection.inject_member_event(
+ self.hs, room_id, "@user:%s" % server_name, "join"
+ )
+ )
+
+ # create an event
+ self.helper.send(room_id, "deary me!", tok=u1_token)
+
+ # ASSERT:
+ # - All servers are up to date so none should have outstanding catch-up
+ outstanding_when_successful = self.get_success(
+ self.hs.get_datastore().get_catch_up_outstanding_destinations(None)
+ )
+ self.assertEqual(outstanding_when_successful, [])
+
+ # ACT:
+ # - Make the remote servers unreachable
+ self.is_online = False
+
+ # - Mark zzzerver as being backed-off from
+ now = self.clock.time_msec()
+ self.get_success(
+ self.hs.get_datastore().set_destination_retry_timings(
+ "zzzerver", now, now, 24 * 60 * 60 * 1000 # retry in 1 day
+ )
+ )
+
+ # - Send an event
+ self.helper.send(room_id, "can anyone hear me?", tok=u1_token)
+
+ # ASSERT (get_catch_up_outstanding_destinations):
+ # - all remotes are outstanding
+ # - they are returned in batches of 25, in order
+ outstanding_1 = self.get_success(
+ self.hs.get_datastore().get_catch_up_outstanding_destinations(None)
+ )
+
+ self.assertEqual(len(outstanding_1), 25)
+ self.assertEqual(outstanding_1, server_names[0:25])
+
+ outstanding_2 = self.get_success(
+ self.hs.get_datastore().get_catch_up_outstanding_destinations(
+ outstanding_1[-1]
+ )
+ )
+ self.assertNotIn("zzzerver", outstanding_2)
+ self.assertEqual(len(outstanding_2), 17)
+ self.assertEqual(outstanding_2, server_names[25:-1])
+
+ # ACT: call _wake_destinations_needing_catchup
+
+ # patch wake_destination to just count the destinations instead
+ woken = []
+
+ def wake_destination_track(destination):
+ woken.append(destination)
+
+ self.hs.get_federation_sender().wake_destination = wake_destination_track
+
+ # cancel the pre-existing timer for _wake_destinations_needing_catchup
+ # this is because we are calling it manually rather than waiting for it
+ # to be called automatically
+ self.hs.get_federation_sender()._catchup_after_startup_timer.cancel()
+
+ self.get_success(
+ self.hs.get_federation_sender()._wake_destinations_needing_catchup(), by=5.0
+ )
+
+ # ASSERT (_wake_destinations_needing_catchup):
+ # - all remotes are woken up, save for zzzerver
+ self.assertNotIn("zzzerver", woken)
+ # - all destinations are woken exactly once; they appear once in woken.
+ self.assertCountEqual(woken, server_names[:-1])
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 6aa322bf3a..969d44c787 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -35,6 +35,17 @@ class DeviceTestCase(unittest.HomeserverTestCase):
# These tests assume that it starts 1000 seconds in.
self.reactor.advance(1000)
+ def test_device_is_created_with_invalid_name(self):
+ self.get_failure(
+ self.handler.check_device_registered(
+ user_id="@boris:foo",
+ device_id="foo",
+ initial_device_display_name="a"
+ * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1),
+ ),
+ synapse.api.errors.SynapseError,
+ )
+
def test_device_is_created_if_doesnt_exist(self):
res = self.get_success(
self.handler.check_device_registered(
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 210ddcbb88..366dcfb670 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -30,7 +30,7 @@ from tests import unittest, utils
class E2eKeysHandlerTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs):
- super(E2eKeysHandlerTestCase, self).__init__(*args, **kwargs)
+ super().__init__(*args, **kwargs)
self.hs = None # type: synapse.server.HomeServer
self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 3362050ce0..7adde9b9de 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -47,7 +47,7 @@ room_keys = {
class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs):
- super(E2eRoomKeysHandlerTestCase, self).__init__(*args, **kwargs)
+ super().__init__(*args, **kwargs)
self.hs = None # type: synapse.server.HomeServer
self.handler = None # type: synapse.handlers.e2e_keys.E2eRoomKeysHandler
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 561258a356..bc578411d6 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -58,7 +58,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
# Patch up the equality operator for events so that we can check
# whether lists of events match using assertEquals
self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)]
- return super(SlavedEventStoreTestCase, self).setUp()
+ return super().setUp()
def prepare(self, *args, **kwargs):
super().prepare(*args, **kwargs)
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index faa7f381a9..92c9058887 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -221,7 +221,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"])
# Ensure the display name was not updated.
request, channel = self.make_request(
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
new file mode 100644
index 0000000000..bf79086f78
--- /dev/null
+++ b/tests/rest/admin/test_event_reports.py
@@ -0,0 +1,382 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import report_event
+
+from tests import unittest
+
+
+class EventReportsTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ report_event.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+
+ self.room_id1 = self.helper.create_room_as(
+ self.other_user, tok=self.other_user_tok, is_public=True
+ )
+ self.helper.join(self.room_id1, user=self.admin_user, tok=self.admin_user_tok)
+
+ self.room_id2 = self.helper.create_room_as(
+ self.other_user, tok=self.other_user_tok, is_public=True
+ )
+ self.helper.join(self.room_id2, user=self.admin_user, tok=self.admin_user_tok)
+
+ # Two rooms and two users. Every user sends and reports every room event
+ for i in range(5):
+ self._create_event_and_report(
+ room_id=self.room_id1, user_tok=self.other_user_tok,
+ )
+ for i in range(5):
+ self._create_event_and_report(
+ room_id=self.room_id2, user_tok=self.other_user_tok,
+ )
+ for i in range(5):
+ self._create_event_and_report(
+ room_id=self.room_id1, user_tok=self.admin_user_tok,
+ )
+ for i in range(5):
+ self._create_event_and_report(
+ room_id=self.room_id2, user_tok=self.admin_user_tok,
+ )
+
+ self.url = "/_synapse/admin/v1/event_reports"
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error 403 is returned.
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.other_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_default_success(self):
+ """
+ Testing list of reported events
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 20)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["event_reports"])
+
+ def test_limit(self):
+ """
+ Testing list of reported events with limit
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 5)
+ self.assertEqual(channel.json_body["next_token"], 5)
+ self._check_fields(channel.json_body["event_reports"])
+
+ def test_from(self):
+ """
+ Testing list of reported events with a defined starting point (from)
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?from=5", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 15)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["event_reports"])
+
+ def test_limit_and_from(self):
+ """
+ Testing list of reported events with a defined starting point and limit
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(channel.json_body["next_token"], 15)
+ self.assertEqual(len(channel.json_body["event_reports"]), 10)
+ self._check_fields(channel.json_body["event_reports"])
+
+ def test_filter_room(self):
+ """
+ Testing list of reported events with a filter of room
+ """
+
+ request, channel = self.make_request(
+ "GET",
+ self.url + "?room_id=%s" % self.room_id1,
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 10)
+ self.assertEqual(len(channel.json_body["event_reports"]), 10)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["event_reports"])
+
+ for report in channel.json_body["event_reports"]:
+ self.assertEqual(report["room_id"], self.room_id1)
+
+ def test_filter_user(self):
+ """
+ Testing list of reported events with a filter of user
+ """
+
+ request, channel = self.make_request(
+ "GET",
+ self.url + "?user_id=%s" % self.other_user,
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 10)
+ self.assertEqual(len(channel.json_body["event_reports"]), 10)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["event_reports"])
+
+ for report in channel.json_body["event_reports"]:
+ self.assertEqual(report["user_id"], self.other_user)
+
+ def test_filter_user_and_room(self):
+ """
+ Testing list of reported events with a filter of user and room
+ """
+
+ request, channel = self.make_request(
+ "GET",
+ self.url + "?user_id=%s&room_id=%s" % (self.other_user, self.room_id1),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 5)
+ self.assertEqual(len(channel.json_body["event_reports"]), 5)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["event_reports"])
+
+ for report in channel.json_body["event_reports"]:
+ self.assertEqual(report["user_id"], self.other_user)
+ self.assertEqual(report["room_id"], self.room_id1)
+
+ def test_valid_search_order(self):
+ """
+ Testing search order. Order by timestamps.
+ """
+
+ # fetch the most recent first, largest timestamp
+ request, channel = self.make_request(
+ "GET", self.url + "?dir=b", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 20)
+ report = 1
+ while report < len(channel.json_body["event_reports"]):
+ self.assertGreaterEqual(
+ channel.json_body["event_reports"][report - 1]["received_ts"],
+ channel.json_body["event_reports"][report]["received_ts"],
+ )
+ report += 1
+
+ # fetch the oldest first, smallest timestamp
+ request, channel = self.make_request(
+ "GET", self.url + "?dir=f", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 20)
+ report = 1
+ while report < len(channel.json_body["event_reports"]):
+ self.assertLessEqual(
+ channel.json_body["event_reports"][report - 1]["received_ts"],
+ channel.json_body["event_reports"][report]["received_ts"],
+ )
+ report += 1
+
+ def test_invalid_search_order(self):
+ """
+ Testing that a invalid search order returns a 400
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?dir=bar", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+ self.assertEqual("Unknown direction: bar", channel.json_body["error"])
+
+ def test_limit_is_negative(self):
+ """
+ Testing that a negative list parameter returns a 400
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ def test_from_is_negative(self):
+ """
+ Testing that a negative from parameter returns a 400
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ def test_next_token(self):
+ """
+ Testing that `next_token` appears at the right place
+ """
+
+ # `next_token` does not appear
+ # Number of results is the number of entries
+ request, channel = self.make_request(
+ "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 20)
+ self.assertNotIn("next_token", channel.json_body)
+
+ # `next_token` does not appear
+ # Number of max results is larger than the number of entries
+ request, channel = self.make_request(
+ "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 20)
+ self.assertNotIn("next_token", channel.json_body)
+
+ # `next_token` does appear
+ # Number of max results is smaller than the number of entries
+ request, channel = self.make_request(
+ "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 19)
+ self.assertEqual(channel.json_body["next_token"], 19)
+
+ # Check
+ # Set `from` to value of `next_token` for request remaining entries
+ # `next_token` does not appear
+ request, channel = self.make_request(
+ "GET", self.url + "?from=19", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 1)
+ self.assertNotIn("next_token", channel.json_body)
+
+ def _create_event_and_report(self, room_id, user_tok):
+ """Create and report events
+ """
+ resp = self.helper.send(room_id, tok=user_tok)
+ event_id = resp["event_id"]
+
+ request, channel = self.make_request(
+ "POST",
+ "rooms/%s/report/%s" % (room_id, event_id),
+ json.dumps({"score": -100, "reason": "this makes me sad"}),
+ access_token=user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ def _check_fields(self, content):
+ """Checks that all attributes are present in a event report
+ """
+ for c in content:
+ self.assertIn("id", c)
+ self.assertIn("received_ts", c)
+ self.assertIn("room_id", c)
+ self.assertIn("event_id", c)
+ self.assertIn("user_id", c)
+ self.assertIn("reason", c)
+ self.assertIn("content", c)
+ self.assertIn("sender", c)
+ self.assertIn("room_alias", c)
+ self.assertIn("event_json", c)
+ self.assertIn("score", c["content"])
+ self.assertIn("reason", c["content"])
+ self.assertIn("auth_events", c["event_json"])
+ self.assertIn("type", c["event_json"])
+ self.assertIn("room_id", c["event_json"])
+ self.assertIn("sender", c["event_json"])
+ self.assertIn("content", c["event_json"])
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index b8b7758d24..98d0623734 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -22,8 +22,8 @@ from mock import Mock
import synapse.rest.admin
from synapse.api.constants import UserTypes
-from synapse.api.errors import HttpResponseException, ResourceLimitError
-from synapse.rest.client.v1 import login
+from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
+from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import sync
from tests import unittest
@@ -874,6 +874,10 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self._is_erased("@user:test", False)
+ d = self.store.mark_user_erased("@user:test")
+ self.assertIsNone(self.get_success(d))
+ self._is_erased("@user:test", True)
# Attempt to reactivate the user (without a password).
request, channel = self.make_request(
@@ -906,6 +910,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
+ self._is_erased("@user:test", False)
def test_set_user_as_admin(self):
"""
@@ -995,3 +1000,104 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Ensure they're still alive
self.assertEqual(0, channel.json_body["deactivated"])
+
+ def _is_erased(self, user_id, expect):
+ """Assert that the user is erased or not
+ """
+ d = self.store.is_user_erased(user_id)
+ if expect:
+ self.assertTrue(self.get_success(d))
+ else:
+ self.assertFalse(self.get_success(d))
+
+
+class UserMembershipRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.url = "/_synapse/admin/v1/users/%s/joined_rooms" % urllib.parse.quote(
+ self.other_user
+ )
+
+ def test_no_auth(self):
+ """
+ Try to list rooms of an user without authentication.
+ """
+ request, channel = self.make_request("GET", self.url, b"{}")
+ self.render(request)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ other_user_token = self.login("user", "pass")
+
+ request, channel = self.make_request(
+ "GET", self.url, access_token=other_user_token,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_user_does_not_exist(self):
+ """
+ Tests that a lookup for a user that does not exist returns a 404
+ """
+ url = "/_synapse/admin/v1/users/@unknown_person:test/joined_rooms"
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_user_is_not_local(self):
+ """
+ Tests that a lookup for a user that is not a local returns a 400
+ """
+ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/joined_rooms"
+
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only lookup local users", channel.json_body["error"])
+
+ def test_get_rooms(self):
+ """
+ Tests that a normal lookup for rooms is successfully
+ """
+ # Create rooms and join
+ other_user_tok = self.login("user", "pass")
+ number_rooms = 5
+ for n in range(number_rooms):
+ self.helper.create_room_as(self.other_user, tok=other_user_tok)
+
+ # Get rooms
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(number_rooms, channel.json_body["total"])
+ self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 2668662c9e..5d987a30c7 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -7,8 +7,9 @@ from mock import Mock
import jwt
import synapse.rest.admin
+from synapse.appservice import ApplicationService
from synapse.rest.client.v1 import login, logout
-from synapse.rest.client.v2_alpha import devices
+from synapse.rest.client.v2_alpha import devices, register
from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
from tests import unittest
@@ -748,3 +749,134 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
channel.json_body["error"],
"JWT validation failed: Signature verification failed",
)
+
+
+AS_USER = "as_user_alice"
+
+
+class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ register.register_servlets,
+ ]
+
+ def register_as_user(self, username):
+ request, channel = self.make_request(
+ b"POST",
+ "/_matrix/client/r0/register?access_token=%s" % (self.service.token,),
+ {"username": username},
+ )
+ self.render(request)
+
+ def make_homeserver(self, reactor, clock):
+ self.hs = self.setup_test_homeserver()
+
+ self.service = ApplicationService(
+ id="unique_identifier",
+ token="some_token",
+ hostname="example.com",
+ sender="@asbot:example.com",
+ namespaces={
+ ApplicationService.NS_USERS: [
+ {"regex": r"@as_user.*", "exclusive": False}
+ ],
+ ApplicationService.NS_ROOMS: [],
+ ApplicationService.NS_ALIASES: [],
+ },
+ )
+ self.another_service = ApplicationService(
+ id="another__identifier",
+ token="another_token",
+ hostname="example.com",
+ sender="@as2bot:example.com",
+ namespaces={
+ ApplicationService.NS_USERS: [
+ {"regex": r"@as2_user.*", "exclusive": False}
+ ],
+ ApplicationService.NS_ROOMS: [],
+ ApplicationService.NS_ALIASES: [],
+ },
+ )
+
+ self.hs.get_datastore().services_cache.append(self.service)
+ self.hs.get_datastore().services_cache.append(self.another_service)
+ return self.hs
+
+ def test_login_appservice_user(self):
+ """Test that an appservice user can use /login
+ """
+ self.register_as_user(AS_USER)
+
+ params = {
+ "type": login.LoginRestServlet.APPSERVICE_TYPE,
+ "identifier": {"type": "m.id.user", "user": AS_USER},
+ }
+ request, channel = self.make_request(
+ b"POST", LOGIN_URL, params, access_token=self.service.token
+ )
+
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ def test_login_appservice_user_bot(self):
+ """Test that the appservice bot can use /login
+ """
+ self.register_as_user(AS_USER)
+
+ params = {
+ "type": login.LoginRestServlet.APPSERVICE_TYPE,
+ "identifier": {"type": "m.id.user", "user": self.service.sender},
+ }
+ request, channel = self.make_request(
+ b"POST", LOGIN_URL, params, access_token=self.service.token
+ )
+
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ def test_login_appservice_wrong_user(self):
+ """Test that non-as users cannot login with the as token
+ """
+ self.register_as_user(AS_USER)
+
+ params = {
+ "type": login.LoginRestServlet.APPSERVICE_TYPE,
+ "identifier": {"type": "m.id.user", "user": "fibble_wibble"},
+ }
+ request, channel = self.make_request(
+ b"POST", LOGIN_URL, params, access_token=self.service.token
+ )
+
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"403", channel.result)
+
+ def test_login_appservice_wrong_as(self):
+ """Test that as users cannot login with wrong as token
+ """
+ self.register_as_user(AS_USER)
+
+ params = {
+ "type": login.LoginRestServlet.APPSERVICE_TYPE,
+ "identifier": {"type": "m.id.user", "user": AS_USER},
+ }
+ request, channel = self.make_request(
+ b"POST", LOGIN_URL, params, access_token=self.another_service.token
+ )
+
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"403", channel.result)
+
+ def test_login_appservice_no_token(self):
+ """Test that users must provide a token when using the appservice
+ login method
+ """
+ self.register_as_user(AS_USER)
+
+ params = {
+ "type": login.LoginRestServlet.APPSERVICE_TYPE,
+ "identifier": {"type": "m.id.user", "user": AS_USER},
+ }
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
+
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"401", channel.result)
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index b090bb974c..dcd65c2a50 100644
--- a/tests/rest/test_well_known.py
+++ b/tests/rest/test_well_known.py
@@ -21,7 +21,7 @@ from tests import unittest
class WellKnownTests(unittest.HomeserverTestCase):
def setUp(self):
- super(WellKnownTests, self).setUp()
+ super().setUp()
# replace the JsonResource with a WellKnownResource
self.resource = WellKnownResource(self.hs)
diff --git a/tests/server.py b/tests/server.py
index 61ec670155..b404ad4e2a 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -260,7 +260,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
return succeed(lookups[name])
self.nameResolver = SimpleResolverComplexifier(FakeResolver())
- super(ThreadedMemoryReactorClock, self).__init__()
+ super().__init__()
def listenUDP(self, port, protocol, interface="", maxPacketSize=8196):
p = udp.Port(port, protocol, interface, maxPacketSize, self)
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index cb808d4de4..46f94914ff 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -413,7 +413,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
# required for ApplicationServiceTransactionStoreTestCase tests
class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(TestTransactionStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 34ae8c9da7..ecb00f4e02 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -23,7 +23,7 @@ import tests.utils
class DeviceStoreTestCase(tests.unittest.TestCase):
def __init__(self, *args, **kwargs):
- super(DeviceStoreTestCase, self).__init__(*args, **kwargs)
+ super().__init__(*args, **kwargs)
self.store = None # type: synapse.storage.DataStore
@defer.inlineCallbacks
diff --git a/tests/test_state.py b/tests/test_state.py
index 2d58467932..80b0ccbc40 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -125,7 +125,7 @@ class StateGroupStore:
class DictObj(dict):
def __init__(self, **kwargs):
- super(DictObj, self).__init__(kwargs)
+ super().__init__(kwargs)
self.__dict__ = self
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index 2d96b0fa8d..fdfb840b62 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -29,8 +29,7 @@ class ToTwistedHandler(logging.Handler):
log_entry = self.format(record)
log_level = record.levelname.lower().replace("warning", "warn")
self.tx_log.emit(
- twisted.logger.LogLevel.levelWithName(log_level),
- log_entry.replace("{", r"(").replace("}", r")"),
+ twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry
)
diff --git a/tests/unittest.py b/tests/unittest.py
index 128dd4e19c..dabf69cff4 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -92,7 +92,7 @@ class TestCase(unittest.TestCase):
root logger's logging level while that test (case|method) runs."""
def __init__(self, methodName, *args, **kwargs):
- super(TestCase, self).__init__(methodName, *args, **kwargs)
+ super().__init__(methodName, *args, **kwargs)
method = getattr(self, methodName)
|