summary refs log tree commit diff
path: root/tests/unittest.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unittest.py')
-rw-r--r--tests/unittest.py95
1 files changed, 92 insertions, 3 deletions
diff --git a/tests/unittest.py b/tests/unittest.py
index a3d39920db..4d40bdb6a5 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -14,6 +14,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import hashlib
+import hmac
 import logging
 
 from mock import Mock
@@ -26,11 +28,13 @@ from twisted.internet.defer import Deferred
 from twisted.trial import unittest
 
 from synapse.http.server import JsonResource
+from synapse.http.site import SynapseRequest
 from synapse.server import HomeServer
 from synapse.types import UserID, create_requester
 from synapse.util.logcontext import LoggingContextFilter
 
 from tests.server import get_clock, make_request, render, setup_test_homeserver
+from tests.utils import default_config
 
 # Set up putting Synapse's logs into Trial's.
 rootLogger = logging.getLogger()
@@ -142,6 +146,13 @@ def DEBUG(target):
     return target
 
 
+def INFO(target):
+    """A decorator to set the .loglevel attribute to logging.INFO.
+    Can apply to either a TestCase or an individual test method."""
+    target.loglevel = logging.INFO
+    return target
+
+
 class HomeserverTestCase(TestCase):
     """
     A base TestCase that reduces boilerplate for HomeServer-using test cases.
@@ -219,7 +230,17 @@ class HomeserverTestCase(TestCase):
 
         Function to be overridden in subclasses.
         """
-        raise NotImplementedError()
+        hs = self.setup_test_homeserver()
+        return hs
+
+    def default_config(self, name="test"):
+        """
+        Get a default HomeServer config object.
+
+        Args:
+            name (str): The homeserver name/domain.
+        """
+        return default_config(name)
 
     def prepare(self, reactor, clock, homeserver):
         """
@@ -236,7 +257,9 @@ class HomeserverTestCase(TestCase):
         Function to optionally be overridden in subclasses.
         """
 
-    def make_request(self, method, path, content=b""):
+    def make_request(
+        self, method, path, content=b"", access_token=None, request=SynapseRequest
+    ):
         """
         Create a SynapseRequest at the path using the method and containing the
         given content.
@@ -254,7 +277,7 @@ class HomeserverTestCase(TestCase):
         if isinstance(content, dict):
             content = json.dumps(content).encode('utf8')
 
-        return make_request(method, path, content)
+        return make_request(method, path, content, access_token, request)
 
     def render(self, request):
         """
@@ -293,3 +316,69 @@ class HomeserverTestCase(TestCase):
             return d
         self.pump()
         return self.successResultOf(d)
+
+    def register_user(self, username, password, admin=False):
+        """
+        Register a user. Requires the Admin API be registered.
+
+        Args:
+            username (bytes/unicode): The user part of the new user.
+            password (bytes/unicode): The password of the new user.
+            admin (bool): Whether the user should be created as an admin
+            or not.
+
+        Returns:
+            The MXID of the new user (unicode).
+        """
+        self.hs.config.registration_shared_secret = u"shared"
+
+        # Create the user
+        request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register")
+        self.render(request)
+        nonce = channel.json_body["nonce"]
+
+        want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
+        nonce_str = b"\x00".join([username.encode('utf8'), password.encode('utf8')])
+        if admin:
+            nonce_str += b"\x00admin"
+        else:
+            nonce_str += b"\x00notadmin"
+        want_mac.update(nonce.encode('ascii') + b"\x00" + nonce_str)
+        want_mac = want_mac.hexdigest()
+
+        body = json.dumps(
+            {
+                "nonce": nonce,
+                "username": username,
+                "password": password,
+                "admin": admin,
+                "mac": want_mac,
+            }
+        )
+        request, channel = self.make_request(
+            "POST", "/_matrix/client/r0/admin/register", body.encode('utf8')
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        user_id = channel.json_body["user_id"]
+        return user_id
+
+    def login(self, username, password, device_id=None):
+        """
+        Log in a user, and get an access token. Requires the Login API be
+        registered.
+
+        """
+        body = {"type": "m.login.password", "user": username, "password": password}
+        if device_id:
+            body["device_id"] = device_id
+
+        request, channel = self.make_request(
+            "POST", "/_matrix/client/r0/login", json.dumps(body).encode('utf8')
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        access_token = channel.json_body["access_token"]
+        return access_token