diff --git a/tests/unittest.py b/tests/unittest.py
index b15b06726b..092c930396 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,14 +13,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import gc
+import hashlib
+import hmac
import logging
+from mock import Mock
+
+from canonicaljson import json
+
import twisted
import twisted.logger
+from twisted.internet.defer import Deferred
from twisted.trial import unittest
-from synapse.util.logcontext import LoggingContextFilter
+from synapse.http.server import JsonResource
+from synapse.http.site import SynapseRequest
+from synapse.server import HomeServer
+from synapse.types import UserID, create_requester
+from synapse.util.logcontext import LoggingContext, LoggingContextFilter
+
+from tests.server import get_clock, make_request, render, setup_test_homeserver
+from tests.utils import default_config, setupdb
+
+setupdb()
# Set up putting Synapse's logs into Trial's.
rootLogger = logging.getLogger()
@@ -56,6 +73,7 @@ def around(target):
def method_name(orig, *args, **kwargs):
return orig(*args, **kwargs)
"""
+
def _around(code):
name = code.__name__
orig = getattr(target, name)
@@ -86,9 +104,18 @@ class TestCase(unittest.TestCase):
# traceback when a unit test exits leaving things on the reactor.
twisted.internet.base.DelayedCall.debug = True
- old_level = logging.getLogger().level
+ # if we're not starting in the sentinel logcontext, then to be honest
+ # all future bets are off.
+ if LoggingContext.current_context() is not LoggingContext.sentinel:
+ self.fail(
+ "Test starting with non-sentinel logging context %s" % (
+ LoggingContext.current_context(),
+ )
+ )
+ old_level = logging.getLogger().level
if old_level != level:
+
@around(self)
def tearDown(orig):
ret = orig()
@@ -98,6 +125,16 @@ class TestCase(unittest.TestCase):
logging.getLogger().setLevel(level)
return orig()
+ @around(self)
+ def tearDown(orig):
+ ret = orig()
+ # force a GC to workaround problems with deferreds leaking logcontexts when
+ # they are GCed (see the logcontext docs)
+ gc.collect()
+ LoggingContext.set_current_context(LoggingContext.sentinel)
+
+ return ret
+
def assertObjectHasAttributes(self, attrs, obj):
"""Asserts that the given object has each of the attributes given, and
that the value of each matches according to assertEquals."""
@@ -117,8 +154,9 @@ class TestCase(unittest.TestCase):
actual (dict): The test result. Extra keys will not be checked.
"""
for key in required:
- self.assertEquals(required[key], actual[key],
- msg="%s mismatch. %s" % (key, actual))
+ self.assertEquals(
+ required[key], actual[key], msg="%s mismatch. %s" % (key, actual)
+ )
def DEBUG(target):
@@ -126,3 +164,251 @@ def DEBUG(target):
Can apply to either a TestCase or an individual test method."""
target.loglevel = logging.DEBUG
return target
+
+
+def INFO(target):
+ """A decorator to set the .loglevel attribute to logging.INFO.
+ Can apply to either a TestCase or an individual test method."""
+ target.loglevel = logging.INFO
+ return target
+
+
+class HomeserverTestCase(TestCase):
+ """
+ A base TestCase that reduces boilerplate for HomeServer-using test cases.
+
+ Attributes:
+ servlets (list[function]): List of servlet registration function.
+ user_id (str): The user ID to assume if auth is hijacked.
+ hijack_auth (bool): Whether to hijack auth to return the user specified
+ in user_id.
+ """
+
+ servlets = []
+ hijack_auth = True
+
+ def setUp(self):
+ """
+ Set up the TestCase by calling the homeserver constructor, optionally
+ hijacking the authentication system to return a fixed user, and then
+ calling the prepare function.
+ """
+ self.reactor, self.clock = get_clock()
+ self._hs_args = {"clock": self.clock, "reactor": self.reactor}
+ self.hs = self.make_homeserver(self.reactor, self.clock)
+
+ if self.hs is None:
+ raise Exception("No homeserver returned from make_homeserver.")
+
+ if not isinstance(self.hs, HomeServer):
+ raise Exception("A homeserver wasn't returned, but %r" % (self.hs,))
+
+ # Register the resources
+ self.resource = JsonResource(self.hs)
+
+ for servlet in self.servlets:
+ servlet(self.hs, self.resource)
+
+ from tests.rest.client.v1.utils import RestHelper
+
+ self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None))
+
+ if hasattr(self, "user_id"):
+ if self.hijack_auth:
+
+ def get_user_by_access_token(token=None, allow_guest=False):
+ return {
+ "user": UserID.from_string(self.helper.auth_user_id),
+ "token_id": 1,
+ "is_guest": False,
+ }
+
+ def get_user_by_req(request, allow_guest=False, rights="access"):
+ return create_requester(
+ UserID.from_string(self.helper.auth_user_id), 1, False, None
+ )
+
+ self.hs.get_auth().get_user_by_req = get_user_by_req
+ self.hs.get_auth().get_user_by_access_token = get_user_by_access_token
+ self.hs.get_auth().get_access_token_from_request = Mock(
+ return_value="1234"
+ )
+
+ if hasattr(self, "prepare"):
+ self.prepare(self.reactor, self.clock, self.hs)
+
+ def make_homeserver(self, reactor, clock):
+ """
+ Make and return a homeserver.
+
+ Args:
+ reactor: A Twisted Reactor, or something that pretends to be one.
+ clock (synapse.util.Clock): The Clock, associated with the reactor.
+
+ Returns:
+ A homeserver (synapse.server.HomeServer) suitable for testing.
+
+ Function to be overridden in subclasses.
+ """
+ hs = self.setup_test_homeserver()
+ return hs
+
+ def default_config(self, name="test"):
+ """
+ Get a default HomeServer config object.
+
+ Args:
+ name (str): The homeserver name/domain.
+ """
+ return default_config(name)
+
+ def prepare(self, reactor, clock, homeserver):
+ """
+ Prepare for the test. This involves things like mocking out parts of
+ the homeserver, or building test data common across the whole test
+ suite.
+
+ Args:
+ reactor: A Twisted Reactor, or something that pretends to be one.
+ clock (synapse.util.Clock): The Clock, associated with the reactor.
+ homeserver (synapse.server.HomeServer): The HomeServer to test
+ against.
+
+ Function to optionally be overridden in subclasses.
+ """
+
+ def make_request(
+ self,
+ method,
+ path,
+ content=b"",
+ access_token=None,
+ request=SynapseRequest,
+ shorthand=True,
+ ):
+ """
+ Create a SynapseRequest at the path using the method and containing the
+ given content.
+
+ Args:
+ method (bytes/unicode): The HTTP request method ("verb").
+ path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
+ escaped UTF-8 & spaces and such).
+ content (bytes or dict): The body of the request. JSON-encoded, if
+ a dict.
+ shorthand: Whether to try and be helpful and prefix the given URL
+ with the usual REST API path, if it doesn't contain it.
+
+ Returns:
+ A synapse.http.site.SynapseRequest.
+ """
+ if isinstance(content, dict):
+ content = json.dumps(content).encode('utf8')
+
+ return make_request(
+ self.reactor, method, path, content, access_token, request, shorthand
+ )
+
+ def render(self, request):
+ """
+ Render a request against the resources registered by the test class's
+ servlets.
+
+ Args:
+ request (synapse.http.site.SynapseRequest): The request to render.
+ """
+ render(request, self.resource, self.reactor)
+
+ def setup_test_homeserver(self, *args, **kwargs):
+ """
+ Set up the test homeserver, meant to be called by the overridable
+ make_homeserver. It automatically passes through the test class's
+ clock & reactor.
+
+ Args:
+ See tests.utils.setup_test_homeserver.
+
+ Returns:
+ synapse.server.HomeServer
+ """
+ kwargs = dict(kwargs)
+ kwargs.update(self._hs_args)
+ return setup_test_homeserver(self.addCleanup, *args, **kwargs)
+
+ def pump(self, by=0.0):
+ """
+ Pump the reactor enough that Deferreds will fire.
+ """
+ self.reactor.pump([by] * 100)
+
+ def get_success(self, d):
+ if not isinstance(d, Deferred):
+ return d
+ self.pump()
+ return self.successResultOf(d)
+
+ def register_user(self, username, password, admin=False):
+ """
+ Register a user. Requires the Admin API be registered.
+
+ Args:
+ username (bytes/unicode): The user part of the new user.
+ password (bytes/unicode): The password of the new user.
+ admin (bool): Whether the user should be created as an admin
+ or not.
+
+ Returns:
+ The MXID of the new user (unicode).
+ """
+ self.hs.config.registration_shared_secret = u"shared"
+
+ # Create the user
+ request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register")
+ self.render(request)
+ nonce = channel.json_body["nonce"]
+
+ want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
+ nonce_str = b"\x00".join([username.encode('utf8'), password.encode('utf8')])
+ if admin:
+ nonce_str += b"\x00admin"
+ else:
+ nonce_str += b"\x00notadmin"
+ want_mac.update(nonce.encode('ascii') + b"\x00" + nonce_str)
+ want_mac = want_mac.hexdigest()
+
+ body = json.dumps(
+ {
+ "nonce": nonce,
+ "username": username,
+ "password": password,
+ "admin": admin,
+ "mac": want_mac,
+ }
+ )
+ request, channel = self.make_request(
+ "POST", "/_matrix/client/r0/admin/register", body.encode('utf8')
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ user_id = channel.json_body["user_id"]
+ return user_id
+
+ def login(self, username, password, device_id=None):
+ """
+ Log in a user, and get an access token. Requires the Login API be
+ registered.
+
+ """
+ body = {"type": "m.login.password", "user": username, "password": password}
+ if device_id:
+ body["device_id"] = device_id
+
+ request, channel = self.make_request(
+ "POST", "/_matrix/client/r0/login", json.dumps(body).encode('utf8')
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ access_token = channel.json_body["access_token"]
+ return access_token
|