summary refs log tree commit diff
path: root/tests/unittest.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unittest.py')
-rw-r--r--tests/unittest.py102
1 files changed, 99 insertions, 3 deletions
diff --git a/tests/unittest.py b/tests/unittest.py
index d852e2465a..a59291cc60 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -14,6 +14,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import hashlib
+import hmac
 import logging
 
 from mock import Mock
@@ -22,14 +24,17 @@ from canonicaljson import json
 
 import twisted
 import twisted.logger
+from twisted.internet.defer import Deferred
 from twisted.trial import unittest
 
 from synapse.http.server import JsonResource
+from synapse.http.site import SynapseRequest
 from synapse.server import HomeServer
 from synapse.types import UserID, create_requester
 from synapse.util.logcontext import LoggingContextFilter
 
 from tests.server import get_clock, make_request, render, setup_test_homeserver
+from tests.utils import default_config
 
 # Set up putting Synapse's logs into Trial's.
 rootLogger = logging.getLogger()
@@ -151,6 +156,7 @@ class HomeserverTestCase(TestCase):
         hijack_auth (bool): Whether to hijack auth to return the user specified
         in user_id.
     """
+
     servlets = []
     hijack_auth = True
 
@@ -217,7 +223,17 @@ class HomeserverTestCase(TestCase):
 
         Function to be overridden in subclasses.
         """
-        raise NotImplementedError()
+        hs = self.setup_test_homeserver()
+        return hs
+
+    def default_config(self, name="test"):
+        """
+        Get a default HomeServer config object.
+
+        Args:
+            name (str): The homeserver name/domain.
+        """
+        return default_config(name)
 
     def prepare(self, reactor, clock, homeserver):
         """
@@ -234,7 +250,9 @@ class HomeserverTestCase(TestCase):
         Function to optionally be overridden in subclasses.
         """
 
-    def make_request(self, method, path, content=b""):
+    def make_request(
+        self, method, path, content=b"", access_token=None, request=SynapseRequest
+    ):
         """
         Create a SynapseRequest at the path using the method and containing the
         given content.
@@ -252,7 +270,7 @@ class HomeserverTestCase(TestCase):
         if isinstance(content, dict):
             content = json.dumps(content).encode('utf8')
 
-        return make_request(method, path, content)
+        return make_request(method, path, content, access_token, request)
 
     def render(self, request):
         """
@@ -279,3 +297,81 @@ class HomeserverTestCase(TestCase):
         kwargs = dict(kwargs)
         kwargs.update(self._hs_args)
         return setup_test_homeserver(self.addCleanup, *args, **kwargs)
+
+    def pump(self, by=0.0):
+        """
+        Pump the reactor enough that Deferreds will fire.
+        """
+        self.reactor.pump([by] * 100)
+
+    def get_success(self, d):
+        if not isinstance(d, Deferred):
+            return d
+        self.pump()
+        return self.successResultOf(d)
+
+    def register_user(self, username, password, admin=False):
+        """
+        Register a user. Requires the Admin API be registered.
+
+        Args:
+            username (bytes/unicode): The user part of the new user.
+            password (bytes/unicode): The password of the new user.
+            admin (bool): Whether the user should be created as an admin
+            or not.
+
+        Returns:
+            The MXID of the new user (unicode).
+        """
+        self.hs.config.registration_shared_secret = u"shared"
+
+        # Create the user
+        request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register")
+        self.render(request)
+        nonce = channel.json_body["nonce"]
+
+        want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
+        nonce_str = b"\x00".join([username.encode('utf8'), password.encode('utf8')])
+        if admin:
+            nonce_str += b"\x00admin"
+        else:
+            nonce_str += b"\x00notadmin"
+        want_mac.update(nonce.encode('ascii') + b"\x00" + nonce_str)
+        want_mac = want_mac.hexdigest()
+
+        body = json.dumps(
+            {
+                "nonce": nonce,
+                "username": username,
+                "password": password,
+                "admin": admin,
+                "mac": want_mac,
+            }
+        )
+        request, channel = self.make_request(
+            "POST", "/_matrix/client/r0/admin/register", body.encode('utf8')
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        user_id = channel.json_body["user_id"]
+        return user_id
+
+    def login(self, username, password, device_id=None):
+        """
+        Log in a user, and get an access token. Requires the Login API be
+        registered.
+
+        """
+        body = {"type": "m.login.password", "user": username, "password": password}
+        if device_id:
+            body["device_id"] = device_id
+
+        request, channel = self.make_request(
+            "POST", "/_matrix/client/r0/login", json.dumps(body).encode('utf8')
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        access_token = channel.json_body["access_token"].encode('ascii')
+        return access_token