diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index e5fc2fcd15..0342aed416 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -13,24 +13,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import hashlib
-import hmac
import json
+import os
+import urllib.parse
+from binascii import unhexlify
+from typing import List, Optional
from mock import Mock
+from twisted.internet.defer import Deferred
+
import synapse.rest.admin
-from synapse.api.constants import UserTypes
from synapse.http.server import JsonResource
+from synapse.logging.context import make_deferred_yieldable
from synapse.rest.admin import VersionServlet
-from synapse.rest.client.v1 import events, login, room
+from synapse.rest.client.v1 import directory, events, login, room
from synapse.rest.client.v2_alpha import groups
from tests import unittest
class VersionTestCase(unittest.HomeserverTestCase):
- url = '/_synapse/admin/v1/server_version'
+ url = "/_synapse/admin/v1/server_version"
def create_test_json_resource(self):
resource = JsonResource(self.hs)
@@ -43,304 +47,8 @@ class VersionTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
- {'server_version', 'python_version'}, set(channel.json_body.keys())
- )
-
-
-class UserRegisterTestCase(unittest.HomeserverTestCase):
-
- servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
-
- def make_homeserver(self, reactor, clock):
-
- self.url = "/_matrix/client/r0/admin/register"
-
- self.registration_handler = Mock()
- self.identity_handler = Mock()
- self.login_handler = Mock()
- self.device_handler = Mock()
- self.device_handler.check_device_registered = Mock(return_value="FAKE")
-
- self.datastore = Mock(return_value=Mock())
- self.datastore.get_current_state_deltas = Mock(return_value=[])
-
- self.secrets = Mock()
-
- self.hs = self.setup_test_homeserver()
-
- self.hs.config.registration_shared_secret = u"shared"
-
- self.hs.get_media_repository = Mock()
- self.hs.get_deactivate_account_handler = Mock()
-
- return self.hs
-
- def test_disabled(self):
- """
- If there is no shared secret, registration through this method will be
- prevented.
- """
- self.hs.config.registration_shared_secret = None
-
- request, channel = self.make_request("POST", self.url, b'{}')
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(
- 'Shared secret registration is not enabled', channel.json_body["error"]
- )
-
- def test_get_nonce(self):
- """
- Calling GET on the endpoint will return a randomised nonce, using the
- homeserver's secrets provider.
- """
- secrets = Mock()
- secrets.token_hex = Mock(return_value="abcd")
-
- self.hs.get_secrets = Mock(return_value=secrets)
-
- request, channel = self.make_request("GET", self.url)
- self.render(request)
-
- self.assertEqual(channel.json_body, {"nonce": "abcd"})
-
- def test_expired_nonce(self):
- """
- Calling GET on the endpoint will return a randomised nonce, which will
- only last for SALT_TIMEOUT (60s).
- """
- request, channel = self.make_request("GET", self.url)
- self.render(request)
- nonce = channel.json_body["nonce"]
-
- # 59 seconds
- self.reactor.advance(59)
-
- body = json.dumps({"nonce": nonce})
- request, channel = self.make_request("POST", self.url, body.encode('utf8'))
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual('username must be specified', channel.json_body["error"])
-
- # 61 seconds
- self.reactor.advance(2)
-
- request, channel = self.make_request("POST", self.url, body.encode('utf8'))
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual('unrecognised nonce', channel.json_body["error"])
-
- def test_register_incorrect_nonce(self):
- """
- Only the provided nonce can be used, as it's checked in the MAC.
- """
- request, channel = self.make_request("GET", self.url)
- self.render(request)
- nonce = channel.json_body["nonce"]
-
- want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
- want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin")
- want_mac = want_mac.hexdigest()
-
- body = json.dumps(
- {
- "nonce": nonce,
- "username": "bob",
- "password": "abc123",
- "admin": True,
- "mac": want_mac,
- }
- )
- request, channel = self.make_request("POST", self.url, body.encode('utf8'))
- self.render(request)
-
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual("HMAC incorrect", channel.json_body["error"])
-
- def test_register_correct_nonce(self):
- """
- When the correct nonce is provided, and the right key is provided, the
- user is registered.
- """
- request, channel = self.make_request("GET", self.url)
- self.render(request)
- nonce = channel.json_body["nonce"]
-
- want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
- want_mac.update(
- nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin\x00support"
- )
- want_mac = want_mac.hexdigest()
-
- body = json.dumps(
- {
- "nonce": nonce,
- "username": "bob",
- "password": "abc123",
- "admin": True,
- "user_type": UserTypes.SUPPORT,
- "mac": want_mac,
- }
- )
- request, channel = self.make_request("POST", self.url, body.encode('utf8'))
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual("@bob:test", channel.json_body["user_id"])
-
- def test_nonce_reuse(self):
- """
- A valid unrecognised nonce.
- """
- request, channel = self.make_request("GET", self.url)
- self.render(request)
- nonce = channel.json_body["nonce"]
-
- want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
- want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin")
- want_mac = want_mac.hexdigest()
-
- body = json.dumps(
- {
- "nonce": nonce,
- "username": "bob",
- "password": "abc123",
- "admin": True,
- "mac": want_mac,
- }
- )
- request, channel = self.make_request("POST", self.url, body.encode('utf8'))
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual("@bob:test", channel.json_body["user_id"])
-
- # Now, try and reuse it
- request, channel = self.make_request("POST", self.url, body.encode('utf8'))
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual('unrecognised nonce', channel.json_body["error"])
-
- def test_missing_parts(self):
- """
- Synapse will complain if you don't give nonce, username, password, and
- mac. Admin and user_types are optional. Additional checks are done for length
- and type.
- """
-
- def nonce():
- request, channel = self.make_request("GET", self.url)
- self.render(request)
- return channel.json_body["nonce"]
-
- #
- # Nonce check
- #
-
- # Must be present
- body = json.dumps({})
- request, channel = self.make_request("POST", self.url, body.encode('utf8'))
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual('nonce must be specified', channel.json_body["error"])
-
- #
- # Username checks
- #
-
- # Must be present
- body = json.dumps({"nonce": nonce()})
- request, channel = self.make_request("POST", self.url, body.encode('utf8'))
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual('username must be specified', channel.json_body["error"])
-
- # Must be a string
- body = json.dumps({"nonce": nonce(), "username": 1234})
- request, channel = self.make_request("POST", self.url, body.encode('utf8'))
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual('Invalid username', channel.json_body["error"])
-
- # Must not have null bytes
- body = json.dumps({"nonce": nonce(), "username": u"abcd\u0000"})
- request, channel = self.make_request("POST", self.url, body.encode('utf8'))
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual('Invalid username', channel.json_body["error"])
-
- # Must not have null bytes
- body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
- request, channel = self.make_request("POST", self.url, body.encode('utf8'))
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual('Invalid username', channel.json_body["error"])
-
- #
- # Password checks
- #
-
- # Must be present
- body = json.dumps({"nonce": nonce(), "username": "a"})
- request, channel = self.make_request("POST", self.url, body.encode('utf8'))
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual('password must be specified', channel.json_body["error"])
-
- # Must be a string
- body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
- request, channel = self.make_request("POST", self.url, body.encode('utf8'))
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual('Invalid password', channel.json_body["error"])
-
- # Must not have null bytes
- body = json.dumps(
- {"nonce": nonce(), "username": "a", "password": u"abcd\u0000"}
- )
- request, channel = self.make_request("POST", self.url, body.encode('utf8'))
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual('Invalid password', channel.json_body["error"])
-
- # Super long
- body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
- request, channel = self.make_request("POST", self.url, body.encode('utf8'))
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual('Invalid password', channel.json_body["error"])
-
- #
- # user_type check
- #
-
- # Invalid user_type
- body = json.dumps(
- {
- "nonce": nonce(),
- "username": "a",
- "password": "1234",
- "user_type": "invalid",
- }
+ {"server_version", "python_version"}, set(channel.json_body.keys())
)
- request, channel = self.make_request("POST", self.url, body.encode('utf8'))
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual('Invalid user type', channel.json_body["error"])
class ShutdownRoomTestCase(unittest.HomeserverTestCase):
@@ -396,7 +104,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
url = "admin/shutdown_room/" + room_id
request, channel = self.make_request(
"POST",
- url.encode('ascii'),
+ url.encode("ascii"),
json.dumps({"new_room_user_id": self.admin_user}),
access_token=self.admin_user_tok,
)
@@ -421,7 +129,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
url = "rooms/%s/state/m.room.history_visibility" % (room_id,)
request, channel = self.make_request(
"PUT",
- url.encode('ascii'),
+ url.encode("ascii"),
json.dumps({"history_visibility": "world_readable"}),
access_token=self.other_user_token,
)
@@ -432,7 +140,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
url = "admin/shutdown_room/" + room_id
request, channel = self.make_request(
"POST",
- url.encode('ascii'),
+ url.encode("ascii"),
json.dumps({"new_room_user_id": self.admin_user}),
access_token=self.admin_user_tok,
)
@@ -449,7 +157,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
url = "rooms/%s/initialSync" % (room_id,)
request, channel = self.make_request(
- "GET", url.encode('ascii'), access_token=self.admin_user_tok
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
)
self.render(request)
self.assertEqual(
@@ -458,7 +166,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
url = "events?timeout=0&room_id=" + room_id
request, channel = self.make_request(
- "GET", url.encode('ascii'), access_token=self.admin_user_tok
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
)
self.render(request)
self.assertEqual(
@@ -486,7 +194,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
# Create a new group
request, channel = self.make_request(
"POST",
- "/create_group".encode('ascii'),
+ "/create_group".encode("ascii"),
access_token=self.admin_user_tok,
content={"localpart": "test"},
)
@@ -502,14 +210,14 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
url = "/groups/%s/admin/users/invite/%s" % (group_id, self.other_user)
request, channel = self.make_request(
- "PUT", url.encode('ascii'), access_token=self.admin_user_tok, content={}
+ "PUT", url.encode("ascii"), access_token=self.admin_user_tok, content={}
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
url = "/groups/%s/self/accept_invite" % (group_id,)
request, channel = self.make_request(
- "PUT", url.encode('ascii'), access_token=self.other_user_token, content={}
+ "PUT", url.encode("ascii"), access_token=self.other_user_token, content={}
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -522,7 +230,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
url = "/admin/delete_group/" + group_id
request, channel = self.make_request(
"POST",
- url.encode('ascii'),
+ url.encode("ascii"),
access_token=self.admin_user_tok,
content={"localpart": "test"},
)
@@ -544,7 +252,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
url = "/groups/%s/profile" % (group_id,)
request, channel = self.make_request(
- "GET", url.encode('ascii'), access_token=self.admin_user_tok
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
)
self.render(request)
@@ -556,10 +264,816 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
"""Returns the list of groups the user is in (given their access token)
"""
request, channel = self.make_request(
- "GET", "/joined_groups".encode('ascii'), access_token=access_token
+ "GET", "/joined_groups".encode("ascii"), access_token=access_token
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
return channel.json_body["groups"]
+
+
+class PurgeRoomTestCase(unittest.HomeserverTestCase):
+ """Test /purge_room admin API.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ def test_purge_room(self):
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ # All users have to have left the room.
+ self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
+
+ url = "/_synapse/admin/v1/purge_room"
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ {"room_id": room_id},
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Test that the following tables have been purged of all rows related to the room.
+ for table in (
+ "current_state_events",
+ "event_backward_extremities",
+ "event_forward_extremities",
+ "event_json",
+ "event_push_actions",
+ "event_search",
+ "events",
+ "group_rooms",
+ "public_room_list_stream",
+ "receipts_graph",
+ "receipts_linearized",
+ "room_aliases",
+ "room_depth",
+ "room_memberships",
+ "room_stats_state",
+ "room_stats_current",
+ "room_stats_historical",
+ "room_stats_earliest_token",
+ "rooms",
+ "stream_ordering_to_exterm",
+ "users_in_public_rooms",
+ "users_who_share_private_rooms",
+ "appservice_room_list",
+ "e2e_room_keys",
+ "event_push_summary",
+ "pusher_throttle",
+ "group_summary_rooms",
+ "local_invites",
+ "room_account_data",
+ "room_tags",
+ # "state_groups", # Current impl leaves orphaned state groups around.
+ "state_groups_state",
+ ):
+ count = self.get_success(
+ self.store.db.simple_select_one_onecol(
+ table=table,
+ keyvalues={"room_id": room_id},
+ retcol="COUNT(*)",
+ desc="test_purge_room",
+ )
+ )
+
+ self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
+
+
+class QuarantineMediaTestCase(unittest.HomeserverTestCase):
+ """Test /quarantine_media admin API.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ synapse.rest.admin.register_servlets_for_media_repo,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.hs = hs
+
+ # Allow for uploading and downloading to/from the media repo
+ self.media_repo = hs.get_media_repository_resource()
+ self.download_resource = self.media_repo.children[b"download"]
+ self.upload_resource = self.media_repo.children[b"upload"]
+ self.image_data = unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000a49444154789c63000100000500010d"
+ b"0a2db40000000049454e44ae426082"
+ )
+
+ def make_homeserver(self, reactor, clock):
+
+ self.fetches = []
+
+ def get_file(destination, path, output_stream, args=None, max_size=None):
+ """
+ Returns tuple[int,dict,str,int] of file length, response headers,
+ absolute URI, and response code.
+ """
+
+ def write_to(r):
+ data, response = r
+ output_stream.write(data)
+ return response
+
+ d = Deferred()
+ d.addCallback(write_to)
+ self.fetches.append((d, destination, path, args))
+ return make_deferred_yieldable(d)
+
+ client = Mock()
+ client.get_file = get_file
+
+ self.storage_path = self.mktemp()
+ self.media_store_path = self.mktemp()
+ os.mkdir(self.storage_path)
+ os.mkdir(self.media_store_path)
+
+ config = self.default_config()
+ config["media_store_path"] = self.media_store_path
+ config["thumbnail_requirements"] = {}
+ config["max_image_pixels"] = 2000000
+
+ provider_config = {
+ "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
+ "store_local": True,
+ "store_synchronous": False,
+ "store_remote": True,
+ "config": {"directory": self.storage_path},
+ }
+ config["media_storage_providers"] = [provider_config]
+
+ hs = self.setup_test_homeserver(config=config, http_client=client)
+
+ return hs
+
+ def test_quarantine_media_requires_admin(self):
+ self.register_user("nonadmin", "pass", admin=False)
+ non_admin_user_tok = self.login("nonadmin", "pass")
+
+ # Attempt quarantine media APIs as non-admin
+ url = "/_synapse/admin/v1/media/quarantine/example.org/abcde12345"
+ request, channel = self.make_request(
+ "POST", url.encode("ascii"), access_token=non_admin_user_tok,
+ )
+ self.render(request)
+
+ # Expect a forbidden error
+ self.assertEqual(
+ 403,
+ int(channel.result["code"]),
+ msg="Expected forbidden on quarantining media as a non-admin",
+ )
+
+ # And the roomID/userID endpoint
+ url = "/_synapse/admin/v1/room/!room%3Aexample.com/media/quarantine"
+ request, channel = self.make_request(
+ "POST", url.encode("ascii"), access_token=non_admin_user_tok,
+ )
+ self.render(request)
+
+ # Expect a forbidden error
+ self.assertEqual(
+ 403,
+ int(channel.result["code"]),
+ msg="Expected forbidden on quarantining media as a non-admin",
+ )
+
+ def test_quarantine_media_by_id(self):
+ self.register_user("id_admin", "pass", admin=True)
+ admin_user_tok = self.login("id_admin", "pass")
+
+ self.register_user("id_nonadmin", "pass", admin=False)
+ non_admin_user_tok = self.login("id_nonadmin", "pass")
+
+ # Upload some media into the room
+ response = self.helper.upload_media(
+ self.upload_resource, self.image_data, tok=admin_user_tok
+ )
+
+ # Extract media ID from the response
+ server_name_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
+ server_name, media_id = server_name_and_media_id.split("/")
+
+ # Attempt to access the media
+ request, channel = self.make_request(
+ "GET",
+ server_name_and_media_id,
+ shorthand=False,
+ access_token=non_admin_user_tok,
+ )
+ request.render(self.download_resource)
+ self.pump(1.0)
+
+ # Should be successful
+ self.assertEqual(200, int(channel.code), msg=channel.result["body"])
+
+ # Quarantine the media
+ url = "/_synapse/admin/v1/media/quarantine/%s/%s" % (
+ urllib.parse.quote(server_name),
+ urllib.parse.quote(media_id),
+ )
+ request, channel = self.make_request("POST", url, access_token=admin_user_tok,)
+ self.render(request)
+ self.pump(1.0)
+ self.assertEqual(200, int(channel.code), msg=channel.result["body"])
+
+ # Attempt to access the media
+ request, channel = self.make_request(
+ "GET",
+ server_name_and_media_id,
+ shorthand=False,
+ access_token=admin_user_tok,
+ )
+ request.render(self.download_resource)
+ self.pump(1.0)
+
+ # Should be quarantined
+ self.assertEqual(
+ 404,
+ int(channel.code),
+ msg=(
+ "Expected to receive a 404 on accessing quarantined media: %s"
+ % server_name_and_media_id
+ ),
+ )
+
+ def test_quarantine_all_media_in_room(self, override_url_template=None):
+ self.register_user("room_admin", "pass", admin=True)
+ admin_user_tok = self.login("room_admin", "pass")
+
+ non_admin_user = self.register_user("room_nonadmin", "pass", admin=False)
+ non_admin_user_tok = self.login("room_nonadmin", "pass")
+
+ room_id = self.helper.create_room_as(non_admin_user, tok=admin_user_tok)
+ self.helper.join(room_id, non_admin_user, tok=non_admin_user_tok)
+
+ # Upload some media
+ response_1 = self.helper.upload_media(
+ self.upload_resource, self.image_data, tok=non_admin_user_tok
+ )
+ response_2 = self.helper.upload_media(
+ self.upload_resource, self.image_data, tok=non_admin_user_tok
+ )
+
+ # Extract mxcs
+ mxc_1 = response_1["content_uri"]
+ mxc_2 = response_2["content_uri"]
+
+ # Send it into the room
+ self.helper.send_event(
+ room_id,
+ "m.room.message",
+ content={"body": "image-1", "msgtype": "m.image", "url": mxc_1},
+ txn_id="111",
+ tok=non_admin_user_tok,
+ )
+ self.helper.send_event(
+ room_id,
+ "m.room.message",
+ content={"body": "image-2", "msgtype": "m.image", "url": mxc_2},
+ txn_id="222",
+ tok=non_admin_user_tok,
+ )
+
+ # Quarantine all media in the room
+ if override_url_template:
+ url = override_url_template % urllib.parse.quote(room_id)
+ else:
+ url = "/_synapse/admin/v1/room/%s/media/quarantine" % urllib.parse.quote(
+ room_id
+ )
+ request, channel = self.make_request("POST", url, access_token=admin_user_tok,)
+ self.render(request)
+ self.pump(1.0)
+ self.assertEqual(200, int(channel.code), msg=channel.result["body"])
+ self.assertEqual(
+ json.loads(channel.result["body"].decode("utf-8")),
+ {"num_quarantined": 2},
+ "Expected 2 quarantined items",
+ )
+
+ # Convert mxc URLs to server/media_id strings
+ server_and_media_id_1 = mxc_1[6:]
+ server_and_media_id_2 = mxc_2[6:]
+
+ # Test that we cannot download any of the media anymore
+ request, channel = self.make_request(
+ "GET",
+ server_and_media_id_1,
+ shorthand=False,
+ access_token=non_admin_user_tok,
+ )
+ request.render(self.download_resource)
+ self.pump(1.0)
+
+ # Should be quarantined
+ self.assertEqual(
+ 404,
+ int(channel.code),
+ msg=(
+ "Expected to receive a 404 on accessing quarantined media: %s"
+ % server_and_media_id_1
+ ),
+ )
+
+ request, channel = self.make_request(
+ "GET",
+ server_and_media_id_2,
+ shorthand=False,
+ access_token=non_admin_user_tok,
+ )
+ request.render(self.download_resource)
+ self.pump(1.0)
+
+ # Should be quarantined
+ self.assertEqual(
+ 404,
+ int(channel.code),
+ msg=(
+ "Expected to receive a 404 on accessing quarantined media: %s"
+ % server_and_media_id_2
+ ),
+ )
+
+ def test_quaraantine_all_media_in_room_deprecated_api_path(self):
+ # Perform the above test with the deprecated API path
+ self.test_quarantine_all_media_in_room("/_synapse/admin/v1/quarantine_media/%s")
+
+ def test_quarantine_all_media_by_user(self):
+ self.register_user("user_admin", "pass", admin=True)
+ admin_user_tok = self.login("user_admin", "pass")
+
+ non_admin_user = self.register_user("user_nonadmin", "pass", admin=False)
+ non_admin_user_tok = self.login("user_nonadmin", "pass")
+
+ # Upload some media
+ response_1 = self.helper.upload_media(
+ self.upload_resource, self.image_data, tok=non_admin_user_tok
+ )
+ response_2 = self.helper.upload_media(
+ self.upload_resource, self.image_data, tok=non_admin_user_tok
+ )
+
+ # Extract media IDs
+ server_and_media_id_1 = response_1["content_uri"][6:]
+ server_and_media_id_2 = response_2["content_uri"][6:]
+
+ # Quarantine all media by this user
+ url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
+ non_admin_user
+ )
+ request, channel = self.make_request(
+ "POST", url.encode("ascii"), access_token=admin_user_tok,
+ )
+ self.render(request)
+ self.pump(1.0)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ json.loads(channel.result["body"].decode("utf-8")),
+ {"num_quarantined": 2},
+ "Expected 2 quarantined items",
+ )
+
+ # Attempt to access each piece of media
+ request, channel = self.make_request(
+ "GET",
+ server_and_media_id_1,
+ shorthand=False,
+ access_token=non_admin_user_tok,
+ )
+ request.render(self.download_resource)
+ self.pump(1.0)
+
+ # Should be quarantined
+ self.assertEqual(
+ 404,
+ int(channel.code),
+ msg=(
+ "Expected to receive a 404 on accessing quarantined media: %s"
+ % server_and_media_id_1,
+ ),
+ )
+
+ # Attempt to access each piece of media
+ request, channel = self.make_request(
+ "GET",
+ server_and_media_id_2,
+ shorthand=False,
+ access_token=non_admin_user_tok,
+ )
+ request.render(self.download_resource)
+ self.pump(1.0)
+
+ # Should be quarantined
+ self.assertEqual(
+ 404,
+ int(channel.code),
+ msg=(
+ "Expected to receive a 404 on accessing quarantined media: %s"
+ % server_and_media_id_2
+ ),
+ )
+
+
+class RoomTestCase(unittest.HomeserverTestCase):
+ """Test /room admin API.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ directory.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ # Create user
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ def test_list_rooms(self):
+ """Test that we can list rooms"""
+ # Create 3 test rooms
+ total_rooms = 3
+ room_ids = []
+ for x in range(total_rooms):
+ room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok
+ )
+ room_ids.append(room_id)
+
+ # Request the list of rooms
+ url = "/_synapse/admin/v1/rooms"
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ # Check request completed successfully
+ self.assertEqual(200, int(channel.code), msg=channel.json_body)
+
+ # Check that response json body contains a "rooms" key
+ self.assertTrue(
+ "rooms" in channel.json_body,
+ msg="Response body does not " "contain a 'rooms' key",
+ )
+
+ # Check that 3 rooms were returned
+ self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body)
+
+ # Check their room_ids match
+ returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]]
+ self.assertEqual(room_ids, returned_room_ids)
+
+ # Check that all fields are available
+ for r in channel.json_body["rooms"]:
+ self.assertIn("name", r)
+ self.assertIn("canonical_alias", r)
+ self.assertIn("joined_members", r)
+
+ # Check that the correct number of total rooms was returned
+ self.assertEqual(channel.json_body["total_rooms"], total_rooms)
+
+ # Check that the offset is correct
+ # Should be 0 as we aren't paginating
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that the prev_batch parameter is not present
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # We shouldn't receive a next token here as there's no further rooms to show
+ self.assertNotIn("next_batch", channel.json_body)
+
+ def test_list_rooms_pagination(self):
+ """Test that we can get a full list of rooms through pagination"""
+ # Create 5 test rooms
+ total_rooms = 5
+ room_ids = []
+ for x in range(total_rooms):
+ room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok
+ )
+ room_ids.append(room_id)
+
+ # Set the name of the rooms so we get a consistent returned ordering
+ for idx, room_id in enumerate(room_ids):
+ self.helper.send_state(
+ room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok,
+ )
+
+ # Request the list of rooms
+ returned_room_ids = []
+ start = 0
+ limit = 2
+
+ run_count = 0
+ should_repeat = True
+ while should_repeat:
+ run_count += 1
+
+ url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % (
+ start,
+ limit,
+ "alphabetical",
+ )
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(
+ 200, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+ self.assertTrue("rooms" in channel.json_body)
+ for r in channel.json_body["rooms"]:
+ returned_room_ids.append(r["room_id"])
+
+ # Check that the correct number of total rooms was returned
+ self.assertEqual(channel.json_body["total_rooms"], total_rooms)
+
+ # Check that the offset is correct
+ # We're only getting 2 rooms each page, so should be 2 * last run_count
+ self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1))
+
+ if run_count > 1:
+ # Check the value of prev_batch is correct
+ self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2))
+
+ if "next_batch" not in channel.json_body:
+ # We have reached the end of the list
+ should_repeat = False
+ else:
+ # Make another query with an updated start value
+ start = channel.json_body["next_batch"]
+
+ # We should've queried the endpoint 3 times
+ self.assertEqual(
+ run_count,
+ 3,
+ msg="Should've queried 3 times for 5 rooms with limit 2 per query",
+ )
+
+ # Check that we received all of the room ids
+ self.assertEqual(room_ids, returned_room_ids)
+
+ url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ def test_correct_room_attributes(self):
+ """Test the correct attributes for a room are returned"""
+ # Create a test room
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ test_alias = "#test:test"
+ test_room_name = "something"
+
+ # Have another user join the room
+ user_2 = self.register_user("user4", "pass")
+ user_tok_2 = self.login("user4", "pass")
+ self.helper.join(room_id, user_2, tok=user_tok_2)
+
+ # Create a new alias to this room
+ url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),)
+ request, channel = self.make_request(
+ "PUT",
+ url.encode("ascii"),
+ {"room_id": room_id},
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Set this new alias as the canonical alias for this room
+ self.helper.send_state(
+ room_id,
+ "m.room.aliases",
+ {"aliases": [test_alias]},
+ tok=self.admin_user_tok,
+ state_key="test",
+ )
+ self.helper.send_state(
+ room_id,
+ "m.room.canonical_alias",
+ {"alias": test_alias},
+ tok=self.admin_user_tok,
+ )
+
+ # Set a name for the room
+ self.helper.send_state(
+ room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok,
+ )
+
+ # Request the list of rooms
+ url = "/_synapse/admin/v1/rooms"
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Check that rooms were returned
+ self.assertTrue("rooms" in channel.json_body)
+ rooms = channel.json_body["rooms"]
+
+ # Check that only one room was returned
+ self.assertEqual(len(rooms), 1)
+
+ # And that the value of the total_rooms key was correct
+ self.assertEqual(channel.json_body["total_rooms"], 1)
+
+ # Check that the offset is correct
+ # We're not paginating, so should be 0
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that there is no `prev_batch`
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # Check that there is no `next_batch`
+ self.assertNotIn("next_batch", channel.json_body)
+
+ # Check that all provided attributes are set
+ r = rooms[0]
+ self.assertEqual(room_id, r["room_id"])
+ self.assertEqual(test_room_name, r["name"])
+ self.assertEqual(test_alias, r["canonical_alias"])
+
+ def test_room_list_sort_order(self):
+ """Test room list sort ordering. alphabetical versus number of members,
+ reversing the order, etc.
+ """
+ # Create 3 test rooms
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C
+ self.helper.send_state(
+ room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok,
+ )
+
+ # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3
+ user_1 = self.register_user("bob1", "pass")
+ user_1_tok = self.login("bob1", "pass")
+ self.helper.join(room_id_2, user_1, tok=user_1_tok)
+
+ user_2 = self.register_user("bob2", "pass")
+ user_2_tok = self.login("bob2", "pass")
+ self.helper.join(room_id_3, user_2, tok=user_2_tok)
+
+ user_3 = self.register_user("bob3", "pass")
+ user_3_tok = self.login("bob3", "pass")
+ self.helper.join(room_id_3, user_3, tok=user_3_tok)
+
+ def _order_test(
+ order_type: str, expected_room_list: List[str], reverse: bool = False,
+ ):
+ """Request the list of rooms in a certain order. Assert that order is what
+ we expect
+
+ Args:
+ order_type: The type of ordering to give the server
+ expected_room_list: The list of room_ids in the order we expect to get
+ back from the server
+ """
+ # Request the list of rooms in the given order
+ url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,)
+ if reverse:
+ url += "&dir=b"
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Check that rooms were returned
+ self.assertTrue("rooms" in channel.json_body)
+ rooms = channel.json_body["rooms"]
+
+ # Check for the correct total_rooms value
+ self.assertEqual(channel.json_body["total_rooms"], 3)
+
+ # Check that the offset is correct
+ # We're not paginating, so should be 0
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that there is no `prev_batch`
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # Check that there is no `next_batch`
+ self.assertNotIn("next_batch", channel.json_body)
+
+ # Check that rooms were returned in alphabetical order
+ returned_order = [r["room_id"] for r in rooms]
+ self.assertListEqual(expected_room_list, returned_order) # order is checked
+
+ # Test different sort orders, with forward and reverse directions
+ _order_test("alphabetical", [room_id_1, room_id_2, room_id_3])
+ _order_test("alphabetical", [room_id_3, room_id_2, room_id_1], reverse=True)
+
+ _order_test("size", [room_id_3, room_id_2, room_id_1])
+ _order_test("size", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ def test_search_term(self):
+ """Test that searching for a room works correctly"""
+ # Create two test rooms
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ room_name_1 = "something"
+ room_name_2 = "else"
+
+ # Set the name for each room
+ self.helper.send_state(
+ room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
+ )
+
+ def _search_test(
+ expected_room_id: Optional[str],
+ search_term: str,
+ expected_http_code: int = 200,
+ ):
+ """Search for a room and check that the returned room's id is a match
+
+ Args:
+ expected_room_id: The room_id expected to be returned by the API. Set
+ to None to expect zero results for the search
+ search_term: The term to search for room names with
+ expected_http_code: The expected http code for the request
+ """
+ url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
+
+ if expected_http_code != 200:
+ return
+
+ # Check that rooms were returned
+ self.assertTrue("rooms" in channel.json_body)
+ rooms = channel.json_body["rooms"]
+
+ # Check that the expected number of rooms were returned
+ expected_room_count = 1 if expected_room_id else 0
+ self.assertEqual(len(rooms), expected_room_count)
+ self.assertEqual(channel.json_body["total_rooms"], expected_room_count)
+
+ # Check that the offset is correct
+ # We're not paginating, so should be 0
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that there is no `prev_batch`
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # Check that there is no `next_batch`
+ self.assertNotIn("next_batch", channel.json_body)
+
+ if expected_room_id:
+ # Check that the first returned room id is correct
+ r = rooms[0]
+ self.assertEqual(expected_room_id, r["room_id"])
+
+ # Perform search tests
+ _search_test(room_id_1, "something")
+ _search_test(room_id_1, "thing")
+
+ _search_test(room_id_2, "else")
+ _search_test(room_id_2, "se")
+
+ _search_test(None, "foo")
+ _search_test(None, "bar")
+ _search_test(None, "", expected_http_code=400)
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
new file mode 100644
index 0000000000..6416fb5d2a
--- /dev/null
+++ b/tests/rest/admin/test_user.py
@@ -0,0 +1,722 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import hashlib
+import hmac
+import json
+import urllib.parse
+
+from mock import Mock
+
+import synapse.rest.admin
+from synapse.api.constants import UserTypes
+from synapse.rest.client.v1 import login
+
+from tests import unittest
+
+
+class UserRegisterTestCase(unittest.HomeserverTestCase):
+
+ servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
+
+ def make_homeserver(self, reactor, clock):
+
+ self.url = "/_matrix/client/r0/admin/register"
+
+ self.registration_handler = Mock()
+ self.identity_handler = Mock()
+ self.login_handler = Mock()
+ self.device_handler = Mock()
+ self.device_handler.check_device_registered = Mock(return_value="FAKE")
+
+ self.datastore = Mock(return_value=Mock())
+ self.datastore.get_current_state_deltas = Mock(return_value=(0, []))
+
+ self.secrets = Mock()
+
+ self.hs = self.setup_test_homeserver()
+
+ self.hs.config.registration_shared_secret = "shared"
+
+ self.hs.get_media_repository = Mock()
+ self.hs.get_deactivate_account_handler = Mock()
+
+ return self.hs
+
+ def test_disabled(self):
+ """
+ If there is no shared secret, registration through this method will be
+ prevented.
+ """
+ self.hs.config.registration_shared_secret = None
+
+ request, channel = self.make_request("POST", self.url, b"{}")
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ "Shared secret registration is not enabled", channel.json_body["error"]
+ )
+
+ def test_get_nonce(self):
+ """
+ Calling GET on the endpoint will return a randomised nonce, using the
+ homeserver's secrets provider.
+ """
+ secrets = Mock()
+ secrets.token_hex = Mock(return_value="abcd")
+
+ self.hs.get_secrets = Mock(return_value=secrets)
+
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
+
+ self.assertEqual(channel.json_body, {"nonce": "abcd"})
+
+ def test_expired_nonce(self):
+ """
+ Calling GET on the endpoint will return a randomised nonce, which will
+ only last for SALT_TIMEOUT (60s).
+ """
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
+ nonce = channel.json_body["nonce"]
+
+ # 59 seconds
+ self.reactor.advance(59)
+
+ body = json.dumps({"nonce": nonce})
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("username must be specified", channel.json_body["error"])
+
+ # 61 seconds
+ self.reactor.advance(2)
+
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("unrecognised nonce", channel.json_body["error"])
+
+ def test_register_incorrect_nonce(self):
+ """
+ Only the provided nonce can be used, as it's checked in the MAC.
+ """
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
+ nonce = channel.json_body["nonce"]
+
+ want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
+ want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin")
+ want_mac = want_mac.hexdigest()
+
+ body = json.dumps(
+ {
+ "nonce": nonce,
+ "username": "bob",
+ "password": "abc123",
+ "admin": True,
+ "mac": want_mac,
+ }
+ )
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("HMAC incorrect", channel.json_body["error"])
+
+ def test_register_correct_nonce(self):
+ """
+ When the correct nonce is provided, and the right key is provided, the
+ user is registered.
+ """
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
+ nonce = channel.json_body["nonce"]
+
+ want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
+ want_mac.update(
+ nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support"
+ )
+ want_mac = want_mac.hexdigest()
+
+ body = json.dumps(
+ {
+ "nonce": nonce,
+ "username": "bob",
+ "password": "abc123",
+ "admin": True,
+ "user_type": UserTypes.SUPPORT,
+ "mac": want_mac,
+ }
+ )
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["user_id"])
+
+ def test_nonce_reuse(self):
+ """
+ A valid unrecognised nonce.
+ """
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
+ nonce = channel.json_body["nonce"]
+
+ want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
+ want_mac.update(nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin")
+ want_mac = want_mac.hexdigest()
+
+ body = json.dumps(
+ {
+ "nonce": nonce,
+ "username": "bob",
+ "password": "abc123",
+ "admin": True,
+ "mac": want_mac,
+ }
+ )
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["user_id"])
+
+ # Now, try and reuse it
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("unrecognised nonce", channel.json_body["error"])
+
+ def test_missing_parts(self):
+ """
+ Synapse will complain if you don't give nonce, username, password, and
+ mac. Admin and user_types are optional. Additional checks are done for length
+ and type.
+ """
+
+ def nonce():
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
+ return channel.json_body["nonce"]
+
+ #
+ # Nonce check
+ #
+
+ # Must be present
+ body = json.dumps({})
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("nonce must be specified", channel.json_body["error"])
+
+ #
+ # Username checks
+ #
+
+ # Must be present
+ body = json.dumps({"nonce": nonce()})
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("username must be specified", channel.json_body["error"])
+
+ # Must be a string
+ body = json.dumps({"nonce": nonce(), "username": 1234})
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("Invalid username", channel.json_body["error"])
+
+ # Must not have null bytes
+ body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"})
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("Invalid username", channel.json_body["error"])
+
+ # Must not have null bytes
+ body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("Invalid username", channel.json_body["error"])
+
+ #
+ # Password checks
+ #
+
+ # Must be present
+ body = json.dumps({"nonce": nonce(), "username": "a"})
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("password must be specified", channel.json_body["error"])
+
+ # Must be a string
+ body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("Invalid password", channel.json_body["error"])
+
+ # Must not have null bytes
+ body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"})
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("Invalid password", channel.json_body["error"])
+
+ # Super long
+ body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("Invalid password", channel.json_body["error"])
+
+ #
+ # user_type check
+ #
+
+ # Invalid user_type
+ body = json.dumps(
+ {
+ "nonce": nonce(),
+ "username": "a",
+ "password": "1234",
+ "user_type": "invalid",
+ }
+ )
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("Invalid user type", channel.json_body["error"])
+
+
+class UsersListTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+ url = "/_synapse/admin/v2/users"
+
+ def prepare(self, reactor, clock, hs):
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.register_user("user1", "pass1", admin=False)
+ self.register_user("user2", "pass2", admin=False)
+
+ def test_no_auth(self):
+ """
+ Try to list users without authentication.
+ """
+ request, channel = self.make_request("GET", self.url, b"{}")
+ self.render(request)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("M_MISSING_TOKEN", channel.json_body["errcode"])
+
+ def test_all_users(self):
+ """
+ List all users, including deactivated users.
+ """
+ request, channel = self.make_request(
+ "GET",
+ self.url + "?deactivated=true",
+ b"{}",
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(3, len(channel.json_body["users"]))
+
+
+class UserRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_token = self.login("user", "pass")
+ self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote(
+ self.other_user
+ )
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ self.hs.config.registration_shared_secret = None
+ url = "/_synapse/admin/v2/users/@bob:test"
+
+ request, channel = self.make_request(
+ "GET", url, access_token=self.other_user_token,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("You are not a server admin", channel.json_body["error"])
+
+ request, channel = self.make_request(
+ "PUT", url, access_token=self.other_user_token, content=b"{}",
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("You are not a server admin", channel.json_body["error"])
+
+ def test_user_does_not_exist(self):
+ """
+ Tests that a lookup for a user that does not exist returns a 404
+ """
+ self.hs.config.registration_shared_secret = None
+
+ request, channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v2/users/@unknown_person:test",
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"])
+
+ def test_create_server_admin(self):
+ """
+ Check that a new admin user is created successfully.
+ """
+ self.hs.config.registration_shared_secret = None
+ url = "/_synapse/admin/v2/users/@bob:test"
+
+ # Create user (server admin)
+ body = json.dumps(
+ {
+ "password": "abc123",
+ "admin": True,
+ "displayname": "Bob's name",
+ "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
+ }
+ )
+
+ request, channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual("Bob's name", channel.json_body["displayname"])
+ self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
+ self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual(True, channel.json_body["admin"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual("Bob's name", channel.json_body["displayname"])
+ self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
+ self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual(True, channel.json_body["admin"])
+ self.assertEqual(False, channel.json_body["is_guest"])
+ self.assertEqual(False, channel.json_body["deactivated"])
+
+ def test_create_user(self):
+ """
+ Check that a new regular user is created successfully.
+ """
+ self.hs.config.registration_shared_secret = None
+ url = "/_synapse/admin/v2/users/@bob:test"
+
+ # Create user
+ body = json.dumps(
+ {
+ "password": "abc123",
+ "admin": False,
+ "displayname": "Bob's name",
+ "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
+ }
+ )
+
+ request, channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual("Bob's name", channel.json_body["displayname"])
+ self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
+ self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual(False, channel.json_body["admin"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual("Bob's name", channel.json_body["displayname"])
+ self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
+ self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual(False, channel.json_body["admin"])
+ self.assertEqual(False, channel.json_body["is_guest"])
+ self.assertEqual(False, channel.json_body["deactivated"])
+
+ def test_set_password(self):
+ """
+ Test setting a new password for another user.
+ """
+ self.hs.config.registration_shared_secret = None
+
+ # Change password
+ body = json.dumps({"password": "hahaha"})
+
+ request, channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ def test_set_displayname(self):
+ """
+ Test setting the displayname of another user.
+ """
+ self.hs.config.registration_shared_secret = None
+
+ # Modify user
+ body = json.dumps({"displayname": "foobar"})
+
+ request, channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual("foobar", channel.json_body["displayname"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_other_user, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual("foobar", channel.json_body["displayname"])
+
+ def test_set_threepid(self):
+ """
+ Test setting threepid for an other user.
+ """
+ self.hs.config.registration_shared_secret = None
+
+ # Delete old and add new threepid to user
+ body = json.dumps(
+ {"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]}
+ )
+
+ request, channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
+ self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_other_user, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
+ self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
+
+ def test_deactivate_user(self):
+ """
+ Test deactivating another user.
+ """
+
+ # Deactivate user
+ body = json.dumps({"deactivated": True})
+
+ request, channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(True, channel.json_body["deactivated"])
+ # the user is deactivated, the threepid will be deleted
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_other_user, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(True, channel.json_body["deactivated"])
+
+ def test_set_user_as_admin(self):
+ """
+ Test setting the admin flag on a user.
+ """
+ self.hs.config.registration_shared_secret = None
+
+ # Set a user as an admin
+ body = json.dumps({"admin": True})
+
+ request, channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(True, channel.json_body["admin"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_other_user, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(True, channel.json_body["admin"])
+
+ def test_accidental_deactivation_prevention(self):
+ """
+ Ensure an account can't accidentally be deactivated by using a str value
+ for the deactivated body parameter
+ """
+ self.hs.config.registration_shared_secret = None
+ url = "/_synapse/admin/v2/users/@bob:test"
+
+ # Create user
+ body = json.dumps({"password": "abc123"})
+
+ request, channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual("bob", channel.json_body["displayname"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual("bob", channel.json_body["displayname"])
+ self.assertEqual(0, channel.json_body["deactivated"])
+
+ # Change password (and use a str for deactivate instead of a bool)
+ body = json.dumps({"password": "abc123", "deactivated": "false"}) # oops!
+
+ request, channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Check user is not deactivated
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual("bob", channel.json_body["displayname"])
+
+ # Ensure they're still alive
+ self.assertEqual(0, channel.json_body["deactivated"])
diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py
index efc5a99db3..6803b372ac 100644
--- a/tests/rest/client/test_consent.py
+++ b/tests/rest/client/test_consent.py
@@ -42,17 +42,17 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
# Make some temporary templates...
temp_consent_path = self.mktemp()
os.mkdir(temp_consent_path)
- os.mkdir(os.path.join(temp_consent_path, 'en'))
+ os.mkdir(os.path.join(temp_consent_path, "en"))
config["user_consent"] = {
"version": "1",
"template_dir": os.path.abspath(temp_consent_path),
}
- with open(os.path.join(temp_consent_path, "en/1.html"), 'w') as f:
+ with open(os.path.join(temp_consent_path, "en/1.html"), "w") as f:
f.write("{{version}},{{has_consented}}")
- with open(os.path.join(temp_consent_path, "en/success.html"), 'w') as f:
+ with open(os.path.join(temp_consent_path, "en/success.html"), "w") as f:
f.write("yay!")
hs = self.setup_test_homeserver(config=config)
@@ -88,7 +88,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
# Get the version from the body, and whether we've consented
- version, consented = channel.result["body"].decode('ascii').split(",")
+ version, consented = channel.result["body"].decode("ascii").split(",")
self.assertEqual(consented, "False")
# POST to the consent page, saying we've agreed
@@ -111,6 +111,6 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
# Get the version from the body, and check that it's the version we
# agreed to, and that we've consented to it.
- version, consented = channel.result["body"].decode('ascii').split(",")
+ version, consented = channel.result["body"].decode("ascii").split(",")
self.assertEqual(consented, "True")
self.assertEqual(version, "1")
diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py
new file mode 100644
index 0000000000..5e9c07ebf3
--- /dev/null
+++ b/tests/rest/client/test_ephemeral_message.py
@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.api.constants import EventContentFields, EventTypes
+from synapse.rest import admin
+from synapse.rest.client.v1 import room
+
+from tests import unittest
+
+
+class EphemeralMessageTestCase(unittest.HomeserverTestCase):
+
+ user_id = "@user:test"
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ config["enable_ephemeral_messages"] = True
+
+ self.hs = self.setup_test_homeserver(config=config)
+ return self.hs
+
+ def prepare(self, reactor, clock, homeserver):
+ self.room_id = self.helper.create_room_as(self.user_id)
+
+ def test_message_expiry_no_delay(self):
+ """Tests that sending a message sent with a m.self_destruct_after field set to the
+ past results in that event being deleted right away.
+ """
+ # Send a message in the room that has expired. From here, the reactor clock is
+ # at 200ms, so 0 is in the past, and even if that wasn't the case and the clock
+ # is at 0ms the code path is the same if the event's expiry timestamp is the
+ # current timestamp.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "hello",
+ EventContentFields.SELF_DESTRUCT_AFTER: 0,
+ },
+ )
+ event_id = res["event_id"]
+
+ # Check that we can't retrieve the content of the event.
+ event_content = self.get_event(self.room_id, event_id)["content"]
+ self.assertFalse(bool(event_content), event_content)
+
+ def test_message_expiry_delay(self):
+ """Tests that sending a message with a m.self_destruct_after field set to the
+ future results in that event not being deleted right away, but advancing the
+ clock to after that expiry timestamp causes the event to be deleted.
+ """
+ # Send a message in the room that'll expire in 1s.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "hello",
+ EventContentFields.SELF_DESTRUCT_AFTER: self.clock.time_msec() + 1000,
+ },
+ )
+ event_id = res["event_id"]
+
+ # Check that we can retrieve the content of the event before it has expired.
+ event_content = self.get_event(self.room_id, event_id)["content"]
+ self.assertTrue(bool(event_content), event_content)
+
+ # Advance the clock to after the deletion.
+ self.reactor.advance(1)
+
+ # Check that we can't retrieve the content of the event anymore.
+ event_content = self.get_event(self.room_id, event_id)["content"]
+ self.assertFalse(bool(event_content), event_content)
+
+ def get_event(self, room_id, event_id, expected_code=200):
+ url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
+
+ request, channel = self.make_request("GET", url)
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ return channel.json_body
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index c9b9eff83e..4224b0a92e 100644
--- a/tests/rest/client/test_identity.py
+++ b/tests/rest/client/test_identity.py
@@ -39,9 +39,7 @@ class IdentityDisabledTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
- config["trusted_third_party_id_servers"] = [
- "testis",
- ]
+ config["trusted_third_party_id_servers"] = ["testis"]
config["enable_3pid_lookup"] = False
self.hs = self.setup_test_homeserver(config=config)
@@ -53,7 +51,7 @@ class IdentityDisabledTestCase(unittest.HomeserverTestCase):
def test_3pid_invite_disabled(self):
request, channel = self.make_request(
- b"POST", "/createRoom", b"{}", access_token=self.tok,
+ b"POST", "/createRoom", b"{}", access_token=self.tok
)
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
@@ -65,18 +63,18 @@ class IdentityDisabledTestCase(unittest.HomeserverTestCase):
"address": "test@example.com",
}
request_data = json.dumps(params)
- request_url = (
- "/rooms/%s/invite" % (room_id)
- ).encode('ascii')
+ request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
request, channel = self.make_request(
- b"POST", request_url, request_data, access_token=self.tok,
+ b"POST", request_url, request_data, access_token=self.tok
)
self.render(request)
self.assertEquals(channel.result["code"], b"403", channel.result)
def test_3pid_lookup_disabled(self):
- url = ("/_matrix/client/unstable/account/3pid/lookup"
- "?id_server=testis&medium=email&address=foo@bar.baz")
+ url = (
+ "/_matrix/client/unstable/account/3pid/lookup"
+ "?id_server=testis&medium=email&address=foo@bar.baz"
+ )
request, channel = self.make_request("GET", url, access_token=self.tok)
self.render(request)
self.assertEqual(channel.result["code"], b"403", channel.result)
@@ -85,20 +83,11 @@ class IdentityDisabledTestCase(unittest.HomeserverTestCase):
url = "/_matrix/client/unstable/account/3pid/bulk_lookup"
data = {
"id_server": "testis",
- "threepids": [
- [
- "email",
- "foo@bar.baz"
- ],
- [
- "email",
- "john.doe@matrix.org"
- ]
- ]
+ "threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]],
}
request_data = json.dumps(data)
request, channel = self.make_request(
- "POST", url, request_data, access_token=self.tok,
+ "POST", url, request_data, access_token=self.tok
)
self.render(request)
self.assertEqual(channel.result["code"], b"403", channel.result)
@@ -118,22 +107,21 @@ class IdentityEnabledTestCase(unittest.HomeserverTestCase):
config = self.default_config()
config["enable_3pid_lookup"] = True
- config["trusted_third_party_id_servers"] = [
- "testis",
- ]
-
- mock_http_client = Mock(spec=[
- "get_json",
- "post_json_get_json",
- ])
+ config["trusted_third_party_id_servers"] = ["testis"]
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"])
mock_http_client.get_json.return_value = defer.succeed((200, "{}"))
mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
self.hs = self.setup_test_homeserver(
- config=config,
- simple_http_client=mock_http_client,
+ config=config, simple_http_client=mock_http_client
)
+ # TODO: This class does not use a singleton to get it's http client
+ # This should be fixed for easier testing
+ # https://github.com/matrix-org/synapse-dinsic/issues/26
+ self.hs.get_handlers().identity_handler.http_client = mock_http_client
+
return self.hs
def prepare(self, reactor, clock, hs):
@@ -142,82 +130,66 @@ class IdentityEnabledTestCase(unittest.HomeserverTestCase):
def test_3pid_invite_enabled(self):
request, channel = self.make_request(
- b"POST", "/createRoom", b"{}", access_token=self.tok,
+ b"POST", "/createRoom", b"{}", access_token=self.tok
)
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
room_id = channel.json_body["room_id"]
+ # Replace the blacklisting SimpleHttpClient with our mock
+ self.hs.get_room_member_handler().simple_http_client = Mock(
+ spec=["get_json", "post_json_get_json"]
+ )
+ self.hs.get_room_member_handler().simple_http_client.get_json.return_value = defer.succeed(
+ (200, "{}")
+ )
+
params = {
"id_server": "testis",
"medium": "email",
"address": "test@example.com",
}
request_data = json.dumps(params)
- request_url = ("/rooms/%s/invite" % (room_id)).encode('ascii')
+ request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
request, channel = self.make_request(
- b"POST", request_url, request_data, access_token=self.tok,
+ b"POST", request_url, request_data, access_token=self.tok
)
self.render(request)
- get_json = self.hs.get_simple_http_client().get_json
+ get_json = self.hs.get_handlers().identity_handler.http_client.get_json
get_json.assert_called_once_with(
"https://testis/_matrix/identity/api/v1/lookup",
- {
- "address": "test@example.com",
- "medium": "email",
- },
+ {"address": "test@example.com", "medium": "email"},
)
def test_3pid_lookup_enabled(self):
- url = ("/_matrix/client/unstable/account/3pid/lookup"
- "?id_server=testis&medium=email&address=foo@bar.baz")
+ url = (
+ "/_matrix/client/unstable/account/3pid/lookup"
+ "?id_server=testis&medium=email&address=foo@bar.baz"
+ )
request, channel = self.make_request("GET", url, access_token=self.tok)
self.render(request)
get_json = self.hs.get_simple_http_client().get_json
get_json.assert_called_once_with(
"https://testis/_matrix/identity/api/v1/lookup",
- {
- "address": "foo@bar.baz",
- "medium": "email",
- },
+ {"address": "foo@bar.baz", "medium": "email"},
)
def test_3pid_bulk_lookup_enabled(self):
url = "/_matrix/client/unstable/account/3pid/bulk_lookup"
data = {
"id_server": "testis",
- "threepids": [
- [
- "email",
- "foo@bar.baz"
- ],
- [
- "email",
- "john.doe@matrix.org"
- ]
- ]
+ "threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]],
}
request_data = json.dumps(data)
request, channel = self.make_request(
- "POST", url, request_data, access_token=self.tok,
+ "POST", url, request_data, access_token=self.tok
)
self.render(request)
post_json = self.hs.get_simple_http_client().post_json_get_json
post_json.assert_called_once_with(
"https://testis/_matrix/identity/api/v1/bulk_lookup",
- {
- "threepids": [
- [
- "email",
- "foo@bar.baz"
- ],
- [
- "email",
- "john.doe@matrix.org"
- ]
- ],
- },
+ {"threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]]},
)
diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py
new file mode 100644
index 0000000000..d2bcf256fa
--- /dev/null
+++ b/tests/rest/client/test_redactions.py
@@ -0,0 +1,204 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import sync
+
+from tests.unittest import HomeserverTestCase
+
+
+class RedactionsTestCase(HomeserverTestCase):
+ """Tests that various redaction events are handled correctly"""
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ config["rc_message"] = {"per_second": 0.2, "burst_count": 10}
+ config["rc_admin_redaction"] = {"per_second": 1, "burst_count": 100}
+
+ return self.setup_test_homeserver(config=config)
+
+ def prepare(self, reactor, clock, hs):
+ # register a couple of users
+ self.mod_user_id = self.register_user("user1", "pass")
+ self.mod_access_token = self.login("user1", "pass")
+ self.other_user_id = self.register_user("otheruser", "pass")
+ self.other_access_token = self.login("otheruser", "pass")
+
+ # Create a room
+ self.room_id = self.helper.create_room_as(
+ self.mod_user_id, tok=self.mod_access_token
+ )
+
+ # Invite the other user
+ self.helper.invite(
+ room=self.room_id,
+ src=self.mod_user_id,
+ tok=self.mod_access_token,
+ targ=self.other_user_id,
+ )
+ # The other user joins
+ self.helper.join(
+ room=self.room_id, user=self.other_user_id, tok=self.other_access_token
+ )
+
+ def _redact_event(self, access_token, room_id, event_id, expect_code=200):
+ """Helper function to send a redaction event.
+
+ Returns the json body.
+ """
+ path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id)
+
+ request, channel = self.make_request(
+ "POST", path, content={}, access_token=access_token
+ )
+ self.render(request)
+ self.assertEqual(int(channel.result["code"]), expect_code)
+ return channel.json_body
+
+ def _sync_room_timeline(self, access_token, room_id):
+ request, channel = self.make_request(
+ "GET", "sync", access_token=self.mod_access_token
+ )
+ self.render(request)
+ self.assertEqual(channel.result["code"], b"200")
+ room_sync = channel.json_body["rooms"]["join"][room_id]
+ return room_sync["timeline"]["events"]
+
+ def test_redact_event_as_moderator(self):
+ # as a regular user, send a message to redact
+ b = self.helper.send(room_id=self.room_id, tok=self.other_access_token)
+ msg_id = b["event_id"]
+
+ # as the moderator, send a redaction
+ b = self._redact_event(self.mod_access_token, self.room_id, msg_id)
+ redaction_id = b["event_id"]
+
+ # now sync
+ timeline = self._sync_room_timeline(self.mod_access_token, self.room_id)
+
+ # the last event should be the redaction
+ self.assertEqual(timeline[-1]["event_id"], redaction_id)
+ self.assertEqual(timeline[-1]["redacts"], msg_id)
+
+ # and the penultimate should be the redacted original
+ self.assertEqual(timeline[-2]["event_id"], msg_id)
+ self.assertEqual(timeline[-2]["unsigned"]["redacted_by"], redaction_id)
+ self.assertEqual(timeline[-2]["content"], {})
+
+ def test_redact_event_as_normal(self):
+ # as a regular user, send a message to redact
+ b = self.helper.send(room_id=self.room_id, tok=self.other_access_token)
+ normal_msg_id = b["event_id"]
+
+ # also send one as the admin
+ b = self.helper.send(room_id=self.room_id, tok=self.mod_access_token)
+ admin_msg_id = b["event_id"]
+
+ # as a normal, try to redact the admin's event
+ self._redact_event(
+ self.other_access_token, self.room_id, admin_msg_id, expect_code=403
+ )
+
+ # now try to redact our own event
+ b = self._redact_event(self.other_access_token, self.room_id, normal_msg_id)
+ redaction_id = b["event_id"]
+
+ # now sync
+ timeline = self._sync_room_timeline(self.other_access_token, self.room_id)
+
+ # the last event should be the redaction of the normal event
+ self.assertEqual(timeline[-1]["event_id"], redaction_id)
+ self.assertEqual(timeline[-1]["redacts"], normal_msg_id)
+
+ # the penultimate should be the unredacted one from the admin
+ self.assertEqual(timeline[-2]["event_id"], admin_msg_id)
+ self.assertNotIn("redacted_by", timeline[-2]["unsigned"])
+ self.assertTrue(timeline[-2]["content"]["body"], {})
+
+ # and the antepenultimate should be the redacted normal
+ self.assertEqual(timeline[-3]["event_id"], normal_msg_id)
+ self.assertEqual(timeline[-3]["unsigned"]["redacted_by"], redaction_id)
+ self.assertEqual(timeline[-3]["content"], {})
+
+ def test_redact_nonexistent_event(self):
+ # control case: an existing event
+ b = self.helper.send(room_id=self.room_id, tok=self.other_access_token)
+ msg_id = b["event_id"]
+ b = self._redact_event(self.other_access_token, self.room_id, msg_id)
+ redaction_id = b["event_id"]
+
+ # room moderators can send redactions for non-existent events
+ self._redact_event(self.mod_access_token, self.room_id, "$zzz")
+
+ # ... but normals cannot
+ self._redact_event(
+ self.other_access_token, self.room_id, "$zzz", expect_code=404
+ )
+
+ # when we sync, we should see only the valid redaction
+ timeline = self._sync_room_timeline(self.other_access_token, self.room_id)
+ self.assertEqual(timeline[-1]["event_id"], redaction_id)
+ self.assertEqual(timeline[-1]["redacts"], msg_id)
+
+ # and the penultimate should be the redacted original
+ self.assertEqual(timeline[-2]["event_id"], msg_id)
+ self.assertEqual(timeline[-2]["unsigned"]["redacted_by"], redaction_id)
+ self.assertEqual(timeline[-2]["content"], {})
+
+ def test_redact_create_event(self):
+ # control case: an existing event
+ b = self.helper.send(room_id=self.room_id, tok=self.mod_access_token)
+ msg_id = b["event_id"]
+ self._redact_event(self.mod_access_token, self.room_id, msg_id)
+
+ # sync the room, to get the id of the create event
+ timeline = self._sync_room_timeline(self.other_access_token, self.room_id)
+ create_event_id = timeline[0]["event_id"]
+
+ # room moderators cannot send redactions for create events
+ self._redact_event(
+ self.mod_access_token, self.room_id, create_event_id, expect_code=403
+ )
+
+ # and nor can normals
+ self._redact_event(
+ self.other_access_token, self.room_id, create_event_id, expect_code=403
+ )
+
+ def test_redact_event_as_moderator_ratelimit(self):
+ """Tests that the correct ratelimiting is applied to redactions
+ """
+
+ message_ids = []
+ # as a regular user, send messages to redact
+ for _ in range(20):
+ b = self.helper.send(room_id=self.room_id, tok=self.other_access_token)
+ message_ids.append(b["event_id"])
+ self.reactor.advance(10) # To get around ratelimits
+
+ # as the moderator, send a bunch of redactions
+ for msg_id in message_ids:
+ # These should all succeed, even though this would be denied by
+ # the standard message ratelimiter
+ self._redact_event(self.mod_access_token, self.room_id, msg_id)
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index d0deff5a3b..9e549d8a91 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -61,9 +61,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
- body={
- "max_lifetime": one_day_ms * 4,
- },
+ body={"max_lifetime": one_day_ms * 4},
tok=self.token,
expect_code=400,
)
@@ -71,9 +69,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
- body={
- "max_lifetime": one_hour_ms,
- },
+ body={"max_lifetime": one_hour_ms},
tok=self.token,
expect_code=400,
)
@@ -89,9 +85,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
- body={
- "max_lifetime": lifetime,
- },
+ body={"max_lifetime": lifetime},
tok=self.token,
)
@@ -110,52 +104,37 @@ class RetentionTestCase(unittest.HomeserverTestCase):
outdated events
"""
store = self.hs.get_datastore()
+ storage = self.hs.get_storage()
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
events = []
# Send a first event, which should be filtered out at the end of the test.
- resp = self.helper.send(
- room_id=room_id,
- body="1",
- tok=self.token,
- )
+ resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
# Get the event from the store so that we end up with a FrozenEvent that we can
# give to filter_events_for_client. We need to do this now because the event won't
# be in the database anymore after it has expired.
- events.append(self.get_success(
- store.get_event(
- resp.get("event_id")
- )
- ))
+ events.append(self.get_success(store.get_event(resp.get("event_id"))))
# Advance the time by 2 days. We're using the default retention policy, therefore
# after this the first event will still be valid.
self.reactor.advance(one_day_ms * 2 / 1000)
# Send another event, which shouldn't get filtered out.
- resp = self.helper.send(
- room_id=room_id,
- body="2",
- tok=self.token,
- )
+ resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
valid_event_id = resp.get("event_id")
- events.append(self.get_success(
- store.get_event(
- valid_event_id
- )
- ))
+ events.append(self.get_success(store.get_event(valid_event_id)))
# Advance the time by anothe 2 days. After this, the first event should be
# outdated but not the second one.
self.reactor.advance(one_day_ms * 2 / 1000)
# Run filter_events_for_client with our list of FrozenEvents.
- filtered_events = self.get_success(filter_events_for_client(
- store, self.user_id, events
- ))
+ filtered_events = self.get_success(
+ filter_events_for_client(storage, self.user_id, events)
+ )
# We should only get one event back.
self.assertEqual(len(filtered_events), 1, filtered_events)
@@ -171,28 +150,22 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# Send a first event to the room. This is the event we'll want to be purged at the
# end of the test.
- resp = self.helper.send(
- room_id=room_id,
- body="1",
- tok=self.token,
- )
+ resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
expired_event_id = resp.get("event_id")
# Check that we can retrieve the event.
expired_event = self.get_event(room_id, expired_event_id)
- self.assertEqual(expired_event.get("content", {}).get("body"), "1", expired_event)
+ self.assertEqual(
+ expired_event.get("content", {}).get("body"), "1", expired_event
+ )
# Advance the time.
self.reactor.advance(increment / 1000)
# Send another event. We need this because the purge job won't purge the most
# recent event in the room.
- resp = self.helper.send(
- room_id=room_id,
- body="2",
- tok=self.token,
- )
+ resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
valid_event_id = resp.get("event_id")
@@ -239,8 +212,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
mock_federation_client = Mock(spec=["backfill"])
self.hs = self.setup_test_homeserver(
- config=config,
- federation_client=mock_federation_client,
+ config=config, federation_client=mock_federation_client,
)
return self.hs
@@ -267,9 +239,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
- body={
- "max_lifetime": one_day_ms * 35,
- },
+ body={"max_lifetime": one_day_ms * 35},
tok=self.token,
)
@@ -278,28 +248,22 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
def _test_retention(self, room_id, expected_code_for_first_event=200):
# Send a first event to the room. This is the event we'll want to be purged at the
# end of the test.
- resp = self.helper.send(
- room_id=room_id,
- body="1",
- tok=self.token,
- )
+ resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
first_event_id = resp.get("event_id")
# Check that we can retrieve the event.
expired_event = self.get_event(room_id, first_event_id)
- self.assertEqual(expired_event.get("content", {}).get("body"), "1", expired_event)
+ self.assertEqual(
+ expired_event.get("content", {}).get("body"), "1", expired_event
+ )
# Advance the time by a month.
self.reactor.advance(one_day_ms * 30 / 1000)
# Send another event. We need this because the purge job won't purge the most
# recent event in the room.
- resp = self.helper.send(
- room_id=room_id,
- body="2",
- tok=self.token,
- )
+ resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
second_event_id = resp.get("event_id")
@@ -312,7 +276,9 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
)
if expected_code_for_first_event == 200:
- self.assertEqual(first_event.get("content", {}).get("body"), "1", first_event)
+ self.assertEqual(
+ first_event.get("content", {}).get("body"), "1", first_event
+ )
# Check that the event that hasn't been purged can still be retrieved.
second_event = self.get_event(room_id, second_event_id)
diff --git a/tests/rest/client/test_room_access_rules.py b/tests/rest/client/test_room_access_rules.py
index 13caea3b01..7da0ef4e18 100644
--- a/tests/rest/client/test_room_access_rules.py
+++ b/tests/rest/client/test_room_access_rules.py
@@ -49,15 +49,11 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
config["third_party_event_rules"] = {
"module": "synapse.third_party_rules.access_rules.RoomAccessRules",
"config": {
- "domains_forbidden_when_restricted": [
- "forbidden_domain"
- ],
+ "domains_forbidden_when_restricted": ["forbidden_domain"],
"id_server": "testis",
- }
+ },
}
- config["trusted_third_party_id_servers"] = [
- "testis",
- ]
+ config["trusted_third_party_id_servers"] = ["testis"]
def send_invite(destination, room_id, event_id, pdu):
return defer.succeed(pdu)
@@ -66,42 +62,46 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
address_domain = args["address"].split("@")[1]
return defer.succeed({"hs": address_domain})
- def post_urlencoded_get_json(uri, args={}, headers=None):
- token = ''.join(random.choice(string.ascii_letters) for _ in range(10))
- return defer.succeed({
- "token": token,
- "public_keys": [
- {
- "public_key": "serverpublickey",
- "key_validity_url": "https://testis/pubkey/isvalid",
- },
- {
- "public_key": "phemeralpublickey",
- "key_validity_url": "https://testis/pubkey/ephemeral/isvalid",
- },
- ],
- "display_name": "f...@b...",
- })
-
- mock_federation_client = Mock(spec=[
- "send_invite",
- ])
+ def post_json_get_json(uri, post_json, args={}, headers=None):
+ token = "".join(random.choice(string.ascii_letters) for _ in range(10))
+ return defer.succeed(
+ {
+ "token": token,
+ "public_keys": [
+ {
+ "public_key": "serverpublickey",
+ "key_validity_url": "https://testis/pubkey/isvalid",
+ },
+ {
+ "public_key": "phemeralpublickey",
+ "key_validity_url": "https://testis/pubkey/ephemeral/isvalid",
+ },
+ ],
+ "display_name": "f...@b...",
+ }
+ )
+
+ mock_federation_client = Mock(spec=["send_invite"])
mock_federation_client.send_invite.side_effect = send_invite
- mock_http_client = Mock(spec=[
- "get_json",
- "post_urlencoded_get_json"
- ])
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"],)
# Mocking the response for /info on the IS API.
mock_http_client.get_json.side_effect = get_json
# Mocking the response for /store-invite on the IS API.
- mock_http_client.post_urlencoded_get_json.side_effect = post_urlencoded_get_json
+ mock_http_client.post_json_get_json.side_effect = post_json_get_json
self.hs = self.setup_test_homeserver(
config=config,
federation_client=mock_federation_client,
simple_http_client=mock_http_client,
)
+ # TODO: This class does not use a singleton to get it's http client
+ # This should be fixed for easier testing
+ # https://github.com/matrix-org/synapse-dinsic/issues/26
+ self.hs.get_handlers().identity_handler.blacklisting_http_client = (
+ mock_http_client
+ )
+
return self.hs
def prepare(self, reactor, clock, homeserver):
@@ -164,74 +164,80 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
# rule to restricted.
preset_room_id = self.create_room(preset=RoomCreationPreset.PUBLIC_CHAT)
self.assertEqual(
- self.current_rule_in_room(preset_room_id), ACCESS_RULE_RESTRICTED,
+ self.current_rule_in_room(preset_room_id), ACCESS_RULE_RESTRICTED
)
# Creating a room with the public join rule in its initial state should succeed
# and set the access rule to restricted.
- init_state_room_id = self.create_room(initial_state=[{
- "type": "m.room.join_rules",
- "content": {
- "join_rule": JoinRules.PUBLIC,
- },
- }])
+ init_state_room_id = self.create_room(
+ initial_state=[
+ {
+ "type": "m.room.join_rules",
+ "content": {"join_rule": JoinRules.PUBLIC},
+ }
+ ]
+ )
self.assertEqual(
- self.current_rule_in_room(init_state_room_id), ACCESS_RULE_RESTRICTED,
+ self.current_rule_in_room(init_state_room_id), ACCESS_RULE_RESTRICTED
)
# Changing access rule to unrestricted should fail.
self.change_rule_in_room(
- preset_room_id, ACCESS_RULE_UNRESTRICTED, expected_code=403,
+ preset_room_id, ACCESS_RULE_UNRESTRICTED, expected_code=403
)
self.change_rule_in_room(
- init_state_room_id, ACCESS_RULE_UNRESTRICTED, expected_code=403,
+ init_state_room_id, ACCESS_RULE_UNRESTRICTED, expected_code=403
)
# Changing access rule to direct should fail.
+ self.change_rule_in_room(preset_room_id, ACCESS_RULE_DIRECT, expected_code=403)
self.change_rule_in_room(
- preset_room_id, ACCESS_RULE_DIRECT, expected_code=403,
- )
- self.change_rule_in_room(
- init_state_room_id, ACCESS_RULE_DIRECT, expected_code=403,
+ init_state_room_id, ACCESS_RULE_DIRECT, expected_code=403
)
# Changing join rule to public in an unrestricted room should fail.
self.change_join_rule_in_room(
- self.unrestricted_room, JoinRules.PUBLIC, expected_code=403,
+ self.unrestricted_room, JoinRules.PUBLIC, expected_code=403
)
# Changing join rule to public in an direct room should fail.
self.change_join_rule_in_room(
- self.direct_rooms[0], JoinRules.PUBLIC, expected_code=403,
+ self.direct_rooms[0], JoinRules.PUBLIC, expected_code=403
)
# Creating a new room with the public_chat preset and an access rule that isn't
# restricted should fail.
self.create_room(
- preset=RoomCreationPreset.PUBLIC_CHAT, rule=ACCESS_RULE_UNRESTRICTED,
+ preset=RoomCreationPreset.PUBLIC_CHAT,
+ rule=ACCESS_RULE_UNRESTRICTED,
expected_code=400,
)
self.create_room(
- preset=RoomCreationPreset.PUBLIC_CHAT, rule=ACCESS_RULE_DIRECT,
+ preset=RoomCreationPreset.PUBLIC_CHAT,
+ rule=ACCESS_RULE_DIRECT,
expected_code=400,
)
# Creating a room with the public join rule in its initial state and an access
# rule that isn't restricted should fail.
self.create_room(
- initial_state=[{
- "type": "m.room.join_rules",
- "content": {
- "join_rule": JoinRules.PUBLIC,
- },
- }], rule=ACCESS_RULE_UNRESTRICTED, expected_code=400,
+ initial_state=[
+ {
+ "type": "m.room.join_rules",
+ "content": {"join_rule": JoinRules.PUBLIC},
+ }
+ ],
+ rule=ACCESS_RULE_UNRESTRICTED,
+ expected_code=400,
)
self.create_room(
- initial_state=[{
- "type": "m.room.join_rules",
- "content": {
- "join_rule": JoinRules.PUBLIC,
- },
- }], rule=ACCESS_RULE_DIRECT, expected_code=400,
+ initial_state=[
+ {
+ "type": "m.room.join_rules",
+ "content": {"join_rule": JoinRules.PUBLIC},
+ }
+ ],
+ rule=ACCESS_RULE_DIRECT,
+ expected_code=400,
)
def test_restricted(self):
@@ -405,12 +411,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
self.helper.send_state(
room_id=self.unrestricted_room,
event_type=EventTypes.PowerLevels,
- body={
- "users": {
- self.user_id: 100,
- "@test:not_forbidden_domain": 10,
- },
- },
+ body={"users": {self.user_id: 100, "@test:not_forbidden_domain": 10}},
tok=self.tok,
expect_code=200,
)
@@ -421,10 +422,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
room_id=self.unrestricted_room,
event_type=EventTypes.PowerLevels,
body={
- "users": {
- self.user_id: 100,
- "@test:not_forbidden_domain": 10,
- },
+ "users": {self.user_id: 100, "@test:not_forbidden_domain": 10},
"users_default": 10,
},
tok=self.tok,
@@ -436,12 +434,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
self.helper.send_state(
room_id=self.unrestricted_room,
event_type=EventTypes.PowerLevels,
- body={
- "users": {
- self.user_id: 100,
- "@test:forbidden_domain": 10,
- },
- },
+ body={"users": {self.user_id: 100, "@test:forbidden_domain": 10}},
tok=self.tok,
expect_code=403,
)
@@ -459,9 +452,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
# We can't change the rule from restricted to direct.
self.change_rule_in_room(
- room_id=self.restricted_room,
- new_rule=ACCESS_RULE_DIRECT,
- expected_code=403,
+ room_id=self.restricted_room, new_rule=ACCESS_RULE_DIRECT, expected_code=403
)
# We can't change the rule from unrestricted to restricted.
@@ -498,12 +489,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
"""
avatar_content = {
- "info": {
- "h": 398,
- "mimetype": "image/jpeg",
- "size": 31037,
- "w": 394
- },
+ "info": {"h": 398, "mimetype": "image/jpeg", "size": 31037, "w": 394},
"url": "mxc://example.org/JWEIFJgwEIhweiWJE",
}
@@ -536,9 +522,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
chat, in which case it's forbidden.
"""
- name_content = {
- "name": "My super room",
- }
+ name_content = {"name": "My super room"}
self.helper.send_state(
room_id=self.restricted_room,
@@ -569,9 +553,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
direct chat, in which case it's forbidden.
"""
- topic_content = {
- "topic": "Welcome to this room",
- }
+ topic_content = {"topic": "Welcome to this room"}
self.helper.send_state(
room_id=self.restricted_room,
@@ -608,15 +590,15 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
"public_keys": [
{
"key_validity_url": "https://validity_url",
- "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA"
+ "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA",
},
{
"key_validity_url": "https://validity_url",
- "public_key": "4_9nzEeDwR5N9s51jPodBiLnqH43A2_g2InVT137t9I"
- }
+ "public_key": "4_9nzEeDwR5N9s51jPodBiLnqH43A2_g2InVT137t9I",
+ },
],
"key_validity_url": "https://validity_url",
- "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA"
+ "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA",
}
self.send_state_with_state_key(
@@ -646,22 +628,19 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
)
def create_room(
- self, direct=False, rule=None, preset=RoomCreationPreset.TRUSTED_PRIVATE_CHAT,
- initial_state=None, expected_code=200,
+ self,
+ direct=False,
+ rule=None,
+ preset=RoomCreationPreset.TRUSTED_PRIVATE_CHAT,
+ initial_state=None,
+ expected_code=200,
):
- content = {
- "is_direct": direct,
- "preset": preset,
- }
+ content = {"is_direct": direct, "preset": preset}
if rule:
- content["initial_state"] = [{
- "type": ACCESS_RULES_TYPE,
- "state_key": "",
- "content": {
- "rule": rule,
- }
- }]
+ content["initial_state"] = [
+ {"type": ACCESS_RULES_TYPE, "state_key": "", "content": {"rule": rule}}
+ ]
if initial_state:
if "initial_state" not in content:
@@ -694,9 +673,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
return channel.json_body["rule"]
def change_rule_in_room(self, room_id, new_rule, expected_code=200):
- data = {
- "rule": new_rule,
- }
+ data = {"rule": new_rule}
request, channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE),
@@ -708,9 +685,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, expected_code, channel.result)
def change_join_rule_in_room(self, room_id, new_join_rule, expected_code=200):
- data = {
- "join_rule": new_join_rule,
- }
+ data = {"join_rule": new_join_rule}
request, channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/state/%s" % (room_id, EventTypes.JoinRules),
@@ -722,11 +697,7 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, expected_code, channel.result)
def send_threepid_invite(self, address, room_id, expected_code=200):
- params = {
- "id_server": "testis",
- "medium": "email",
- "address": address,
- }
+ params = {"id_server": "testis", "medium": "email", "address": address}
request, channel = self.make_request(
"POST",
@@ -741,7 +712,9 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
self, room_id, event_type, state_key, body, tok, expect_code=200
):
path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % (
- room_id, event_type, state_key
+ room_id,
+ event_type,
+ state_key,
)
request, channel = self.make_request(
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index 708dc26e61..a3d7e3c046 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -2,9 +2,9 @@ from mock import Mock, call
from twisted.internet import defer, reactor
+from synapse.logging.context import LoggingContext
from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache
from synapse.util import Clock
-from synapse.util.logcontext import LoggingContext
from tests import unittest
from tests.utils import MockClock
@@ -46,7 +46,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
@defer.inlineCallbacks
def cb():
yield Clock(reactor).sleep(0)
- defer.returnValue("yay")
+ return "yay"
@defer.inlineCallbacks
def test():
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index f340b7e851..ffb2de1505 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -134,3 +134,30 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
# someone else set topic, expect 6 (join,send,topic,join,send,topic)
pass
+
+
+class GetEventsTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ events.register_servlets,
+ room.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ ]
+
+ def prepare(self, hs, reactor, clock):
+
+ # register an account
+ self.user_id = self.register_user("sid1", "pass")
+ self.token = self.login(self.user_id, "pass")
+
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ def test_get_event_via_events(self):
+ resp = self.helper.send(self.room_id, tok=self.token)
+ event_id = resp["event_id"]
+
+ request, channel = self.make_request(
+ "GET", "/events/" + event_id, access_token=self.token,
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 200, msg=channel.result)
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 0397f91a9e..da2c9bfa1e 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -1,11 +1,18 @@
import json
+import urllib.parse
+
+from mock import Mock
import synapse.rest.admin
from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import devices
+from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
from tests import unittest
+from tests.unittest import override_config
LOGIN_URL = b"/_matrix/client/r0/login"
+TEST_URL = b"/_matrix/client/r0/account/whoami"
class LoginRestServletTestCase(unittest.HomeserverTestCase):
@@ -13,6 +20,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
+ devices.register_servlets,
+ lambda hs, http_server: WhoamiRestServlet(hs).register(http_server),
]
def make_homeserver(self, reactor, clock):
@@ -144,3 +153,213 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEquals(channel.result["code"], b"403", channel.result)
+
+ @override_config({"session_lifetime": "24h"})
+ def test_soft_logout(self):
+ self.register_user("kermit", "monkey")
+
+ # we shouldn't be able to make requests without an access token
+ request, channel = self.make_request(b"GET", TEST_URL)
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], "M_MISSING_TOKEN")
+
+ # log in as normal
+ params = {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": "kermit"},
+ "password": "monkey",
+ }
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ self.render(request)
+
+ self.assertEquals(channel.code, 200, channel.result)
+ access_token = channel.json_body["access_token"]
+ device_id = channel.json_body["device_id"]
+
+ # we should now be able to make requests with the access token
+ request, channel = self.make_request(
+ b"GET", TEST_URL, access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 200, channel.result)
+
+ # time passes
+ self.reactor.advance(24 * 3600)
+
+ # ... and we should be soft-logouted
+ request, channel = self.make_request(
+ b"GET", TEST_URL, access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 401, channel.result)
+ self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
+ self.assertEquals(channel.json_body["soft_logout"], True)
+
+ #
+ # test behaviour after deleting the expired device
+ #
+
+ # we now log in as a different device
+ access_token_2 = self.login("kermit", "monkey")
+
+ # more requests with the expired token should still return a soft-logout
+ self.reactor.advance(3600)
+ request, channel = self.make_request(
+ b"GET", TEST_URL, access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 401, channel.result)
+ self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
+ self.assertEquals(channel.json_body["soft_logout"], True)
+
+ # ... but if we delete that device, it will be a proper logout
+ self._delete_device(access_token_2, "kermit", "monkey", device_id)
+
+ request, channel = self.make_request(
+ b"GET", TEST_URL, access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 401, channel.result)
+ self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
+ self.assertEquals(channel.json_body["soft_logout"], False)
+
+ def _delete_device(self, access_token, user_id, password, device_id):
+ """Perform the UI-Auth to delete a device"""
+ request, channel = self.make_request(
+ b"DELETE", "devices/" + device_id, access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 401, channel.result)
+ # check it's a UI-Auth fail
+ self.assertEqual(
+ set(channel.json_body.keys()),
+ {"flows", "params", "session"},
+ channel.result,
+ )
+
+ auth = {
+ "type": "m.login.password",
+ # https://github.com/matrix-org/synapse/issues/5665
+ # "identifier": {"type": "m.id.user", "user": user_id},
+ "user": user_id,
+ "password": password,
+ "session": channel.json_body["session"],
+ }
+
+ request, channel = self.make_request(
+ b"DELETE",
+ "devices/" + device_id,
+ access_token=access_token,
+ content={"auth": auth},
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 200, channel.result)
+
+
+class CASRedirectConfirmTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ login.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ self.base_url = "https://matrix.goodserver.com/"
+ self.redirect_path = "_synapse/client/login/sso/redirect/confirm"
+
+ config = self.default_config()
+ config["cas_config"] = {
+ "enabled": True,
+ "server_url": "https://fake.test",
+ "service_url": "https://matrix.goodserver.com:8448",
+ }
+
+ async def get_raw(uri, args):
+ """Return an example response payload from a call to the `/proxyValidate`
+ endpoint of a CAS server, copied from
+ https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20
+
+ This needs to be returned by an async function (as opposed to set as the
+ mock's return value) because the corresponding Synapse code awaits on it.
+ """
+ return """
+ <cas:serviceResponse xmlns:cas='http://www.yale.edu/tp/cas'>
+ <cas:authenticationSuccess>
+ <cas:user>username</cas:user>
+ <cas:proxyGrantingTicket>PGTIOU-84678-8a9d...</cas:proxyGrantingTicket>
+ <cas:proxies>
+ <cas:proxy>https://proxy2/pgtUrl</cas:proxy>
+ <cas:proxy>https://proxy1/pgtUrl</cas:proxy>
+ </cas:proxies>
+ </cas:authenticationSuccess>
+ </cas:serviceResponse>
+ """
+
+ mocked_http_client = Mock(spec=["get_raw"])
+ mocked_http_client.get_raw.side_effect = get_raw
+
+ self.hs = self.setup_test_homeserver(
+ config=config, proxied_http_client=mocked_http_client,
+ )
+
+ return self.hs
+
+ def test_cas_redirect_confirm(self):
+ """Tests that the SSO login flow serves a confirmation page before redirecting a
+ user to the redirect URL.
+ """
+ base_url = "/_matrix/client/r0/login/cas/ticket?redirectUrl"
+ redirect_url = "https://dodgy-site.com/"
+
+ url_parts = list(urllib.parse.urlparse(base_url))
+ query = dict(urllib.parse.parse_qsl(url_parts[4]))
+ query.update({"redirectUrl": redirect_url})
+ query.update({"ticket": "ticket"})
+ url_parts[4] = urllib.parse.urlencode(query)
+ cas_ticket_url = urllib.parse.urlunparse(url_parts)
+
+ # Get Synapse to call the fake CAS and serve the template.
+ request, channel = self.make_request("GET", cas_ticket_url)
+ self.render(request)
+
+ # Test that the response is HTML.
+ self.assertEqual(channel.code, 200)
+ content_type_header_value = ""
+ for header in channel.result.get("headers", []):
+ if header[0] == b"Content-Type":
+ content_type_header_value = header[1].decode("utf8")
+
+ self.assertTrue(content_type_header_value.startswith("text/html"))
+
+ # Test that the body isn't empty.
+ self.assertTrue(len(channel.result["body"]) > 0)
+
+ # And that it contains our redirect link
+ self.assertIn(redirect_url, channel.result["body"].decode("UTF-8"))
+
+ @override_config(
+ {
+ "sso": {
+ "client_whitelist": [
+ "https://legit-site.com/",
+ "https://other-site.com/",
+ ]
+ }
+ }
+ )
+ def test_cas_redirect_whitelisted(self):
+ """Tests that the SSO login flow serves a redirect to a whitelisted url
+ """
+ redirect_url = "https://legit-site.com/"
+ cas_ticket_url = (
+ "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket"
+ % (urllib.parse.quote(redirect_url))
+ )
+
+ # Get Synapse to call the fake CAS and serve the template.
+ request, channel = self.make_request("GET", cas_ticket_url)
+ self.render(request)
+
+ self.assertEqual(channel.code, 302)
+ location_headers = channel.headers.getRawHeaders("Location")
+ self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url)
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 66c2b68707..0fdff79aa7 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -15,6 +15,8 @@
from mock import Mock
+from twisted.internet import defer
+
from synapse.rest.client.v1 import presence
from synapse.types import UserID
@@ -36,6 +38,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
)
hs.presence_handler = Mock()
+ hs.presence_handler.set_state.return_value = defer.succeed(None)
return hs
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 6958430608..8df58b4a63 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -52,6 +52,14 @@ class MockHandlerProfileTestCase(unittest.TestCase):
]
)
+ self.mock_handler.get_displayname.return_value = defer.succeed(Mock())
+ self.mock_handler.set_displayname.return_value = defer.succeed(Mock())
+ self.mock_handler.get_avatar_url.return_value = defer.succeed(Mock())
+ self.mock_handler.set_avatar_url.return_value = defer.succeed(Mock())
+ self.mock_handler.check_profile_query_allowed.return_value = defer.succeed(
+ Mock()
+ )
+
hs = yield setup_test_homeserver(
self.addCleanup,
"test",
@@ -63,7 +71,7 @@ class MockHandlerProfileTestCase(unittest.TestCase):
)
def _get_user_by_req(request=None, allow_guest=False):
- return synapse.types.create_requester(myid)
+ return defer.succeed(synapse.types.create_requester(myid))
hs.get_auth().get_user_by_req = _get_user_by_req
@@ -183,7 +191,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
def test_set_displayname(self):
request, channel = self.make_request(
"PUT",
- "/profile/%s/displayname" % (self.owner, ),
+ "/profile/%s/displayname" % (self.owner,),
content=json.dumps({"displayname": "test"}),
access_token=self.owner_tok,
)
@@ -197,7 +205,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
"""Attempts to set a stupid displayname should get a 400"""
request, channel = self.make_request(
"PUT",
- "/profile/%s/displayname" % (self.owner, ),
+ "/profile/%s/displayname" % (self.owner,),
content=json.dumps({"displayname": "test" * 100}),
access_token=self.owner_tok,
)
@@ -209,8 +217,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
def get_displayname(self):
request, channel = self.make_request(
- "GET",
- "/profile/%s/displayname" % (self.owner, ),
+ "GET", "/profile/%s/displayname" % (self.owner,)
)
self.render(request)
self.assertEqual(channel.code, 200, channel.result)
@@ -230,7 +237,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
config = self.default_config()
config["require_auth_for_profile_requests"] = True
- config["limit_profile_requests_to_known_users"] = True
+ config["limit_profile_requests_to_users_who_share_rooms"] = True
self.hs = self.setup_test_homeserver(config=config)
return self.hs
@@ -303,6 +310,7 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
config["require_auth_for_profile_requests"] = True
+ config["limit_profile_requests_to_users_who_share_rooms"] = True
self.hs = self.setup_test_homeserver(config=config)
return self.hs
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 2d64b338be..7dd86d0c27 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018-2019 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -24,8 +26,12 @@ from six.moves.urllib import parse as urlparse
from twisted.internet import defer
import synapse.rest.admin
-from synapse.api.constants import Membership
-from synapse.rest.client.v1 import login, profile, room
+from synapse.api.constants import EventContentFields, EventTypes, Membership
+from synapse.handlers.pagination import PurgeStatus
+from synapse.rest.client.v1 import directory, login, profile, room
+from synapse.rest.client.v2_alpha import account
+from synapse.types import JsonDict, RoomAlias
+from synapse.util.stringutils import random_string
from tests import unittest
@@ -79,7 +85,7 @@ class RoomPermissionsTestCase(RoomBase):
# send a message in one of the rooms
self.created_rmid_msg_path = (
"rooms/%s/send/m.room.message/a1" % (self.created_rmid)
- ).encode('ascii')
+ ).encode("ascii")
request, channel = self.make_request(
"PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}'
)
@@ -89,7 +95,7 @@ class RoomPermissionsTestCase(RoomBase):
# set topic for public room
request, channel = self.make_request(
"PUT",
- ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode('ascii'),
+ ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"),
b'{"topic":"Public Room Topic"}',
)
self.render(request)
@@ -193,7 +199,7 @@ class RoomPermissionsTestCase(RoomBase):
request, channel = self.make_request("GET", topic_path)
self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
- self.assert_dict(json.loads(topic_content.decode('utf8')), channel.json_body)
+ self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body)
# set/get topic in created PRIVATE room and left, expect 403
self.helper.leave(room=self.created_rmid, user=self.user_id)
@@ -484,6 +490,15 @@ class RoomsCreateTestCase(RoomBase):
self.render(request)
self.assertEquals(400, channel.code)
+ def test_post_room_invitees_invalid_mxid(self):
+ # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088
+ # Note the trailing space in the MXID here!
+ request, channel = self.make_request(
+ "POST", "/createRoom", b'{"invite":["@alice:example.com "]}'
+ )
+ self.render(request)
+ self.assertEquals(400, channel.code)
+
class RoomTopicTestCase(RoomBase):
""" Tests /rooms/$room_id/topic REST events. """
@@ -497,7 +512,7 @@ class RoomTopicTestCase(RoomBase):
def test_invalid_puts(self):
# missing keys or invalid json
- request, channel = self.make_request("PUT", self.path, '{}')
+ request, channel = self.make_request("PUT", self.path, "{}")
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
@@ -515,11 +530,11 @@ class RoomTopicTestCase(RoomBase):
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", self.path, 'text only')
+ request, channel = self.make_request("PUT", self.path, "text only")
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", self.path, '')
+ request, channel = self.make_request("PUT", self.path, "")
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
@@ -572,7 +587,7 @@ class RoomMemberStateTestCase(RoomBase):
def test_invalid_puts(self):
path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id)
# missing keys or invalid json
- request, channel = self.make_request("PUT", path, '{}')
+ request, channel = self.make_request("PUT", path, "{}")
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
@@ -590,11 +605,11 @@ class RoomMemberStateTestCase(RoomBase):
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", path, 'text only')
+ request, channel = self.make_request("PUT", path, "text only")
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", path, '')
+ request, channel = self.make_request("PUT", path, "")
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
@@ -604,7 +619,7 @@ class RoomMemberStateTestCase(RoomBase):
Membership.JOIN,
Membership.LEAVE,
)
- request, channel = self.make_request("PUT", path, content.encode('ascii'))
+ request, channel = self.make_request("PUT", path, content.encode("ascii"))
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
@@ -616,7 +631,7 @@ class RoomMemberStateTestCase(RoomBase):
# valid join message (NOOP since we made the room)
content = '{"membership":"%s"}' % Membership.JOIN
- request, channel = self.make_request("PUT", path, content.encode('ascii'))
+ request, channel = self.make_request("PUT", path, content.encode("ascii"))
self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
@@ -678,7 +693,7 @@ class RoomMessagesTestCase(RoomBase):
def test_invalid_puts(self):
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
# missing keys or invalid json
- request, channel = self.make_request("PUT", path, b'{}')
+ request, channel = self.make_request("PUT", path, b"{}")
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
@@ -696,11 +711,11 @@ class RoomMessagesTestCase(RoomBase):
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", path, b'text only')
+ request, channel = self.make_request("PUT", path, b"text only")
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = self.make_request("PUT", path, b'')
+ request, channel = self.make_request("PUT", path, b"")
self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
@@ -786,7 +801,7 @@ class RoomMessageListTestCase(RoomBase):
self.render(request)
self.assertEquals(200, channel.code)
self.assertTrue("start" in channel.json_body)
- self.assertEquals(token, channel.json_body['start'])
+ self.assertEquals(token, channel.json_body["start"])
self.assertTrue("chunk" in channel.json_body)
self.assertTrue("end" in channel.json_body)
@@ -798,10 +813,82 @@ class RoomMessageListTestCase(RoomBase):
self.render(request)
self.assertEquals(200, channel.code)
self.assertTrue("start" in channel.json_body)
- self.assertEquals(token, channel.json_body['start'])
+ self.assertEquals(token, channel.json_body["start"])
self.assertTrue("chunk" in channel.json_body)
self.assertTrue("end" in channel.json_body)
+ def test_room_messages_purge(self):
+ store = self.hs.get_datastore()
+ pagination_handler = self.hs.get_pagination_handler()
+
+ # Send a first message in the room, which will be removed by the purge.
+ first_event_id = self.helper.send(self.room_id, "message 1")["event_id"]
+ first_token = self.get_success(
+ store.get_topological_token_for_event(first_event_id)
+ )
+
+ # Send a second message in the room, which won't be removed, and which we'll
+ # use as the marker to purge events before.
+ second_event_id = self.helper.send(self.room_id, "message 2")["event_id"]
+ second_token = self.get_success(
+ store.get_topological_token_for_event(second_event_id)
+ )
+
+ # Send a third event in the room to ensure we don't fall under any edge case
+ # due to our marker being the latest forward extremity in the room.
+ self.helper.send(self.room_id, "message 3")
+
+ # Check that we get the first and second message when querying /messages.
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
+ % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})),
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ chunk = channel.json_body["chunk"]
+ self.assertEqual(len(chunk), 2, [event["content"] for event in chunk])
+
+ # Purge every event before the second event.
+ purge_id = random_string(16)
+ pagination_handler._purges_by_id[purge_id] = PurgeStatus()
+ self.get_success(
+ pagination_handler._purge_history(
+ purge_id=purge_id,
+ room_id=self.room_id,
+ token=second_token,
+ delete_local_events=True,
+ )
+ )
+
+ # Check that we only get the second message through /message now that the first
+ # has been purged.
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
+ % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})),
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ chunk = channel.json_body["chunk"]
+ self.assertEqual(len(chunk), 1, [event["content"] for event in chunk])
+
+ # Check that we get no event, but also no error, when querying /messages with
+ # the token that was pointing at the first event, because we don't have it
+ # anymore.
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
+ % (self.room_id, first_token, json.dumps({"types": [EventTypes.Message]})),
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ chunk = channel.json_body["chunk"]
+ self.assertEqual(len(chunk), 0, [event["content"] for event in chunk])
+
class RoomSearchTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -961,9 +1048,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
# Set a profile for the test user
self.displayname = "test user"
- data = {
- "displayname": self.displayname,
- }
+ data = {"displayname": self.displayname}
request_data = json.dumps(data)
request, channel = self.make_request(
"PUT",
@@ -977,16 +1062,12 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
def test_per_room_profile_forbidden(self):
- data = {
- "membership": "join",
- "displayname": "other test user"
- }
+ data = {"membership": "join", "displayname": "other test user"}
request_data = json.dumps(data)
request, channel = self.make_request(
"PUT",
- "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (
- self.room_id, self.user_id,
- ),
+ "/_matrix/client/r0/rooms/%s/state/m.room.member/%s"
+ % (self.room_id, self.user_id),
request_data,
access_token=self.tok,
)
@@ -1004,3 +1085,899 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
res_displayname = channel.json_body["content"]["displayname"]
self.assertEqual(res_displayname, self.displayname, channel.result)
+
+
+class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
+ """Tests that clients can add a "reason" field to membership events and
+ that they get correctly added to the generated events and propagated.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.creator = self.register_user("creator", "test")
+ self.creator_tok = self.login("creator", "test")
+
+ self.second_user_id = self.register_user("second", "test")
+ self.second_tok = self.login("second", "test")
+
+ self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok)
+
+ def test_join_reason(self):
+ reason = "hello"
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/{}/join".format(self.room_id),
+ content={"reason": reason},
+ access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self._check_for_reason(reason)
+
+ def test_leave_reason(self):
+ self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
+
+ reason = "hello"
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/{}/leave".format(self.room_id),
+ content={"reason": reason},
+ access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self._check_for_reason(reason)
+
+ def test_kick_reason(self):
+ self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
+
+ reason = "hello"
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/{}/kick".format(self.room_id),
+ content={"reason": reason, "user_id": self.second_user_id},
+ access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self._check_for_reason(reason)
+
+ def test_ban_reason(self):
+ self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
+
+ reason = "hello"
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/{}/ban".format(self.room_id),
+ content={"reason": reason, "user_id": self.second_user_id},
+ access_token=self.creator_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self._check_for_reason(reason)
+
+ def test_unban_reason(self):
+ reason = "hello"
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/{}/unban".format(self.room_id),
+ content={"reason": reason, "user_id": self.second_user_id},
+ access_token=self.creator_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self._check_for_reason(reason)
+
+ def test_invite_reason(self):
+ reason = "hello"
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/{}/invite".format(self.room_id),
+ content={"reason": reason, "user_id": self.second_user_id},
+ access_token=self.creator_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self._check_for_reason(reason)
+
+ def test_reject_invite_reason(self):
+ self.helper.invite(
+ self.room_id,
+ src=self.creator,
+ targ=self.second_user_id,
+ tok=self.creator_tok,
+ )
+
+ reason = "hello"
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/{}/leave".format(self.room_id),
+ content={"reason": reason},
+ access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self._check_for_reason(reason)
+
+ def _check_for_reason(self, reason):
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format(
+ self.room_id, self.second_user_id
+ ),
+ access_token=self.creator_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ event_content = channel.json_body
+
+ self.assertEqual(event_content.get("reason"), reason, channel.result)
+
+
+class LabelsTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ profile.register_servlets,
+ ]
+
+ # Filter that should only catch messages with the label "#fun".
+ FILTER_LABELS = {
+ "types": [EventTypes.Message],
+ "org.matrix.labels": ["#fun"],
+ }
+ # Filter that should only catch messages without the label "#fun".
+ FILTER_NOT_LABELS = {
+ "types": [EventTypes.Message],
+ "org.matrix.not_labels": ["#fun"],
+ }
+ # Filter that should only catch messages with the label "#work" but without the label
+ # "#notfun".
+ FILTER_LABELS_NOT_LABELS = {
+ "types": [EventTypes.Message],
+ "org.matrix.labels": ["#work"],
+ "org.matrix.not_labels": ["#notfun"],
+ }
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("test", "test")
+ self.tok = self.login("test", "test")
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+ def test_context_filter_labels(self):
+ """Test that we can filter by a label on a /context request."""
+ event_id = self._send_labelled_messages_in_room()
+
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/context/%s?filter=%s"
+ % (self.room_id, event_id, json.dumps(self.FILTER_LABELS)),
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ events_before = channel.json_body["events_before"]
+
+ self.assertEqual(
+ len(events_before), 1, [event["content"] for event in events_before]
+ )
+ self.assertEqual(
+ events_before[0]["content"]["body"], "with right label", events_before[0]
+ )
+
+ events_after = channel.json_body["events_before"]
+
+ self.assertEqual(
+ len(events_after), 1, [event["content"] for event in events_after]
+ )
+ self.assertEqual(
+ events_after[0]["content"]["body"], "with right label", events_after[0]
+ )
+
+ def test_context_filter_not_labels(self):
+ """Test that we can filter by the absence of a label on a /context request."""
+ event_id = self._send_labelled_messages_in_room()
+
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/context/%s?filter=%s"
+ % (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)),
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ events_before = channel.json_body["events_before"]
+
+ self.assertEqual(
+ len(events_before), 1, [event["content"] for event in events_before]
+ )
+ self.assertEqual(
+ events_before[0]["content"]["body"], "without label", events_before[0]
+ )
+
+ events_after = channel.json_body["events_after"]
+
+ self.assertEqual(
+ len(events_after), 2, [event["content"] for event in events_after]
+ )
+ self.assertEqual(
+ events_after[0]["content"]["body"], "with wrong label", events_after[0]
+ )
+ self.assertEqual(
+ events_after[1]["content"]["body"], "with two wrong labels", events_after[1]
+ )
+
+ def test_context_filter_labels_not_labels(self):
+ """Test that we can filter by both a label and the absence of another label on a
+ /context request.
+ """
+ event_id = self._send_labelled_messages_in_room()
+
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/context/%s?filter=%s"
+ % (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)),
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ events_before = channel.json_body["events_before"]
+
+ self.assertEqual(
+ len(events_before), 0, [event["content"] for event in events_before]
+ )
+
+ events_after = channel.json_body["events_after"]
+
+ self.assertEqual(
+ len(events_after), 1, [event["content"] for event in events_after]
+ )
+ self.assertEqual(
+ events_after[0]["content"]["body"], "with wrong label", events_after[0]
+ )
+
+ def test_messages_filter_labels(self):
+ """Test that we can filter by a label on a /messages request."""
+ self._send_labelled_messages_in_room()
+
+ token = "s0_0_0_0_0_0_0_0_0"
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/messages?access_token=%s&from=%s&filter=%s"
+ % (self.room_id, self.tok, token, json.dumps(self.FILTER_LABELS)),
+ )
+ self.render(request)
+
+ events = channel.json_body["chunk"]
+
+ self.assertEqual(len(events), 2, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
+ self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
+
+ def test_messages_filter_not_labels(self):
+ """Test that we can filter by the absence of a label on a /messages request."""
+ self._send_labelled_messages_in_room()
+
+ token = "s0_0_0_0_0_0_0_0_0"
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/messages?access_token=%s&from=%s&filter=%s"
+ % (self.room_id, self.tok, token, json.dumps(self.FILTER_NOT_LABELS)),
+ )
+ self.render(request)
+
+ events = channel.json_body["chunk"]
+
+ self.assertEqual(len(events), 4, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "without label", events[0])
+ self.assertEqual(events[1]["content"]["body"], "without label", events[1])
+ self.assertEqual(events[2]["content"]["body"], "with wrong label", events[2])
+ self.assertEqual(
+ events[3]["content"]["body"], "with two wrong labels", events[3]
+ )
+
+ def test_messages_filter_labels_not_labels(self):
+ """Test that we can filter by both a label and the absence of another label on a
+ /messages request.
+ """
+ self._send_labelled_messages_in_room()
+
+ token = "s0_0_0_0_0_0_0_0_0"
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/messages?access_token=%s&from=%s&filter=%s"
+ % (
+ self.room_id,
+ self.tok,
+ token,
+ json.dumps(self.FILTER_LABELS_NOT_LABELS),
+ ),
+ )
+ self.render(request)
+
+ events = channel.json_body["chunk"]
+
+ self.assertEqual(len(events), 1, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
+
+ def test_search_filter_labels(self):
+ """Test that we can filter by a label on a /search request."""
+ request_data = json.dumps(
+ {
+ "search_categories": {
+ "room_events": {
+ "search_term": "label",
+ "filter": self.FILTER_LABELS,
+ }
+ }
+ }
+ )
+
+ self._send_labelled_messages_in_room()
+
+ request, channel = self.make_request(
+ "POST", "/search?access_token=%s" % self.tok, request_data
+ )
+ self.render(request)
+
+ results = channel.json_body["search_categories"]["room_events"]["results"]
+
+ self.assertEqual(
+ len(results), 2, [result["result"]["content"] for result in results],
+ )
+ self.assertEqual(
+ results[0]["result"]["content"]["body"],
+ "with right label",
+ results[0]["result"]["content"]["body"],
+ )
+ self.assertEqual(
+ results[1]["result"]["content"]["body"],
+ "with right label",
+ results[1]["result"]["content"]["body"],
+ )
+
+ def test_search_filter_not_labels(self):
+ """Test that we can filter by the absence of a label on a /search request."""
+ request_data = json.dumps(
+ {
+ "search_categories": {
+ "room_events": {
+ "search_term": "label",
+ "filter": self.FILTER_NOT_LABELS,
+ }
+ }
+ }
+ )
+
+ self._send_labelled_messages_in_room()
+
+ request, channel = self.make_request(
+ "POST", "/search?access_token=%s" % self.tok, request_data
+ )
+ self.render(request)
+
+ results = channel.json_body["search_categories"]["room_events"]["results"]
+
+ self.assertEqual(
+ len(results), 4, [result["result"]["content"] for result in results],
+ )
+ self.assertEqual(
+ results[0]["result"]["content"]["body"],
+ "without label",
+ results[0]["result"]["content"]["body"],
+ )
+ self.assertEqual(
+ results[1]["result"]["content"]["body"],
+ "without label",
+ results[1]["result"]["content"]["body"],
+ )
+ self.assertEqual(
+ results[2]["result"]["content"]["body"],
+ "with wrong label",
+ results[2]["result"]["content"]["body"],
+ )
+ self.assertEqual(
+ results[3]["result"]["content"]["body"],
+ "with two wrong labels",
+ results[3]["result"]["content"]["body"],
+ )
+
+ def test_search_filter_labels_not_labels(self):
+ """Test that we can filter by both a label and the absence of another label on a
+ /search request.
+ """
+ request_data = json.dumps(
+ {
+ "search_categories": {
+ "room_events": {
+ "search_term": "label",
+ "filter": self.FILTER_LABELS_NOT_LABELS,
+ }
+ }
+ }
+ )
+
+ self._send_labelled_messages_in_room()
+
+ request, channel = self.make_request(
+ "POST", "/search?access_token=%s" % self.tok, request_data
+ )
+ self.render(request)
+
+ results = channel.json_body["search_categories"]["room_events"]["results"]
+
+ self.assertEqual(
+ len(results), 1, [result["result"]["content"] for result in results],
+ )
+ self.assertEqual(
+ results[0]["result"]["content"]["body"],
+ "with wrong label",
+ results[0]["result"]["content"]["body"],
+ )
+
+ def _send_labelled_messages_in_room(self):
+ """Sends several messages to a room with different labels (or without any) to test
+ filtering by label.
+ Returns:
+ The ID of the event to use if we're testing filtering on /context.
+ """
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with right label",
+ EventContentFields.LABELS: ["#fun"],
+ },
+ tok=self.tok,
+ )
+
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "without label"},
+ tok=self.tok,
+ )
+
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "without label"},
+ tok=self.tok,
+ )
+ # Return this event's ID when we test filtering in /context requests.
+ event_id = res["event_id"]
+
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with wrong label",
+ EventContentFields.LABELS: ["#work"],
+ },
+ tok=self.tok,
+ )
+
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with two wrong labels",
+ EventContentFields.LABELS: ["#work", "#notfun"],
+ },
+ tok=self.tok,
+ )
+
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with right label",
+ EventContentFields.LABELS: ["#fun"],
+ },
+ tok=self.tok,
+ )
+
+ return event_id
+
+
+class ContextTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ account.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("user", "password")
+ self.tok = self.login("user", "password")
+ self.room_id = self.helper.create_room_as(
+ self.user_id, tok=self.tok, is_public=False
+ )
+
+ self.other_user_id = self.register_user("user2", "password")
+ self.other_tok = self.login("user2", "password")
+
+ self.helper.invite(self.room_id, self.user_id, self.other_user_id, tok=self.tok)
+ self.helper.join(self.room_id, self.other_user_id, tok=self.other_tok)
+
+ def test_erased_sender(self):
+ """Test that an erasure request results in the requester's events being hidden
+ from any new member of the room.
+ """
+
+ # Send a bunch of events in the room.
+
+ self.helper.send(self.room_id, "message 1", tok=self.tok)
+ self.helper.send(self.room_id, "message 2", tok=self.tok)
+ event_id = self.helper.send(self.room_id, "message 3", tok=self.tok)["event_id"]
+ self.helper.send(self.room_id, "message 4", tok=self.tok)
+ self.helper.send(self.room_id, "message 5", tok=self.tok)
+
+ # Check that we can still see the messages before the erasure request.
+
+ request, channel = self.make_request(
+ "GET",
+ '/rooms/%s/context/%s?filter={"types":["m.room.message"]}'
+ % (self.room_id, event_id),
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ events_before = channel.json_body["events_before"]
+
+ self.assertEqual(len(events_before), 2, events_before)
+ self.assertEqual(
+ events_before[0].get("content", {}).get("body"),
+ "message 2",
+ events_before[0],
+ )
+ self.assertEqual(
+ events_before[1].get("content", {}).get("body"),
+ "message 1",
+ events_before[1],
+ )
+
+ self.assertEqual(
+ channel.json_body["event"].get("content", {}).get("body"),
+ "message 3",
+ channel.json_body["event"],
+ )
+
+ events_after = channel.json_body["events_after"]
+
+ self.assertEqual(len(events_after), 2, events_after)
+ self.assertEqual(
+ events_after[0].get("content", {}).get("body"),
+ "message 4",
+ events_after[0],
+ )
+ self.assertEqual(
+ events_after[1].get("content", {}).get("body"),
+ "message 5",
+ events_after[1],
+ )
+
+ # Deactivate the first account and erase the user's data.
+
+ deactivate_account_handler = self.hs.get_deactivate_account_handler()
+ self.get_success(
+ deactivate_account_handler.deactivate_account(self.user_id, erase_data=True)
+ )
+
+ # Invite another user in the room. This is needed because messages will be
+ # pruned only if the user wasn't a member of the room when the messages were
+ # sent.
+
+ invited_user_id = self.register_user("user3", "password")
+ invited_tok = self.login("user3", "password")
+
+ self.helper.invite(
+ self.room_id, self.other_user_id, invited_user_id, tok=self.other_tok
+ )
+ self.helper.join(self.room_id, invited_user_id, tok=invited_tok)
+
+ # Check that a user that joined the room after the erasure request can't see
+ # the messages anymore.
+
+ request, channel = self.make_request(
+ "GET",
+ '/rooms/%s/context/%s?filter={"types":["m.room.message"]}'
+ % (self.room_id, event_id),
+ access_token=invited_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ events_before = channel.json_body["events_before"]
+
+ self.assertEqual(len(events_before), 2, events_before)
+ self.assertDictEqual(events_before[0].get("content"), {}, events_before[0])
+ self.assertDictEqual(events_before[1].get("content"), {}, events_before[1])
+
+ self.assertDictEqual(
+ channel.json_body["event"].get("content"), {}, channel.json_body["event"]
+ )
+
+ events_after = channel.json_body["events_after"]
+
+ self.assertEqual(len(events_after), 2, events_after)
+ self.assertDictEqual(events_after[0].get("content"), {}, events_after[0])
+ self.assertEqual(events_after[1].get("content"), {}, events_after[1])
+
+
+class RoomAliasListTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ directory.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.room_owner = self.register_user("room_owner", "test")
+ self.room_owner_tok = self.login("room_owner", "test")
+
+ self.room_id = self.helper.create_room_as(
+ self.room_owner, tok=self.room_owner_tok
+ )
+
+ def test_no_aliases(self):
+ res = self._get_aliases(self.room_owner_tok)
+ self.assertEqual(res["aliases"], [])
+
+ def test_not_in_room(self):
+ self.register_user("user", "test")
+ user_tok = self.login("user", "test")
+ res = self._get_aliases(user_tok, expected_code=403)
+ self.assertEqual(res["errcode"], "M_FORBIDDEN")
+
+ def test_admin_user(self):
+ alias1 = self._random_alias()
+ self._set_alias_via_directory(alias1)
+
+ self.register_user("user", "test", admin=True)
+ user_tok = self.login("user", "test")
+
+ res = self._get_aliases(user_tok)
+ self.assertEqual(res["aliases"], [alias1])
+
+ def test_with_aliases(self):
+ alias1 = self._random_alias()
+ alias2 = self._random_alias()
+
+ self._set_alias_via_directory(alias1)
+ self._set_alias_via_directory(alias2)
+
+ res = self._get_aliases(self.room_owner_tok)
+ self.assertEqual(set(res["aliases"]), {alias1, alias2})
+
+ def test_peekable_room(self):
+ alias1 = self._random_alias()
+ self._set_alias_via_directory(alias1)
+
+ self.helper.send_state(
+ self.room_id,
+ EventTypes.RoomHistoryVisibility,
+ body={"history_visibility": "world_readable"},
+ tok=self.room_owner_tok,
+ )
+
+ self.register_user("user", "test")
+ user_tok = self.login("user", "test")
+
+ res = self._get_aliases(user_tok)
+ self.assertEqual(res["aliases"], [alias1])
+
+ def _get_aliases(self, access_token: str, expected_code: int = 200) -> JsonDict:
+ """Calls the endpoint under test. returns the json response object."""
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc2432/rooms/%s/aliases"
+ % (self.room_id,),
+ access_token=access_token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, expected_code, channel.result)
+ res = channel.json_body
+ self.assertIsInstance(res, dict)
+ if expected_code == 200:
+ self.assertIsInstance(res["aliases"], list)
+ return res
+
+ def _random_alias(self) -> str:
+ return RoomAlias(random_string(5), self.hs.hostname).to_string()
+
+ def _set_alias_via_directory(self, alias: str, expected_code: int = 200):
+ url = "/_matrix/client/r0/directory/room/" + alias
+ data = {"room_id": self.room_id}
+ request_data = json.dumps(data)
+
+ request, channel = self.make_request(
+ "PUT", url, request_data, access_token=self.room_owner_tok
+ )
+ self.render(request)
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+
+class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ directory.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.room_owner = self.register_user("room_owner", "test")
+ self.room_owner_tok = self.login("room_owner", "test")
+
+ self.room_id = self.helper.create_room_as(
+ self.room_owner, tok=self.room_owner_tok
+ )
+
+ self.alias = "#alias:test"
+ self._set_alias_via_directory(self.alias)
+
+ def _set_alias_via_directory(self, alias: str, expected_code: int = 200):
+ url = "/_matrix/client/r0/directory/room/" + alias
+ data = {"room_id": self.room_id}
+ request_data = json.dumps(data)
+
+ request, channel = self.make_request(
+ "PUT", url, request_data, access_token=self.room_owner_tok
+ )
+ self.render(request)
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict:
+ """Calls the endpoint under test. returns the json response object."""
+ request, channel = self.make_request(
+ "GET",
+ "rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
+ access_token=self.room_owner_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, expected_code, channel.result)
+ res = channel.json_body
+ self.assertIsInstance(res, dict)
+ return res
+
+ def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict:
+ """Calls the endpoint under test. returns the json response object."""
+ request, channel = self.make_request(
+ "PUT",
+ "rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
+ json.dumps(content),
+ access_token=self.room_owner_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, expected_code, channel.result)
+ res = channel.json_body
+ self.assertIsInstance(res, dict)
+ return res
+
+ def test_canonical_alias(self):
+ """Test a basic alias message."""
+ # There is no canonical alias to start with.
+ self._get_canonical_alias(expected_code=404)
+
+ # Create an alias.
+ self._set_canonical_alias({"alias": self.alias})
+
+ # Canonical alias now exists!
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {"alias": self.alias})
+
+ # Now remove the alias.
+ self._set_canonical_alias({})
+
+ # There is an alias event, but it is empty.
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {})
+
+ def test_alt_aliases(self):
+ """Test a canonical alias message with alt_aliases."""
+ # Create an alias.
+ self._set_canonical_alias({"alt_aliases": [self.alias]})
+
+ # Canonical alias now exists!
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {"alt_aliases": [self.alias]})
+
+ # Now remove the alt_aliases.
+ self._set_canonical_alias({})
+
+ # There is an alias event, but it is empty.
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {})
+
+ def test_alias_alt_aliases(self):
+ """Test a canonical alias message with an alias and alt_aliases."""
+ # Create an alias.
+ self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
+
+ # Canonical alias now exists!
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]})
+
+ # Now remove the alias and alt_aliases.
+ self._set_canonical_alias({})
+
+ # There is an alias event, but it is empty.
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {})
+
+ def test_partial_modify(self):
+ """Test removing only the alt_aliases."""
+ # Create an alias.
+ self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
+
+ # Canonical alias now exists!
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]})
+
+ # Now remove the alt_aliases.
+ self._set_canonical_alias({"alias": self.alias})
+
+ # There is an alias event, but it is empty.
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {"alias": self.alias})
+
+ def test_add_alias(self):
+ """Test removing only the alt_aliases."""
+ # Create an additional alias.
+ second_alias = "#second:test"
+ self._set_alias_via_directory(second_alias)
+
+ # Add the canonical alias.
+ self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
+
+ # Then add the second alias.
+ self._set_canonical_alias(
+ {"alias": self.alias, "alt_aliases": [self.alias, second_alias]}
+ )
+
+ # Canonical alias now exists!
+ res = self._get_canonical_alias()
+ self.assertEqual(
+ res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]}
+ )
+
+ def test_bad_data(self):
+ """Invalid data for alt_aliases should cause errors."""
+ self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400)
+ self._set_canonical_alias({"alt_aliases": None}, expected_code=400)
+ self._set_canonical_alias({"alt_aliases": 0}, expected_code=400)
+ self._set_canonical_alias({"alt_aliases": 1}, expected_code=400)
+ self._set_canonical_alias({"alt_aliases": False}, expected_code=400)
+ self._set_canonical_alias({"alt_aliases": True}, expected_code=400)
+ self._set_canonical_alias({"alt_aliases": {}}, expected_code=400)
+
+ def test_bad_alias(self):
+ """An alias which does not point to the room raises a SynapseError."""
+ self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400)
+ self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400)
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 30fb77bac8..4bc3aaf02d 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -109,7 +109,9 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code)
self.assertEquals(self.event_source.get_current_key(), 1)
- events = self.event_source.get_new_events(from_key=0, room_ids=[self.room_id])
+ events = self.get_success(
+ self.event_source.get_new_events(from_key=0, room_ids=[self.room_id])
+ )
self.assertEquals(
events[0],
[
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 449a69183f..873d5ef99c 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -1,5 +1,8 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,6 +21,8 @@ import time
import attr
+from twisted.web.resource import Resource
+
from synapse.api.constants import Membership
from tests.server import make_request, render
@@ -44,7 +49,7 @@ class RestHelper(object):
path = path + "?access_token=%s" % tok
request, channel = make_request(
- self.hs.get_reactor(), "POST", path, json.dumps(content).encode('utf8')
+ self.hs.get_reactor(), "POST", path, json.dumps(content).encode("utf8")
)
render(request, self.resource, self.hs.get_reactor())
@@ -93,7 +98,7 @@ class RestHelper(object):
data = {"membership": membership}
request, channel = make_request(
- self.hs.get_reactor(), "PUT", path, json.dumps(data).encode('utf8')
+ self.hs.get_reactor(), "PUT", path, json.dumps(data).encode("utf8")
)
render(request, self.resource, self.hs.get_reactor())
@@ -106,18 +111,27 @@ class RestHelper(object):
self.auth_user_id = temp_id
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
- if txn_id is None:
- txn_id = "m%s" % (str(time.time()))
if body is None:
body = "body_text_here"
- path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id)
content = {"msgtype": "m.text", "body": body}
+
+ return self.send_event(
+ room_id, "m.room.message", content, txn_id, tok, expect_code
+ )
+
+ def send_event(
+ self, room_id, type, content={}, txn_id=None, tok=None, expect_code=200
+ ):
+ if txn_id is None:
+ txn_id = "m%s" % (str(time.time()))
+
+ path = "/_matrix/client/r0/rooms/%s/send/%s/%s" % (room_id, type, txn_id)
if tok:
path = path + "?access_token=%s" % tok
request, channel = make_request(
- self.hs.get_reactor(), "PUT", path, json.dumps(content).encode('utf8')
+ self.hs.get_reactor(), "PUT", path, json.dumps(content).encode("utf8")
)
render(request, self.resource, self.hs.get_reactor())
@@ -148,3 +162,38 @@ class RestHelper(object):
)
return channel.json_body
+
+ def upload_media(
+ self,
+ resource: Resource,
+ image_data: bytes,
+ tok: str,
+ filename: str = "test.png",
+ expect_code: int = 200,
+ ) -> dict:
+ """Upload a piece of test media to the media repo
+ Args:
+ resource: The resource that will handle the upload request
+ image_data: The image data to upload
+ tok: The user token to use during the upload
+ filename: The filename of the media to be uploaded
+ expect_code: The return code to expect from attempting to upload the media
+ """
+ image_length = len(image_data)
+ path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
+ request, channel = make_request(
+ self.hs.get_reactor(), "POST", path, content=image_data, access_token=tok
+ )
+ request.requestHeaders.addRawHeader(
+ b"Content-Length", str(image_length).encode("UTF-8")
+ )
+ request.render(resource)
+ self.hs.get_reactor().pump([100])
+
+ assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
+ expect_code,
+ int(channel.result["code"]),
+ channel.result["body"],
+ )
+
+ return channel.json_body
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 32adf88c35..c3facc00eb 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -135,9 +135,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
self.assertEquals(len(self.email_attempts), 1)
# Attempt to reset password without clicking the link
- self._reset_password(
- new_password, session_id, client_secret, expected_code=401,
- )
+ self._reset_password(new_password, session_id, client_secret, expected_code=401)
# Assert we can log in with the old password
self.login("kermit", old_password)
@@ -172,9 +170,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
session_id = "weasle"
# Attempt to reset password without even requesting an email
- self._reset_password(
- new_password, session_id, client_secret, expected_code=401,
- )
+ self._reset_password(new_password, session_id, client_secret, expected_code=401)
# Assert we can log in with the old password
self.login("kermit", old_password)
@@ -252,29 +248,14 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
]
def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver()
- return hs
+ self.hs = self.setup_test_homeserver()
+ return self.hs
def test_deactivate_account(self):
user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test")
- request_data = json.dumps({
- "auth": {
- "type": "m.login.password",
- "user": user_id,
- "password": "test",
- },
- "erase": False,
- })
- request, channel = self.make_request(
- "POST",
- "account/deactivate",
- request_data,
- access_token=tok,
- )
- self.render(request)
- self.assertEqual(request.code, 200)
+ self.deactivate(user_id, tok)
store = self.hs.get_datastore()
@@ -304,7 +285,9 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
)
# Make sure the invite is here.
- pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id))
+ pending_invites = self.get_success(
+ store.get_invited_rooms_for_local_user(invitee_id)
+ )
self.assertEqual(len(pending_invites), 1, pending_invites)
self.assertEqual(pending_invites[0].room_id, room_id, pending_invites)
@@ -312,12 +295,16 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
self.deactivate(invitee_id, invitee_tok)
# Check that the invite isn't there anymore.
- pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id))
+ pending_invites = self.get_success(
+ store.get_invited_rooms_for_local_user(invitee_id)
+ )
self.assertEqual(len(pending_invites), 0, pending_invites)
# Check that the membership of @invitee:test in the room is now "leave".
memberships = self.get_success(
- store.get_rooms_for_user_where_membership_is(invitee_id, [Membership.LEAVE])
+ store.get_rooms_for_local_user_where_membership_is(
+ invitee_id, [Membership.LEAVE]
+ )
)
self.assertEqual(len(memberships), 1, memberships)
self.assertEqual(memberships[0].room_id, room_id, memberships)
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index b9ef46e8fb..b6df1396ad 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -18,11 +18,22 @@ from twisted.internet.defer import succeed
import synapse.rest.admin
from synapse.api.constants import LoginType
+from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.rest.client.v2_alpha import auth, register
from tests import unittest
+class DummyRecaptchaChecker(UserInteractiveAuthChecker):
+ def __init__(self, hs):
+ super().__init__(hs)
+ self.recaptcha_attempts = []
+
+ def check_auth(self, authdict, clientip):
+ self.recaptcha_attempts.append((authdict, clientip))
+ return succeed(True)
+
+
class FallbackAuthTests(unittest.HomeserverTestCase):
servlets = [
@@ -44,15 +55,9 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
return hs
def prepare(self, reactor, clock, hs):
+ self.recaptcha_checker = DummyRecaptchaChecker(hs)
auth_handler = hs.get_auth_handler()
-
- self.recaptcha_attempts = []
-
- def _recaptcha(authdict, clientip):
- self.recaptcha_attempts.append((authdict, clientip))
- return succeed(True)
-
- auth_handler.checkers[LoginType.RECAPTCHA] = _recaptcha
+ auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
@unittest.INFO
def test_fallback_captcha(self):
@@ -89,8 +94,9 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
self.assertEqual(request.code, 200)
# The recaptcha handler is called with the response given
- self.assertEqual(len(self.recaptcha_attempts), 1)
- self.assertEqual(self.recaptcha_attempts[0][0]["response"], "a")
+ attempts = self.recaptcha_checker.recaptcha_attempts
+ self.assertEqual(len(attempts), 1)
+ self.assertEqual(attempts[0][0]["response"], "a")
# also complete the dummy auth
request, channel = self.make_request(
diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py
index bce5b0cf4c..b9e01c9418 100644
--- a/tests/rest/client/v2_alpha/test_capabilities.py
+++ b/tests/rest/client/v2_alpha/test_capabilities.py
@@ -47,15 +47,15 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request("GET", self.url, access_token=access_token)
self.render(request)
- capabilities = channel.json_body['capabilities']
+ capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, 200)
- for room_version in capabilities['m.room_versions']['available'].keys():
+ for room_version in capabilities["m.room_versions"]["available"].keys():
self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, "" + room_version)
self.assertEqual(
self.config.default_room_version.identifier,
- capabilities['m.room_versions']['default'],
+ capabilities["m.room_versions"]["default"],
)
def test_get_change_password_capabilities(self):
@@ -66,16 +66,16 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request("GET", self.url, access_token=access_token)
self.render(request)
- capabilities = channel.json_body['capabilities']
+ capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, 200)
# Test case where password is handled outside of Synapse
- self.assertTrue(capabilities['m.change_password']['enabled'])
+ self.assertTrue(capabilities["m.change_password"]["enabled"])
self.get_success(self.store.user_set_password_hash(user, None))
request, channel = self.make_request("GET", self.url, access_token=access_token)
self.render(request)
- capabilities = channel.json_body['capabilities']
+ capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, 200)
- self.assertFalse(capabilities['m.change_password']['enabled'])
+ self.assertFalse(capabilities["m.change_password"]["enabled"])
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index f42a8efbf4..e0e9e94fbf 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -92,7 +92,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
)
self.render(request)
- self.assertEqual(channel.result["code"], b"400")
+ self.assertEqual(channel.result["code"], b"404")
self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND)
# Currently invalid params do not have an appropriate errcode
diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py
index 17c22fe751..37f970c6b0 100644
--- a/tests/rest/client/v2_alpha/test_password_policy.py
+++ b/tests/rest/client/v2_alpha/test_password_policy.py
@@ -60,9 +60,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
}
config = self.default_config()
- config["password_config"] = {
- "policy": self.policy,
- }
+ config["password_config"] = {"policy": self.policy}
hs = self.setup_test_homeserver(config=config)
return hs
@@ -70,17 +68,23 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
def test_get_policy(self):
"""Tests if the /password_policy endpoint returns the configured policy."""
- request, channel = self.make_request("GET", "/_matrix/client/r0/password_policy")
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/password_policy"
+ )
self.render(request)
self.assertEqual(channel.code, 200, channel.result)
- self.assertEqual(channel.json_body, {
- "m.minimum_length": 10,
- "m.require_digit": True,
- "m.require_symbol": True,
- "m.require_lowercase": True,
- "m.require_uppercase": True,
- }, channel.result)
+ self.assertEqual(
+ channel.json_body,
+ {
+ "m.minimum_length": 10,
+ "m.require_digit": True,
+ "m.require_symbol": True,
+ "m.require_lowercase": True,
+ "m.require_uppercase": True,
+ },
+ channel.result,
+ )
def test_password_too_short(self):
request_data = json.dumps({"username": "kermit", "password": "shorty"})
@@ -89,9 +93,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
- channel.json_body["errcode"],
- Codes.PASSWORD_TOO_SHORT,
- channel.result,
+ channel.json_body["errcode"], Codes.PASSWORD_TOO_SHORT, channel.result
)
def test_password_no_digit(self):
@@ -101,9 +103,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
- channel.json_body["errcode"],
- Codes.PASSWORD_NO_DIGIT,
- channel.result,
+ channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT, channel.result
)
def test_password_no_symbol(self):
@@ -113,9 +113,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
- channel.json_body["errcode"],
- Codes.PASSWORD_NO_SYMBOL,
- channel.result,
+ channel.json_body["errcode"], Codes.PASSWORD_NO_SYMBOL, channel.result
)
def test_password_no_uppercase(self):
@@ -125,9 +123,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
- channel.json_body["errcode"],
- Codes.PASSWORD_NO_UPPERCASE,
- channel.result,
+ channel.json_body["errcode"], Codes.PASSWORD_NO_UPPERCASE, channel.result
)
def test_password_no_lowercase(self):
@@ -137,9 +133,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
- channel.json_body["errcode"],
- Codes.PASSWORD_NO_LOWERCASE,
- channel.result,
+ channel.json_body["errcode"], Codes.PASSWORD_NO_LOWERCASE, channel.result
)
def test_password_compliant(self):
@@ -161,14 +155,16 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
user_id = self.register_user("kermit", compliant_password)
tok = self.login("kermit", compliant_password)
- request_data = json.dumps({
- "new_password": not_compliant_password,
- "auth": {
- "password": compliant_password,
- "type": LoginType.PASSWORD,
- "user": user_id,
+ request_data = json.dumps(
+ {
+ "new_password": not_compliant_password,
+ "auth": {
+ "password": compliant_password,
+ "type": LoginType.PASSWORD,
+ "user": user_id,
+ },
}
- })
+ )
request, channel = self.make_request(
"POST",
"/_matrix/client/r0/account/password",
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index a5c7aaa9c0..d99b100d0f 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -38,19 +38,12 @@ from tests import unittest
class RegisterRestServletTestCase(unittest.HomeserverTestCase):
servlets = [register.register_servlets]
+ url = b"/_matrix/client/r0/register"
- def make_homeserver(self, reactor, clock):
-
- self.url = b"/_matrix/client/r0/register"
-
- self.hs = self.setup_test_homeserver()
- self.hs.config.enable_registration = True
- self.hs.config.registrations_require_3pid = []
- self.hs.config.auto_join_rooms = []
- self.hs.config.enable_registration_captcha = False
- self.hs.config.allow_guest_access = True
-
- return self.hs
+ def default_config(self, name="test"):
+ config = super().default_config(name)
+ config["allow_guest_access"] = True
+ return config
def test_POST_appservice_registration_valid(self):
user_id = "@as_user_kermit:test"
@@ -203,12 +196,78 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
+ def test_advertised_flows(self):
+ request, channel = self.make_request(b"POST", self.url, b"{}")
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ flows = channel.json_body["flows"]
+
+ # with the stock config, we only expect the dummy flow
+ self.assertCountEqual([["m.login.dummy"]], (f["stages"] for f in flows))
+
+ @unittest.override_config(
+ {
+ "public_baseurl": "https://test_server",
+ "enable_registration_captcha": True,
+ "user_consent": {
+ "version": "1",
+ "template_dir": "/",
+ "require_at_registration": True,
+ },
+ "account_threepid_delegates": {
+ "email": "https://id_server",
+ "msisdn": "https://id_server",
+ },
+ }
+ )
+ def test_advertised_flows_captcha_and_terms_and_3pids(self):
+ request, channel = self.make_request(b"POST", self.url, b"{}")
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ flows = channel.json_body["flows"]
+
+ self.assertCountEqual(
+ [
+ ["m.login.recaptcha", "m.login.terms", "m.login.dummy"],
+ ["m.login.recaptcha", "m.login.terms", "m.login.email.identity"],
+ ["m.login.recaptcha", "m.login.terms", "m.login.msisdn"],
+ [
+ "m.login.recaptcha",
+ "m.login.terms",
+ "m.login.msisdn",
+ "m.login.email.identity",
+ ],
+ ],
+ (f["stages"] for f in flows),
+ )
+
+ @unittest.override_config(
+ {
+ "public_baseurl": "https://test_server",
+ "registrations_require_3pid": ["email"],
+ "disable_msisdn_registration": True,
+ "email": {
+ "smtp_host": "mail_server",
+ "smtp_port": 2525,
+ "notif_from": "sender@host",
+ },
+ }
+ )
+ def test_advertised_flows_no_msisdn_email_required(self):
+ request, channel = self.make_request(b"POST", self.url, b"{}")
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ flows = channel.json_body["flows"]
+
+ # with the stock config, we expect all four combinations of 3pid
+ self.assertCountEqual(
+ [["m.login.email.identity"]], (f["stages"] for f in flows)
+ )
+
class RegisterHideProfileTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets_for_client_rest_resource,
- ]
+ servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
def make_homeserver(self, reactor, clock):
@@ -219,15 +278,11 @@ class RegisterHideProfileTestCase(unittest.HomeserverTestCase):
config["show_users_in_user_directory"] = False
config["replicate_user_profiles_to"] = ["fakeserver"]
- mock_http_client = Mock(spec=[
- "get_json",
- "post_json_get_json",
- ])
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"])
mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
self.hs = self.setup_test_homeserver(
- config=config,
- simple_http_client=mock_http_client,
+ config=config, simple_http_client=mock_http_client
)
return self.hs
@@ -376,14 +431,11 @@ class AccountValidityUserDirectoryTestCase(unittest.HomeserverTestCase):
config["replicate_user_profiles_to"] = "test.is"
# Mock homeserver requests to an identity server
- mock_http_client = Mock(spec=[
- "post_json_get_json",
- ])
+ mock_http_client = Mock(spec=["post_json_get_json"])
mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
self.hs = self.setup_test_homeserver(
- config=config,
- simple_http_client=mock_http_client,
+ config=config, simple_http_client=mock_http_client
)
return self.hs
@@ -517,14 +569,13 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# Email config.
self.email_attempts = []
- def sendmail(*args, **kwargs):
+ async def sendmail(*args, **kwargs):
self.email_attempts.append((args, kwargs))
- return
config["email"] = {
"enable_notifs": True,
"template_dir": os.path.abspath(
- pkg_resources.resource_filename('synapse', 'res/templates')
+ pkg_resources.resource_filename("synapse", "res/templates")
),
"expiry_template_html": "notice_expiry.html",
"expiry_template_text": "notice_expiry.txt",
@@ -624,19 +675,18 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
(user_id, tok) = self.create_user()
- request_data = json.dumps({
- "auth": {
- "type": "m.login.password",
- "user": user_id,
- "password": "monkey",
- },
- "erase": False,
- })
+ request_data = json.dumps(
+ {
+ "auth": {
+ "type": "m.login.password",
+ "user": user_id,
+ "password": "monkey",
+ },
+ "erase": False,
+ }
+ )
request, channel = self.make_request(
- "POST",
- "account/deactivate",
- request_data,
- access_token=tok,
+ "POST", "account/deactivate", request_data, access_token=tok
)
self.render(request)
self.assertEqual(request.code, 200, channel.result)
@@ -660,7 +710,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
added_at=now,
)
)
- return (user_id, tok)
+ return user_id, tok
def test_manual_email_send_expired_account(self):
user_id = self.register_user("kermit", "monkey")
@@ -700,20 +750,16 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets_for_client_rest_resource,
- ]
+ servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
def make_homeserver(self, reactor, clock):
self.validity_period = 10
- self.max_delta = self.validity_period * 10. / 100.
+ self.max_delta = self.validity_period * 10.0 / 100.0
config = self.default_config()
config["enable_registration"] = True
- config["account_validity"] = {
- "enabled": False,
- }
+ config["account_validity"] = {"enabled": False}
self.hs = self.setup_test_homeserver(config=config)
self.hs.config.account_validity.period = self.validity_period
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index 43b3049daa..c7e5859970 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -56,7 +56,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
creates the right shape of event.
"""
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key=u"👍")
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
self.assertEquals(200, channel.code, channel.json_body)
event_id = channel.json_body["event_id"]
@@ -76,7 +76,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"content": {
"m.relates_to": {
"event_id": self.parent_id,
- "key": u"👍",
+ "key": "👍",
"rel_type": RelationTypes.ANNOTATION,
}
},
@@ -93,7 +93,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
def test_deny_double_react(self):
"""Test that we deny relations on membership events
"""
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
self.assertEquals(200, channel.code, channel.json_body)
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
@@ -126,6 +126,11 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel.json_body["chunk"][0],
)
+ # We also expect to get the original event (the id of which is self.parent_id)
+ self.assertEquals(
+ channel.json_body["original_event"]["event_id"], self.parent_id
+ )
+
# Make sure next_batch has something in it that looks like it could be a
# valid token.
self.assertIsInstance(
@@ -187,7 +192,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
access_tokens.append(token)
idx = 0
- sent_groups = {u"👍": 10, u"a": 7, u"b": 5, u"c": 3, u"d": 2, u"e": 1}
+ sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1}
for key in itertools.chain.from_iterable(
itertools.repeat(key, num) for key, num in sent_groups.items()
):
@@ -259,7 +264,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self._send_relation(
RelationTypes.ANNOTATION,
"m.reaction",
- key=u"👍",
+ key="👍",
access_token=access_tokens[idx],
)
self.assertEquals(200, channel.code, channel.json_body)
@@ -273,7 +278,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
prev_token = None
found_event_ids = []
- encoded_key = six.moves.urllib.parse.quote_plus(u"👍".encode("utf-8"))
+ encoded_key = six.moves.urllib.parse.quote_plus("👍".encode("utf-8"))
for _ in range(20):
from_token = ""
if prev_token:
@@ -466,9 +471,15 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.json_body["content"], new_body)
- self.assertEquals(
- channel.json_body["unsigned"].get("m.relations"),
- {RelationTypes.REPLACE: {"event_id": edit_event_id}},
+ relations_dict = channel.json_body["unsigned"].get("m.relations")
+ self.assertIn(RelationTypes.REPLACE, relations_dict)
+
+ m_replace_dict = relations_dict[RelationTypes.REPLACE]
+ for key in ["event_id", "sender", "origin_server_ts"]:
+ self.assertIn(key, m_replace_dict)
+
+ self.assert_dict(
+ {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
)
def test_multi_edit(self):
@@ -518,19 +529,133 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.json_body["content"], new_body)
- self.assertEquals(
- channel.json_body["unsigned"].get("m.relations"),
- {RelationTypes.REPLACE: {"event_id": edit_event_id}},
+ relations_dict = channel.json_body["unsigned"].get("m.relations")
+ self.assertIn(RelationTypes.REPLACE, relations_dict)
+
+ m_replace_dict = relations_dict[RelationTypes.REPLACE]
+ for key in ["event_id", "sender", "origin_server_ts"]:
+ self.assertIn(key, m_replace_dict)
+
+ self.assert_dict(
+ {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
+ )
+
+ def test_relations_redaction_redacts_edits(self):
+ """Test that edits of an event are redacted when the original event
+ is redacted.
+ """
+ # Send a new event
+ res = self.helper.send(self.room, body="Heyo!", tok=self.user_token)
+ original_event_id = res["event_id"]
+
+ # Add a relation
+ channel = self._send_relation(
+ RelationTypes.REPLACE,
+ "m.room.message",
+ parent_id=original_event_id,
+ content={
+ "msgtype": "m.text",
+ "body": "Wibble",
+ "m.new_content": {"msgtype": "m.text", "body": "First edit"},
+ },
)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ # Check the relation is returned
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message"
+ % (self.room, original_event_id),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assertIn("chunk", channel.json_body)
+ self.assertEquals(len(channel.json_body["chunk"]), 1)
+
+ # Redact the original event
+ request, channel = self.make_request(
+ "PUT",
+ "/rooms/%s/redact/%s/%s"
+ % (self.room, original_event_id, "test_relations_redaction_redacts_edits"),
+ access_token=self.user_token,
+ content="{}",
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ # Try to check for remaining m.replace relations
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message"
+ % (self.room, original_event_id),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ # Check that no relations are returned
+ self.assertIn("chunk", channel.json_body)
+ self.assertEquals(channel.json_body["chunk"], [])
+
+ def test_aggregations_redaction_prevents_access_to_aggregations(self):
+ """Test that annotations of an event are redacted when the original event
+ is redacted.
+ """
+ # Send a new event
+ res = self.helper.send(self.room, body="Hello!", tok=self.user_token)
+ original_event_id = res["event_id"]
+
+ # Add a relation
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", key="👍", parent_id=original_event_id
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ # Redact the original
+ request, channel = self.make_request(
+ "PUT",
+ "/rooms/%s/redact/%s/%s"
+ % (
+ self.room,
+ original_event_id,
+ "test_aggregations_redaction_prevents_access_to_aggregations",
+ ),
+ access_token=self.user_token,
+ content="{}",
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ # Check that aggregations returns zero
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/aggregations/%s/m.annotation/m.reaction"
+ % (self.room, original_event_id),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assertIn("chunk", channel.json_body)
+ self.assertEquals(channel.json_body["chunk"], [])
def _send_relation(
- self, relation_type, event_type, key=None, content={}, access_token=None
+ self,
+ relation_type,
+ event_type,
+ key=None,
+ content={},
+ access_token=None,
+ parent_id=None,
):
"""Helper function to send a relation pointing at `self.parent_id`
Args:
relation_type (str): One of `RelationTypes`
event_type (str): The type of the event to create
+ parent_id (str): The event_id this relation relates to. If None, then self.parent_id
key (str|None): The aggregation key used for m.annotation relation
type.
content(dict|None): The content of the created event.
@@ -547,10 +672,12 @@ class RelationsTestCase(unittest.HomeserverTestCase):
if key:
query = "?key=" + six.moves.urllib.parse.quote_plus(key.encode("utf-8"))
+ original_id = parent_id if parent_id else self.parent_id
+
request, channel = self.make_request(
"POST",
"/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
- % (self.room, self.parent_id, relation_type, event_type, query),
+ % (self.room, original_id, relation_type, event_type, query),
json.dumps(content).encode("utf-8"),
access_token=access_token,
)
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 71895094bd..fa3a3ec1bd 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
-# Copyright 2018 New Vector
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,10 +13,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from mock import Mock
+import json
import synapse.rest.admin
+from synapse.api.constants import EventContentFields, EventTypes
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import sync
@@ -26,14 +27,12 @@ from tests.server import TimedOutException
class FilterTestCase(unittest.HomeserverTestCase):
user_id = "@apple:test"
- servlets = [sync.register_servlets]
-
- def make_homeserver(self, reactor, clock):
-
- hs = self.setup_test_homeserver(
- "red", http_client=None, federation_client=Mock()
- )
- return hs
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
def test_sync_argless(self):
request, channel = self.make_request("GET", "/sync")
@@ -41,16 +40,14 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertTrue(
- set(
- [
- "next_batch",
- "rooms",
- "presence",
- "account_data",
- "to_device",
- "device_lists",
- ]
- ).issubset(set(channel.json_body.keys()))
+ {
+ "next_batch",
+ "rooms",
+ "presence",
+ "account_data",
+ "to_device",
+ "device_lists",
+ }.issubset(set(channel.json_body.keys()))
)
def test_sync_presence_disabled(self):
@@ -64,11 +61,149 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertTrue(
- set(
- ["next_batch", "rooms", "account_data", "to_device", "device_lists"]
- ).issubset(set(channel.json_body.keys()))
+ {
+ "next_batch",
+ "rooms",
+ "account_data",
+ "to_device",
+ "device_lists",
+ }.issubset(set(channel.json_body.keys()))
+ )
+
+
+class SyncFilterTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def test_sync_filter_labels(self):
+ """Test that we can filter by a label."""
+ sync_filter = json.dumps(
+ {
+ "room": {
+ "timeline": {
+ "types": [EventTypes.Message],
+ "org.matrix.labels": ["#fun"],
+ }
+ }
+ }
+ )
+
+ events = self._test_sync_filter_labels(sync_filter)
+
+ self.assertEqual(len(events), 2, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
+ self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
+
+ def test_sync_filter_not_labels(self):
+ """Test that we can filter by the absence of a label."""
+ sync_filter = json.dumps(
+ {
+ "room": {
+ "timeline": {
+ "types": [EventTypes.Message],
+ "org.matrix.not_labels": ["#fun"],
+ }
+ }
+ }
+ )
+
+ events = self._test_sync_filter_labels(sync_filter)
+
+ self.assertEqual(len(events), 3, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "without label", events[0])
+ self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1])
+ self.assertEqual(
+ events[2]["content"]["body"], "with two wrong labels", events[2]
+ )
+
+ def test_sync_filter_labels_not_labels(self):
+ """Test that we can filter by both a label and the absence of another label."""
+ sync_filter = json.dumps(
+ {
+ "room": {
+ "timeline": {
+ "types": [EventTypes.Message],
+ "org.matrix.labels": ["#work"],
+ "org.matrix.not_labels": ["#notfun"],
+ }
+ }
+ }
+ )
+
+ events = self._test_sync_filter_labels(sync_filter)
+
+ self.assertEqual(len(events), 1, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
+
+ def _test_sync_filter_labels(self, sync_filter):
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+
+ room_id = self.helper.create_room_as(user_id, tok=tok)
+
+ self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with right label",
+ EventContentFields.LABELS: ["#fun"],
+ },
+ tok=tok,
+ )
+
+ self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "without label"},
+ tok=tok,
)
+ self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with wrong label",
+ EventContentFields.LABELS: ["#work"],
+ },
+ tok=tok,
+ )
+
+ self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with two wrong labels",
+ EventContentFields.LABELS: ["#work", "#notfun"],
+ },
+ tok=tok,
+ )
+
+ self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with right label",
+ EventContentFields.LABELS: ["#fun"],
+ },
+ tok=tok,
+ )
+
+ request, channel = self.make_request(
+ "GET", "/sync?filter=%s" % sync_filter, access_token=tok
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ return channel.json_body["rooms"]["join"][room_id]["timeline"]["events"]
+
class SyncTypingTests(unittest.HomeserverTestCase):
diff --git a/tests/rest/key/__init__.py b/tests/rest/key/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/rest/key/__init__.py
diff --git a/tests/rest/key/v2/__init__.py b/tests/rest/key/v2/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/rest/key/v2/__init__.py
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
new file mode 100644
index 0000000000..6776a56cad
--- /dev/null
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -0,0 +1,257 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import urllib.parse
+from io import BytesIO, StringIO
+
+from mock import Mock
+
+import signedjson.key
+from canonicaljson import encode_canonical_json
+from nacl.signing import SigningKey
+from signedjson.sign import sign_json
+
+from twisted.web.resource import NoResource
+
+from synapse.crypto.keyring import PerspectivesKeyFetcher
+from synapse.http.site import SynapseRequest
+from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.storage.keys import FetchKeyResult
+from synapse.util.httpresourcetree import create_resource_tree
+from synapse.util.stringutils import random_string
+
+from tests import unittest
+from tests.server import FakeChannel, wait_until_result
+from tests.utils import default_config
+
+
+class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ self.http_client = Mock()
+ return self.setup_test_homeserver(http_client=self.http_client)
+
+ def create_test_json_resource(self):
+ return create_resource_tree(
+ {"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource()
+ )
+
+ def expect_outgoing_key_request(
+ self, server_name: str, signing_key: SigningKey
+ ) -> None:
+ """
+ Tell the mock http client to expect an outgoing GET request for the given key
+ """
+
+ def get_json(destination, path, ignore_backoff=False, **kwargs):
+ self.assertTrue(ignore_backoff)
+ self.assertEqual(destination, server_name)
+ key_id = "%s:%s" % (signing_key.alg, signing_key.version)
+ self.assertEqual(
+ path, "/_matrix/key/v2/server/%s" % (urllib.parse.quote(key_id),)
+ )
+
+ response = {
+ "server_name": server_name,
+ "old_verify_keys": {},
+ "valid_until_ts": 200 * 1000,
+ "verify_keys": {
+ key_id: {
+ "key": signedjson.key.encode_verify_key_base64(
+ signing_key.verify_key
+ )
+ }
+ },
+ }
+ sign_json(response, server_name, signing_key)
+ return response
+
+ self.http_client.get_json.side_effect = get_json
+
+
+class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
+ def make_notary_request(self, server_name: str, key_id: str) -> dict:
+ """Send a GET request to the test server requesting the given key.
+
+ Checks that the response is a 200 and returns the decoded json body.
+ """
+ channel = FakeChannel(self.site, self.reactor)
+ req = SynapseRequest(channel)
+ req.content = BytesIO(b"")
+ req.requestReceived(
+ b"GET",
+ b"/_matrix/key/v2/query/%s/%s"
+ % (server_name.encode("utf-8"), key_id.encode("utf-8")),
+ b"1.1",
+ )
+ wait_until_result(self.reactor, req)
+ self.assertEqual(channel.code, 200)
+ resp = channel.json_body
+ return resp
+
+ def test_get_key(self):
+ """Fetch a remote key"""
+ SERVER_NAME = "remote.server"
+ testkey = signedjson.key.generate_signing_key("ver1")
+ self.expect_outgoing_key_request(SERVER_NAME, testkey)
+
+ resp = self.make_notary_request(SERVER_NAME, "ed25519:ver1")
+ keys = resp["server_keys"]
+ self.assertEqual(len(keys), 1)
+
+ self.assertIn("ed25519:ver1", keys[0]["verify_keys"])
+ self.assertEqual(len(keys[0]["verify_keys"]), 1)
+
+ # it should be signed by both the origin server and the notary
+ self.assertIn(SERVER_NAME, keys[0]["signatures"])
+ self.assertIn(self.hs.hostname, keys[0]["signatures"])
+
+ def test_get_own_key(self):
+ """Fetch our own key"""
+ testkey = signedjson.key.generate_signing_key("ver1")
+ self.expect_outgoing_key_request(self.hs.hostname, testkey)
+
+ resp = self.make_notary_request(self.hs.hostname, "ed25519:ver1")
+ keys = resp["server_keys"]
+ self.assertEqual(len(keys), 1)
+
+ # it should be signed by both itself, and the notary signing key
+ sigs = keys[0]["signatures"]
+ self.assertEqual(len(sigs), 1)
+ self.assertIn(self.hs.hostname, sigs)
+ oursigs = sigs[self.hs.hostname]
+ self.assertEqual(len(oursigs), 2)
+
+ # the requested key should be present in the verify_keys section
+ self.assertIn("ed25519:ver1", keys[0]["verify_keys"])
+
+
+class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
+ """End-to-end tests of the perspectives fetch case
+
+ The idea here is to actually wire up a PerspectivesKeyFetcher to the notary
+ endpoint, to check that the two implementations are compatible.
+ """
+
+ def default_config(self, *args, **kwargs):
+ config = super().default_config(*args, **kwargs)
+
+ # replace the signing key with our own
+ self.hs_signing_key = signedjson.key.generate_signing_key("kssk")
+ strm = StringIO()
+ signedjson.key.write_signing_keys(strm, [self.hs_signing_key])
+ config["signing_key"] = strm.getvalue()
+
+ return config
+
+ def prepare(self, reactor, clock, homeserver):
+ # make a second homeserver, configured to use the first one as a key notary
+ self.http_client2 = Mock()
+ config = default_config(name="keyclient")
+ config["trusted_key_servers"] = [
+ {
+ "server_name": self.hs.hostname,
+ "verify_keys": {
+ "ed25519:%s"
+ % (
+ self.hs_signing_key.version,
+ ): signedjson.key.encode_verify_key_base64(
+ self.hs_signing_key.verify_key
+ )
+ },
+ }
+ ]
+ self.hs2 = self.setup_test_homeserver(
+ http_client=self.http_client2, config=config
+ )
+
+ # wire up outbound POST /key/v2/query requests from hs2 so that they
+ # will be forwarded to hs1
+ def post_json(destination, path, data):
+ self.assertEqual(destination, self.hs.hostname)
+ self.assertEqual(
+ path, "/_matrix/key/v2/query",
+ )
+
+ channel = FakeChannel(self.site, self.reactor)
+ req = SynapseRequest(channel)
+ req.content = BytesIO(encode_canonical_json(data))
+
+ req.requestReceived(
+ b"POST", path.encode("utf-8"), b"1.1",
+ )
+ wait_until_result(self.reactor, req)
+ self.assertEqual(channel.code, 200)
+ resp = channel.json_body
+ return resp
+
+ self.http_client2.post_json.side_effect = post_json
+
+ def test_get_key(self):
+ """Fetch a key belonging to a random server"""
+ # make up a key to be fetched.
+ testkey = signedjson.key.generate_signing_key("abc")
+
+ # we expect hs1 to make a regular key request to the target server
+ self.expect_outgoing_key_request("targetserver", testkey)
+ keyid = "ed25519:%s" % (testkey.version,)
+
+ fetcher = PerspectivesKeyFetcher(self.hs2)
+ d = fetcher.get_keys({"targetserver": {keyid: 1000}})
+ res = self.get_success(d)
+ self.assertIn("targetserver", res)
+ keyres = res["targetserver"][keyid]
+ assert isinstance(keyres, FetchKeyResult)
+ self.assertEqual(
+ signedjson.key.encode_verify_key_base64(keyres.verify_key),
+ signedjson.key.encode_verify_key_base64(testkey.verify_key),
+ )
+
+ def test_get_notary_key(self):
+ """Fetch a key belonging to the notary server"""
+ # make up a key to be fetched. We randomise the keyid to try to get it to
+ # appear before the key server signing key sometimes (otherwise we bail out
+ # before fetching its signature)
+ testkey = signedjson.key.generate_signing_key(random_string(5))
+
+ # we expect hs1 to make a regular key request to itself
+ self.expect_outgoing_key_request(self.hs.hostname, testkey)
+ keyid = "ed25519:%s" % (testkey.version,)
+
+ fetcher = PerspectivesKeyFetcher(self.hs2)
+ d = fetcher.get_keys({self.hs.hostname: {keyid: 1000}})
+ res = self.get_success(d)
+ self.assertIn(self.hs.hostname, res)
+ keyres = res[self.hs.hostname][keyid]
+ assert isinstance(keyres, FetchKeyResult)
+ self.assertEqual(
+ signedjson.key.encode_verify_key_base64(keyres.verify_key),
+ signedjson.key.encode_verify_key_base64(testkey.verify_key),
+ )
+
+ def test_get_notary_keyserver_key(self):
+ """Fetch the notary's keyserver key"""
+ # we expect hs1 to make a regular key request to itself
+ self.expect_outgoing_key_request(self.hs.hostname, self.hs_signing_key)
+ keyid = "ed25519:%s" % (self.hs_signing_key.version,)
+
+ fetcher = PerspectivesKeyFetcher(self.hs2)
+ d = fetcher.get_keys({self.hs.hostname: {keyid: 1000}})
+ res = self.get_success(d)
+ self.assertIn(self.hs.hostname, res)
+ keyres = res[self.hs.hostname][keyid]
+ assert isinstance(keyres, FetchKeyResult)
+ self.assertEqual(
+ signedjson.key.encode_verify_key_base64(keyres.verify_key),
+ signedjson.key.encode_verify_key_base64(self.hs_signing_key.verify_key),
+ )
diff --git a/tests/rest/media/v1/test_base.py b/tests/rest/media/v1/test_base.py
index 00688a7325..ebd7869208 100644
--- a/tests/rest/media/v1/test_base.py
+++ b/tests/rest/media/v1/test_base.py
@@ -21,17 +21,17 @@ from tests import unittest
class GetFileNameFromHeadersTests(unittest.TestCase):
# input -> expected result
TEST_CASES = {
- b"inline; filename=abc.txt": u"abc.txt",
- b'inline; filename="azerty"': u"azerty",
- b'inline; filename="aze%20rty"': u"aze%20rty",
- b'inline; filename="aze\"rty"': u'aze"rty',
- b'inline; filename="azer;ty"': u"azer;ty",
- b"inline; filename*=utf-8''foo%C2%A3bar": u"foo£bar",
+ b"inline; filename=abc.txt": "abc.txt",
+ b'inline; filename="azerty"': "azerty",
+ b'inline; filename="aze%20rty"': "aze%20rty",
+ b'inline; filename="aze"rty"': 'aze"rty',
+ b'inline; filename="azer;ty"': "azer;ty",
+ b"inline; filename*=utf-8''foo%C2%A3bar": "foo£bar",
}
def tests(self):
for hdr, expected in self.TEST_CASES.items():
- res = get_filename_from_headers({b'Content-Disposition': [hdr]})
+ res = get_filename_from_headers({b"Content-Disposition": [hdr]})
self.assertEqual(
res,
expected,
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 1069a44145..1809ceb839 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -22,27 +22,28 @@ from binascii import unhexlify
from mock import Mock
from six.moves.urllib import parse
-from twisted.internet import defer, reactor
from twisted.internet.defer import Deferred
+from synapse.logging.context import make_deferred_yieldable
from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
-from synapse.util.logcontext import make_deferred_yieldable
from tests import unittest
-class MediaStorageTests(unittest.TestCase):
- def setUp(self):
+class MediaStorageTests(unittest.HomeserverTestCase):
+
+ needs_threadpool = True
+
+ def prepare(self, reactor, clock, hs):
self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
+ self.addCleanup(shutil.rmtree, self.test_dir)
self.primary_base_path = os.path.join(self.test_dir, "primary")
self.secondary_base_path = os.path.join(self.test_dir, "secondary")
- hs = Mock()
- hs.get_reactor = Mock(return_value=reactor)
hs.config.media_store_path = self.primary_base_path
storage_providers = [FileStorageProviderBackend(hs, self.secondary_base_path)]
@@ -52,10 +53,6 @@ class MediaStorageTests(unittest.TestCase):
hs, self.primary_base_path, self.filepaths, storage_providers
)
- def tearDown(self):
- shutil.rmtree(self.test_dir)
-
- @defer.inlineCallbacks
def test_ensure_media_is_in_local_cache(self):
media_id = "some_media_id"
test_body = "Test\n"
@@ -73,7 +70,15 @@ class MediaStorageTests(unittest.TestCase):
# Now we run ensure_media_is_in_local_cache, which should copy the file
# to the local cache.
file_info = FileInfo(None, media_id)
- local_path = yield self.media_storage.ensure_media_is_in_local_cache(file_info)
+
+ # This uses a real blocking threadpool so we have to wait for it to be
+ # actually done :/
+ x = self.media_storage.ensure_media_is_in_local_cache(file_info)
+
+ # Hotloop until the threadpool does its job...
+ self.wait_on_thread(x)
+
+ local_path = self.get_success(x)
self.assertTrue(os.path.exists(local_path))
@@ -143,7 +148,8 @@ class MediaRepoTests(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.media_repo = hs.get_media_repository_resource()
- self.download_resource = self.media_repo.children[b'download']
+ self.download_resource = self.media_repo.children[b"download"]
+ self.thumbnail_resource = self.media_repo.children[b"thumbnail"]
# smol png
self.end_content = unhexlify(
@@ -152,11 +158,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
b"0a2db40000000049454e44ae426082"
)
+ self.media_id = "example.com/12345"
+
def _req(self, content_disposition):
- request, channel = self.make_request(
- "GET", "example.com/12345", shorthand=False
- )
+ request, channel = self.make_request("GET", self.media_id, shorthand=False)
request.render(self.download_resource)
self.pump()
@@ -165,13 +171,13 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.assertEqual(len(self.fetches), 1)
self.assertEqual(self.fetches[0][1], "example.com")
self.assertEqual(
- self.fetches[0][2], "/_matrix/media/v1/download/example.com/12345"
+ self.fetches[0][2], "/_matrix/media/v1/download/" + self.media_id
)
self.assertEqual(self.fetches[0][3], {"allow_remote": "false"})
headers = {
b"Content-Length": [b"%d" % (len(self.end_content))],
- b"Content-Type": [b'image/png'],
+ b"Content-Type": [b"image/png"],
}
if content_disposition:
headers[b"Content-Disposition"] = [content_disposition]
@@ -204,7 +210,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
correctly decode it as the UTF-8 string, and use filename* in the
response.
"""
- filename = parse.quote(u"\u2603".encode('utf8')).encode('ascii')
+ filename = parse.quote("\u2603".encode("utf8")).encode("ascii")
channel = self._req(b"inline; filename*=utf-8''" + filename + b".png")
headers = channel.headers
@@ -224,3 +230,42 @@ class MediaRepoTests(unittest.HomeserverTestCase):
headers = channel.headers
self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"])
self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
+
+ def test_thumbnail_crop(self):
+ expected_body = unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000020000000200806"
+ b"000000737a7af40000001a49444154789cedc101010000008220"
+ b"ffaf6e484001000000ef0610200001194334ee0000000049454e"
+ b"44ae426082"
+ )
+
+ self._test_thumbnail("crop", expected_body)
+
+ def test_thumbnail_scale(self):
+ expected_body = unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000d49444154789c636060606000000005"
+ b"0001a5f645400000000049454e44ae426082"
+ )
+
+ self._test_thumbnail("scale", expected_body)
+
+ def _test_thumbnail(self, method, expected_body):
+ params = "?width=32&height=32&method=" + method
+ request, channel = self.make_request(
+ "GET", self.media_id + params, shorthand=False
+ )
+ request.render(self.thumbnail_resource)
+ self.pump()
+
+ headers = {
+ b"Content-Length": [b"%d" % (len(self.end_content))],
+ b"Content-Type": [b"image/png"],
+ }
+ self.fetches[0][0].callback(
+ (self.end_content, (len(self.end_content), headers))
+ )
+ self.pump()
+
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.result["body"], expected_body, channel.result["body"])
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 1ab0f7293a..852b8ab11c 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -55,10 +55,10 @@ class URLPreviewTests(unittest.HomeserverTestCase):
hijack_auth = True
user_id = "@test:user"
end_content = (
- b'<html><head>'
+ b"<html><head>"
b'<meta property="og:title" content="~matrix~" />'
b'<meta property="og:description" content="hi" />'
- b'</head></html>'
+ b"</head></html>"
)
def make_homeserver(self, reactor, clock):
@@ -98,7 +98,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.media_repo = hs.get_media_repository_resource()
- self.preview_url = self.media_repo.children[b'preview_url']
+ self.preview_url = self.media_repo.children[b"preview_url"]
self.lookups = {}
@@ -109,7 +109,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
hostName,
portNumber=0,
addressTypes=None,
- transportSemantics='TCP',
+ transportSemantics="TCP",
):
resolution = HostResolution(hostName)
@@ -118,7 +118,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
raise DNSLookupError("OH NO")
for i in self.lookups[hostName]:
- resolutionReceiver.addressResolved(i[0]('TCP', i[1], portNumber))
+ resolutionReceiver.addressResolved(i[0]("TCP", i[1], portNumber))
resolutionReceiver.resolutionComplete()
return resolutionReceiver
@@ -184,11 +184,11 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
end_content = (
- b'<html><head>'
+ b"<html><head>"
b'<meta http-equiv="Content-Type" content="text/html; charset=windows-1251"/>'
b'<meta property="og:title" content="\xe4\xea\xe0" />'
b'<meta property="og:description" content="hi" />'
- b'</head></html>'
+ b"</head></html>"
)
request, channel = self.make_request(
@@ -204,7 +204,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
client.dataReceived(
(
b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b"Content-Type: text/html; charset=\"utf8\"\r\n\r\n"
+ b'Content-Type: text/html; charset="utf8"\r\n\r\n'
)
% (len(end_content),)
+ end_content
@@ -212,16 +212,16 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.pump()
self.assertEqual(channel.code, 200)
- self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430")
+ self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
def test_non_ascii_preview_content_type(self):
self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
end_content = (
- b'<html><head>'
+ b"<html><head>"
b'<meta property="og:title" content="\xe4\xea\xe0" />'
b'<meta property="og:description" content="hi" />'
- b'</head></html>'
+ b"</head></html>"
)
request, channel = self.make_request(
@@ -237,7 +237,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
client.dataReceived(
(
b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b"Content-Type: text/html; charset=\"windows-1251\"\r\n\r\n"
+ b'Content-Type: text/html; charset="windows-1251"\r\n\r\n'
)
% (len(end_content),)
+ end_content
@@ -245,7 +245,42 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.pump()
self.assertEqual(channel.code, 200)
- self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430")
+ self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
+
+ def test_overlong_title(self):
+ self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+
+ end_content = (
+ b"<html><head>"
+ b"<title>" + b"x" * 2000 + b"</title>"
+ b'<meta property="og:description" content="hi" />'
+ b"</head></html>"
+ )
+
+ request, channel = self.make_request(
+ "GET", "url_preview?url=http://matrix.org", shorthand=False
+ )
+ request.render(self.preview_url)
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="windows-1251"\r\n\r\n'
+ )
+ % (len(end_content),)
+ + end_content
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ res = channel.json_body
+ # We should only see the `og:description` field, as `title` is too long and should be stripped out
+ self.assertCountEqual(["og:description"], res.keys())
def test_ipaddr(self):
"""
@@ -293,8 +328,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(
channel.json_body,
{
- 'errcode': 'M_UNKNOWN',
- 'error': 'DNS resolution failure during URL preview generation',
+ "errcode": "M_UNKNOWN",
+ "error": "DNS resolution failure during URL preview generation",
},
)
@@ -314,8 +349,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(
channel.json_body,
{
- 'errcode': 'M_UNKNOWN',
- 'error': 'DNS resolution failure during URL preview generation',
+ "errcode": "M_UNKNOWN",
+ "error": "DNS resolution failure during URL preview generation",
},
)
@@ -334,8 +369,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(
channel.json_body,
{
- 'errcode': 'M_UNKNOWN',
- 'error': 'IP address blocked by IP blacklist entry',
+ "errcode": "M_UNKNOWN",
+ "error": "IP address blocked by IP blacklist entry",
},
)
self.assertEqual(channel.code, 403)
@@ -354,8 +389,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(
channel.json_body,
{
- 'errcode': 'M_UNKNOWN',
- 'error': 'IP address blocked by IP blacklist entry',
+ "errcode": "M_UNKNOWN",
+ "error": "IP address blocked by IP blacklist entry",
},
)
@@ -396,7 +431,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
non-blacklisted one, it will be rejected.
"""
# Hardcode the URL resolving to the IP we want.
- self.lookups[u"example.com"] = [
+ self.lookups["example.com"] = [
(IPv4Address, "1.1.1.2"),
(IPv4Address, "8.8.8.8"),
]
@@ -410,8 +445,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(
channel.json_body,
{
- 'errcode': 'M_UNKNOWN',
- 'error': 'DNS resolution failure during URL preview generation',
+ "errcode": "M_UNKNOWN",
+ "error": "DNS resolution failure during URL preview generation",
},
)
@@ -435,8 +470,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(
channel.json_body,
{
- 'errcode': 'M_UNKNOWN',
- 'error': 'DNS resolution failure during URL preview generation',
+ "errcode": "M_UNKNOWN",
+ "error": "DNS resolution failure during URL preview generation",
},
)
@@ -456,7 +491,19 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(
channel.json_body,
{
- 'errcode': 'M_UNKNOWN',
- 'error': 'DNS resolution failure during URL preview generation',
+ "errcode": "M_UNKNOWN",
+ "error": "DNS resolution failure during URL preview generation",
},
)
+
+ def test_OPTIONS(self):
+ """
+ OPTIONS returns the OPTIONS.
+ """
+ request, channel = self.make_request(
+ "OPTIONS", "url_preview?url=http://example.com", shorthand=False
+ )
+ request.render(self.preview_url)
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body, {})
|