diff --git a/tests/utils.py b/tests/utils.py
index 52405502e9..2d0bd205fd 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -20,7 +20,6 @@ from synapse.storage.prepare_database import prepare_database
from synapse.storage.engines import create_engine
from synapse.server import HomeServer
from synapse.federation.transport import server
-from synapse.types import Requester
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.logcontext import LoggingContext
@@ -49,11 +48,17 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config.event_cache_size = 1
config.enable_registration = True
config.macaroon_secret_key = "not even a little secret"
- config.server_name = "server.under.test"
+ config.expire_access_token = False
+ config.server_name = name
config.trusted_third_party_id_servers = []
config.room_invite_state_types = []
+ config.password_providers = []
+ config.worker_replication_url = ""
+ config.worker_app = None
+ config.use_frozen_dicts = True
config.database_config = {"name": "sqlite3"}
+ config.ldap_enabled = False
if "clock" not in kargs:
kargs["clock"] = MockClock()
@@ -64,8 +69,10 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
hs = HomeServer(
name, db_pool=db_pool, config=config,
version_string="Synapse/tests",
- database_engine=create_engine(config),
+ database_engine=create_engine(config.database_config),
get_db_conn=db_pool.get_db_conn,
+ room_list_handler=object(),
+ tls_server_context_factory=Mock(),
**kargs
)
hs.setup()
@@ -73,21 +80,18 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
hs = HomeServer(
name, db_pool=None, datastore=datastore, config=config,
version_string="Synapse/tests",
- database_engine=create_engine(config),
+ database_engine=create_engine(config.database_config),
+ room_list_handler=object(),
+ tls_server_context_factory=Mock(),
**kargs
)
# bcrypt is far too slow to be doing in unit tests
- def swap_out_hash_for_testing(old_build_handlers):
- def build_handlers():
- handlers = old_build_handlers()
- auth_handler = handlers.auth_handler
- auth_handler.hash = lambda p: hashlib.md5(p).hexdigest()
- auth_handler.validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h
- return handlers
- return build_handlers
-
- hs.build_handlers = swap_out_hash_for_testing(hs.build_handlers)
+ # Need to let the HS build an auth handler and then mess with it
+ # because AuthHandler's constructor requires the HS, so we can't make one
+ # beforehand and pass it in to the HS's constructor (chicken / egg)
+ hs.get_auth_handler().hash = lambda p: hashlib.md5(p).hexdigest()
+ hs.get_auth_handler().validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h
fed = kargs.get("resource_for_federation", None)
if fed:
@@ -116,6 +120,15 @@ def get_mock_call_args(pattern_func, mock_func):
return getcallargs(pattern_func, *invoked_args, **invoked_kargs)
+def mock_getRawHeaders(headers=None):
+ headers = headers if headers is not None else {}
+
+ def getRawHeaders(name, default=None):
+ return headers.get(name, default)
+
+ return getRawHeaders
+
+
# This is a mock /resource/ not an entire server
class MockHttpResource(HttpServer):
@@ -128,7 +141,7 @@ class MockHttpResource(HttpServer):
@patch('twisted.web.http.Request')
@defer.inlineCallbacks
- def trigger(self, http_method, path, content, mock_request):
+ def trigger(self, http_method, path, content, mock_request, federation_auth=False):
""" Fire an HTTP event.
Args:
@@ -156,9 +169,10 @@ class MockHttpResource(HttpServer):
mock_request.getClientIP.return_value = "-"
- mock_request.requestHeaders.getRawHeaders.return_value = [
- "X-Matrix origin=test,key=,sig="
- ]
+ headers = {}
+ if federation_auth:
+ headers["Authorization"] = ["X-Matrix origin=test,key=,sig="]
+ mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
# return the right path if the event requires it
mock_request.path = path
@@ -189,7 +203,7 @@ class MockHttpResource(HttpServer):
)
defer.returnValue((code, response))
except CodeMessageException as e:
- defer.returnValue((e.code, cs_error(e.msg)))
+ defer.returnValue((e.code, cs_error(e.msg, code=e.errcode)))
raise KeyError("No event can handle %s" % path)
@@ -221,6 +235,7 @@ class MockClock(object):
# list of lists of [absolute_time, callback, expired] in no particular
# order
self.timers = []
+ self.loopers = []
def time(self):
return self.now
@@ -241,7 +256,7 @@ class MockClock(object):
return t
def looping_call(self, function, interval):
- pass
+ self.loopers.append([function, interval / 1000., self.now])
def cancel_call_later(self, timer, ignore_errs=False):
if timer[2]:
@@ -270,6 +285,12 @@ class MockClock(object):
else:
self.timers.append(t)
+ for looped in self.loopers:
+ func, interval, last = looped
+ if last + interval < self.now:
+ func()
+ looped[2] = self.now
+
def advance_time_msec(self, ms):
self.advance_time(ms / 1000.)
@@ -298,7 +319,7 @@ class SQLiteMemoryDbPool(ConnectionPool, object):
return conn
def create_engine(self):
- return create_engine(self.config)
+ return create_engine(self.config.database_config)
class MemoryDataStore(object):
@@ -512,7 +533,3 @@ class DeferredMockCallable(object):
"call(%s)" % _format_call(c[0], c[1]) for c in calls
])
)
-
-
-def requester_for_user(user):
- return Requester(user, None, False)
|