diff --git a/tests/utils.py b/tests/utils.py
index 358b5b72b7..52405502e9 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -19,6 +19,9 @@ from synapse.api.constants import EventTypes
from synapse.storage.prepare_database import prepare_database
from synapse.storage.engines import create_engine
from synapse.server import HomeServer
+from synapse.federation.transport import server
+from synapse.types import Requester
+from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.logcontext import LoggingContext
@@ -44,9 +47,13 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config = Mock()
config.signing_key = [MockKey()]
config.event_cache_size = 1
- config.disable_registration = False
+ config.enable_registration = True
config.macaroon_secret_key = "not even a little secret"
config.server_name = "server.under.test"
+ config.trusted_third_party_id_servers = []
+ config.room_invite_state_types = []
+
+ config.database_config = {"name": "sqlite3"}
if "clock" not in kargs:
kargs["clock"] = MockClock()
@@ -57,14 +64,16 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
hs = HomeServer(
name, db_pool=db_pool, config=config,
version_string="Synapse/tests",
- database_engine=create_engine("sqlite3"),
+ database_engine=create_engine(config),
+ get_db_conn=db_pool.get_db_conn,
**kargs
)
+ hs.setup()
else:
hs = HomeServer(
name, db_pool=None, datastore=datastore, config=config,
version_string="Synapse/tests",
- database_engine=create_engine("sqlite3"),
+ database_engine=create_engine(config),
**kargs
)
@@ -80,6 +89,22 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
hs.build_handlers = swap_out_hash_for_testing(hs.build_handlers)
+ fed = kargs.get("resource_for_federation", None)
+ if fed:
+ server.register_servlets(
+ hs,
+ resource=fed,
+ authenticator=server.Authenticator(hs),
+ ratelimiter=FederationRateLimiter(
+ hs.get_clock(),
+ window_size=hs.config.federation_rc_window_size,
+ sleep_limit=hs.config.federation_rc_sleep_limit,
+ sleep_msec=hs.config.federation_rc_sleep_delay,
+ reject_limit=hs.config.federation_rc_reject_limit,
+ concurrent_requests=hs.config.federation_rc_concurrent
+ ),
+ )
+
defer.returnValue(hs)
@@ -131,7 +156,7 @@ class MockHttpResource(HttpServer):
mock_request.getClientIP.return_value = "-"
- mock_request.requestHeaders.getRawHeaders.return_value=[
+ mock_request.requestHeaders.getRawHeaders.return_value = [
"X-Matrix origin=test,key=,sig="
]
@@ -203,12 +228,12 @@ class MockClock(object):
def time_msec(self):
return self.time() * 1000
- def call_later(self, delay, callback):
+ def call_later(self, delay, callback, *args, **kwargs):
current_context = LoggingContext.current_context()
def wrapped_callback():
LoggingContext.thread_local.current_context = current_context
- callback()
+ callback(*args, **kwargs)
t = [self.now + delay, wrapped_callback, False]
self.timers.append(t)
@@ -218,9 +243,10 @@ class MockClock(object):
def looping_call(self, function, interval):
pass
- def cancel_call_later(self, timer):
+ def cancel_call_later(self, timer, ignore_errs=False):
if timer[2]:
- raise Exception("Cannot cancel an expired timer")
+ if not ignore_errs:
+ raise Exception("Cannot cancel an expired timer")
timer[2] = True
self.timers = [t for t in self.timers if t != timer]
@@ -256,12 +282,24 @@ class SQLiteMemoryDbPool(ConnectionPool, object):
cp_max=1,
)
+ self.config = Mock()
+ self.config.database_config = {"name": "sqlite3"}
+
def prepare(self):
- engine = create_engine("sqlite3")
+ engine = self.create_engine()
return self.runWithConnection(
- lambda conn: prepare_database(conn, engine)
+ lambda conn: prepare_database(conn, engine, self.config)
)
+ def get_db_conn(self):
+ conn = self.connect()
+ engine = self.create_engine()
+ prepare_database(conn, engine, self.config)
+ return conn
+
+ def create_engine(self):
+ return create_engine(self.config)
+
class MemoryDataStore(object):
@@ -333,13 +371,12 @@ class MemoryDataStore(object):
def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
return [
- self.members[r].get(user_id) for r in self.members
- if user_id in self.members[r] and
- self.members[r][user_id].membership in membership_list
+ m[user_id] for m in self.members.values()
+ if user_id in m and m[user_id].membership in membership_list
]
def get_room_events_stream(self, user_id=None, from_key=None, to_key=None,
- limit=0, with_feedback=False):
+ limit=0, with_feedback=False):
return ([], from_key) # TODO
def get_joined_hosts_for_room(self, room_id):
@@ -349,7 +386,6 @@ class MemoryDataStore(object):
if event.type == EventTypes.Member:
room_id = event.room_id
user = event.state_key
- membership = event.membership
self.members.setdefault(room_id, {})[user] = event
if hasattr(event, "state_key"):
@@ -429,9 +465,9 @@ class DeferredMockCallable(object):
d.callback(None)
return result
- failure = AssertionError("Was not expecting call(%s)" %
+ failure = AssertionError("Was not expecting call(%s)" % (
_format_call(args, kwargs)
- )
+ ))
for _, _, d in self.expectations:
try:
@@ -452,14 +488,12 @@ class DeferredMockCallable(object):
)
timer = reactor.callLater(
- timeout/1000,
+ timeout / 1000,
deferred.errback,
- AssertionError(
- "%d pending calls left: %s"% (
- len([e for e in self.expectations if not e[2].called]),
- [e for e in self.expectations if not e[2].called]
- )
- )
+ AssertionError("%d pending calls left: %s" % (
+ len([e for e in self.expectations if not e[2].called]),
+ [e for e in self.expectations if not e[2].called]
+ ))
)
yield deferred
@@ -473,8 +507,12 @@ class DeferredMockCallable(object):
calls = self.calls
self.calls = []
- raise AssertionError("Expected not to received any calls, got:\n" +
- "\n".join([
+ raise AssertionError(
+ "Expected not to received any calls, got:\n" + "\n".join([
"call(%s)" % _format_call(c[0], c[1]) for c in calls
])
)
+
+
+def requester_for_user(user):
+ return Requester(user, None, False)
|