diff --git a/tests/rest/client/v1/test_admin.py b/tests/rest/client/v1/test_admin.py
index 1a553fa3f9..e38eb628a9 100644
--- a/tests/rest/client/v1/test_admin.py
+++ b/tests/rest/client/v1/test_admin.py
@@ -19,24 +19,17 @@ import json
from mock import Mock
-from synapse.http.server import JsonResource
from synapse.rest.client.v1.admin import register_servlets
-from synapse.util import Clock
from tests import unittest
-from tests.server import (
- ThreadedMemoryReactorClock,
- make_request,
- render,
- setup_test_homeserver,
-)
-class UserRegisterTestCase(unittest.TestCase):
- def setUp(self):
+class UserRegisterTestCase(unittest.HomeserverTestCase):
+
+ servlets = [register_servlets]
+
+ def make_homeserver(self, reactor, clock):
- self.clock = ThreadedMemoryReactorClock()
- self.hs_clock = Clock(self.clock)
self.url = "/_matrix/client/r0/admin/register"
self.registration_handler = Mock()
@@ -50,17 +43,14 @@ class UserRegisterTestCase(unittest.TestCase):
self.secrets = Mock()
- self.hs = setup_test_homeserver(
- self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
- )
+ self.hs = self.setup_test_homeserver()
self.hs.config.registration_shared_secret = u"shared"
self.hs.get_media_repository = Mock()
self.hs.get_deactivate_account_handler = Mock()
- self.resource = JsonResource(self.hs)
- register_servlets(self.hs, self.resource)
+ return self.hs
def test_disabled(self):
"""
@@ -69,8 +59,8 @@ class UserRegisterTestCase(unittest.TestCase):
"""
self.hs.config.registration_shared_secret = None
- request, channel = make_request("POST", self.url, b'{}')
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, b'{}')
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
@@ -87,8 +77,8 @@ class UserRegisterTestCase(unittest.TestCase):
self.hs.get_secrets = Mock(return_value=secrets)
- request, channel = make_request("GET", self.url)
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
self.assertEqual(channel.json_body, {"nonce": "abcd"})
@@ -97,25 +87,25 @@ class UserRegisterTestCase(unittest.TestCase):
Calling GET on the endpoint will return a randomised nonce, which will
only last for SALT_TIMEOUT (60s).
"""
- request, channel = make_request("GET", self.url)
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
nonce = channel.json_body["nonce"]
# 59 seconds
- self.clock.advance(59)
+ self.reactor.advance(59)
body = json.dumps({"nonce": nonce})
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('username must be specified', channel.json_body["error"])
# 61 seconds
- self.clock.advance(2)
+ self.reactor.advance(2)
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('unrecognised nonce', channel.json_body["error"])
@@ -124,8 +114,8 @@ class UserRegisterTestCase(unittest.TestCase):
"""
Only the provided nonce can be used, as it's checked in the MAC.
"""
- request, channel = make_request("GET", self.url)
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -141,8 +131,8 @@ class UserRegisterTestCase(unittest.TestCase):
"mac": want_mac,
}
)
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("HMAC incorrect", channel.json_body["error"])
@@ -152,8 +142,8 @@ class UserRegisterTestCase(unittest.TestCase):
When the correct nonce is provided, and the right key is provided, the
user is registered.
"""
- request, channel = make_request("GET", self.url)
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -169,8 +159,8 @@ class UserRegisterTestCase(unittest.TestCase):
"mac": want_mac,
}
)
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["user_id"])
@@ -179,8 +169,8 @@ class UserRegisterTestCase(unittest.TestCase):
"""
A valid unrecognised nonce.
"""
- request, channel = make_request("GET", self.url)
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -196,15 +186,15 @@ class UserRegisterTestCase(unittest.TestCase):
"mac": want_mac,
}
)
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["user_id"])
# Now, try and reuse it
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('unrecognised nonce', channel.json_body["error"])
@@ -217,8 +207,8 @@ class UserRegisterTestCase(unittest.TestCase):
"""
def nonce():
- request, channel = make_request("GET", self.url)
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
return channel.json_body["nonce"]
#
@@ -227,8 +217,8 @@ class UserRegisterTestCase(unittest.TestCase):
# Must be present
body = json.dumps({})
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('nonce must be specified', channel.json_body["error"])
@@ -239,32 +229,32 @@ class UserRegisterTestCase(unittest.TestCase):
# Must be present
body = json.dumps({"nonce": nonce()})
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('username must be specified', channel.json_body["error"])
# Must be a string
body = json.dumps({"nonce": nonce(), "username": 1234})
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid username', channel.json_body["error"])
# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": u"abcd\u0000"})
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid username', channel.json_body["error"])
# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid username', channel.json_body["error"])
@@ -275,16 +265,16 @@ class UserRegisterTestCase(unittest.TestCase):
# Must be present
body = json.dumps({"nonce": nonce(), "username": "a"})
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('password must be specified', channel.json_body["error"])
# Must be a string
body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid password', channel.json_body["error"])
@@ -293,16 +283,16 @@ class UserRegisterTestCase(unittest.TestCase):
body = json.dumps(
{"nonce": nonce(), "username": "a", "password": u"abcd\u0000"}
)
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid password', channel.json_body["error"])
# Super long
body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid password', channel.json_body["error"])
diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py
index 6b7ff813d5..f973eff8cf 100644
--- a/tests/rest/client/v1/test_register.py
+++ b/tests/rest/client/v1/test_register.py
@@ -45,11 +45,11 @@ class CreateUserServletTestCase(unittest.TestCase):
)
handlers = Mock(registration_handler=self.registration_handler)
- self.clock = MemoryReactorClock()
- self.hs_clock = Clock(self.clock)
+ self.reactor = MemoryReactorClock()
+ self.hs_clock = Clock(self.reactor)
self.hs = self.hs = setup_test_homeserver(
- self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
+ self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor
)
self.hs.get_datastore = Mock(return_value=self.datastore)
self.hs.get_handlers = Mock(return_value=handlers)
@@ -76,8 +76,8 @@ class CreateUserServletTestCase(unittest.TestCase):
return_value=(user_id, token)
)
- request, channel = make_request(b"POST", url, request_data)
- render(request, res, self.clock)
+ request, channel = make_request(self.reactor, b"POST", url, request_data)
+ render(request, res, self.reactor)
self.assertEquals(channel.result["code"], b"200")
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 359f7777ff..a824be9a62 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -23,7 +23,7 @@ from six.moves.urllib import parse as urlparse
from twisted.internet import defer
from synapse.api.constants import Membership
-from synapse.rest.client.v1 import room
+from synapse.rest.client.v1 import admin, login, room
from tests import unittest
@@ -799,3 +799,107 @@ class RoomMessageListTestCase(RoomBase):
self.assertEquals(token, channel.json_body['start'])
self.assertTrue("chunk" in channel.json_body)
self.assertTrue("end" in channel.json_body)
+
+
+class RoomSearchTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+ user_id = True
+ hijack_auth = False
+
+ def prepare(self, reactor, clock, hs):
+
+ # Register the user who does the searching
+ self.user_id = self.register_user("user", "pass")
+ self.access_token = self.login("user", "pass")
+
+ # Register the user who sends the message
+ self.other_user_id = self.register_user("otheruser", "pass")
+ self.other_access_token = self.login("otheruser", "pass")
+
+ # Create a room
+ self.room = self.helper.create_room_as(self.user_id, tok=self.access_token)
+
+ # Invite the other person
+ self.helper.invite(
+ room=self.room,
+ src=self.user_id,
+ tok=self.access_token,
+ targ=self.other_user_id,
+ )
+
+ # The other user joins
+ self.helper.join(
+ room=self.room, user=self.other_user_id, tok=self.other_access_token
+ )
+
+ def test_finds_message(self):
+ """
+ The search functionality will search for content in messages if asked to
+ do so.
+ """
+ # The other user sends some messages
+ self.helper.send(self.room, body="Hi!", tok=self.other_access_token)
+ self.helper.send(self.room, body="There!", tok=self.other_access_token)
+
+ request, channel = self.make_request(
+ "POST",
+ "/search?access_token=%s" % (self.access_token,),
+ {
+ "search_categories": {
+ "room_events": {"keys": ["content.body"], "search_term": "Hi"}
+ }
+ },
+ )
+ self.render(request)
+
+ # Check we get the results we expect -- one search result, of the sent
+ # messages
+ self.assertEqual(channel.code, 200)
+ results = channel.json_body["search_categories"]["room_events"]
+ self.assertEqual(results["count"], 1)
+ self.assertEqual(results["results"][0]["result"]["content"]["body"], "Hi!")
+
+ # No context was requested, so we should get none.
+ self.assertEqual(results["results"][0]["context"], {})
+
+ def test_include_context(self):
+ """
+ When event_context includes include_profile, profile information will be
+ included in the search response.
+ """
+ # The other user sends some messages
+ self.helper.send(self.room, body="Hi!", tok=self.other_access_token)
+ self.helper.send(self.room, body="There!", tok=self.other_access_token)
+
+ request, channel = self.make_request(
+ "POST",
+ "/search?access_token=%s" % (self.access_token,),
+ {
+ "search_categories": {
+ "room_events": {
+ "keys": ["content.body"],
+ "search_term": "Hi",
+ "event_context": {"include_profile": True},
+ }
+ }
+ },
+ )
+ self.render(request)
+
+ # Check we get the results we expect -- one search result, of the sent
+ # messages
+ self.assertEqual(channel.code, 200)
+ results = channel.json_body["search_categories"]["room_events"]
+ self.assertEqual(results["count"], 1)
+ self.assertEqual(results["results"][0]["result"]["content"]["body"], "Hi!")
+
+ # We should get context info, like the two users, and the display names.
+ context = results["results"][0]["context"]
+ self.assertEqual(len(context["profile_info"].keys()), 2)
+ self.assertEqual(
+ context["profile_info"][self.other_user_id]["displayname"], "otheruser"
+ )
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 530dc8ba6d..9c401bf300 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -169,7 +169,7 @@ class RestHelper(object):
path = path + "?access_token=%s" % tok
request, channel = make_request(
- "POST", path, json.dumps(content).encode('utf8')
+ self.hs.get_reactor(), "POST", path, json.dumps(content).encode('utf8')
)
render(request, self.resource, self.hs.get_reactor())
@@ -217,7 +217,9 @@ class RestHelper(object):
data = {"membership": membership}
- request, channel = make_request("PUT", path, json.dumps(data).encode('utf8'))
+ request, channel = make_request(
+ self.hs.get_reactor(), "PUT", path, json.dumps(data).encode('utf8')
+ )
render(request, self.resource, self.hs.get_reactor())
@@ -228,18 +230,6 @@ class RestHelper(object):
self.auth_user_id = temp_id
- @defer.inlineCallbacks
- def register(self, user_id):
- (code, response) = yield self.mock_resource.trigger(
- "POST",
- "/_matrix/client/r0/register",
- json.dumps(
- {"user": user_id, "password": "test", "type": "m.login.password"}
- ),
- )
- self.assertEquals(200, code)
- defer.returnValue(response)
-
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
if txn_id is None:
txn_id = "m%s" % (str(time.time()))
@@ -251,7 +241,9 @@ class RestHelper(object):
if tok:
path = path + "?access_token=%s" % tok
- request, channel = make_request("PUT", path, json.dumps(content).encode('utf8'))
+ request, channel = make_request(
+ self.hs.get_reactor(), "PUT", path, json.dumps(content).encode('utf8')
+ )
render(request, self.resource, self.hs.get_reactor())
assert int(channel.result["code"]) == expect_code, (
|